Repository: google/flax Branch: main Commit: 1572c84e34a8 Files: 618 Total size: 9.9 MB Directory structure: gitextract_07u909eo/ ├── .git-blame-ignore-revs ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── bug_report.md │ ├── analytics/ │ │ ├── README.md │ │ ├── get_repo_metrics.py │ │ ├── issue_activity_since_date.gql │ │ ├── pr_data_query.gql │ │ └── requirements.txt │ ├── pull_request_template.md │ └── workflows/ │ ├── flax_publish.yml │ ├── flax_test.yml │ ├── flaxlib_publish.yml │ └── jax_nightly.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── AUTHORS ├── CHANGELOG.md ├── LICENSE ├── README.md ├── benchmarks/ │ ├── README.md │ ├── nnx_graph_overhead.py │ ├── nnx_mlpmixer_training.py │ ├── nnx_simple_training.py │ ├── nnx_state_traversal.py │ └── tracing/ │ ├── README.md │ ├── __init__.py │ ├── gemma.py │ ├── imagenet.py │ ├── lm1b.py │ ├── mnist.py │ ├── nlp_seq.py │ ├── ogbg_molpcba.py │ ├── ppo.py │ ├── requirements.txt │ ├── run_all_benchmarks.sh │ ├── seq2seq.py │ ├── sst2.py │ ├── tracing_benchmark.py │ ├── vae.py │ └── wmt.py ├── contributing.md ├── docs/ │ ├── .gitignore │ ├── .readthedocs.yaml │ ├── Makefile │ ├── README.md │ ├── _ext/ │ │ ├── codediff.py │ │ ├── codediff_test.py │ │ └── flax_module.py │ ├── _static/ │ │ └── css/ │ │ └── flax_theme.css │ ├── _templates/ │ │ └── autosummary/ │ │ └── flax_module.rst │ ├── api_reference/ │ │ ├── flax.core.frozen_dict.rst │ │ ├── flax.cursor.rst │ │ ├── flax.errors.rst │ │ ├── flax.jax_utils.rst │ │ ├── flax.linen/ │ │ │ ├── activation_functions.rst │ │ │ ├── decorators.rst │ │ │ ├── index.rst │ │ │ ├── init_apply.rst │ │ │ ├── initializers.rst │ │ │ ├── inspection.rst │ │ │ ├── layers.rst │ │ │ ├── module.rst │ │ │ ├── profiling.rst │ │ │ ├── spmd.rst │ │ │ ├── transformations.rst │ │ │ └── variable.rst │ │ ├── flax.serialization.rst │ │ ├── flax.struct.rst │ │ ├── flax.traceback_util.rst │ │ ├── flax.training.rst │ │ └── index.rst │ ├── conf.py │ ├── conf_sphinx_patch.py │ ├── developer_notes/ │ │ ├── index.rst │ │ ├── lift.md │ │ └── module_lifecycle.rst │ ├── examples/ │ │ ├── community_examples.rst │ │ ├── core_examples.rst │ │ ├── google_research_examples.rst │ │ ├── index.rst │ │ └── repositories_that_use_flax.rst │ ├── faq.rst │ ├── flip/ │ │ ├── 0000-template.md │ │ ├── 1009-optimizer-api.md │ │ ├── 1777-default-dtype.md │ │ ├── 2396-rnn.md │ │ ├── 2434-general-metadata.md │ │ ├── 2974-kw-only-dataclasses.md │ │ ├── 3099-rnnbase-refactor.md │ │ ├── 4105-jax-style-nnx-transforms.md │ │ └── README.md │ ├── glossary.rst │ ├── guides/ │ │ ├── converting_and_upgrading/ │ │ │ ├── convert_pytorch_to_flax.rst │ │ │ ├── haiku_migration_guide.rst │ │ │ ├── index.rst │ │ │ ├── linen_upgrade_guide.rst │ │ │ ├── optax_update_guide.rst │ │ │ ├── orbax_upgrade_guide.rst │ │ │ ├── regular_dict_upgrade_guide.rst │ │ │ └── rnncell_upgrade_guide.rst │ │ ├── data_preprocessing/ │ │ │ ├── full_eval.rst │ │ │ ├── index.rst │ │ │ ├── loading_datasets.ipynb │ │ │ └── loading_datasets.md │ │ ├── flax_fundamentals/ │ │ │ ├── arguments.md │ │ │ ├── flax_basics.ipynb │ │ │ ├── flax_basics.md │ │ │ ├── index.rst │ │ │ ├── rng_guide.ipynb │ │ │ ├── rng_guide.md │ │ │ ├── setup_or_nncompact.rst │ │ │ └── state_params.rst │ │ ├── flax_sharp_bits.ipynb │ │ ├── flax_sharp_bits.md │ │ ├── index.rst │ │ ├── model_inspection/ │ │ │ ├── extracting_intermediates.rst │ │ │ ├── index.rst │ │ │ ├── model_surgery.ipynb │ │ │ └── model_surgery.md │ │ ├── parallel_training/ │ │ │ ├── ensembling.rst │ │ │ ├── flax_on_pjit.ipynb │ │ │ ├── flax_on_pjit.md │ │ │ └── index.rst │ │ ├── quantization/ │ │ │ ├── fp8_basics.ipynb │ │ │ ├── fp8_basics.md │ │ │ └── index.rst │ │ └── training_techniques/ │ │ ├── batch_norm.rst │ │ ├── dropout.rst │ │ ├── index.rst │ │ ├── lr_schedule.rst │ │ ├── transfer_learning.ipynb │ │ ├── transfer_learning.md │ │ ├── use_checkpointing.ipynb │ │ └── use_checkpointing.md │ ├── index.rst │ ├── linen_intro.ipynb │ ├── linen_intro.md │ ├── quick_start.ipynb │ ├── quick_start.md │ └── robots.txt ├── docs_nnx/ │ ├── .gitignore │ ├── .readthedocs.yaml │ ├── Makefile │ ├── README.md │ ├── _ext/ │ │ ├── codediff.py │ │ ├── codediff_test.py │ │ └── flax_module.py │ ├── _static/ │ │ └── css/ │ │ └── flax_theme.css │ ├── _templates/ │ │ └── autosummary/ │ │ └── flax_module.rst │ ├── api_reference/ │ │ ├── flax.config.rst │ │ ├── flax.core.frozen_dict.rst │ │ ├── flax.nnx/ │ │ │ ├── bridge.rst │ │ │ ├── filterlib.rst │ │ │ ├── graph.rst │ │ │ ├── helpers.rst │ │ │ ├── index.rst │ │ │ ├── module.rst │ │ │ ├── nn/ │ │ │ │ ├── activations.rst │ │ │ │ ├── attention.rst │ │ │ │ ├── dtypes.rst │ │ │ │ ├── index.rst │ │ │ │ ├── initializers.rst │ │ │ │ ├── linear.rst │ │ │ │ ├── lora.rst │ │ │ │ ├── normalization.rst │ │ │ │ ├── recurrent.rst │ │ │ │ └── stochastic.rst │ │ │ ├── object.rst │ │ │ ├── rnglib.rst │ │ │ ├── spmd.rst │ │ │ ├── state.rst │ │ │ ├── summary.rst │ │ │ ├── training/ │ │ │ │ ├── index.rst │ │ │ │ ├── metrics.rst │ │ │ │ └── optimizer.rst │ │ │ ├── transforms.rst │ │ │ ├── variables.rst │ │ │ └── visualization.rst │ │ ├── flax.struct.rst │ │ ├── flax.training.rst │ │ ├── flax.traverse_util.rst │ │ └── index.rst │ ├── conf.py │ ├── conf_sphinx_patch.py │ ├── contributing.md │ ├── examples/ │ │ ├── core_examples.rst │ │ ├── gemma.ipynb │ │ ├── gemma.md │ │ └── index.rst │ ├── faq.rst │ ├── flip/ │ │ ├── 0000-template.md │ │ ├── 1009-optimizer-api.md │ │ ├── 1777-default-dtype.md │ │ ├── 2396-rnn.md │ │ ├── 2434-general-metadata.md │ │ ├── 2974-kw-only-dataclasses.md │ │ ├── 3099-rnnbase-refactor.md │ │ ├── 4105-jax-style-nnx-transforms.md │ │ ├── 4844-var-eager-sharding.md │ │ ├── 5310-tree-mode-nnx.md │ │ └── README.md │ ├── guides/ │ │ ├── blog.md │ │ ├── bridge_guide.ipynb │ │ ├── bridge_guide.md │ │ ├── checkpointing.ipynb │ │ ├── checkpointing.md │ │ ├── demo.ipynb │ │ ├── demo.md │ │ ├── extracting_intermediates.ipynb │ │ ├── extracting_intermediates.md │ │ ├── filters_guide.ipynb │ │ ├── filters_guide.md │ │ ├── flax_gspmd.ipynb │ │ ├── flax_gspmd.md │ │ ├── index.rst │ │ ├── jax_and_nnx_transforms.rst │ │ ├── performance.ipynb │ │ ├── performance.md │ │ ├── pytree.ipynb │ │ ├── pytree.md │ │ ├── randomness.ipynb │ │ ├── randomness.md │ │ ├── surgery.ipynb │ │ ├── surgery.md │ │ ├── tiny_nnx.ipynb │ │ ├── transforms.ipynb │ │ ├── transforms.md │ │ ├── view.ipynb │ │ └── view.md │ ├── guides_advanced.rst │ ├── guides_basic.rst │ ├── hijax/ │ │ ├── hijax.ipynb │ │ ├── hijax.md │ │ └── index.rst │ ├── index.rst │ ├── key_concepts.ipynb │ ├── key_concepts.md │ ├── migrating/ │ │ ├── convert_pytorch_to_flax.rst │ │ ├── haiku_to_flax.rst │ │ ├── index.rst │ │ ├── linen_to_nnx.rst │ │ └── nnx_010_to_nnx_011.rst │ ├── mnist_tutorial.ipynb │ ├── mnist_tutorial.md │ ├── nnx_basics.ipynb │ ├── nnx_basics.md │ ├── nnx_glossary.rst │ ├── philosophy.md │ ├── robots.txt │ └── why.rst ├── examples/ │ ├── README.md │ ├── __init__.py │ ├── cloud/ │ │ ├── README.md │ │ ├── launch_gce.py │ │ └── startup_script.sh │ ├── gemma/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── gemma3_4b.py │ │ │ ├── small.py │ │ │ └── tiny.py │ │ ├── helpers.py │ │ ├── helpers_test.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── layers.py │ │ ├── layers_test.py │ │ ├── main.py │ │ ├── modules.py │ │ ├── modules_test.py │ │ ├── params.py │ │ ├── positional_embeddings.py │ │ ├── positional_embeddings_test.py │ │ ├── requirements.txt │ │ ├── sampler.py │ │ ├── sampler_test.py │ │ ├── sow_lib.py │ │ ├── tokenizer.py │ │ ├── train.py │ │ ├── transformer.py │ │ ├── transformer_test.py │ │ └── utils.py │ ├── imagenet/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── fake_data_benchmark.py │ │ │ ├── tpu.py │ │ │ ├── v100_x8.py │ │ │ └── v100_x8_mixed_precision.py │ │ ├── imagenet.ipynb │ │ ├── imagenet_benchmark.py │ │ ├── imagenet_fake_data_benchmark.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── linen_design_test/ │ │ ├── attention_simple.py │ │ ├── autoencoder.py │ │ ├── dense.py │ │ ├── linear_regression.py │ │ ├── mlp_explicit.py │ │ ├── mlp_inline.py │ │ └── mlp_lazy.py │ ├── lm1b/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── temperature_sampler.py │ │ ├── temperature_sampler_test.py │ │ ├── tokenizer.py │ │ ├── train.py │ │ ├── train_test.py │ │ └── utils.py │ ├── mnist/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── main.py │ │ ├── mnist.ipynb │ │ ├── mnist_benchmark.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── nlp_seq/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ └── train.py │ ├── nnx_toy_examples/ │ │ ├── 01_functional_api.py │ │ ├── 02_lifted_transforms.py │ │ ├── 03_train_state.py │ │ ├── 04_data_parallel_with_jit.py │ │ ├── 05_vae.py │ │ ├── 06_scan_over_layers.py │ │ ├── 07_array_leaves.py │ │ ├── 08_save_load_checkpoints.py │ │ ├── 09_parameter_surgery.py │ │ ├── 10_fsdp_and_optimizer.py │ │ ├── hijax_basic.py │ │ ├── hijax_demo.py │ │ └── requirements.txt │ ├── ogbg_molpcba/ │ │ ├── README.md │ │ ├── configs/ │ │ │ ├── default.py │ │ │ ├── default_graph_net.py │ │ │ ├── hparam_sweep.py │ │ │ └── test.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── ogbg_molpcba.ipynb │ │ ├── ogbg_molpcba_benchmark.py │ │ ├── requirements.txt │ │ ├── train.py │ │ └── train_test.py │ ├── ppo/ │ │ ├── README.md │ │ ├── agent.py │ │ ├── configs/ │ │ │ └── default.py │ │ ├── env_utils.py │ │ ├── models.py │ │ ├── ppo_lib.py │ │ ├── ppo_lib_test.py │ │ ├── ppo_main.py │ │ ├── requirements.txt │ │ ├── seed_rl_atari_preprocessing.py │ │ └── test_episodes.py │ ├── seq2seq/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── seq2seq.ipynb │ │ ├── train.py │ │ └── train_test.py │ ├── sst2/ │ │ ├── README.md │ │ ├── build_vocabulary.py │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── input_pipeline_test.py │ │ ├── main.py │ │ ├── models.py │ │ ├── models_test.py │ │ ├── requirements.txt │ │ ├── sst2.ipynb │ │ ├── train.py │ │ ├── train_test.py │ │ ├── vocab.txt │ │ └── vocabulary.py │ ├── vae/ │ │ ├── README.md │ │ ├── configs/ │ │ │ └── default.py │ │ ├── input_pipeline.py │ │ ├── main.py │ │ ├── models.py │ │ ├── requirements.txt │ │ ├── results/ │ │ │ └── .gitignore │ │ ├── train.py │ │ └── utils.py │ └── wmt/ │ ├── README.md │ ├── bleu.py │ ├── configs/ │ │ └── default.py │ ├── decode.py │ ├── input_pipeline.py │ ├── input_pipeline_test.py │ ├── main.py │ ├── models.py │ ├── requirements.txt │ ├── tokenizer.py │ ├── train.py │ └── train_test.py ├── flax/ │ ├── __init__.py │ ├── configurations.py │ ├── core/ │ │ ├── __init__.py │ │ ├── axes_scan.py │ │ ├── flax_functional_engine.ipynb │ │ ├── frozen_dict.py │ │ ├── lift.py │ │ ├── meta.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── linear.py │ │ │ ├── normalization.py │ │ │ └── stochastic.py │ │ ├── partial_eval.py │ │ ├── scope.py │ │ ├── spmd.py │ │ ├── tracers.py │ │ └── variables.py │ ├── cursor.py │ ├── errors.py │ ├── experimental/ │ │ ├── __init__.py │ │ └── nnx.py │ ├── ids.py │ ├── io.py │ ├── jax_utils.py │ ├── linen/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── attention.py │ │ ├── batch_apply.py │ │ ├── combinators.py │ │ ├── dtypes.py │ │ ├── experimental/ │ │ │ └── layers_with_named_axes.py │ │ ├── fp8_ops.py │ │ ├── initializers.py │ │ ├── kw_only_dataclasses.py │ │ ├── linear.py │ │ ├── module.py │ │ ├── normalization.py │ │ ├── partitioning.py │ │ ├── pooling.py │ │ ├── recurrent.py │ │ ├── spmd.py │ │ ├── stochastic.py │ │ ├── summary.py │ │ └── transforms.py │ ├── metrics/ │ │ ├── __init__.py │ │ └── tensorboard.py │ ├── nnx/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── bridge/ │ │ │ ├── __init__.py │ │ │ ├── interop.py │ │ │ ├── module.py │ │ │ ├── variables.py │ │ │ └── wrappers.py │ │ ├── compat.py │ │ ├── extract.py │ │ ├── filterlib.py │ │ ├── graph.py │ │ ├── graphlib.py │ │ ├── helpers.py │ │ ├── ids.py │ │ ├── module.py │ │ ├── nn/ │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── attention.py │ │ │ ├── dtypes.py │ │ │ ├── initializers.py │ │ │ ├── linear.py │ │ │ ├── lora.py │ │ │ ├── normalization.py │ │ │ ├── recurrent.py │ │ │ └── stochastic.py │ │ ├── proxy_caller.py │ │ ├── pytreelib.py │ │ ├── reprlib.py │ │ ├── rnglib.py │ │ ├── scripts/ │ │ │ ├── requirements.txt │ │ │ └── run-all-examples.bash │ │ ├── spmd.py │ │ ├── statelib.py │ │ ├── summary.py │ │ ├── tracers.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── metrics.py │ │ │ └── optimizer.py │ │ ├── transforms/ │ │ │ ├── __init__.py │ │ │ ├── autodiff.py │ │ │ ├── compilation.py │ │ │ ├── general.py │ │ │ ├── iteration.py │ │ │ └── transforms.py │ │ ├── traversals.py │ │ ├── variablelib.py │ │ └── visualization.py │ ├── oss/ │ │ └── .git-blame-ignore-revs │ ├── py.typed │ ├── serialization.py │ ├── struct.py │ ├── testing/ │ │ ├── __init__.py │ │ └── benchmark.py │ ├── traceback_util.py │ ├── training/ │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── common_utils.py │ │ ├── dynamic_scale.py │ │ ├── early_stopping.py │ │ ├── lr_schedule.py │ │ ├── orbax_utils.py │ │ ├── prefetch_iterator.py │ │ └── train_state.py │ ├── traverse_util.py │ ├── typing.py │ └── version.py ├── flaxlib_src/ │ ├── .gitignore │ ├── CMakeLists.txt │ ├── Cargo.toml │ ├── LICENSE │ ├── README.md │ ├── pyproject.toml │ └── src/ │ ├── flaxlib/ │ │ ├── __init__.py │ │ └── flaxlib_cpp.pyi │ └── lib.cc ├── nnx.py ├── pylintrc ├── pyproject.toml └── tests/ ├── checkpoints_test.py ├── colab_tpu_jax_version.ipynb ├── configurations_test.py ├── core/ │ ├── core_frozen_dict_test.py │ ├── core_lift_test.py │ ├── core_meta_test.py │ ├── core_scope_test.py │ └── design/ │ ├── core_attention_test.py │ ├── core_auto_encoder_test.py │ ├── core_big_resnets_test.py │ ├── core_custom_vjp_test.py │ ├── core_dense_test.py │ ├── core_flow_test.py │ ├── core_resnet_test.py │ ├── core_scan_test.py │ ├── core_tied_autoencoder_test.py │ ├── core_vmap_test.py │ └── core_weight_std_test.py ├── cursor_test.py ├── download_dataset_metadata.sh ├── early_stopping_test.py ├── flaxlib_test.py ├── import_test.ipynb ├── io_test.py ├── jax_utils_test.py ├── linen/ │ ├── initializers_test.py │ ├── kw_only_dataclasses_test.py │ ├── linen_activation_test.py │ ├── linen_attention_test.py │ ├── linen_batch_apply_test.py │ ├── linen_combinators_test.py │ ├── linen_dtypes_test.py │ ├── linen_linear_test.py │ ├── linen_meta_test.py │ ├── linen_module_test.py │ ├── linen_recurrent_test.py │ ├── linen_test.py │ ├── linen_transforms_test.py │ ├── partitioning_test.py │ └── summary_test.py ├── nnx/ │ ├── __init__.py │ ├── bridge/ │ │ ├── module_test.py │ │ └── wrappers_test.py │ ├── containers_test.py │ ├── filters_test.py │ ├── graph_utils_test.py │ ├── helpers_test.py │ ├── ids_test.py │ ├── integration_test.py │ ├── metrics_test.py │ ├── module_test.py │ ├── mutable_array_test.py │ ├── nn/ │ │ ├── attention_test.py │ │ ├── conv_test.py │ │ ├── embed_test.py │ │ ├── linear_test.py │ │ ├── lora_test.py │ │ ├── normalization_test.py │ │ ├── recurrent_test.py │ │ └── stochastic_test.py │ ├── optimizer_test.py │ ├── partitioning_test.py │ ├── rngs_test.py │ ├── spmd_test.py │ ├── state_test.py │ ├── summary_test.py │ ├── test_traversals.py │ ├── transforms_test.py │ └── variable_test.py ├── pickle_test.py ├── run_all_tests.sh ├── serialization_test.py ├── struct_test.py ├── tensorboard_test.py ├── traceback_util_test.py └── traverse_util_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .git-blame-ignore-revs ================================================ # apply pyink 40a6e074e5224d733f964be00e21e0a1cb98bd2e ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug assignees: '' --- Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried. ### System information - OS Platform and Distribution (e.g., Linux Ubuntu 16.04): - Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`: - Python version: - GPU/TPU model and memory: - CUDA version (if applicable): ### Problem you have encountered: ### What you expected to happen: ### Logs, error messages, etc: ### Steps to reproduce: Whenever possible, please provide a *minimal example*. Please consider submitting it as a Colab link. ================================================ FILE: .github/analytics/README.md ================================================ # Repo Analytics To run the repo analytics follow the steps below: 1. You must have a Github token, if you don't have one you can create one by following [this guide](https://docs.github.com/en/enterprise-server@3.4/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token). 2. Install the requirements: ```bash pip install -r .github/analytics/requirements.txt ``` 3. Run the analytics: ```bash GITHUB_TOKEN= \ python .github/analytics/get_repo_metrics.py \ --repo-owner google \ --repo-name flax ``` ================================================ FILE: .github/analytics/get_repo_metrics.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from datetime import datetime from collections.abc import Callable import matplotlib.dates as mdates import matplotlib.pyplot as plt import pandas as pd import requests from absl import app, flags token = os.environ['GITHUB_TOKEN'] endpoint = r'https://api.github.com/graphql' headers = {'Authorization': f'bearer {token}'} # ------------------------------------------------------------------------------ # GraphQL # ------------------------------------------------------------------------------ # NOTE: This GraphQL logic was ported and adapted from this script: # https://github.com/scientific-python/devstats-data/blob/4c022961abc4ca6061f8719d9c3387e98734b90c/query.py # It contains style differences from Google's style guide. def load_query_from_file(fname, repo_owner, repo_name) -> str: with open(fname) as fh: query = fh.read() # Set target repo from template query = query.replace('_REPO_OWNER_', repo_owner) query = query.replace('_REPO_NAME_', repo_name) return query def send_query(query, query_type, cursor=None): """ Sends a GraphQL to the GitHub API. No validation is done on the query before sending. GitHub GraphQL is supported with the `cursor` argument. Parameters ---------- query : str The GraphQL query to be sent query_type : {"issues", "pullRequests"} The object being queried according to the GitHub GraphQL schema. Currently only issues and pullRequests are supported cursor : str, optional If given, then the cursor is injected into the query to support GitHub's GraphQL pagination. Returns ------- dict The result of the query (json) parsed by `json.loads` Notes ----- This is intended mostly for internal use within `get_all_responses`. """ # TODO: Expand this, either by parsing the query type from the query # directly or manually adding more query_types to the set if query_type not in {'issues', 'pullRequests'}: raise ValueError( "Only 'issues' and 'pullRequests' queries are currently supported" ) # TODO: Generalize this # WARNING: The cursor injection depends on the specific structure of the # query, this is the main reason why query types are limited to issues/PRs if cursor is not None: cursor_insertion_key = query_type + '(' cursor_ind = query.find(cursor_insertion_key) + len(cursor_insertion_key) query = query[:cursor_ind] + f'after:"{cursor}", ' + query[cursor_ind:] # Build request payload payload = {'query': query} response = requests.post(endpoint, json=payload, headers=headers) return json.loads(response.content) def get_all_responses(query, query_type): 'Helper function to bypass GitHub GraphQL API node limit.' # Get data from a single response initial_data = send_query(query, query_type) data, last_cursor, total_count = parse_single_query(initial_data, query_type) print(f'Retrieving {len(data)} out of {total_count} values...') # Continue requesting data (with pagination) until all are acquired while len(data) < total_count: rdata = send_query(query, query_type, cursor=last_cursor) pdata, last_cursor, _ = parse_single_query(rdata, query_type) data.extend(pdata) print(f'Retrieving {len(data)} out of {total_count} values...') print('Done.') return data def parse_single_query(data, query_type): """ Parses the data returned by `send_query` .. warning:: Like `send_query`, the logic here depends on the specific structure of the query (e.g. it must be an issue or PR query, and must have a total count). """ try: total_count = data['data']['repository'][query_type]['totalCount'] data = data['data']['repository'][query_type]['edges'] last_cursor = data[-1]['cursor'] except KeyError as e: print(data) raise e return data, last_cursor, total_count class GithubGrabber: """ Pulls down data via the GitHub APIv.4 given a valid GraphQL query. """ def __init__(self, query_fname, query_type, repo_owner, repo_name): """ Create an object to send/recv queries related to the issue tracker for the given repository via the GitHub API v.4. The repository to query against is given by: https://github.com// Parameters ---------- query_fname : str Path to a valid GraphQL query conforming to the GitHub GraphQL schema query_type : {"issues", "pullRequests"} Type of object that is being queried according to the GitHub GraphQL schema. Currently only "issues" and "pullRequests" are supported. repo_owner : str Repository owner. repo_name : str Repository name. """ self.query_fname = query_fname self.query_type = query_type # TODO: Parse this directly from query self.repo_owner = repo_owner self.repo_name = repo_name self.raw_data = None self.load_query() def load_query(self): self.query = load_query_from_file( self.query_fname, self.repo_owner, self.repo_name ) def get(self): self.raw_data = get_all_responses(self.query, self.query_type) # ------------------------------------------------------------------------------ # metrics helpers # ------------------------------------------------------------------------------ def _to_datetime(date_str: str) -> datetime: return datetime.fromisoformat(date_str.replace('Z', '')) def _get_issues_features(issues): for issue in issues: issue = issue['node'] created_at = _to_datetime(issue['createdAt']) time_labeled_or_converted = None time_issue_closed = None for event in issue['timelineItems']['edges']: event = event['node'] if event['__typename'] in {'LabeledEvent', 'ConvertedToDiscussionEvent'}: time_labeled_or_converted = _to_datetime(event['createdAt']) if event['__typename'] == 'ClosedEvent': time_issue_closed = _to_datetime(event['createdAt']) yield { 'created_at': created_at, 'time_labeled_or_converted': time_labeled_or_converted, 'time_issue_closed': time_issue_closed, 'issue_closed': issue['state'] == 'CLOSED', } def _get_pr_features(prs): for pr in prs: pr = pr['node'] created_at = _to_datetime(pr['createdAt']) ready_for_review_at = _to_datetime(pr['createdAt']) time_labeled_or_assigned = None time_merged_or_closed = None time_review = None if pr['reviews']['nodes']: review = pr['reviews']['nodes'][0] time_review = _to_datetime(review['createdAt']) for event in pr['timelineItems']['edges']: event = event['node'] if ( time_labeled_or_assigned is None and event['__typename'] == 'LabeledEvent' and 'cla:' not in event['label']['name'] ): time_labeled_or_assigned = _to_datetime(event['createdAt']) if ( time_labeled_or_assigned is None and event['__typename'] == 'AssignedEvent' ): time_labeled_or_assigned = _to_datetime(event['createdAt']) if event['__typename'] in {'ClosedEvent', 'MergedEvent'}: time_merged_or_closed = _to_datetime(event['createdAt']) if event['__typename'] == 'ReadyForReviewEvent': ready_for_review_at = _to_datetime(event['createdAt']) yield { 'created_at': created_at, 'ready_for_review_at': ready_for_review_at, 'time_labeled_or_assigned': time_labeled_or_assigned, 'time_merged_or_closed': time_merged_or_closed, 'time_review': time_review, 'pr_closed': pr['state'] != 'OPEN', } def _start_of_month(date: datetime) -> datetime: return date.replace(day=1, hour=0, minute=0, second=0, microsecond=0) def _shift_n_months(date: datetime, n: int) -> datetime: month = ((date.month + n - 1) % 12) + 1 # shift to next year if necessary if date.month > month: date = date.replace(year=date.year + 1) date = date.replace(month=month) return date def _rolling_window( df: pd.DataFrame, f: Callable[[pd.DataFrame], pd.Series], window_size: int = 6, step: int = 1, ) -> pd.DataFrame: # start of month of the first issue start: datetime = df.iloc[0]['created_at'].replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) end = _shift_n_months(start, window_size) last_month = _start_of_month(df.iloc[-1]['created_at']) last_month = _shift_n_months(last_month, 1) rows: list[pd.Series] = [] while end < last_month: row = f(df[(df['created_at'] >= start) & (df['created_at'] < end)]) row['period_start'] = start row['period_end'] = end rows.append(row) start = _shift_n_months(start, step) end = _shift_n_months(end, step) df = pd.DataFrame(rows) df = df[['period_start', 'period_end'] + list(df.columns[:-2])] return df def _process_prs(df: pd.DataFrame) -> pd.Series: return pd.Series( { 'pr_response_time': df['pr_response_time'].dt.days.mean(), 'pr_resolution_time': df['pr_resolution_time'].dt.days.mean(), } ) def _process_issues(df: pd.DataFrame) -> pd.Series: return pd.Series( { 'issue_response_time': df['issue_response_time'].dt.days.mean(), 'issue_resolution_time': df['issue_resolution_time'].dt.days.mean(), } ) # ----------------------------------------------------------------------------- # main # ----------------------------------------------------------------------------- FLAGS = flags.FLAGS flags.DEFINE_string('repo_owner', 'google', 'User name or organization') flags.DEFINE_string('repo_name', 'flax', 'Name of the repository') def main(_): repo_owner: str = FLAGS.repo_owner repo_name: str = FLAGS.repo_name # Download issue data issues = GithubGrabber( '.github/analytics/issue_activity_since_date.gql', 'issues', repo_owner=repo_owner, repo_name=repo_name, ) issues.get() df_issues = df_issues0 = pd.DataFrame( list(_get_issues_features(issues.raw_data)) ) df_issues['issue_response_time'] = ( df_issues['time_labeled_or_converted'] - df_issues['created_at'] ) df_issues['issue_resolution_time'] = ( df_issues['time_issue_closed'] - df_issues['created_at'] ) df_issues = _rolling_window(df_issues, _process_issues) prs = GithubGrabber( '.github/analytics/pr_data_query.gql', 'pullRequests', repo_owner=repo_owner, repo_name=repo_name, ) prs.get() df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data))) time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min( axis=1 ) df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at'] df_prs['pr_resolution_time'] = ( df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at'] ) df_prs = _rolling_window(df_prs, _process_prs) # get cummulative issues df_issues0 = df_issues0.copy() df_issues0['number_of_issues'] = 1 df_issues0['number_of_issues'] = df_issues0['number_of_issues'].cumsum() # get cummulative prs df_prs0 = df_prs0.copy() df_prs0['number_of_prs'] = 1 df_prs0['number_of_prs'] = df_prs0['number_of_prs'].cumsum() # plot cumulative issues plt.figure() plt.plot(df_issues0['created_at'], df_issues0['number_of_issues']) plt.xlabel('Date') plt.ylabel('Number of issues') plt.title('Number of issues') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) # plot cumulative prs plt.figure() plt.plot(df_prs0['created_at'], df_prs0['number_of_prs']) plt.xlabel('Date') plt.ylabel('Number of PRs') plt.title('Number of PRs') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) # plot for isssue_response_time plt.figure() plt.plot(df_issues['period_end'], df_issues['issue_response_time']) plt.xlabel('Date') plt.ylabel('Issue Response Time (days)') plt.title('Issue Response Time') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) plt.ylim(0) # plot for issue_resolution_time plt.figure() plt.plot(df_issues['period_end'], df_issues['issue_resolution_time']) plt.xlabel('Date') plt.ylabel('Issue Resolution Time (days)') plt.title('Issue Resolution Time') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) plt.ylim(0) # plot for pr_response_time plt.figure() plt.plot(df_prs['period_end'], df_prs['pr_response_time']) plt.xlabel('Date') plt.ylabel('Pull Request Response Time (days)') plt.title('Pull Request Response Time') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) plt.ylim(0) # plot for pr_resolution_time plt.figure() plt.plot(df_prs['period_end'], df_prs['pr_resolution_time']) plt.xlabel('Date') plt.ylabel('Pull Request Resolution Time (days)') plt.title('Pull Request Resolution Time') plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5)) plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%b %Y')) plt.ylim(0) # show plots plt.show() if __name__ == '__main__': app.run(main) ================================================ FILE: .github/analytics/issue_activity_since_date.gql ================================================ { # Queries all the issues in a repo. For each issue, we get some basic data such as # the number, state, labels, and title. The most important part is the 'timelineItems' # which are the events that happened to the issue, we can use the information about # the datetime about certain key events to define some metrics. Note that we are # getting more information than is probably needed but its fine for now. repository(owner: "_REPO_OWNER_", name: "_REPO_NAME_") { issues(first: 100) { totalCount edges { cursor node { number title createdAt state closedAt updatedAt url labels(first: 100) { edges { node { name } } } timelineItems(first: 100, itemTypes: [LABELED_EVENT, CONVERTED_TO_DISCUSSION_EVENT, ISSUE_COMMENT, CLOSED_EVENT]) { totalCount edges { node { __typename ... on ConvertedToDiscussionEvent { createdAt } ... on IssueComment { author { login } createdAt } ... on ClosedEvent { actor { login } createdAt } ... on LabeledEvent { label { name } createdAt } } } } } } } } } ================================================ FILE: .github/analytics/pr_data_query.gql ================================================ query { # Queries all the Pull Requests in a repo. For each issue, we get some basic data such as # the number, state, reviews, and title. The most important part is the 'timelineItems' # which are the events that happened to the issue, we can use the information about # the datetime about certain key events to define some metrics. We also use the 'reviews' # as indicators for certain metrics. Note that we are getting more information than is # probably needed but its fine for now. repository(owner:"_REPO_OWNER_", name:"_REPO_NAME_") { pullRequests(first:100) { totalCount edges { cursor node { number state title createdAt author{ login } mergedAt reviews(first: 100){ nodes { createdAt } } timelineItems(first: 100, itemTypes: [LABELED_EVENT, ASSIGNED_EVENT, MERGED_EVENT, READY_FOR_REVIEW_EVENT, CLOSED_EVENT]) { edges { node { __typename ... on ClosedEvent { actor { login } createdAt } ... on LabeledEvent { label { name } actor { login } createdAt } ... on MergedEvent { actor { login } createdAt } ... on ReadyForReviewEvent { actor { login } createdAt } ... on AssignedEvent { actor { login } createdAt } } } } } } } } } ================================================ FILE: .github/analytics/requirements.txt ================================================ pandas absl-py requests matplotlib ================================================ FILE: .github/pull_request_template.md ================================================ # What does this PR do? Fixes # (issue) ## Checklist - [ ] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case). - [ ] This change is discussed in a Github issue/[discussion](https://github.com/google/flax/discussions) (please add a link). - [ ] The documentation and docstrings adhere to the [documentation guidelines](https://github.com/google/flax/blob/main/docs/README.md#how-to-write-code-documentation). - [ ] This change includes necessary high-coverage tests. (No quality testing = no merge!) ================================================ FILE: .github/workflows/flax_publish.yml ================================================ # This workflows will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries name: Flax - Build and upload to PyPI on: release: types: [published] jobs: build_package: name: Build package runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Build package run: pipx run build - name: List files run: ls -l dist/ - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: distribution path: | dist/*.tar.gz dist/*.whl upload_pypi: name: Release & Upload to PyPI # Only publish release to PyPI when a github release is created. if: github.event_name == 'release' && github.event.action == 'published' needs: build_package runs-on: ubuntu-latest environment: release permissions: id-token: write steps: - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: name: distribution path: dist - name: List files run: ls -l dist/ - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 discord_release: if: github.repository_owner == 'google' runs-on: ubuntu-latest steps: - name: Get release URL id: get-release-url run: | URL="https://github.com/google/flax/releases" echo "::set-output name=URL::$URL" - name: Get content uses: 2428392/gh-truncate-string-action@b3ff790d21cf42af3ca7579146eedb93c8fb0757 # v1.4.1 id: get-content with: stringToTruncate: | Flax [${{ github.event.release.tag_name }}](<${{ steps.get-release-url.outputs.URL }}>) was just released! ${{ github.event.release.body }} maxLength: 2000 truncationSymbol: "..." - name: Discord Webhook Action uses: tsickert/discord-webhook@c840d45a03a323fbc3f7507ac7769dbd91bfb164 # v5.3.0 with: webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }} content: ${{ steps.get-content.outputs.string }} ================================================ FILE: .github/workflows/flax_test.yml ================================================ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions name: Flax - Test concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: push: branches: - main - 'test_*' pull_request: branches: - main jobs: pre-commit: name: Test pre-commit hooks runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.11' - run: python -m pip install pre-commit - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/pre-commit key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'pyproject.toml') }} - run: pre-commit run --show-diff-on-failure --color=always --all-files commit-count: name: Check commit count runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 # We allow at most 5 commits in a branch to ensure our CI doesn't break. - name: Check commit count in PR if: always() shell: bash run: | set -x # $GITHUB_REF is in format `refs/heads/`. We fetch it under # the name `commit-count` so we can refer to it below. # Do an unshallow fetch so we retrieve all commits (this is necessary # because ations/checkout@v2 fetches a shallow copy). git fetch origin --unshallow $GITHUB_REF:commit-count git fetch origin main diff=$(git rev-list --count origin/main...commit-count) # $GITHUB_REF adds an additional commit to the commit tree, so $diff is # one too high when executing this as a Github Action. if (( $diff > 6)); then echo "ERROR! More than 5 commits in PR -- please squash your commits." url=https://flax.readthedocs.io/en/latest/contributing.html#too-many-commits-in-a-pull-request echo "See $url for help on how to resolve this." exit 1 fi test-import: name: Test import standalone runs-on: ubuntu-latest strategy: matrix: python-version: ['3.11', '3.12', '3.13', '3.14'] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 with: version: "0.8.13" - name: Install standalone dependencies only run: | uv sync - name: Test importing Flax run: | uv run --no-sync python -c "import flax" tests: name: Run Tests needs: [pre-commit, commit-count] runs-on: ubuntu-24.04-16core strategy: # Make sure to change `github_check_runs` in `copy.bara.sky` if you change the tests here. matrix: python-version: ['3.11', '3.12', '3.13'] test-type: [doctest, pytest] jax-version: [newest] include: # keep in sync with internal type checking - python-version: '3.12' test-type: pytype jax-version: 'newest' - python-version: '3.12' test-type: mypy jax-version: 'newest' steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Setup uv uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 with: version: "0.8.13" python-version: ${{ matrix.python-version }} enable-cache: true - name: Install dependencies run: | rm -fr .venv uv sync --extra testing --extra docs - name: Install JAX run: | if [[ "${{ matrix.jax-version }}" == "newest" ]]; then uv pip install -U jax jaxlib else uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi if [[ "${{ matrix.test-type }}" == "pytest" ]]; then uv pip install -U tensorflow-datasets fi - name: Test with ${{ matrix.test-type }} run: | if [[ "${{ matrix.test-type }}" == "doctest" ]]; then uv run --no-sync tests/run_all_tests.sh --only-doctest elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then uv run --no-sync tests/run_all_tests.sh --only-pytest elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then uv run --no-sync tests/run_all_tests.sh --only-pytype elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then uv run --no-sync tests/run_all_tests.sh --only-mypy else echo "Unknown test type: ${{ matrix.test-type }}" exit 1 fi - name: Upload coverage to Codecov if: matrix.test-type == 'pytest' uses: codecov/codecov-action@1e68e06f1dbfde0e4cefc87efeba9e4643565303 # v5.1.2 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: file: ./coverage.xml # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status if: always() shell: bash run: | status="${{ job.status }}" lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') curl -sS --request POST \ --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ --header 'content-type: application/json' \ --data '{ "state": "'$lowercase_status'", "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", "description": "'$status'", "context": "github-actions/Build" }' # This is a temporary workflow to test flax on Python 3.14 and # skipping deps like tensorstore, tensorflow etc tests-python314: name: Run Tests on Python 3.14 needs: [pre-commit, commit-count] runs-on: ubuntu-24.04-16core steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Setup uv uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 with: version: "0.9.2" python-version: "3.14" activate-environment: true enable-cache: true - name: Install dependencies run: | rm -fr .venv uv sync --extra testing --extra docs - name: Test with pytest run: | export XLA_FLAGS='--xla_force_host_platform_device_count=4' find tests/ -name "*.py" | grep -vE 'io_test|tensorboard' | xargs pytest -n auto ================================================ FILE: .github/workflows/flaxlib_publish.yml ================================================ name: Flaxlib - Build and upload to PyPI # for testing only: on: push: branches: [main] paths: ['flaxlib/**'] release: types: [published] jobs: build_wheels: if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/') name: Build wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: # macos-13 is an intel runner, macos-14 is apple silicon os: [ubuntu-latest, windows-latest, macos-13, macos-14] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - name: Setup Rust uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1 - name: Install cibuildwheel run: python -m pip install cibuildwheel==2.21.0 - name: Build wheels run: python -m cibuildwheel --output-dir ./flaxlib/wheelhouse ./flaxlib env: # rust doesn't seem to be available for musl linux on i686 CIBW_SKIP: "*-musllinux_i686" CIBW_ENVIRONMENT: 'PATH="$HOME/.cargo/bin:$PATH" CARGO_TERM_COLOR="always"' CIBW_ENVIRONMENT_WINDOWS: 'PATH="$UserProfile\.cargo\bin;$PATH"' CIBW_BEFORE_BUILD: rustup show CIBW_BEFORE_BUILD_LINUX: | curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=minimal -y && rustup show - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./flaxlib/wheelhouse/*.whl build_sdist: if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/') name: Build source distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Setup Rust uses: actions-rust-lang/setup-rust-toolchain@11df97af8e8102fd60b60a77dfbf58d40cd843b8 # v1.10.1 - name: Build sdist run: pipx run build --sdist flaxlib - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: cibw-sdist path: ./flaxlib/dist/*.tar.gz upload_pypi: if: github.event_name == 'push' && contains(github.event.head_commit.modified, 'flaxlib/') name: Upload to PyPI needs: [build_wheels, build_sdist] runs-on: ubuntu-latest permissions: id-token: write steps: - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.x' - name: Install dependencies run: | python -m pip install --upgrade pip pip install setuptools build wheel twine - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: # unpacks all CIBW artifacts into dist/ pattern: cibw-* path: ./flaxlib/dist merge-multiple: true - name: Build and publish env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.FLAXLIB_PYPI_TOKEN }} run: | twine upload flaxlib/dist/* ================================================ FILE: .github/workflows/jax_nightly.yml ================================================ name: CI - with JAX nightly concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: schedule: - cron: "0 12 * * *" # Daily at 12:00 UTC workflow_dispatch: # allows triggering the workflow run manually pull_request: # Automatically trigger on pull requests affecting this file branches: - main paths: - '**workflows/jax_nightly.yml' jobs: jax-nightly: runs-on: ubuntu-latest permissions: contents: read issues: write # for failed-build-issue strategy: fail-fast: false matrix: python-version: ["3.11"] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python ${{ matrix.python-version }} id: setup_python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Setup uv uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0 with: version: "0.8.13" - name: Install dependencies run: | uv sync --extra testing --extra docs - name: Install JAX run: | uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - name: Run test suite if: success() run: | uv run tests/run_all_tests.sh --only-pytest - name: Notify failed build uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 if: failure() && github.event.pull_request == null with: github-token: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ *~ \#*\# *.pyc .tfds .DS_Store dist/ build/ *.egg-info *.rej .pytype .vscode/* /.devcontainer docs*/**/_autosummary docs*/_build docs*/**/tmp flaxlib_src/build flaxlib_src/builddir flaxlib_src/dist flaxlib_src/subprojects .venv venv/ venv.bak/ # used by direnv .envrc # uv uv.lock # custom /tmp-files .agent/rules/flax.md .github/copilot-instructions.md .env ================================================ FILE: .pre-commit-config.yaml ================================================ # Install the pre-commit hooks below with # 'pre-commit install' # Auto-update the version of the hooks with # 'pre-commit autoupdate' # Run the hooks on all files with # 'pre-commit run --all' repos: - repo: https://github.com/mwouts/jupytext rev: v1.13.8 hooks: - id: jupytext args: [--sync] # diable pyink for now # - repo: https://github.com/google/pyink # rev: 23.5.0 # hooks: # - id: pyink - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - id: check-toml - id: trailing-whitespace exclude: ^docs.*\.md$ - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: - id: nbstripout exclude: ^examples/.* args: [ --keep-output, --keep-count, --extra-keys, "cell.metadata.executionInfo cell.metadata.id metadata.kernelspec metadata.vscode metadata.colab cell.metadata.executionInfo.user cell.metadata.executionInfo.user_tz cell.metadata.colab", ] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.1.3 hooks: # Run the Ruff linter. - id: ruff args: [--fix, --exit-non-zero-on-fix] # Disable Ruff formatter for now # # Run the Ruff formatter. # - id: ruff-format - repo: https://github.com/asottile/pyupgrade rev: v3.16.0 hooks: - id: pyupgrade args: [--py310-plus] ================================================ FILE: .readthedocs.yml ================================================ # deprecated ================================================ FILE: AUTHORS ================================================ # This is the list the Flax authors for copyright purposes. # # This does not necessarily list everyone who has contributed code, since in # some cases, their employer may be the copyright holder. To see the full list # of contributors, see the revision history in source control. Google LLC ================================================ FILE: CHANGELOG.md ================================================ Changelog ---------- vNext ------ (Add your change to a random empty line to avoid merge conflicts) - - - - - removed GeGLU simplistic activation, it should be implemented manually. - - - - removed FLAX_LAZY_RNG flag support for old non-lazy PRNG derivation mode - - - - - - - - - - 0.8.2 ----- - fixed rng guide outputs by @chiamp in https://github.com/google/flax/pull/3685 - enforce mask kwarg in norm layers by @chiamp in https://github.com/google/flax/pull/3663 - added kwargs to self.param and self.variable by @chiamp in https://github.com/google/flax/pull/3675 - added nnx normalization tests by @chiamp in https://github.com/google/flax/pull/3689 - added NNX init_cache docstring example by @chiamp in https://github.com/google/flax/pull/3688 - added nnx attention equivalence test by @chiamp in https://github.com/google/flax/pull/3687 - Fix bug that assumed frozen-dict keys were strings. by @copybara-service in https://github.com/google/flax/pull/3692 - added nnx rmsnorm by @chiamp in https://github.com/google/flax/pull/3691 - updated nnx compute_stats by @chiamp in https://github.com/google/flax/pull/3693 - fixed intercept_methods docstring by @chiamp in https://github.com/google/flax/pull/3694 - [nnx] Add Sphinx Docs by @cgarciae in https://github.com/google/flax/pull/3678 - Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in https://github.com/google/flax/pull/3703 - added default params rng to .apply by @chiamp in https://github.com/google/flax/pull/3698 - [nnx] add partial_init by @cgarciae in https://github.com/google/flax/pull/3674 - make make_rng default to 'params' by @chiamp in https://github.com/google/flax/pull/3699 - Add SimpleCell. by @carlosgmartin in https://github.com/google/flax/pull/3697 - fix Module.module_paths docstring by @cgarciae in https://github.com/google/flax/pull/3709 - Guarantee the latest JAX version on CI by @cgarciae in https://github.com/google/flax/pull/3705 - Replace deprecated API `jax.tree.map` by @copybara-service in https://github.com/google/flax/pull/3715 - Use `jax.tree_util.tree_map` instead of deprecated `jax.tree.map`. by @copybara-service in https://github.com/google/flax/pull/3714 - [nnx] simplify readme by @cgarciae in https://github.com/google/flax/pull/3707 - [nnx] add demo.ipynb by @cgarciae in https://github.com/google/flax/pull/3680 - Fix Tabulate's compute_flops by @cgarciae in https://github.com/google/flax/pull/3721 - [nnx] simplify TraceState by @cgarciae in https://github.com/google/flax/pull/3724 - Add broadcast of `strides` and `kernel_dilation` to `nn.ConvTranspose` by @IvyZX in https://github.com/google/flax/pull/3731 - [nnx] Fix State.__sub__ by @cgarciae in https://github.com/google/flax/pull/3704 - [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in https://github.com/google/flax/pull/3722 - [nnx] explicit Variables by @cgarciae in https://github.com/google/flax/pull/3720 - Improves fingerprint definition for Modules in nn.jit. by @copybara-service in https://github.com/google/flax/pull/3736 - Flax: avoid key reuse in tests by @copybara-service in https://github.com/google/flax/pull/3740 - added Einsum layer by @chiamp in https://github.com/google/flax/pull/3710 - nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in https://github.com/google/flax/pull/3737 - [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in https://github.com/google/flax/pull/3623 - removed nnx dataclass by @chiamp in https://github.com/google/flax/pull/3742 - [nnx] cleanup graph_utils by @cgarciae in https://github.com/google/flax/pull/3728 - Fix doctest and unbreak head by @IvyZX in https://github.com/google/flax/pull/3753 - [nnx] add pytree support by @cgarciae in https://github.com/google/flax/pull/3732 - fixed intercept_methods docstring by @chiamp in https://github.com/google/flax/pull/3752 - Add ConvLSTMCell to docs. by @carlosgmartin in https://github.com/google/flax/pull/3712 - [nnx] remove flagslib by @cgarciae in https://github.com/google/flax/pull/3733 - Fix tests after applying JAX key-reuse checker. See: by @copybara-service in https://github.com/google/flax/pull/3748 0.8.1 ----- - Added default collection in `make_rng`. - Added `InstanceNorm` and renamed `channel_axes` to `feature_axes`. - Added norm equivalence tests. - Added `Module.module_paths` and doc. - make `Sequential.__call__` compact. - Added `nn.compact_name_scope` v3. - Add explicit control over frozen/slots setting in `flax.struct.dataclass`. - Replacing `jax.tree_util.tree_map` with mapping over leafs. - Fixed docs and docstrings. 0.8.0 ----- - Added [NNX](https://github.com/google/flax/tree/main/flax/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch. - Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier. - Add copy() method to Module. This is a user-friendly version of the internal clone() method with better defaults for common use cases. - Added [`BatchApply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#batchapply) class. - Added `sow_weights` option in attention layer. - Added [`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.MultiHeadAttention.html) alias. - Added kwargs support for `nn.jit`. - Deprecated `normalize` activation function, in favor of `standardize`. - Added `GeGLU` activation function. - Added `Enum` support for `tabulate` function. - Added simple argument-only lifted `nn.grad` function. 0.7.5 ----- - Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns. - Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding `inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic` to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389). - Use new typed PRNG keys throughout flax: this essentially involved changing uses of `jax.random.PRNGKey` to `jax.random.key`. (See [JEP 9263](https://github.com/jax-ml/jax/pull/17297) for details). If you notice dispatch performance regressions after this change, be sure you update `jax` to version 0.4.16 or newer. - Added `has_improved` field to EarlyStopping and changed the return signature of `EarlyStopping.update` from returning a tuple to returning just the updated class. See more details in [#3385](https://github.com/google/flax/pull/3385) 0.7.4 ----- New features: - Add QK-normalization to MultiHeadDotProductAttention - Allow apply's method argument to accept submodules - Add module path to nn.module. - [JAX] Generate new type of PRNG keys Bug fixes: - Directly call original method if method interceptor stack is empty. - fix stackoverflow when loading pickled module - Improve kw_only_dataclass. - Allow pass-through implementation of state dict - Promote dot_general injections from a function to a module. 0.7.2 ----- New features: - make `flax.core.copy` `add_or_replace` optional - Add `use_fast_variance` option to `GroupNorm` and `BatchNorm` to allow disabling it. Bug fixes: - Use `field_specifiers` instead of `field_descriptors` in `@dataclass_transform`. - Fix `nn.Module` typing. - [JAX] Replace uses of `jax.experimental.pjit.with_sharding_constraint` with `jax.lax.with_sharding_constraint`. 0.7.1 ----- Breaking changes: - Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191) New features: - Use pyink - added dict migration guide to index - add scan over layers section - Expose options to customize rich.Table - add support for initializing carry variables in scan - Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings. Bug fixes: - Use import `orbax.checkpoint` which is a better import pattern. - Use import `orbax.checkpoint as ocp` to avoid the verbosity of using 'orbax.checkpoint` every time. - [linen] Add alternative, more numerically stable, variance calculation to `LayerNorm`. - [linen] Minor cleanup to normalization code. - Fix norm calculation bug for 0-rank arrays. - [JAX] Remove references to jax.config.jax_array. - [linen] Use `stack` instead of `concatenate` in `compute_stats`, to handle scalar stats case. - [linen] More minor cleanup in normalization `compute_stats`. - Fix warnings from atari gym. - Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches. - Fix carry slice logic - make flax_basics guide use utility fns - Fix checkpointing guide error at head - Improve scan docs 0.7.0 ----- - RNNCellBase refactor. 0.6.11 ----- - Set Orbax-as-backend to be the default checkpointing method. - Fix setup trigger issue under sharing and transforms. - Add collection to self.scope.reserve(name, col) so that sow works with the same name in different collections. - Minor improvements for Sequential. - Improve the error message in MultiHeadDotProductAttention. - Allow manually specifying the rng key for Dropout. - RNN refactor. - fixed missing separator for rng fold in. 0.6.10 ----- - Rudimentary quantization support: some layers can be parametrized with custom dot_general and conv_general_dilated. 0.6.9 ----- - Depend on `orbax-checkpoint` package instead of `orbax`. - Refactored setup scripts to `project.toml`. - Added pretty_repr utility fn. - Fix get_partition_spec on replicated array. - Updates imagenet.ipynb to use GPU Colab runtime, and fixed config. - Upgrade checkpointing code to `jax.sharding`, and with more warnings. 0.6.8 ----- - The automatic checkpoint migration was temporarily rolled back due to legacy compatibility issues. - We still recommend you to use the [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and migrate completely to the Orbax API to ensure stability. - Or alternatively, add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project to avoid being impacted by the automatic migration process. - Added utility functions to frozen_dict api. - Migrated Flax away from `register_keypaths`. - Fixes kwargs in convert_to_graphs_tuple_fn. - Fixed examples in a few ways: - Bumped the TF version - Used latest checkpoint formats - Other misc fixes. 0.6.7 ----- - New checkpoints will be saved using Orbax! Please check out [upgrade guide](https://flax.readthedocs.io/en/latest/guides/orbax_upgrade_guide.html) and consider migrating completely to the Orbax API. - You could `flax.config.update('flax_use_orbax_checkpointing', False)` to temporarily disable this migration, but note that Flax legacy checkpointing will be removed 3 months from Mar 10, 2023. - Migrating `FrozenDict` to regular dict: utility functions now work on both. - Migrated Flax dataclass and `FrozenDict` to JAX pytree keypath API. - Fixed pytype and improved typing for `Module` - Fixed up uses of PyTree and PyTreeDef types. 0.6.6 ----- - 0.6.5 was yanked so this release contains all that was in 0.6.5 as well. - Migrated regular dict to FrozenDict, currently controlled by a flag. - Refactored and separate out name relaxation policy changes. - Added RMS normalization layer. 0.6.5 ----- - Added logical partitioning helpers for using pjit with Flax. - Add ``Module.lazy_init`` to avoid compute during Module initialization. 0.6.4 ----- New features: - Our [ReadTheDoc](https://flax.readthedocs.io/en/latest/index.html) site is a lot more organized now! More improvements on the way. - Flax auto-SPMD parallelism API to work seamlessly with `jax.pjit`: https://flax.readthedocs.io/en/latest/guides/flax_on_pjit.html - Added new `zeros_init` and `ones_init` initializers. - Adds standardize initializer. - Allowed specifying method as a string. - Allowed runtime overwrite of `flax.config` flags. Bug fixes: - Added missing `dataclass.fields` from `__repr__`. - Renamed ConvLSTM to ConvLSTMCell. - Fix some tiny inconsistencies between scope.py and module.py. - Improved many many docstrings, comments and error messages. 0.6.3 ----- New features: - Flax checkpointing now uses [Orbax](https://github.com/google/orbax) for more flexiblity and features. - Added support for python 3.10 and removed support for 3.7. Bug fixes: - Fixed rng generation in DenseGeneral init. - Improved support for Mac M1 chip. - Bumped package versions for a bunch of examples. - Improved many docstrings and error messages. 0.6.2 ----- New features: - Add rng_collection argument to Dropout. - Fix flax.linen.stochastic.Dropout. - Add flag allow_partial_mpa_restoration in checkpointing. - Use `gfile.remove` for files because it doesn't work on GCS files. - Added guides for: Flax the Sharp Bits, Checkpointing, Extracting Gradients - Improved existed documentation pages. - Improved errors, error messages and tests. - Removed codebase's trailing whitespaces. Bug fixes: - Fixes launch_gce.sh with imagenet example. - Properly report AttributeErrors from descriptors. - Fixes usages of `pmap`. - Return None if no _parent_ref is set. - Cap dynamic scale to float32 max. - no-op when double wrapping with struct.dataclass. - Allow variable_with_axes to have empty axes when axes is set to an empty tuple. - Don't create reference cycles among Modules. 0.6.1 ----- - Adds axis_name and axis_index_groups to LayerNorm and GroupNorm. by @copybara-service in [#2402](https://github.com/google/flax/pull/2402) - Plumb spmd_axis_name through transforms.vmap through to JAX vmap by @copybara-service in [#2398](https://github.com/google/flax/pull/2398) - Support multiple inputs in flax lifted vjp/custom_vjp by @copybara-service in [#2399](https://github.com/google/flax/pull/2399) - Improve tabulate by @cgarciae in [#2316](https://github.com/google/flax/pull/2316) - Add path_aware_map function by @cgarciae in [#2371](https://github.com/google/flax/pull/2371) - Add static_argnums to nn.checkpoint by @cgarciae in [#2457](https://github.com/google/flax/pull/2457) - Adding "count_include_pad" argument to flax.linen.pooling.avg_pool by @dslisleedh in [#2451](https://github.com/google/flax/pull/2451) - Add perturb() to allow capturing intermediate gradients by @IvyZX in [#2476](https://github.com/google/flax/pull/2476) 0.6.0 ----- - Removed deprecated optimizers in `flax.optim` package. - Moved `flax.optim.dynamic_scale` to `flax.training.dynamic_scale`. - Switched to using `jax.named_scope` for all profile naming, cut some pointless stack traces out. 0.5.3 ----- New features: - Added `nn.switch` as a lifted version of `jax.lax.switch`. - Added a method for detecting the use of "init" functions. - Added checkpointing support for `jax.experimental.GlobalDeviceArray`, a useful array type for multiprocess/multihost computing. - Added async option to `save_checkpoints()` on single-process scenario. - Improved documentation pages. Bug fixes: - Fixed variable aliasing in put_variable - Fixed missing passthrough of nn.scan unroll arg - Fixed the MNIST example 0.5.2 ----- - Fixes missing PyYAML dependency. 0.5.1 ----- New features: - Added `nn.tabulate` and `Module.tabulate` to generate rich representations of the network structure. 0.5.0 ----- - Added `flax.jax_utils.ad_shard_unpad()` by @lucasb-eyer - Implemented [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md). This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32. This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate complex numbers to their real component by default. Instead the complex dtype is preserved by default. Bug fixes: - Fix support for JAX's experimental_name_stack. Breaking changes: - In rare cases the dtype of a layer can change due to [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md). See the "Backward compatibility" section of the proposal for more information. 0.4.2 ----- New features: - Add lifted conditional `nn.cond`. - Improved error messages: parameters not found, loading checkpoints. - Replace `jax.tree_multimap` (deprecated) with `jax.tree.map`. - Add the "Module Lifecycle" design note. - Add support for JAX dynamic stack-based named_call Bug fixes: - Handle rate==1.0 edgecase in Dropout. - Fix bug where Linen Module state is reused. - Bug fixes and generalizations of nn.partitioning API. 0.4.1 ----- New features: - Added locally-connected (unshared CNN) layer `flax.linen.ConvLocal`. - Improved seq2seq example: Factored our model and input pipeline code. - Added Optax update guide and deprecated `flax.optim`. - Added `sep` argument to `flax.traverse_util.flatten_dict()`. - Implemented Sequential module, in `flax.linen.combinators`. 0.4.0 ------ Breaking changes: - flax.deprecated.nn is removed. Please pin to flax==0.3.6 if you are still using it. - PixelCNN++ example is removed. It was not working well on TPU. - linen Normalization layers no longer downcast double and complex floats tofloat32 when computing the mean and variance. New features: - Added `flax.linen.custom_vjp` for custom derivatives inside a `Module`. - Add `param_dtype` attribute to standard Linen Modules for specifying parameter dtypes. 0.3.6 ------ Breaking changes: - Move `flax.nn` to `flax.deprecated.nn`. New features: - Add experimental checkpoint policy argument. See `flax.linen.checkpoint` - Add lifted versions of jvp and vjp. - Add lifted transformation for mapping variables. See `flax.linen.map_variables`. 0.3.5 ------ Breaking changes: - You can no longer pass an int as the `kernel_size` for a `flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambigious when the kernel rank is known. - `flax.linen.enable_named_call` and `flax.linen.disable_named_call` now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now `flax.linen.override_named_call` that provided a context manager to locally disable/enable named_call. - NamedTuples are no longer converted to tuples on assignment to a `linen.Module`. New features: - Flax internal stack frames are now removed from exception state traces. - Added `flax.linen.nowrap` to decorate method that should not be transformed because they are stateful. - Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with `--jax_numpy_rank_promotion=raise`. Bugfixes: - linen Modules and dataclasses made with `flax.struct.dataclass` or `flax.struct.PyTreeNode` are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode. - Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names. - Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)). - Mixed precision training with float16 now works correctly with the attention layers. - auto-generated linen Module `__hash__`, `__eq__`, `__repr__` no longer fail by default on non-init attributes. 0.3.4 ------ Possibly breaking changes: - When calling `init` the 'intermediates' collection is no longer mutable. Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization. - When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the `deterministic` argument in `MultiHeadDotProductAttention`. Other changes: - Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2). - Added an NLP text classification example (on the SST-2 dataset) to [`examples/sst2`](https://github.com/google/flax/tree/main/examples/sst2). that uses a bidirectional LSTM (BiLSTM) to encode the input text. - Added `flax.training.train_state` to simplify using Optax optimizers. - `mutable` argument is now available on `Module.init` and `Module.init_with_outputs` - Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance. - Expose `dot_product_attention_weights`, allowing access to attention weights. - `BatchNorm` instances will behave correctly during init when called multiple times. - Added a more extensive "how to contribute" guide in `contributing.md`. - Add proper cache behavior for [`lift.jit`](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.jit.html#flax.linen.jit), fixing cache misses. - Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array. - Fix `linen.Module` for deep inheritance chains. - Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions. - Allow Flax lifted transforms to work on partially applied Modules. - Make `MultiOptimizer` use `apply_gradient` instead of `apply_param_gradient`. 0.3.3 ------ Possible breaking changes: - Bug Fix: Disallow modifying attributes in Modules after they are initialized. - Raise an error when saving a checkpoint which has a smaller step than the latest checkpoint already saved. - MultiOptimizer now rejects the case where multiple sub optimizers update the same parameter. Other changes: - Added custom error classes to many Linen errors. See: https://flax.readthedocs.io/en/latest/flax.errors.html - Adds `Module.bind` for binding variables and RNGs to an interactive Module. - Adds `nn.apply` and `nn.init` for transforming arbitrary functions that take a `linen.Module` as their first argument. - Add option to overwrite existing checkpoints in `save_checkpoint`. - Remove JAX omnistaging check for forward compatibility. - Pathlib compatibility for checkpoint paths. - `is_leaf` argument in `traverse_util.flatten_dict` 0.3.2 ------ `flax.nn` deprecation message no longer appears if you import flax directly. NOTE: You must now explicitly import `flax.nn` if you want to use the old pre-Linen `flax.nn.Module`. 0.3.1 ------ Many improvements to Linen, and the old `flax.nn` is officially deprecated! Notably, there's a clean API for extracting intermediates from modules defined using `@nn.compact`, a more ergonomic API for using Batch Norm and Dropout in modules defined using `setup`, support for `MultiOptimizer` with Linen, and multiple safety, performance and error message improvements. Possible breaking changes: - Call setup lazily. See #938 for motivation and more details. - Linen `Module` instances are now frozen after `setup` has been called. Previously mutations after setup could be dropped silently. Now the stateless requirement is enforced by raising a TypeError in `__setattr__` after `setup`. - Pytrees of dicts and lists are transformed into FrozenDict and tuples during attribute assignment. This avoids undetected submodules and inner state. - Bug Fix `flax.core.apply` and `Module.apply`. Now it returns a tuple containing the output and a frozen empty collection when `mutable` is specified as an empty list. - `broadcast_dims` is now an attribute to `Dropout` instead of a `__call__` argument. - `use_running_average` and `deterministic` no longer have a default. They should be passed explicitly - Bug Fix `Scope.variable` mutability check, before a variable could only be initialized if the 'params' collection was mutable. Other Improvements: - Re-introduced the `lm1b` language modeling example - Recognizes batch free inputs in pooling layers. (for use with vmap) - Add Adadelta optimizer - Fully deprecate all "pre-Linen" `flax.nn` classes and methods. - Some Module arguments can now be passed either as dataclass attribute or as argument to `__call__`. See [design note](https://flax.readthedocs.io/en/latest/guides/arguments.html) - Add `sow` method to `Module` and `capture_intermediates` argument to `Module.apply`. See [howto](https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html) for usage patterns. - Support passing in modules directly as attributes to other modules, and deal with them correctly both in top-level modules and in submodules. - Don't require the `variable` argument to `Module.apply` to be a FrozenDict - Add support for dict/FrozenDict when using `ModelParamTraversal` As a result `MultiOptimizer` can be used properly with linen modules. - Added OptimizedLSTM: ~33% faster than the original LSTM when using <=1024 units - Fix dtype handling for Adam and LAMB optimizers in 64bit mode. - Added `is_mutable()` method to `Variable` and `is_mutable_collection()` to `flax.linen.Module`. - Add `axis_name` arg to `flax.linen.vmap` - Enable broadcast in `flax.linen.scan` - Fix behavior when inner module classes were defined in another module - Add automatic giant array chunking in msgpack checkpoints. - Log info message when a checkpoint is not found in the directory. v0.3 ----- Linen is now out of Alpha (flax.nn is being deprecated)! - `flax.core.apply` and linen `Module.apply` will now only return the variables collections that were specified as mutable. - Fixed handling of multiple separate subclasses of a Module. - We now allow assignment of mixed Module pytrees in setup. - Refactored collection creation to fail early when modifying an undefined collection as before an non-existing non-mutable collection would just be silently ignored. - Added the silu activation function. - Add offset argument to Adafactor optimizer for fine-tuning schedules. - Relaxed limit on calling methods on unbound modules. - Relaxed parameter attribute check - Added centered version of RMSProp. - Added GCE getting started kit. - Renamed -gpu_type to -accelerator_type. - Fixed bug in MultiOptimizer causing it to throw away empty dictionary ### Improvements - Made FrozenDict constructor freeze correctly. - Made freeze a synonym of the FrozenDict constructor - Optimize freezing FrozenDicts by sharing immutable internal state. - We simplified __setattr__ handling of trees with Modules. - Minor improvements in dtype handling, broadcast option for dropout. - Added a dtype specification to Embed layer, made Adafactor use float32 state consistently, and added a broadcasting option to the Dropout layer. - Improved frozen dict performance. - (Massive) docs improvements - End to end benchmarks added. - Examples were updated to Linen. v0.2.2 ---- - Added Reinforcement Learning example (examples/ppo). - Fix Adafactor bug that prevented factorization. - Fix scan broadcast issue in functional core. - Fix initialization RNGs to work with omnistaging for jitted inits. - Replaces usage of 'param' kind to 'params' collection. - Fix LARS optimizer for zero param initialization. - Added various examples in Linen API. See [README.md](https://github.com/google/flax/blob/main/flax/linen/README.md) for more information. - Full JAX omnistaging compatibility. v0.2 ---- - Added JAX trace-level checks for transforms. - BatchNorm added axis_index_groups for control in parallel training. - Optimizers broken out into separate directory with base class and implementations. - traverse_util added flatten_dict and unflatten_dict utility methods for nested dicts. v0.1 ---- ### API Changes - Add ConvTranspose Module to nn.linear - Rename the following optional arguments to nn.linear.Conv: `lhs_dilation` -> `input_dilation`, `rhs_dilation` -> `kernel_dilation` - Change default layer names from numbers '0', '1', etc. to include the Module class name, e.g. 'Dense_0', 'LayerNorm_1'. ================================================ FILE: LICENSE ================================================ Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
logo
# Flax: A neural network library and ecosystem for JAX designed for flexibility [![Flax - Test](https://github.com/google/flax/actions/workflows/flax_test.yml/badge.svg)](https://github.com/google/flax/actions/workflows/flax_test.yml) [![PyPI version](https://img.shields.io/pypi/v/flax)](https://pypi.org/project/flax/) [**Overview**](#overview) | [**Quick install**](#quick-install) | [**What does Flax look like?**](#what-does-flax-look-like) | [**Documentation**](https://flax.readthedocs.io/) Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability. Flax NNX evolved from the [Flax Linen API](https://flax-linen.readthedocs.io/), which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team. You can learn more about Flax NNX on the [dedicated Flax documentation site](https://flax.readthedocs.io/). Make sure you check out: * [Flax NNX basics](https://flax.readthedocs.io/en/latest/nnx_basics.html) * [MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html) * [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html) * [Evolution from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) **Note:** Flax Linen's [documentation has its own site](https://flax-linen.readthedocs.io/). The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. You can make feature requests, let us know what you are working on, report issues, ask questions in our [Flax GitHub discussion forum](https://github.com/google/flax/discussions). We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use [Changelog](https://github.com/google/flax/tree/main/CHANGELOG.md) entries and deprecation warnings when possible. In case you want to reach us directly, we're at flax-dev@google.com. ## Overview Flax is a high-performance neural network library and ecosystem for JAX that is **designed for flexibility**: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework. Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including: * **Neural network API** (`flax.nnx`): Including [`Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear), [`Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv), [`BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm), [`LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm), [`GroupNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm), [Attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html) ([`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.MultiHeadAttention)), [`LSTMCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.LSTMCell), [`GRUCell`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/recurrent.html#flax.nnx.nn.recurrent.GRUCell), [`Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout). * **Utilities and patterns**: replicated training, serialization and checkpointing, metrics, prefetching on device. * **Educational examples**: [MNIST](https://flax.readthedocs.io/en/latest/mnist_tutorial.html), [Inference/sampling with the Gemma language model (transformer)](https://github.com/google/flax/tree/main/examples/gemma). ## Quick install Flax uses JAX, so do check out [JAX installation instructions on CPUs, GPUs and TPUs](https://jax.readthedocs.io/en/latest/installation.html). You will need Python 3.8 or later. Install Flax from PyPi: ``` pip install flax ``` To upgrade to the latest version of Flax, you can use: ``` pip install --upgrade git+https://github.com/google/flax.git ``` To install some additional dependencies (like `matplotlib`) that are required but not included by some dependencies, you can use: ```bash pip install "flax[all]" ``` ## What does Flax look like? We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder. To learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our [guides](https://flax.readthedocs.io/en/latest/guides/index.html) and [developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html). Example of an MLP: ```py class MLP(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dmid, rngs=rngs) self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x: jax.Array): x = nnx.gelu(self.dropout(self.bn(self.linear1(x)))) return self.linear2(x) ``` Example of a CNN: ```py class CNN(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x): x = self.avg_pool(nnx.relu(self.conv1(x))) x = self.avg_pool(nnx.relu(self.conv2(x))) x = x.reshape(x.shape[0], -1) # flatten x = nnx.relu(self.linear1(x)) x = self.linear2(x) return x ``` Example of an autoencoder: ```py Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) class AutoEncoder(nnx.Module): def __init__(self, rngs): self.encoder = Encoder(rngs) self.decoder = Decoder(rngs) def __call__(self, x) -> jax.Array: return self.decoder(self.encoder(x)) def encode(self, x) -> jax.Array: return self.encoder(x) ``` ## Citing Flax To cite this repository: ``` @software{flax2020github, author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee}, title = {{F}lax: A neural network library and ecosystem for {JAX}}, url = {http://github.com/google/flax}, version = {0.12.6}, year = {2024}, } ``` In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from [flax/version.py](https://github.com/google/flax/blob/main/flax/version.py), and the year corresponds to the project's open-source release. ## Note Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product. ================================================ FILE: benchmarks/README.md ================================================ # Benchmarks These are mini benchmarks to measure the performance of NNX operations. Sample profile command: ```shell python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000 ``` Sample profile inspection: ```shell snakeviz ~/tmp/overhead.prof ``` ================================================ FILE: benchmarks/nnx_graph_overhead.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp import numpy as np import optax from time import time from flax import nnx from absl import flags from absl import app FLAGS = flags.FLAGS flags.DEFINE_enum( 'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w + self.b class Block(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.linear = Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x): return nnx.relu(self.bn(self.linear(x))) class Count(nnx.Variable): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = nnx.List([ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ]) self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 x = nnx.relu(self.linear_in(x)) for layer in self.intermediates: x = nnx.relu(layer(x)) x = self.linear_out(x) return x def main(argv): print(argv) mode: str = FLAGS.mode total_steps: int = FLAGS.total_steps width: int = FLAGS.width depth: int = FLAGS.depth print(f'{mode=}, {total_steps=}, {width=}') X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) # ------------------------------------------------------------ # NNX # ------------------------------------------------------------ if mode in ['all', 'nnx']: model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) t0 = time() @nnx.jit def step_nnx(model: MLP, optimizer: nnx.Optimizer): pass t0 = time() for _ in range(total_steps): step_nnx(model, optimizer) total_time = time() - t0 time_per_step = total_time / total_steps time_per_layer = time_per_step / depth print('### NNX ###') print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') # ------------------------------------------------------------ # JAX # ------------------------------------------------------------ if mode in ['all', 'jax']: model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) t0 = time() @jax.jit def step_jax(graphdef, state): return graphdef, state graphdef, state = nnx.split((model, optimizer)) t0 = time() for _ in range(total_steps): graphdef, state = step_jax(graphdef, state) total_time = time() - t0 time_per_step = total_time / total_steps time_per_layer = time_per_step / depth print('### JAX ###') print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') print() if __name__ == '__main__': app.run(main) ================================================ FILE: benchmarks/nnx_mlpmixer_training.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% from functools import partial import jax import jax.numpy as jnp from flax import nnx import optax import numpy as np from einop import einop from time import time from tqdm import tqdm from flax import nnx from absl import flags from absl import app FLAGS = flags.FLAGS flags.DEFINE_enum( 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 4, 'Depth of the model') class MlpBlock(nnx.Module): def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs): self.din, self.mlp_dim = din, mlp_dim self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs) self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs) def __call__(self, x): return self.linear_out(nnx.gelu(self.linear_in(x))) class MixerBlock(nnx.Module): def __init__( self, tokens_mlp_dim: int, channels_mlp_dim: int, hidden_dim: int, rngs: nnx.Rngs, ): self.tokens_mlp_dim = tokens_mlp_dim self.channels_mlp_dim = channels_mlp_dim self.hidden_dim = hidden_dim self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs) self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs) self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) def __call__(self, x): y = self.ln1(x) y = y.swapaxes(1, 2) y = self.token_mixing(y) y = y.swapaxes(1, 2) x = x + y y = self.ln2(x) return x + self.channel_mixing(y) class MlpMixer(nnx.Module): def __init__( self, din: int, kernel_size: tuple[int, int], strides: tuple[int, int], num_blocks: int, hidden_dim: int, tokens_mlp_dim: int, channels_mlp_dim: int, rngs: nnx.Rngs, ): self.din = din self.kernel_size = kernel_size self.num_blocks = num_blocks self.hidden_dim = hidden_dim self.tokens_mlp_dim = tokens_mlp_dim self.channels_mlp_dim = channels_mlp_dim self.stem = nnx.Conv( din + 1, channels_mlp_dim, kernel_size=kernel_size, strides=strides, rngs=rngs, ) self.blocks = [ MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs) for _ in range(num_blocks) ] self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) self.conv_t = nnx.ConvTranspose( channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs ) def __call__(self, *, x, t): # add time feature to input t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1) x = jnp.concatenate([x, t], axis=-1) # create patches x = self.stem(x) h, w = x.shape[1], x.shape[2] x = einop(x, 'n h w c -> n (h w) c') # apply blocks for block in self.blocks: x = block(x) x = self.pre_head_layer_norm(x) # recreate image x = einop(x, 'n (h w) c -> n h w c', h=h, w=w) x = self.conv_t(x) return x def main(argv): print(argv) mode: str = FLAGS.mode total_steps: int = FLAGS.total_steps batch_size: int = FLAGS.batch_size width: int = FLAGS.width depth: int = FLAGS.depth print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') X = np.random.uniform(size=(batch_size, 28, 28, 1)) if mode == 'nnx' or mode == 'all': rngs = nnx.Rngs(0) flow = MlpMixer( din=1, kernel_size=(2, 2), strides=(2, 2), num_blocks=4, hidden_dim=512, tokens_mlp_dim=196, channels_mlp_dim=512, rngs=rngs, ) optimizer = nnx.Optimizer( flow, tx=optax.adamw(1e-4), wrt=nnx.Param ) t0 = time() mse = lambda a, b: jnp.mean((a - b) ** 2) @nnx.jit(donate_argnums=(0, 1, 2)) def train_step_nnx(flow, optimizer, rngs, x_1): print('JITTING NNX') x_0 = jax.random.normal(rngs(), x_1.shape) t = jax.random.uniform(rngs(), (len(x_1),)) x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) dx_t = x_1 - x_0 loss, grads = nnx.value_and_grad( lambda flow: mse(flow(x=x_t, t=t), dx_t) )(flow) optimizer.update(flow, grads) return loss losses = [] t0 = time() for step in tqdm(range(total_steps), desc='NNX'): loss = train_step_nnx(flow, optimizer, rngs, X) losses.append(loss) total_time = time() - t0 print('### NNX ###') print(f'final loss: {losses[-1]}') print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') if mode == 'jax' or mode == 'all': rngs = nnx.Rngs(0) flow = MlpMixer( din=1, kernel_size=(2, 2), strides=(2, 2), num_blocks=depth, hidden_dim=width, tokens_mlp_dim=196, channels_mlp_dim=width, rngs=rngs, ) optimizer = nnx.Optimizer( flow, tx=optax.adamw(1e-4), wrt=nnx.Param ) graphdef, state = nnx.split((flow, optimizer, rngs)) t0 = time() mse = lambda a, b: jnp.mean((a - b) ** 2) @partial(nnx.jit, donate_argnums=0) def train_step_jax(state, x_1): print('JITTING JAX') flow, optimizer, rngs = nnx.merge(graphdef, state) x_0 = jax.random.normal(rngs(), x_1.shape) t = jax.random.uniform(rngs(), (len(x_1),)) x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) dx_t = x_1 - x_0 loss, grads = nnx.value_and_grad( lambda flow: mse(flow(x=x_t, t=t), dx_t) )(flow) optimizer.update(flow, grads) state = nnx.state((flow, optimizer, rngs)) return loss, state losses = [] t0 = time() for step in tqdm(range(total_steps), desc='JAX'): loss, state = train_step_jax(state, X) losses.append(loss) nnx.update((flow, optimizer, rngs), state) total_time = time() - t0 print('### JAX ###') print(f'final loss: {losses[-1]}') print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') if __name__ == '__main__': app.run(main) ================================================ FILE: benchmarks/nnx_simple_training.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% from functools import partial import jax import jax.numpy as jnp import numpy as np import optax from time import time from flax import nnx from absl import flags from absl import app FLAGS = flags.FLAGS flags.DEFINE_enum( 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') def dataset(X, Y, batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w + self.b class Block(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.linear = Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x): return nnx.relu(self.bn(self.linear(x))) class Count(nnx.Variable): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear_in = Block(din, dhidden, rngs=rngs) self.intermediates = [ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) ] self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 x = nnx.relu(self.linear_in(x)) for layer in self.intermediates: x = nnx.relu(layer(x)) x = self.linear_out(x) return x def main(argv): print(argv) mode: str = FLAGS.mode total_steps: int = FLAGS.total_steps batch_size: int = FLAGS.batch_size width: int = FLAGS.width depth: int = FLAGS.depth print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) if mode == 'nnx' or mode == 'all': model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) t0 = time() @nnx.jit(donate_argnums=(0, 1)) def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads: nnx.State = nnx.grad(loss_fn)(model) optimizer.update(model, grads) @nnx.jit(donate_argnums=0) def test_step_nnx(model: MLP, batch): x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} for step, batch in enumerate(dataset(X, Y, batch_size)): train_step_nnx(model, optimizer, batch) if step % 1000 == 0: logs = test_step_nnx(model, (X, Y)) if step >= total_steps - 1: break print('### NNX ###') print(f'final loss: {logs["loss"]}') total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') print('times called:', model.count.value) if mode == 'jax' or mode == 'all': model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) t0 = time() @partial(jax.jit, donate_argnums=0) def train_step_jax(state, batch): model, optimizer = nnx.merge(graphdef, state) x, y = batch def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = nnx.grad(loss_fn)(model) optimizer.update(model,grads) return nnx.state((model, optimizer)) @partial(jax.jit, donate_argnums=0) def test_step_jax(state, batch): model, optimizer = nnx.merge(graphdef, state) x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) state = nnx.state((model, optimizer)) return state, {'loss': loss} graphdef, state = nnx.split((model, optimizer)) for step, batch in enumerate(dataset(X, Y, batch_size)): state = train_step_jax(state, batch) if step % 1000 == 0: state, logs = test_step_jax(state, (X, Y)) if step >= total_steps - 1: break model, optimizer = nnx.merge(graphdef, state) print('### JAX ###') print(f'final loss: {logs["loss"]}') total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') print('times called:', model.count.value) if __name__ == '__main__': app.run(main) ================================================ FILE: benchmarks/nnx_state_traversal.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Example profile command: # python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000 # View profile (need to install snakeviz): # snakeviz ~/tmp/overhead.prof import jax from time import time from flax import nnx from absl import flags from absl import app FLAGS = flags.FLAGS flags.DEFINE_integer('total_steps', 1000, 'Total number of training steps') flags.DEFINE_integer('width', 4, 'Width of each level') flags.DEFINE_integer('depth', 4, 'Depth of the model') class NestedClass(nnx.Module): def __init__(self, width, depth): self.x = nnx.Variable(jax.numpy.ones((depth+1, ))) if depth > 0: for i in range(width): setattr(self, f'child{i}', NestedClass(width, depth-1)) def main(argv): print(argv) total_steps: int = FLAGS.total_steps width: int = FLAGS.width depth: int = FLAGS.depth model = NestedClass(width, depth) to_test = nnx.state(model) print(f'{total_steps=}, {width=}') #------------------------------------------------------------ # tree_flatten_with_path #------------------------------------------------------------ t0 = time() for _ in range(total_steps): jax.tree_util.tree_flatten_with_path(to_test) total_time = time() - t0 time_per_step = total_time / total_steps time_per_layer = time_per_step / depth print("### tree_flatten_with_path ###") print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') #------------------------------------------------------------ # tree_map_with_path #------------------------------------------------------------ t0 = time() for _ in range(total_steps): jax.tree_util.tree_map_with_path(lambda _, x: x, to_test) total_time = time() - t0 time_per_step = total_time / total_steps time_per_layer = time_per_step / depth print("### tree_map_with_path ###") print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') #------------------------------------------------------------ # tree_flatten #------------------------------------------------------------ t0 = time() for _ in range(total_steps): jax.tree_util.tree_flatten(to_test) total_time = time() - t0 time_per_step = total_time / total_steps time_per_layer = time_per_step / depth print("### tree_flatten ###") print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') if __name__ == '__main__': app.run(main) ================================================ FILE: benchmarks/tracing/README.md ================================================ # Tracing and lowering benchmarks for Flax examples See Flax [documentation](https://flax.readthedocs.io/en/latest/examples/index.html) on their examples. ## Getting started bash ``` pip install -r benchmarks/tracing/requirements.txt # Benchmark trace and lower timing for all workloads. python tracing_benchmark.py # Profile a single example. python tracing_benchmark.py --example=wmt # Profile just tracing for a single example. python tracing_benchmark.py --example=wmt --mode=trace ``` ================================================ FILE: benchmarks/tracing/__init__.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: benchmarks/tracing/gemma.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Gemma helper functions.""" from typing import Any from flax import nnx from flax.benchmarks.tracing import tracing_benchmark from flax.examples.gemma import transformer as transformer_lib from flax.examples.gemma import utils from flax.examples.gemma.configs import default as gemma_config from flax.training import common_utils import google_benchmark import jax import jax.numpy as jnp import ml_collections import numpy as np import optax def rsqrt_schedule(init_value: float, shift: int = 0): def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights, label_smoothing=0.0): loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } return metrics def train_step( state: utils.TrainState, batch, learning_rate_fn, label_smoothing=0.0, ): train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] (inputs, inputs_positions, inputs_segmentation, targets) = ( batch.get(k, None) for k in train_keys ) pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id attention_mask = transformer_lib.make_causal_attn_mask(input_mask) mask = ( inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] ) attention_mask = jnp.logical_and(mask, attention_mask) def loss_fn(params): module = nnx.merge(state.graphdef, params) logits, _ = module( inputs, positions=inputs_positions, attention_mask=attention_mask, cache=None, ) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) metrics['learning_rate'] = lr return new_state, metrics def get_fake_batch(batch_size: int) -> Any: rng = jax.random.PRNGKey(0) batch = {} for k in ( 'inputs', 'inputs_position', 'inputs_segmentation', 'targets', 'targets_position', 'targets_segmentation', ): batch[k] = jax.random.randint(rng, (batch_size, 128), 0, 9999999, jnp.int32) return batch def get_apply_fn_and_args( config: ml_collections.ConfigDict, vocab_size: int | None = None, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: if vocab_size is None: vocab_size = config.vocab_size if config.transformer_name is not None: model_config = transformer_lib.TransformerConfig.from_version_name( config.transformer_name, num_embed=vocab_size, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, axis_rules=config.axis_rules, ) else: assert config.transformer_params is not None model_config = transformer_lib.TransformerConfig.from_dict( **config.transformer_params, num_embed=vocab_size, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, axis_rules=config.axis_rules, ) devices_array = utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) state, state_sharding = utils.setup_initial_state( constructor, optimizer, model_config, init_rng, mesh ) data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding)) jit_train_step = jax.jit( train_step, in_shardings=( state_sharding, data_sharding, ), # type: ignore out_shardings=(state_sharding, None), # type: ignore static_argnames=('learning_rate_fn', 'label_smoothing'), donate_argnums=0, ) batch = get_fake_batch(config.per_device_batch_size) batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) return ( jit_train_step, (state, batch, learning_rate_fn, 0.0), dict(), ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_gemma_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, gemma_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_gemma_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, gemma_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/imagenet.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet helper functions for benchmarking.""" import functools from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.imagenet import models from flax.examples.imagenet.configs import default as imagenet_config from flax.training import common_utils from flax.training import dynamic_scale as dynamic_scale_lib from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import optax NUM_CLASSES = 1000 class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: dynamic_scale_lib.DynamicScale def create_model(*, model_cls, half_precision, **kwargs): platform = jax.local_devices()[0].platform if half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 else: model_dtype = jnp.float16 else: model_dtype = jnp.float32 return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs) def initialized(key, image_size, model): input_shape = (1, image_size, image_size, 3) @jax.jit def init(*args): return model.init(*args) variables = init({'params': key}, jnp.ones(input_shape, model.dtype)) return variables['params'], variables['batch_stats'] def cross_entropy_loss(logits, labels): one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES) xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) return jnp.mean(xentropy) def create_train_state( rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn ): dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = dynamic_scale_lib.DynamicScale() params, batch_stats = initialized(rng, image_size, model) tx = optax.sgd( learning_rate=learning_rate_fn, momentum=config.momentum, nesterov=True, ) state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, dynamic_scale=dynamic_scale, ) return state def get_fake_batch(batch_size: int = 128) -> dict[str, jnp.ndarray]: images = jax.random.uniform( jax.random.key(0), (batch_size, 224, 224, 3), dtype=jnp.float32 ) labels = jax.random.randint( jax.random.key(1), (batch_size,), minval=0, maxval=1000, dtype=jnp.int32 ) return {'image': images, 'label': labels} class BenchmarkResNet(models.ResNet): @nn.compact def __call__(self, x, train: bool = True): conv = functools.partial(self.conv, use_bias=False, dtype=self.dtype) norm = functools.partial( nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype, axis_name=None, ) x = conv( self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init', )(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( self.num_filters * 2**i, strides=strides, conv=conv, norm=norm, act=self.act, )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) return x def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: if config.model == 'ResNet50': model_cls = functools.partial( BenchmarkResNet, stage_sizes=[3, 4, 6, 3], block_cls=models.BottleneckResNetBlock, ) else: model_cls = getattr(models, config.model) model = create_model( model_cls=model_cls, half_precision=config.half_precision ) learning_rate_fn = lambda step: 0.1 rng = jax.random.key(0) image_size = 224 state = create_train_state( rng, config, model, image_size, learning_rate_fn ) batch = get_fake_batch(config.batch_size) return ( bench_train_step, (state, batch, learning_rate_fn), {}, ) @functools.partial(jax.jit, static_argnums=(2,)) def bench_train_step(state, batch, learning_rate_fn): def compute_metrics(logits, labels): loss = cross_entropy_loss(logits, labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, 'accuracy': accuracy, } return metrics def loss_fn(params): logits, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, batch['image'], mutable=['batch_stats'], ) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_util.tree_leaves(params) weight_decay = 0.0001 weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_model_state, logits) step = state.step dynamic_scale = state.dynamic_scale lr = learning_rate_fn(step) if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True) dynamic_scale, is_fin, aux, grads = grad_fn(state.params) else: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) aux, grads = grad_fn(state.params) new_model_state, logits = aux[1] metrics = compute_metrics(logits, batch['label']) metrics['learning_rate'] = lr new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state['batch_stats'], ) if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_util.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state, ), params=jax.tree_util.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params ), dynamic_scale=dynamic_scale, ) metrics['scale'] = dynamic_scale.scale return new_state, metrics @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_imagenet_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, imagenet_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_imagenet_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, imagenet_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/lm1b.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LM1B helper functions for benchmarking.""" import functools from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.lm1b import models from flax.examples.lm1b.configs import default as lm1b_config from flax.training import common_utils from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import numpy as np import optax def rsqrt_schedule(init_value: float, shift: int = 0): def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.ndarray]: batch_size = config.per_device_batch_size max_len = config.max_target_length inputs = jax.random.randint( jax.random.key(0), (batch_size, max_len), minval=0, maxval=config.vocab_size, dtype=jnp.int32, ) inputs_position = jnp.tile( jnp.arange(max_len, dtype=jnp.int32), (batch_size, 1) ) inputs_segmentation = jnp.ones((batch_size, max_len), dtype=jnp.int32) return { 'inputs': inputs, 'inputs_position': inputs_position, 'inputs_segmentation': inputs_segmentation, } @functools.partial(jax.jit, static_argnums=(2, 3)) def bench_train_step(state, batch, config, learning_rate_fn): def compute_metrics(logits, labels, weights): loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, 0.0 ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } return metrics inputs = batch['inputs'] inputs_positions = batch['inputs_position'] inputs_segmentation = batch['inputs_segmentation'] weights = jnp.where(inputs > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(jax.random.key(0), state.step) def loss_fn(params): logits = models.TransformerLM(config).apply( {'params': params}, inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, rngs={'dropout': dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy( logits, inputs, weights, 0.0 ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, inputs, weights) metrics['learning_rate'] = lr return new_state, metrics def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: train_config = models.TransformerConfig( vocab_size=config.vocab_size, output_vocab_size=config.vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=jax.nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.normal(stddev=1e-6), ) model = models.TransformerLM(train_config) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) rng = jax.random.key(0) init_rng, _ = jax.random.split(rng) initial_variables = model.init( init_rng, jnp.ones( (config.per_device_batch_size, config.max_target_length), jnp.int32 ), jnp.ones( (config.per_device_batch_size, config.max_target_length), jnp.int32 ), jnp.ones( (config.per_device_batch_size, config.max_target_length), jnp.int32 ), ) state = train_state.TrainState.create( apply_fn=model.apply, params=initial_variables['params'], tx=optimizer, ) batch = get_fake_batch(config) return ( bench_train_step, (state, batch, train_config, learning_rate_fn), {}, ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_lm1b_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, lm1b_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_lm1b_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, lm1b_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/mnist.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST helper functions.""" from functools import partial from typing import Any from flax import nnx from flax.benchmarks.tracing import tracing_benchmark from flax.examples.mnist.configs import default as mnist_config import google_benchmark import jax import jax.numpy as jnp import ml_collections import optax class CNN(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) self.dropout1 = nnx.Dropout(rate=0.025) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.dropout2 = nnx.Dropout(rate=0.025) self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x, rngs: nnx.Rngs): x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs)))) x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) x = x.reshape(x.shape[0], -1) # flatten x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs)) x = self.linear2(x) return x def loss_fn(model: CNN, batch, rngs): logits = model(batch['image'], rngs) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] ).mean() return loss, logits def get_fake_batch(batch_size: int) -> dict[str, Any]: rng = jax.random.PRNGKey(0) images = jax.random.normal(rng, (batch_size, 28, 28, 1), jnp.float32) labels = jax.random.randint(rng, (batch_size,), 0, 10, jnp.int32) return {'image': images, 'label': labels} def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: model = CNN(rngs=nnx.Rngs(0)) batch = get_fake_batch(config.batch_size) rngs = nnx.Rngs(0) loss_fn_jit = jax.jit(loss_fn) return ( loss_fn_jit, (model, batch, rngs), dict(), ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_mnist_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, mnist_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_mnist_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, mnist_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/nlp_seq.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """NLP Sequence Tagging helper functions for benchmarking.""" import functools from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.nlp_seq import models from flax.examples.nlp_seq.configs import default as nlp_seq_config from flax.training import common_utils from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import numpy as np import optax def create_learning_rate_scheduler( factors='constant * linear_warmup * rsqrt_decay', base_learning_rate=0.5, warmup_steps=8000, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, ): factors = [n.strip() for n in factors.split('*')] def step_fn(step): ret = 1.0 for name in factors: if name == 'constant': ret *= base_learning_rate elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= decay_factor ** (step // steps_per_decay) elif name == 'cosine_decay': progress = jnp.maximum( 0.0, (step - warmup_steps) / float(steps_per_cycle) ) ret *= jnp.maximum( 0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))) ) else: raise ValueError('Unknown factor %s.' % name) return jnp.asarray(ret, dtype=jnp.float32) return step_fn def compute_weighted_cross_entropy(logits, targets, weights=None): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, jnp.ndarray]: batch_size = config.batch_size max_len = config.max_length vocab_size = config.vocab_size inputs = jax.random.randint( jax.random.key(0), (batch_size, max_len), minval=0, maxval=vocab_size, dtype=jnp.int32, ) targets = jax.random.randint( jax.random.key(1), (batch_size, max_len), minval=0, maxval=config.output_vocab_size, dtype=jnp.int32, ) return { 'inputs': inputs, 'targets': targets, } @functools.partial(jax.jit, static_argnums=(2, 3)) def bench_train_step(state, batch, config, learning_rate_fn): def local_compute_metrics(logits, labels, weights): loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } return metrics inputs = batch['inputs'] targets = batch['targets'] weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(jax.random.key(0), state.step) def loss_fn(params): model = models.Transformer(config) logits = model.apply( {'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = local_compute_metrics(logits, targets, weights) metrics['learning_rate'] = lr return new_state, metrics def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: if not hasattr(config, 'vocab_size'): config.vocab_size = 30000 if not hasattr(config, 'output_vocab_size'): config.output_vocab_size = 50 model_config = models.TransformerConfig( vocab_size=config.vocab_size, output_vocab_size=config.output_vocab_size, max_len=config.max_length, ) model = models.Transformer(model_config) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) rng = jax.random.key(0) init_rng, _ = jax.random.split(rng) init_batch = jnp.ones((config.batch_size, config.max_length), jnp.int32) initial_variables = model.init(init_rng, inputs=init_batch, train=False) state = train_state.TrainState.create( apply_fn=model.apply, params=initial_variables['params'], tx=optimizer, ) batch = get_fake_batch(config) return ( bench_train_step, (state, batch, model_config, learning_rate_fn), {}, ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_nlp_seq_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, nlp_seq_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_nlp_seq_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, nlp_seq_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/ogbg_molpcba.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """OGBG-MolPCBA helper functions for benchmarking.""" from typing import Any from clu import metrics import flax import flax.linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.ogbg_molpcba import models from flax.examples.ogbg_molpcba.configs import default as ogbg_config from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import jraph import ml_collections import optax def create_model( config: ml_collections.ConfigDict, deterministic: bool ) -> nn.Module: if config.model == 'GraphNet': return models.GraphNet( latent_size=config.latent_size, num_mlp_layers=config.num_mlp_layers, message_passing_steps=config.message_passing_steps, output_globals_size=config.num_classes, dropout_rate=config.dropout_rate, skip_connections=config.skip_connections, layer_norm=config.layer_norm, use_edge_model=config.use_edge_model, deterministic=deterministic, ) if config.model == 'GraphConvNet': return models.GraphConvNet( latent_size=config.latent_size, num_mlp_layers=config.num_mlp_layers, message_passing_steps=config.message_passing_steps, output_globals_size=config.num_classes, dropout_rate=config.dropout_rate, skip_connections=config.skip_connections, layer_norm=config.layer_norm, deterministic=deterministic, ) raise ValueError(f'Unsupported model: {config.model}.') def create_optimizer( config: ml_collections.ConfigDict, ) -> optax.GradientTransformation: if config.optimizer == 'adam': return optax.adam(learning_rate=config.learning_rate) if config.optimizer == 'sgd': return optax.sgd( learning_rate=config.learning_rate, momentum=config.momentum ) raise ValueError(f'Unsupported optimizer: {config.optimizer}.') def binary_cross_entropy_with_mask( *, logits: jnp.ndarray, labels: jnp.ndarray, mask: jnp.ndarray ): assert logits.shape == labels.shape == mask.shape assert len(logits.shape) == 2 labels = jnp.where(mask, labels, -1) positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits))) def predictions_match_labels( *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs ) -> jnp.ndarray: del kwargs preds = logits > 0 return (preds == labels).astype(jnp.float32) def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) def get_predicted_logits( state: train_state.TrainState, graphs: jraph.GraphsTuple, rngs: dict[str, jnp.ndarray] | None, ) -> jnp.ndarray: pred_graphs = state.apply_fn(state.params, graphs, rngs=rngs) logits = pred_graphs.globals return logits def get_valid_mask( labels: jnp.ndarray, graphs: jraph.GraphsTuple ) -> jnp.ndarray: labels_mask = ~jnp.isnan(labels) graph_mask = jraph.get_graph_padding_mask(graphs) return labels_mask & graph_mask[:, None] @flax.struct.dataclass class TrainMetrics(metrics.Collection): accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') @jax.jit def ogbg_train_step( state: train_state.TrainState, graphs: jraph.GraphsTuple, rngs: dict[str, jnp.ndarray], ) -> tuple[train_state.TrainState, metrics.Collection]: def loss_fn(params, graphs): curr_state = state.replace(params=params) labels = graphs.globals graphs = replace_globals(graphs) logits = get_predicted_logits(curr_state, graphs, rngs) mask = get_valid_mask(labels, graphs) loss = binary_cross_entropy_with_mask( logits=logits, labels=labels, mask=mask ) mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask) return mean_loss, (loss, logits, labels, mask) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, (loss, logits, labels, mask)), grads = grad_fn(state.params, graphs) state = state.apply_gradients(grads=grads) metrics_update = TrainMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=mask ) return state, metrics_update def get_fake_graphs(config: ml_collections.ConfigDict) -> jraph.GraphsTuple: rng = jax.random.key(0) n_graphs = config.batch_size n_total_nodes = n_graphs * 20 n_total_edges = n_graphs * 40 nodes = jax.random.normal(rng, (n_total_nodes, 9)) edges = jax.random.normal(rng, (n_total_edges, 3)) senders = jax.random.randint(rng, (n_total_edges,), 0, n_total_nodes) receivers = jax.random.randint(rng, (n_total_edges,), 0, n_total_nodes) n_node = jnp.full((n_graphs,), 20, dtype=jnp.int32) n_edge = jnp.full((n_graphs,), 40, dtype=jnp.int32) globals_ = jax.random.bernoulli( rng, shape=(n_graphs, config.num_classes) ).astype(jnp.float32) return jraph.GraphsTuple( nodes=nodes, edges=edges, senders=senders, receivers=receivers, n_node=n_node, n_edge=n_edge, globals=globals_, ) def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: rng = jax.random.key(0) rng, init_rng, dropout_rng = jax.random.split(rng, 3) graphs = get_fake_graphs(config) init_net = create_model(config, deterministic=True) init_graphs = replace_globals(graphs) params = jax.jit(init_net.init)(init_rng, init_graphs) tx = create_optimizer(config) net = create_model(config, deterministic=False) state = train_state.TrainState.create( apply_fn=net.apply, params=params, tx=tx ) return ( ogbg_train_step, (state, graphs, {'dropout': dropout_rng}), {}, ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_ogbg_molpcba_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, ogbg_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_ogbg_molpcba_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, ogbg_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/ppo.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PPO helper functions for RL benchmarking.""" from typing import Any from flax.benchmarks.tracing import tracing_benchmark from flax.examples.ppo import models from flax.examples.ppo import ppo_lib from flax.examples.ppo.configs import default as ppo_config import google_benchmark import jax import jax.numpy as jnp import ml_collections def get_fake_batch(batch_size: int = 256) -> tuple[jnp.ndarray, ...]: """Generate a batch of fake Atari observations and trajectories. Args: batch_size: Size of the minibatch. Returns: A tuple of (states, actions, old_log_probs, returns, advantages). """ # Atari observations: (batch_size, height, width, stacked_frames) states = jax.random.randint( jax.random.key(0), (batch_size, 84, 84, 4), minval=0, maxval=256, dtype=jnp.int32, ) # Actions: discrete action space (e.g., 6 actions for Pong) actions = jax.random.randint( jax.random.key(1), (batch_size,), minval=0, maxval=6, dtype=jnp.int32 ) # Old log probabilities from behavior policy old_log_probs = jax.random.normal( jax.random.key(2), (batch_size,), dtype=jnp.float32 ) # Returns (discounted cumulative rewards) returns = jax.random.normal( jax.random.key(3), (batch_size,), dtype=jnp.float32 ) # Advantages (GAE advantages) advantages = jax.random.normal( jax.random.key(4), (batch_size,), dtype=jnp.float32 ) return (states, actions, old_log_probs, returns, advantages) def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: """Returns the apply function and args for the given config. Args: config: The training configuration. Returns: A tuple of the apply function, args, and kwargs. """ # Create model (6 actions for Pong) num_outputs = 6 model = models.ActorCritic(num_outputs=num_outputs) # Initialize model parameters rng = jax.random.key(0) init_shape = jnp.ones((1, 84, 84, 4), jnp.float32) initial_params = model.init(rng, init_shape)['params'] # Create train state # For benchmarking, we don't need actual training loops, just one step train_steps = 1000 # Dummy value for state creation state = ppo_lib.create_train_state(initial_params, model, config, train_steps) # Generate fake trajectories trajectories = get_fake_batch(config.batch_size) # PPO hyperparameters clip_param = config.clip_param vf_coeff = config.vf_coeff entropy_coeff = config.entropy_coeff # ppo_lib.train_step is already JIT-compiled with static batch_size return ( ppo_lib.train_step, (state, trajectories, config.batch_size), { 'clip_param': clip_param, 'vf_coeff': vf_coeff, 'entropy_coeff': entropy_coeff, }, ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_ppo_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, ppo_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_ppo_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, ppo_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/requirements.txt ================================================ absl-py flax google-benchmark jax ml_collections numpy optax ================================================ FILE: benchmarks/tracing/run_all_benchmarks.sh ================================================ #!/bin/bash set -e export XLA_FLAGS=--xla_force_host_platform_device_count=8 TARGETS=( mnist vae sst2 gemma imagenet seq2seq lm1b nlp_seq ogbg_molpcba wmt ppo ) for target in "${TARGETS[@]}"; do echo "============================================" echo "Running benchmark: ${target}" echo "============================================" benchy "third_party/py/flax/benchmarks/tracing:${target}" echo "" done ================================================ FILE: benchmarks/tracing/seq2seq.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Seq2Seq helper functions.""" from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.seq2seq import models from flax.examples.seq2seq.configs import default as seq2seq_config from flax.examples.seq2seq.input_pipeline import CharacterTable from flax.examples.seq2seq.input_pipeline import get_sequence_lengths from flax.examples.seq2seq.input_pipeline import mask_sequences from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import optax def cross_entropy_loss(logits, labels, lengths): xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1) masked_xe = jnp.mean(mask_sequences(xe, lengths)) return -masked_xe def compute_metrics(logits, labels, eos_id): lengths = get_sequence_lengths(labels, eos_id) loss = cross_entropy_loss(logits, labels, lengths) token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1) sequence_accuracy = ( jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths ) accuracy = jnp.mean(sequence_accuracy) metrics = { 'loss': loss, 'accuracy': accuracy, } return metrics @jax.jit def seq2seq_train_step(state, batch, lstm_rng, eos_id): labels = batch['answer'][:, 1:] lstm_key = jax.random.fold_in(lstm_rng, state.step) def loss_fn(params): logits, _ = state.apply_fn( {'params': params}, batch['query'], batch['answer'], rngs={'lstm': lstm_key}, ) loss = cross_entropy_loss( logits, labels, get_sequence_lengths(labels, eos_id) ) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, labels, eos_id) return state, metrics def get_fake_batch(batch_size: int, ctable: CharacterTable) -> dict[str, Any]: return ctable.get_batch(batch_size) def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: rng = jax.random.key(0) ctable = CharacterTable("0123456789+= ", config.max_len_query_digit) model = models.Seq2seq( teacher_force=False, hidden_size=config.hidden_size, eos_id=ctable.eos_id, vocab_size=ctable.vocab_size, ) rng1, rng2 = jax.random.split(rng) params = model.init( {'params': rng1, 'lstm': rng2}, jnp.ones(ctable.encoder_input_shape, jnp.float32), jnp.ones(ctable.decoder_input_shape, jnp.float32), )['params'] tx = optax.adam(config.learning_rate) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx ) batch = get_fake_batch(config.batch_size, ctable) return seq2seq_train_step, (state, batch, rng, ctable.eos_id), {} @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_seq2seq_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, seq2seq_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_seq2seq_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, seq2seq_config.get_config, state ) if __name__ == "__main__": tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/sst2.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SST2 helper functions.""" from typing import Any from flax.benchmarks.tracing import tracing_benchmark from flax.examples.sst2 import models from flax.examples.sst2.configs import default as sst2_config from flax.training import train_state as train_state_lib import google_benchmark import jax import jax.numpy as jnp import ml_collections import optax Array = jnp.ndarray TrainState = train_state_lib.TrainState @jax.vmap def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -> Array: zeros = jnp.zeros_like(logits, dtype=logits.dtype) condition = logits >= zeros relu_logits = jnp.where(condition, logits, zeros) neg_abs_logits = jnp.where(condition, -logits, logits) return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits)) def model_from_config(config: ml_collections.ConfigDict): model = models.TextClassifier( embedding_size=config.embedding_size, hidden_size=config.hidden_size, vocab_size=config.vocab_size, output_size=config.output_size, dropout_rate=config.dropout_rate, word_dropout_rate=config.word_dropout_rate, unk_idx=config.unk_idx, ) return model def get_initial_params(rng, model): token_ids = jnp.ones((2, 3), jnp.int32) lengths = jnp.ones((2,), dtype=jnp.int32) variables = model.init(rng, token_ids, lengths, deterministic=True) return variables['params'] def create_train_state(rng, config: ml_collections.ConfigDict, model): params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), optax.add_decayed_weights(weight_decay=config.weight_decay), ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state def compute_metrics(*, labels: Array, logits: Array): if labels.ndim == 1: labels = jnp.expand_dims(labels, axis=1) loss = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) binary_predictions = logits >= 0.0 binary_accuracy = jnp.equal(binary_predictions, labels) return { 'loss': jnp.sum(loss), 'accuracy': jnp.sum(binary_accuracy), 'count': logits.shape[0], } def train_step( state: TrainState, batch: dict[str, Array], rngs: dict[str, Any], ) -> tuple[TrainState, Any]: step = state.step rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()} def loss_fn(params): variables = {'params': params} logits = state.apply_fn( variables, batch['token_ids'], batch['length'], deterministic=False, rngs=rngs, ) labels = batch['label'] if labels.ndim == 1: labels = jnp.expand_dims(labels, 1) loss = jnp.mean( sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) ) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) value, grads = grad_fn(state.params) (_, logits) = value new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(labels=batch['label'], logits=logits) return new_state, metrics def get_fake_batch(batch_size: int) -> dict[str, Any]: rng = jax.random.key(0) max_length = 60 token_ids = jax.random.randint( rng, (batch_size, max_length), 0, 1000, jnp.int32 ) lengths = jnp.full((batch_size,), max_length, jnp.int32) labels = jax.random.uniform(rng, (batch_size,), jnp.float32) return { 'token_ids': token_ids, 'length': lengths, 'label': labels, } def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: rng = jax.random.key(0) config = config.copy_and_resolve_references() if config.vocab_size is None: config.vocab_size = 1000 model = model_from_config(config) state = create_train_state(rng, config, model) batch = get_fake_batch(config.batch_size) _, dropout_rng = jax.random.split(rng) rngs = {'dropout': dropout_rng} train_step_jit = jax.jit(train_step) return train_step_jit, (state, batch, rngs), {} @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_sst2_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, sst2_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_sst2_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, sst2_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/tracing_benchmark.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Shared benchmark utilities for Jax tracing flax examples.""" from collections.abc import Callable import sys from typing import Any from absl import app from absl import flags from absl import logging import google_benchmark import jax flags.DEFINE_enum( "mode", "trace_and_lower", ["trace", "lower", "trace_and_lower"], "Measure trace, lower, or trace_and_lower.", ) def clear_caches(state): state.pause_timing() jax.clear_caches() state.resume_timing() def benchmark_tracing( get_apply_fn_and_args: Callable[..., Any], get_config: Callable[[], Any], state: Any, ) -> None: """Benchmark for tracing a flax example.""" config = get_config() apply_fn, args, kwargs = get_apply_fn_and_args(config) while state: if flags.FLAGS.mode == 'trace' or flags.FLAGS.mode == 'trace_and_lower': _ = apply_fn.trace(*args, **kwargs) clear_caches(state) def benchmark_lowering( get_apply_fn_and_args: Callable[..., Any], get_config: Callable[[], Any], state: Any, platform: str = 'tpu', ) -> None: """Benchmark for lowering a flax example.""" config = get_config() apply_fn, args, kwargs = get_apply_fn_and_args(config) traced = apply_fn.trace(*args, **kwargs) while state: if flags.FLAGS.mode == 'lower' or flags.FLAGS.mode == 'trace_and_lower': _ = traced.lower(lowering_platforms=(platform,)) clear_caches(state) def run_single_example( get_apply_fn_and_args: Callable[..., Any], get_config: Callable[[], Any], ) -> None: """Run a single example for profiling.""" def main(argv): del argv if flags.FLAGS.mode == 'lower': raise ValueError( '`--mode=lower` is not supported when profiling a single example.' ) config = get_config() apply_fn, args, kwargs, *_ = get_apply_fn_and_args(config) traced = apply_fn.trace(*args, **kwargs) lowered = traced.lower(lowering_platforms=('tpu',)) logging.info('lowered: %s', lowered.as_text('hlo')) app.run(main) def run_benchmarks() -> None: """Run registered google_benchmark benchmarks.""" flags.FLAGS(sys.argv, known_only=True) flags.FLAGS.mark_as_parsed() google_benchmark.main() ================================================ FILE: benchmarks/tracing/vae.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """VAE helper functions.""" from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.vae import models from flax.examples.vae.configs import default as vae_config from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import optax @jax.vmap def binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) return -jnp.sum( labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)) ) @jax.vmap def kl_divergence(mean, logvar): return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar)) def train_step(state, batch, z_rng, latents): def loss_fn(params): recon_x, mean, logvar = models.model(latents).apply( {'params': params}, batch, z_rng ) bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() kld_loss = kl_divergence(mean, logvar).mean() loss = bce_loss + kld_loss return loss grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads) def get_fake_batch(batch_size: int) -> Any: return jnp.ones((batch_size, 784), jnp.float32) def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: rng = jax.random.key(0) rng, key = jax.random.split(rng) batch = get_fake_batch(config.batch_size) params = models.model(config.latents).init(key, batch, rng)['params'] state = train_state.TrainState.create( apply_fn=models.model(config.latents).apply, params=params, tx=optax.adam(config.learning_rate), ) train_step_jit = jax.jit(train_step, static_argnames=('latents',)) return ( train_step_jit, (state, batch, rng, config.latents), dict(), ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_vae_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, vae_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_vae_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, vae_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: benchmarks/tracing/wmt.py ================================================ # Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """WMT helper functions for benchmarking.""" import functools from typing import Any from flax import linen as nn from flax.benchmarks.tracing import tracing_benchmark from flax.examples.wmt import models from flax.examples.wmt.configs import default as wmt_config from flax.training import common_utils from flax.training import dynamic_scale as dynamic_scale_lib from flax.training import train_state import google_benchmark import jax import jax.numpy as jnp import ml_collections import numpy as np import optax class TrainState(train_state.TrainState): dynamic_scale: dynamic_scale_lib.DynamicScale def rsqrt_schedule(init_value: float, shift: int = 0): def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def preferred_dtype(config): platform = jax.local_devices()[0].platform if config.use_mixed_precision: if platform == 'tpu': return jnp.bfloat16 elif platform == 'gpu': return jnp.float16 return jnp.float32 def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights, label_smoothing=0.0): loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } return metrics def wmt_train_step( state, batch, config, learning_rate_fn, label_smoothing=0.0, dropout_rng=None, ): train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation', ] ( inputs, targets, inputs_positions, targets_positions, inputs_segmentation, targets_segmentation, ) = (batch.get(k, None) for k in train_keys) weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(dropout_rng, state.step) def loss_fn(params): logits = models.Transformer(config).apply( {'params': params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={'dropout': dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step if state.dynamic_scale: grad_fn = state.dynamic_scale.value_and_grad(loss_fn, has_aux=True) dynamic_scale, is_fin, (_, logits), grads = grad_fn(state.params) state = state.replace(dynamic_scale=dynamic_scale) else: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) metrics['learning_rate'] = learning_rate_fn(step) if state.dynamic_scale: select_fn = functools.partial(jnp.where, is_fin) new_state = new_state.replace( opt_state=jax.tree_util.tree_map( select_fn, new_state.opt_state, state.opt_state ), params=jax.tree_util.tree_map( select_fn, new_state.params, state.params ), ) metrics['loss_scale'] = dynamic_scale.scale * metrics['denominator'] return new_state, metrics def get_fake_batch(config: ml_collections.ConfigDict) -> dict[str, Any]: rng = jax.random.key(0) batch_size = config.per_device_batch_size max_len = config.max_target_length inputs = jax.random.randint( rng, (batch_size, max_len), 0, config.vocab_size, jnp.int32 ) targets = jax.random.randint( rng, (batch_size, max_len), 0, config.vocab_size, jnp.int32 ) return { 'inputs': inputs, 'targets': targets, 'inputs_position': None, 'targets_position': None, 'inputs_segmentation': None, 'targets_segmentation': None, } def get_apply_fn_and_args( config: ml_collections.ConfigDict, ) -> tuple[Any, tuple[Any, ...], dict[str, Any]]: dtype = preferred_dtype(config) train_config = models.TransformerConfig( vocab_size=config.vocab_size, output_vocab_size=config.vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=dtype, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=jax.nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.normal(stddev=1e-6), ) model = models.Transformer(train_config) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) rng = jax.random.key(0) init_rng, dropout_rng = jax.random.split(rng) batch = get_fake_batch(config) inputs = batch['inputs'] targets = batch['targets'] initial_variables = model.init(init_rng, inputs, targets) dynamic_scale = None platform = jax.local_devices()[0].platform if config.use_mixed_precision and platform == 'gpu': dynamic_scale = dynamic_scale_lib.DynamicScale() state = TrainState.create( apply_fn=model.apply, params=initial_variables['params'], tx=optimizer, dynamic_scale=dynamic_scale, ) jit_train_step = jax.jit( wmt_train_step, static_argnums=(2, 3, 4), ) return ( jit_train_step, (state, batch, train_config, learning_rate_fn, 0.0, dropout_rng), {}, ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_wmt_trace(state): tracing_benchmark.benchmark_tracing( get_apply_fn_and_args, wmt_config.get_config, state ) @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) def test_flax_wmt_lower(state): tracing_benchmark.benchmark_lowering( get_apply_fn_and_args, wmt_config.get_config, state ) if __name__ == '__main__': tracing_benchmark.run_benchmarks() ================================================ FILE: contributing.md ================================================ # How to Contribute Please see https://flax.readthedocs.io/en/latest/contributing.html for more information. ================================================ FILE: docs/.gitignore ================================================ _formatted_howtos ================================================ FILE: docs/.readthedocs.yaml ================================================ # .readthedocs.yml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 build: os: ubuntu-22.04 tools: python: "3.12" # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - epub # - pdf # Optionally set the version of Python and requirements required to build your docs python: install: - method: pip path: . extra_requirements: - all - testing - docs ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/README.md ================================================ # Deprecation This folder contains the deprecated Flax Linen documentation. For the latest Flax NNX docs, check out the `docs_nnx` folder. # Where to find the docs The FLAX Linen documentation can be found here: https://flax-linen.readthedocs.io/en/latest/ # How to build the docs 1. Clone the `flax` repository with `git clone https://github.com/google/flax.git`. 1. In the main `flax` folder, install the required dependencies using `uv pip install -e .[docs]`. 1. [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch. 1. To build the docs, in the `flax/docs` folder run the make script: `make html`. Alternatively, install [`entr`](https://github.com/eradman/entr/), which helps run arbitrary commands when files change. Then run `find ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'`. 1. If the build is successful, you should get the `The HTML pages are in _build/html.` message. You can preview the docs in `flax/docs/_build/html`. # How to run embedded code tests We use `doctest` blocks for embedded code in documents, that are also tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html To run tests locally, run `make doctest` # How to write code documentation Our documentation is written in reStructuredText for Sphinx. It is a meta-language that is compiled into online documentation. For more details, check out [Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). As a result, our docstrings adhere to a specific syntax that has to be kept in mind. Below we provide some guidelines. To learn how to contribute to Jupyter Notebooks or other formats in Flax docs, refer to the dedicated [Contributing](https://flax.readthedocs.io/en/latest/contributing.html) page. ## How much information to put in a docstring Docstring should be informative. We prefer to err on the side of too much documentation than too little. For instance, providing a one-line explanation to a new `Module` which implements new functionality is not sufficient. Furthermore, we highly encourage adding examples to your docstrings, so users can directly see how code can be used. ## How to write inline tested code We use [doctest](https://docs.python.org/3/library/doctest.html) syntax for writing examples in documentation. These examples are ran as tests as part of our CI process. In order to write `doctest` code in your documentation, please use the following notation: ```bash # Example code:: # # def sum(a, b): # return a + b # # sum(0, 1) ``` The `Example code` string at the beginning can be replaced by anything as long as there are two semicolons and a newline following it, and the code is indented. ## How to use "code font" When writing code font in a docstring, please use double backticks. Example: ```bash # This returns a ``str`` object. ``` Note that argument names and objects like True, None or any strings should usually be put in `code`. ## How to create cross-references/links It is possible to create cross-references to other classes, functions, and methods. In the following, `obj_typ` is either `class`, `func`, or `meth`. ```bash # First method: # :`path_to_obj` # Second method: # ::`description ` ``` You can use the second method if the `path_to_obj` is very long. Some examples: ```bash # Create: a reference to class flax.linen.Module. # :class:`flax.linen.Module` # Create a reference to local function my_func. # :func:`my_func` # Create a reference "Module.apply()" to method flax.linen.Module.apply. # :meth:`Module.apply() ` # ``` To creata a hyperlink, use the following syntax: ```bash # Note the double underscore at the end: # `Link to Google `__ ``` ### How to specify arguments for classes and methods * Class attributes should be specified using the `Attributes:` tag. * Method argument should be specified using the `Args:` tags. * All attributes and arguments should have types. Here is an example from our library: ```python class DenseGeneral(Module): """A linear transformation with flexible axes. Attributes: features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` for details. """ features: Union[int, Iterable[int]] axis: Union[int, Iterable[int]] = -1 batch_dims: Iterable[int] = () use_bias: bool = True dtype: Dtype = jnp.float32 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros precision: Any = None @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ ... ``` ================================================ FILE: docs/_ext/codediff.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sphinx directive for creating code diff tables. Use directive as follows: .. codediff:: :title: , --- In order to highlight a line of code, append "#!" to it. """ import sphinx from docutils import nodes from docutils.parsers.rst import directives from docutils.statemachine import ViewList from sphinx.util.docutils import SphinxDirective MISSING = object() class CodeDiffParser: def parse( self, lines: list[str], title: str, groups: list[str] | None = None, skip_test: str | None = None, code_sep: str = '---', sync: object = MISSING, ): """Parse the code diff block and format it so that it renders in different tabs and is tested by doctest. For example: .. testcode:: tab0, tab2, tab3 .. codediff:: :title: Tab 0, Tab 1, Tab 2, Tab 3 :groups: tab0, tab1, tab2, tab3 :skip_test: tab1, tab3 --- --- --- For group tab0: and are executed. For group tab1: Nothing is executed. For group tab2: and are executed. For group tab3: is executed. Arguments: lines: a string list, where each element is a single string code line title: a single string that contains the titles of each tab (they should be separated by commas) groups: a single string that contains the group of each tab (they should be separated by commas). Code snippets that are part of the same group will be executed together. If groups=None, then the group names will default to the tab title names. skip_test: a single string denoting which group(s) to skip testing (they should be separated by commas). This is useful for legacy code snippets that no longer run correctly anymore. If skip_test=None, then no tests are skipped. code_sep: the separator character(s) used to denote a separate code block for a new tab. The default code separator is '---'. sync: an option for Sphinx directives, that will sync all tabs together. This means that if the user clicks to switch to another tab, all tabs will switch to the new tab. """ titles = [t.strip() for t in title.split(',')] num_tabs = len(titles) sync = sync is not MISSING # skip legacy code snippets in upgrade guides if skip_test is not None: skip_tests = {index.strip() for index in skip_test.split(',')} else: skip_tests = set() code_blocks = '\n'.join(lines) if code_blocks.count(code_sep) != num_tabs - 1: raise ValueError( f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.' ) code_blocks = [ code_block.split('\n') for code_block in code_blocks.split(code_sep + '\n') ] # list[code_tab_list1[string_line1, ...], ...] # by default, put each code snippet in a different group denoted by an index number, to be executed separately if groups is not None: groups = [group_name.strip() for group_name in groups.split(',')] else: groups = titles if len(groups) != num_tabs: raise ValueError( f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.' ) tabs = [] test_codes = [] for i, code_block in enumerate(code_blocks): if groups[i] not in skip_tests: test_codes.append((code_block, groups[i])) tabs.append((titles[i], self._code_block(code_block))) output = self._tabs(*tabs, sync=sync) return output, test_codes def _code_block(self, lines): """Creates a codeblock.""" # Remove right trailing whitespace so we can detect the comments. lines = [x.rstrip() for x in lines] highlight = lambda x: x.endswith('#!') code = map(lambda x: x[:-2].rstrip() if highlight(x) else x, lines) highlights = [i + 1 for i in range(len(lines)) if highlight(lines[i])] highlights = ','.join(str(i) for i in highlights) directive = ['.. code-block:: python'] if highlights: directive += [f' :emphasize-lines: {highlights}'] # Indent code and add empty line so the code is picked up by the directive. return directive + [''] + list(map(lambda x: ' ' + x, code)) def _tabs(self, *contents: tuple[str, list[str]], sync): output = ['.. tab-set::'] + [' '] for title, content in contents: output += [f' .. tab-item:: {title}'] if sync: key = title.strip() output += [f' :sync: {key}'] output += [' '] output += [' ' + line for line in content] return output class CodeDiffDirective(SphinxDirective): has_content = True option_spec = { 'title': directives.unchanged, 'groups': directives.unchanged, 'skip_test': directives.unchanged, 'code_sep': directives.unchanged, 'sync': directives.flag, } def run(self): table_code, test_codes = CodeDiffParser().parse( list(self.content), **self.options ) # Create a test node as a comment node so it won't show up in the docs. # We add attribute "testnodetype" so it is be picked up by the doctest # builder. This functionality is not officially documented but can be found # in the source code: # https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py # (search for 'testnodetype'). test_nodes = [] for test_code, group in test_codes: test_node = nodes.comment( '\n'.join(test_code), '\n'.join(test_code), testnodetype='testcode', groups=[group], ) self.set_source_info(test_node) test_node['options'] = {} test_node['language'] = 'python3' test_nodes.append(test_node) # The table node is the side-by-side diff view that will be shown on RTD. table_node = nodes.paragraph() self.content = ViewList(table_code, self.content.parent) self.state.nested_parse(self.content, self.content_offset, table_node) return [table_node] + test_nodes def setup(app): app.add_directive('codediff', CodeDiffDirective) return { 'version': sphinx.__display_version__, 'parallel_read_safe': True, 'parallel_write_safe': True, } ================================================ FILE: docs/_ext/codediff_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for codediff Sphinx extension.""" from absl.testing import parameterized from codediff import CodeDiffParser class CodeDiffTest(parameterized.TestCase): def test_parse(self): input_text = r"""@jax.jit #! def get_initial_params(key): #! init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] extra_line return initial_params --- @jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] return initial_params""" expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params""" expected_testcodes = [ r"""@jax.jit #! def get_initial_params(key): #! init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] extra_line return initial_params """, r"""@jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] return initial_params""", ] title_left = 'Single device' title_right = 'Ensembling on multiple devices' actual_table, actual_testcodes = CodeDiffParser().parse( lines=input_text.split('\n'), title=f'{title_left}, {title_right}', ) actual_table = '\n'.join(actual_table) actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes] self.assertEqual(expected_table, actual_table) self.assertEqual(expected_testcodes[0], actual_testcodes[0]) self.assertEqual(expected_testcodes[1], actual_testcodes[1]) @parameterized.parameters( { 'input_text': r"""x = 1 --- x = 2 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': None, 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 --- x = 4 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': None, 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': 'tab0, tab2', 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': 'tab0, tab1, tab2, tab3', 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.', }, ) def test_parse_errors(self, input_text, title, groups, error_msg): with self.assertRaisesRegex(ValueError, error_msg): _, _ = CodeDiffParser().parse( lines=input_text.split('\n'), title=title, groups=groups, ) ================================================ FILE: docs/_ext/flax_module.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sphinx directive for visualizing Flax modules. Use directive as follows: .. flax_module:: :module: flax.linen :class: Dense """ import importlib import sphinx import sphinx.ext.autosummary.generate as ag from docutils import nodes from docutils.parsers.rst import directives from docutils.statemachine import ViewList from sphinx.util.docutils import SphinxDirective from docs.conf_sphinx_patch import generate_autosummary_content def render_module(modname: str, qualname: str, app): parent = importlib.import_module(modname) obj = getattr(parent, qualname) template = ag.AutosummaryRenderer(app) template_name = 'flax_module' imported_members = False recursive = False context = {} return generate_autosummary_content( qualname, obj, parent, template, template_name, imported_members, app, recursive, context, modname, qualname, ) class FlaxModuleDirective(SphinxDirective): has_content = True option_spec = { 'module': directives.unchanged, 'class': directives.unchanged, } def run(self): module_template = render_module( self.options['module'], self.options['class'], self.env.app ) module_template = module_template.splitlines() # Create a container for the rendered nodes container_node = nodes.container() self.content = ViewList(module_template, self.content.parent) self.state.nested_parse(self.content, self.content_offset, container_node) return [container_node] def setup(app): app.add_directive('flax_module', FlaxModuleDirective) return { 'version': sphinx.__display_version__, 'parallel_read_safe': True, 'parallel_write_safe': True, } ================================================ FILE: docs/_static/css/flax_theme.css ================================================ @import url("theme.css"); .wy-nav-content { max-width: 1290px; } .rst-content table.docutils { width: 100%; } .rst-content table.docutils td { vertical-align: top; padding: 0; } .rst-content table.docutils td p { padding: 8px; } .rst-content div[class^=highlight] { border: 0; margin: 0; } ================================================ FILE: docs/_templates/autosummary/flax_module.rst ================================================ {{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :exclude-members: .. automethod:: __call__ {% block methods %} {% for item in methods %} {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} .. automethod:: {{ item }} {%- endif %} {%- endfor %} {% if methods %} .. rubric:: Methods .. autosummary:: {% for item in methods %} {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs/api_reference/flax.core.frozen_dict.rst ================================================ flax.core.frozen_dict package ============================= .. currentmodule:: flax.core.frozen_dict .. autoclass:: FrozenDict :members: pretty_repr, copy, pop, unfreeze, tree_flatten .. autofunction:: freeze .. autofunction:: unfreeze .. autofunction:: copy .. autofunction:: pop .. autofunction:: pretty_repr ================================================ FILE: docs/api_reference/flax.cursor.rst ================================================ flax.cursor package ============================= The Cursor API allows for mutability of pytrees. This API provides a more ergonomic solution to making partial-updates of deeply nested immutable data structures, compared to making many nested ``dataclasses.replace`` calls. To illustrate, consider the example below:: >>> from flax.cursor import cursor >>> import dataclasses >>> from typing import Any >>> @dataclasses.dataclass(frozen=True) >>> class A: ... x: Any >>> a = A(A(A(A(A(A(A(0))))))) To replace the int ``0`` using ``dataclasses.replace``, we would have to write many nested calls:: >>> a2 = dataclasses.replace( ... a, ... x=dataclasses.replace( ... a.x, ... x=dataclasses.replace( ... a.x.x, ... x=dataclasses.replace( ... a.x.x.x, ... x=dataclasses.replace( ... a.x.x.x.x, ... x=dataclasses.replace( ... a.x.x.x.x.x, ... x=dataclasses.replace(a.x.x.x.x.x.x, x=1), ... ), ... ), ... ), ... ), ... ), ... ) The equivalent can be achieved much more simply using the Cursor API:: >>> a3 = cursor(a).x.x.x.x.x.x.x.set(1) >>> assert a2 == a3 The Cursor object keeps tracks of changes made to it and when ``.build`` is called, generates a new object with the accumulated changes. Basic usage involves wrapping the object in a Cursor, making changes to the Cursor object and generating a new copy of the original object with the accumulated changes. .. currentmodule:: flax.cursor .. autofunction:: cursor .. autoclass:: Cursor :members: apply_update, build, find, find_all, set ================================================ FILE: docs/api_reference/flax.errors.rst ================================================ flax.errors package =================== Flax has the following classes of errors. .. automodule:: flax.errors :members: :exclude-members: FlaxError ================================================ FILE: docs/api_reference/flax.jax_utils.rst ================================================ flax.jax_utils package ======================== .. currentmodule:: flax.jax_utils .. automodule:: flax.jax_utils .. autofunction:: partial_eval_by_shape Multi device utilities ------------------------ .. autofunction:: replicate .. autofunction:: unreplicate .. autofunction:: prefetch_to_device .. autofunction:: pmean .. autofunction:: pad_shard_unpad ================================================ FILE: docs/api_reference/flax.linen/activation_functions.rst ================================================ Activation functions ------------------------ .. automodule:: flax.linen.activation .. currentmodule:: flax.linen.activation .. autoclass:: PReLU :members: :special-members: __call__ .. autofunction:: celu .. autofunction:: elu .. autofunction:: gelu .. autofunction:: glu .. autofunction:: hard_sigmoid .. autofunction:: hard_silu .. autofunction:: hard_swish .. autofunction:: hard_tanh .. autofunction:: leaky_relu .. autofunction:: log_sigmoid .. autofunction:: log_softmax .. autofunction:: logsumexp .. autofunction:: one_hot .. autofunction:: relu .. autofunction:: relu6 as relu6, .. autofunction:: selu .. autofunction:: sigmoid .. autofunction:: silu .. autofunction:: soft_sign .. autofunction:: softmax .. autofunction:: softplus .. autofunction:: standardize .. autofunction:: swish .. autofunction:: tanh ================================================ FILE: docs/api_reference/flax.linen/decorators.rst ================================================ Decorators ---------------------- .. currentmodule:: flax.linen .. autofunction:: compact .. autofunction:: nowrap ================================================ FILE: docs/api_reference/flax.linen/index.rst ================================================ flax.linen ========== Linen is the Flax Module system. Read more about our design goals in the `Linen README `_. .. toctree:: :maxdepth: 2 module init_apply layers activation_functions initializers transformations inspection variable spmd decorators profiling ================================================ FILE: docs/api_reference/flax.linen/init_apply.rst ================================================ Init/Apply ============== .. currentmodule:: flax.linen .. autofunction:: apply .. autofunction:: init .. autofunction:: init_with_output ================================================ FILE: docs/api_reference/flax.linen/initializers.rst ================================================ Initializers ------------------------ .. automodule:: flax.linen.initializers .. currentmodule:: flax.linen.initializers .. autofunction:: constant .. autofunction:: delta_orthogonal .. autofunction:: glorot_normal .. autofunction:: glorot_uniform .. autofunction:: he_normal .. autofunction:: he_uniform .. autofunction:: kaiming_normal .. autofunction:: kaiming_uniform .. autofunction:: lecun_normal .. autofunction:: lecun_uniform .. autofunction:: normal .. autofunction:: truncated_normal .. autofunction:: ones .. autofunction:: ones_init .. autofunction:: orthogonal .. autofunction:: uniform .. autofunction:: variance_scaling .. autofunction:: xavier_normal .. autofunction:: xavier_uniform .. autofunction:: zeros .. autofunction:: zeros_init ================================================ FILE: docs/api_reference/flax.linen/inspection.rst ================================================ Inspection ---------------------- .. currentmodule:: flax.linen .. autofunction:: tabulate ================================================ FILE: docs/api_reference/flax.linen/layers.rst ================================================ Layers ====== .. currentmodule:: flax.linen Linear Modules ------------------------ .. flax_module:: :module: flax.linen :class: Dense .. flax_module:: :module: flax.linen :class: DenseGeneral .. flax_module:: :module: flax.linen :class: Conv .. flax_module:: :module: flax.linen :class: ConvTranspose .. flax_module:: :module: flax.linen :class: ConvLocal .. flax_module:: :module: flax.linen :class: Einsum .. flax_module:: :module: flax.linen :class: Embed Pooling ------------------------ .. autofunction:: max_pool .. autofunction:: avg_pool .. autofunction:: pool Normalization ------------------------ .. flax_module:: :module: flax.linen :class: BatchNorm .. flax_module:: :module: flax.linen :class: LayerNorm .. flax_module:: :module: flax.linen :class: GroupNorm .. flax_module:: :module: flax.linen :class: RMSNorm .. flax_module:: :module: flax.linen :class: InstanceNorm .. flax_module:: :module: flax.linen :class: SpectralNorm .. flax_module:: :module: flax.linen :class: WeightNorm Combinators ------------------------ .. flax_module:: :module: flax.linen :class: Sequential Stochastic ------------------------ .. flax_module:: :module: flax.linen :class: Dropout Attention ------------------------ .. flax_module:: :module: flax.linen :class: MultiHeadDotProductAttention .. flax_module:: :module: flax.linen :class: MultiHeadAttention .. flax_module:: :module: flax.linen :class: SelfAttention .. autofunction:: dot_product_attention_weights .. autofunction:: dot_product_attention .. autofunction:: make_attention_mask .. autofunction:: make_causal_mask Recurrent ------------------------ .. flax_module:: :module: flax.linen :class: RNNCellBase .. flax_module:: :module: flax.linen :class: LSTMCell .. flax_module:: :module: flax.linen :class: OptimizedLSTMCell .. flax_module:: :module: flax.linen :class: ConvLSTMCell .. flax_module:: :module: flax.linen :class: SimpleCell .. flax_module:: :module: flax.linen :class: GRUCell .. flax_module:: :module: flax.linen :class: MGUCell .. flax_module:: :module: flax.linen :class: RNN .. flax_module:: :module: flax.linen :class: Bidirectional BatchApply ------------------------ .. flax_module:: :module: flax.linen :class: BatchApply ================================================ FILE: docs/api_reference/flax.linen/module.rst ================================================ Module ------------------------ .. automodule:: flax.linen .. currentmodule:: flax.linen .. autoclass:: Module :members: setup, variable, param, bind, unbind, apply, init, init_with_output, copy, make_rng, sow, variables, Variable, __setattr__, tabulate, module_paths, is_initializing, perturb, put_variable, has_variable, has_rng, lazy_init, get_variable, path, is_mutable_collection .. autofunction:: apply .. autofunction:: init .. autofunction:: init_with_output .. autofunction:: intercept_methods .. autofunction:: share_scope ================================================ FILE: docs/api_reference/flax.linen/profiling.rst ================================================ Profiling ---------------------- .. currentmodule:: flax.linen .. autofunction:: enable_named_call .. autofunction:: disable_named_call .. autofunction:: override_named_call ================================================ FILE: docs/api_reference/flax.linen/spmd.rst ================================================ SPMD ---------------------- .. automodule:: flax.linen.spmd .. currentmodule:: flax.linen .. autofunction:: Partitioned .. autofunction:: with_partitioning .. autofunction:: get_partition_spec .. autofunction:: get_sharding .. autofunction:: LogicallyPartitioned .. autofunction:: logical_axis_rules .. autofunction:: set_logical_axis_rules .. autofunction:: get_logical_axis_rules .. autofunction:: logical_to_mesh_axes .. autofunction:: logical_to_mesh .. autofunction:: logical_to_mesh_sharding .. autofunction:: with_logical_constraint .. autofunction:: with_logical_partitioning ================================================ FILE: docs/api_reference/flax.linen/transformations.rst ================================================ Transformations ---------------------- .. automodule:: flax.linen.transforms .. currentmodule:: flax.linen .. autofunction:: vmap .. autofunction:: scan .. autofunction:: jit .. autofunction:: remat .. autofunction:: remat_scan .. autofunction:: map_variables .. autofunction:: jvp .. autofunction:: vjp .. autofunction:: custom_vjp .. autofunction:: while_loop .. autofunction:: cond .. autofunction:: switch ================================================ FILE: docs/api_reference/flax.linen/variable.rst ================================================ Variable dictionary ---------------------- .. automodule:: flax.core.variables .. autoclass:: flax.linen.Variable ================================================ FILE: docs/api_reference/flax.serialization.rst ================================================ flax.serialization package ============================ .. currentmodule:: flax.serialization .. automodule:: flax.serialization State dicts ------------------------ .. autofunction:: from_state_dict .. autofunction:: to_state_dict .. autofunction:: register_serialization_state Serialization with MessagePack -------------------------------- .. autofunction:: msgpack_serialize .. autofunction:: msgpack_restore .. autofunction:: to_bytes .. autofunction:: from_bytes ================================================ FILE: docs/api_reference/flax.struct.rst ================================================ flax.struct package ===================== .. currentmodule:: flax.struct .. automodule:: flax.struct .. autofunction:: dataclass .. autoclass:: PyTreeNode ================================================ FILE: docs/api_reference/flax.traceback_util.rst ================================================ flax.traceback_util package ============================ .. currentmodule:: flax.traceback_util .. automodule:: flax.traceback_util Traceback filtering utils -------------------------- .. autofunction:: hide_flax_in_tracebacks .. autofunction:: show_flax_in_tracebacks ================================================ FILE: docs/api_reference/flax.training.rst ================================================ flax.training package ===================== Checkpoints ------------------------ .. currentmodule:: flax.training.checkpoints .. automodule:: flax.training.checkpoints .. autofunction:: save_checkpoint .. autofunction:: save_checkpoint_multiprocess .. autofunction:: latest_checkpoint .. autofunction:: restore_checkpoint .. autofunction:: convert_pre_linen Learning rate schedules ------------------------ .. currentmodule:: flax.training.lr_schedule .. automodule:: flax.training.lr_schedule .. autofunction:: create_constant_learning_rate_schedule .. autofunction:: create_stepped_learning_rate_schedule .. autofunction:: create_cosine_learning_rate_schedule Train state ------------------------ .. currentmodule:: flax.training.train_state .. autoclass:: TrainState :members: apply_gradients, create Early Stopping ------------------------ .. currentmodule:: flax.training.early_stopping .. autoclass:: EarlyStopping :members: reset, update Common Utilities ------------------------ .. currentmodule:: flax.training.common_utils .. autofunction:: shard .. autofunction:: shard_prng_key .. autofunction:: stack_forest .. autofunction:: get_metrics .. autofunction:: onehot ================================================ FILE: docs/api_reference/index.rst ================================================ API Reference ============= .. toctree:: :maxdepth: 4 flax.config flax.core.frozen_dict flax.cursor flax.errors flax.jax_utils flax.linen/index flax.serialization flax.struct flax.traceback_util flax.training flax.traverse_util ================================================ FILE: docs/conf.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Configuration file for the Sphinx documentation builder.""" # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) import os import sys import doctest sys.path.insert(0, os.path.abspath('..')) # Include local extension. sys.path.append(os.path.abspath('./_ext')) # patch sphinx # -- Project information ----------------------------------------------------- project = 'Flax' copyright = '2023, The Flax authors' # pylint: disable=redefined-builtin author = 'The Flax authors' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'myst_nb', 'codediff', 'flax_module', 'sphinx_design', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The suffix(es) of source filenames. # Note: important to list ipynb before md here: we have both md and ipynb # copies of each notebook, and myst will choose which to convert based on # the order in the source_suffix list. Notebooks which are not executed have # outputs stored in ipynb but not in md, so we must convert the ipynb. source_suffix = ['.rst', '.ipynb', '.md'] autosummary_generate = True master_doc = 'index' autodoc_typehints = 'none' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # # html_theme = 'pydata_sphinx_theme' html_theme = 'sphinx_book_theme' html_css_files = ['css/flax_theme.css'] # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = './flax.png' html_favicon = './flax.png' # title of the website html_title = '' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named 'default.css' will overwrite the builtin 'default.css'. html_static_path = ['_static'] html_extra_path = ['robots.txt'] # href with no underline and white bold text color announcement = """ This site covers the old Flax Linen API. [Explore the new Flax NNX API ✨] """ html_theme_options = { 'repository_url': 'https://github.com/google/flax', 'use_repository_button': True, # add a 'link to repository' button 'use_issues_button': False, # add an 'Open an Issue' button 'path_to_docs': ( 'docs' ), # used to compute the path to launch notebooks in colab 'launch_buttons': { 'colab_url': 'https://colab.research.google.com/', }, 'prev_next_buttons_location': None, 'show_navbar_depth': 1, 'announcement': announcement, } # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development # nb_execution_mode = 'off' # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook # files that will not be executed. myst_enable_extensions = ['dollarmath'] nb_execution_excludepatterns = [ 'quick_start.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx 'guides/quantization/fp8_basics.ipynb', 'guides/training_techniques/use_checkpointing.ipynb', # TODO(IvyZX): needs to be updated ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False nb_execution_raise_on_error = True # -- Extension configuration ------------------------------------------------- # Tell sphinx-autodoc-typehints to generate stub parameter annotations including # types, even if the parameters aren't explicitly documented. always_document_param_types = True # -- doctest configuration ------------------------------------------------- doctest_default_flags = doctest.NORMALIZE_WHITESPACE doctest_global_setup = """ import jax import jax.numpy as jnp from flax import nnx import logging as slog from absl import logging as alog # Avoid certain absl logging messages to break doctest filtered_message = [ 'SaveArgs.aggregate is deprecated', '', ] class _CustomLogFilter(slog.Formatter): def format(self, record): message = super(_CustomLogFilter, self).format(record) for m in filtered_message: if m in message: return '' return message alog.use_absl_handler() alog.get_absl_handler().setFormatter(_CustomLogFilter()) """ ================================================ FILE: docs/conf_sphinx_patch.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Patch Sphinx to improve documentation aesthetics.""" # TODO(cgarciae): Send a PR to sphinx to upstream this fix. Issue: https://github.com/google/flax/issues/2196 # This patch is needed to make autosummary provide the "annotations" # variable so we can exclude function attributes from the methods list # in flax_module.rst. The patch as such only adds this single line: # # ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())' # # We should consider sending a PR to sphinx so we can get rid of this. # Original source: https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351 from typing import Any import sphinx.ext.autodoc import sphinx.ext.autosummary.generate as ag def generate_autosummary_content( name: str, obj: Any, parent: Any, template: ag.AutosummaryRenderer, template_name: str, imported_members: bool, app: Any, recursive: bool, context: dict, modname: str = None, qualname: str = None, ) -> str: doc = ag.get_documenter(app, obj, parent) def skip_member(obj: Any, name: str, objtype: str) -> bool: try: return app.emit_firstresult( 'autodoc-skip-member', objtype, name, obj, False, {} ) except Exception as exc: ag.logger.warning( __( 'autosummary: failed to determine %r to be documented, ' 'the following exception was raised:\n%s' ), name, exc, type='autosummary', ) return False def get_class_members(obj: Any) -> dict[str, Any]: members = sphinx.ext.autodoc.get_class_members( obj, [qualname], ag.safe_getattr ) return {name: member.object for name, member in members.items()} def get_module_members(obj: Any) -> dict[str, Any]: members = {} for name in ag.members_of(obj, app.config): try: members[name] = ag.safe_getattr(obj, name) except AttributeError: continue return members def get_all_members(obj: Any) -> dict[str, Any]: if doc.objtype == 'module': return get_module_members(obj) elif doc.objtype == 'class': return get_class_members(obj) return {} def get_members( obj: Any, types: set[str], include_public: list[str] = [], imported: bool = True, ) -> tuple[list[str], list[str]]: items: list[str] = [] public: list[str] = [] all_members = get_all_members(obj) for name, value in all_members.items(): documenter = ag.get_documenter(app, value, obj) if documenter.objtype in types: # skip imported members if expected if imported or getattr(value, '__module__', None) == obj.__name__: skipped = skip_member(value, name, documenter.objtype) if skipped is True: pass elif skipped is False: # show the member forcedly items.append(name) public.append(name) else: items.append(name) if name in include_public or not name.startswith('_'): # considers member as public public.append(name) return public, items def get_module_attrs(members: Any) -> tuple[list[str], list[str]]: """Find module attributes with docstrings.""" attrs, public = [], [] try: analyzer = ag.ModuleAnalyzer.for_module(name) attr_docs = analyzer.find_attr_docs() for namespace, attr_name in attr_docs: if namespace == '' and attr_name in members: attrs.append(attr_name) if not attr_name.startswith('_'): public.append(attr_name) except ag.PycodeError: pass # give up if ModuleAnalyzer fails to parse code return public, attrs def get_modules(obj: Any) -> tuple[list[str], list[str]]: items: list[str] = [] for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): fullname = name + '.' + modname try: module = ag.import_module(fullname) if module and hasattr(module, '__sphinx_mock__'): continue except ImportError: pass items.append(fullname) public = [x for x in items if not x.split('.')[-1].startswith('_')] return public, items ns: dict[str, Any] = {} ns.update(context) if doc.objtype == 'module': scanner = ag.ModuleScanner(app, obj) ns['members'] = scanner.scan(imported_members) ns['functions'], ns['all_functions'] = get_members( obj, {'function'}, imported=imported_members ) ns['classes'], ns['all_classes'] = get_members( obj, {'class'}, imported=imported_members ) ns['exceptions'], ns['all_exceptions'] = get_members( obj, {'exception'}, imported=imported_members ) ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members']) ispackage = hasattr(obj, '__path__') if ispackage and recursive: ns['modules'], ns['all_modules'] = get_modules(obj) elif doc.objtype == 'class': ns['members'] = dir(obj) ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys()) ns['methods'], ns['all_methods'] = get_members( obj, {'method'}, ['__init__'] ) ns['attributes'], ns['all_attributes'] = get_members( obj, {'attribute', 'property'} ) ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) if modname is None or qualname is None: modname, qualname = ag.split_full_qualified_name(name) if doc.objtype in ('method', 'attribute', 'property'): ns['class'] = qualname.rsplit('.', 1)[0] if doc.objtype in ('class',): shortname = qualname else: shortname = qualname.rsplit('.', 1)[-1] ns['fullname'] = name ns['module'] = modname ns['objname'] = qualname ns['name'] = shortname ns['objtype'] = doc.objtype ns['underline'] = len(name) * '=' if template_name: return template.render(template_name, ns) else: return template.render(doc.objtype, ns) ag.generate_autosummary_content = generate_autosummary_content ================================================ FILE: docs/developer_notes/index.rst ================================================ Developer notes =============== .. toctree:: :maxdepth: 1 module_lifecycle lift FLIPs ================================================ FILE: docs/developer_notes/lift.md ================================================ # Lifted transformations ⚠️ Advanced topic ⚠️ This design note explains the underlying implementation of `flax.linen.transform`, which enables JAX transformations inside Flax `Module`s. ## Introduction JAX uses a functional API meaning that it only guarantees correct behavior when using functions without side effects ([JAX docs](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html#differences-from-numpy)). Typically, these side effects are the result of mutating an object that lives outside the function. The functional paradigm has some advantages like the ability to explicitly reason about state and stochasticity. The function output only changes when an input argument changes. Therefore, a function is guaranteed to behave deterministically. But pure functions offer another big advantage to JAX: specifically, they enable functional transformations. For example `jax.vmap(f)` will vectorize a function `f`. Because `f` cannot have side effects the vectorized/parallel version of `f` is well-defined. To see why we need this restriction, consider what happens if `f` would increment a counter or draw a random number. Would `f` draw the same or a different random number for each item in the vector? Would each item in the batch have its own counter or is the counter shared among the items? And in what order is the counter incremented if `f` is computed in parallel? The answer to all these questions is "it depends". The behavior is ambiguous and the functional constraint elegantly avoids this problem. Flax introduces a safe way to have limited randomness and stateful variables in a JAX-compatible form. The reason why the state in Flax is not problematic is because it is local: inside a Flax `Module` there are variables and PRNG sequences, but on the outside there are only JAX Arrays and PRNG keys. For most use cases, Flax is used to define models in a stateful way. Because a `Module` behaves like a pure function externally, we can fully utilize JAX with all of its transformations. There are, however, cases when we want to have the best of both worlds by using transformations and `Module` together. This design note explains how we extend JAX's functional transformation to work on `Module`s that have internal state and randomness. ## Functionalization Before we jump into the details let's consider a simple example where we would like to use `vmap` inside a `Module`. First, we define a simple MLP without any transformations: ```python import jax from jax import random, numpy as jnp from flax import linen as nn class MLP(nn.Module): @nn.compact def __call__(self, xs): h = nn.Dense(4, name='hidden')(xs) h = nn.relu(h) return nn.Dense(1, name='out')(h) ``` Now what if we want to have separate MLP parameters for each item in `xs`? If this were "vanilla JAX" we could imagine writing something like `jax.vmap(apply_mlp)(mlp_params, xs)`. But doing something like this in Linen will actually fail: ```python class NaiveVmapMLP(nn.Module): @nn.compact def __call__(self, xs): mlp = MLP() return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs) # fails ``` JAX will raise an error when `vmap` is used on `mlp` because it's not a JAX array or a simple container of arrays. We can not really blame JAX for refusing to perform this under-specified job. After all, it's not even clear what should happen here. The parameters inside the MLP are not even initialized yet and we will need a separate PRNG key for each group of parameters. `jax.vmap` can only broadcast or map over an axis but it cannot automatically split an PRNG key. Therefore, we have to call `jax.random.split` manually. We can fix this problem by first turning `MLP` into a pure init and apply function. Afterwards, we use the `param` method to store the parameters: ```python class ManualVmapMLP(nn.Module): @nn.compact def __call__(self, xs): mlp = MLP(parent=None) init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params'] apply_fn = jax.vmap(mlp.apply, in_axes=0) mlp_params = self.param('mlp', init_fn, xs) return apply_fn({'params': mlp_params}, xs) xs = jnp.ones((3, 4)) variables = ManualVmapMLP().init(random.key(0), xs) print(jax.tree_util.tree_map(jnp.shape, variables['params'])) """==> { mlp: { hidden: { bias: (3, 4), kernel: (3, 4, 4), }, out: { bias: (3, 1), kernel: (3, 4, 1), }, }, } """ ``` Here, `MLP(parent=None)` creates a detached instance of `MLP`. This avoids reserving a name for the submodule inside the current module. Although not strictly necessary, this also ensures we cannot accidentally use the MLP instance in a stateful way and we are forced to use it through either `.init` or `.apply`. This example is still relatively concise but it already takes a few extra "bookkeeping" statements to make it work. However, this implementation has a number of limitations: 1. During initialization, we call the submodule twice through `init_fn` and `apply_fn`. If the submodule used the same trick to do functional transformation we will end up executing a lot of code as the number of module calls grows like 2^d where d is the number of nested function transformations. 2. The implementation assumes the submodule only requires the parameter RNG sequence. 3. The implementation assumes we only create variables in the "params" collection during `init`. However, it does not support other variable collections and creating/updating variables in `apply`. Point 3 in particular makes manual functionalization cumbersome. Feel free to try and extend the above example with a `nn.BatchNorm` layer in the `MLP` module. This will require dealing with some additional complexity like storing the updated batch stats and making sure the batch stats are not mutable inside `vmap` when it should be immutable (e.g.: eval mode). We call the process of transforming a stateful Module into a pure function "functionalization". By temporarily turning a stateful `Module` into a function we make it compatible with JAX's functional transformations. ## Lifting Flax provides an alternative for manual functionalization which we call lifted transformation. Lifted transformations are defined in `flax.core.lift`. All the lifted JAX transformations are defined with a single generic lifting API called `pack`. A number of decisions had to be made in order to define `pack`. The implementation of `pack` controls how variables and rngs are lifted and how fine-grained the user control is. It must also decide whether lifting decisions are made at variable or transformation definition. ### Lifting granularity With the Linen API, users can define arbitrary variable collections and PRNG sequences. Each variable in a collection is lifted in the same way. Collections are typically given a semantically meaningful name like "params" or "batch_stats" rather than a general purpose name like "state". Because collections carry semantic meaning we can decide at the transformation level how each collection should be lifted. For example, we want to share all parameter variables when we add a batch dimension to a model. At the same time we can write generic code that uses transformations without knowing exactly what kind of variables the submodules will create. Collections thus strike a balance between fine-grained control and generality. We also avoid brittle string matching code that loops over all variables and tries to split up collections in an ad-hoc way based on naming conventions like: target all variables with the name prefix "kernel". If more fine-grained control is necessary a user can simply split up a set of variables over multiple collections that should be handled differently. ### Transformation vs variable control Lifting behavior could be defined either at the transformation level or during variable definition. We use transformation level definitions of lifting behavior. The reason for this choice is that there are many different transformations with various behaviors. For example: `vmap` has broadcasted and vectorized arguments, while `scan` has scan, carry, and broadcast arguments. A variable would have to define its behavior for all these transformations otherwise a `Module` would not be compatible with these transformations. Alternatively, we would have to make default decisions for how transformations are handled. However, this could lead to silent bugs because the behavior might not actually be valid given the users intent. The lift package also provides a general purpose `transform`, which allows an arbitrary function to transform a variable collection. For example, this can be used to tie the weights in a tied auto-encoder by transposing the weights. It is unclear whether a similar general purpose transform could be defined if lifting decisions were made at variable definition. ### Linen The lifting module does not know about the Linen `Module` API. Instead it operates directly on instances of `flax.core.Scope`. A `Scope` instance contains the variables and PRNG sequences of a `Module`. Each `Module` instance has a `Scope` instance in the `.scope` field if it has a parent or it was created using `init` or `apply`. Typically, the top-level `Module` instance — on which you call `init` or `apply` — is the only `Module` instance that does not have a `Scope` bound to it. When a `Module` is transformed, we use the `flax.core.lift` APIs to lift the scope and use `Module.clone()` to create a new `Module` instance with the lifted scope bound to it. `flax.linen.transforms` exposes wrappers for the transformations in `flax.core.lift`. The core lifting APIs operate on functions while the Linen wrappers can transform either a `Module` class or a `Module` method. Thus, lifting is implemented independently from the Linen API. This separation of concern simplifies the implementation, while potentially allowing alternative `Module` abstractions to build upon a common core for lifting and state management. ### Implementation The `pack(fn, in_vars, out_vars, rngs)` API goes through the following stages: 1. *Scope de-duplication* This stage is only relevant if multiple Scopes are lifted together. In this case we must first find the set of root scopes. A scope is a root if none of its ancestors are in the set of scopes that need to be lifted. By only lifting roots we avoid lifting the same variables twice. For non-root scopes we store a reference to its ancestor scope and a path such that we can later reconstruct it (stage 4). 2. *Filter stage* Variables and PRNG sequences are split up into groups. This way `fn` can lift each group into the transformation separately. A group is defined by a filter specified as: - a list of collections/prng names - `True` (match everything) - `False` (match nothing) - `DenyList(filter)` (match everything but the specified collections (e.g.: `DenyList(['params'])` matches everything except the 'params' collection.)). A collection or PRNG sequence can only be put into a single group. If a collection matches multiple filters, it will be put into the first group with a matching filter. If a collection or PRNG sequence does not match any filter it will not be lifted. This means that it cannot be used inside the transformation and attempting to do this will cause an error to be raised. For example, `in_vars = (["params"], True)` will cause the "params" collection to be put in the first group and all other collection to be put in the second group. For each PRNG sequence that is matched we seed a new PRNG sequence by calling `make_rng`. This avoids the need to update the PRNG state after the lifted transformation is complete. 3. *Transform-specific lifting* `fn` is called with the variable and PRNG groups. JAX transforms have varying signatures and lifting options. Arguably the cleanest example is `vmap`. In the case of vmap the function arguments, PRNGs and variable collections are passed into a `jax.vmap` wrapped function. 4. *Scope reconstruction* Now that the variables and PRNGs are lifted inside the transformation, we want to recreate the lifted scopes. Pack calls `fn` with a `scope_fn` that takes the lifted variables and PRNGs and returns the reconstructed scopes with the lifted variables and rng sequences. 5. *Repack stage* After we have used the lifted scopes we have to retrieve the updated variables (PRNG sequences can simply be discarded). pack passes the `repack_fn` to support this. This stage is similar to stage 2 except that we only lift variables and immutable variables are ignored. Immutable variables cannot be updated. Therefore, they should not be returned from the transformed function. 6. *Commit stage* `pack` expects `fn` to return a pair where the first item will simply be returned from pack and the second item should be the repacked variables. The updated variables are stored in the original/un-lifted scopes such that the mutations that happen inside the transformation survive after the transformation completes. ### Using pack example A minimal example of using `pack` to transpose each matrix in a variable collection: ```python from flax.core import lift from flax.core import Scope, init, apply, nn as core_nn def lift_transpose(fn, target='params', variables=True, rngs=True): # by default we transpose 'params' and simply pass through all other variables. def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args): # normally we would first call into a JAX transformed function here... target, rest = variable_groups def trans(x): if x.ndim == 2: return x.T return x target = jax.tree_util.tree_map(trans, target) variable_groups = (target, rest) scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args) out_variables = repack_fn(scope) return y, out_variables return lift.pack( wrapper, in_variable_filters=(target, variables), out_variable_filters=(variables,), rng_filters=(rngs,)) x = jnp.ones((3, 2)) y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4) ``` NOTE that most users should not need to interact with `pack` directly. Please open a GitHub issue when you find a use case that is not supported yet by the existing lifted transformations. ### Supported transformations | Jax Transform | Supported in Linen? | Comments | |-|-|-| | vmap | ✅ | | | scan | ✅ | Carry variables cannot be initialized inside the scan body. | | remat | ✅ | | | jit | ✅ | Current implementation might cause unnecessary recompilation. | | jvp | ✅ | | | vjp | ✅ | | | custom_vjp | ✅ | | | custom_jvp | ❌ | | | while_loop | ✅ | Carry variables cannot be initialized inside the while_loop body. | | cond | ✅ | Variable initialization / mutation must structurally match across branches. | | switch | ✅ | Variable initialization / mutation must structurally match across branches. | | pmap | ❌ | | | xmap | ❌ | | References: - [Linen transforms documentation](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) - [Linen transforms source code](https://github.com/google/flax/blob/main/flax/linen/transforms.py) - [Core lifting source code](https://github.com/google/flax/blob/main/flax/core/lift.py) ### Linen examples Going back to our original example, we can now use `nn.vmap` to simplify our implementation: ```python class LinenVmapMLP(nn.Module): @nn.compact def __call__(self, xs): VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0) return VmapMLP(name='mlp')(xs) variables = LinenVmapMLP().init(random.key(0), xs) print(jax.tree_util.tree_map(jnp.shape, variables['params'])) """==> { mlp: { Dense_0: { bias: (3, 4), kernel: (3, 2, 4), }, Dense_1: { bias: (3, 1), kernel: (3, 4, 1), }, }, } """ ``` Here we use `variable_axes={'params': 0}` to indicate that parameters are vectorized rather than shared and `split_rngs={'params': True}` means each set of parameters is initialized independently. We can also extend the example with some inner state by adding a `BatchNorm` layer: ```python class StatefulMLP(nn.Module): @nn.compact def __call__(self, x, *, train): h = nn.Dense(4, name='hidden')(x) h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train) h = nn.relu(h) return nn.Dense(1, name='out')(h) class LinenStatefulVmapMLP(nn.Module): @nn.compact def __call__(self, xs, *, train): VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0) return VmapMLP(name='mlp')(xs, train=train) variables = LinenStatefulVmapMLP().init(random.key(0), xs) ``` All we had to add to `nn.vmap` is `'batch_stats': 0`, indicating that the batch stats are vectorized rather than shared along the first axis. ## Alternatives Other numerical computation frameworks consider variables a first-class citizen. An alternative to functionalization would be to use a variable system either integrated or on top of JAX. An advantage of this is that per-variable lifting becomes easier. If variables are part of the JAX IR (JAXPR), we could inspect which variables have to be lifted in a certain computation. Optionally, they could be annotated with a collection tag to decide on various lifting options. The downside of this approach is that a variable system is more complicated. Variables are related references and break a core assumption of Functional Programming (see [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency)) Other APIs that currently have a functional interface would probably require integration as well (e.g.: checkpointing and optimization APIs). ================================================ FILE: docs/developer_notes/module_lifecycle.rst ================================================ The Flax Module lifecycle ######################### .. testsetup:: from typing import Any, Callable, Iterable import flax from flax import linen as nn from jax import random import jax This design note is intended for users who are already familiar with Flax Linen Modules but want to understand more about the design principles behind the abstraction. This note should give you a good understanding of the assumptions and guarantees the Module API is built upon. If you have no practical experience with Modules yet, check out the `Quickstart guide `_. Flax Linen Modules offer a Pythonic abstraction on top of Flax core. The `Module `_ abstraction allows you to create classes that have state, parameters and randomness on top of JAX. This is a practical guide to the design and behavior of the ``Module`` class. By the end, you should feel comfortable to go off the beaten track and use Modules in new ways. Overview *********** Definition ============= Let's start with a high-level overview of the Module lifecycle. First, define a simple Module: .. testcode:: class MLP(nn.Module): # 1. Attribute annotations hidden_size: int out_size: int # 2. The ``setup`` method def setup(self): self.hidden = nn.Dense(self.hidden_size) self.out = nn.Dense(self.out_size) # 3. User methods def __call__(self, x): a = self.hidden(x) h = nn.relu(a) return self.out(h) This Module consists of: #. **Attribute annotations**, defined as `dataclass `_ fields. These annotations automatically define a constructor. #. **The ``setup`` method**, which creates submodules and assigns them to attributes. #. **User methods**. By convention, most Modules have just one ``__call__`` method, but you can define multiple methods or use different method names. Construction/initialization ============================= Now we want to construct and use the ``MLP`` Module: .. testcode:: mlp = MLP(hidden_size=5, out_size=3) x = jax.numpy.ones((1, 2)) variables = mlp.init(random.key(0), x) y = mlp.apply(variables, x) First, we construct an instance of ``MLP`` and pass the construction attributes. Note that construction here is different from what you might expect if you are not used to Functional Programming patterns. The ``MLP`` constructor does not actually create variables or any internal state whatsoever. It's best to think of it as a specification or template of the Module that contains functionality but no data. Let's take a closer look at initialization. Surprisingly, there actually is no separate initialization path in Flax. Calling ``init`` is just a special case of ``apply``, which you can also write as: .. testcode:: # equivalent to: variables = mlp.init(random.key(0), x) _, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True) Thus, ``init`` is nothing more than a wrapper around ``apply`` where: #. We call a Module without any initial variables (an empty dict). #. A PRNG generator named ``"params"`` is always passed for randomly initializing parameters (using the parameter initialization function). #. All variable collections are set to mutable (``mutable=True``). When a collection is mutable, existing variables can be updated and new variables can be created. Thus, inside ``init`` variables can be initialized in any variable collection and they are all added to the returned variable dictionary. Lifecycle ============= Now that you have learned about ``init`` being a special case of ``apply``, let's look at ``.apply(...)`` in more detail. In fact, most of the complexity of Modules resides in the ``apply`` method. The "Module lifecycle" consists of constructing and ``apply``-ing a Module. We can summarize the Module lifecycle as follows: #. We construct ``mlp = MLP(hidden_size=5, out_size=3)``, such that ``mlp.hidden_size=5`` and ``mlp.out_size=3``. #. Then, call ``mlp.apply``, which: #. Makes a clone of ``mlp``, let's call it ``mlp_copy``. #. Calls ``mlp_copy.setup()``. #. Returns the output of ``mlp_copy.__call__()`` and optionally the variable collections that were specified as mutable using the keyword argument ``mutable=``. Notice that the lifecycle includes cloning the Module instance. This is done to ensure that ``apply`` can be treated as a pure function (i.e., if you pass the same arguments in, it will return the same outputs). You will learn about this in more detail later in the :ref:`Top-level Modules` section. Variables ========== The word “variable” is ubiquitous in programming and math. However, it's important to have a good understanding of what variables are in the context of JAX and Flax. Inside Flax Modules, `variables `_ act like you expect from Python. They are initialized once, read, and perhaps even updated every so often. However, JAX has no concept of variables. Instead, values are stored in arrays similar to NumPy arrays - with one important difference: they are immutable. The ``init`` and ``apply`` methods return the variables as a nested dictionary with string keys and JAX arrays at the leaves. At the top level each key corresponds to a variable collection. Inside each collection the nested dict structure corresponds with the ``Module`` hierarchy. The variable dict is immutable and therefore really just a snapshot of state the variables are in. When ``apply`` is called again, the variable dict is passed as an argument. Such that the variables are in the same state as when the previous ``init`` / ``apply`` call finished. .. note:: Module fields are declared using the `field_name: TypeHint` syntax (same as dataclasses). Without a type hint, an attribute is considered a static property of the class. In case you cannot specify the type you can use ``typing.Any`` as a wildcard type. Compact Modules ****************** Linen provides an alternative API for defining modules more compactly. This is especially useful for the common case where the Module consists of only one method that uses parameters and/or sub-modules. Using the compact API the MLP can be rewritten as follows: .. testcode:: class CompactMLP(nn.Module): hidden_size: int out_size: int @nn.compact def __call__(self, x): a = nn.Dense(self.hidden_size)(x) h = nn.relu(a) return nn.Dense(self.out_size)(h) A compact ``Module`` is similar in spirit to a function. It offers a concise notation and restricts external interaction to the inputs and return values of the function. In this case the concise notation might make it easier for others to understand what the Module does. There is no need to jump back and forth between the ``setup`` and ``__call__`` method to understand what the submodules are doing. Instead, simply reading the ``__call__`` method from top to bottom once should provide a concise overview. This can make a significant difference if you are implementing complex Modules with many hyperparameters. See `setup or compact `_ for a practical guide on deciding between setup and compact. Another benefit of defining submodules and/or variables inline is that you can add arguments to your method when constructing variables. The most common example of this is using shape information to determine the shape of a parameter like this: .. testcode:: class CompactScaledMLP(nn.Module): hidden_size: int out_size: int @nn.compact def __call__(self, x): scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:]) x *= scale[None] a = nn.Dense(self.hidden_size)(x) h = nn.relu(a) return nn.Dense(self.out_size)(h) .. testcode:: :hide: mdl = CompactScaledMLP(hidden_size=4, out_size=5) x = jax.numpy.ones((3, 2)) vars = mdl.init(random.key(0), x) assert vars["params"]["scale"].shape == (2,) Many of the standard Linen Modules like ``nn.Dense`` use shape inference already to avoid the need to specify input shapes (like the number of input features to a Dense layer). Compact control flow ===================== The order in which you define submodules determines the name of a submodule if none is provided explicitly (using the ``name=`` keyword argument passed to the Module's constructor). Because the ``name`` determines how parameters are mapped to submodules, you must be careful about mixing control flow with auto-generated names. Using control flow can change the order or remove certain submodules altogether. This is useful in case a submodule should only exist depending on some construction argument. However, when control flow depends on the input arguments to the Module, you should be careful. For example, the following Module will break: .. testcode:: class WrongModule(nn.Module): @nn.compact def __call__(self, x, mode): if mode == "encode": return nn.Dense(features=8)(x) elif mode == "decode": return nn.Dense(features=4)(x) The above Module will break because either the encoder or decoder path will construct a Module named "Dense_0". This means the two Modules will share parameters which is not intended here. Actually, the two Modules cannot share parameters because they each have a different number of features. This problem can be solved in various ways: - Provide explicit names - create the modules in ``setup`` - or move the constructor out of the control flow. The latter is done as follows: .. testcode:: class CorrectModule(nn.Module): @nn.compact def __call__(self, x, mode): encoder = nn.Dense(8) decoder = nn.Dense(4) if mode == "encode": return encoder(x) elif mode == "decode": return decoder(x) .. testcode:: :hide: def init_fn(mdl): x = jax.numpy.ones((3, 2)) z = mdl(x, "encode") return mdl(z, "decode") mdl = CorrectModule() vars = nn.init(init_fn, mdl)(random.key(0)) assert vars["params"]["Dense_0"]["kernel"].shape == (2, 8) assert vars["params"]["Dense_1"]["kernel"].shape == (8, 4) In the above example the construction order is fixed. After construction the submodules can be used in an arbitrary order. .. note:: compact modules show a strong resemblance to `React hooks `_. Top-level Modules ***************** When a Module instance is created at the "top-level", it will be in an "unbound" state - that is, it has no variables attached. "Top-level" means it is not constructed as a sub-Module inside another Module class. Apart from calling ``init`` and ``apply``, there is not much you can do with an unbound Module. Note also that ``setup`` is not called on unbound Modules, so you can only access the construction arguments. Refer to the :ref:`Future work` section to learn how this might change in the future. Why are top-level Modules always unbound? =============================================== When we call ``apply``, a copy of the top-level Module is created which will actually hold the variables and PRNG sequences. This stateful, "bound", clone only exists while we are executing the apply method. The reason for this is that if you create a stateful object and destroy it before the apply function returns, the ``apply`` function itself behaves like a pure function. A pure function has two constraints: #. If you put the same arguments in, it will return the same outputs #. It does not change anything outside the function. This means you cannot manipulate stateful objects that are accessible outside the pure function. Pure functions have many advantages but when using JAX they are often essential. For example, most code requires compilation using ``jax.jit`` to be fast and once you created a Module you probably want to optimize its parameters using ``jax.grad``. However, these APIs expect a pure function and don't work on stateful bound ``Module`` instances directly. Moreover, pure functions allow for flexible interoperability with other libraries. For example, We recommend `Optax `_ for optimizing parameters. The optimizers in Optax expect and return a PyTree of JAX arrays to optimize, just like the ``apply`` function of a Linen Module. Cloning =============================================== To make this approach work reliably we need well-defined cloning behavior. Rather than relying on a complex nested cloning procedure like Python's ``deepcopy``, Flax enforces that a ``Module`` is exactly defined by its construction arguments. Therefore cloning a Module reduces to calling the constructor with its original construction arguments. Because ``Module`` acts as an immutable dataclass, the construction arguments are mapped directly to instance attributes. Non-construction attributes that are computed in ``setup`` or ``__post_init__`` should also depend only on the construciton arguments to ensure a well-defined clone. Bind =============================================== Sometimes it's useful to have a bound, top-level Module without having to wrap the code in a function. For example: to interact with a Module inside a Jupyter notebook. The `bind `_ method returns a bound clone with an unlimited lifetime. The downside of this is that you cannot combine it with JAX transformations or integrate it into a vanilla JAX codebase that expects stateless code. For example, `Optax `_ can optimize a Pytree of parameters but it cannot directly optimize a bound ``Module`` instance created with ``.bind`` (because that's not a Pytree). Thus, you cannot combine the ``bind`` API with a functional optimizer API like Optax. Setup ********** The ``setup`` method is often used like the constructor hook (``__init__``) in normal Python classes. However, for more advanced use cases it's good to realize that it is not quite the same as a constructor. ``setup`` is only called after a Module becomes bound. Normally, this is not an issue because most Modules are bound (almost) immediately (as part of ``init`` and ``apply``). Inside ``setup``, sub-modules become bound when they are assigned to an attribute. Inside an ``nn.compact`` decorated method, sub-modules are bound immediately when constructed. As explained in the previous section, top-level Modules are never bound and thus setup is not called when they are constructed. This means you cannot access attributes assigned in setup from an unbound, top-level module. .. testcode:: class TopLevelAccess(nn.Module): def setup(self): self.foo = nn.Dense(2) mdl = TopLevelAccess() assert not hasattr(mdl, "foo") # foo is not defined because setup is not called The ``setup`` method is not called immediately after the ``Module`` becomes bound but only when you interact with the ``Module`` instance (e.g.: call a method or access an attribute). This should not impact the behavior of a ``Module`` but the lazy execution does sometimes affect log statements and stack traces during debugging. The section on :ref:`Functionalization` will explain why we need ``setup`` to be lazy in the first place. Functionalization ****************** So far we had a pure ``apply`` function that is typically transformed with some JAX transformations and inside ``apply`` we have a stateful Module instance to work with. In other words: Outside of a Module we are in a functional world where we have the power of JAX's functional transformations and inside the Module we get the power of Flax's stateful variables and PRNG sequence, and the ``apply`` method is our bridge between these two worlds. But what if we want to use JAX transformations **inside** Modules? The answer to this is functionalization. This procedure itself is tedious and error-prone but handled internally by Flax. At a high-level we can summarize it as follows. For a method ``fn`` defined within a Module: #. Collect the state (variables & PRNG sequences) of the Module(s) that should be available inside the JAX transformation and take a snapshot of it. #. Call the JAX transformation with the original arguments and the collected state. Then inside the transformation: #. Unpack the state and recreate the Modules #. Call the user code ``fn`` #. Collect the updated variables and rng and return it together with the original return values from ``fn`` #. Update the original state with the updated state returned from the transformation. A more in depth explanation of functionalization and lifting can be found in the `Lifted Transformation `_ design note. Practical consequences ========================== For the most part functionalization is something that is handled automatically for you. Still there are some constraints that you must take into account. Most importantly, Flax only handles the stateful primitives (Linen variables and RNGs) and not arbitrary stateful Python code. Most importantly: You cannot close over stateful objects and ``Module`` objects because they are invisible to Flax's internals (and to JAX in general). .. testcode:: class Foo(nn.Module): @nn.compact def __call__(self, x): dense = nn.Dense(x.shape[-1]) fn = lambda x: dense(x) + 1 # simply calling inner works fine # return self.inner(x, fn) # but applying a transformation doesn't: vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True}) return vmap_inner(self, x, fn) def inner(self, x, fn): for i in range(3): x = fn(x) return x Here ``inner`` takes a function that closes over a Module instance. In this example, that works fine because we are not transforming the inner method with a lifted transformation. Most methods are not transformed but it is good to know how to make Module methods transformable. The main obstacle for transformability are types that JAX does not recognize. JAX only understands `Pytree `_ arguments; i.e. arbitrarily nested Python containers (dict, list, tuple) of (Jax) numpy ndarrays and Python numbers/bools. Flax allows to define dataclasses which are Pytree compatible using the `flax.struct `_ API. Function closure is the most common way to accidentally hide a JAX array or Linen Module from a transformation. There is however an easy workaround if you want to pass closures that are also compatible with JAX and Linen transformations: .. testcode:: class Partial(flax.struct.PyTreeNode): fn: Callable = flax.struct.field(pytree_node=False) args: Iterable[Any] def __call__(self, *args, **kwargs): return self.fn(*(tuple(self.args) + args), **kwargs) class Foo(nn.Module): @nn.compact def __call__(self, x): dense = nn.Dense(x.shape[-1]) fn = lambda mdl, x: mdl(x) + 1 vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True}) return vmap_inner(self, x, Partial(fn, [dense])) def inner(self, x, fn): for i in range(3): x = fn(x) return x .. testcode:: :hide: x = jax.numpy.ones((3, 2)) mdl = Foo() vars = mdl.init(random.key(0), x) assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2) Here the closure is implemented using a Flax dataclass. The function itself is annotated with ``flax.struct.field(pytree_node=False)`` to indicate that it does not contain JAX Arrays or Linen Modules. The partially applied ``args`` on the other hand is treated as a pytree container. We rewrite the closure to use Partial. Now the inner method can be transformed using lifted transformations. Future work ************* Setup for unbound Modules =========================== The current Module abstraction is particularly restrictive when it comes to initializing fields after construction. In the current Module API, the ``setup`` method is the place to initialize the fields of the Module instance. Because ``setup`` is only called on a bound Module, the full Module API is available inside ``setup``, including variable declaration. However, oftentimes we don't actually require any stateful API's to initialize a field. In fact, most commonly we simply want to declare a submodule. More importantly, it's often useful to inspect submodules for debugging or to partially run the model. Consider for example: .. testcode:: class AutoEncoder(nn.Module): def setup(self): self.encoder = Encoder(...) self.decoder = Decoder(...) Imagine we want to call just the decoder using `auto_encoder.decoder.apply(decoder_variables, x)`. With the current setup API this does not work because we must first bind the variables before setup is called and the decoder attribute is defined. Of course we can manually construct the Decoder Module with the same attributes as in setup but this is not ideal in many cases. There are two possible solutions to make this use case more ergonomic. First, setup could be made to run immediately after construction before it becomes bound. This means you can still create sub modules but you can no longer define or manipulate variables. Therefore, this would be a breaking change and it would require a new API for defining variables lazily Alternatively, an additional special method could be introduced that runs right away after Module construction and before it becomes bound. In this case, the ``setup`` method would preserve its original semantics. ================================================ FILE: docs/examples/community_examples.rst ================================================ Community examples ================== In addition to the `curated list of official Flax examples on GitHub `__, there is a growing community of people using Flax to build new types of machine learning models. We are happy to showcase any example built by the community here! If you want to submit your own Flax example, you can start by forking one of the `official Flax examples on GitHub `__. Models ****** .. list-table:: :header-rows: 1 * - Link - Author - Task type - Reference * - `matthias-wright/flaxmodels `__ - `@matthias-wright `__ - Various - GPT-2, ResNet, StyleGAN-2, VGG, ... * - `DarshanDeshpande/jax-models `__ - `@DarshanDeshpande `__ - Various - Segformer, Swin Transformer, ... also some stand-alone layers * - `google/vision_transformer `__ - `@andsteing `__ - Image classification, image/text - https://arxiv.org/abs/2010.11929, https://arxiv.org/abs/2105.01601, https://arxiv.org/abs/2111.07991, ... * - `jax-resnet `__ - `@n2cholas `__ - Various resnet implementations - `torch.hub `__ * - `Wav2Vec2 finetuning `__ - `@vasudevgupta7 `__ - Automatic Speech Recognition - https://arxiv.org/abs/2006.11477 Examples ******** .. list-table:: :header-rows: 1 * - Link - Author - Task type - Reference * - `JAX-RL `__ - `@henry-prior `__ - Reinforcement learning - N/A * - `BigBird Fine-tuning `__ - `@vasudevgupta7 `__ - Question-Answering - https://arxiv.org/abs/2007.14062 * - `DCGAN `__ - `@bkkaggle `__ - Image Synthesis - https://arxiv.org/abs/1511.06434 * - `denoising-diffusion-flax `__ - `@yiyixuxu `__ - Image generation - https://arxiv.org/abs/2006.11239 Tutorials ********* .. currently left empty as a placeholder for tutorials .. list-table:: :header-rows: 1 * - Link - Author - Task type - Reference * - - - - Contributing policy ******************* If you are interested in adding a project to the Community Examples section, take the following into consideration: * **Code examples**: Examples must contain a README that is helpful, clear, and explains how to run the code. The code itself should be easy to follow. * **Tutorials**: These docs should preferrably be a Jupyter Notebook format (refer to `Contributing `__ to learn how to convert a Jupyter Notebook into a Markdown file with `jupytext`). Your tutorial should be well-written, and discuss/describe an interesting topic/task. To avoid duplication, the content of these docs must be different from `existing docs on the Flax documentation site `__ or other community examples mentioned in this document. * **Models**: repositories with models ported to Flax must provide at least one of the following: * Metrics that are comparable to the original work when the model is trained to completion. Having available plots of the metric's history during training is highly encouraged. * Tests to verify numerical equivalence against a well known implementation (same inputs + weights = same outputs) preferably using pretrained weights. In all cases mentioned above, the code must work with the latest stable versions of the following packages: ``jax``, ``flax``, and ``optax``, and make substantial use of Flax. Note that both ``jax`` and ``optax`` are `required packages `__ of ``flax`` (refer to the `installation instructions `__ for more details). ================================================ FILE: docs/examples/core_examples.rst ================================================ Core examples ============= Core examples are hosted on the GitHub Flax repository in the `examples `__ directory. Each example is designed to be **self-contained and easily forkable**, while reproducing relevant results in different areas of machine learning. As discussed in `#231 `__, we decided to go for a standard pattern for all examples including the simplest ones (like MNIST). This makes every example a bit more verbose, but once you know one example, you know the structure of all of them. Having unit tests and integration tests is also very useful when you fork these examples. Some of the examples below have a link "Interactive🕹" that lets you run them directly in Colab. Image classification ******************** - :octicon:`mark-github;0.9em` `MNIST `__ - `Interactive🕹 `__: Convolutional neural network for MNIST classification (featuring simple code). - :octicon:`mark-github;0.9em` `ImageNet `__ - `Interactive🕹 `__: Resnet-50 on ImageNet with weight decay (featuring multi-host SPMD, custom preprocessing, checkpointing, dynamic scaling, mixed precision). Reinforcement learning ********************** - :octicon:`mark-github;0.9em` `Proximal Policy Optimization `__: Learning to play Atari games (featuring single host SPMD, RL setup). Natural language processing *************************** - :octicon:`mark-github;0.9em` `Sequence to sequence for number addition `__: (featuring simple code, LSTM state handling, on the fly data generation). - :octicon:`mark-github;0.9em` `Parts-of-speech tagging `__: Simple transformer encoder model using the universal dependency dataset. - :octicon:`mark-github;0.9em` `Sentiment classification `__: with a LSTM model. - :octicon:`mark-github;0.9em` `Transformer encoder/decoder model trained on WMT `__: Translating English/German (featuring multihost SPMD, dynamic bucketing, attention cache, packed sequences, recipe for TPU training on GCP). - :octicon:`mark-github;0.9em` `Transformer encoder trained on one billion word benchmark `__: for autoregressive language modeling, based on the WMT example above. Generative models ***************** - :octicon:`mark-github;0.9em` `Variational auto-encoder `__: Trained on binarized MNIST (featuring simple code, vmap). Graph modeling ************** - :octicon:`mark-github;0.9em` `Graph Neural Networks `__: Molecular predictions on ogbg-molpcba from the Open Graph Benchmark. Contributing to core Flax examples ********************************** Most of the `core Flax examples on GitHub `__ follow a structure that the Flax dev team found works well with Flax projects. The team strives to make these examples easy to explore and fork. In particular (as per GitHub Issue `#231 `__): - README: contains links to paper, command line, `TensorBoard `__ metrics. - Focus: an example is about a single model/dataset. - Configs: we use ``ml_collections.ConfigDict`` stored under ``configs/``. - Tests: executable ``main.py`` loads ``train.py`` which has ``train_test.py``. - Data: is read from `TensorFlow Datasets `__. - Standalone: every directory is self-contained. - Requirements: versions are pinned in ``requirements.txt``. - Boilerplate: is reduced by using `clu `__. - Interactive: the example can be explored with a `Colab `__. ================================================ FILE: docs/examples/google_research_examples.rst ================================================ ######################## Google Research examples ######################## A collection of research by Google Research made with Flax. Attention ********* Fast Attention (FAVOR+) and Rethinking Attention with Performers ================================================================ - Code on GitHub: - `Performer's Fast Attention (FAVOR+) module `__ - Research paper: - `Rethinking Attention with Performers `__ (Choromanski et al., 2020) - Introduces *"Performers, Transformer architectures which can estimate regular (softmax) full-rank-attention Transformers with provable accuracy, but using only linear (as opposed to quadratic) space and time complexity, without relying on any priors such as sparsity or low-rankness. To approximate softmax attention-kernels, Performers use a novel Fast Attention Via positive Orthogonal Random features approach (FAVOR+), which may be of independent interest for scalable kernel methods. FAVOR+ can be also used to efficiently model kernelizable attention mechanisms beyond softmax."* Self-attention Does Not Need O(n^2) Memory ========================================== - `Code on GitHub `__ - `Colab notebook `__ - Research paper: - `Self-attention Does Not Need O(n^2) Memory `__ (Rabe and Staats, 2021) - *"We present a very simple algorithm for attention that requires O(1) memory with respect to sequence length and an extension to self-attention that requires O(log n) memory. This is in contrast with the frequently stated belief that self-attention requires O(n^2) memory. While the time complexity is still O(n^2), device memory rather than compute capability is often the limiting factor on modern accelerators. Thus, reducing the memory requirements of attention allows processing of longer sequences than might otherwise be feasible..."* Computer vision *************** Colorization Transformer (ColTran) ================================== - `Code on GitHub `__ - Research paper: - `Colorization Transformer `__ (Kumar et al., 2020) - *"We presented the Colorization Transformer (ColTran), an architecture that entirely relies on self-attention for image colorization. We introduce conditional transformer layers, a novel building block for conditional, generative models based on self-attention. Our ablations show the superiority of employing this mechanism over a number of different baselines. Finally, we demonstrate that ColTran can generate diverse, high-fidelity colorizations on ImageNet, which are largely indistinguishable from the ground-truth even for human raters."* Vision Transformer (ViT), MLP-Mixer Architectures *and* Big Vision ================================================================== - Code on GitHub: - `Vision Transformer and MLP-Mixer Architectures `__ - `Big Vision `__ - *"This codebase is designed for training large-scale vision models using Cloud TPU VMs or GPU machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable and reproducible input pipelines."* - `Colab notebooks `__: - The JAX code of Vision Transformers and MLP Mixers - More than 50k Vision Transformer and hybrid checkpoints that were used to generate the data of "How to train your ViT?" - Research papers: - `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ (Dosovitskiy et al., 2020) - *"In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train."* - `MLP-Mixer: An All-MLP Architecture for Vision `__ (Tolstikhin et al., 2021) - *"In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models."* - `How to Train Your ViT? Data, Augmentation, and Regularization in Vision Transformers `__ (Steiner et al., 2021) - *"Vision Transformers (ViT) have been shown to attain highly competitive performance for a wide range of vision applications, such as image classification, object detection and semantic image segmentation. In comparison to convolutional neural networks, the Vision Transformer's weaker inductive bias is generally found to cause an increased reliance on model regularization or data augmentation ("AugReg" for short) when training on smaller training datasets. We conduct a systematic empirical study in order to better understand the interplay between the amount of training data, AugReg, model size and compute budget."* - `When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations `__ (X. Chen et al., 2021) - *"Vision Transformers (ViTs) and MLPs signal further efforts on replacing hand-wired features or inductive biases with general-purpose neural architectures. Existing works empower the models by massive data, such as large-scale pre-training and/or repeated strong data augmentations, and still report optimization-related problems (e.g., sensitivity to initialization and learning rates). Hence, this paper investigates ViTs and MLP-Mixers from the lens of loss geometry, intending to improve the models' data efficiency at training and generalization at inference."* - `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ (X. Zhai et al., 2021) - *"This paper presents contrastive-tuning, a simple method employing contrastive training to align image and text models while still taking advantage of their pre-training. In our empirical study we find that locked pre-trained image models with unlocked text models work best. We call this instance of contrastive-tuning "Locked-image Tuning" (LiT), which just teaches a text model to read out good representations from a pre-trained image model for new tasks. A LiT model gains the capability of zero-shot transfer to new vision tasks, such as image classification or retrieval. The proposed LiT is widely applicable; it works reliably with multiple pre-training methods (supervised and unsupervised) and across diverse architectures (ResNet, Vision Transformers and MLP-Mixer) using three different image-text datasets."* Scaling Vision with Sparse Mixture of Experts (MoE) =================================================== - `Code on GitHub `__ - Research paper: - `Scaling Vision with Sparse Mixture of Experts `__ (Riquelme et al., 2021) - *"Sparsely-gated Mixture of Experts networks (MoEs) have demonstrated excellent scalability in Natural Language Processing. In Computer Vision, however, almost all performant networks are "dense", that is, every input is processed by every parameter. We present a Vision MoE (V-MoE), a sparse version of the Vision Transformer, that is scalable and competitive with the largest dense networks... we demonstrate the potential of V-MoE to scale vision models, and train a 15B parameter model that attains 90.35% on ImageNet..."* Diffusion ********* Variational Diffusion Models ============================ - `Code on GitHub `__ - `Colab notebooks `__ - Research paper: - `Variational Diffusion Models `__ (Kingma et al., 2021) - *"Diffusion-based generative models have demonstrated a capacity for perceptually impressive synthesis, but can they also be great likelihood-based models? We answer this in the affirmative, and introduce a family of diffusion-based generative models that obtain state-of-the-art likelihoods on standard image density estimation benchmarks. Unlike other diffusion-based models, our method allows for efficient optimization of the noise schedule jointly with the rest of the model. We show that the variational lower bound (VLB) simplifies to a remarkably short expression in terms of the signal-to-noise ratio of the diffused data, thereby improving our theoretical understanding of this model class. Using this insight, we prove an equivalence between several models proposed in the literature. In addition, we show that the continuous-time VLB is invariant to the noise schedule, except for the signal-to-noise ratio at its endpoints. This enables us to learn a noise schedule that minimizes the variance of the resulting VLB estimator, leading to faster optimization..."* Domain adaptation ***************** GIFT (Gradual Interpolation of Features toward Target) ====================================================== - `Code on GitHub `__ - Research paper: - `Gradual Domain Adaptation in the Wild: When Intermediate Distributions are Absent `__ (Abnar et al., 2021) - *"We focus on the problem of domain adaptation when the goal is shifting the model towards the target distribution, rather than learning domain invariant representations. It has been shown that under the following two assumptions: (a) access to samples from intermediate distributions, and (b) samples being annotated with the amount of change from the source distribution, self-training can be successfully applied on gradually shifted samples to adapt the model toward the target distribution. We hypothesize having (a) is enough to enable iterative self-training to slowly adapt the model to the target distribution, by making use of an implicit curriculum. In the case where (a) does not hold, we observe that iterative self-training falls short. We propose GIFT, a method that creates virtual samples from intermediate distributions by interpolating representations of examples from source and target domains..."* Generalization ************** Surrogate Gap Minimization Improves Sharpness-Aware Training ============================================================ - `Code on GitHub `__ - Research paper: - `Surrogate Gap Minimization Improves Sharpness-Aware Training `__ (J. Zhuang et al., 2022) - *"The recently proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. However, we show that both sharp and flat minima can have a low perturbed loss, implying that SAM does not always prefer flat minima. Instead, we define a surrogate gap, a measure equivalent to the dominant eigenvalue of Hessian at a local minimum when the radius of neighborhood (to derive the perturbed loss) is small. The surrogate gap is easy to compute and feasible for direct minimization during training. Based on the above observations, we propose Surrogate Gap Guided Sharpness-Aware Minimization (GSAM), a novel improvement over SAM with negligible computation overhead..."* Meta learning ************* ``learned_optimization`` ======================= - Code on GitHub: `learned_optimization `__ - `Colab notebooks `__ - Research papers: - `Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies `__ (Vicol et al., 2021) - *"We introduce a method called Persistent Evolution Strategies (PES), which divides the computation graph into a series of truncated unrolls, and performs an evolution strategies-based update step after each unroll. PES eliminates bias from these truncations by accumulating correction terms over the entire sequence of unrolls. PES allows for rapid parameter updates, has low memory usage, is unbiased, and has reasonable variance characteristics."* - `Gradients Are Not All You Need `__ (Metz et al., 2021) - *"...In this short report, we discuss a common chaos based failure mode which appears in a variety of differentiable circumstances, ranging from recurrent neural networks and numerical physics simulation to training learned optimizers. We trace this failure to the spectrum of the Jacobian of the system under study, and provide criteria for when a practitioner might expect this failure to spoil their differentiation based optimization algorithms."* Model efficiency **************** Efficiently Scaling Transformer Inference ========================================= - Code on GitHub: - `T5X `__ - `AQT: Accurate Quantized Training `__ - Research paper: - `Efficiently Scaling Transformer Inference `__ (Pope et al., 2022) - *"We develop a simple analytical model for inference efficiency to select the best multi-dimensional partitioning techniques optimized for TPU v4 slices based on the application requirements. We combine these with a suite of low-level optimizations to achieve a new Pareto frontier on the latency and model FLOPS utilization (MFU) tradeoffs on 500B+ parameter models that outperforms the FasterTransformer suite of benchmarks. We further show that with appropriate partitioning, the lower memory requirements of multiquery attention (i.e. multiple query heads share single key/value head) enables scaling up to 32× larger context lengths."* Neural rendering / NeRF *********************** Generalizable Patch-Based Neural Rendering ========================================== - `Code on GitHub `__ - Research paper: - `Generalizable Patch-Based Neural Rendering `__ (Suhail et al., 2022) - *"...We propose a different paradigm, where no deep features and no NeRF-like volume rendering are needed. Our method is capable of predicting the color of a target ray in a novel scene directly, just from a collection of patches sampled from the scene."* Voxel-based Radiance Fields in JAX and Flax =========================================== - `Colab notebook `__ (Velez and Dellaert, 2022) - *"In this notebook we show how with JAX/Flax, it is relatively easy to quickly get a voxel-based NeRF variant up and running. Specifically, we will develop a simplified version of DVGO that directly regresses color instead of having a small MLP. It works remarkably well."* Optimization ************ Amos Optimizer *and* JEstimator =============================== - Code on GitHub: - `Amos and JEstimator `__ - *"... implements Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX, which we use to run experiments in the paper."* - Research paper: - `Amos: An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale `__ (Tian and Parikh, 2022) - Presents *"Amos, an optimizer compatible with the optax library, and JEstimator, a light-weight library with a tf.Estimator-like interface to manage T5X-compatible checkpoints for machine learning programs in JAX."* *"When used for pre-training BERT variants and T5, Amos consistently converges faster than the state-of-the-art settings of AdamW, achieving better validation loss within <=70% training steps and time, while requiring <=51% memory for slot variables."* Quantization ************ Pareto-Optimal Quantized ResNet Is Mostly 4-bit *and* AQT: Accurate Quantized Training ====================================================================================== - Code on GitHub: - `AQT: Accurate Quantized Training `__ - Research paper: - `Pareto-Optimal Quantized ResNet Is Mostly 4-bit `__ (Abdolrashidi et al., 2021) - *"In this work, we use ResNet as a case study to systematically investigate the effects of quantization on inference compute cost-quality tradeoff curves. Our results suggest that for each bfloat16 ResNet model, there are quantized models with lower cost and higher accuracy; in other words, the bfloat16 compute cost-quality tradeoff curve is Pareto-dominated by the 4-bit and 8-bit curves, with models primarily quantized to 4-bit yielding the best Pareto curve... The quantization method we used is optimized for practicality: It requires little tuning and is designed with hardware capabilities in mind... As part of this work, we contribute a quantization library written in JAX..."* Reinforcement learning ********************** Continuous Control with Action Quantization from Demonstrations (AQuaDem) ========================================================================= - `Code on GitHub `__ - Research paper: - `Continuous Control with Action Quantization from Demonstrations `__ (Dadashi et al., 2021) - Proposes *"a novel Reinforcement Learning (RL) framework for problems with continuous action spaces: Action Quantization from Demonstrations (AQuaDem). The proposed approach consists in learning a discretization of continuous action spaces from human demonstrations. This discretization returns a set of plausible actions (in light of the demonstrations) for each input state, thus capturing the priors of the demonstrator and their multimodal behavior. By discretizing the action space, any discrete action deep RL technique can be readily applied to the continuous control problem. Experiments show that the proposed approach outperforms state-of-the-art methods such as SAC in the RL setup, and GAIL in the Imitation Learning setup."* Sequence models / Model parallelism *********************************** T5X: Scaling Up Models and Data with ``t5x`` and ``seqio`` ========================================================== - `Code on GitHub `__ - *"T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales."* - Research paper: - `T5X: Scaling Up Models and Data with t5x and seqio `__ (Roberts et al., 2022) - *"Recent neural network-based language models have benefited greatly from scaling up the size of training datasets and the number of parameters in the models themselves. Scaling can be complicated due to various factors including the need to distribute computation on supercomputer clusters (e.g., TPUs), prevent bottlenecks when infeeding data, and ensure reproducible results. In this work, we present two software libraries that ease these issues: t5x simplifies the process of building and training large language models at scale while maintaining ease of use, and seqio provides a task-based API for simple creation of fast and reproducible training data and evaluation pipelines. These open-source libraries have been used to train models with hundreds of billions of parameters on datasets with multiple terabytes of training data. Along with the libraries, we release configurations and instructions for T5-like encoder-decoder models as well as GPT-like decoder-only architectures."* Simulation ********** Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation ============================================================================ - `Code on GitHub `__ - `Colab notebooks `__ - Research paper: - `Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation `__ (Freeman et al., 2021) - *"We present Brax, an open source library for rigid body simulation with a focus on performance and parallelism on accelerators, written in JAX. We present results on a suite of tasks inspired by the existing reinforcement learning literature, but remade in our engine. Additionally, we provide reimplementations of PPO, SAC, ES, and direct policy optimization in JAX that compile alongside our environments, allowing the learning algorithm and the environment processing to occur on the same device, and to scale seamlessly on accelerators."* ================================================ FILE: docs/examples/index.rst ================================================ Examples ======== .. toctree:: :maxdepth: 2 core_examples google_research_examples repositories_that_use_flax community_examples ================================================ FILE: docs/examples/repositories_that_use_flax.rst ================================================ Repositories that use Flax ========================== The following code bases use Flax and provide training frameworks and a wealth of examples. In many cases, you can also find pre-trained weights: 🤗 Hugging Face *************** `🤗 Hugging Face `__ is a very popular library for building, training, and deploying state of the art machine learning models. These models can be applied on text, images, and audio. After organizing the `JAX/Flax community week `__, they have now over 5,000 `Flax/JAX models `__ in their repository. 🥑 DALLE Mini ************* `🥑 DALLE Mini `__ is a Transformer-based text-to-image model implemented in JAX/Flax that follows the ideas from the original `DALLE `__ paper by OpenAI. Scenic ****** `Scenic `__ is a codebase/library for computer vision research and beyond. Scenic's main focus is around attention-based models. Scenic has been successfully used to develop classification, segmentation, and detection models for multiple modalities including images, video, audio, and multimodal combinations of them. Big Vision ********** `Big Vision `__ is a codebase designed for training large-scale vision models using Cloud TPU VMs or GPU machines. It is based on Jax/Flax libraries, and uses tf.data and TensorFlow Datasets for scalable and reproducible input pipelines. This is the original codebase of ViT, MLP-Mixer, LiT, UViM, and many more models. T5X *** `T5X `__ is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales. ================================================ FILE: docs/faq.rst ================================================ Frequently Asked Questions (FAQ) ================================ This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. Where to search for an answer to a Flax-related question? ********************************************************* There are a number of official Flax resources to search for information: - `Flax Documentation on ReadTheDocs `__ (this site): Use the `search bar `__ or the table of contents on the left-hand side. - `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. - `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)? ************************************************************************************************ To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. For full examples and detailed documentation, go to: - The :meth:`flax.linen.Module.perturb` API docs - The `Extracting gradients of intermediate values `_ guide - `Flax GitHub Discussions #1152 `__ Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`? ************************************************************************ Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose. What are the recommended training loop libraries? ************************************************* Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__. To get started, go to this `CLU Synopsis Colab `__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions `__. Check out the official `google/flax Examples `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub `__. ================================================ FILE: docs/flip/0000-template.md ================================================ - Start Date: (fill me in with today's date, YYYY-MM-DD) - FLIP PR: [#0000](https://github.com/google/flax/pull/0000) - FLIP Issue: [#0000](https://github.com/google/flax/issues/0000) (Below sections are just a possible structure - please adapt to your FLIP.) # Summary [summary]: #summary One paragraph explanation of the FLIP. # Motivation [motivation]: #motivation Why are we doing this? What use cases does it support? What is the expected outcome? # Implementation [implementation]: #implementation The technical part. # Discussion [discussion]: #discussion Summarize the discussion from the original issue and from the pull request. ================================================ FILE: docs/flip/1009-optimizer-api.md ================================================ - Start Date: 2021-02-08 - FLIP PR: [#1011](https://github.com/google/flax/pull/1011) - FLIP Issue: [#1009](https://github.com/google/flax/issues/1009) Table of contents: - [Summary] - [Motivation] - [Using Optax] - [Gradient Transformations] - [Optax Training Step] - [Multi Optimizer] - [Train State] - [Previous API] - [Optimizer and OptimizerDef] - [Previous Training Step] - [Update Plan] - [Appendix] - [Setup Code] # Summary [Summary]: #summary This FLIP proposes to replace our current `flax.optim` API (referred to as [previous API] in this document) with [Optax], DeepMind's optimizer library. # Motivation [motivation]: #motivation Our current API (referred to as [previous API] in this document) uses a pattern where an `Optimizer` dataclass is created from a pytree of `target` variables and from an `OptimizerDef` that defines how to update optimizer state, hyperparameters, and target variables. This pattern is relatively complex for implementing a simple optimizer, while being quite verbose in the typical Linen train step (especially when using mutable state collections). This package `flax.optim` contains some optimizers, but that list is far from exhaustive and ideally we would instead use JAX optimizers from a dedicated PyPi package. DeepMind already has a dedicated library — [Optax] — that implements a wide range of interesting optimizers and provides a framework to compose new optimizers from reusable gradient transformations. [Optax]: https://github.com/deepmind/optax # Using Optax [Using Optax]: #using-optax ## Gradient Transformations [Gradient Transformations]: #gradient-transformations While [Optax] does provide predefined optimizers (like `optax.adam`, or `optax.sgd` with momentum), it is really a library of *gradient transformations* and the idiomatic way of instantiating an optimizer is by providing a combination of these gradient transformations. To emulate the momentum optimizer from the example when using the [previous API] we would write: ```python import optax tx = optax.chain( optax.trace(decay=0.9, nesterov=False), optax.scale_by_schedule(lambda step: -get_learning_rate(step)), ) ``` Remarks: - Above gradient transformation would be equivalent with the example under [Optimizer and OptimizerDef] where we define a Momentum optimizer without Nesterov momentum (note that the `beta` parameter corresponds to the `decay` parameter of the `optax.trace()` transformation, and the learning rate is applied in a second chained transformation). - Note that hyper parameters like `decay` or `nesterov` only exist in the inner scope of the higher order functions returning the `GradientTransformation`. Such a gradient transformation is currently defined as a `NamedTuple` of the `init()` and the `update()` function. In principle this pattern could be extended to also store hyperparameters, maybe a point to discuss on the [Optax] repo. - We can use a `get_learning_rate()` that returns the learning rate depending on the step number when defining the Optax gradient update transformation. Above code illustrates how this can be a drop-in replacement for a function we also use in our [previous training step], where this update function already exists (notice how we need to invert the sign because we add the gradient update to the parameters). In addition, you can use [`inject_hyperparams()`](https://github.com/deepmind/optax/pull/48) to schedule arbitrary hyper parameters with Optax. ## Optax Training Step [Optax Training Step]: #optax-training-step ```python @functools.partial(jax.jit, static_argnums=(4, 5)) def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn): def loss_fn(params): logits, new_model_state = apply_fn( {**variables, 'params': params}, inputs, mutable=['batch_stats']) loss = xent_loss(logits, labels) return loss, new_model_state variables, params = variables.pop('params') (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)( params) updates, new_opt_state = tx_update_fn(grads, opt_state, params) new_params = optax.apply_updates(params, updates) new_variables = {**variables, **new_model_state, 'params': new_params} return new_opt_state, new_variables, loss opt_state = tx.init(variables['params']) for batch in ds.as_numpy_iterator(): opt_state, variables, loss = train_step( opt_state, variables, batch['image'], batch['label'], model.apply, tx.update) print(loss) ``` Remarks: - Since `tx.update()` only transforms the gradient, we still need to call `optax.apply_updates()` to apply these transformed gradients to the parameters. - Compared with the [previous API], we can now keep the entire `variables` including the `params` as an input and output to the `train_step()`. - Splitting `params` from `variables` is still necessary inside the train step because we only want to compute gradients with respect to `params` and not the entire `variables`. - We can still log internal optimizer state, such as the learning rate, as long as Optax transformations expose that information in their respective state. For example, `optax.scale_by_schedule()` currently only exposes `opt_state.count` but could easily be extend to also expose the `step_size`. The same is true for internal optimizer states that change over time. ## Multi Optimizer [Multi Optimizer]: #multi-optimizer The [previous API] defined `flax.optim.MultiOptimizer` for processing different parts of the parameter tree with different optimizers: ```python biases_traversal = flax.optim.ModelParamTraversal( lambda path, _: path.endswith('/bias')) not_biases_traversal = flax.optim.ModelParamTraversal( lambda path, _: not path.endswith('/bias')) optimizer_def = flax.optim.MultiOptimizer( (biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)), (not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)), ) ``` Note how we first define a traversal that selects parameters based on their path (which is the concatenation of module scopes and variable name), and then create a `MultiOptimizer` that binds a different optimizer for each of these separate traversals. Optax has recently implemented `optax.masked()` that can be used for specifying gradient transformations that only applied to a subset of the gradients: ```python def flattened_traversal(fn): def mask(data): flat = traverse_util.flatten_dict(data) return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) return mask tx = optax.chain( optax.masked(optax.sgd(learning_rate=0.1), mask=flattened_traversal(lambda path, _: path[-1] == 'bias')), optax.masked(optax.sgd(learning_rate=0.05), mask=flattened_traversal(lambda path, _: path[-1] != 'bias')), ) ``` ## Train State [Train State]: #train-state In Flax it is common to hand around a `TrainState` object that can then be used for checkpointing. This simplifies the above [Optax training step] a bit by reducing the number of arguments and getting rid of the `static_argnums`. We can define a `TrainState` dataclass that wraps the common pattern of updating the optimizer state and parameters by applying the gradients. ```python # Small helper class in flax.training class TrainState(flax.struct.PyTreeNode): step: int apply_fn: Callable = flax.struct.field(pytree_node=False) params: flax.core.FrozenDict[str, Any] tx: optax.GradientTransformation = flax.struct.field(pytree_node=False) opt_state: optax.OptState def apply_gradients(self, *, grads, **kwargs): updates, new_opt_state = self.tx.update( grads, self.opt_state, self.params) new_params = optax.apply_updates(self.params, updates) return self.replace( step=self.step + 1, params=new_params, opt_state=new_opt_state, **kwargs, ) @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): opt_state = tx.init(params) return cls( step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs, ) ``` Users can then derive from this dataclass and add more fields, for example mutable model state: ```python from flax.training import train_state class TrainState(train_state.TrainState): batch_stats: flax.core.FrozenDict[str, Any] ``` With this the [Optax Training Step] becomes: ```python @jax.jit def train_step(state, inputs, labels): def loss_fn(params): outputs, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, inputs, mutable=['batch_stats']) loss = xent_loss(outputs, labels) return loss, new_model_state (loss, new_model_state), grads = jax.value_and_grad( loss_fn, has_aux=True)(state.params) new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state['batch_stats'], ) return new_state, loss state = TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, batch_stats=variables['batch_stats'], ) for batch in ds.as_numpy_iterator(): state, loss = train_step(state, batch['image'], batch['label']) ``` The train step without mutable state reduces to: ```python @jax.jit def train_step(state, inputs, labels): def loss_fn(params): outputs = state.apply_fn({'params': params}, inputs) loss = xent_loss(outputs, labels) return loss loss, grads = jax.value_and_grad(loss_fn)(state.params) new_state = state.update(grads=grads) return new_state, loss state = flax.training.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, ) for batch in ds.as_numpy_iterator(): state, loss = train_step(state, batch['image'], batch['label']) ``` Remarks: - It is a common pattern in Flax training loops to have a `TrainState` dataclass that is updated with new state after every step. - The simple solution proposed in `flax.training.train_state` an be extended with additional data, but advanced usecases (e.g. multiple different models and/or optimizers) are not supported. Users should instead fork the dataclass and re-implement it to their needs. - As opposed to the `Optimizer` abstraction in the [previous API], the `TrainState` now directly contains the `.params`, without having to to through `.optimizer` # Previous API [previous API]: #previous-api ## Optimizer and OptimizerDef [Optimizer and OptimizerDef]: #optimizer-and-optimizerdef The optimizer itself would be implemented by creating a new class derived from `OpimizerDef`: ```python # flax/optim/momentum.py @flax.struct.dataclass class _MomentumHyperParams: learning_rate: jnp.ndarray beta: jnp.ndarray @flax.struct.dataclass class _MomentumParamState: momentum: np.ndarray class Momentum(flax.optim.OptimizerDef): def __init__(self, learning_rate=None, beta=0.9): super().__init__( _MomentumHyperParams(learning_rate, beta) ) def init_param_state(self, param): return _MomentumParamState(jnp.zeros_like(param)) def apply_param_gradient(self, step, hyper_params, param, state, grad): del step assert hyper_params.learning_rate is not None new_momentum = state.momentum * hyper_params.beta + grad new_params = param - hyper_params.learning_rate * new_momentum return new_params, _MomentumParamState(new_momentum) ``` Remarks: - Note the relationship between `OptimizerDef` and `Optimizer` : When the function `Optimizer.apply_gradient()` is called from the user code, it calls into `OptimizerDef.apply_gradient()` (among other things) which in turn will call `OptimizerDef.apply_param_gradient()` (implemented by subclasses of `OptimizerDef`). - The functions `init_param_state()` and `apply_param_gradient()` are called for every leaf in the params/grads pytree. This makes it possible to write the calculations directly without `jax.tree_util.tree_map()`. - The interface was defined in pre-Linen without the distinction of `params` vs. other collections in `variables` in mind. The original API was elegant because one only needed to pass around the optimizer, which included the parameters, optimizer state, optimizer hyperparameters, and a reference to the `OptimizerDef` to perform the param/state update. ## Previous Training Step [Previous Training Step]: #previous-training-step An optimizer would first be constructed from its definition and the pytree of target params: ```python optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = optimizer_def.create(variables['params']) ``` Then, the target variables would optimized in the train step (assuming a single non-params collection "batch_stats"): ```python def make_train_step(apply_fn): @jax.jit def train_step(optimizer, batch_stats, inputs, labels): def loss_fn(params): variables = {'params': params, 'batch_stats': batch_stats} logits, new_model_state = apply_fn( variables, inputs, mutable=['batch_stats']) loss = xent_loss(logits, labels) return loss, new_model_state['batch_stats'] (loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)( optimizer.target) lr = get_learning_rate(step) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) return new_optimizer, new_batch_stats, loss return train_step batch_stats = variables['batch_stats'] train_step = make_train_step(model.apply) for step, batch in enumerate(ds) optimizer, batch_stats, loss = train_step( optimizer, batch_stats, batch['image'], batch['label']) ``` Remarks: - Notice how `optimizer.apply_gradient()` can take additional arguments to update hyperparameters, such as learning rate from an independent function `get_learning_rate()` in this case. # Update Plan [Update Plan]: #update-plan 1. Finalize discussions on this FLIP 2. Add [equivalence tests] to Optax that guarantee that existing `flax.optim` optimizers return identical values with corresponding `optax` optimizers. 3. Update examples to use Optax and verify that they reach the same final performance with the same computational cost. 4. Port missing optimizers to Optax (e.g. Adafactor) - and verify above points. 5. Update all documentation (including README, Flax guided tour, HOWTOs, ...) to talk exclusively about Optax optimizers. 6. Create a transition guide for updating users from `flax.optim` to using Optax. This transition guide should also point to Optax's [equivalence tests] and the pull requests updating the examples. 7. Mark optimizers in `flax.optim` as deprecated. [equivalence tests]: https://github.com/deepmind/optax/blob/master/optax/_src/equivalence_test.py Note that all current Flax examples use an optimizer that is already available in Optax: | Example | Flax | Optax | Comments | | -------- | -------------- | ----------- | ----------------------------------- | | imagenet | optim.Momentum | optax.sgd | DynamicScale can be used unchanged. | | mnist | optim.Momentum | optax.sgd | | | nlp_seq | optim.Adam | optax.adamw | | | pixelcnn | optim.Adam | optax.adam | | | ppo | optim.Adam | optax.adam | | | seq2seq | optim.Adam | optax.adam | | | vae | optim.Adam | optax.adam | | | wmt | optim.Adam | optax.adamw | | (Flax's Adam implementation has an optional parameter for weight decay, but in Optax Adam with and without weight decay are two different aliases.) # Appendix [Appendix]: #appendix ## Setup Code [Setup Code]: #setup-code The following setup code can be used for running the code snippets in this FLIP: ```python import functools from typing import Callable, Sequence import jax import jax.numpy as jnp import flax import flax.linen as nn import tensorflow as tf import tensorflow_datasets as tfds def pp(features): return { 'image': tf.cast(features['image'], tf.float32) / 255 - 0.5, 'label': features['label'], } class Model(nn.Module): @nn.compact def __call__(self, inputs): x = inputs.reshape([inputs.shape[0], -1]) x = nn.normalization.BatchNorm(True)(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x def onehot(labels, num_classes, on_value=1.0, off_value=0.0): x = (labels[..., None] == jnp.arange(num_classes)[None]) x = jax.lax.select( x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) return x.astype(jnp.float32) def xent_loss(logits, labels): return -jnp.sum( onehot(labels, num_classes=10) * logits) / labels.size def get_learning_rate(step): return 0.1 model = Model() rng = jax.random.key(0) ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16) batch = next(iter(ds)) variables = model.init(rng, jnp.array(batch['image'][:1])) jax.tree_util.tree_map(jnp.shape, variables) ``` ================================================ FILE: docs/flip/1777-default-dtype.md ================================================ # FLIP: Default dtypes - Start Date: 2022-01-11 - FLIP PR: [#1776](https://github.com/google/flax/pull/1776) - FLIP Issue: [#1777](https://github.com/google/flax/issues/1777) - Status: Implemented ## Summary This FLIP proposes to replace the default dtype which is currently fixed to float32, and instead use the JAX type promotion results to derive a default dtype from the input and parameters of a layer. ## Motivation Currently, Linen Modules always produce `module.dtype` (defaults to float32) outputs regardless of input and parameter dtypes. Half-precision types like float16 and bfloat16 are supported by explicitly passing the half-precision type to each Module. The way this is currently implemented is that each Module has a dtype argument with float32 as the default value. The layer guarantees that this dtype will be the return type of the result returned by `__call__`. The current behavior is problematic and results in silent bugs, especially for dtypes that do not fit inside float32 (complex, float64). Also, the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes. ### Dtypes in JAX JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: ![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg) ## Dtypes in Linen Besides input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently, the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32. By using the dtype promotion rules already available in JAX we can avoid this issue. A public API is available called `jax.numpy.result_dtype(*args)`, which returns the dtype that JAX would promote the given arguments to, in accordance with the type promotion lattice. For Linen layers the arguments would be the layer inputs together with the parameters. For example, for a linear layer this would be inputs, kernel, and bias. Note that there is also a `param_dtype` attribute in standard Linen Modules that also defaults to flaot32. This behavior is left untouched and encodes the common case of having float32 parameters. There are a few reasons why float32 is almost always the correct dtype for parameters: 1. Storing weights in half-precision often leads to underflow during optimization. 2. Double precision is rarely used because it severely slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for. 3. Complex Modules are relatively uncommon. Even within complex networks, the complex inputs can be projected with a real matrix. # Implementation A simplified example implementation: ```python def promote_arrays(*xs, dtype): if dtype is None: dtype = jnp.result_type(*jax.tree_util.tree_leaves(xs)) return jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), xs) Dtype = Any class Dense(nn.Module): features: int kernel_init: Callable bias_init: Callable dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, x): kernel = self.param("kernel", self.kernel_init, (x.shape[-1], self.features), self.param_dtype) bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) x, kernel, bias = promote_arrays(x, kernel, bias, dtype=self.dtype) return x @ kernel + bias ``` ## Half-precision dtypes Some layers don’t work with half-precision dtypes internally. For example: The normalization layers currently compute mean and variance in float32 even when a half-precision dtype is specified to avoid numerical issues. We can replicate this behavior by calling result_dtype with a dummy argument that has the minimum precision for the sub computation to work correctly. ## Backward compatibility This proposal causes some layers to behave differently in cases where the dtype is not specified to a Linen Module. By default, parameters are in float32. Therefore, passing in half or float32 precision inputs will cause a float32 dtype and no functional differences with current behavior. When passing complex or float64 precision, the result will no longer truncate the imaginary component or the precision. The silent truncation is problematic and has caused [user complaints](https://github.com/google/flax/issues/805#issuecomment-981468837). Therefore, this change can be considered a bugfix. Thus, although this proposal strictly speaking changes behavior it is unlikely to cause problems for users. There are 2 exceptions to this which should be rare and easy to fix: 1. A user relies on the enforced float32 to downcast a double precision value. 2. A user relies on the float32 to explicitly upcast a half precision value even though the weights are in half precision. ## Corner cases In this section we describe corner cases where the implementation of the proposal is not obvious. The two main concerns are how complex numbers are handled in existing layers and how to determine the dtype of state variables. **Autoregressive decoding cache** Currently, only attention implements autoregressive caching and the stored key and value mirror the dtype of the key and value passed to the layer. Forcing the cache dtype to be the same as the output dtype could result in reduced precision during cached decoding vs uncached. This seems undesirable. Decision: keep the current behavior. **Batch statistics** BatchNorm layers are often used with a half precision output dtype. However, calculating statistics is by default always done in float32 to avoid numerical precision issues and over/underflow for float16. With float64 this would actually cause a downcast so we should now use `np.promote_types(float32, dtype)` such that the precision is at least float32. The running batch statistics will be stored with the same dtype for consistency. **Complex number support** Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex numbers correctly: 1. Normalization layers use the complex conjugate to calculate norms instead of normal squaring. 2. Attention: It’s not exactly clear how the dot product and softmax are defined in this case. Raise an error on complex inputs. 3. Recurrent layers: might require special gating / activation functions to function correctly, but these can be specified by the user. # Discussion Summarizing the main points from the discussion: ## Consider implicit complex truncation an error Q: I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do. (Maybe it's a contrived example, but I can imagine cases where layers have their dtype set by parent modules based on assumptions without complex numbers in mind) A: This is worth considering in a follow-up CL but this might as well be solved in JAX directly where the safeguard would apply more generally. In NumPy this was also considered but abandoned because it is not backwards compatible. ## Dtype attribute names Q: Are the dtype and param_dtype arguments confusing? In particular, should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit? A: This would be a large and orthogonal change wrt to this proposal so leaving it out for now. Also, this breaks with the standard dtype argument in NumPY/JAX. Although dtype indeed constrains the output dtype it is also a hint for the dtype we would like the computation to happen in. ================================================ FILE: docs/flip/2396-rnn.md ================================================ # RNN Flip - Start Date: 2022-08-18 - FLIP PR: [#2604](https://github.com/google/flax/pull/2604) - FLIP Issue: [#2396](https://github.com/google/flax/issues/2396) - Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae) ## Summary This FLIP adds support for higher-level recurrent layers (RNN, GRU, LSTM) that can help users process input sequences using the recurrent cells already available in Flax. ## Motivation Implementing well known recurrent architectures is tricky and prone to user errors, even a simple LSTM layers involves the manual creation and handling of the carry/memory and correctly setting up `nn.scan`: ```python @nn.compact def __call__(self, x): LSTM = nn.scan( nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False} ) carry = LSTM.initialize_carry( jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size ) carry, x = LSTM()(carry, x) return x ``` Slightly more complicated cases involving padding like in the [seq2seq](https://github.com/google/flax/blob/main/examples/seq2seq/models.py) example require even more work but couple potentially be simplified to a couple of lines with the right abstractions. We propose providing users with clean, correct, and efficient abstractions to use recurrent cells. ## Requirements * **Masking**: We need to support a batch of sequences that contain padding at the end of each sequence. * We do not intend to support non-contiguous padding, i.e. padding that is not at the end of a sequence, for performance reasons, except in the case of packing (see below). * **Bidirectionality**: The ability to process a sequence in both the forward and reverse directions, respecting padding (i.e., the reverse direction should start with the actual inputs, not with padding values). * **Performance**: The proposed classes should be benchmarked to provide the best performance in terms of step time and/or memory use. * **Recurrent Dropout**: Support for recurrent dropout in cells (e.g. dropout on the state of the cell). ## Implementation ### High-level structure We propose to have these 3 levels of abstraction: * **Cells (unchanged)**: all RNNCellBase subclasses such as LSTMCell and GRUCell, these implement the stepwise logic. These already exist in Flax today. * **Layers (new)**: a class (RNN) that takes a cell and scans over a sequence respecting possible padding values and optionally also allows packed sequences. * **Bidirectional (new)**: a single class that takes a forward and a backward RNN instance and correctly processes the input sequence in both directions and merges the results. ### Example of proposed API We start with a code example of what you could do with the proposed API, and then we discuss the API in detail below. ```python cell = nn.LSTMCell() # Encodes a batch of input sequences. carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths) ``` A Bidirectional layer with a LSTM RNNs for the forward and backward directions respectively would look like this: ```python forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) # Bidirectional combinator. bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) # Encodes a batch of input sequences in both directions. carry, outputs = bi_rnn(inputs, seq_lengths) ``` Next we will discuss `RNN`, `Bidirectional`, and proposed changes to `RNNCellBase`. ### RNNBase The `RNNBase` class serves as a base class for the `RNN` class, it specifies the API that all RNN layers should implement to be compatible with the `Bidirectional`. `RNNBase` contains the `__call__` and `flip_sequences` methods: ```python class RNNBase(Protocol): def __call__( self, inputs: jax.Array, *, initial_carry: Optional[Carry] = None, init_key: Optional[random.KeyArray] = None, seq_lengths: Optional[Array] = None, return_carry: Optional[bool] = None, time_major: Optional[bool] = None, reverse: Optional[bool] = None, keep_order: Optional[bool] = None, ) -> Union[Output, Tuple[Carry, Output]]: ... ``` Where: * `inputs`: the input sequence. * `initial_carry`: the initial carry, if not provided it will be initialized using the cell's :meth:`RNNCellBase.initialize_carry` method. * `init_key`: a PRNG key used to initialize the carry, if not provided ``jax.random.key(0)`` will be used. Most cells will ignore this argument. * `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored. * `return_carry`: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. * `time_major`: if ``time_major=False`` (default) it will expect inputs with shape ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. * `reverse`: if ``reverse=False`` (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If ``seq_lengths`` is passed, padding will always remain at the end of the sequence. * `keep_order`: if ``keep_order=True``, when ``reverse=True`` the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default), the output will remain in the order specified by ``reverse``. * `Returns`: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. ### RNN The `RNN` module inherits from `RNNBase`, it main function is to apply an `RNNCellBase` instance over a batch of input sequences, it can be used with any type of cell (e.g., `GRUCell`, `LSTMCell`, etc). It accepts the following parameters: ```python class RNN(RNNBase): cell: RNNCellBase, cell_size: int | Tuple[int, ...] time_axis: int = -2, variable_axes = FrozenDict(), variable_broadcast: CollectionFilter = 'params' variable_carry: CollectionFilter = False split_rngs = FrozenDict({'params': False}) # implement RNNBase ... ``` Attributes like `variable_axes`, `variable_broadcast`, `variable_carry`, and `split_rngs` are directly passed to `nn.scan`, their default values are set such that common cells like `LSTMCell` and `GRUCell` work out of the box. ### Masking `seq_lengths` is defined as an integer array of shape `(*batch,)` indicating the length of each sequence.
Discussion There are various masking formats found in other frameworks, here are some of the most popular ones: * **Binary masking**: specifies per-sample and timestep whether that data point should be included or not in the computation, it can be non-contigous (e.g., [1, 1, 0, 1]). This is used by Keras. * **Sequence length masking**: specifies per-sample the number of non-padding examples contained in the sequence, any padding contained in the sequence should be stacked at the end. This is used by FlaxFormer. * **Segmentation Mask**: specifies row and timestep to which sample the data point belongs to, this format allows more than one sample per row which potentially reduces the total amount of padding needed (e.g. [1, 1, 1, 2, 2, 0, 0]). Pytorch uses this representation (see [pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)). While Sequence packing (see [LM1B example](https://github.com/google/flax/blob/main/examples/lm1b/input_pipeline.py#L90-L92)) is is more powerful, its implementation is more complex and it is not clear whether it is worth the effort. The simplest format is sequence length masking, which is the one we propose to use.
### Bidirectional Bidirectional processing can be achieved via a Module that accepts a `forward_rnn` Module and a `backward_rnn` Module, both of which should be `RNN` instances, in order to process the input sequence in both directions. Here we present some pseudo code of the implementation: ```python def __call__(self, inputs, seq_lengths): # Encode in the forward direction. carry_forward, outputs_forward = self.forward_rnn( inputs, seq_lengths=seq_lengths, return_carry=True, reverse=False, ) # Encode in the reverse order. carry_backward, outputs_backward = self.backward_rnn( inputs, seq_lengths=seq_lengths, return_carry=True, reverse=True, # process in reverse order keep_order=True, # but return the sequence in the original order ) # Merge both sequences. outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward) return (carry_forward, carry_backward), outputs ``` Here `merge_fn` a function that takes both outputs and fuses them (`concat` by default). As showcased in the beginning of this document, usage would look like this: ```python forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32) # Bidirectional combinator. bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) # Encodes a batch of input sequences in both directions. carry, outputs = bi_rnn(inputs, seq_lengths) ``` ### Recurrent Dropout There are two main uses of dropout in RNNs: 1. Input dropout: regular dropout applied to the inputs, different for every step. 4. Recurrent dropout: applies dropout to a recurrent input/output, same for every step. Flax's `nn.scan` can easily express both types of dropout via `split_rns`, input dropout would split rngs while recurrent dropout would not. [#2540](https://github.com/google/flax/pull/2540) was introduces such that the `rng_name` in `nn.Dropout` can now be defined by the user, this way Cells could define both types of dropout e.g: ```python self.dropout = nn.Dropout(...) # input dropout self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout') ``` Based on this, `nn.scan` / `nn.RNN` can now specify `split_rngs` accordingly e.g: ``` nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False}) ``` # Future ideas
show ### Sequence Packing Allow packing multiple sequences to make efficient use of space/memory. This might result in a trade-off where step time is higher (because at each step we need to check whether we are starting a new sequence and reset the carry/initial state), but where less padding is used increasing efficiency overall. ### RNNCell redesign #### Make initialize_state an instance method First altenative is to make `initialize_carry` a instance method. With this change hyperparameters can be passed directly to the cell, it signature would look like this: ```python def initialize_carry(self, sample_input) -> Carry: ... ``` Usage would look like this: ```python LSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'dropout': True}) lstm = LSTM(features=32) carry = lstm.initialize_carry(x[:, 0]) carry, y = lstm(carry, x) ``` #### Remove initialize_carry An alternative is to remove `initialize_carry` entirely and have the carry state be handled as a carry collection. This would simplify usage quite a bit: ```python LSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'dropout': True}) y = LSTM(features=32)(carry, x) ``` However, this would require `nn.scan` to support initialization of carry collections which is currently not possible. Also, users would have to specify that a collection is mutable e.g. `mutable=['carry']`, even if they are not interested in the output carry state.
================================================ FILE: docs/flip/2434-general-metadata.md ================================================ # FLIP: Axis Metadata - Start Date: 2022-08-08 - FLIP Issue: [#2434](https://github.com/google/flax/issues/2434) - FLIP PR: [#2435](https://github.com/google/flax/pull/2435) - Status: Proposal ## Summary This FLIP proposes to extend Flax's variable collections with a generic axis metadata API. The core of the API is an abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations. ## Motivation Generally, there is no way in Flax to track metadata for variables across lifted transformations. Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. Currently, there is an experimental [API](https://github.com/google/flax/blob/main/flax/linen/partitioning.py) supporting partitioning annotations with wrappers around lifted transforms that change axes (``nn.scan_with_axes``, ``nn.vmap_with_axes``) and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). The experimental partitioning API stores the metadata in a separate collection named "[collection]_axes". The experimental API has a number of shortcomings that we like to solve: 1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations. 2. The implementation using an "xxx_axes" collection requires error-prone and non-composable string manipulation. 3. Special, partioning-aware variable creators and lifted transforms are required 4. The partioning API is hard to use with pre-existing Modules that aren't partioning aware. ## Proposal To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class: ```python TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata") class AxisMetadata(metaclass=abc.ABCMeta): """Abstract base class for boxed Metadata. ``AxisMetadata`` enables arbitrary, per axis metadata for variables. By using ``unbox`` the metadata is stripped away to obtain the original variables. By using unboxing, most code handling variables does not need to handle ``AxisMetadata`` specifically, but can directly operate on the JAX arrays that they wrap. Additionally, ``AxisMetadata`` supports updating metadata whenever an axis is added or removed by a functional transformation (e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis`` methods. By extending ``AxisMetadata``, custom metadata can be stored. See ``Partitioned`` for a specific implementation. """ @abc.abstractmethod def unbox(self) -> Any: """Returns the content of the AxisMetadata box. Note that unlike ``meta.unbox`` the unbox call should recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata. In practise, AxisMetadata subclasses should be registred as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this note should correspond to the value returned by unbox. Returns: The unboxed value. """ pass @abc.abstractmethod def add_axis(self: TAxisMetadata, index: int, params: Dict[Any, Any]) -> TAxisMetadata: """Adds a new axis to the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``) Args: index: The position at which the new axis will be inserted params: An arbitrary dictionary of parameters passed by the transformation that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass @abc.abstractmethod def remove_axis(self: TAxisMetadata, index: int, params: Dict[Any, Any]) -> TAxisMetadata: """Removes an axis from the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``) Args: index: The position of the axis that is to be removed params: An arbitrary dictionary of parameters passed by the transformation that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass ``` We call this type of class wrapping a value and keeping track of some additional data a **box**. By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked. This should make the API future proof and modular. The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place. Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. Calling ``jax.tree.map`` on a boxed value will simply map over the value in the box. The lifted transforms that need to handle metadata will call ``jax.tree.map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. Advantages of the boxing approach: 1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree.map`` over the boxed parameters. 2. Boxes are composable. 3. Boxing avoids string manipulation and generally avoids having to handle additional auxiliary collections like "param_axes" in the current partitioning API. 4. No need to lift metadata collections separately. Disadvantages: 1. Adding the boxes changes the PyTree hierarchy and introduces dataclasses within the otherwise plain, nested dict of variables. 3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. ### Init syntax Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. The main advantage of this is that we can decouple metadata handling completely from the Module definition. Also, most Modules already overwrite attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: ```python class Partitioned(flax.struct.PyTreeNode, AxisMetadata): value: Any names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False) def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: axis_name = self._get_partition_name(params) names = list(self.names) names.insert(index, axis_name) return self.replace(names=tuple(names)) def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: axis_name = self._get_partition_name(params) names = list(self.names) assert names.pop(index) == axis_name return self.replace(names=tuple(names)) def with_partitioning(init_fn, names): def wrapper(*args, **kwargs): return Partitioned(init_fn(*args, **kwargs), names) return wrapper ``` Here we also defined a small utility called ``with_partitioning`` that we can use to wrap existing initialzers to add metadata: ```python # init kernel with lecun normal and split the output features over the data axis partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data"))) ``` Initializing a model that creates partitioned weights would result in the following variable structure: ```python variables = partitioned_dense.init(rng, jnp.ones((4,))) jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} ``` The variable tree with metadata can be used to integrate with other libraries and APIs. For example, we can turn the ``Partitioned`` metadata into ``jax.pjit`` sharding annotations: ```python def to_sharding_spec(x): if isinstance(x, Partitioned): return PartitionSpec(*x.names) else: # fully replicated return PartitionSpec() # Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}} variables_pspec = jax.tree.map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) ``` ### Unbox syntax Metadata typically doesn't need to be handled by Modules directly. Therefore, we propose to make Modules agnostic to Metadata boxes by default. The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``). The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible with the new API. ```python kernel = self.param("kernel", self.kernel_init, shape) # No AxisMetadata instances kernel_box = self.get_variable("param", "kernel", unbox=False) # AxisMetadata boxes are preserved ``` ### Lift syntax When calling a lifted transformation that adds an axis you will now be able to pass a dictionary with arguments. These params will be passed to ``AxisMetadata`` add_axis/remove_axis callbacks: ```python nn.scan(..., variable_axes={"params": 0}, metadata_params={nn.Partitioned.AXIS_NAME: "layers"}) ``` A dict is used such that users can add their own arguments to custom AxisMetadata classes. ================================================ FILE: docs/flip/2974-kw-only-dataclasses.md ================================================ # FLIP: kw_only dataclasses Authors: Brennan Saeta, Ivy Zheng - Start Date: Mar 23, 2023 - FLIP Issue: [TBD] - FLIP PR: #2974 - Status: Implementing ## Summary Python 3.10 adds support for `kw_only` dataclasses. Subclasses of `flax.linen.Module` are automatically converted to `dataclasses` on users' behalf, but today, Flax doesn't allow setting the `kw_only` parameter to this dataclass transform, even if users are running Python 3.10. This proposal allows users to use this new feature with `nn.Module`'s. ## Motivation In larger Flax-based codebases (e.g. [`PaxML`](https://github.com/google/paxml) / [`Praxis`](https://github.com/google/praxis)), it’s not uncommon to define an (abstract) subclass of nn.Module that contains shared functionality that is itself further subclassed for specific implementations (e.g. [`BaseLayer`](https://github.com/google/praxis/blob/main/praxis/base_layer.py), or [`StackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L1836) which is further subclassed by [`PipelineCompatibleStackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L2198)). Often, these parent types define hyperparameters (constructor arguments), often with default values. Without `kw_only` on the `dataclass` transform, default values must be specified for all child layers hyperparameters. This is suboptimal, because users could forget to set them when instantiating the modules. For example, `Child` must set a default value for `num_heads` (because a non-defaulted argument can’t come after a defaulted argument if they are positional), but no reasonable default is available: ```python class BaseLayer(nn.Module): mesh: Optional[jax.experimental.mesh.Mesh] = None def with_sharding(self, some_variable, some_sharding): if self.mesh: # Do something useful here. class Child(BaseLayer): num_heads: int # Don't want to have to set a default argument! def __call__(self, x): ... ``` Note: Flax already has this problem, which is why `nn.Module` has its own fancy `kw_only_dataclasses.dataclass` transform: it moves the `name` and `parent` dataclass fields to the end, so they can have defaults. ## Implementation To allow modules to optionally opt into this `kw_only` dataclass behavior, we leverage arguments to `__init_subclass__`. This would look as follows: ```python class BaseLayer(nn.Module, kw_only=True): ... class Child(BaseLayer): ... ``` The implementation of `nn.Module`’s `__init_subclass__` will be tweaked as follows: ```python class Module(ModuleBase): def __init_subclass__(self, kw_only: Optional[bool] = None): # ... if kw_only: if is_python_310_or_above(): dataclass_transform_args = {'kw_only': True} else: raise TypeError("Can't use `kw_only` before Py3.10.") else: dataclass_transform_args = {} kw_only_dataclasses.dataclass( cls, unsafe_hash='__hash__' not in cls.__dict__, repr=False, **dataclass_transform_args) ``` ### Forward compatibility For future simplification, if `kw_only` is requested and the Python version is 3.10 or above, bypass the `kw_only_dataclasses` implementation and just use the regular `dataclasses` transform. That means we may one day remove `flax/linen/kw_only_dataclasses.py` when Flax rolls over 3.10. ## Discussion ### Aligned with Python `dataclass` We prefer to keep the behavior of `nn.Module`’s `kw_only` aligned with the Python dataclasses. Note that this means `kw_only` will not be inheritable, and this could happen: ```python class BaseLayer(nn.Module, kw_only=True): base_muliplier: Optional[int] = -1 class ChildLayer(BaseLayer): child_multiplier: int BaseLayer(2) # This will throw error ChildLayer(2) # But this will not ``` ### `flax.struct.dataclass` There’s a potentially related feature to allow `kw_only` to be specified for `flax.struct.dataclass`. This should be considered an orthogonal decision. ================================================ FILE: docs/flip/3099-rnnbase-refactor.md ================================================ # Refactor RNNCellBase in FLIP Authors: Cristian Garcia, Marcus Chiam, Jasmijn Bastings - Start Date: May 1, 2023 - FLIP Issue: [TBD] - FLIP PR: #3053 - Status: Implemented ## Summary This proposal aims to improve the usability of the `RNNCellBase` class by refactoring the `initialize_carry` method and other relevant components. ## Motivation Currently, `initialize_carry` is used to both initialize the carry and pass crucial metadata like the number of features. The API can be unintuitive as it requires users to manually calculate things that could typically be inferred by the modules themselves, such as the shape of batch dimensions and the shape of feature dimensions. ### Example: ConvLSTM The current API can be unintuitive in cases like `ConvLSTM` where a the `size` parameter contains both the input image shape and output feature dimensions: ```python x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) # image shape: vvvvvvv carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16)) # batch size: ^^ ^^ :output features lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) (carry, y), initial_params = lstm.init_with_output(key2, carry, x) ``` This FLIP will propose some changes to `initialize_carry` such that the previous example can be simplified to: ```python x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) carry = lstm.initialize_carry(key1, input_shape=x.shape) (carry, y), initial_params = lstm.init_with_output(key2, carry, x) ``` ## Implementation The proposal suggests the following changes: ### initialize_carry `initialize_carry` should be refactored as an instance method with the following signature: ```python def initialize_carry(self, key, sample_input): ``` `sample_input` should be an array of the same shape that will be processed by the cell, excluding the time axis. ### Refactor RNNCellBase subclasses `RNNCellBase` should be refactored to include the metadata required to initialize the cell and execute its forward pass. For `LSTMCell` and `GRUCell`, this means adding a `features` attribute that should be provided by the user upon construction. This change aligns with the structure of most other `Module`s, making them more familiar to users. ```python x = jnp.ones((2, 100, 10)) # (batch, time, features) cell = nn.LSTMCell(features=32) carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input (carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x) ``` ### num_feature_dims To simplify the handling of `RNNCellBase` instances in abstractions like `RNN`, each cell should implement the `num_feature_dims` property. For most cells, such as `LSTMCell` and `GRUCell`, this is always 1. For cells like `ConvLSTM`, this depends on their `kernel_size`. ## Discussion ### Alternative Approaches * To eliminate the need for `num_feature_dims`, `RNN` could support only a single batch dimension, i.e., inputs of the form `(batch, time, *features)`. Currently, it supports both multiple batch dimensions and multiple feature dimensions. * Another approach could be a complete redesign of how Flax deals with recurrent states. For example, a `memory` collection could be handled as part of the variables. However, this introduces challenges such as handling stateless cells during training, passing state from one layer to another, and performing initialization inside `scan`. ### Refactor Cost Initial TGP results showed 761 broken and 110 failed tests. However, after fixing one test, TGP results in 231 broken and 13 failed tests so there seems to be a lot of overlap between the broken tests. To minimize refactor costs, the current implementation will be kept for Google internal users under a deprecated name. This will allow users to migrate to the new API at their own pace. For Open Source users we should bump Flax version to `0.7.0` so existing users can continue to depend on `0.6.x` versions. ================================================ FILE: docs/flip/4105-jax-style-nnx-transforms.md ================================================ # JAX-style NNX Transforms - Authors: Cristian Garcia, Anselm Levskaya - Date: Jun/2024 - FLIP PR: #4107 - Status: Implementing ## Motivation NNX allows users to utilize Modules at the top level due to their eager initialization and self-contained state. This naturally leads users to want to use them with transforms and soon start playing with NNX transforms. Since NNX Modules resemble PyTrees in that they contain Arrays, new users often attempt to apply JAX conventions, for example: ```py @nnx.vmap(in_axes=(1, 0)) def f(m1: Module, m2: Module): ... ``` However, this can be misleading. Currently, NNX transforms follow Linen's convention of treating input Modules as a single unit (all Modules are split together to preserve shared references) and provide APIs for transforming that State separately. The previous example effectively translates to: ```py # this is what is really happening @nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0}) def f(m1: Module, m2: Module): ... ``` Note that `IGNORE` is not a real symbol, but represents the fact that any value placed here won't affect the outcome, as Modules are replaced by empty PyTree placeholders (similar to `None`). The `state_axes` parameter controls how the State is vectorized through a mapping of high-level `Filter`s to their desired axes. In this example, `...` (ellipsis) is a filter that accepts everything, so by default all States are vectorized on the 0th axis. To express their original intention, users must resort to more complex custom filters that guess the index of each Module in the monolith. While this is straightforward in simple cases, users generally need to calculate the index (Modules appear in the order specified by `jax.tree.leaves` over the `args`): ```py select_m1 = lambda path, value: path[0] == 0 select_m2 = lambda path, value: path[0] == 1 # To select modules individually, you must create a filter (which can be tricky) @nnx.vmap(state_axes={select_m1: 1, select_m2: 0}) def f(m1: Module, m2: Module): ... ``` ## What if JAX conventions Just Worked™? This proposal aims to align NNX transforms with user's expectations based on their JAX experience, making the syntax work as intuitively as possible. The original example would function **as if** `m1` and `m2` were PyTrees vectorized in axes `1` and `0` respectively: ```py @nnx.vmap(in_axes=(1, 0)) def f(m1: Module, m2: Module): ... ``` The primary advantage of this approach is that for `vmap` and `scan`, we could eliminate the `state_axes` and `split_rngs` arguments, relying solely on the `in_axes` API. This syntax alone would likely suffice for 80-90% of use cases, as users tend to manage state in predictable ways. ### The Lift symbols To enable more fine-grained state control within each Module, we introduce the `Lift` API. By using special types containing State Filters in place of a tree prefix, state lifting can now be done **structurally**. This allows different Filters to be applied to different Modules in the arguments without the need for complex path-based filters. Ideally, each transform would support its own Lift type, adding the desired behavior through existing JAX APIs. For example, in `vmap`, we could allow `StateAxes` instances (vmap's Lift type) to be accepted by `in/out_axes` to control how substates are handled by mapping state `Filter`s to an axis specifier: ```py state_axes = StateAxes({Param: 1, BatchStat: None}) @nnx.vmap(in_axes=(state_axes, 0)) def f(m1: Module, m2: Module): ... ``` In this case, `m1`'s `Param`s are vectorized in axis `1` while its `BatchStat`s are broadcasted, and `m2`'s entire state is vectorized in axis `0`. For `nnx.grad`, we could allow `DiffState` to be used in the `argnums` parameter to specify both the position of the argument to be differentiated and a Filter specifying the differentiable State of the Module: ```py grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y) ``` ## Rng Handling To simplify RNG state handling, we propose removing the separate `split_rngs` parameter in `vmap` and `scan`. Instead, we suggest introducing a new `nnx.split_rngs` API that would manage RNG handling before and after the transformation. This approach provides more explicit control to the user and aligns better with JAX transform behavior. ## Consistent Aliasing To ensure the correctness of transformations with objects that obey reference semantics, we must enforce consistent lifting/lowering specifications for all aliases of a reference. Transforms must adhere to two rules: 1. All aliases of a reference must receive the **exact same** lifting/lowering specification. 2. Captured references are not allowed on the output of transformed functions. For example: ```py @nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes) def f(m1, m2, m1_alias): return m2 m2 = f(m1, m2, m1) ``` Here, `m1` has two input aliases as it is passed as the first and third input to `f`, but this is acceptable because `m1_axes` is assigned to both in `in_axes`. `m2` is passed as the second input and has an output alias, which is also acceptable because `m2_axes` is assigned in both `in_axes` and `out_axes`. Let's examine some examples of programs that should be **rejected** based on these criteria: ### Inconsistent input aliases Consider a function with two arguments `m1` and `m2` being vectorized in axis `0` and `1` respectively. Passing the same Module as both arguments would be inconsistent: ```py @nnx.vmap(in_axes=(0, 1)) def f(m1: Module, m2: Module): ... f(m, m) # This should be rejected ``` ### Inconsistent input / output aliases Now consider an identity function `g` under `vmap` with `in_axes=0` and `out_axes=1`. In JAX, this would result in transposing the arrays in the inputs: ```py @nnx.vmap(in_axes=0, out_axes=1) def g(m: Module): return m ``` While this appears correct, in NNX this behavior is not well-defined because shared mutable references behave as auxiliary outputs. Under the hood, `g` is converted into a function that has the inputs as an extra first output, and `out_axes` is set to the same values as `in_axes` for that output: ```py @nnx.vmap(in_axes=0, out_axes=(0, 1)) def g_real(m: Module): return m, m ``` This return structure reveals an inconsistency: we're attempting to lower `m` with both `out_axes=0` and `out_axes=1`. ### Inconsistent aliases in nested structures Similar issues can arise in less obvious cases, such as when `m` is contained within another structure: ```py @nnx.vmap(in_axes=0, out_axes=1) def f(m: Module): return SomeModule(m) ``` This means we must traverse the entire graph of both inputs and outputs to check for consistent assignments. The same problem occurs when passing shared reference inputs/outputs with different specifications: ```py shared = Shared() m1, m2 = Foo(shared), Foo(shared) @nnx.vmap(in_axes=(0, 1)) def f(m1, m2): # shared is passed through both ... ``` ### Captured Modules cannot be outputs Finally, let's consider the second consistent aliasing rule, which states that captured Modules cannot be outputs. The main issue here is that NNX needs to split all input references together to track changes, but captured Modules bypass this process. Treating them as new references would result in **implicit cloning**: ```py m = SomeModule() @nnx.vmap(out_axes=0, axis_size=5) def f(): return m assert m is not f() # implicit cloning ``` To preserve reference identity, we must disallow captured Modules as outputs. In practice, we can detect captured Modules using the trace level context machinery used to restrict stateful updates on Modules from a different level. ## Recap In this document, we have: * Discussed issues with the current implementation that make it unintuitive for JAX users. * Proposed refactoring NNX transforms to allow users to use regular JAX semantics when interacting with objects, removing extra arguments introduced by NNX transforms. * Introduced the use of Lift types in JAX APIs to compensate for the lack of a "prefix" notion in NNX objects, enabling independent lifting of Module substates. * Proposed a new `nnx.split_rngs` API to replace the `split_rngs` arguments in `vmap` and `scan`, making RNG handling an explicit operation and giving users more control. * Analyzed edge cases resulting from aliasing shared mutable references and proposed enforcing **consistent aliasing** on all transforms with semantics over the inputs. ================================================ FILE: docs/flip/README.md ================================================ # FLIP: Flax Improvement Process Most changes can be discussed with simple issues/discussions and pull requests. Some changes though are a bit larger in scope or require more discussion, and these should be implemented as FLIPs. This allows for writing longer documents that can be discussed in a pull request themselves. The structure of FLIPs is kept as lightweight as possible to start and might be extended later on. ## When you should use a FLIP - When your change requires a design doc. We prefer collecting the designs as FLIPs for better discoverability and further reference. - When your change requires extensive discussion. It's fine to have relatively short discussions on issues or pull requests, but when the discussion gets longer this becomes unpractical for later digestion. FLIPs allow to update the main document with a summary of the discussion and these updates can be discussed themselves in the pull request adding the FLIP. ## How to start a FLIP First, create an issue with the [FLIP label]. All pull requests that relate to the FLIP (i.e. adding the FLIP itself as well as any implementing pull requests) should be linked to this issue. Then create a pull request that consists of a copy of the `0000-template.md` renamed to `%04d-{short-title}.md` - with the number being the issue number. [FLIP label]: https://github.com/google/flax/issues?q=label%3AFLIP ================================================ FILE: docs/glossary.rst ================================================ ********* Glossary ********* For additional terms, refer to the `Jax glossary `__. .. glossary:: Bound Module When a :class:`Module ` is created through regular Python object construction (e.g. `module = SomeModule(args...)`, it is in an *unbound* state. This means that only dataclass attributes are set, and no variables are bound to the module. When the pure functions :meth:`Module.init() ` or :meth:`Module.apply() ` are called, Flax clones the Module and binds the variables to it, and the module's method code is executed in a locally bound state, allowing things like calling submodules directly without providing variables. For more details, refer to the `module lifecycle `__. Compact / Non-compact Module Modules with a single method are able to declare submodules and variables inline by using the :func:`@nn.compact ` decorator. These are referred to as “compact-style modules”, whereas modules defining a :meth:`setup() ` method (usually but not always with multiple callable methods) are referred to as “setup-style modules”. To learn more, refer to the `setup vs compact guide `__. `Folding in `__ Generating a new PRNG key given an input PRNG key and integer. Typically used when you want to generate a new key but still be able to use the original rng key afterwards. You can also do this with `jax.random.split `__ but this will effectively create two RNG keys, which is slower. See how Flax generates new PRNG keys automatically within ``Modules`` in our `RNG guide `__. `FrozenDict `__ An immutable dictionary which can be “`unfrozen `__” to a regular, mutable dictionary. Internally, Flax uses FrozenDicts to ensure variable dicts aren't accidentally mutated. Note: We are considering returning to regular dicts from our APIs, and only using FrozenDicts internally. (see `#1223 `__). Functional core The flax core library implements the simple container Scope API for threading variables and PRNGs through a model, as well as the lifting machinery needed to transform functions passing Scope objects. The python class-based module API is built on top of this core library. Lazy initialization Variables in Flax are initialized late, only when needed. That is, during normal execution of a module, if a requested variable name isn’t found in the provided variable collection data, we call the initializer function to create it. This allows us to treat initialization and application under the same code-paths, simplifying the use of JAX transforms with layers. Lifted transformation Refer to the `Flax docs `__. Module A dataclass allowing the definition and initialization of parameters in a referentially-transparent form. This is responsible for storing and updating variables and parameters within itself. Modules can be readily transformed into functions, allowing them to be trivially used with JAX transformations like `vmap` and `scan`. Params / parameters "params" is the canonical variable collection in the variable dictionary (dict). The “params” collection generally contains the trainable weights. RNG sequences Inside Flax :class:`Modules `, you can obtain a new `PRNG `__ key through :meth:`Module.make_rng() `. These keys can be used to generate random numbers through `JAX's functional random number generators `__. Having different RNG sequences (e.g. for "params" and "dropout") allows fine-grained control in a multi-host setup (e.g. initializing parameters identically on different hosts, but have different dropout masks) and treating these sequences differently when `lifting transformations `__. See the `RNG guide `__ for more details. Scope A container class for holding the variables and PRNG keys for each layer. Shape inference Modules do not need to specify the shape of the input array in their definitions. Flax upon initialization inspects the input array, and infers the correct shapes for parameters in the model. TrainState Refer to :class:`flax.training.train_state.TrainState`. Variable The `weights / parameters / data / arrays `__ residing in the leaves of :term:`variable collections`. Variables are defined inside modules using :meth:`Module.variable() `. A variable of collection "params" is simply called a param and can be set using :meth:`Module.param() `. Variable collections Entries in the variable dict, containing weights / parameters / data / arrays that are used by the model. “params” is the canonical collection in the variable dict. They are typically differentiable, updated by an outer SGD-like loop / optimizer, rather than modified directly by forward-pass code. `Variable dictionary `__ A dictionary containing :term:`variable collections`. Each variable collection is a mapping from a string name (e.g., ":term:`params`" or "batch_stats") to a (possibly nested) dictionary with :term:`Variables` as leaves, matching the submodule tree structure. Read more about pytrees and leaves in the `Jax docs `__. ================================================ FILE: docs/guides/converting_and_upgrading/convert_pytorch_to_flax.rst ================================================ Convert PyTorch models to Flax ============================== .. testsetup:: import numpy as np import jax from jax import random, numpy as jnp import flax from flax import linen as nn import torch We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling. FC Layers -------------------------------- Let's start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC] and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick. .. testcode:: t_fc = torch.nn.Linear(in_features=3, out_features=4) kernel = t_fc.weight.detach().cpu().numpy() bias = t_fc.bias.detach().cpu().numpy() # [outC, inC] -> [inC, outC] kernel = jnp.transpose(kernel, (1, 0)) key = random.key(0) x = random.normal(key, (1, 3)) variables = {'params': {'kernel': kernel, 'bias': bias}} j_fc = nn.Dense(features=4) j_out = j_fc.apply(variables, x) t_x = torch.from_numpy(np.array(x)) t_out = t_fc(t_x) t_out = t_out.detach().cpu().numpy() np.testing.assert_almost_equal(j_out, t_out, decimal=6) Convolutions -------------------------------- Let's now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC. Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW] and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick. .. testcode:: t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') kernel = t_conv.weight.detach().cpu().numpy() bias = t_conv.bias.detach().cpu().numpy() # [outC, inC, kH, kW] -> [kH, kW, inC, outC] kernel = jnp.transpose(kernel, (2, 3, 1, 0)) key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) variables = {'params': {'kernel': kernel, 'bias': bias}} j_conv = nn.Conv(features=4, kernel_size=(2, 2), padding='valid') j_out = j_conv.apply(variables, x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_conv(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Convolutions and FC Layers -------------------------------- We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. Consider this PyTorch model: .. testcode:: class TModel(torch.nn.Module): def __init__(self): super(TModel, self).__init__() self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') self.fc = torch.nn.Linear(in_features=100, out_features=2) def forward(self, x): x = self.conv(x) x = x.reshape(x.shape[0], -1) x = self.fc(x) return x t_model = TModel() Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this: .. testcode:: class JModel(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=4, kernel_size=(2, 2), padding='valid', name='conv')(x) # [N, H, W, C] -> [N, C, H, W] x = jnp.transpose(x, (0, 3, 1, 2)) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.Dense(features=2, name='fc')(x) return x j_model = JModel() The model looks very similar to the PyTorch model, except that we included a transpose operation before reshaping our activations for the fc layer. We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1. Other than the transpose operation before reshaping, we can convert the weights the same way as we did before: .. testcode:: conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy() conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy() fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy() fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy() # [outC, inC, kH, kW] -> [kH, kW, inC, outC] conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0)) # [outC, inC] -> [inC, outC] fc_kernel = jnp.transpose(fc_kernel, (1, 0)) variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias}, 'fc': {'kernel': fc_kernel, 'bias': fc_bias}}} key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = j_model.apply(variables, x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_model(t_x) t_out = t_out.detach().cpu().numpy() np.testing.assert_almost_equal(j_out, t_out, decimal=6) Batch Norm -------------------------------- ``torch.nn.BatchNorm2d`` uses ``0.1`` as the default value for the ``momentum`` parameter while |nn.BatchNorm|_ uses ``0.9``. However, this corresponds to the same computation, because PyTorch multiplies the estimated statistic with ``(1 − momentum)`` and the new observed value with ``momentum``, while Flax multiplies the estimated statistic with ``momentum`` and the new observed value with ``(1 − momentum)``. .. |nn.BatchNorm| replace:: ``nn.BatchNorm`` .. _nn.BatchNorm: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.BatchNorm .. testcode:: t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1) t_bn.eval() scale = t_bn.weight.detach().cpu().numpy() bias = t_bn.bias.detach().cpu().numpy() mean = t_bn.running_mean.detach().cpu().numpy() var = t_bn.running_var.detach().cpu().numpy() variables = {'params': {'scale': scale, 'bias': bias}, 'batch_stats': {'mean': mean, 'var': var}} key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True) j_out = j_bn.apply(variables, x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_bn(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Average Pooling -------------------------------- ``torch.nn.AvgPool2d`` and |nn.avg_pool()|_ are compatible when using default parameters. However, ``torch.nn.AvgPool2d`` has a parameter ``count_include_pad``. When ``count_include_pad=False``, the zero-padding will not be considered for the average calculation. There does not exist a similar parameter for |nn.avg_pool()|_. However, we can easily implement a wrapper around the pooling operation. ``nn.pool()`` is the core function behind |nn.avg_pool()|_ and |nn.max_pool()|_. .. |nn.avg_pool()| replace:: ``nn.avg_pool()`` .. _nn.avg_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool .. |nn.max_pool()| replace:: ``nn.max_pool()`` .. _nn.max_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool .. testcode:: def avg_pool(inputs, window_shape, strides=None, padding='VALID'): """ Pools the input by taking the average over a window. In comparison to nn.avg_pool(), this pooling operation does not consider the padded zero's for the average computation. """ assert len(window_shape) == 2 y = nn.pool(inputs, 0., jax.lax.add, window_shape, strides, padding) counts = nn.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding) y = y / counts return y key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1))) t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_pool(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Transposed Convolutions -------------------------------- ``torch.nn.ConvTranspose2d`` and |nn.ConvTranspose|_ are not compatible. |nn.ConvTranspose|_ is a wrapper around |jax.lax.conv_transpose|_ which computes a fractionally strided convolution, while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolution. Currently, there is no implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ that contains an implementation. To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's ``nn.ConvTranspose`` layer. .. testcode:: # padding is inverted torch_padding = 0 flax_padding = 1 - torch_padding t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding) kernel = t_conv.weight.detach().cpu().numpy() bias = t_conv.bias.detach().cpu().numpy() # [inC, outC, kH, kW] -> [kH, kW, outC, inC] kernel = jnp.transpose(kernel, (2, 3, 1, 0)) key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) variables = {'params': {'kernel': kernel, 'bias': bias}} # ConvTranspose expects the kernel to be [kH, kW, inC, outC], # but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True) j_out = j_conv.apply(variables, x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_conv(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) .. _`pull request`: https://github.com/jax-ml/jax/pull/5772 .. |nn.ConvTranspose| replace:: ``nn.ConvTranspose`` .. _nn.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.ConvTranspose .. |jax.lax.conv_transpose| replace:: ``jax.lax.conv_transpose`` .. _jax.lax.conv_transpose: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html ================================================ FILE: docs/guides/converting_and_upgrading/haiku_migration_guide.rst ================================================ Migrating from Haiku to Flax ============================ This guide will walk through the process of migrating Haiku models to Flax, and highlight the differences between the two libraries. .. testsetup:: Haiku, Flax import jax import jax.numpy as jnp from jax import random import optax import flax.linen as nn import haiku as hk Basic Example ----------------- To create custom Modules you subclass from a ``Module`` base class in both Haiku and Flax. However, Haiku classes use a regular ``__init__`` method whereas Flax classes are ``dataclasses``, meaning you define some class attributes that are used to automatically generate a constructor. Also, all Flax Modules accept a ``name`` argument without needing to define it, whereas in Haiku ``name`` must be explicitly defined in the constructor signature and passed to the superclass constructor. .. codediff:: :title: Haiku, Flax :sync: import haiku as hk class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x) x = jax.nn.relu(x) return x class Model(hk.Module): def __init__(self, dmid: int, dout: int, name=None): super().__init__(name=name) self.dmid = dmid self.dout = dout def __call__(self, x, training: bool): x = Block(self.dmid)(x, training) x = hk.Linear(self.dout)(x) return x --- import flax.linen as nn class Block(nn.Module): features: int @nn.compact def __call__(self, x, training: bool): x = nn.Dense(self.features)(x) x = nn.Dropout(0.5, deterministic=not training)(x) x = jax.nn.relu(x) return x class Model(nn.Module): dmid: int dout: int @nn.compact def __call__(self, x, training: bool): x = Block(self.dmid)(x, training) x = nn.Dense(self.dout)(x) return x The ``__call__`` method looks very similar in both libraries, however, in Flax you have to use the ``@nn.compact`` decorator in order to be able to define submodules inline. In Haiku, this is the default behavior. Now, a place where Haiku and Flax differ substantially is in how you construct the model. In Haiku, you use ``hk.transform`` over a function that calls your Module, ``transform`` will return an object with ``init`` and ``apply`` methods. In Flax, you simply instantiate your Module. .. codediff:: :title: Haiku, Flax :sync: def forward(x, training: bool): return Model(256, 10)(x, training) model = hk.transform(forward) --- ... model = Model(256, 10) To get the model parameters in both libraries you use the ``init`` method with a ``random.key`` plus some inputs to run the model. The main difference here is that Flax returns a mapping from collection names to nested array dictionaries, ``params`` is just one of these possible collections. In Haiku, you get the ``params`` structure directly. .. codediff:: :title: Haiku, Flax :sync: sample_x = jax.numpy.ones((1, 784)) params = model.init( random.key(0), sample_x, training=False # <== inputs ) ... --- sample_x = jax.numpy.ones((1, 784)) variables = model.init( random.key(0), sample_x, training=False # <== inputs ) params = variables["params"] One very important thing to note is that in Flax the parameters structure is hierarchical, with one level per nested module and a final level for the parameter name. In Haiku the parameters structure is a python dictionary with a two level hierarchy: the fully qualified module name mapping to the parameter name. The module name consists of a ``/`` separated string path of all the nested Modules. .. tab-set:: .. tab-item:: Haiku :sync: Haiku .. code-block:: python ... { 'model/block/linear': { 'b': (256,), 'w': (784, 256), }, 'model/linear': { 'b': (10,), 'w': (256, 10), } } ... .. tab-item:: Flax :sync: Flax .. code-block:: python FrozenDict({ Block_0: { Dense_0: { bias: (256,), kernel: (784, 256), }, }, Dense_0: { bias: (10,), kernel: (256, 10), }, }) During training in both frameworks you pass the parameters structure to the ``apply`` method to run the forward pass. Since we are using dropout, in both cases we must provide a ``key`` to ``apply`` in order to generate the random dropout masks. .. codediff:: :title: Haiku, Flax :sync: def train_step(key, params, inputs, labels): def loss_fn(params): logits = model.apply( params, key, inputs, training=True # <== inputs ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(params) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) return params --- def train_step(key, params, inputs, labels): def loss_fn(params): logits = model.apply( {'params': params}, inputs, training=True, # <== inputs rngs={'dropout': key} ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(params) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) return params .. testcode:: Haiku, Flax :hide: train_step(random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) The most notable differences is that in Flax you have to pass the parameters inside a dictionary with a ``params`` key, and the key inside a dictionary with a ``dropout`` key. This is because in Flax you can have many types of model state and random state. In Haiku, you just pass the parameters and the key directly. Handling State ----------------- Now let's see how mutable state is handled in both libraries. We will take the same model as before, but now we will replace Dropout with BatchNorm. .. codediff:: :title: Haiku, Flax :sync: class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.BatchNorm( create_scale=True, create_offset=True, decay_rate=0.99 )(x, is_training=training) x = jax.nn.relu(x) return x --- class Block(nn.Module): features: int @nn.compact def __call__(self, x, training: bool): x = nn.Dense(self.features)(x) x = nn.BatchNorm( momentum=0.99 )(x, use_running_average=not training) x = jax.nn.relu(x) return x The code is very similar in this case as both libraries provide a BatchNorm layer. The most notable difference is that Haiku uses ``is_training`` to control whether or not to update the running statistics, whereas Flax uses ``use_running_average`` for the same purpose. To instantiate a stateful model in Haiku you use ``hk.transform_with_state``, which changes the signature for ``init`` and ``apply`` to accept and return state. As before, in Flax you construct the Module directly. .. codediff:: :title: Haiku, Flax :sync: def forward(x, training: bool): return Model(256, 10)(x, training) model = hk.transform_with_state(forward) --- ... model = Model(256, 10) To initialize both the parameters and state you just call the ``init`` method as before. However, in Haiku you now get ``state`` as a second return value, and in Flax you get a new ``batch_stats`` collection in the ``variables`` dictionary. Note that since ``hk.BatchNorm`` only initializes batch statistics when ``is_training=True``, we must set ``training=True`` when initializing parameters of a Haiku model with an ``hk.BatchNorm`` layer. In Flax, we can set ``training=False`` as usual. .. codediff:: :title: Haiku, Flax :sync: sample_x = jax.numpy.ones((1, 784)) params, state = model.init( random.key(0), sample_x, training=True # <== inputs #! ) ... --- sample_x = jax.numpy.ones((1, 784)) variables = model.init( random.key(0), #! sample_x, training=False # <== inputs ) params, batch_stats = variables["params"], variables["batch_stats"] In general, in Flax you might find other state collections in the ``variables`` dictionary such as ``cache`` for auto-regressive transformers models, ``intermediates`` for intermediate values added using ``Module.sow``, or other collection names defined by custom layers. Haiku only makes a distinction between ``params`` (variables which do not change while running ``apply``) and ``state`` (variables which can change while running ``apply``). Now, training looks very similar in both frameworks as you use the same ``apply`` method to run the forward pass. In Haiku, now pass the ``state`` as the second argument to ``apply``, and get the new state as the second return value. In Flax, you instead add ``batch_stats`` as a new key to the input dictionary, and get the ``updates`` variables dictionary as the second return value. .. codediff:: :title: Haiku, Flax :sync: def train_step(params, state, inputs, labels): def loss_fn(params): logits, new_state = model.apply( params, state, None, # <== rng inputs, training=True # <== inputs ) loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() return loss, new_state grads, new_state = jax.grad(loss_fn, has_aux=True)(params) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) return params, new_state --- def train_step(params, batch_stats, inputs, labels): def loss_fn(params): logits, updates = model.apply( {'params': params, 'batch_stats': batch_stats}, inputs, training=True, # <== inputs mutable='batch_stats', ) loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() return loss, updates["batch_stats"] grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) return params, batch_stats .. testcode:: Flax :hide: train_step(params, batch_stats, sample_x, jnp.ones((1,), dtype=jnp.int32)) One major difference is that in Flax a state collection can be mutable or immutable. During ``init`` all collections are mutable by default, however, during ``apply`` you have to explicitly specify which collections are mutable. In this example, we specify that ``batch_stats`` is mutable. Here a single string is passed but a list can also be given if there are more mutable collections. If this is not done an error will be raised at runtime when trying to mutate ``batch_stats``. Also, when ``mutable`` is anything other than ``False``, the ``updates`` dictionary is returned as the second return value of ``apply``, else only the model output is returned. Haiku makes the mutable/immutable distinction through having ``params`` (immutable) and ``state`` (mutable) and using either ``hk.transform`` or ``hk.transform_with_state`` Using Multiple Methods ----------------------- In this section we will take a look at how to use multiple methods in Haiku and Flax. As an example, we will implement an auto-encoder model with three methods: ``encode``, ``decode``, and ``__call__``. In Haiku, we can just define the submodules that ``encode`` and ``decode`` need directly in ``__init__``, in this case each will just use a ``Linear`` layer. In Flax, we will define an ``encoder`` and a ``decoder`` Module ahead of time in ``setup``, and use them in the ``encode`` and ``decode`` respectively. .. codediff:: :title: Haiku, Flax :sync: class AutoEncoder(hk.Module): def __init__(self, embed_dim: int, output_dim: int, name=None): super().__init__(name=name) self.encoder = hk.Linear(embed_dim, name="encoder") self.decoder = hk.Linear(output_dim, name="decoder") def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) def __call__(self, x): x = self.encode(x) x = self.decode(x) return x --- class AutoEncoder(nn.Module): embed_dim: int output_dim: int def setup(self): self.encoder = nn.Dense(self.embed_dim) self.decoder = nn.Dense(self.output_dim) def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) def __call__(self, x): x = self.encode(x) x = self.decode(x) return x Note that in Flax ``setup`` doesn't run after ``__init__``, instead it runs when ``init`` or ``apply`` are called. Now, we want to be able to call any method from our ``AutoEncoder`` model. In Haiku we can define multiple ``apply`` methods for a module through ``hk.multi_transform``. The function passed to ``multi_transform`` defines how to initialize the module and which different apply methods to generate. .. codediff:: :title: Haiku, Flax :sync: def forward(): module = AutoEncoder(256, 784) init = lambda x: module(x) return init, (module.encode, module.decode) model = hk.multi_transform(forward) --- ... model = AutoEncoder(256, 784) To initialize the parameters of our model, ``init`` can be used to trigger the ``__call__`` method, which uses both the ``encode`` and ``decode`` method. This will create all the necessary parameters for the model. .. codediff:: :title: Haiku, Flax :sync: params = model.init( random.key(0), x=jax.numpy.ones((1, 784)), ) ... --- variables = model.init( random.key(0), x=jax.numpy.ones((1, 784)), ) params = variables["params"] This generates the following parameter structure. .. tab-set:: .. tab-item:: Haiku :sync: Haiku .. code-block:: python { 'auto_encoder/~/decoder': { 'b': (784,), 'w': (256, 784) }, 'auto_encoder/~/encoder': { 'b': (256,), 'w': (784, 256) } } .. tab-item:: Flax :sync: Flax .. code-block:: python FrozenDict({ decoder: { bias: (784,), kernel: (256, 784), }, encoder: { bias: (256,), kernel: (784, 256), }, }) Finally, let's explore how we can employ the ``apply`` function to invoke the ``encode`` method: .. codediff:: :title: Haiku, Flax :sync: encode, decode = model.apply z = encode( params, None, # <== rng x=jax.numpy.ones((1, 784)), ) --- ... z = model.apply( {"params": params}, x=jax.numpy.ones((1, 784)), method="encode", ) Because the Haiku ``apply`` function is generated through ``hk.multi_transform``, it's a tuple of two functions which we can unpack into an ``encode`` and ``decode`` function which correspond to the methods on the ``AutoEncoder`` module. In Flax we call the ``encode`` method through passing the method name as a string. Another noteworthy distinction here is that in Haiku, ``rng`` needs to be explicitly passed, even though the module does not use any stochastic operations during ``apply``. In Flax this is not necessary (check out `Randomness and PRNGs in Flax `_). The Haiku ``rng`` is set to ``None`` here, but you could also use ``hk.without_apply_rng`` on the ``apply`` function to remove the ``rng`` argument. Lifted Transforms ----------------- Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms, that wrap JAX transformations in such a way that they can be used with Modules and sometimes provide additional functionality. In this section we will take a look at how to use the lifted version of ``scan`` in both Flax and Haiku to implement a simple RNN layer. To begin, we will first define a ``RNNCell`` module that will contain the logic for a single step of the RNN. We will also define a ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: :title: Haiku, Flax :sync: class RNNCell(hk.Module): def __init__(self, hidden_size: int, name=None): super().__init__(name=name) self.hidden_size = hidden_size def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = hk.Linear(self.hidden_size)(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) --- class RNNCell(nn.Module): hidden_size: int @nn.compact def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = nn.Dense(self.hidden_size)(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In Haiku, we will first initialze the ``RNNCell``, then use it to construct the ``carry``, and finally use ``hk.scan`` to run the ``RNNCell`` over the input sequence. In Flax its done a bit differently, we will use ``nn.scan`` to define a new temporary type that wraps ``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast the ``params`` collection (all steps share the same parameters) and to not split the ``params`` rng stream (so all steps intialize with the same parameters), and finally we will specify that we want scan to run over the second axis of the input and stack the outputs along the second axis as well. We will then use this temporary type immediately to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and the run the ``__call__`` method which will ``scan`` over the sequence. .. codediff:: :title: Haiku, Flax :sync: class RNN(hk.Module): def __init__(self, hidden_size: int, name=None): super().__init__(name=name) self.hidden_size = hidden_size def __call__(self, x): cell = RNNCell(self.hidden_size) carry = cell.initial_state(x.shape[0]) carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0)) y = jnp.swapaxes(y, 0, 1) return y --- class RNN(nn.Module): hidden_size: int @nn.compact def __call__(self, x): rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1)(self.hidden_size) carry = rnn.initial_state(x.shape[0]) carry, y = rnn(carry, x) return y In general, the main difference between lifted transforms between Flax and Haiku is that in Haiku the lifted transforms don't operate over the state, that is, Haiku will handle the ``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the transform. In Flax, the lifted transforms can operate over both variable collections and rng streams, the user must define how different collections are treated by each transform according to the transform's semantics. Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku and Flax. .. codediff:: :title: Haiku, Flax :sync: def forward(x): return RNN(64)(x) model = hk.without_apply_rng(hk.transform(forward)) params = model.init( random.key(0), x=jax.numpy.ones((3, 12, 32)), ) y = model.apply( params, x=jax.numpy.ones((3, 12, 32)), ) --- ... model = RNN(64) variables = model.init( random.key(0), x=jax.numpy.ones((3, 12, 32)), ) params = variables['params'] y = model.apply( {'params': params}, x=jax.numpy.ones((3, 12, 32)), ) The only notable change with respect to the examples in the previous sections is that this time around we used ``hk.without_apply_rng`` in Haiku so we didn't have to pass the ``rng`` argument as ``None`` to the ``apply`` method. Scan over layers ---------------- One very important application of ``scan`` is apply a sequence of layers iteratively over an input, passing the output of each layer as the input to the next layer. This is very useful to reduce compilation time for big models. As an example we will create a simple ``Block`` Module, and then use it inside an ``MLP`` Module that will apply the ``Block`` Module ``num_layers`` times. In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack of ``Block`` Modules. In Flax, the definition of ``Block`` is a little different, ``__call__`` will accept and return a second dummy input/output that in both cases will be ``None``. In ``MLP``, we will use ``nn.scan`` as in the previous example, but by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}`` we are telling ``nn.scan`` create different parameters for each step and slice the ``params`` collection along the first axis, effectively implementing a stack of ``Block`` Modules as in Haiku. .. codediff:: :title: Haiku, Flax :sync: class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x) x = jax.nn.relu(x) return x class MLP(hk.Module): def __init__(self, features: int, num_layers: int, name=None): super().__init__(name=name) self.features = features self.num_layers = num_layers def __call__(self, x, training: bool): @hk.experimental.layer_stack(self.num_layers) def stack_block(x): return Block(self.features)(x, training) stack = hk.experimental.layer_stack(self.num_layers) return stack_block(x) --- class Block(nn.Module): features: int training: bool @nn.compact def __call__(self, x, _): x = nn.Dense(self.features)(x) x = nn.Dropout(0.5)(x, deterministic=not self.training) x = jax.nn.relu(x) return x, None class MLP(nn.Module): features: int num_layers: int @nn.compact def __call__(self, x, training: bool): ScanBlock = nn.scan( Block, variable_axes={'params': 0}, split_rngs={'params': True}, length=self.num_layers) y, _ = ScanBlock(self.features, training)(x, None) return y Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore its second output. These represent the inputs/outputs per-step but they are ``None`` because in this case we don't have any. Initializing each model is the same as in previous examples. In this case, we will be specifying that we want to use ``5`` layers each with ``64`` features. .. codediff:: :title: Haiku, Flax :sync: def forward(x, training: bool): return MLP(64, num_layers=5)(x, training) model = hk.transform(forward) sample_x = jax.numpy.ones((1, 64)) params = model.init( random.key(0), sample_x, training=False # <== inputs ) ... --- ... model = MLP(64, num_layers=5) sample_x = jax.numpy.ones((1, 64)) variables = model.init( random.key(0), sample_x, training=False # <== inputs ) params = variables['params'] When using scan over layers the one thing you should notice is that all layers are fused into a single layer whose parameters have an extra "layer" dimension on the first axis. In this case, the shape of all parameters will start with ``(5, ...)`` as we are using ``5`` layers. .. tab-set:: .. tab-item:: Haiku :sync: Haiku .. code-block:: python ... { 'mlp/__layer_stack_no_per_layer/block/linear': { 'b': (5, 64), 'w': (5, 64, 64) } } ... .. tab-item:: Flax :sync: Flax .. code-block:: python FrozenDict({ ScanBlock_0: { Dense_0: { bias: (5, 64), kernel: (5, 64, 64), }, }, }) Top-level Haiku functions vs top-level Flax modules ----------------------------------- In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and states. It very common to write the top-level "Module" as a function instead: The Flax team recommends a more Module-centric approach that uses `__call__` to define the forward function. The corresponding accessor will be `nn.module.param` and `nn.module.variable` (go to `Handling State <#handling-state>`__ for an explanaion on collections). .. codediff:: :title: Haiku, Flax :sync: def forward(x): counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones) multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones) output = x + multiplier * counter hk.set_state("counter", counter + 1) return output model = hk.transform_with_state(forward) params, state = model.init(random.key(0), jax.numpy.ones((1, 64))) --- class FooModule(nn.Module): @nn.compact def __call__(self, x): counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32)) multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype) output = x + multiplier * counter.value if not self.is_initializing(): # otherwise model.init() also increases it counter.value += 1 return output model = FooModule() variables = model.init(random.key(0), jax.numpy.ones((1, 64))) params, counter = variables['params'], variables['counter'] ================================================ FILE: docs/guides/converting_and_upgrading/index.rst ================================================ Converting and upgrading ======================== .. toctree:: :maxdepth: 1 haiku_migration_guide convert_pytorch_to_flax orbax_upgrade_guide optax_update_guide linen_upgrade_guide rnncell_upgrade_guide regular_dict_upgrade_guide ================================================ FILE: docs/guides/converting_and_upgrading/linen_upgrade_guide.rst ================================================ Upgrading my codebase to Linen ============================== As of Flax v0.4.0, ``flax.nn`` no longer exists, and is replaced with the new Linen API at ``flax.linen``. If your codebase is still using the old API, you can use this upgrade guide to upgrade it to Linen. .. testsetup:: Linen from flax.training import train_state from jax import random import optax import jax import flax.linen as nn from flax.linen import initializers from jax import lax import jax.numpy as jnp import numpy as np from typing import Any, Callable, Sequence, Tuple PRNGKey = Any Shape = Tuple[int, ...] Dtype = Any Array = Any default_kernel_init = initializers.lecun_normal() Defining simple Flax Modules ---------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: from flax import nn class Dense(base.Module): def apply(self, inputs, features, use_bias=True, kernel_init=default_kernel_init, bias_init=initializers.zeros_init()): kernel = self.param('kernel', (inputs.shape[-1], features), kernel_init) y = jnp.dot(inputs, kernel) if use_bias: bias = self.param( 'bias', (features,), bias_init) y = y + bias return y --- from flax import linen as nn # [1] #! class Dense(nn.Module): features: int # [2] #! use_bias: bool = True kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init() @nn.compact def __call__(self, inputs): # [3] #! kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) # [4] #! y = jnp.dot(inputs, kernel) if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,)) # [5] #! y = y + bias return y 1. Replace ``from flax import nn`` with ``from flax import linen as nn``. 2. Move arguments to ``apply`` into dataclass attributes. Add type annotations (or use type ``Any`` to bypass). 3. Rename method ``apply`` to ``__call__`` and (optionally) wrap with |@compact|_. Methods wrapped in |@compact|_ can define submodules directly within the method (like in old Flax). You can only wrap a single method with |@compact|_. Alternatively, you can define a ``setup`` method. For more details, please see our other HOWTO `Should I use setup or nn.compact?`_. 4. Access dataclass attributes values by ``self.`` inside methods, e.g. ``self.features``. 5. Move shape to the end of the arguments to |self.param|_ (initializer functions can take arbitrary argument lists). Using Flax Modules inside other Modules --------------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: class Encoder(nn.Module): def apply(self, x): x = nn.Dense(x, 500) x = nn.relu(x) z = nn.Dense(x, 500, name="latents") return z --- class Encoder(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(500)(x) # [1] #! x = nn.relu(x) z = nn.Dense(500, name='latents')(x) # [2] #! return z 1. Module constructors no longer return the outputs. Instead, they work like normal constructors and return module instances. These instances can be shared like in normal Python (instead of using ``.shared()`` in old Flax). Since most modules implement ``__call__``, you can retain the conciseness of old Flax. 2. Names can be optionally passed to all module constructors. Sharing submodules and defining multiple methods -------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: class AutoEncoder(nn.Module): def _create_submodules(self): return Decoder.shared(name="encoder") def apply(self, x, z_rng, latents=20): decoder = self._create_decoder() z = Encoder(x, latents, name="encoder") return decoder(z) @nn.module_method def generate(self, z, **unused_kwargs): decoder = self._create_decoder() return nn.sigmoid(decoder(z)) --- class AutoEncoder(nn.Module): latents: int = 20 def setup(self): # [1] #! self.encoder = Encoder(self.latents) # [2] #! self.decoder = Decoder() def __call__(self, x): # [3] #! z = self.encoder(x) return self.decoder(z) def generate(self, z): # [4] #! return nn.sigmoid(self.decoder(z)) 1. Use |setup|_ instead of ``__init__``, which is already defined in the dataclasses library. Flax calls setup right after modules are ready to be used. (You can do this for all modules if you like instead of using |@compact|, but we like how |@compact| co-locates where modules are defined and used, especially if you have loops or conditionals). 2. Like regular Python, share submodules by assigning to self during initialization. Similar to PyTorch, ``self.encoder`` automatically has the name ``"encoder"``. 3. We don't use |@compact|_ here because we're not defining any inline submodules (all submodules are defined in setup). 4. Define additional methods just like in regular Python. ``Module.partial`` inside other modules --------------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: # no import #! class ResNet(nn.Module): """ResNetV1.""" def apply(self, x, stage_sizes, num_filters=64, train=True): conv = nn.Conv.partial(bias=False) norm = nn.BatchNorm.partial( use_running_average=not train, momentum=0.9, epsilon=1e-5) x = conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init') x = norm(x, name='bn_init') # [...] return x --- from functools import partial #! class ResNet(nn.Module): """ResNetV1.""" stage_sizes: Sequence[int] num_filters: int = 64 train: bool = True @nn.compact def __call__(self, x): conv = partial(nn.Conv, use_bias=False) #! norm = partial(nn.BatchNorm, #! use_running_average=not self.train, #! momentum=0.9, epsilon=1e-5) #! x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) x = norm(name='bn_init')(x) # [...] return x Use normal ``functools.partial`` instead of ``Module.partial``. The rest stays the same. Top-level training code patterns -------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: def create_model(key): _, initial_params = CNN.init_by_shape( key, [((1, 28, 28, 1), jnp.float32)]) model = nn.Model(CNN, initial_params) return model def create_optimizer(model, learning_rate): optimizer_def = optim.Momentum(learning_rate=learning_rate) optimizer = optimizer_def.create(model) return optimizer def cross_entropy_loss(*, logits, labels): one_hot_labels = jax.nn.one_hot(labels, num_classes=10) return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1)) def loss_fn(model): logits = model(batch['image']) one_hot = jax.nn.one_hot(batch['label'], num_classes=10) loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'], axis=-1)) return loss, logits --- def create_train_state(rng, config): # [1] #! variables = CNN().init(rng, jnp.ones([1, 28, 28, 1])) # [2] #! params = variables['params'] # [3] #! tx = optax.sgd(config.learning_rate, config.momentum) # [4] #! return train_state.TrainState.create( apply_fn=CNN.apply, params=params, tx=tx) def loss_fn(params): logits = CNN().apply({'params': params}, batch['image']) # [5] #! one_hot = jax.nn.one_hot(batch['label'], 10) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits 1. We no longer use the ``Model`` abstraction -- instead we pass parameters around directly, usually encapsulated in a `TrainState`_ object, which can directly be passed to JAX transformations. 2. To compute initial parameters, construct a module instance and call |init|_ or |init_with_output|_. We haven't ported over ``init_by_shape`` because this function did some magic we did not like (it evaluated the function by shape. but returned real values anyway). Therefore, you should now pass concrete values to the initializer functions, and you can optimize the initialization by wrapping it with |jax.jit|_, which is highly recommended to avoid running a full forward pass. 3. Linen generalizes parameters into variables. Parameters are one "collection" of variables. Variables are nested dicts, where the top-level keys reflect the different variable collections, of which "param" is one of. See the `Variables documentation`_ for more details. 4. We recommend using Optax optimizers. See our separate HOWTO called `Upgrading my codebase to Optax`_ for more details. 5. To make predictions with your model, make an instance at the top level (this is free -- just a wrapper around constructor attributes) and call the ``apply`` method (which will call ``__call__`` internally). Non-trainable variables ("state"): Use within Modules ----------------------------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: class BatchNorm(nn.Module): def apply(self, x): # [...] ra_mean = self.state( 'mean', (x.shape[-1], ), initializers.zeros_init()) ra_var = self.state( 'var', (x.shape[-1], ), initializers.ones_init()) # [...] --- class BatchNorm(nn.Module): def __call__(self, x): # [...] ra_mean = self.variable( #! 'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], )) ra_var = self.variable( 'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], )) # [...] The first argument is the name of the variable collection ("param" is the only variable collection that's always available). Some colllections may be treated as mutable, and others as immutable at top-level training code (see next section for details). Flax also lets you treat each variable collection differently when using JAX transformations inside modules. Non-trainable variables ("state"): Top-level training code patterns ------------------------------------------------------------------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: # initial params and state def initial_model(key, init_batch): with nn.stateful() as initial_state: _, initial_params = ResNet.init(key, init_batch) model = nn.Model(ResNet, initial_params) return model, init_state # updates batch statistics during training def loss_fn(model, model_state): with nn.stateful(model_state) as new_model_state: logits = model(batch['image']) # [...] # reads immutable batch statistics during evaluation def eval_step(model, model_state, batch): with nn.stateful(model_state, mutable=False): logits = model(batch['image'], train=False) return compute_metrics(logits, batch['label']) --- # initial variables ({"param": ..., "batch_stats": ...}) def initial_variables(key, init_batch): return ResNet().init(key, init_batch) # [1] #! # updates batch statistics during training def loss_fn(params, batch_stats): variables = {'params': params, 'batch_stats': batch_stats} # [2] #! logits, new_variables = ResNet(train=true).apply( variables, batch['image'], mutable=['batch_stats']) # [3] #! new_batch_stats = new_variables['batch_stats'] # [...] # reads immutable batch statistics during evaluation def eval_step(params, batch_stats, batch): variables = {'params': params, 'batch_stats': batch_stats} logits = ResNet(train=False).apply( variables, batch['image'], mutable=False) # [4] #! return compute_metrics(logits, batch['label']) 1. |init|_ returns a variable dict, e.g. ``{"param": ..., "batch_stats": ...}`` (see `Variables documentation`_). 2. Combine the different variable collections into a variable dict. 3. During training, the ``batch_stats`` variable collection changes. Since we specify that in the mutable argument, the return value from ``module.apply`` becomes an ordered pair of ``output, new_variables``. 4. During evaluation, we want to raise an error if we're accidentally applying Batch Norm in training mode. By passing ``mutable=False`` into ``module.apply`` we enforce that. Since no variables are mutated, the return value is once again just the output. Loading pre-Linen checkpoints ----------------------------- While most Linen modules should be able to use pre-Linen weights without any modification, there is one catch: In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class. In pre-Linen, params have the following structure: ``{'Conv_0': { ... }, 'Dense_1': { ... } }`` In Linen this is instead: ``{'Conv_0': { ... }, 'Dense_0': { ... } }`` TODO: Add an example here how to load a new ``TrainState`` object. Randomness ---------- .. codediff:: :title: Old Flax, Linen :skip_test: Old Flax :sync: def dropout(inputs, rate, deterministic=False): keep_prob = 1. - rate if deterministic: return inputs else: mask = random.bernoulli( make_rng(), p=keep_prob, shape=inputs.shape) return lax.select( mask, inputs / keep_prob, jnp.zeros_like(inputs)) def loss_fn(model, dropout_rng): with nn.stochastic(dropout_rng): logits = model(inputs) --- class Dropout(nn.Module): rate: float @nn.compact def __call__(self, inputs, deterministic=False): keep_prob = 1. - self.rate if deterministic: return inputs else: mask = random.bernoulli( self.make_rng('dropout'), p=keep_prob, shape=inputs.shape) # [1] #! return lax.select( mask, inputs / keep_prob, jnp.zeros_like(inputs)) def loss_fn(params, dropout_rng): logits = Transformer().apply( {'params': params}, inputs, rngs={'dropout': dropout_rng}) # [2] #! 1. RNGs in Linen have "kinds" -- in this case ``'dropout'``. Different kinds can be treated different in JAX transformations (for example, do you want the same dropout mask for each timestep in a sequence model or a different one?) 2. Instead of using the ``nn.stochastic`` context manager, you pass in RNGs explicitly to ``module.apply``. During evaluation you wouldn't pass any RNGs -- then if you accidentally use dropout in non-deterministic mode, ``self.make_rng('dropout')`` would raise an error. Lifted transformations ---------------------- In Linen, rather than using JAX transformation directly, we are using "lifted transforms", which are JAX transformations applied to Flax Modules. For more information, please see the design note on `Lifted transformations`_. TODO: Given an example of ``jax.scan_in_dim`` (pre-Linen) vs. ``nn.scan`` (Linen). .. _`Should I use setup or nn.compact?`: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/setup_or_nncompact.html .. _`Variables documentation`: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/variable.html .. _`TrainState`: https://flax.readthedocs.io/en/latest/flax.training.html#train-state .. _`Upgrading my codebase to Optax`: https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html .. _`Lifted transformations`: https://flax.readthedocs.io/en/latest/developer_notes/lift.html .. |@compact| replace:: ``@compact`` .. _@compact: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact .. |init| replace:: ``init`` .. _init: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init .. |init_with_output| replace:: ``init_with_output`` .. _init_with_output: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init_with_output .. |jax.jit| replace:: ``jax.jit`` .. _jax.jit: https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit .. |self.param| replace:: ``self.param`` .. _self.param: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.param .. |setup| replace:: ``setup`` .. _setup: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.setup .. |@flax.struct.dataclass| replace:: ``@flax.struct.dataclass`` .. _@flax.struct.dataclass: https://flax.readthedocs.io/en/latest/flax.struct.html#flax.struct.dataclass .. |checkpoints.convert_pre_linen()| replace:: ``checkpoints.convert_pre_linen()`` .. _checkpoints.convert_pre_linen(): https://flax.readthedocs.io/en/latest/flax.training.html#flax.training.checkpoints.convert_pre_linen ================================================ FILE: docs/guides/converting_and_upgrading/optax_update_guide.rst ================================================ Upgrading my codebase to Optax ============================== We have proposed to replace :py:mod:`flax.optim` with `Optax `_ in 2021 with `FLIP #1009 `_ and the Flax optimizers have been removed in v0.6.0 - this guide is targeted towards :py:mod:`flax.optim` users to help them update their code to Optax. See also Optax's quick start documentation: https://optax.readthedocs.io/en/latest/getting_started.html .. testsetup:: default, flax.optim, optax import flax import jax import jax.numpy as jnp import flax.linen as nn import optax # Note: this is the minimal code required to make below code run. See in the # Colab linked above for a more meaningful definition of datasets etc. batch = {'image': jnp.ones([1, 28, 28, 1]), 'label': jnp.array([0])} ds_train = [batch] get_ds_train = lambda: [batch] model = nn.Dense(1) variables = model.init(jax.random.key(0), batch['image']) learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1. loss = lambda params, batch: jnp.array(0.) Replacing ``flax.optim`` with ``optax`` --------------------------------------- Optax has drop-in replacements for all of Flax's optimizers. Refer to Optax's documentation `Common Optimizers `_ for API details. The usage is very similar, with the difference that ``optax`` does not keep a copy of the ``params``, so they need to be passed around separately. Flax provides the utility :py:class:`~flax.training.train_state.TrainState` to store optimizer state, parameters, and other associated data in a single dataclass (not used in code below). .. codediff:: :title: flax.optim, optax :skip_test: flax.optim :sync: @jax.jit def train_step(optimizer, batch): grads = jax.grad(loss)(optimizer.target, batch) return optimizer.apply_gradient(grads) optimizer_def = flax.optim.Momentum( learning_rate, momentum) optimizer = optimizer_def.create(variables['params']) for batch in get_ds_train(): optimizer = train_step(optimizer, batch) --- @jax.jit def train_step(params, opt_state, batch): grads = jax.grad(loss)(params, batch) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state tx = optax.sgd(learning_rate, momentum) params = variables['params'] opt_state = tx.init(params) for batch in ds_train: params, opt_state = train_step(params, opt_state, batch) Composable Gradient Transformations ----------------------------------- The function |optax.sgd()|_ used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use |optax.chain()|_ to combine multiple of these generic building blocks. .. |optax.sgd()| replace:: ``optax.sgd()`` .. _optax.sgd(): https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd .. |optax.chain()| replace:: ``optax.chain()`` .. _optax.chain(): https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.chain .. codediff:: :title: Pre-defined alias, Combining transformations :groups: default, default # Note that the aliases follow the convention to use positive # values for the learning rate by default. tx = optax.sgd(learning_rate, momentum) --- # tx = optax.chain( # 1. Step: keep a trace of past updates and add to gradients. optax.trace(decay=momentum), # 2. Step: multiply result from step 1 with negative learning rate. # Note that `optax.apply_updates()` simply adds the final updates to the # parameters, so we must make sure to flip the sign here for gradient # descent. optax.scale(-learning_rate), ) Weight Decay ------------ Some of Flax's optimizers also include a weight decay. In Optax, some optimizers also have a weight decay parameter (such as |optax.adamw()|_), and to others the weight decay can be added as another "gradient transformation" |optax.add_decayed_weights()|_ that adds an update derived from the parameters. .. |optax.adamw()| replace:: ``optax.adamw()`` .. _optax.adamw(): https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw .. |optax.add_decayed_weights()| replace:: ``optax.add_decayed_weights()`` .. _optax.add_decayed_weights(): https://optax.readthedocs.io/en/latest/api/transformations.html#optax.add_decayed_weights .. codediff:: :title: flax.optim, optax :skip_test: flax.optim :sync: optimizer_def = flax.optim.Adam( learning_rate, weight_decay=weight_decay) optimizer = optimizer_def.create(variables['params']) --- # (Note that you could also use `optax.adamw()` in this case) tx = optax.chain( optax.scale_by_adam(), optax.add_decayed_weights(weight_decay), # params -= learning_rate * (adam(grads) + params * weight_decay) optax.scale(-learning_rate), ) # Note that you'll need to specify `params` when computing the udpates: # tx.update(grads, opt_state, params) Gradient Clipping ----------------- Training can be stabilized by clipping gradients to a global norm (`Pascanu et al, 2012 `_). In Flax this is often done by processing the gradients before passing them to the optimizer. With Optax this becomes just another gradient transformation |optax.clip_by_global_norm()|_. .. |optax.clip_by_global_norm()| replace:: ``optax.clip_by_global_norm()`` .. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api/transformations.html#optax.clip_by_global_norm .. codediff:: :title: flax.optim, optax :skip_test: flax.optim :sync: def train_step(optimizer, batch): grads = jax.grad(loss)(optimizer.target, batch) grads_flat, _ = jax.tree_util.tree_flatten(grads) global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat])) g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2) grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads) return optimizer.apply_gradient(grads) --- tx = optax.chain( optax.clip_by_global_norm(grad_clip_norm), optax.trace(decay=momentum), optax.scale(-learning_rate), ) Learning Rate Schedules ----------------------- For learning rate schedules, Flax allows overwriting hyper parameters when applying the gradients. Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with |optax.scale_by_schedule()|_. Optax also allows specifying a functions to inject arbitrary scalar values for other gradient updates via |optax.inject_hyperparams()|_. Read more about learning rate schedules in the :doc:`lr_schedule` guide. Read more about schedules defined in Optax under `Optimizer Schedules `_. the standard optimizers (like ``optax.adam()``, ``optax.sgd()`` etc.) also accept a learning rate schedule as a parameter for ``learning_rate``. .. |optax.scale_by_schedule()| replace:: ``optax.scale_by_schedule()`` .. _optax.scale_by_schedule(): https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale_by_schedule .. |optax.inject_hyperparams()| replace:: ``optax.inject_hyperparams()`` .. _optax.inject_hyperparams(): https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.inject_hyperparams .. codediff:: :title: flax.optim, optax :skip_test: flax.optim :sync: def train_step(step, optimizer, batch): grads = jax.grad(loss)(optimizer.target, batch) return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step)) --- tx = optax.chain( optax.trace(decay=momentum), # Note that we still want a negative value for scaling the updates! optax.scale_by_schedule(lambda step: -schedule(step)), ) Multiple Optimizers / Updating a Subset of Parameters ----------------------------------------------------- In Flax, traversals are used to specify which parameters should be updated by an optimizer. And you can combine traversals using :py:class:`flax.optim.MultiOptimizer` to apply different optimizers on different parameters. The equivalent in Optax is |optax.masked()|_ and |optax.chain()|_. Note that the example below is using :py:mod:`flax.traverse_util` to create the boolean masks required by |optax.masked()|_ - alternatively you could also create them manually, or use |optax.multi_transform()|_ that takes a multivalent pytree to specify gradient transformations. Beware that |optax.masked()|_ flattens the pytree internally and the inner gradient transformations will only be called with that partial flattened view of the params/gradients. This is not a problem usually, but it makes it hard to nest multiple levels of masked gradient transformations (because the inner masks will expect the mask to be defined in terms of the partial flattened view that is not readily available outside the outer mask). .. |optax.masked()| replace:: ``optax.masked()`` .. _optax.masked(): https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html#optax.masked .. |optax.multi_transform()| replace:: ``optax.multi_transform()`` .. _optax.multi_transform(): https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.multi_transform .. codediff:: :title: flax.optim, optax :skip_test: flax.optim :sync: kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) kernel_opt = flax.optim.Momentum(learning_rate, momentum) bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum) optimizer = flax.optim.MultiOptimizer( (kernels, kernel_opt), (biases, bias_opt) ).create(variables['params']) --- kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p) all_false = jax.tree_util.tree_map(lambda _: False, params) kernels_mask = kernels.update(lambda _: True, all_false) biases_mask = biases.update(lambda _: True, all_false) tx = optax.chain( optax.trace(decay=momentum), optax.masked(optax.scale(-learning_rate), kernels_mask), optax.masked(optax.scale(-learning_rate * 0.1), biases_mask), ) Final Words ----------- All above patterns can of course also be mixed and Optax makes it possible to encapsulate all these transformations into a single place outside the main training loop, which makes testing much easier. ================================================ FILE: docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst ================================================ Migrate checkpointing to Orbax ============================== This guide shows how to convert Flax's checkpoint saving and restoring calls — `flax.training.checkpoints.save_checkpoint `__ and `restore_checkpoint `__ — to the equivalent `Orbax `__ methods. Orbax provides a flexible and customizable API for managing checkpoints for various objects. Note that as Flax's checkpointing is being migrated to Orbax from ``flax.training.checkpoints``, all existing features in the Flax API will continue to be supported, but the API will change. You will learn how to migrate to Orbax through the following scenarios: * The most common use case: Saving/loading and managing checkpoints * A "lightweight" use case: "Pure" saving/loading without the top-level checkpoint manager * Restoring checkpoints without a target pytree * Async checkpointing * Saving/loading a single JAX or NumPy Array To learn more about Orbax, check out the `quick start introductory Colab notebook `__ and `the official Orbax documentation `_. You can click on "Open in Colab" above to run the code from this guide. Throughout the guide, you will be able to compare code examples with and without the Orbax code. .. testsetup:: orbax.checkpoint import flax from flax.training import checkpoints, orbax_utils import orbax import jax import jax.numpy as jnp import numpy as np # Orbax needs to have asyncio enabled in the Colab environment. import nest_asyncio nest_asyncio.apply() # Set up the directory. import os import shutil if os.path.exists('/tmp/orbax_upgrade'): shutil.rmtree('/tmp/orbax_upgrade') os.makedirs('/tmp/orbax_upgrade') Setup ***** .. testcode:: orbax.checkpoint # Create some dummy variables for this example. MAX_STEPS = 5 CKPT_PYTREE = [12, {'bar': np.array((2, 3))}, [1, 4, 10]] TARGET_PYTREE = [0, {'bar': np.array((0))}, [0, 0, 0]] Most common use case: Saving/loading and managing checkpoints ************************************************************* This section covers the following scenario: * Your original Flax ``save_checkpoint()`` or ``save_checkpoint_multiprocess()`` call contains the following arguments: ``prefix``, ``keep``, ``keep_every_n_steps``; or * You want to use some automatic management logic for your checkpoints (for example, for deleting old data, deleting data based on metrics/loss, and so on). In this case, you need to use ``orbax.CheckpointManager``. This allows you to not only save and load your model, but also manage your checkpoints and delete outdated checkpoints *automatically*. To upgrade your code: 1. Create and keep an ``orbax.CheckpointManager`` instance at the top level, customized with ``orbax.CheckpointManagerOptions``. 2. At runtime, call ``orbax.CheckpointManager.save()`` to save your data. 3. Then, call ``orbax.CheckpointManager.restore()`` to restore your data. 4. And, if your checkpoint includes some multi-host/multi-process array, pass the correct ``mesh`` into ``flax.training.orbax_utils.restore_args_from_target()`` to generate the correct ``restore_args`` before restoring. For example: .. codediff:: :title: flax.checkpoints, orbax.checkpoint :skip_test: flax.checkpoints :sync: CKPT_DIR = '/tmp/orbax_upgrade/' flax.config.update('flax_use_orbax_checkpointing', False) # Inside your training loop for step in range(MAX_STEPS): # do training checkpoints.save_checkpoint(CKPT_DIR, CKPT_PYTREE, step=step, prefix='test_', keep=3, keep_every_n_steps=2) checkpoints.restore_checkpoint(CKPT_DIR, target=TARGET_PYTREE, step=4, prefix='test_') --- CKPT_DIR = '/tmp/orbax_upgrade/orbax' # At the top level mgr_options = orbax.checkpoint.CheckpointManagerOptions( create=True, max_to_keep=3, keep_period=2, step_prefix='test') ckpt_mgr = orbax.checkpoint.CheckpointManager( CKPT_DIR, orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()), mgr_options) # Inside your training loop for step in range(MAX_STEPS): # do training save_args = flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE) ckpt_mgr.save(step, CKPT_PYTREE, save_kwargs={'save_args': save_args}) restore_args = flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None) ckpt_mgr.restore(4, items=TARGET_PYTREE, restore_kwargs={'restore_args': restore_args}) A "lightweight" use case: "Pure" saving/loading without the top-level checkpoint manager **************************************************************************************** If you prefer to not maintain a top-level checkpoint manager, you can still save and restore any individual checkpoint with an ``orbax.checkpoint.Checkpointer``. Note that this means you cannot use all the Orbax management features. To migrate to Orbax code, instead of using the ``overwrite`` argument in ``flax.save_checkpoint()`` use the ``force`` argument in ``orbax.checkpoint.Checkpointer.save()``. For example: .. codediff:: :title: flax.checkpoints, orbax.checkpoint :skip_test: flax.checkpoints :sync: PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure' flax.config.update('flax_use_orbax_checkpointing', False) checkpoints.save_checkpoint(PURE_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True) checkpoints.restore_checkpoint(PURE_CKPT_DIR, target=TARGET_PYTREE) --- PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure' ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) # A stateless object, can be created on the fly. ckptr.save(PURE_CKPT_DIR, CKPT_PYTREE, save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE), force=True) ckptr.restore(PURE_CKPT_DIR, item=TARGET_PYTREE, restore_args=flax.training.orbax_utils.restore_args_from_target(TARGET_PYTREE, mesh=None)) Restoring checkpoints without a target pytree ********************************************* If you need to restore your checkpoints without a target pytree, pass ``item=None`` to ``orbax.checkpoint.Checkpointer`` or ``items=None`` to ``orbax.CheckpointManager``'s ``.restore()`` method, which should trigger the restoration. For example: .. codediff:: :title: flax.checkpoints, orbax.checkpoint :skip_test: flax.checkpoints :sync: NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target' flax.config.update('flax_use_orbax_checkpointing', False) checkpoints.save_checkpoint(NOTARGET_CKPT_DIR, CKPT_PYTREE, step=0) checkpoints.restore_checkpoint(NOTARGET_CKPT_DIR, target=None) --- NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target' # A stateless object, can be created on the fly. ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) ckptr.save(NOTARGET_CKPT_DIR, CKPT_PYTREE, save_args=flax.training.orbax_utils.save_args_from_target(CKPT_PYTREE)) ckptr.restore(NOTARGET_CKPT_DIR, item=None) Async checkpointing ******************* To make your checkpoint-saving asynchronous, substitute ``orbax.checkpoint.Checkpointer`` with ``orbax.checkpoint.AsyncCheckpointer``. Then, you can call ``orbax.checkpoint.AsyncCheckpointer.wait_until_finished()`` or Orbax's ``CheckpointerManager.wait_until_finished()`` to wait for the save the complete. For more details, read the `checkpoint guide `_. You can also use Orbax AsyncCheckpointer with Flax APIs through async manager. Async manager internally calls wait_until_finished(). This solution is not actively maintained and the recommedation is to use Orbax async checkpointing. For example: .. codediff:: :title: flax.checkpoints, orbax.checkpoint :skip_test: flax.checkpoints :sync: ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async' flax.config.update('flax_use_orbax_checkpointing', True) async_manager = checkpoints.AsyncManager() checkpoints.save_checkpoint(ASYNC_CKPT_DIR, CKPT_PYTREE, step=0, overwrite=True, async_manager=async_manager) checkpoints.restore_checkpoint(ASYNC_CKPT_DIR, target=TARGET_PYTREE) --- ASYNC_CKPT_DIR = '/tmp/orbax_upgrade/async' import orbax.checkpoint as ocp ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) ckptr.save(ASYNC_CKPT_DIR, args=ocp.args.StandardSave(CKPT_PYTREE)) # ... Continue with your work... # ... Until a time when you want to wait until the save completes: ckptr.wait_until_finished() # Blocks until the checkpoint saving is completed. ckptr.restore(ASYNC_CKPT_DIR, args=ocp.args.StandardRestore(TARGET_PYTREE)) Saving/loading a single JAX or NumPy Array ****************************************** The ``orbax.checkpoint.PyTreeCheckpointHandler`` class, as the name suggests, can only be used for pytrees. Therefore, if you need to save/restore a single pytree leaf (for example, an array), use ``orbax.checkpoint.ArrayCheckpointHandler`` instead. For example: .. codediff:: :title: flax.checkpoints, orbax.checkpoint :skip_test: flax.checkpoints :sync: ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton' flax.config.update('flax_use_orbax_checkpointing', False) checkpoints.save_checkpoint(ARR_CKPT_DIR, jnp.arange(10), step=0) checkpoints.restore_checkpoint(ARR_CKPT_DIR, target=None) --- ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton' ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.ArrayCheckpointHandler()) ckptr.save(ARR_CKPT_DIR, jnp.arange(10)) ckptr.restore(ARR_CKPT_DIR, item=None) Final words *********** This guide provides an overview of how to migrate from the "legacy" Flax checkpointing API to the Orbax API. Orbax provides more functionalities and the Orbax team is actively developing new features. Stay tuned and follow the `official Orbax GitHub repository `__ for more! ================================================ FILE: docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst ================================================ Migrate to regular dicts ======================== Flax will migrate from returning ``FrozenDicts`` to regular dicts when calling :meth:`.init `, :meth:`.init_with_output ` and :meth:`.apply ` ``Module`` methods. The original issue is outlined `here `__. This guide shows some common upgrade patterns. Utility functions ----------------- ``FrozenDicts`` are immutable dictionaries that implement an additional 4 methods: * :meth:`copy ` * :meth:`pop ` * :meth:`pretty_repr ` * :meth:`unfreeze ` To accommodate the regular dict change, replace usage of ``FrozenDict`` methods with their utility function equivalent from ``flax.core.frozen_dict``. These utility functions mimic the behavior of their corresponding ``FrozenDict`` method, and can be called on either ``FrozenDicts`` or regular dicts. The following are the utility functions and example upgrade patterns: .. testsetup:: default, Only ``FrozenDict``, Both ``FrozenDict`` and regular dict import flax import flax.linen as nn import jax import jax.numpy as jnp x = jnp.empty((1,3)) variables = flax.core.freeze(nn.Dense(5).init(jax.random.key(0), x)) other_variables = jnp.array([1, 1, 1, 1, 1], dtype=jnp.float32) :meth:`copy ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: variables = variables.copy(add_or_replace={'other_variables': other_variables}) --- variables = flax.core.copy(variables, add_or_replace={'other_variables': other_variables}) :meth:`pop ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: state, params = variables.pop('params') --- state, params = flax.core.pop(variables, 'params') :meth:`pretty_repr ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: str_repr = variables.pretty_repr() --- str_repr = flax.core.pretty_repr(variables) :meth:`unfreeze ` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: variables = variables.unfreeze() --- variables = flax.core.unfreeze(variables) Modifying config values ----------------------- A temporary feature flag ``flax_return_frozendict`` is set up to help with the migration. To toggle behavior between returning FrozenDict and regular dict variables at runtime, run ``flax.config.update('flax_return_frozendict', )`` in your code. For example: .. testcode:: x = jnp.empty((1,3)) flax.config.update('flax_return_frozendict', True) # set Flax to return FrozenDicts variables = nn.Dense(5).init(jax.random.key(0), x) assert isinstance(variables, flax.core.FrozenDict) flax.config.update('flax_return_frozendict', False) # set Flax to return regular dicts variables = nn.Dense(5).init(jax.random.key(0), x) assert isinstance(variables, dict) Alternatively, the environment variable ``flax_return_frozendict`` (found `here `__) can be directly modified in the Flax source code. Migration status -------------- As of July 19th, 2023, ``flax_return_frozendict`` is set to ``False`` (see `#3193 `__), meaning Flax will default to returning regular dicts from version `0.7.1 `__ onward. This flag can be flipped to ``True`` temporarily to have Flax return ``Frozendicts``. However this feature flag will eventually be removed in the future. ================================================ FILE: docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst ================================================ RNNCellBase Upgrade Guide ========================= The ``RNNCellBase`` API has undergone some key updates aimed at enhancing usability: - The ``initialize_carry`` method has transitioned from a class method to an instance method, simplifying its application. - All necessary metadata is now stored directly within the cell instance, providing a streamlined method signature. This guide will walk you through these changes, demonstrating how to update your existing code to align with these enhancements. Basic Usage ----------- .. testsetup:: New import flax.linen as nn import jax.numpy as jnp import jax import functools Let's begin by defining some variables and a sample input that represents a batch of sequences: .. testcode:: New batch_size = 32 seq_len = 10 in_features = 64 out_features = 128 x = jnp.ones((batch_size, seq_len, in_features)) First and foremost, it's important to note that all metadata, including the number of features, carry initializer, and so on, is now stored within the cell instance: .. codediff:: :title: Legacy, New :skip_test: Legacy :sync: cell = nn.LSTMCell() --- cell = nn.LSTMCell(features=out_features) A significant change is that ``initialize_carry`` has been transitioned into an instance method. Given that the cell instance now contains all metadata, the ``initialize_carry`` method's signature only requires a PRNG key and a sample input: .. codediff:: :title: Legacy, New :skip_test: Legacy :sync: carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features) --- carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape) Here, ``x[:, 0].shape`` represents the input for the cell (without the time dimension). You can also just create the input shape directly when its more convenient: .. testcode:: New carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features)) Upgrade Patterns ----------------- The following sections will demonstrate some useful patterns for updating your code to align with the new API. First, we will show how to upgrade a ``Module`` that wraps a cell, applies the scan logic during ``__call__``, and has a static ``initialize_carry`` method. Here, we will try to make the minimal amount of changes to the code to get it working, albeit not in the most idiomatic way: .. codediff:: :title: Legacy, New :skip_test: Legacy :sync: class SimpleLSTM(nn.Module): @functools.partial( nn.transforms.scan, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): return nn.OptimizedLSTMCell()(carry, x) @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell.initialize_carry( jax.random.key(0), batch_dims, hidden_size) --- class SimpleLSTM(nn.Module): @functools.partial( nn.transforms.scan, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): features = carry[0].shape[-1] return nn.OptimizedLSTMCell(features)(carry, x) @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry( jax.random.key(0), (*batch_dims, hidden_size)) Notice how in the new version, we have to extract the number of features from the carry during ``__call__``, and use ``parent=None`` during ``initialize_carry`` to avoid some potential side effects. Next, we will show a more idiomatic way of writing a similar LSTM module. The main change here will be that we will add a ``features`` attribute to the module and use it to initialize a ``nn.scan``-ed version of the cell in the ``setup`` method: .. codediff:: :title: Legacy, New :skip_test: Legacy :sync: class SimpleLSTM(nn.Module): @functools.partial( nn.transforms.scan, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): return nn.OptimizedLSTMCell()(carry, x) @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell.initialize_carry( jax.random.key(0), batch_dims, hidden_size) model = SimpleLSTM() carry = SimpleLSTM.initialize_carry((batch_size,), out_features) variables = model.init(jax.random.key(0), carry, x) --- class SimpleLSTM(nn.Module): features: int def setup(self): self.scan_cell = nn.transforms.scan( nn.OptimizedLSTMCell, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False})(self.features) @nn.compact def __call__(self, x): carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape) return self.scan_cell(carry, x)[1] # only return the output model = SimpleLSTM(features=out_features) variables = model.init(jax.random.key(0), x) Because the ``carry`` can be easily initialized from the sample input, we can move the call to ``initialize_carry`` into the ``__call__`` method, somewhat simplifying the code. Development Notes ----------------- When developing a new cell, consider the following: * Include necessary metadata as instance attributes. * The ``initialize_carry`` now only requires a PRNG key and a sample input. * A new ``num_feature_axes`` property is required to specify the number of feature dimensions. .. code-block:: class LSTMCell(nn.RNNCellBase): features: int # ← All metadata is now stored within the cell instance ... # ↓ carry_init: Initializer def initialize_carry(self, rng, input_shape) -> Carry: ... @property def num_feature_axes(self): return 1 ``num_feature_axes`` is a new API feature that allows code handling arbitrary ``RNNCellBase`` instances, such as the ``RNN`` Module, to infer the number of batch dimensions and determine the position of the time axis. ================================================ FILE: docs/guides/data_preprocessing/full_eval.rst ================================================ Processing the entire Dataset ============================= For efficiency reasons, we form batches that contain multiple examples and process them in parallel. Especially when evaluating a model, it is important that we process all examples and **avoid losing the remainder** of examples that does not form a complete batch at the end. The problem ----------- When evaluating on a single device, one can either drop the last incomplete batch, or one can form a last batch with a shape different from the preceding batches. Doing the latter has the disadvantage that this will trigger a **recompilation** of the ``eval_step()`` because XLA is not shape polymorphic. .. code-block:: python collections.Counter( tuple(batch['image'].shape) for batch in tfds.load('mnist', split='test').batch(per_device_batch_size) ) # output: # Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19}) The problem is accentuated when using multiple devices for data parallelism. If the batch size is not **divisible by the number devices**, then that last step must be executed on a single device (or a subset of devices). Usually one would drop the last batch, but this will lead to incorrect results. .. code-block:: python sum( np.prod(batch['label'].shape) for batch in tfds.load('mnist', split='test') .batch(per_device_batch_size, drop_remainder=True) .batch(jax.local_device_count()) ) # output: # 9728 Using multiple hosts further complicates the situation because JAX uses the SPMD paradigm and every host must execute the same program. We would usually form non-overlapping splits for different hosts with |tfds.split_for_jax_process()|_, but this can lead to **different numbers for different hosts**, resulting in different JAX programs when all examples are to be processed. .. code-block:: python process_count = 6 [ len(tfds.load(dataset_name, split=tfds.split_for_jax_process( 'test', process_index=process_index, process_count=process_count))) for process_index in range(process_count) ] # output: # [1667, 1667, 1667, 1667, 1666, 1666] .. |tfds.split_for_jax_process()| replace:: ``tfds.split_for_jax_process()`` .. _tfds.split_for_jax_process(): https://www.tensorflow.org/datasets/api_docs/python/tfds/split_for_jax_process The solution: padding --------------------- Even though it's possible to solve this problem by cleverly adjusting the number of batches executed by different devices on different hosts, such a solution quickly becomes complicated and makes the main eval loop hard to read with a lot of cumbersome logic. The more straightforward solution to this problem is to use padding at the end of the dataset to make sure that the last batch has the same size as the preceding batches. Manual implementation ~~~~~~~~~~~~~~~~~~~~~ The last batch is manually padded to contain the same number of examples as in the preceding batches. The predictions for the padded examples are discarded from the computation. .. code-block:: python shard = lambda x: einops.rearrange( x, '(d b) ... -> d b ...', d=jax.local_device_count()) unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...') correct = total = 0 for batch in ds.as_numpy_iterator(): images = batch['image'] n = len(images) padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype) padded_images = np.concatenate([images, padding]) preds = unshard(get_preds(variables, shard(padded_images)))[:n] total += n correct += (batch['label'] == preds.argmax(axis=-1)).sum() Using ``pad_shard_unpad()`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The above pattern, namely the pad→shard→predict→unshard→unpad sequence, can be extracted into a utility wrapper ``pad_shard_unpad()``, which greatly simplifies above evaluation loop. .. code-block:: python correct = total = 0 for batch in ds.as_numpy_iterator(): preds = flax.jax_utils.pad_shard_unpad(get_preds)( vs, batch['image'], min_device_batch=per_device_batch_size) total += len(batch['image']) correct += (batch['label'] == preds.argmax(axis=-1)).sum() Computing metrics in ``eval_step()`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Instead of returning the predictions and computing the metrics in the main evaluation loop, we would often want to make the metric computation part of the evaluation step, especially when using libraries like |jax_metrics|_, or |clu.metrics|_. In that case we would want to pass the metrics as a ``static_argnums`` (i.e. do not shard/pad it), and treat the return value as ``static_return`` too (i.e. no un-sharding or un-padding): .. code-block:: python def eval_step(metrics, variables, batch): print('retrigger compilation', {k: v.shape for k, v in batch.items()}) preds = model.apply(variables, batch['image']) correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum() total = batch['mask'].sum() return dict( correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'), total=metrics['total'] + jax.lax.psum(total, axis_name='batch'), ) eval_step = jax.pmap(eval_step, axis_name='batch') eval_step = flax.jax_utils.pad_shard_unpad( eval_step, static_argnums=(0, 1), static_return=True) .. |jax_metrics| replace:: ``clu.metrics`` .. _jax_metrics: https://github.com/cgarciae/jax_metrics .. |clu.metrics| replace:: ``clu.metrics`` .. _clu.metrics: https://github.com/google/CommonLoopUtils/blob/main/clu/metrics.py Adding "infinite padding" ~~~~~~~~~~~~~~~~~~~~~~~~~ The above solution works in most cases, but it has some limitations: 1. In the rare case where even splitting of the dataset on multiple hosts leads to a different number of batches. Imagine having a dataset of ``n=4097`` examples, and evaluating this on ``h=8``, each having ``d=8`` local devices, and forming on-device batch sizes of ``b=128``. With even dataset splitting, the first host would get ``4096/8+1==513`` examples, and all other hosts would get ``4096/8==512`` examples. Forming per-host batches of ``d*b==512`` this would lead to two batches on the first host, and a single batch on all other hosts, violating SPMD principles and hanging the multi-host setup in the last ``psum()`` directive (which would only be executed by the first host, but not the others). 2. When dropping examples dynamically by using ``ds.filter()``. In these more complicated cases we could add "infinite padding" to the dataset, on each of the hosts independently, and continuing processing examples until *all* hosts run out of unpadded examples. .. code-block:: python correct = total = 0 for batch in ds.as_numpy_iterator(): n = count_p(batch['mask'])[0].item() # adds sync barrier if not n: break preds = get_preds(vs, batch['image']).argmax(axis=-1) total += n correct += count_correct_p(batch['label'], preds, batch['mask'])[0] ================================================ FILE: docs/guides/data_preprocessing/index.rst ================================================ Data preprocessing ================= .. toctree:: :maxdepth: 1 full_eval loading_datasets ================================================ FILE: docs/guides/data_preprocessing/loading_datasets.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading datasets\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/data_preprocessing/loading_datasets.ipynb)\n", "\n", "A neural net written in Jax+Flax expects its input data as `jax.numpy` array instances. Therefore, loading a dataset from any source is as simple as converting it to `jax.numpy` types and reshaping it to the appropriate dimensions for your network.\n", "\n", "As an example, this guide demonstrates how to import [MNIST](http://yann.lecun.com/exdb/mnist/) using the APIs from Torchvision, Tensorflow, and Hugging Face. We'll load the whole dataset into memory. For datasets that don't fit into memory the process is analogous but should be done in a batchwise fashion.\n", "\n", "The MNIST dataset consists of greyscale images of 28x28 pixels of handwritten digits, and has a designated 60k/10k train/test split. The task is to predict the correct class (digit 0, ..., 9) of each image.\n", "\n", "Assuming a CNN-based classifier, the input data should have shape `(B, 28, 28, 1)`, where the trailing singleton dimension denotes the greyscale image channel.\n", "\n", "The labels are simply the integer denoting the digit corresponding to the image. Labels should therefore have shape `(B,)`, to enable loss computation with [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "import numpy as np\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading from `torchvision.datasets`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "import torchvision" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "def get_dataset_torch():\n", " mnist = {\n", " 'train': torchvision.datasets.MNIST('./data', train=True, download=True),\n", " 'test': torchvision.datasets.MNIST('./data', train=False, download=True)\n", " }\n", "\n", " ds = {}\n", "\n", " for split in ['train', 'test']:\n", " ds[split] = {\n", " 'image': mnist[split].data.numpy(),\n", " 'label': mnist[split].targets.numpy()\n", " }\n", "\n", " # cast from np to jnp and rescale the pixel values from [0,255] to [0,1]\n", " ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n", " ds[split]['label'] = jnp.int16(ds[split]['label'])\n", "\n", " # torchvision returns shape (B, 28, 28).\n", " # hence, append the trailing channel dimension.\n", " ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)\n", "\n", " return ds['train'], ds['test']" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "outputId": "be39b756-d13e-4380-b99e-a5cbf61458cc", "tags": [ "skip-execution" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000, 28, 28, 1) float32\n", "(60000,) int16\n", "(10000, 28, 28, 1) float32\n", "(10000,) int16\n" ] } ], "source": [ "train, test = get_dataset_torch()\n", "print(train['image'].shape, train['image'].dtype)\n", "print(train['label'].shape, train['label'].dtype)\n", "print(test['image'].shape, test['image'].dtype)\n", "print(test['label'].shape, test['label'].dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading from `tensorflow_datasets`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "import tensorflow_datasets as tfds" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "def get_dataset_tf():\n", " mnist = tfds.builder('mnist')\n", " mnist.download_and_prepare()\n", "\n", " ds = {}\n", "\n", " for split in ['train', 'test']:\n", " ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))\n", "\n", " # cast to jnp and rescale pixel values\n", " ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n", " ds[split]['label'] = jnp.int16(ds[split]['label'])\n", "\n", " return ds['train'], ds['test']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "outputId": "25d2c468-cbc8-4971-a738-1295ce8c6f16", "tags": [ "skip-execution" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000, 28, 28, 1) float32\n", "(60000,) int16\n", "(10000, 28, 28, 1) float32\n", "(10000,) int16\n" ] } ], "source": [ "train, test = get_dataset_tf()\n", "print(train['image'].shape, train['image'].dtype)\n", "print(train['label'].shape, train['label'].dtype)\n", "print(test['image'].shape, test['image'].dtype)\n", "print(test['label'].shape, test['label'].dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading from 🤗 Hugging Face `datasets`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "def get_dataset_hf():\n", " mnist = load_dataset(\"mnist\")\n", "\n", " ds = {}\n", "\n", " for split in ['train', 'test']:\n", " ds[split] = {\n", " 'image': np.array([np.array(im) for im in mnist[split]['image']]),\n", " 'label': np.array(mnist[split]['label'])\n", " }\n", "\n", " # cast to jnp and rescale pixel values\n", " ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n", " ds[split]['label'] = jnp.int16(ds[split]['label'])\n", "\n", " # append trailing channel dimension\n", " ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)\n", "\n", " return ds['train'], ds['test']" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "outputId": "b026b33f-3bdd-4d26-867c-49400fff1c96", "tags": [ "skip-execution" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000, 28, 28, 1) float32\n", "(60000,) int16\n", "(10000, 28, 28, 1) float32\n", "(10000,) int16\n" ] } ], "source": [ "train, test = get_dataset_hf()\n", "print(train['image'].shape, train['image'].dtype)\n", "print(train['label'].shape, train['label'].dtype)\n", "print(test['image'].shape, test['image'].dtype)\n", "print(test['label'].shape, test['label'].dtype)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/guides/data_preprocessing/loading_datasets.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Loading datasets [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/data_preprocessing/loading_datasets.ipynb) A neural net written in Jax+Flax expects its input data as `jax.numpy` array instances. Therefore, loading a dataset from any source is as simple as converting it to `jax.numpy` types and reshaping it to the appropriate dimensions for your network. As an example, this guide demonstrates how to import [MNIST](http://yann.lecun.com/exdb/mnist/) using the APIs from Torchvision, Tensorflow, and Hugging Face. We'll load the whole dataset into memory. For datasets that don't fit into memory the process is analogous but should be done in a batchwise fashion. The MNIST dataset consists of greyscale images of 28x28 pixels of handwritten digits, and has a designated 60k/10k train/test split. The task is to predict the correct class (digit 0, ..., 9) of each image. Assuming a CNN-based classifier, the input data should have shape `(B, 28, 28, 1)`, where the trailing singleton dimension denotes the greyscale image channel. The labels are simply the integer denoting the digit corresponding to the image. Labels should therefore have shape `(B,)`, to enable loss computation with [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). ```{code-cell} ipython3 :tags: [skip-execution] import numpy as np import jax.numpy as jnp ``` ## Loading from `torchvision.datasets` ```{code-cell} ipython3 :tags: [skip-execution] import torchvision ``` ```{code-cell} ipython3 :tags: [skip-execution] def get_dataset_torch(): mnist = { 'train': torchvision.datasets.MNIST('./data', train=True, download=True), 'test': torchvision.datasets.MNIST('./data', train=False, download=True) } ds = {} for split in ['train', 'test']: ds[split] = { 'image': mnist[split].data.numpy(), 'label': mnist[split].targets.numpy() } # cast from np to jnp and rescale the pixel values from [0,255] to [0,1] ds[split]['image'] = jnp.float32(ds[split]['image']) / 255 ds[split]['label'] = jnp.int16(ds[split]['label']) # torchvision returns shape (B, 28, 28). # hence, append the trailing channel dimension. ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3) return ds['train'], ds['test'] ``` ```{code-cell} ipython3 :outputId: be39b756-d13e-4380-b99e-a5cbf61458cc :tags: [skip-execution] train, test = get_dataset_torch() print(train['image'].shape, train['image'].dtype) print(train['label'].shape, train['label'].dtype) print(test['image'].shape, test['image'].dtype) print(test['label'].shape, test['label'].dtype) ``` ## Loading from `tensorflow_datasets` ```{code-cell} ipython3 :tags: [skip-execution] import tensorflow_datasets as tfds ``` ```{code-cell} ipython3 :tags: [skip-execution] def get_dataset_tf(): mnist = tfds.builder('mnist') mnist.download_and_prepare() ds = {} for split in ['train', 'test']: ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1)) # cast to jnp and rescale pixel values ds[split]['image'] = jnp.float32(ds[split]['image']) / 255 ds[split]['label'] = jnp.int16(ds[split]['label']) return ds['train'], ds['test'] ``` ```{code-cell} ipython3 :outputId: 25d2c468-cbc8-4971-a738-1295ce8c6f16 :tags: [skip-execution] train, test = get_dataset_tf() print(train['image'].shape, train['image'].dtype) print(train['label'].shape, train['label'].dtype) print(test['image'].shape, test['image'].dtype) print(test['label'].shape, test['label'].dtype) ``` ## Loading from 🤗 Hugging Face `datasets` ```{code-cell} ipython3 :tags: [skip-execution] #!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install from datasets import load_dataset ``` ```{code-cell} ipython3 :tags: [skip-execution] def get_dataset_hf(): mnist = load_dataset("mnist") ds = {} for split in ['train', 'test']: ds[split] = { 'image': np.array([np.array(im) for im in mnist[split]['image']]), 'label': np.array(mnist[split]['label']) } # cast to jnp and rescale pixel values ds[split]['image'] = jnp.float32(ds[split]['image']) / 255 ds[split]['label'] = jnp.int16(ds[split]['label']) # append trailing channel dimension ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3) return ds['train'], ds['test'] ``` ```{code-cell} ipython3 :outputId: b026b33f-3bdd-4d26-867c-49400fff1c96 :tags: [skip-execution] train, test = get_dataset_hf() print(train['image'].shape, train['image'].dtype) print(train['label'].shape, train['label'].dtype) print(test['image'].shape, test['image'].dtype) print(test['label'].shape, test['label'].dtype) ``` ================================================ FILE: docs/guides/flax_fundamentals/arguments.md ================================================ # Dealing with Flax Module arguments ## Introduction In Flax Linen we can define `Module` arguments either as dataclass attributes or as arguments to methods (usually `__call__`). Typically the distinction is clear: * Completely fixed properties, such as the choice of kernel initializer or number of output features, are hyperparameters and should be defined as dataclass attributes. Typically two Module instances with different hyperparameters cannot share in a meaningful way. * Dynamic properties, such as input data and top-level "mode switches" like `train=True/False`, should be passed as arguments to `__call__` or another method. Some cases are however less clear cut. Take for example the `Dropout` module. We have a number of clear hyperparameters: 1. The dropout rate 2. The axes for which a dropout mask is generated And some clear call time arguments: 1. The input that should be masked using dropout 2. The (optional) rng used to sample the random mask There is however one property that is ambiguous -- the `deterministic` property in a Dropout module. If `deterministic` is `True` no dropout mask is sampled. This is typically used during model evaluation. However, if we pass `eval=True` or `train=False` to a top-level Module. The `deterministic` argument needs to be applied everywhere and the boolean argument needs to be passed down to all the layers that might use `Dropout`. If instead `deterministic` is a dataclass attribute, we might do the following: ```python from functools import partial from flax import linen as nn class ResidualModel(nn.Module): drop_rate: float @nn.compact def __call__(self, x, *, train): dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train) for i in range(10): x += ResidualBlock(dropout=dropout, ...)(x) ``` It makes sense to pass `determinstic` to the constructor here because this way we can pass the dropout template to the sub-modules. Now the sub-module no longer needs to take care of train vs eval mode and can simply use the `dropout` argument. Note that because the dropout layer can only be constructed in the sub-module we can only partially apply `deterministic` to the constructor but not to `__call__`. However, if `deterministic` is a dataclass attribute we run into trouble when using the setup pattern. We would **want** to write our module code like this: ```python class SomeModule(nn.Module): drop_rate: float def setup(self): self.dropout = nn.Dropout(rate=self.drop_rate) @nn.compact def __call__(self, x, *, train): # ... x = self.dropout(x, deterministic=not train) # ... ``` But, as defined above, `deterministic` would be an attribute, so this doesn't work. Here it makes sense to pass `deterministic` during `__call__` because it depends on the `train` argument. ## Solution We can support both use cases described before by allowing certain properties to be passed as dataclass attributes or as method argument (but not both!). This can be implemented as follows: ```python class MyDropout(nn.Module): drop_rate: float deterministic: Optional[bool] = None @nn.compact def __call__(self, x, deterministic=None): deterministic = nn.merge_param('deterministic', self.deterministic, deterministic) # ... ``` In this example `nn.merge_param` will ensure that either `self.deterministic` or `deterministic` is set but not both. An error is raised if both values are `None` or both values are not `None`. This avoids confusing behavior where 2 different parts of the code set the same parameter and one is overruled by the other. It also avoids a default value which would probably cause either the train step or eval step of a training procedure to be broken by default. ## Functional Core Functional core defines functions rather than classes. Therefore, there is no clear distinction between hyperparameters and call-time arguments. The only way to pre-determine the hyperparameters is by using `partial`. On the upside, there are no ambiguous cases where method arguments could also be attributes. ================================================ FILE: docs/guides/flax_fundamentals/flax_basics.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb)\n", "\n", "# Flax Basics\n", "\n", "This notebook will walk you through the following workflow:\n", "\n", "* Instantiating a model from Flax built-in layers or third-party models.\n", "* Initializing parameters of the model and manually written training.\n", "* Using optimizers provided by Flax to ease training.\n", "* Serialization of parameters and other objects.\n", "* Creating your own models and managing state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up our environment\n", "\n", "Here we provide the code needed to set up the environment for our notebook." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "outputId": "e30aa464-fa52-4f35-df96-716c68a4b3ee", "tags": [ "skip-execution" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n", "\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n" ] } ], "source": [ "# Install the latest JAXlib version.\n", "!pip install --upgrade -q pip jax jaxlib\n", "# Install Flax at head:\n", "!pip install --upgrade -q git+https://github.com/google/flax.git" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import jax\n", "from typing import Any, Callable, Sequence\n", "from jax import random, numpy as jnp\n", "import flax\n", "from flax import linen as nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear regression with Flax\n", "\n", "In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.\n", "\n", "A dense layer is a layer that has a kernel parameter $W\\in\\mathcal{M}_{m,n}(\\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the dimensionality of the input, and a bias parameter $b\\in\\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\\in\\mathbb{R}^n$.\n", "\n", "This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# We create one dense layer instance (taking 'features' parameter as input)\n", "model = nn.Dense(features=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.\n", "\n", "### Model parameters & initialization\n", "\n", "Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "06feb9d2-db50-4f41-c169-6df4336f43a5" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "FrozenDict({\n", " params: {\n", " bias: (5,),\n", " kernel: (10, 5),\n", " },\n", "})" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "key1, key2 = random.split(random.key(0))\n", "x = random.normal(key1, (10,)) # Dummy input data\n", "params = model.init(key2, x) # Initialization call\n", "jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.*\n", "\n", "The result is what we expect: bias and kernel parameters of the correct size. Under the hood:\n", "\n", "* The dummy input data `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.\n", "* The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here).\n", "* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.\n", "* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "outputId": "7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)" ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "model.apply(params, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient descent\n", "\n", "If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n", "\n", "$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n", "\n", "Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "outputId": "6eae59dc-0632-4f53-eac8-c22a7c646a52" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x shape: (20, 10) ; y shape: (20, 5)\n" ] } ], "source": [ "# Set problem dimensions.\n", "n_samples = 20\n", "x_dim = 10\n", "y_dim = 5\n", "\n", "# Generate random ground truth W and b.\n", "key = random.key(0)\n", "k1, k2 = random.split(key)\n", "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", "# Store the parameters in a FrozenDict pytree.\n", "true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n", "\n", "# Generate samples with additional noise.\n", "key_sample, key_noise = random.split(k1)\n", "x_samples = random.normal(key_sample, (n_samples, x_dim))\n", "y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))\n", "print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees))." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Same as JAX version but using model.apply().\n", "@jax.jit\n", "def mse(params, x_batched, y_batched):\n", " # Define the squared loss for a single pair (x,y)\n", " def squared_error(x, y):\n", " pred = model.apply(params, x)\n", " return jnp.inner(y-pred, y-pred) / 2.0\n", " # Vectorize the previous to compute the average of the loss on all samples.\n", " return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And finally perform the gradient descent." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "outputId": "50d975b3-4706-4d8a-c4b8-2629ab8e3ac4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss for \"true\" W,b: 0.023639778\n", "Loss step 0: 38.094772\n", "Loss step 10: 0.44692168\n", "Loss step 20: 0.10053458\n", "Loss step 30: 0.035822745\n", "Loss step 40: 0.018846875\n", "Loss step 50: 0.013864839\n", "Loss step 60: 0.012312559\n", "Loss step 70: 0.011812928\n", "Loss step 80: 0.011649306\n", "Loss step 90: 0.011595251\n", "Loss step 100: 0.0115773035\n" ] } ], "source": [ "learning_rate = 0.3 # Gradient step size.\n", "print('Loss for \"true\" W,b: ', mse(true_params, x_samples, y_samples))\n", "loss_grad_fn = jax.value_and_grad(mse)\n", "\n", "@jax.jit\n", "def update_params(params, learning_rate, grads):\n", " params = jax.tree_util.tree_map(\n", " lambda p, g: p - learning_rate * g, params, grads)\n", " return params\n", "\n", "for i in range(101):\n", " # Perform one gradient update.\n", " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", " params = update_params(params, learning_rate, grads)\n", " if i % 10 == 0:\n", " print(f'Loss step {i}: ', loss_val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimizing with Optax\n", "\n", "Flax used to use its own `flax.optim` package for optimization, but with\n", "[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)\n", "this was deprecated in favor of\n", "[Optax](https://github.com/deepmind/optax).\n", "\n", "Basic usage of Optax is straightforward:\n", "\n", "1. Choose an optimization method (e.g. `optax.adam`).\n", "2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the [momentum values](https://optax.readthedocs.io/en/latest/api.html#optax.adam)).\n", "3. Compute the gradients of your loss with `jax.value_and_grad()`.\n", "4. At every iteration, call the Optax `update` function to update the internal\n", " optimizer state and create an update to the parameters. Then add the update\n", " to the parameters with Optax's `apply_updates` method.\n", "\n", "Note that Optax can do a lot more: it's designed for composing simple gradient\n", "transformations into more complex transformations that allows to implement a\n", "wide range of optimizers. There is also support for changing optimizer\n", "hyperparameters over time (\"schedules\"), applying different updates to different\n", "parts of the parameter tree (\"masking\") and much more. For details please refer\n", "to the\n", "[official documentation](https://optax.readthedocs.io/en/latest/)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "import optax\n", "tx = optax.adam(learning_rate=learning_rate)\n", "opt_state = tx.init(params)\n", "loss_grad_fn = jax.value_and_grad(mse)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "outputId": "eec0c096-1d9e-4b3c-f8e5-942ee63828ec" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss step 0: 0.011576377\n", "Loss step 10: 0.0115710115\n", "Loss step 20: 0.011569244\n", "Loss step 30: 0.011568661\n", "Loss step 40: 0.011568454\n", "Loss step 50: 0.011568379\n", "Loss step 60: 0.011568358\n", "Loss step 70: 0.01156836\n", "Loss step 80: 0.01156835\n", "Loss step 90: 0.011568353\n", "Loss step 100: 0.011568348\n" ] } ], "source": [ "for i in range(101):\n", " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", " updates, opt_state = tx.update(grads, opt_state)\n", " params = optax.apply_updates(params, updates)\n", " if i % 10 == 0:\n", " print('Loss step {}: '.format(i), loss_val)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Serializing the result\n", "\n", "Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dict output\n", "{'params': {'bias': DeviceArray([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547], dtype=float32), 'kernel': DeviceArray([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285,\n", " 0.34720102],\n", " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 ,\n", " -0.10574618],\n", " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109,\n", " -1.3132601 ],\n", " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 ,\n", " 0.07996067],\n", " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774,\n", " -2.239638 ],\n", " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 ,\n", " 0.90352124],\n", " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 ,\n", " 0.9286919 ],\n", " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117,\n", " -1.202457 ],\n", " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 ,\n", " -0.45790705],\n", " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863,\n", " 0.17010891]], dtype=float32)}}\n", "Bytes output\n", "b'\\x81\\xa6params\\x82\\xa4bias\\xc7!\\x01\\x93\\x91\\x05\\xa7float32\\xc4\\x14\\x1d\\x1d\\xba\\xbf\\xc4\\xad\\x01\\xc0\\x81)\\x05@\\xdd.\\x9c?\\xa8\\x17\\x7f\\xbf\\xa6kernel\\xc7\\xd6\\x01\\x93\\x92\\n\\x05\\xa7float32\\xc4\\xc8\\x84]\\x81?\\xf0\\xb5B>`\\xb59=z^m\\xbfU\\xc4\\xb1>\\x00\\xb3\\xdd?\\xb8x}?\\xc7F\\x95?2(\\x8d?t\\x91\\xd8\\xbd\\x83\\xb7\\x99\\xbfr\\xa5\\x93>#u\\xb5?\\xdcA\\xf7=\\xe8\\x18\\xa8\\xbf;\\xe5\\x98\\xbf\\xd1}B\\xbe0h\\n=)\\x86\\xa8?k\\xc2\\xa3=\\xaaj\\x10>\\x91\\xd8\\xaf?\\xa9y\\xa8\\xbfc\\xb5\\x08?;V\\x0f\\xc0Av\\x10?ZHP?wD\\xa3>\\x022\\t?+Mg?\\xa0K\\xc2\\xbe\\xb1\\xd3\\xde?)\\x16\\x8a?\\x04\\x13\\x01\\xbf\\xc1\\xbem?\\xfdZx?Wn\\xa8\\xbf\\x940\\xac>\\x925O?\\x1c\\xea\\x99\\xbf\\x9e\\x89\\x82?\\x07\\xad\\x1e\\xbf\\xe2\\x87\\x8a?\\xdfU\\xeb\\xbf\\xcbr\\xea\\xbe\\xe9\\xd2$\\xbf\\xf4\\xb8\\xe9>\\x98\\t\\x91\\xbfm\\x81/\\xbf\\x081.>'\n" ] } ], "source": [ "from flax import serialization\n", "bytes_output = serialization.to_bytes(params)\n", "dict_output = serialization.to_state_dict(params)\n", "print('Dict output')\n", "print(dict_output)\n", "print('Bytes output')\n", "print(bytes_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place.\n", "\n", "*The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.*" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "outputId": "13acc4e1-8757-4554-e2c8-d594ba6e67dc" }, "outputs": [ { "data": { "text/plain": [ "FrozenDict({\n", " params: {\n", " bias: array([-1.4540135, -2.0262308, 2.0806582, 1.2201802, -0.9964547],\n", " dtype=float32),\n", " kernel: array([[ 1.0106664 , 0.19014716, 0.04533899, -0.92722285, 0.34720102],\n", " [ 1.7320251 , 0.9901233 , 1.1662225 , 1.1027892 , -0.10574618],\n", " [-1.2009128 , 0.28837162, 1.4176372 , 0.12073109, -1.3132601 ],\n", " [-1.1944956 , -0.18993308, 0.03379077, 1.3165942 , 0.07996067],\n", " [ 0.14103189, 1.3737966 , -1.3162128 , 0.53401774, -2.239638 ],\n", " [ 0.5643044 , 0.813604 , 0.31888172, 0.5359193 , 0.90352124],\n", " [-0.37948322, 1.7408353 , 1.0788013 , -0.5041964 , 0.9286919 ],\n", " [ 0.9701384 , -1.3158673 , 0.33630812, 0.80941117, -1.202457 ],\n", " [ 1.0198247 , -0.6198277 , 1.0822718 , -1.8385581 , -0.45790705],\n", " [-0.64384323, 0.4564892 , -1.1331053 , -0.68556863, 0.17010891]],\n", " dtype=float32),\n", " },\n", "})" ] }, "execution_count": 14, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "serialization.from_bytes(params, bytes_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining your own models\n", "\n", "Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.\n", "\n", "*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Module basics\n", "\n", "The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "outputId": "b59c679c-d164-4fd6-92db-b50f0d310ec3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", " -4.5417298e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", " -1.0110770e-03]]\n" ] } ], "source": [ "class ExplicitMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " def setup(self):\n", " # we automatically know what to do with lists, dicts of submodules\n", " self.layers = [nn.Dense(feat) for feat in self.features]\n", " # for single submodules, we would just write:\n", " # self.layer1 = nn.Dense(feat1)\n", "\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, lyr in enumerate(self.layers):\n", " x = lyr(x)\n", " if i != len(self.layers) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitMLP(features=[3,4,5])\n", "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, a `nn.Module` subclass is made of:\n", "\n", "* A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.\n", "* A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model.\n", "* A `__call__` function that returns the output of the model from a given input.\n", "* The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.\n", "\n", "*Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524)\n", "\n", "Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "outputId": "4af16ec5-b52a-43b0-fc47-1f8ab25e7058" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\"ExplicitMLP\" object has no attribute \"layers\"\n" ] } ], "source": [ "try:\n", " y = model(x) # Returns an error\n", "except AttributeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "outputId": "183a74ef-f54e-4848-99bf-fee4c174ba6d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967806e-01 -1.4551792e-01 9.4432183e-02 1.2521387e-02\n", " -4.5417298e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024032e-04 2.7864395e-05 2.4478821e-04 8.1344310e-04\n", " -1.0110770e-03]]\n" ] } ], "source": [ "class SimpleMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " # providing a name is optional though!\n", " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", " return x\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleMLP(features=[3,4,5])\n", "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are, however, a few differences you should be aware of between the two declaration modes:\n", "\n", "* In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).\n", "* If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.\n", "* The last initialization will be handled differently. See these notes for more details (TODO: add notes link)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Module parameters\n", "\n", "In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "outputId": "83b5fea4-071e-4ea0-8fa8-610e69fb5fd5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameters:\n", " FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.86789787, 0.4604268 ],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.65221626],\n", " [-0.82430327, 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "output:\n", " [[ 0.5035518 1.8548558 -0.4270195 ]\n", " [ 0.0279097 0.5589246 -0.43061772]\n", " [ 0.3547128 1.5740999 -0.32865518]\n", " [ 0.5264864 1.2928858 0.10089308]]\n" ] } ], "source": [ "class SimpleDense(nn.Module):\n", " features: int\n", " kernel_init: Callable = nn.initializers.lecun_normal()\n", " bias_init: Callable = nn.initializers.zeros_init()\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " kernel = self.param('kernel',\n", " self.kernel_init, # Initialization function\n", " (inputs.shape[-1], self.features)) # shape info.\n", " y = jnp.dot(inputs, kernel)\n", " bias = self.param('bias', self.bias_init, (self.features,))\n", " y = y + bias\n", " return y\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleDense(features=3)\n", "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", "print('initialized parameters:\\n', params)\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args, **init_kwargs)` :\n", "\n", "* `name` is simply the name of the parameter that will end up in the parameter structure.\n", "* `init_fn` is a function with input `(PRNGKey, *init_args, **init_kwargs)` returning an Array, with `init_args` and `init_kwargs` being the arguments needed to call the initialisation function.\n", "* `init_args` and `init_kwargs` are the arguments to provide to the initialization function.\n", "\n", "Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Variables and collections of variables\n", "\n", "As we've seen so far, working with models means working with:\n", "\n", "* A subclass of `nn.Module`;\n", "* A pytree of parameters for the model (typically from `model.init()`);\n", "\n", "However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.\n", "\n", "For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py)." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "outputId": "75465fd6-cdc8-497c-a3ec-7f709b5dde7a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized variables:\n", " FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", " },\n", " params: {\n", " bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "updated state:\n", " FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", " },\n", "})\n" ] } ], "source": [ "class BiasAdderWithRunningMean(nn.Module):\n", " decay: float = 0.99\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " # easy pattern to detect if we're initializing via empty variable tree\n", " is_initialized = self.has_variable('batch_stats', 'mean')\n", " ra_mean = self.variable('batch_stats', 'mean',\n", " lambda s: jnp.zeros(s),\n", " x.shape[1:])\n", " bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])\n", " if is_initialized:\n", " ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n", "\n", " return x - ra_mean.value + bias\n", "\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = jnp.ones((10,5))\n", "model = BiasAdderWithRunningMean()\n", "variables = model.init(key1, x)\n", "print('initialized variables:\\n', variables)\n", "y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", "print('updated state:\\n', updated_state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "outputId": "09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "updated state:\n", " FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", " },\n", "})\n", "updated state:\n", " FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),\n", " },\n", "})\n", "updated state:\n", " FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),\n", " },\n", "})\n" ] } ], "source": [ "for val in [1.0, 2.0, 3.0]:\n", " x = val * jnp.ones((10,5))\n", " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", " old_state, params = flax.core.pop(variables, 'params')\n", " variables = flax.core.freeze({'params': params, **updated_state})\n", " print('updated state:\\n', updated_state) # Shows only the mutable part" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.\n", "\n", "*This example isn't doing anything and is only for demonstration purposes.*" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "outputId": "0906fbab-b866-4956-d231-b1374415d448" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Updated state: FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),\n", " },\n", "})\n", "Updated state: FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),\n", " },\n", "})\n", "Updated state: FrozenDict({\n", " batch_stats: {\n", " mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),\n", " },\n", "})\n" ] } ], "source": [ "from functools import partial\n", "\n", "@partial(jax.jit, static_argnums=(0, 1))\n", "def update_step(tx, apply_fn, x, opt_state, params, state):\n", "\n", " def loss(params):\n", " y, updated_state = apply_fn({'params': params, **state},\n", " x, mutable=list(state.keys()))\n", " l = ((x - y) ** 2).sum()\n", " return l, updated_state\n", "\n", " (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)\n", " updates, opt_state = tx.update(grads, opt_state)\n", " params = optax.apply_updates(params, updates)\n", " return opt_state, params, state\n", "\n", "x = jnp.ones((10,5))\n", "variables = model.init(random.key(0), x)\n", "state, params = flax.core.pop(variables, 'params')\n", "del variables\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", "\n", "for _ in range(3):\n", " opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)\n", " print('Updated state: ', state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the above function has a quite verbose signature and it would not actually\n", "work with `jax.jit()` because the function arguments are not \"valid JAX types\".\n", "\n", "Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Exporting to Tensorflow's SavedModel with jax2tf\n", "\n", "JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "name": "python", "version": "3.8.15" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/guides/flax_fundamentals/flax_basics.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/flax_basics.ipynb) # Flax Basics This notebook will walk you through the following workflow: * Instantiating a model from Flax built-in layers or third-party models. * Initializing parameters of the model and manually written training. * Using optimizers provided by Flax to ease training. * Serialization of parameters and other objects. * Creating your own models and managing state. +++ ## Setting up our environment Here we provide the code needed to set up the environment for our notebook. ```{code-cell} :outputId: e30aa464-fa52-4f35-df96-716c68a4b3ee :tags: [skip-execution] # Install the latest JAXlib version. !pip install --upgrade -q pip jax jaxlib # Install Flax at head: !pip install --upgrade -q git+https://github.com/google/flax.git ``` ```{code-cell} import jax from typing import Any, Callable, Sequence from jax import random, numpy as jnp import flax from flax import linen as nn ``` ## Linear regression with Flax In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done. A dense layer is a layer that has a kernel parameter $W\in\mathcal{M}_{m,n}(\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the dimensionality of the input, and a bias parameter $b\in\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\in\mathbb{R}^n$. This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`). ```{code-cell} # We create one dense layer instance (taking 'features' parameter as input) model = nn.Dense(features=5) ``` Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class. ### Model parameters & initialization Parameters are not stored with the models themselves. You need to initialize parameters by calling the `init` function, using a PRNGKey and dummy input data. ```{code-cell} :outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5 key1, key2 = random.split(random.key(0)) x = random.normal(key1, (10,)) # Dummy input data params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes ``` *Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.* The result is what we expect: bias and kernel parameters of the correct size. Under the hood: * The dummy input data `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel. * The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here). * Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`. * The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`. +++ To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input: ```{code-cell} :outputId: 7bbe6bb4-94d5-4574-fbb5-aa0fcd1c84ae model.apply(params, x) ``` ### Gradient descent If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error: $$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$ Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example. ```{code-cell} :outputId: 6eae59dc-0632-4f53-eac8-c22a7c646a52 # Set problem dimensions. n_samples = 20 x_dim = 10 y_dim = 5 # Generate random ground truth W and b. key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) x_samples = random.normal(key_sample, (n_samples, x_dim)) y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim)) print('x shape:', x_samples.shape, '; y shape:', y_samples.shape) ``` We copy the same training loop that we used in the JAX pytree linear regression example with `jax.value_and_grad()`, but here we can use `model.apply()` instead of having to define our own feed-forward function (`predict_pytree()` in the [JAX example](https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#linear-regression-with-pytrees)). ```{code-cell} # Same as JAX version but using model.apply(). @jax.jit def mse(params, x_batched, y_batched): # Define the squared loss for a single pair (x,y) def squared_error(x, y): pred = model.apply(params, x) return jnp.inner(y-pred, y-pred) / 2.0 # Vectorize the previous to compute the average of the loss on all samples. return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0) ``` And finally perform the gradient descent. ```{code-cell} :outputId: 50d975b3-4706-4d8a-c4b8-2629ab8e3ac4 learning_rate = 0.3 # Gradient step size. print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples)) loss_grad_fn = jax.value_and_grad(mse) @jax.jit def update_params(params, learning_rate, grads): params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, grads) return params for i in range(101): # Perform one gradient update. loss_val, grads = loss_grad_fn(params, x_samples, y_samples) params = update_params(params, learning_rate, grads) if i % 10 == 0: print(f'Loss step {i}: ', loss_val) ``` ### Optimizing with Optax Flax used to use its own `flax.optim` package for optimization, but with [FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md) this was deprecated in favor of [Optax](https://github.com/deepmind/optax). Basic usage of Optax is straightforward: 1. Choose an optimization method (e.g. `optax.adam`). 2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the [momentum values](https://optax.readthedocs.io/en/latest/api.html#optax.adam)). 3. Compute the gradients of your loss with `jax.value_and_grad()`. 4. At every iteration, call the Optax `update` function to update the internal optimizer state and create an update to the parameters. Then add the update to the parameters with Optax's `apply_updates` method. Note that Optax can do a lot more: it's designed for composing simple gradient transformations into more complex transformations that allows to implement a wide range of optimizers. There is also support for changing optimizer hyperparameters over time ("schedules"), applying different updates to different parts of the parameter tree ("masking") and much more. For details please refer to the [official documentation](https://optax.readthedocs.io/en/latest/). ```{code-cell} import optax tx = optax.adam(learning_rate=learning_rate) opt_state = tx.init(params) loss_grad_fn = jax.value_and_grad(mse) ``` ```{code-cell} :outputId: eec0c096-1d9e-4b3c-f8e5-942ee63828ec for i in range(101): loss_val, grads = loss_grad_fn(params, x_samples, y_samples) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates) if i % 10 == 0: print('Loss step {}: '.format(i), loss_val) ``` ### Serializing the result Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that. ```{code-cell} :outputId: b97e7d83-3e40-4a80-b1fe-1f6ceff30a0c from flax import serialization bytes_output = serialization.to_bytes(params) dict_output = serialization.to_state_dict(params) print('Dict output') print(dict_output) print('Bytes output') print(bytes_output) ``` To load the model back, you'll need to use a template of the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated `params` as a template. Note that this will produce a new variable structure, and not mutate in-place. *The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.* ```{code-cell} :outputId: 13acc4e1-8757-4554-e2c8-d594ba6e67dc serialization.from_bytes(params, bytes_output) ``` ## Defining your own models Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class. *Keep in mind that we imported* `linen as nn` *and this only works with the new linen API* +++ ### Module basics The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function. ```{code-cell} :outputId: b59c679c-d164-4fd6-92db-b50f0d310ec3 class ExplicitMLP(nn.Module): features: Sequence[int] def setup(self): # we automatically know what to do with lists, dicts of submodules self.layers = [nn.Dense(feat) for feat in self.features] # for single submodules, we would just write: # self.layer1 = nn.Dense(feat1) def __call__(self, inputs): x = inputs for i, lyr in enumerate(self.layers): x = lyr(x) if i != len(self.layers) - 1: x = nn.relu(x) return x key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` As we can see, a `nn.Module` subclass is made of: * A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`. * A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model. * A `__call__` function that returns the output of the model from a given input. * The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit. *Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524) Since the module structure and its parameters are not tied to each other, you can't directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one, which is the one to call on an input: ```{code-cell} :outputId: 4af16ec5-b52a-43b0-fc47-1f8ab25e7058 try: y = model(x) # Returns an error except AttributeError as e: print(e) ``` Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so: ```{code-cell} :outputId: 183a74ef-f54e-4848-99bf-fee4c174ba6d class SimpleMLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): x = nn.Dense(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) # providing a name is optional though! # the default autonames would be "Dense_0", "Dense_1", ... return x key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` There are, however, a few differences you should be aware of between the two declaration modes: * In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders). * If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated. * The last initialization will be handled differently. See these notes for more details (TODO: add notes link). +++ ### Module parameters In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules: ```{code-cell} :outputId: 83b5fea4-071e-4ea0-8fa8-610e69fb5fd5 class SimpleDense(nn.Module): features: int kernel_init: Callable = nn.initializers.lecun_normal() bias_init: Callable = nn.initializers.zeros_init() @nn.compact def __call__(self, inputs): kernel = self.param('kernel', self.kernel_init, # Initialization function (inputs.shape[-1], self.features)) # shape info. y = jnp.dot(inputs, kernel) bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleDense(features=3) params = model.init(key2, x) y = model.apply(params, x) print('initialized parameters:\n', params) print('output:\n', y) ``` Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args, **init_kwargs)` : * `name` is simply the name of the parameter that will end up in the parameter structure. * `init_fn` is a function with input `(PRNGKey, *init_args, **init_kwargs)` returning an Array, with `init_args` and `init_kwargs` being the arguments needed to call the initialisation function. * `init_args` and `init_kwargs` are the arguments to provide to the initialization function. Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site. +++ ### Variables and collections of variables As we've seen so far, working with models means working with: * A subclass of `nn.Module`; * A pytree of parameters for the model (typically from `model.init()`); However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method. For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py). ```{code-cell} :outputId: 75465fd6-cdc8-497c-a3ec-7f709b5dde7a class BiasAdderWithRunningMean(nn.Module): decay: float = 0.99 @nn.compact def __call__(self, x): # easy pattern to detect if we're initializing via empty variable tree is_initialized = self.has_variable('batch_stats', 'mean') ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s), x.shape[1:]) bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) if is_initialized: ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True) return x - ra_mean.value + bias key1, key2 = random.split(random.key(0), 2) x = jnp.ones((10,5)) model = BiasAdderWithRunningMean() variables = model.init(key1, x) print('initialized variables:\n', variables) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) print('updated state:\n', updated_state) ``` Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern: ```{code-cell} :outputId: 09a8bdd1-eaf8-401a-cf7c-386a7a5aa87b for val in [1.0, 2.0, 3.0]: x = val * jnp.ones((10,5)) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) old_state, params = flax.core.pop(variables, 'params') variables = flax.core.freeze({'params': params, **updated_state}) print('updated state:\n', updated_state) # Shows only the mutable part ``` From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables. *This example isn't doing anything and is only for demonstration purposes.* ```{code-cell} :outputId: 0906fbab-b866-4956-d231-b1374415d448 from functools import partial @partial(jax.jit, static_argnums=(0, 1)) def update_step(tx, apply_fn, x, opt_state, params, state): def loss(params): y, updated_state = apply_fn({'params': params, **state}, x, mutable=list(state.keys())) l = ((x - y) ** 2).sum() return l, updated_state (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params) updates, opt_state = tx.update(grads, opt_state) params = optax.apply_updates(params, updates) return opt_state, params, state x = jnp.ones((10,5)) variables = model.init(random.key(0), x) state, params = flax.core.pop(variables, 'params') del variables tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) for _ in range(3): opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state) print('Updated state: ', state) ``` Note that the above function has a quite verbose signature and it would not actually work with `jax.jit()` because the function arguments are not "valid JAX types". Flax provides a handy wrapper - `TrainState` - that simplifies the above code. Check out [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to learn more. +++ ### Exporting to Tensorflow's SavedModel with jax2tf JAX released an experimental converter called [jax2tf](https://github.com/jax-ml/jax/tree/main/jax/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax. ================================================ FILE: docs/guides/flax_fundamentals/index.rst ================================================ Flax fundamentals ================= .. toctree:: :maxdepth: 1 JAX 101 flax_basics state_params setup_or_nncompact arguments rng_guide ================================================ FILE: docs/guides/flax_fundamentals/rng_guide.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Randomness and PRNGs in Flax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this guide, you will learn how Flax uses [JAX's explicit pseudorandom number generator (PRNG) keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng) to emulate randomness, and adds some additional features to make it easier for users to thread PRNG keys through different Flax `Module`s.\n", "\n", "If you are new to JAX PRNG keys or need a refresher, check out:\n", "- [JAX 101: PRNGs in JAX](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", "- [🔪 JAX - The Sharp Bits 🔪: Random Numbers](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install or upgrade Flax, and then import some necessary dependencies.\n", "\n", "**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device Google Cloud TPU environment, for example, on Google Cloud or in a Kaggle VM with a TPU." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "!pip install -q flax" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import flax, flax.linen as nn\n", "import jax, jax.numpy as jnp\n", "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", "from jax.experimental import mesh_utils\n", "from jax.experimental.shard_map import shard_map\n", "\n", "import hashlib" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "ec904f6b-0e87-4efe-87c4-fea0f8e8ec23" }, "outputs": [ { "data": { "text/plain": [ "[CpuDevice(id=0),\n", " CpuDevice(id=1),\n", " CpuDevice(id=2),\n", " CpuDevice(id=3),\n", " CpuDevice(id=4),\n", " CpuDevice(id=5),\n", " CpuDevice(id=6),\n", " CpuDevice(id=7)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.devices()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "jax.config.update('jax_threefry_partitionable', True)\n", "assert jax.config.jax_threefry_partitionable == True\n", "assert jax.config.jax_default_prng_impl == 'threefry2x32'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Receiving, manipulating and creating PRNG keys with `Module.make_rng`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The primary method Flax uses to receive, manipulate and create PRNG keys is via the `Module` method [`self.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng). It is a method that accepts a string name that represents an \"RNG stream\". Each RNG stream has an initial starting seed PRNG key, which the user passes in as a dictionary argument (i.e. into an [`.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) or [`.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) function), and the starting seed is used by `self.make_rng` to generate more PRNG keys for that stream. If `self.make_rng` is called on a string name that does not have an initial starting seed PRNG key (i.e. the user did not pass a key with the corresponding name into `.init` or `.apply`), then `self.make_rng` will use the `'params'` key as the initial starting seed by default.\n", "\n", "Note that this method can only be called with bounded modules (see [The Flax Module lifecycle](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#top-level-modules))." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "outputId": "2a16435b-e92a-480a-f9fb-e6effc42c4c2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n" ] } ], "source": [ "class RNGModule(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " print(self.make_rng('rng_stream'))\n", " print(self.make_rng('rng_stream'))\n", " print(self.make_rng('rng_stream'))\n", "\n", "rng_module = RNGModule()\n", "variables = rng_module.init({'rng_stream': jax.random.key(0)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now if we use a different starting seed PRNG key, we will generate different values (as intended)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "outputId": "985b0f62-dfde-4f0f-fad4-a31927fc9f59" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", "[3077990774 2166202870]\n", "Array((), dtype=key) overlaying:\n", "[3825832496 2886313970]\n", "Array((), dtype=key) overlaying:\n", "[ 791337683 1373966058]\n" ] } ], "source": [ "variables = rng_module.init({'rng_stream': jax.random.key(1)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calling `self.make_rng` for one stream will not affect the random values generated from another stream; i.e. the call order doesn't matter." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "outputId": "7e8ce538-e380-4db9-db23-bc4a8da577da" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rng_stream1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[3077990774 2166202870]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[3825832496 2886313970]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[ 791337683 1373966058]\n" ] } ], "source": [ "class RNGModuleTwoStreams(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " # same value as first code snippet above\n", " print(f\"rng_stream1: {self.make_rng('rng_stream1')}\")\n", " # same value as second code snippet above\n", " print(f\"rng_stream2: {self.make_rng('rng_stream2')}\")\n", " # same value as first code snippet above\n", " print(f\"rng_stream1: {self.make_rng('rng_stream1')}\")\n", " # same value as second code snippet above\n", " print(f\"rng_stream2: {self.make_rng('rng_stream2')}\")\n", " # same value as first code snippet above\n", " print(f\"rng_stream1: {self.make_rng('rng_stream1')}\")\n", " # same value as second code snippet above\n", " print(f\"rng_stream2: {self.make_rng('rng_stream2')}\")\n", "\n", "rng_module_two_streams = RNGModuleTwoStreams()\n", "variables = rng_module_two_streams.init(\n", " {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)}\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Providing the same seed PRNG key will result in the same values being generated (provided that the same operations are used for those keys)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "outputId": "b70be039-589a-48f7-dc54-65e78c449c65" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "rng_stream1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "rng_stream1: Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n", "rng_stream2: Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n" ] } ], "source": [ "variables = rng_module_two_streams.init(\n", " {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)}\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How `self.make_rng` works under the hood" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is what happens when `self.make_rng` (`flax.linen.Module.make_rng`) is called:\n", "* The following data is collected:\n", " * The path of the `Module` as provided by `self.scope.path` (the top-level root module has an empty path `()`).\n", " * The `self.make_rng` call count. That is, the number of times `self.make_rng` has been called for this specific stream (including this call).\n", " * **Note:** Each sub-`Module` will have its own individual call count that's separate from other `Module`s. For example, a `Module` that has called `self.make_rng('params')` twice and contains a sub-`Module` that has called `self.make_rng('params')` once, will have a call count of 2 and 1 for each of the RNG stream `'params'`, respectively.\n", "* The data is bundled into a tuple and fed into a hash function and produces an integer.\n", "* The generated integer is folded into the RNG stream's starting seed PRNG key to generate a new, unique PRNG key." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is a slightly simplified version of the hash function that Flax uses for `self.make_rng`:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def produce_hash(data):\n", " m = hashlib.sha1()\n", " for x in data:\n", " if isinstance(x, str):\n", " m.update(x.encode('utf-8'))\n", " elif isinstance(x, int):\n", " m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))\n", " else:\n", " raise ValueError(f'Expected int or string, got: {x}')\n", " d = m.digest()\n", " hash_int = int.from_bytes(d[:4], byteorder='big')\n", " return hash_int" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now you can manually reproduce the PRNG keys generated from `self.make_rng`:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "outputId": "d26b7355-9e8b-4954-b2f4-cf7520d5c5a3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n" ] } ], "source": [ "stream_seed = jax.random.key(0)\n", "for call_count in range(1, 4):\n", " hash_int = produce_hash(data=(call_count,))\n", " print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int)))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "outputId": "dec627a6-4c5a-4e3e-ce11-ce4f72775261" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "Array((), dtype=key) overlaying:\n", "[2411773124 4124888837]\n" ] } ], "source": [ "variables = rng_module.init({'rng_stream': jax.random.key(0)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sub-`Module`s and `self.make_rng`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This section explores how `self.make_rng` (`flax.linen.Module.make_rng`) behaves with sub-`Module`s.\n", "\n", "Consider the following example:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "5b7a9ae9-ca49-4ac0-d007-5caeee739ff0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[ 601859108 3782857444]\n", "RNGSubSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[ 234240654 1028548813]\n", "RNGSubSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[3650462303 2124609379]\n" ] } ], "source": [ "class RNGSubSubModule(nn.Module):\n", " def __call__(self):\n", " print(f\"{self.name}, count 1: {self.make_rng('rng_stream')}\")\n", " print(f\"{self.name}, count 2: {self.make_rng('rng_stream')}\")\n", "\n", "class RNGSubModule(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " print(f\"{self.name}, count 1: {self.make_rng('rng_stream')}\")\n", " print(f\"{self.name}, count 2: {self.make_rng('rng_stream')}\")\n", " RNGSubSubModule()()\n", "\n", "class RNGModule(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " print(f\"RNGModule, count 1: {self.make_rng('rng_stream')}\")\n", " print(f\"RNGModule, count 2: {self.make_rng('rng_stream')}\")\n", " RNGSubModule()()\n", "\n", "rng_module = RNGModule()\n", "variables = rng_module.init({'rng_stream': jax.random.key(0)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As previously discussed, the data that is fed into the Flax hash function consists of:\n", "\n", " * The path of the `Module`, provided by `self.scope.path` (the top-level root module has an empty path `()`); and\n", " * The call count for the specific RNG stream.\n", "\n", "In addition, note that each Flax `Module` and sub-`Module` have their own individual call counts, even for the same RNG stream. The convention for sub-`Module` names is: `f'{module_name}_{module_number}'`. For example, the first `Dense` sub-`Module` will be called `Dense_0`, the second one will be called `Dense_1`, and so on.\n", "\n", "Therefore, the following data will be fed into the hash function:\n", "\n", " * For `RNGModule`: The data is just the call count, such as `(1,)` and `(2,)`, since the root `Module` has an empty path.\n", " * For `RNGSubModule`: The data is `('RNGSubModule_0', 1)` and `('RNGSubModule_0', 2)`.\n", " * For `RNGSubSubModule`: The data is `('RNGSubModule_0', 'RNGSubSubModule_0', 1)` and `('RNGSubModule_0', 'RNGSubSubModule_0', 2)`.\n", "\n", "With this data, you can manually reproduce the PRNG keys generated from the `Module` and sub-`Module`s using `self.make_rng`.\n", "\n", "For example:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "outputId": "c0de4d37-0f00-4e58-bdfd-e8a6454ed681" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[ 601859108 3782857444]\n", "RNGSubSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[ 234240654 1028548813]\n", "RNGSubSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[3650462303 2124609379]\n" ] } ], "source": [ "stream_seed = jax.random.key(0)\n", "for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')):\n", " if initial_data:\n", " module_name = initial_data[-1]\n", " else:\n", " module_name = 'RNGModule'\n", " for call_count in (1, 2):\n", " hash_int = produce_hash(data=initial_data+(call_count,))\n", " rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))\n", " print(f\"{module_name}, count {call_count}: {rng_key}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the same sub-`Module` class is used multiple times, you can increment the suffix of the sub-`Module` name accordingly. For example: `RNGSubModule_0`, `RNGSubModule_1`, and so on." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "outputId": "0b77a038-7000-407b-c5b8-a28dea7951d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[ 601859108 3782857444]\n", "RNGSubModule_1, count 1: Array((), dtype=key) overlaying:\n", "[ 426957352 2006350344]\n", "RNGSubModule_1, count 2: Array((), dtype=key) overlaying:\n", "[4006253729 4205356731]\n" ] } ], "source": [ "class RNGSubModule(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " print(f\"{self.name}, count 1: {self.make_rng('rng_stream')}\")\n", " print(f\"{self.name}, count 2: {self.make_rng('rng_stream')}\")\n", "\n", "class RNGModule(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " print(f\"RNGModule, count 1: {self.make_rng('rng_stream')}\")\n", " print(f\"RNGModule, count 2: {self.make_rng('rng_stream')}\")\n", " RNGSubModule()()\n", " RNGSubModule()()\n", "\n", "rng_module = RNGModule()\n", "variables = rng_module.init({'rng_stream': jax.random.key(0)})" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "outputId": "d189d25e-425d-4fd7-fe18-2dfd63f28b87" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNGModule, count 1: Array((), dtype=key) overlaying:\n", "[1428664606 3351135085]\n", "RNGModule, count 2: Array((), dtype=key) overlaying:\n", "[3456700291 3873160899]\n", "RNGSubModule_0, count 1: Array((), dtype=key) overlaying:\n", "[3858825717 2323087578]\n", "RNGSubModule_0, count 2: Array((), dtype=key) overlaying:\n", "[ 601859108 3782857444]\n", "RNGSubModule_1, count 1: Array((), dtype=key) overlaying:\n", "[ 426957352 2006350344]\n", "RNGSubModule_1, count 2: Array((), dtype=key) overlaying:\n", "[4006253729 4205356731]\n" ] } ], "source": [ "stream_seed = jax.random.key(0)\n", "for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)):\n", " if initial_data:\n", " module_name = initial_data[-1]\n", " else:\n", " module_name = 'RNGModule'\n", " for call_count in (1, 2):\n", " hash_int = produce_hash(data=initial_data+(call_count,))\n", " rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))\n", " print(f\"{module_name}, count {call_count}: {rng_key}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using `self.param` and `self.variable`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Flax users have the option of creating additional parameters and variables in their modules by using the [`self.param`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.param) and [`self.variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.variable) `Module` methods. An `init_fn` argument must be passed to these methods so that it can generate the initial value of the parameter/variable. `self.make_rng` is commonly used implicitly or explicitly in this `init_fn`, since many initializer functions are stochastic in nature and require a PRNG key. See the full list of Flax initializers [here](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/initializers.html).\n", "\n", "There are a couple of differences between the two methods that the user should take note of:\n", "* `self.param` always creates a parameter in the `'params'` [collection](https://flax.readthedocs.io/en/latest/glossary.html#term-Params-parameters), whereas `self.variable` creates a variable in any [collection](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) the user specifies\n", "* `self.param` will automatically call `self.make_rng('params')` and pass in the generated PRNG key implicitly to the `init_fn` of the parameter you instantiated (it will be passed in as the first argument), whereas users will have to manually specify what RNG stream to call `self.make_rng` on in the `init_fn` of `self.variable` (it could be `'params'` or something different).\n", "\n", "Below is an example using both `self.param` and `self.variable`:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "outputId": "a7816385-0e08-48e2-dc51-055d7bcd0bab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-1.6185919 0.700908 ]\n", " [-1.3146383 -0.79342234]]\n", "[[ 0.0761425 -1.6157459]\n", " [-1.6857724 0.7126891]]\n", "[[ 0.60175574 0.2553228 ]\n", " [ 0.27367848 -2.1975214 ]]\n", "[[1.6249592 0.30813068]\n", " [1.6613585 1.0404155 ]]\n", "[[ 0.0030665 0.29551846]\n", " [ 0.16670242 -0.78252524]]\n", "[1.582462 0.15216611]\n" ] } ], "source": [ "class Model(nn.Module):\n", " @nn.compact\n", " def __call__(self, x):\n", " # kernel will use 'params' seed, initial data will include 'Dense_0', call count 1\n", " x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x)\n", " # model_param will use 'params' seed, call count 1\n", " model_param = self.param('model_param', jax.random.normal, x.shape)\n", " # model_variable1 will use 'params' seed, call count 2\n", " model_variable1 = self.variable(\n", " 'other_collection',\n", " 'model_variable1',\n", " lambda: jax.random.normal(self.make_rng('params'), x.shape),\n", " )\n", " # model_variable2 will use 'other' seed, call count 1\n", " model_variable2 = self.variable(\n", " 'other_collection',\n", " 'model_variable2',\n", " lambda: jax.random.normal(self.make_rng('other'), x.shape),\n", " )\n", " # kernel will use 'params' seed, initial data will include 'Dense_1', call count 1\n", " # bias will use 'params' seed, initial data will include 'Dense_1', call count 2\n", " x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)(\n", " x\n", " )\n", " return x\n", "\n", "model = Model()\n", "variables = model.init(\n", " {'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2))\n", ")\n", "print(variables['params']['Dense_0']['kernel'])\n", "print(variables['params']['model_param'])\n", "print(variables['other_collection']['model_variable1'])\n", "print(variables['other_collection']['model_variable2'])\n", "print(variables['params']['Dense_1']['kernel'])\n", "print(variables['params']['Dense_1']['bias'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Remember:\n", "* there is a separate count for each RNG stream; this is why the count for `self.make_rng('other')` starts at 1 even though there were earlier calls of `self.make_rng('params')`\n", "* each submodule has their own separate count for each rng stream; this is why each `Dense` layer has their own separate count for `self.make_rng('params')` and why `model_param` and `model_variable1` share the same count (since they are defined within the same top-level parent module)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "outputId": "ccec9d64-9a27-47f7-adaf-b36a5ea655db" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-1.6185919 0.700908 ]\n", " [-1.3146383 -0.79342234]]\n", "[[ 0.0761425 -1.6157459]\n", " [-1.6857724 0.7126891]]\n", "[[ 0.60175574 0.2553228 ]\n", " [ 0.27367848 -2.1975214 ]]\n", "[[1.6249592 0.30813068]\n", " [1.6613585 1.0404155 ]]\n", "[[ 0.0030665 0.29551846]\n", " [ 0.16670242 -0.78252524]]\n", "[[1.582462 0.15216611]]\n" ] } ], "source": [ "params_seed = jax.random.key(0)\n", "other_seed = jax.random.key(1)\n", "for initial_data, count, seed, shape in (\n", " (('Dense_0',), 1, params_seed, (2, 2)),\n", " ((), 1, params_seed, (2, 2)),\n", " ((), 2, params_seed, (2, 2)),\n", " ((), 1, other_seed, (2, 2)),\n", " (('Dense_1',), 1, params_seed, (2, 2)),\n", " (('Dense_1',), 2, params_seed, (1, 2)),\n", "):\n", " hash_int = produce_hash(data=(*initial_data, count))\n", " rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int))\n", " print(jax.random.normal(rng_key, shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Managing RNG streams inside a training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is an example of managing RNG streams from `self.make_rng`, `self.param`, `self.variable` and `nn.Dropout` in a training loop (note: `nn.Dropout` requires a seed PRNG key to be passed in the `'dropout'` RNG stream, since it implicitly calls `self.make_rng('dropout')`):" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class SubModule(nn.Module):\n", " @nn.compact\n", " def __call__(self, x, train):\n", " # variables created using `self.param` will use `self.make_rng('params')`\n", " kernel = self.param('submodule_kernel', jax.random.normal, x.shape)\n", " x = x + kernel\n", " # `nn.Dropout` will use self.make_rng('dropout')\n", " x = nn.Dropout(0.2)(x, deterministic=not train)\n", " # `nn.Dense` will use self.make_rng('params')\n", " x = nn.Dense(3)(x)\n", " return x\n", "\n", "class Model(nn.Module):\n", " @nn.compact\n", " def __call__(self, x, train):\n", " # make kernel use `self.make_rng('other')`\n", " kernel = self.variable(\n", " 'other_collection',\n", " 'module_kernel',\n", " lambda: jax.random.normal(self.make_rng('other'), x.shape),\n", " )\n", " x = (\n", " x + kernel.value\n", " ) # `.value` will extract the underlying value of the variable\n", " x = SubModule()(x, train)\n", " # `nn.Dropout` will use self.make_rng('dropout')\n", " x = nn.Dropout(0.2)(x, deterministic=not train)\n", " # `nn.Dense` will use self.make_rng('params')\n", " x = nn.Dense(2)(x)\n", " return x\n", "\n", "params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3)\n", "init_rngs = {'params': params_rng, 'other': other_rng}\n", "\n", "x = jnp.ones((1, 3))\n", "y = jnp.ones((1, 2))\n", "\n", "module = Model()\n", "variables = module.init(init_rngs, x, train=False)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "outputId": "e9da8228-acba-403d-bcb5-33a39d4d530d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.518454\n", "2.4859657\n", "2.4171872\n", "2.412684\n", "2.3435805\n", "2.2773488\n", "2.2592616\n", "2.2009292\n", "2.1839895\n", "2.1707344\n" ] } ], "source": [ "def update(variables, rng):\n", " # we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training\n", " # split the rng to get a dropout_rng to be used for this training iteration,\n", " # and to get another rng key to be used for the next training iteration\n", " dropout_rng, next_rng = jax.random.split(rng)\n", " def loss(params):\n", " out = module.apply(\n", " {'params': params, 'other_collection': variables['other_collection']},\n", " x,\n", " train=True,\n", " rngs={'dropout': dropout_rng},\n", " )\n", " return jnp.mean((y - out) ** 2)\n", " grads = jax.grad(loss)(variables['params'])\n", " params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads)\n", " return {\n", " 'params': params,\n", " 'other_collection': variables['other_collection'],\n", " }, next_rng\n", "\n", "for _ in range(10):\n", " variables, train_rng = update(variables, train_rng)\n", " out = module.apply(variables, x, train=False)\n", " print(jnp.mean((y - out)**2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🔪 Sharp edge 🔪 - unintentionally generating the same values\n", "\n", "There is an edge case where the same value can be unintentionally generated.\n", "See the [Flax issue](https://github.com/google/flax/issues/2157) for more details." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "outputId": "887142ff-c9ca-4aae-d9fa-cc9993d809c5" }, "outputs": [ { "data": { "text/plain": [ "Array(True, dtype=bool)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Leaf(nn.Module):\n", " def __call__(self, x):\n", " return x + jax.random.randint(self.make_rng(\"rng\"), (), 0, 100)\n", "\n", "class Node(nn.Module):\n", " leaf_name: str\n", " @nn.compact\n", " def __call__(self, x):\n", " return Leaf(name=self.leaf_name)(x)\n", "\n", "class Model(nn.Module):\n", " @nn.compact\n", " def __call__(self, x):\n", " return (Node(name=\"ab\", leaf_name=\"cdef\")(x),\n", " Node(name=\"abc\", leaf_name=\"def\")(x),\n", " )\n", "\n", "out1, out2 = Model().apply({}, 0, rngs={\"rng\": jax.random.key(33)})\n", "out1 == out2 # same output, despite having different submodule names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This occurs because the hash function [concatenates strings together](https://docs.python.org/3/library/hashlib.html#hashlib.hash.update), so the data `('AB', 'C')` is equivalent to data `('A', 'BC')` when fed into the hash function, therefore producing the same hash int." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "outputId": "001cbd49-129b-4474-c6a1-3255a4ee3dfe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "947574064\n", "947574064\n", "947574064\n", "947574064\n" ] } ], "source": [ "print(produce_hash(data=('A', 'B', 'C', 1)))\n", "print(produce_hash(data=('AB', 'C', 1)))\n", "print(produce_hash(data=('A', 'BC', 1)))\n", "print(produce_hash(data=('ABC', 1)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid this edge case, users can flip the `flax_fix_rng_separator` [configuration flag](https://flax.readthedocs.io/en/latest/api_reference/flax.config.html#flax.configurations.Config.flax_fix_rng_separator) to `True`." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "outputId": "35a2b204-bdfd-4f83-8e98-ba723963cb0c" }, "outputs": [ { "data": { "text/plain": [ "Array(False, dtype=bool)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "flax.config.update('flax_fix_rng_separator', True)\n", "out1, out2 = Model().apply({}, 0, rngs={\"rng\": jax.random.key(33)})\n", "out1 == out2 # different output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Managing RNG's on multiple devices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This section will show examples on how to use `jit` and `shard_map` to use RNG's in multi-device settings." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using `jax.jit`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), we can use RNG's as we did before, but we now include `in_shardings` and `out_shardings` arguments to specify how to shard input and output data. The RNG key itself gets replicated (not sharded); `jax.jit` makes each device use it as appropriate for its shard of the data.\n", "\n", "For more details on training on multiple devices in Flax using `jax.jit`, see our [Scale up Flax Modules on multiple devices guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html#) and [lm1b example](https://github.com/google/flax/tree/main/examples/lm1b)." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "outputId": "6c280522-4b43-4b82-f40a-b73986659b2c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)\n", " CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]\n", "Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('data',))\n", "NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec('data',))\n" ] } ], "source": [ "# Create a mesh and annotate the axis with a name.\n", "device_mesh = mesh_utils.create_device_mesh((8,))\n", "print(device_mesh)\n", "\n", "mesh = Mesh(devices=device_mesh, axis_names=('data',))\n", "print(mesh)\n", "\n", "data_sharding = NamedSharding(mesh, PartitionSpec('data',))\n", "print(data_sharding)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "outputId": "d1bfbcad-e28a-4fae-8136-98bd1efb9332" }, "outputs": [ { "data": { "text/plain": [ "Array([[-2.2187614 ],\n", " [-2.8055234 ],\n", " [-2.5464187 ],\n", " [ 1.0270392 ],\n", " [-3.5243359 ],\n", " [-2.2795477 ],\n", " [-0.6504516 ],\n", " [ 0.17373264]], dtype=float32)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Model(nn.Module):\n", " @nn.compact\n", " def __call__(self, x, add_noise):\n", " x = nn.Dense(1)(x)\n", " # use jnp.where for control flow; for more details see: https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError\n", " return jnp.where(\n", " add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x\n", " )\n", "\n", "module = Model()\n", "init_rng, apply_rng = jax.random.split(jax.random.key(0))\n", "x = jnp.ones((8, 1))\n", "variables = module.init(init_rng, x, False)\n", "\n", "# create custom forward function, since jit does not support kwargs when in_shardings is specified\n", "def forward(variables, x, add_noise, rng):\n", " return module.apply(variables, x, add_noise, rngs={'params': rng})\n", "\n", "# shard the inputs x across devices\n", "# replicate the variables, add_noise boolean and rng key across devices\n", "# shard the output across devices\n", "jit_forward = jax.jit(\n", " forward,\n", " in_shardings=(None, data_sharding, None, None),\n", " out_shardings=data_sharding,\n", ")\n", "out = jit_forward(variables, x, True, apply_rng)\n", "out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output is different given the same input, meaning the RNG key was used to add noise to the output.\n", "\n", "We can also confirm that the output is sharded across devices:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "outputId": "b672b85f-7a2d-44b5-afc1-bbf9426655ed" }, "outputs": [ { "data": { "text/plain": [ "[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-2.2187614]]),\n", " Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-2.8055234]]),\n", " Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-2.5464187]]),\n", " Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[1.0270392]]),\n", " Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-3.5243359]]),\n", " Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-2.2795477]]),\n", " Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-0.6504516]]),\n", " Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373264]])]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.addressable_shards" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way to visualize the output sharding:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "outputId": "1c0a16ce-fa3f-4b95-d794-58464bbaa9ae" }, "outputs": [ { "data": { "text/html": [ "
  CPU 0  \n",
       "         \n",
       "  CPU 1  \n",
       "         \n",
       "  CPU 2  \n",
       "         \n",
       "  CPU 3  \n",
       "         \n",
       "  CPU 4  \n",
       "         \n",
       "  CPU 5  \n",
       "         \n",
       "  CPU 6  \n",
       "         \n",
       "  CPU 7  \n",
       "         \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jax.debug.visualize_array_sharding(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we choose not to add noise, then the output is the same across all batches (as expected, since the input is the same for all batches):" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "outputId": "fe9ec875-3e7f-4861-babc-f07064737276" }, "outputs": [ { "data": { "text/plain": [ "Array([[-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764]], dtype=float32)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = jit_forward(variables, x, False, apply_rng)\n", "out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can confirm the un-jitted function produces the same values, albeit unsharded (note there may be small numerical differences due to compiler optimizations from jitting):" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "outputId": "0a9e5f2c-d4bf-4051-bf71-f32a9c32dc06" }, "outputs": [ { "data": { "text/plain": [ "Array([[-2.2187614 ],\n", " [-2.8055234 ],\n", " [-2.5464187 ],\n", " [ 1.0270392 ],\n", " [-3.5243359 ],\n", " [-2.2795477 ],\n", " [-0.6504516 ],\n", " [ 0.17373264]], dtype=float32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = forward(variables, x, True, apply_rng)\n", "out" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "outputId": "772a5063-1bd5-46b4-f6f6-cae9b4b81a26" }, "outputs": [ { "data": { "text/plain": [ "Array([[-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764],\n", " [-1.2839764]], dtype=float32)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = forward(variables, x, False, apply_rng)\n", "out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using `shard_map`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When using [`jax.experimental.shard_map.shard_map`](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html), the important parts to remember are to:\n", "* split your PRNG key to produce a different key for each device\n", "* the PRNG keys will be sharded automatically to each device (provided you use the correct partition specification), but the [**rank of the original batched PRNG key array will not be reduced**](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#rank-reducing-vs-rank-preserving-maps-over-array-axes); e.g.\n", "with a batch of 8 PRNG keys and 8 devices, each device will see a PRNG key batch of size 1 within the `shard_map`-ed function\n", " * therefore to access the PRNG key itself, we need to index slice into it (see the example below)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "outputId": "aa00b9a3-24ba-4048-ed8c-afbb9070f039" }, "outputs": [ { "data": { "text/plain": [ "Array([[-1.2605132 ],\n", " [-1.2405176 ],\n", " [-0.99350417],\n", " [-1.0277128 ],\n", " [-1.4154483 ],\n", " [-0.3905797 ],\n", " [-2.417677 ],\n", " [ 0.9023453 ]], dtype=float32)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def forward(variables, x, add_noise, rng_key_batch):\n", " # rng_key_batch is a batch of size 1 containing 1 PRNG key\n", " # index slice into the rng_key_batch to access the PRNG key\n", " return module.apply(\n", " variables, x, add_noise, rngs={'params': rng_key_batch[0]}\n", " )\n", "\n", "# define partition specifications\n", "data_pspec = PartitionSpec('data')\n", "no_pspec = PartitionSpec()\n", "\n", "# shard the inputs x and rng keys across devices\n", "# replicate the variables and add_noise boolean across devices\n", "# shard the output across devices\n", "shmap_forward = shard_map(\n", " forward,\n", " mesh=mesh,\n", " in_specs=(no_pspec, data_pspec, no_pspec, data_pspec),\n", " out_specs=data_pspec,\n", ")\n", "# get 8 different rng's that will be used by the 8 devices when doing forward inference\n", "apply_rngs = jax.random.split(apply_rng, 8)\n", "out = shmap_forward(variables, x, True, apply_rngs)\n", "out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Confirm that the output is sharded across devices:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "outputId": "e304289b-ef1c-4e4a-d4c1-4c41613bfa62" }, "outputs": [ { "data": { "text/plain": [ "[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-1.2605132]]),\n", " Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-1.2405176]]),\n", " Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-0.99350417]]),\n", " Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[-1.0277128]]),\n", " Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-1.4154483]]),\n", " Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-0.3905797]]),\n", " Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-2.417677]]),\n", " Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.addressable_shards" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "outputId": "52fdb6d2-4c4f-44b3-feee-4bc5363c8f2f" }, "outputs": [ { "data": { "text/html": [ "
  CPU 0  \n",
       "         \n",
       "  CPU 1  \n",
       "         \n",
       "  CPU 2  \n",
       "         \n",
       "  CPU 3  \n",
       "         \n",
       "  CPU 4  \n",
       "         \n",
       "  CPU 5  \n",
       "         \n",
       "  CPU 6  \n",
       "         \n",
       "  CPU 7  \n",
       "         \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jax.debug.visualize_array_sharding(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lifted transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.\n", "\n", "Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `nn.vmap`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use [`nn.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.vmap.html) to create a batched `Dense` layer:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "outputId": "f0830f6b-659c-446f-c933-7b2a430f8004" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'bias': Array([0., 0.], dtype=float32),\n", " 'kernel': Array([[-1.2488099 , -0.6127134 ],\n", " [-0.07084481, 0.60130936]], dtype=float32)}}" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.ones((3, 2))\n", "\n", "BatchDense = nn.vmap(\n", " nn.Dense,\n", " in_axes=0, out_axes=0,\n", " variable_axes={'params': None},\n", " split_rngs={'params': False})\n", "\n", "BatchDense(2).init(jax.random.key(0), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By denoting `variable_axes={'params': 0}'`, we vectorize the `params` Arrays on the first axis. However the parameter values generated are all identical to each other:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "outputId": "eef5c0ca-f8d5-4f25-8ce6-9f2f60622daf" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-1.2488099 , -0.6127134 ],\n", " [-0.07084481, 0.60130936]],\n", " \n", " [[-1.2488099 , -0.6127134 ],\n", " [-0.07084481, 0.60130936]],\n", " \n", " [[-1.2488099 , -0.6127134 ],\n", " [-0.07084481, 0.60130936]]], dtype=float32)}}" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "BatchDense = nn.vmap(\n", " nn.Dense,\n", " in_axes=0, out_axes=0,\n", " variable_axes={'params': 0},\n", " split_rngs={'params': False})\n", "\n", "BatchDense(2).init(jax.random.key(0), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we also make `split_rngs={'params': True}`, then the PRNG key we provide is split across the variable axis (in this case, the batch axis 0), and we can generate different parameters for each batch input:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "outputId": "275699c3-ba48-403e-877d-07b65981cff5" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.2526208 , -0.15088455],\n", " [-1.1987205 , -0.40843305]],\n", " \n", " [[-0.7064888 , -1.108805 ],\n", " [-0.938775 , 1.4812315 ]],\n", " \n", " [[-0.59468937, -0.2502723 ],\n", " [-1.33515 , 0.5067442 ]]], dtype=float32)}}" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "BatchDense = nn.vmap(\n", " nn.Dense,\n", " in_axes=0, out_axes=0,\n", " variable_axes={'params': 0},\n", " split_rngs={'params': True})\n", "\n", "BatchDense(2).init(jax.random.key(0), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding a variable via `self.variable` is straightforward:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "outputId": "c11a80bc-d865-4e2e-e059-4d6bcea79e09" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", " [-0.01285526, 0.4320353 ]],\n", " \n", " [[ 0.12398645, 0.7884565 ],\n", " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", " 'other_collection': {'kernel': Array([[-0.8193048 , 0.711106 ],\n", " [-0.37802765, -0.66705877],\n", " [-0.44808003, 0.93031347]], dtype=float32)}}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Model(nn.Module):\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Dense(2)(x)\n", " kernel = self.variable(\n", " 'other_collection',\n", " 'kernel',\n", " lambda: jax.random.normal(self.make_rng('other'), x.shape),\n", " )\n", " return x + kernel.value\n", "\n", "BatchModel = nn.vmap(\n", " Model,\n", " in_axes=0,\n", " out_axes=0,\n", " variable_axes={'params': 0, 'other_collection': 0},\n", " split_rngs={'params': True, 'other': True},\n", ")\n", "\n", "BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can control which RNG stream to split, for example, if we only wanted to split the `'params'` RNG stream, then the variables generated from `self.variable` will be the same for each batch input:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "outputId": "fb16619c-c975-497d-c867-6fd5143b4507" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", " [-0.01285526, 0.4320353 ]],\n", " \n", " [[ 0.12398645, 0.7884565 ],\n", " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", " 'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],\n", " [ 0.44956833, -1.1854612 ],\n", " [ 0.44956833, -1.1854612 ]], dtype=float32)}}" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "BatchModel = nn.vmap(\n", " Model,\n", " in_axes=0, out_axes=0,\n", " variable_axes={'params': 0, 'other_collection': 0},\n", " split_rngs={'params': True, 'other': False})\n", "\n", "BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also control which parameters / variables should be generated for each batch input, for example, if we only wanted `'params'` to generate separate parameters for each batch input:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "outputId": "f3a17d59-6f75-4408-caba-5769d4589263" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.9079084 , 0.76390624],\n", " [-0.01285526, 0.4320353 ]],\n", " \n", " [[ 0.12398645, 0.7884565 ],\n", " [ 1.5344163 , 1.3186085 ]],\n", " \n", " [[-0.44171348, 0.43430036],\n", " [-0.40732604, 0.29774475]]], dtype=float32)}},\n", " 'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "BatchModel = nn.vmap(\n", " Model,\n", " in_axes=0, out_axes=0,\n", " variable_axes={'params': 0, 'other_collection': None},\n", " split_rngs={'params': True, 'other': False})\n", "\n", "BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `nn.scan`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use [`nn.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html) to create a scanned `Module` layer (this is useful for simplifying repetitively stacked submodules):" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "outputId": "29d1863b-809f-42ce-894c-1b0810faa41e" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.07838312, -0.7422982 ],\n", " [ 0.87488323, 0.13773395]],\n", " \n", " [[ 0.97309333, 0.9087693 ],\n", " [-0.12564984, -1.0920651 ]],\n", " \n", " [[-0.99055105, 1.1499453 ],\n", " [-0.15721127, -0.62520015]]], dtype=float32)}}}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.ones((3, 2))\n", "\n", "class ResidualMLPBlock(nn.Module):\n", " @nn.compact\n", " def __call__(self, x, _):\n", " h = nn.Dense(features=2)(x)\n", " h = nn.relu(h)\n", " return x + h, None # return an empty carry\n", "\n", "ScanMLP = nn.scan(\n", " ResidualMLPBlock, variable_axes={'params': 0},\n", " variable_broadcast=False, split_rngs={'params': True},\n", " length=3)\n", "\n", "ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to before, we can control whether to split the RNG stream or not, for example, if we wanted all the stacked modules to be initialized to the same parameter values, we can pass in `split_rngs={'params': False}`:" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "outputId": "6a825bcd-9c3b-43c2-afd2-42500d89fb26" }, "outputs": [ { "data": { "text/plain": [ "{'params': {'Dense_0': {'bias': Array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.]], dtype=float32),\n", " 'kernel': Array([[[-0.66715515, -0.0484313 ],\n", " [ 0.9867164 , 0.75408363]],\n", " \n", " [[-0.66715515, -0.0484313 ],\n", " [ 0.9867164 , 0.75408363]],\n", " \n", " [[-0.66715515, -0.0484313 ],\n", " [ 0.9867164 , 0.75408363]]], dtype=float32)}}}" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ScanMLP = nn.scan(\n", " ResidualMLPBlock, variable_axes={'params': 0},\n", " variable_broadcast=False, split_rngs={'params': False},\n", " length=3)\n", "\n", "ScanMLP().init(jax.random.key(0), x, None)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/guides/flax_fundamentals/rng_guide.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Randomness and PRNGs in Flax +++ In this guide, you will learn how Flax uses [JAX's explicit pseudorandom number generator (PRNG) keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng) to emulate randomness, and adds some additional features to make it easier for users to thread PRNG keys through different Flax `Module`s. If you are new to JAX PRNG keys or need a refresher, check out: - [JAX 101: PRNGs in JAX](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) - [🔪 JAX - The Sharp Bits 🔪: Random Numbers](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). +++ ## Setup +++ Install or upgrade Flax, and then import some necessary dependencies. **Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device Google Cloud TPU environment, for example, on Google Cloud or in a Kaggle VM with a TPU. ```{code-cell} :tags: [skip-execution] !pip install -q flax ``` ```{code-cell} import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` ```{code-cell} import flax, flax.linen as nn import jax, jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map import hashlib ``` ```{code-cell} :outputId: ec904f6b-0e87-4efe-87c4-fea0f8e8ec23 jax.devices() ``` Set the JAX config variable `jax_threefry_partitionable` to `True`. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under `jax.jit`. Refer to [JAX discussion](https://github.com/jax-ml/jax/discussions/18480) for more details. ```{code-cell} jax.config.update('jax_threefry_partitionable', True) assert jax.config.jax_threefry_partitionable == True assert jax.config.jax_default_prng_impl == 'threefry2x32' ``` ## Receiving, manipulating and creating PRNG keys with `Module.make_rng` +++ The primary method Flax uses to receive, manipulate and create PRNG keys is via the `Module` method [`self.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng). It is a method that accepts a string name that represents an "RNG stream". Each RNG stream has an initial starting seed PRNG key, which the user passes in as a dictionary argument (i.e. into an [`.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) or [`.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) function), and the starting seed is used by `self.make_rng` to generate more PRNG keys for that stream. If `self.make_rng` is called on a string name that does not have an initial starting seed PRNG key (i.e. the user did not pass a key with the corresponding name into `.init` or `.apply`), then `self.make_rng` will use the `'params'` key as the initial starting seed by default. Note that this method can only be called with bounded modules (see [The Flax Module lifecycle](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#top-level-modules)). ```{code-cell} :outputId: 2a16435b-e92a-480a-f9fb-e6effc42c4c2 class RNGModule(nn.Module): @nn.compact def __call__(self): print(self.make_rng('rng_stream')) print(self.make_rng('rng_stream')) print(self.make_rng('rng_stream')) rng_module = RNGModule() variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` Now if we use a different starting seed PRNG key, we will generate different values (as intended). ```{code-cell} :outputId: 985b0f62-dfde-4f0f-fad4-a31927fc9f59 variables = rng_module.init({'rng_stream': jax.random.key(1)}) ``` Calling `self.make_rng` for one stream will not affect the random values generated from another stream; i.e. the call order doesn't matter. ```{code-cell} :outputId: 7e8ce538-e380-4db9-db23-bc4a8da577da class RNGModuleTwoStreams(nn.Module): @nn.compact def __call__(self): # same value as first code snippet above print(f"rng_stream1: {self.make_rng('rng_stream1')}") # same value as second code snippet above print(f"rng_stream2: {self.make_rng('rng_stream2')}") # same value as first code snippet above print(f"rng_stream1: {self.make_rng('rng_stream1')}") # same value as second code snippet above print(f"rng_stream2: {self.make_rng('rng_stream2')}") # same value as first code snippet above print(f"rng_stream1: {self.make_rng('rng_stream1')}") # same value as second code snippet above print(f"rng_stream2: {self.make_rng('rng_stream2')}") rng_module_two_streams = RNGModuleTwoStreams() variables = rng_module_two_streams.init( {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)} ) ``` Providing the same seed PRNG key will result in the same values being generated (provided that the same operations are used for those keys). ```{code-cell} :outputId: b70be039-589a-48f7-dc54-65e78c449c65 variables = rng_module_two_streams.init( {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)} ) ``` ### How `self.make_rng` works under the hood +++ This is what happens when `self.make_rng` (`flax.linen.Module.make_rng`) is called: * The following data is collected: * The path of the `Module` as provided by `self.scope.path` (the top-level root module has an empty path `()`). * The `self.make_rng` call count. That is, the number of times `self.make_rng` has been called for this specific stream (including this call). * **Note:** Each sub-`Module` will have its own individual call count that's separate from other `Module`s. For example, a `Module` that has called `self.make_rng('params')` twice and contains a sub-`Module` that has called `self.make_rng('params')` once, will have a call count of 2 and 1 for each of the RNG stream `'params'`, respectively. * The data is bundled into a tuple and fed into a hash function and produces an integer. * The generated integer is folded into the RNG stream's starting seed PRNG key to generate a new, unique PRNG key. +++ Below is a slightly simplified version of the hash function that Flax uses for `self.make_rng`: ```{code-cell} def produce_hash(data): m = hashlib.sha1() for x in data: if isinstance(x, str): m.update(x.encode('utf-8')) elif isinstance(x, int): m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big')) else: raise ValueError(f'Expected int or string, got: {x}') d = m.digest() hash_int = int.from_bytes(d[:4], byteorder='big') return hash_int ``` And now you can manually reproduce the PRNG keys generated from `self.make_rng`: ```{code-cell} :outputId: d26b7355-9e8b-4954-b2f4-cf7520d5c5a3 stream_seed = jax.random.key(0) for call_count in range(1, 4): hash_int = produce_hash(data=(call_count,)) print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int))) ``` ```{code-cell} :outputId: dec627a6-4c5a-4e3e-ce11-ce4f72775261 variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` ### Sub-`Module`s and `self.make_rng` +++ This section explores how `self.make_rng` (`flax.linen.Module.make_rng`) behaves with sub-`Module`s. Consider the following example: ```{code-cell} :outputId: 5b7a9ae9-ca49-4ac0-d007-5caeee739ff0 class RNGSubSubModule(nn.Module): def __call__(self): print(f"{self.name}, count 1: {self.make_rng('rng_stream')}") print(f"{self.name}, count 2: {self.make_rng('rng_stream')}") class RNGSubModule(nn.Module): @nn.compact def __call__(self): print(f"{self.name}, count 1: {self.make_rng('rng_stream')}") print(f"{self.name}, count 2: {self.make_rng('rng_stream')}") RNGSubSubModule()() class RNGModule(nn.Module): @nn.compact def __call__(self): print(f"RNGModule, count 1: {self.make_rng('rng_stream')}") print(f"RNGModule, count 2: {self.make_rng('rng_stream')}") RNGSubModule()() rng_module = RNGModule() variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` As previously discussed, the data that is fed into the Flax hash function consists of: * The path of the `Module`, provided by `self.scope.path` (the top-level root module has an empty path `()`); and * The call count for the specific RNG stream. In addition, note that each Flax `Module` and sub-`Module` have their own individual call counts, even for the same RNG stream. The convention for sub-`Module` names is: `f'{module_name}_{module_number}'`. For example, the first `Dense` sub-`Module` will be called `Dense_0`, the second one will be called `Dense_1`, and so on. Therefore, the following data will be fed into the hash function: * For `RNGModule`: The data is just the call count, such as `(1,)` and `(2,)`, since the root `Module` has an empty path. * For `RNGSubModule`: The data is `('RNGSubModule_0', 1)` and `('RNGSubModule_0', 2)`. * For `RNGSubSubModule`: The data is `('RNGSubModule_0', 'RNGSubSubModule_0', 1)` and `('RNGSubModule_0', 'RNGSubSubModule_0', 2)`. With this data, you can manually reproduce the PRNG keys generated from the `Module` and sub-`Module`s using `self.make_rng`. For example: ```{code-cell} :outputId: c0de4d37-0f00-4e58-bdfd-e8a6454ed681 stream_seed = jax.random.key(0) for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')): if initial_data: module_name = initial_data[-1] else: module_name = 'RNGModule' for call_count in (1, 2): hash_int = produce_hash(data=initial_data+(call_count,)) rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int)) print(f"{module_name}, count {call_count}: {rng_key}") ``` If the same sub-`Module` class is used multiple times, you can increment the suffix of the sub-`Module` name accordingly. For example: `RNGSubModule_0`, `RNGSubModule_1`, and so on. ```{code-cell} :outputId: 0b77a038-7000-407b-c5b8-a28dea7951d1 class RNGSubModule(nn.Module): @nn.compact def __call__(self): print(f"{self.name}, count 1: {self.make_rng('rng_stream')}") print(f"{self.name}, count 2: {self.make_rng('rng_stream')}") class RNGModule(nn.Module): @nn.compact def __call__(self): print(f"RNGModule, count 1: {self.make_rng('rng_stream')}") print(f"RNGModule, count 2: {self.make_rng('rng_stream')}") RNGSubModule()() RNGSubModule()() rng_module = RNGModule() variables = rng_module.init({'rng_stream': jax.random.key(0)}) ``` ```{code-cell} :outputId: d189d25e-425d-4fd7-fe18-2dfd63f28b87 stream_seed = jax.random.key(0) for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)): if initial_data: module_name = initial_data[-1] else: module_name = 'RNGModule' for call_count in (1, 2): hash_int = produce_hash(data=initial_data+(call_count,)) rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int)) print(f"{module_name}, count {call_count}: {rng_key}") ``` ### Using `self.param` and `self.variable` +++ Flax users have the option of creating additional parameters and variables in their modules by using the [`self.param`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.param) and [`self.variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.variable) `Module` methods. An `init_fn` argument must be passed to these methods so that it can generate the initial value of the parameter/variable. `self.make_rng` is commonly used implicitly or explicitly in this `init_fn`, since many initializer functions are stochastic in nature and require a PRNG key. See the full list of Flax initializers [here](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/initializers.html). There are a couple of differences between the two methods that the user should take note of: * `self.param` always creates a parameter in the `'params'` [collection](https://flax.readthedocs.io/en/latest/glossary.html#term-Params-parameters), whereas `self.variable` creates a variable in any [collection](https://flax.readthedocs.io/en/latest/glossary.html#term-Variable-collections) the user specifies * `self.param` will automatically call `self.make_rng('params')` and pass in the generated PRNG key implicitly to the `init_fn` of the parameter you instantiated (it will be passed in as the first argument), whereas users will have to manually specify what RNG stream to call `self.make_rng` on in the `init_fn` of `self.variable` (it could be `'params'` or something different). Below is an example using both `self.param` and `self.variable`: ```{code-cell} :outputId: a7816385-0e08-48e2-dc51-055d7bcd0bab class Model(nn.Module): @nn.compact def __call__(self, x): # kernel will use 'params' seed, initial data will include 'Dense_0', call count 1 x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x) # model_param will use 'params' seed, call count 1 model_param = self.param('model_param', jax.random.normal, x.shape) # model_variable1 will use 'params' seed, call count 2 model_variable1 = self.variable( 'other_collection', 'model_variable1', lambda: jax.random.normal(self.make_rng('params'), x.shape), ) # model_variable2 will use 'other' seed, call count 1 model_variable2 = self.variable( 'other_collection', 'model_variable2', lambda: jax.random.normal(self.make_rng('other'), x.shape), ) # kernel will use 'params' seed, initial data will include 'Dense_1', call count 1 # bias will use 'params' seed, initial data will include 'Dense_1', call count 2 x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)( x ) return x model = Model() variables = model.init( {'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2)) ) print(variables['params']['Dense_0']['kernel']) print(variables['params']['model_param']) print(variables['other_collection']['model_variable1']) print(variables['other_collection']['model_variable2']) print(variables['params']['Dense_1']['kernel']) print(variables['params']['Dense_1']['bias']) ``` Remember: * there is a separate count for each RNG stream; this is why the count for `self.make_rng('other')` starts at 1 even though there were earlier calls of `self.make_rng('params')` * each submodule has their own separate count for each rng stream; this is why each `Dense` layer has their own separate count for `self.make_rng('params')` and why `model_param` and `model_variable1` share the same count (since they are defined within the same top-level parent module) ```{code-cell} :outputId: ccec9d64-9a27-47f7-adaf-b36a5ea655db params_seed = jax.random.key(0) other_seed = jax.random.key(1) for initial_data, count, seed, shape in ( (('Dense_0',), 1, params_seed, (2, 2)), ((), 1, params_seed, (2, 2)), ((), 2, params_seed, (2, 2)), ((), 1, other_seed, (2, 2)), (('Dense_1',), 1, params_seed, (2, 2)), (('Dense_1',), 2, params_seed, (1, 2)), ): hash_int = produce_hash(data=(*initial_data, count)) rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int)) print(jax.random.normal(rng_key, shape)) ``` ### Managing RNG streams inside a training loop +++ Below is an example of managing RNG streams from `self.make_rng`, `self.param`, `self.variable` and `nn.Dropout` in a training loop (note: `nn.Dropout` requires a seed PRNG key to be passed in the `'dropout'` RNG stream, since it implicitly calls `self.make_rng('dropout')`): ```{code-cell} class SubModule(nn.Module): @nn.compact def __call__(self, x, train): # variables created using `self.param` will use `self.make_rng('params')` kernel = self.param('submodule_kernel', jax.random.normal, x.shape) x = x + kernel # `nn.Dropout` will use self.make_rng('dropout') x = nn.Dropout(0.2)(x, deterministic=not train) # `nn.Dense` will use self.make_rng('params') x = nn.Dense(3)(x) return x class Model(nn.Module): @nn.compact def __call__(self, x, train): # make kernel use `self.make_rng('other')` kernel = self.variable( 'other_collection', 'module_kernel', lambda: jax.random.normal(self.make_rng('other'), x.shape), ) x = ( x + kernel.value ) # `.value` will extract the underlying value of the variable x = SubModule()(x, train) # `nn.Dropout` will use self.make_rng('dropout') x = nn.Dropout(0.2)(x, deterministic=not train) # `nn.Dense` will use self.make_rng('params') x = nn.Dense(2)(x) return x params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3) init_rngs = {'params': params_rng, 'other': other_rng} x = jnp.ones((1, 3)) y = jnp.ones((1, 2)) module = Model() variables = module.init(init_rngs, x, train=False) ``` ```{code-cell} :outputId: e9da8228-acba-403d-bcb5-33a39d4d530d def update(variables, rng): # we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training # split the rng to get a dropout_rng to be used for this training iteration, # and to get another rng key to be used for the next training iteration dropout_rng, next_rng = jax.random.split(rng) def loss(params): out = module.apply( {'params': params, 'other_collection': variables['other_collection']}, x, train=True, rngs={'dropout': dropout_rng}, ) return jnp.mean((y - out) ** 2) grads = jax.grad(loss)(variables['params']) params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads) return { 'params': params, 'other_collection': variables['other_collection'], }, next_rng for _ in range(10): variables, train_rng = update(variables, train_rng) out = module.apply(variables, x, train=False) print(jnp.mean((y - out)**2)) ``` ### 🔪 Sharp edge 🔪 - unintentionally generating the same values There is an edge case where the same value can be unintentionally generated. See the [Flax issue](https://github.com/google/flax/issues/2157) for more details. ```{code-cell} :outputId: 887142ff-c9ca-4aae-d9fa-cc9993d809c5 class Leaf(nn.Module): def __call__(self, x): return x + jax.random.randint(self.make_rng("rng"), (), 0, 100) class Node(nn.Module): leaf_name: str @nn.compact def __call__(self, x): return Leaf(name=self.leaf_name)(x) class Model(nn.Module): @nn.compact def __call__(self, x): return (Node(name="ab", leaf_name="cdef")(x), Node(name="abc", leaf_name="def")(x), ) out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)}) out1 == out2 # same output, despite having different submodule names ``` This occurs because the hash function [concatenates strings together](https://docs.python.org/3/library/hashlib.html#hashlib.hash.update), so the data `('AB', 'C')` is equivalent to data `('A', 'BC')` when fed into the hash function, therefore producing the same hash int. ```{code-cell} :outputId: 001cbd49-129b-4474-c6a1-3255a4ee3dfe print(produce_hash(data=('A', 'B', 'C', 1))) print(produce_hash(data=('AB', 'C', 1))) print(produce_hash(data=('A', 'BC', 1))) print(produce_hash(data=('ABC', 1))) ``` To avoid this edge case, users can flip the `flax_fix_rng_separator` [configuration flag](https://flax.readthedocs.io/en/latest/api_reference/flax.config.html#flax.configurations.Config.flax_fix_rng_separator) to `True`. ```{code-cell} :outputId: 35a2b204-bdfd-4f83-8e98-ba723963cb0c flax.config.update('flax_fix_rng_separator', True) out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)}) out1 == out2 # different output ``` ## Managing RNG's on multiple devices +++ This section will show examples on how to use `jit` and `shard_map` to use RNG's in multi-device settings. +++ ### Using `jax.jit` +++ When using [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), we can use RNG's as we did before, but we now include `in_shardings` and `out_shardings` arguments to specify how to shard input and output data. The RNG key itself gets replicated (not sharded); `jax.jit` makes each device use it as appropriate for its shard of the data. For more details on training on multiple devices in Flax using `jax.jit`, see our [Scale up Flax Modules on multiple devices guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html#) and [lm1b example](https://github.com/google/flax/tree/main/examples/lm1b). ```{code-cell} :outputId: 6c280522-4b43-4b82-f40a-b73986659b2c # Create a mesh and annotate the axis with a name. device_mesh = mesh_utils.create_device_mesh((8,)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data',)) print(mesh) data_sharding = NamedSharding(mesh, PartitionSpec('data',)) print(data_sharding) ``` ```{code-cell} :outputId: d1bfbcad-e28a-4fae-8136-98bd1efb9332 class Model(nn.Module): @nn.compact def __call__(self, x, add_noise): x = nn.Dense(1)(x) # use jnp.where for control flow; for more details see: https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError return jnp.where( add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x ) module = Model() init_rng, apply_rng = jax.random.split(jax.random.key(0)) x = jnp.ones((8, 1)) variables = module.init(init_rng, x, False) # create custom forward function, since jit does not support kwargs when in_shardings is specified def forward(variables, x, add_noise, rng): return module.apply(variables, x, add_noise, rngs={'params': rng}) # shard the inputs x across devices # replicate the variables, add_noise boolean and rng key across devices # shard the output across devices jit_forward = jax.jit( forward, in_shardings=(None, data_sharding, None, None), out_shardings=data_sharding, ) out = jit_forward(variables, x, True, apply_rng) out ``` The output is different given the same input, meaning the RNG key was used to add noise to the output. We can also confirm that the output is sharded across devices: ```{code-cell} :outputId: b672b85f-7a2d-44b5-afc1-bbf9426655ed out.addressable_shards ``` Another way to visualize the output sharding: ```{code-cell} :outputId: 1c0a16ce-fa3f-4b95-d794-58464bbaa9ae jax.debug.visualize_array_sharding(out) ``` If we choose not to add noise, then the output is the same across all batches (as expected, since the input is the same for all batches): ```{code-cell} :outputId: fe9ec875-3e7f-4861-babc-f07064737276 out = jit_forward(variables, x, False, apply_rng) out ``` We can confirm the un-jitted function produces the same values, albeit unsharded (note there may be small numerical differences due to compiler optimizations from jitting): ```{code-cell} :outputId: 0a9e5f2c-d4bf-4051-bf71-f32a9c32dc06 out = forward(variables, x, True, apply_rng) out ``` ```{code-cell} :outputId: 772a5063-1bd5-46b4-f6f6-cae9b4b81a26 out = forward(variables, x, False, apply_rng) out ``` ### Using `shard_map` +++ When using [`jax.experimental.shard_map.shard_map`](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html), the important parts to remember are to: * split your PRNG key to produce a different key for each device * the PRNG keys will be sharded automatically to each device (provided you use the correct partition specification), but the [**rank of the original batched PRNG key array will not be reduced**](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html#rank-reducing-vs-rank-preserving-maps-over-array-axes); e.g. with a batch of 8 PRNG keys and 8 devices, each device will see a PRNG key batch of size 1 within the `shard_map`-ed function * therefore to access the PRNG key itself, we need to index slice into it (see the example below) ```{code-cell} :outputId: aa00b9a3-24ba-4048-ed8c-afbb9070f039 def forward(variables, x, add_noise, rng_key_batch): # rng_key_batch is a batch of size 1 containing 1 PRNG key # index slice into the rng_key_batch to access the PRNG key return module.apply( variables, x, add_noise, rngs={'params': rng_key_batch[0]} ) # define partition specifications data_pspec = PartitionSpec('data') no_pspec = PartitionSpec() # shard the inputs x and rng keys across devices # replicate the variables and add_noise boolean across devices # shard the output across devices shmap_forward = shard_map( forward, mesh=mesh, in_specs=(no_pspec, data_pspec, no_pspec, data_pspec), out_specs=data_pspec, ) # get 8 different rng's that will be used by the 8 devices when doing forward inference apply_rngs = jax.random.split(apply_rng, 8) out = shmap_forward(variables, x, True, apply_rngs) out ``` Confirm that the output is sharded across devices: ```{code-cell} :outputId: e304289b-ef1c-4e4a-d4c1-4c41613bfa62 out.addressable_shards ``` ```{code-cell} :outputId: 52fdb6d2-4c4f-44b3-feee-4bc5363c8f2f jax.debug.visualize_array_sharding(out) ``` ## Lifted transforms +++ [Flax lifted transforms](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html) allow you to use [JAX transforms](https://github.com/jax-ml/jax#transformations) with `Module` arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms. Refer to [Lifted transformations](https://flax.readthedocs.io/en/latest/developer_notes/lift.html) for more detail. +++ ### `nn.vmap` +++ We can use [`nn.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.vmap.html) to create a batched `Dense` layer: ```{code-cell} :outputId: f0830f6b-659c-446f-c933-7b2a430f8004 x = jnp.ones((3, 2)) BatchDense = nn.vmap( nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': None}, split_rngs={'params': False}) BatchDense(2).init(jax.random.key(0), x) ``` By denoting `variable_axes={'params': 0}'`, we vectorize the `params` Arrays on the first axis. However the parameter values generated are all identical to each other: ```{code-cell} :outputId: eef5c0ca-f8d5-4f25-8ce6-9f2f60622daf BatchDense = nn.vmap( nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': False}) BatchDense(2).init(jax.random.key(0), x) ``` If we also make `split_rngs={'params': True}`, then the PRNG key we provide is split across the variable axis (in this case, the batch axis 0), and we can generate different parameters for each batch input: ```{code-cell} :outputId: 275699c3-ba48-403e-877d-07b65981cff5 BatchDense = nn.vmap( nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}) BatchDense(2).init(jax.random.key(0), x) ``` Adding a variable via `self.variable` is straightforward: ```{code-cell} :outputId: c11a80bc-d865-4e2e-e059-4d6bcea79e09 class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(2)(x) kernel = self.variable( 'other_collection', 'kernel', lambda: jax.random.normal(self.make_rng('other'), x.shape), ) return x + kernel.value BatchModel = nn.vmap( Model, in_axes=0, out_axes=0, variable_axes={'params': 0, 'other_collection': 0}, split_rngs={'params': True, 'other': True}, ) BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) ``` We can control which RNG stream to split, for example, if we only wanted to split the `'params'` RNG stream, then the variables generated from `self.variable` will be the same for each batch input: ```{code-cell} :outputId: fb16619c-c975-497d-c867-6fd5143b4507 BatchModel = nn.vmap( Model, in_axes=0, out_axes=0, variable_axes={'params': 0, 'other_collection': 0}, split_rngs={'params': True, 'other': False}) BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) ``` We can also control which parameters / variables should be generated for each batch input, for example, if we only wanted `'params'` to generate separate parameters for each batch input: ```{code-cell} :outputId: f3a17d59-6f75-4408-caba-5769d4589263 BatchModel = nn.vmap( Model, in_axes=0, out_axes=0, variable_axes={'params': 0, 'other_collection': None}, split_rngs={'params': True, 'other': False}) BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x) ``` ### `nn.scan` +++ We can use [`nn.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.scan.html) to create a scanned `Module` layer (this is useful for simplifying repetitively stacked submodules): ```{code-cell} :outputId: 29d1863b-809f-42ce-894c-1b0810faa41e x = jnp.ones((3, 2)) class ResidualMLPBlock(nn.Module): @nn.compact def __call__(self, x, _): h = nn.Dense(features=2)(x) h = nn.relu(h) return x + h, None # return an empty carry ScanMLP = nn.scan( ResidualMLPBlock, variable_axes={'params': 0}, variable_broadcast=False, split_rngs={'params': True}, length=3) ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry ``` Similar to before, we can control whether to split the RNG stream or not, for example, if we wanted all the stacked modules to be initialized to the same parameter values, we can pass in `split_rngs={'params': False}`: ```{code-cell} :outputId: 6a825bcd-9c3b-43c2-afd2-42500d89fb26 ScanMLP = nn.scan( ResidualMLPBlock, variable_axes={'params': 0}, variable_broadcast=False, split_rngs={'params': False}, length=3) ScanMLP().init(jax.random.key(0), x, None) ``` ================================================ FILE: docs/guides/flax_fundamentals/setup_or_nncompact.rst ================================================ ``setup`` vs ``compact`` ========================================= In Flax's module system (named `Linen`_), submodules and variables (parameters or others) can be defined in two ways: 1. **Explicitly** (using ``setup``): Assign submodules or variables to ``self.`` inside a :meth:`setup ` method. Then use the submodules and variables assigned to ``self.`` in ``setup`` from any "forward pass" method defined on the class. This resembles how modules are defined in PyTorch. 2. **In-line** (using ``nn.compact``): Write your network's logic directly within a single "forward pass" method annotated with :meth:`nn.compact `. This allows you to define your whole module in a single method, and "co-locate" submodules and variables next to where they are used. **Both of these approaches are perfectly valid, behave the same way, and interoperate with all of Flax**. Here is a short example of a module defined in both ways, with exactly the same functionality. .. testsetup:: Using ``setup``, Using ``nn.compact`` import flax.linen as nn .. codediff:: :title: Using ``setup``, Using ``nn.compact`` class MLP(nn.Module): def setup(self): # Submodule names are derived by the attributes you assign to. In this # case, "dense1" and "dense2". This follows the logic in PyTorch. self.dense1 = nn.Dense(32) self.dense2 = nn.Dense(32) def __call__(self, x): x = self.dense1(x) x = nn.relu(x) x = self.dense2(x) return x --- class MLP(nn.Module): @nn.compact #! def __call__(self, x): x = nn.Dense(32, name="dense1")(x) #! x = nn.relu(x) x = nn.Dense(32, name="dense2")(x) #! return x So, how would you decide which style to use? It can be a matter of taste, but here are some pros and cons: Reasons to prefer using ``nn.compact``: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. Allows defining submodules, parameters and other variables next to where they are used: less scrolling up/down to see how everything is defined. 2. Reduces code duplication when there are conditionals or for loops that conditionally define submodules, parameters or variables. 3. Code typically looks more like mathematical notation: ``y = self.param('W', ...) @ x + self.param('b', ...)`` looks similar to :math:`y=Wx+b``) 4. If you are using shape inference, i.e. using parameters whose shape/value depend on shapes of the inputs (which are unknown at initialization), this is not possible using ``setup``. Reasons to prefer using ``setup``: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. Closer to the PyTorch convention, thus easier when porting models from PyTorch 2. Some people find it more natural to explicitly separate the definition of submodules and variables from where they are used 3. Allows defining more than one "forward pass" method (see :class:`MultipleMethodsCompactError `) .. _`Linen`: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#JIT-mechanics:-tracing-and-static-variables ================================================ FILE: docs/guides/flax_fundamentals/state_params.rst ================================================ Managing Parameters and State ============================= We will show you how to... * manage the variables from initialization to updates. * split and re-assemble parameters and state. * use :code:`vmap` with batch-dependant state. .. testsetup:: import flax from flax import linen as nn from jax import random import jax.numpy as jnp import jax import optax # Create some fake data and run only for one epoch for testing. dummy_input = jnp.ones((3, 4)) num_epochs = 1 .. testcode:: class BiasAdderWithRunningMean(nn.Module): momentum: float = 0.9 @nn.compact def __call__(self, x): is_initialized = self.has_variable('batch_stats', 'mean') mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:]) bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) if is_initialized: mean.value = (self.momentum * mean.value + (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True)) return mean.value + bias This example model is a minimal example that contains both parameters (declared with :code:`self.param`) and state variables (declared with :code:`self.variable`). The tricky part with initialization here is that we need to split the state variables and the parameters we're going to optimize for. First we define ``update_step`` as follows (with a dummy loss that should be replaced with yours): .. testcode:: def update_step(apply_fn, x, opt_state, params, state): def loss(params): y, updated_state = apply_fn({'params': params, **state}, x, mutable=list(state.keys())) l = ((x - y) ** 2).sum() # Replace with your loss here. return l, updated_state (l, updated_state), grads = jax.value_and_grad( loss, has_aux=True)(params) updates, opt_state = tx.update(grads, opt_state) # Defined below. params = optax.apply_updates(params, updates) return opt_state, params, updated_state Then we can write the actual training code. .. testcode:: model = BiasAdderWithRunningMean() variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) for _ in range(num_epochs): opt_state, params, state = update_step( model.apply, dummy_input, opt_state, params, state) :code:`vmap` accross the batch dimension ---------------------------------------- When using :code:`vmap` and managing state that depends on the batch dimension, for example when using :code:`BatchNorm`, the setup above must be modified slightly. This is because any layer whose state depends on the batch dimension is not strictly vectorizable. In the case of :code:`BatchNorm`, :code:`lax.pmean()` must be used to average the statistics over the batch dimension so that the state is in sync for each item in the batch. This requires two small changes. Firstly, we need to name the batch axis in our model definition. Here, this is done by specifying the :code:`axis_name` argument of :code:`BatchNorm`. In your own code this might require specifying the :code:`axis_name` argument of :code:`lax.pmean()` directly. .. testsetup:: from functools import partial from flax import linen as nn from jax import random import jax.numpy as jnp import jax import optax # Create some fake data and run only for one epoch for testing. dummy_input = jnp.ones((100,)) key1, key2 = random.split(random.key(0), num=2) batch_size = 64 X = random.normal(key1, (batch_size, 100)) Y = random.normal(key2, (batch_size, 1)) num_epochs = 1 .. testcode:: class MLP(nn.Module): hidden_size: int out_size: int @nn.compact def __call__(self, x, train=False): norm = partial( nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, axis_name="batch", # Name batch dim ) x = nn.Dense(self.hidden_size)(x) x = norm()(x) x = nn.relu(x) x = nn.Dense(self.hidden_size)(x) x = norm()(x) x = nn.relu(x) y = nn.Dense(self.out_size)(x) return y Secondly, we need to specify the same name when calling :code:`vmap` in our training code: .. testcode:: def update_step(apply_fn, x_batch, y_batch, opt_state, params, state): def batch_loss(params): def loss_fn(x, y): pred, updated_state = apply_fn( {'params': params, **state}, x, mutable=list(state.keys()) ) return (pred - y) ** 2, updated_state loss, updated_state = jax.vmap( loss_fn, out_axes=(0, None), # Do not vmap `updated_state`. axis_name='batch' # Name batch dim )(x_batch, y_batch) # vmap only `x`, `y`, but not `state`. return jnp.mean(loss), updated_state (loss, updated_state), grads = jax.value_and_grad( batch_loss, has_aux=True )(params) updates, opt_state = tx.update(grads, opt_state) # Defined below. params = optax.apply_updates(params, updates) return opt_state, params, updated_state, loss Note that we also need to specify that the model state does not have a batch dimension. Now we are able to train the model: .. testcode:: model = MLP(hidden_size=10, out_size=1) variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) for _ in range(num_epochs): opt_state, params, state, loss = update_step( model.apply, X, Y, opt_state, params, state) ================================================ FILE: docs/guides/flax_sharp_bits.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🔪 Flax - The Sharp Bits 🔪\n", "\n", "Flax exposes the full power of JAX. And just like when using JAX, there are certain _[\"sharp bits\"](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)_ you may experience when working with Flax. This evolving document is designed to assist you with them.\n", "\n", "First, install and/or update Flax:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "! pip install -qq flax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔪 `flax.linen.Dropout` layer and randomness\n", "\n", "### TL;DR\n", "\n", "When working on a model with dropout (subclassed from [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)), add the `'dropout'` PRNGkey only during the forward pass.\n", "\n", "1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`.\n", "2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)).\n", "3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a \"simpler\" model.\n", "4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html), pass in `rngs={'dropout': dropout_key}`.\n", "\n", "Check out a full example below.\n", "\n", "### Why this works\n", "\n", "- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)).\n", "- Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key.\n", "- `make_rng` still _guarantees full reproducibility_.\n", "\n", "### Background \n", "\n", "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n", "\n", "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", "\n", "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it. See the [RNG guide](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) for more details.\n", "\n", "> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) is the base class for all neural network modules. All layers and models are subclassed from it.\n", "\n", "### Example\n", "\n", "Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Setup.\n", "import jax\n", "import jax.numpy as jnp\n", "import flax.linen as nn" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Randomness.\n", "seed = 0\n", "root_key = jax.random.key(seed=seed)\n", "main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)\n", "\n", "# A simple network.\n", "class MyModel(nn.Module):\n", " num_neurons: int\n", " training: bool\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Dense(self.num_neurons)(x)\n", " # Set the dropout layer with a rate of 50% .\n", " # When the `deterministic` flag is `True`, dropout is turned off.\n", " x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)\n", " return x\n", "\n", "# Instantiate `MyModel` (you don't need to set `training=True` to\n", "# avoid performing the forward pass computation).\n", "my_model = MyModel(num_neurons=3, training=False)\n", "\n", "x = jax.random.uniform(key=main_key, shape=(3, 4, 4))\n", "\n", "# Initialize with `flax.linen.init()`.\n", "# The `params_key` is equivalent to a dictionary of PRNGs.\n", "# (Here, you are providing only one PRNG key.) \n", "variables = my_model.init(params_key, x)\n", "\n", "# Perform the forward pass with `flax.linen.apply()`.\n", "y = my_model.apply(variables, x, rngs={'dropout': dropout_key})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Real-life examples:\n", "\n", "* Applying word dropout to a batch of input IDs (in a [text classification](https://github.com/google/flax/blob/main/examples/sst2/models.py) context).\n", "* Defining a prediction token in a decoder of a [sequence-to-sequence model](https://github.com/google/flax/blob/main/examples/seq2seq/models.py)." ] } ], "metadata": { "accelerator": "GPU", "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/guides/flax_sharp_bits.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # 🔪 Flax - The Sharp Bits 🔪 Flax exposes the full power of JAX. And just like when using JAX, there are certain _["sharp bits"](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)_ you may experience when working with Flax. This evolving document is designed to assist you with them. First, install and/or update Flax: ```{code-cell} ipython3 :tags: [skip-execution] ! pip install -qq flax ``` ## 🔪 `flax.linen.Dropout` layer and randomness ### TL;DR When working on a model with dropout (subclassed from [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)), add the `'dropout'` PRNGkey only during the forward pass. 1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`. 2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)). 3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a "simpler" model. 4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html), pass in `rngs={'dropout': dropout_key}`. Check out a full example below. ### Why this works - Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)). - Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key. - `make_rng` still _guarantees full reproducibility_. ### Background The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. > Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to "pull PRNG keys". `make_rng` guarantees to provide a unique key each time you call it. See the [RNG guide](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) for more details. > Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) is the base class for all neural network modules. All layers and models are subclassed from it. ### Example Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/init_apply.html) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`. ```{code-cell} ipython3 # Setup. import jax import jax.numpy as jnp import flax.linen as nn ``` ```{code-cell} ipython3 # Randomness. seed = 0 root_key = jax.random.key(seed=seed) main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) # A simple network. class MyModel(nn.Module): num_neurons: int training: bool @nn.compact def __call__(self, x): x = nn.Dense(self.num_neurons)(x) # Set the dropout layer with a rate of 50% . # When the `deterministic` flag is `True`, dropout is turned off. x = nn.Dropout(rate=0.5, deterministic=not self.training)(x) return x # Instantiate `MyModel` (you don't need to set `training=True` to # avoid performing the forward pass computation). my_model = MyModel(num_neurons=3, training=False) x = jax.random.uniform(key=main_key, shape=(3, 4, 4)) # Initialize with `flax.linen.init()`. # The `params_key` is equivalent to a dictionary of PRNGs. # (Here, you are providing only one PRNG key.) variables = my_model.init(params_key, x) # Perform the forward pass with `flax.linen.apply()`. y = my_model.apply(variables, x, rngs={'dropout': dropout_key}) ``` Real-life examples: * Applying word dropout to a batch of input IDs (in a [text classification](https://github.com/google/flax/blob/main/examples/sst2/models.py) context). * Defining a prediction token in a decoder of a [sequence-to-sequence model](https://github.com/google/flax/blob/main/examples/seq2seq/models.py). ================================================ FILE: docs/guides/index.rst ================================================ Guides ====== .. toctree:: :maxdepth: 2 flax_fundamentals/index data_preprocessing/index training_techniques/index parallel_training/index model_inspection/index converting_and_upgrading/index quantization/index The Sharp Bits ================================================ FILE: docs/guides/model_inspection/extracting_intermediates.rst ================================================ Extracting intermediate values ============================== This guide will show you how to extract intermediate values from a module. Let's start with this simple CNN that uses :code:`nn.compact`. .. testsetup:: default, sow import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core import FrozenDict from typing import Sequence batch = jnp.ones((4, 32, 32, 3)) .. testcode:: from flax import linen as nn import jax import jax.numpy as jnp from typing import Sequence class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x Because this module uses ``nn.compact``, we don't have direct access to intermediate values. There are a few ways to expose them: Store intermediate values in a new variable collection ------------------------------------------------------ The CNN can be augmented with calls to ``sow`` to store intermediates as following: .. codediff:: :title: Default CNN, CNN using sow API :groups: default, sow class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x --- class SowCNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten self.sow('intermediates', 'features', x) #! x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x ``sow`` acts as a no-op when the variable collection is not mutable. Therefore, it works perfectly for debugging and optional tracking of intermediates. The 'intermediates' collection is also used by the ``capture_intermediates`` API (see the :ref:`Use ``capture_intermediates``` section). Note that, by default ``sow`` appends values every time it is called: * This is necessary because once instantiated, a module could be called multiple times in its parent module, and we want to catch all the sowed values. * Therefore you want to make sure that you **do not** feed intermediate values back into ``variables``. Otherwise every call will increase the length of that tuple and trigger a recompile. * To override the default append behavior, specify ``init_fn`` and ``reduce_fn`` - see :meth:`Module.sow() `. .. testcode:: sow class SowCNN2(nn.Module): @nn.compact def __call__(self, x): mod = SowCNN(name='SowCNN') return mod(x) + mod(x) # Calling same module instance twice. @jax.jit def init(key, x): variables = SowCNN2().init(key, x) # By default the 'intermediates' collection is not mutable during init. # So variables will only contain 'params' here. return variables @jax.jit def predict(variables, x): # If mutable='intermediates' is not specified, then .sow() acts as a noop. output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates') features = mod_vars['intermediates']['SowCNN']['features'] return output, features batch = jnp.ones((1,28,28,1)) variables = init(jax.random.key(0), batch) preds, feats = predict(variables, batch) assert len(feats) == 2 # Tuple with two values since module was called twice. Refactor module into submodules ------------------------------- This is a useful pattern for cases where it's clear in what particular way you want to split your submodules. Any submodule you expose in ``setup`` can be used directly. In the limit, you can define all submodules in ``setup`` and avoid using ``nn.compact`` altogether. .. testcode:: class RefactoredCNN(nn.Module): def setup(self): self.features = Features() self.classifier = Classifier() def __call__(self, x): x = self.features(x) x = self.classifier(x) return x class Features(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten return x class Classifier(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x @jax.jit def init(key, x): variables = RefactoredCNN().init(key, x) return variables['params'] @jax.jit def features(params, x): return RefactoredCNN().apply({"params": params}, x, method=lambda module, x: module.features(x)) params = init(jax.random.key(0), batch) features(params, batch) Use ``capture_intermediates`` ----------------------------- Linen supports the capture of intermediate return values from submodules automatically without any code changes. This pattern should be considered the "sledge hammer" approach to capturing intermediates. As a debugging and inspection tool it is very useful, but using the other patterns described in this guide will give you more fine-grained control over what intermediates you want to extract. In the following code example we check if any intermediate activations are non-finite (NaN or infinite): .. testcode:: @jax.jit def init(key, x): variables = CNN().init(key, x) return variables @jax.jit def predict(variables, x): y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"]) intermediates = state['intermediates'] fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates) return y, fin variables = init(jax.random.key(0), batch) y, is_finite = predict(variables, batch) all_finite = all(jax.tree_util.tree_leaves(is_finite)) assert all_finite, "non-finite intermediate detected!" By default only the intermediates of ``__call__`` methods are collected. Alternatively, you can pass a custom filter function based on the ``Module`` instance and the method name. .. testcode:: filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense) filter_encodings = lambda mdl, method_name: method_name == "encode" y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"]) dense_intermediates = state['intermediates'] Note that ``capture_intermediates`` will only apply to layers. You can use ``self.sow`` to manually store non-layer intermediates, but the filter function won't be applied to it. .. codediff:: :title: Capturing all layer intermediates, Using filter function and ``self.sow()`` :groups: default, sow class Model(nn.Module): @nn.compact def __call__(self, x): a = nn.Dense(4)(x) # Dense_0 b = nn.Dense(4)(x) # Dense_1 c = a + b # not a Flax layer, so won't be stored as an intermediate d = nn.Dense(4)(c) # Dense_2 return d @jax.jit def init(key, x): variables = Model().init(key, x) return variables['params'] @jax.jit def predict(params, x): return Model().apply({"params": params}, x, capture_intermediates=True) batch = jax.random.uniform(jax.random.key(1), (1,3)) params = init(jax.random.key(0), batch) preds, feats = predict(params, batch) feats # intermediate c in Model was not stored because it's not a Flax layer --- class Model(nn.Module): @nn.compact def __call__(self, x): a = nn.Dense(4)(x) # Dense_0 b = nn.Dense(4)(x) # Dense_1 c = a + b self.sow('intermediates', 'c', c) # store intermediate c #! d = nn.Dense(4)(c) # Dense_2 return d @jax.jit def init(key, x): variables = Model().init(key, x) return variables['params'] @jax.jit def predict(params, x): # filter specifically for only the Dense_0 and Dense_2 layer #! filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'}) #! return Model().apply({"params": params}, x, capture_intermediates=filter_fn) #! batch = jax.random.uniform(jax.random.key(1), (1,3)) params = init(jax.random.key(0), batch) preds, feats = predict(params, batch) feats # intermediate c in Model is stored and isn't filtered out by the filter function #! To separate the intermediates extracted from ``self.sow`` from the intermediates extracted from ``capture_intermediates``, we can either define a separate collection like ``self.sow('sow_intermediates', 'c', c)``, or manually filter out the intermediates after calling ``.apply()``. For example: .. testcode:: sow flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/') flattened_dict['c'] In terms of efficiency, as long as everything is jitted, then any intermediates you don't end up using should be optimized away by XLA. Use ``Sequential`` --------------------- You could also define ``CNN`` using a simple implementation of a ``Sequential`` combinator (this is quite common in more stateful approaches). This may be useful for very simple models and gives you arbitrary model surgery. But it can be very limiting -- if you even want to add one conditional, you are forced to refactor away from ``Sequential`` and structure your model more explicitly. .. testcode:: class Sequential(nn.Module): layers: Sequence[nn.Module] def __call__(self, x): for layer in self.layers: x = layer(x) return x def SeqCNN(): return Sequential([ nn.Conv(features=32, kernel_size=(3, 3)), nn.relu, lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), nn.Conv(features=64, kernel_size=(3, 3)), nn.relu, lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)), lambda x: x.reshape((x.shape[0], -1)), # flatten nn.Dense(features=256), nn.relu, nn.Dense(features=10), nn.log_softmax, ]) @jax.jit def init(key, x): variables = SeqCNN().init(key, x) return variables['params'] @jax.jit def features(params, x): return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x) batch = jnp.ones((1,28,28,1)) params = init(jax.random.key(0), batch) features(params, batch) Extracting gradients of intermediate values =========================================== For debugging purposes, it can be useful to extract the gradients of intermediate values. This can be done by using the :meth:`Module.perturb() ` method over the desired values. .. testcode:: class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.relu(nn.Dense(8)(x)) x = self.perturb('hidden', x) x = nn.Dense(2)(x) x = self.perturb('logits', x) return x ``perturb`` adds a variable to a ``perturbations`` collection by default, it behaves like an identity function and the gradient of the perturbation matches the gradient of the input. To get the perturbations just initialize the model: .. testcode:: x = jnp.empty((1, 4)) # random data y = jnp.empty((1, 2)) # random data model = Model() variables = model.init(jax.random.key(1), x) params, perturbations = variables['params'], variables['perturbations'] Finally compute the gradients of the loss with respect to the perturbations, these will match the gradients of the intermediates: .. testcode:: def loss_fn(params, perturbations, x, y): y_pred = model.apply({'params': params, 'perturbations': perturbations}, x) return jnp.mean((y_pred - y) ** 2) intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y) ================================================ FILE: docs/guides/model_inspection/index.rst ================================================ Model inspection ================ .. toctree:: :maxdepth: 1 model_surgery extracting_intermediates ================================================ FILE: docs/guides/model_inspection/model_surgery.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "120e57f5", "metadata": {}, "source": [ "Model surgery\n", "==============================\n", "\n", "Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick." ] }, { "cell_type": "markdown", "id": "9c3bfb0e", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "413f8b2d", "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "!pip install --upgrade -q pip jax jaxlib flax" ] }, { "cell_type": "code", "execution_count": null, "id": "5b002c8d", "metadata": {}, "outputs": [], "source": [ "import functools\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from flax import traverse_util\n", "from flax import linen as nn\n", "from flax.core import freeze\n", "import jax\n", "import optax" ] }, { "cell_type": "markdown", "id": "1060b519", "metadata": {}, "source": [ "Surgery with Flax Modules\n", "--------------------------------\n", "\n", "Let's create a small convolutional neural network model for our demo.\n", "\n", "As usual, you can run `CNN.init(...)['params']` to get the `params` to pass and modify it in every step of your training.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "755ae323", "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1))\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=10)(x)\n", " x = nn.log_softmax(x)\n", " return x\n", "\n", "def get_initial_params(key):\n", " init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)\n", " initial_params = CNN().init(key, init_shape)['params']\n", " return initial_params\n", "\n", "key = jax.random.key(0)\n", "params = get_initial_params(key)\n", "\n", "jax.tree_util.tree_map(jnp.shape, params)" ] }, { "cell_type": "markdown", "id": "170273f8", "metadata": {}, "source": [ "Note that what returned as `params` is a `FrozenDict`, which contains a few JAX arrays as kernel and bias. \n", "\n", "A `FrozenDict` is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new `params` need to replace the old `params`. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating.\n", "\n", "One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator `sep` to join all nested keys. If no `sep` is given, the key will be a tuple of all nested keys." ] }, { "cell_type": "code", "execution_count": null, "id": "c7ec7741", "metadata": {}, "outputs": [], "source": [ "# Get a flattened key-value list.\n", "flat_params = traverse_util.flatten_dict(params, sep='/')\n", "\n", "jax.tree_util.tree_map(jnp.shape, flat_params)" ] }, { "cell_type": "markdown", "id": "2adda656", "metadata": {}, "source": [ "Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training." ] }, { "cell_type": "code", "execution_count": null, "id": "bb975feb", "metadata": {}, "outputs": [], "source": [ "# Somehow modify a layer\n", "dense_kernel = flat_params['Dense_1/kernel']\n", "flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)\n", "\n", "# Unflatten.\n", "unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')\n", "# Refreeze.\n", "unflat_params = freeze(unflat_params)\n", "jax.tree_util.tree_map(jnp.shape, unflat_params)" ] }, { "cell_type": "markdown", "id": "f3462cd8", "metadata": {}, "source": [ "Surgery with Optimizers\n", "--------------------------------\n", "\n", "When using `Optax` as an optimizer, the ``opt_state`` is actually a nested tuple\n", "of the states of individual gradient transformations that compose the optimizer.\n", "These states contain pytrees that mirror the parameter tree, and can be modified\n", "the same way: flattening, modifying, unflattening, and then recreating a new\n", "optimizer state that mirrors the original state." ] }, { "cell_type": "code", "execution_count": null, "id": "3cbecb63", "metadata": {}, "outputs": [], "source": [ "tx = optax.adam(1.0)\n", "opt_state = tx.init(params)\n", "\n", "# The optimizer state is a tuple of gradient transformation states.\n", "jax.tree_util.tree_map(jnp.shape, opt_state)" ] }, { "cell_type": "markdown", "id": "18f1cebb", "metadata": {}, "source": [ "The pytrees inside the optimizer state follow the same structure as the\n", "parameters and can be flattened / modified exactly the same way." ] }, { "cell_type": "code", "execution_count": null, "id": "13b5e25f", "metadata": {}, "outputs": [], "source": [ "flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')\n", "flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')\n", "\n", "jax.tree_util.tree_map(jnp.shape, flat_mu)" ] }, { "cell_type": "markdown", "id": "e5c4479e", "metadata": {}, "source": [ "After modification, re-create optimizer state. Use this for future training." ] }, { "cell_type": "code", "execution_count": null, "id": "9dcac8cd", "metadata": {}, "outputs": [], "source": [ "opt_state = (\n", " opt_state[0]._replace(\n", " mu=traverse_util.unflatten_dict(flat_mu, sep='/'),\n", " nu=traverse_util.unflatten_dict(flat_nu, sep='/'),\n", " ),\n", ") + opt_state[1:]\n", "jax.tree_util.tree_map(jnp.shape, opt_state)" ] } ], "metadata": { "jupytext": { "formats": "md,ipynb", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/guides/model_inspection/model_surgery.md ================================================ --- jupyter: jupytext: formats: md,ipynb main_language: python text_representation: extension: .md format_name: markdown format_version: '1.3' jupytext_version: 1.13.8 --- Model surgery ============================== Usually, Flax modules and optimizers track and update the params for you. But there may be some time when you want to do some model surgery and tweak the param tensors yourself. This guide shows you how to do the trick. ## Setup ```python tags=["skip-execution"] !pip install --upgrade -q pip jax jaxlib flax ``` ```python import functools import jax import jax.numpy as jnp from flax import traverse_util from flax import linen as nn from flax.core import freeze import jax import optax ``` Surgery with Flax Modules -------------------------------- Let's create a small convolutional neural network model for our demo. As usual, you can run `CNN.init(...)['params']` to get the `params` to pass and modify it in every step of your training. ```python class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) x = nn.log_softmax(x) return x def get_initial_params(key): init_shape = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_shape)['params'] return initial_params key = jax.random.key(0) params = get_initial_params(key) jax.tree_util.tree_map(jnp.shape, params) ``` Note that what returned as `params` is a `FrozenDict`, which contains a few JAX arrays as kernel and bias. A `FrozenDict` is nothing more than a read-only dict, and Flax made it read-only because of the functional nature of JAX: JAX arrays are immutable, and the new `params` need to replace the old `params`. Making the dict read-only ensures that no in-place mutation of the dict can happen accidentally during the training and updating. One way to actually modify the params outside of a Flax module is to explicitly flatten it and creates a mutable dict. Note that you can use a separator `sep` to join all nested keys. If no `sep` is given, the key will be a tuple of all nested keys. ```python # Get a flattened key-value list. flat_params = traverse_util.flatten_dict(params, sep='/') jax.tree_util.tree_map(jnp.shape, flat_params) ``` Now you can do whatever you want with the params. When you are done, unflatten it back and use it in future training. ```python # Somehow modify a layer dense_kernel = flat_params['Dense_1/kernel'] flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel) # Unflatten. unflat_params = traverse_util.unflatten_dict(flat_params, sep='/') # Refreeze. unflat_params = freeze(unflat_params) jax.tree_util.tree_map(jnp.shape, unflat_params) ``` Surgery with Optimizers -------------------------------- When using `Optax` as an optimizer, the ``opt_state`` is actually a nested tuple of the states of individual gradient transformations that compose the optimizer. These states contain pytrees that mirror the parameter tree, and can be modified the same way: flattening, modifying, unflattening, and then recreating a new optimizer state that mirrors the original state. ```python tx = optax.adam(1.0) opt_state = tx.init(params) # The optimizer state is a tuple of gradient transformation states. jax.tree_util.tree_map(jnp.shape, opt_state) ``` The pytrees inside the optimizer state follow the same structure as the parameters and can be flattened / modified exactly the same way. ```python flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/') flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/') jax.tree_util.tree_map(jnp.shape, flat_mu) ``` After modification, re-create optimizer state. Use this for future training. ```python opt_state = ( opt_state[0]._replace( mu=traverse_util.unflatten_dict(flat_mu, sep='/'), nu=traverse_util.unflatten_dict(flat_nu, sep='/'), ), ) + opt_state[1:] jax.tree_util.tree_map(jnp.shape, opt_state) ``` ================================================ FILE: docs/guides/parallel_training/ensembling.rst ================================================ Ensembling on multiple devices ============================== We show how to train an ensemble of CNNs on the MNIST dataset, where the size of the ensemble is equal to the number of available devices. In short, this change be described as: * make a number of functions parallel using |jax.pmap()|_, * split the random seed to obtain different parameter initialization, * replicate the inputs and unreplicate the outputs where necessary, * average probabilities across devices to compute the predictions. In this HOWTO we omit some of the code such as imports, the CNN module, and metrics computation, but they can be found in the `MNIST example`_. .. testsetup:: Single-model, Ensemble import functools from flax import jax_utils # Copied from examples/mnist/train.py from absl import logging from flax import linen as nn from flax.training import train_state import jax import jax.numpy as jnp import numpy as np import optax class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x # Fake data for faster execution. def get_datasets(): train_ds = test_ds = { 'image': jnp.zeros([64, 28, 28, 1]), 'label': jnp.zeros([64], jnp.int32), } return train_ds, test_ds # Modified from examples/mnist/configs.default.py learning_rate = 0.1 momentum = 0.9 batch_size = 32 num_epochs = 1 Parallel functions ------------------ We start by creating a parallel version of ``create_train_state()``, which retrieves the initial parameters of the models. We do this using |jax.pmap()|_. The effect of "pmapping" a function is that it will compile the function with XLA (similar to |jax.jit()|_), but execute it in parallel on XLA devices (e.g., GPUs/TPUs). .. codediff:: :title: Single-model, Ensemble :sync: #! def create_train_state(rng, learning_rate, momentum): cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate, momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) --- @functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2)) #! def create_train_state(rng, learning_rate, momentum): cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate, momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) Note that for the single-model code above, we use |jax.jit()|_ to lazily initialize the model (see `Module.init`_'s documentation for more details). For the ensembling case, |jax.pmap()|_ will map over the first axis of the provided argument ``rng`` by default, so we should make sure that we provide a different value for each device when we call this function later on. Note also how we specify that ``learning_rate`` and ``momentum`` are static arguments, which means the concrete values of these arguments will be used, rather than abstract shapes. This is necessary because the provided arguments will be scalar values. For more details see `JIT mechanics: tracing and static variables`_. Next we simply do the same for the functions ``apply_model()`` and ``update_model()``. To compute the predictions from the ensemble, we take the average of the individual probabilities. We use |jax.lax.pmean()|_ to compute the average *across devices*. This also requires us to specify the ``axis_name`` to both |jax.pmap()|_ and |jax.lax.pmean()|_. .. codediff:: :title: Single-model, Ensemble :sync: @jax.jit #! def apply_model(state, images, labels): def loss_fn(params): logits = CNN().apply({'params': params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean() return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) #! accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) #! return grads, loss, accuracy @jax.jit #! def update_model(state, grads): return state.apply_gradients(grads=grads) --- @functools.partial(jax.pmap, axis_name='ensemble') #! def apply_model(state, images, labels): def loss_fn(params): logits = CNN().apply({'params': params}, images) one_hot = jax.nn.one_hot(labels, 10) loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean() return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble') #! accuracy = jnp.mean(jnp.argmax(probs, -1) == labels) #! return grads, loss, accuracy @jax.pmap #! def update_model(state, grads): return state.apply_gradients(grads=grads) Training the Ensemble --------------------- Next we transform the ``train_epoch()`` function. When calling the pmapped functions from above, we mainly need to take care of duplicating the arguments for all devices where necessary, and de-duplicating the return values. .. codediff:: :title: Single-model, Ensemble :sync: def train_epoch(state, train_ds, batch_size, rng): train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: batch_images = train_ds['image'][perm, ...] #! batch_labels = train_ds['label'][perm, ...] #! grads, loss, accuracy = apply_model(state, batch_images, batch_labels) state = update_model(state, grads) epoch_loss.append(loss) #! epoch_accuracy.append(accuracy) #! train_loss = np.mean(epoch_loss) train_accuracy = np.mean(epoch_accuracy) return state, train_loss, train_accuracy --- def train_epoch(state, train_ds, batch_size, rng): train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) epoch_loss = [] epoch_accuracy = [] for perm in perms: batch_images = jax_utils.replicate(train_ds['image'][perm, ...]) #! batch_labels = jax_utils.replicate(train_ds['label'][perm, ...]) #! grads, loss, accuracy = apply_model(state, batch_images, batch_labels) state = update_model(state, grads) epoch_loss.append(jax_utils.unreplicate(loss)) #! epoch_accuracy.append(jax_utils.unreplicate(accuracy)) #! train_loss = np.mean(epoch_loss) train_accuracy = np.mean(epoch_accuracy) return state, train_loss, train_accuracy As can be seen, we do not have to make any changes to the logic around the ``state``. This is because, as we will see below in our training code, the train state is replicated already, so when we pass it to ``train_step()``, things will just work fine since ``train_step()`` is pmapped. However, the train dataset is not yet replicated, so we do that here. Since replicating the entire train dataset is too memory intensive we do it at the batch level. We can now rewrite the actual training logic. This consists of two simple changes: making sure the RNGs are replicated when we pass them to ``create_train_state()``, and replicating the test dataset, which is much smaller than the train dataset so we can do this for the entire dataset directly. .. codediff:: :title: Single-model, Ensemble :sync: train_ds, test_ds = get_datasets() #! rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, learning_rate, momentum) #! #! for epoch in range(1, num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, batch_size, input_rng) _, test_loss, test_accuracy = apply_model( #! state, test_ds['image'], test_ds['label']) #! logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, ' 'test_loss: %.4f, test_accuracy: %.2f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)) --- train_ds, test_ds = get_datasets() test_ds = jax_utils.replicate(test_ds) #! rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) state = create_train_state(jax.random.split(init_rng, jax.device_count()), #! learning_rate, momentum) #! for epoch in range(1, num_epochs + 1): rng, input_rng = jax.random.split(rng) state, train_loss, train_accuracy = train_epoch( state, train_ds, batch_size, input_rng) _, test_loss, test_accuracy = jax_utils.unreplicate( #! apply_model(state, test_ds['image'], test_ds['label'])) #! logging.info( 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, ' 'test_loss: %.4f, test_accuracy: %.2f' % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)) .. |jax.jit()| replace:: ``jax.jit()`` .. _jax.jit(): https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#To-JIT-or-not-to-JIT .. |jax.pmap()| replace:: ``jax.pmap()`` .. _jax.pmap(): https://jax.readthedocs.io/en/latest/jax.html#jax.pmap .. |jax.lax.pmean()| replace:: ``jax.lax.pmean()`` .. _jax.lax.pmean(): https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.pmean.html .. _Module.init: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init .. _`JIT mechanics: tracing and static variables`: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#JIT-mechanics:-tracing-and-static-variables .. _`MNIST example`: https://github.com/google/flax/blob/main/examples/mnist/train.py ================================================ FILE: docs/guides/parallel_training/flax_on_pjit.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scale up Flax Modules on multiple devices\n", "\n", "This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Flax and `jax.jit` scaled up\n", "\n", "[`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", "\n", "Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including:\n", "\n", "1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", "2. Utility functions to generate the sharding information that `jax.jit` requires to run.\n", "3. An interface to customize your axis names called \"logical axis annotations\" to decouple both your Module code and partition plan to experiment with different partition layouts more easily.\n", "\n", "You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "Import some necessary dependencies.\n", "\n", "**Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n" ] } ], "source": [ "import functools\n", "from typing import Optional, Callable\n", "\n", "import numpy as np\n", "import jax\n", "from jax import lax, random, numpy as jnp\n", "\n", "import flax\n", "from flax import struct, traverse_util, linen as nn\n", "from flax.core import freeze, unfreeze\n", "from flax.training import train_state, checkpoints\n", "\n", "import optax # Optax for common losses and optimizers." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" ] } ], "source": [ "print(f'We have 8 fake JAX devices now: {jax.devices()}')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide:\n", "\n", "1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board).\n", "\n", "2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where:\n", " * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.\n", " * `'model'`: the mesh dimension used for sharding parameters of the model across devices.\n", "\n", "3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", "from jax.lax import with_sharding_constraint\n", "from jax.experimental import mesh_utils" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]\n", " [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]\n", "Mesh('data': 2, 'model': 4)\n" ] } ], "source": [ "# Create a mesh and annotate each axis with a name.\n", "device_mesh = mesh_utils.create_device_mesh((2, 4))\n", "print(device_mesh)\n", "\n", "mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))\n", "print(mesh)\n", "\n", "def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:\n", " return NamedSharding(mesh, pspec)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define a layer\n", "\n", "Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between.\n", "\n", "To shard the parameters efficiently, apply the following APIs to annotate the parameters and intermediate variables:\n", "\n", "1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters.\n", "\n", "2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known.\n", "\n", " * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class DotReluDot(nn.Module):\n", " depth: int\n", " dense_init: Callable = nn.initializers.xavier_normal()\n", " @nn.compact\n", " def __call__(self, x):\n", "\n", " y = nn.Dense(self.depth,\n", " kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),\n", " use_bias=False, # or overwrite with `bias_init`\n", " )(x)\n", "\n", " y = jax.nn.relu(y)\n", " # Force a local sharding annotation.\n", " y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))\n", "\n", " W2 = self.param(\n", " 'W2',\n", " nn.with_partitioning(self.dense_init, ('model', None)),\n", " (self.depth, x.shape[-1]))\n", "\n", " z = jnp.dot(y, W2)\n", " # Force a local sharding annotation.\n", " z = with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))\n", "\n", " # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.\n", " return z, None" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.\n", "\n", "For example:\n", "\n", "* When you define `W1` with shape `(x.shape[-1], self.depth)` and annotate as `(None, 'model')`:\n", "\n", " * The first dimension (of length `x.shape[-1]`) will be replicated across all devices.\n", " * The second dimension (of length `self.depth`) will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, on this dimension.\n", "\n", "* When you annotate the output `z` as `('data', None)`:\n", "\n", " * The first dimension — the batch dimension — will be sharded over the `'data'` axis. This means half of the batch will be processed on devices `0-3` (first four devices), and another half on devices `4-7` (the remaining four devices).\n", " * The second dimension — the data depth dimension — will be replicated across all devices." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define a model with `flax.linen.scan` lifted transformation\n", "\n", "Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)) as multiple layers of `DotReluDot`.\n", "\n", "To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), or a for-loop:\n", "\n", "* `flax.linen.scan` can provide faster compilation times.\n", "* The for-loop can be faster on runtime.\n", "\n", "The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimensional and you can visualize their sharding.\n", "\n", "The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " num_layers: int\n", " depth: int\n", " use_scan: bool\n", " @nn.compact\n", " def __call__(self, x):\n", " if self.use_scan:\n", " x, _ = nn.scan(DotReluDot, length=self.num_layers,\n", " variable_axes={\"params\": 0},\n", " split_rngs={\"params\": True},\n", " metadata_params={nn.PARTITION_NAME: None}\n", " )(self.depth)(x)\n", " else:\n", " for i in range(self.num_layers):\n", " x, _ = DotReluDot(self.depth)(x)\n", " return x" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now, create a `model` instance, and a sample input `x`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# MLP hyperparameters.\n", "BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False\n", "# Create fake inputs.\n", "x = jnp.ones((BATCH, DEPTH))\n", "# Initialize a PRNG key.\n", "k = random.key(0)\n", "\n", "# Create an Optax optimizer.\n", "optimizer = optax.adam(learning_rate=0.001)\n", "# Instantiate the model.\n", "model = MLP(LAYERS, DEPTH, USE_SCAN)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Specify sharding\n", "\n", "Next, you need to tell `jax.jit` how to shard our data across devices.\n", "\n", "### The input's sharding\n", "\n", "For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
┌──────────────────────────────────────────────────────────────────────────────┐\n",
       "│                                                                              │\n",
       "│                                 CPU 0,1,2,3                                  │\n",
       "│                                                                              │\n",
       "│                                                                              │\n",
       "├──────────────────────────────────────────────────────────────────────────────┤\n",
       "│                                                                              │\n",
       "│                                 CPU 4,5,6,7                                  │\n",
       "│                                                                              │\n",
       "│                                                                              │\n",
       "└──────────────────────────────────────────────────────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┌──────────────────────────────────────────────────────────────────────────────┐\n", "│ │\n", "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "├──────────────────────────────────────────────────────────────────────────────┤\n", "│ │\n", "│ CPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", "│ │\n", "│ │\n", "└──────────────────────────────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length)\n", "x = jax.device_put(x, x_sharding)\n", "jax.debug.visualize_array_sharding(x)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### The output's sharding\n", "\n", "You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree.\n", "\n", "To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can:\n", "\n", "1. Evaluate `model.init` (in this case, a wrapper of it) abstractly using [`jax.eval_shape`](https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html).\n", "\n", "1. Use [`flax.linen.get_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.get_sharding) to automatically generate the `jax.sharding.NamedSharding`.\n", " * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) annotations in the earlier definition to generate the correct sharding for the parameters." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def init_fn(k, x, model, optimizer):\n", " variables = model.init(k, x) # Initialize the model.\n", " state = train_state.TrainState.create( # Create a `TrainState`.\n", " apply_fn=model.apply,\n", " params=variables['params'],\n", " tx=optimizer)\n", " return state" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), apply_fn=, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x33e134280>, update=.update_fn at 0x33e134430>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec()), mu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}, nu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None))}}), EmptyState()))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create an abstract closure to wrap the function before feeding it in\n", "# because `jax.eval_shape` only takes pytrees as arguments.\n", "abstract_variables = jax.eval_shape(\n", " functools.partial(init_fn, model=model, optimizer=optimizer), k, x)\n", "\n", "# This `state_sharding` has the same pytree structure as `state`, the output\n", "# of the `init_fn`.\n", "state_sharding = nn.get_sharding(abstract_variables, mesh)\n", "state_sharding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile the code\n", "\n", "Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) to your `init_fn`, but with two extra arguments: `in_shardings` and `out_shardings`.\n", "\n", "Run it to get the `initialized_state`, in which parameters are sharded exactly as instructed:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
┌───────┬───────┬───────┬───────┐\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "└───────┴───────┴───────┴───────┘\n",
       "
\n" ], "text/plain": [ "┌───────┬───────┬───────┬───────┐\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
       "│        CPU 0,4        │\n",
       "├───────────────────────┤\n",
       "│        CPU 1,5        │\n",
       "├───────────────────────┤\n",
       "│        CPU 2,6        │\n",
       "├───────────────────────┤\n",
       "│        CPU 3,7        │\n",
       "└───────────────────────┘\n",
       "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),\n", " in_shardings=(mesh_sharding(PartitionSpec()), x_sharding), # PRNG key and x\n", " out_shardings=state_sharding)\n", "\n", "initialized_state = jit_init_fn(k, x, model, optimizer)\n", "\n", "# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():\n", "# print(f'Sharding of {weight}: {partitioned.names}')\n", "jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)\n", "jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspect the Module output\n", "\n", "Note that in the output of `initialized_state`, the `params` `W1` and `W2` are of type [`flax.linen.Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.Partitioned). This is a wrapper around the actual `jax.Array` that allows Flax to record the axis names associated with it.\n", "\n", "You can access the raw `jax.Array`s by calling `flax.linen.meta.unbox()` upon the dictionary, or call `.value` upon individual variable. You can also use `flax.linen.meta.replace_boxed()` to change the underlying `jax.Array` without modifying the sharding annotations." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "(None, 'model')\n", "(1024, 1024)\n" ] } ], "source": [ "print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))\n", "print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))\n", "print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)\n", "print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Say for some unknown reason you want to make the whole param tree all-zero\n", "unboxed_params = nn.meta.unbox(initialized_state.params)\n", "all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)\n", "all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)\n", "assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n" ] }, { "data": { "text/plain": [ "NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec())" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(initialized_state.step)\n", "initialized_state.step.sharding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can use [`jax.tree_util.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'Dense_0': {'kernel': Partitioned(value=(1024, 1024), names=(None, 'model'), mesh=None)}, 'W2': Partitioned(value=(1024, 1024), names=('model', None), mesh=None)}\n", "\n", "(1024, 1024)\n" ] } ], "source": [ "diff = jax.tree_util.tree_map(\n", " lambda a, b: a - b,\n", " initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0'])\n", "print(jax.tree_util.tree_map(jnp.shape, diff))\n", "diff_array = diff['Dense_0']['kernel'].value\n", "print(type(diff_array))\n", "print(diff_array.shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Compile the train step and inference\n", "\n", "Create a `jit`ted training step as follows:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),\n", " out_shardings=state_sharding)\n", "def train_step(state, x):\n", " # A fake loss function.\n", " def loss_unrolled(params):\n", " y = model.apply({'params': params}, x)\n", " return y.sum()\n", " grad_fn = jax.grad(loss_unrolled)\n", " grads = grad_fn(state.params)\n", " state = state.apply_gradients(grads=grads)\n", " return state\n", "\n", "with jax.set_mesh(mesh):\n", " new_state = train_step(initialized_state, x)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sharding of Weight 1:\n" ] }, { "data": { "text/html": [ "
┌───────┬───────┬───────┬───────┐\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "└───────┴───────┴───────┴───────┘\n",
       "
\n" ], "text/plain": [ "┌───────┬───────┬───────┬───────┐\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Sharding of Weight 2:\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
       "│        CPU 0,4        │\n",
       "├───────────────────────┤\n",
       "│        CPU 1,5        │\n",
       "├───────────────────────┤\n",
       "│        CPU 2,6        │\n",
       "├───────────────────────┤\n",
       "│        CPU 3,7        │\n",
       "└───────────────────────┘\n",
       "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f'Sharding of Weight 1:')\n", "jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)\n", "print(f'Sharding of Weight 2:')\n", "jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "float32\n", "(8, 1024)\n" ] }, { "data": { "text/html": [ "
┌──────────────────────────────────────────────────────────────────────────────┐\n",
       "│                                                                              │\n",
       "│                                 CPU 0,1,2,3                                  │\n",
       "│                                                                              │\n",
       "│                                                                              │\n",
       "├──────────────────────────────────────────────────────────────────────────────┤\n",
       "│                                                                              │\n",
       "│                                 CPU 4,5,6,7                                  │\n",
       "│                                                                              │\n",
       "│                                                                              │\n",
       "└──────────────────────────────────────────────────────────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┌──────────────────────────────────────────────────────────────────────────────┐\n", "│ │\n", "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "├──────────────────────────────────────────────────────────────────────────────┤\n", "│ │\n", "│ CPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", "│ │\n", "│ │\n", "└──────────────────────────────────────────────────────────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),\n", " out_shardings=x_sharding)\n", "def apply_fn(state, x):\n", " return state.apply_fn({'params': state.params}, x)\n", "\n", "with jax.set_mesh(mesh):\n", " y = apply_fn(new_state, x)\n", "print(type(y))\n", "print(y.dtype)\n", "print(y.shape)\n", "jax.debug.visualize_array_sharding(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Profiling\n", "\n", "If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "20.9 ms ± 319 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "def block_all(xs):\n", " jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)\n", " return xs\n", "\n", "with jax.set_mesh(mesh):\n", " new_state = block_all(train_step(initialized_state, x))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Logical axis annotation\n", "\n", "JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`).\n", "\n", "The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to the Modules you created earlier, except for the following:\n", "\n", "1. All axes are annotated with more concrete, meaningful names, such as `'embed'`, `'hidden'`, `'batch'` and `'layer'`. These names are referred to as _logical axis names_ in Flax. They make the dimensional changes inside model definitions more readable.\n", "\n", "2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_partitioning) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class LogicalDotReluDot(nn.Module):\n", " depth: int\n", " dense_init: Callable = nn.initializers.xavier_normal()\n", " @nn.compact\n", " def __call__(self, x):\n", " y = nn.Dense(self.depth,\n", " kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),\n", " use_bias=False, # or overwrite with `bias_init`\n", " )(x)\n", "\n", " y = jax.nn.relu(y)\n", " # Force a local sharding annotation.\n", " y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))\n", "\n", " W2 = self.param(\n", " 'W2',\n", " nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),\n", " (self.depth, x.shape[-1]))\n", "\n", " z = jnp.dot(y, W2)\n", " # Force a local sharding annotation.\n", " z = nn.with_logical_constraint(z, ('batch', 'embed'))\n", " return z, None\n", "\n", "class LogicalMLP(nn.Module):\n", " num_layers: int\n", " depth: int\n", " use_scan: bool\n", " @nn.compact\n", " def __call__(self, x):\n", " if self.use_scan:\n", " x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,\n", " variable_axes={\"params\": 0},\n", " split_rngs={\"params\": True},\n", " metadata_params={nn.PARTITION_NAME: 'layer'}\n", " )(self.depth)(x)\n", " else:\n", " for i in range(self.num_layers):\n", " x, _ = LogicalDotReluDot(self.depth)(x)\n", " return x" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now, initiate a model and try to figure out what sharding its `state` should have.\n", "\n", "To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.logical_to_mesh_sharding) will convert them to the kind of sharding that the device mesh can understand.\n", "\n", "This allows you to change the rules and try out new partition layouts without modifying the model definition." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "annotations are logical, not mesh-specific: PartitionSpec('embed', 'hidden')\n", "sharding annotations are mesh-specific: PartitionSpec(None, 'model')\n" ] } ], "source": [ "# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.\n", "rules = (('batch', 'data'),\n", " ('hidden', 'model'))\n", "\n", "logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)\n", "\n", "logical_abstract_variables = jax.eval_shape(\n", " functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)\n", "logical_state_spec = nn.get_partition_spec(logical_abstract_variables)\n", "print('annotations are logical, not mesh-specific: ',\n", " logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'])\n", "\n", "logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules)\n", "print('sharding annotations are mesh-specific: ',\n", " logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous (\"non-logical\") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0']" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),\n", " in_shardings=(mesh_sharding(PartitionSpec()), x_sharding), # PRNG key and x\n", " out_shardings=logical_state_sharding)\n", "\n", "logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sharding of Weight 1:\n" ] }, { "data": { "text/html": [ "
┌───────┬───────┬───────┬───────┐\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "│       │       │       │       │\n",
       "└───────┴───────┴───────┴───────┘\n",
       "
\n" ], "text/plain": [ "┌───────┬───────┬───────┬───────┐\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m│CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m│CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m│CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m│\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "│ │ │ │ │\n", "└───────┴───────┴───────┴───────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Sharding of Weight 2:\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
       "│        CPU 0,4        │\n",
       "├───────────────────────┤\n",
       "│        CPU 1,5        │\n",
       "├───────────────────────┤\n",
       "│        CPU 2,6        │\n",
       "├───────────────────────┤\n",
       "│        CPU 3,7        │\n",
       "└───────────────────────┘\n",
       "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ CPU \u001b[1;36m0\u001b[0m,\u001b[1;36m4\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m1\u001b[0m,\u001b[1;36m5\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m2\u001b[0m,\u001b[1;36m6\u001b[0m │\n", "├───────────────────────┤\n", "│ CPU \u001b[1;36m3\u001b[0m,\u001b[1;36m7\u001b[0m │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f'Sharding of Weight 1:')\n", "jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value)\n", "print(f'Sharding of Weight 2:')\n", "jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## When to use device axis / logical axis\n", "\n", "Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model:\n", "\n", "* **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", "* **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.\n", "\n", "* **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Save the data\n", "\n", "To save the cross-device array, you can use Orbax as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod).\n", "\n", "In practice, you might want to save the raw `jax.Array` pytree as checkpoint, instead of the wrapped `Partitioned` values, to reduce complexity. You can restore it as-is and put it back into an annotated pytree with `flax.linen.meta.replace_boxed()`.\n", "\n", "Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired [`jax.sharding.Sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Sharding) in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/guides/parallel_training/flax_on_pjit.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Scale up Flax Modules on multiple devices This guide shows how to scale up [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) on multiple devices and hosts using [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) (formerly [`experimental.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit)) and [`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html). +++ ## Flax and `jax.jit` scaled up [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. Flax provides several functionalities that can help you use auto-SPMD on [Flax Modules](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html), including: 1. An interface to specify partitions of your data when defining [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). 2. Utility functions to generate the sharding information that `jax.jit` requires to run. 3. An interface to customize your axis names called "logical axis annotations" to decouple both your Module code and partition plan to experiment with different partition layouts more easily. You can learn more about `jax.jit` APIs for scaling up in [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) on JAX's documentation site. +++ ## Setup Import some necessary dependencies. **Note:** This guide uses the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don't need this if you are already using a multi-device TPU environment. ```{code-cell} ipython3 import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` ```{code-cell} ipython3 import functools from typing import Optional, Callable import numpy as np import jax from jax import lax, random, numpy as jnp import flax from flax import struct, traverse_util, linen as nn from flax.core import freeze, unfreeze from flax.training import train_state, checkpoints import optax # Optax for common losses and optimizers. ``` ```{code-cell} ipython3 print(f'We have 8 fake JAX devices now: {jax.devices()}') ``` The code below shows how to import and set up the JAX-level device API, following JAX's [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) guide: 1. Start a 2x4 device `mesh` (8 devices) using JAX's `mesh_utils.create_device_mesh`. This layout is the same as the one of a [TPU v3-8](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board). 2. Annotate each axis with a name using the `axis_names` parameter in `jax.sharding.Mesh`. A typical way to annotate axis names is `axis_name=('data', 'model')`, where: * `'data'`: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations. * `'model'`: the mesh dimension used for sharding parameters of the model across devices. 3. Make a simple utility function `mesh_sharding` for generating a sharding object from the mesh and any layout. ```{code-cell} ipython3 from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.lax import with_sharding_constraint from jax.experimental import mesh_utils ``` ```{code-cell} ipython3 # Create a mesh and annotate each axis with a name. device_mesh = mesh_utils.create_device_mesh((2, 4)) print(device_mesh) mesh = Mesh(devices=device_mesh, axis_names=('data', 'model')) print(mesh) def mesh_sharding(pspec: PartitionSpec) -> NamedSharding: return NamedSharding(mesh, pspec) ``` ## Define a layer Before defining a simple model, create an example layer called `DotReluDot` (by subclassing `flax.linen.Module`). The layer creates two parameters `W1` and `W2` for dot product multiplication, and uses the `jax.nn.relu` (ReLU) activation function in-between. To shard the parameters efficiently, apply the following APIs to annotate the parameters and intermediate variables: 1. Use [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) to decorate the initializer function when creating sub-layers or raw parameters. 2. Apply [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) (formerly, `pjit.with_sharding_constraint`) to annotate intermediate variables like `y` and `z` to force a particular sharding pattern when the ideal constraint is known. * This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for `y` and `z` regardless. ```{code-cell} ipython3 class DotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @nn.compact def __call__(self, x): y = nn.Dense(self.depth, kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')), use_bias=False, # or overwrite with `bias_init` )(x) y = jax.nn.relu(y) # Force a local sharding annotation. y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) W2 = self.param( 'W2', nn.with_partitioning(self.dense_init, ('model', None)), (self.depth, x.shape[-1])) z = jnp.dot(y, W2) # Force a local sharding annotation. z = with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None))) # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below. return z, None ``` Note that device axis names like `'data'`, `'model'` or `None` are passed into both [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) and [`jax.lax.with_sharding_constraint`](https://github.com/jax-ml/jax/blob/main/jax/_src/pjit.py#L1516) API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all. For example: * When you define `W1` with shape `(x.shape[-1], self.depth)` and annotate as `(None, 'model')`: * The first dimension (of length `x.shape[-1]`) will be replicated across all devices. * The second dimension (of length `self.depth`) will be sharded over the `'model'` axis of the device mesh. This means `W1` will be sharded 4-way on devices `(0, 4)`, `(1, 5)`, `(2, 6)` and `(3, 7)`, on this dimension. * When you annotate the output `z` as `('data', None)`: * The first dimension — the batch dimension — will be sharded over the `'data'` axis. This means half of the batch will be processed on devices `0-3` (first four devices), and another half on devices `4-7` (the remaining four devices). * The second dimension — the data depth dimension — will be replicated across all devices. +++ ## Define a model with `flax.linen.scan` lifted transformation Having created `DotReluDot`, you can now define the `MLP` model (by subclassing [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)) as multiple layers of `DotReluDot`. To replicate identical layers, you can either use [`flax.linen.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), or a for-loop: * `flax.linen.scan` can provide faster compilation times. * The for-loop can be faster on runtime. The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimensional and you can visualize their sharding. The `flax.linen.scan` code is just to show that this API works with [Flax lifted transforms](https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations). ```{code-cell} ipython3 class MLP(nn.Module): num_layers: int depth: int use_scan: bool @nn.compact def __call__(self, x): if self.use_scan: x, _ = nn.scan(DotReluDot, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True}, metadata_params={nn.PARTITION_NAME: None} )(self.depth)(x) else: for i in range(self.num_layers): x, _ = DotReluDot(self.depth)(x) return x ``` Now, create a `model` instance, and a sample input `x`. ```{code-cell} ipython3 # MLP hyperparameters. BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. x = jnp.ones((BATCH, DEPTH)) # Initialize a PRNG key. k = random.key(0) # Create an Optax optimizer. optimizer = optax.adam(learning_rate=0.001) # Instantiate the model. model = MLP(LAYERS, DEPTH, USE_SCAN) ``` ## Specify sharding Next, you need to tell `jax.jit` how to shard our data across devices. ### The input's sharding For data parallelism, you can shard the batched _input_ `x` across the `data` axis by denoting the batch axis as `'data'`. Then, use [`jax.device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to place it onto the correct `device`s. ```{code-cell} ipython3 x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length) x = jax.device_put(x, x_sharding) jax.debug.visualize_array_sharding(x) ``` ### The output's sharding You need to compile `model.init()` (that is, [`flax.linen.Module.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init)), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) to track other variables, such as optimizer states, and that would make the output an even more complex pytree. To achieve this, luckily, you don't have to hardcode the output's sharding by hand. Instead, you can: 1. Evaluate `model.init` (in this case, a wrapper of it) abstractly using [`jax.eval_shape`](https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html). 1. Use [`flax.linen.get_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.get_sharding) to automatically generate the `jax.sharding.NamedSharding`. * This step utilizes the [`flax.linen.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) annotations in the earlier definition to generate the correct sharding for the parameters. ```{code-cell} ipython3 def init_fn(k, x, model, optimizer): variables = model.init(k, x) # Initialize the model. state = train_state.TrainState.create( # Create a `TrainState`. apply_fn=model.apply, params=variables['params'], tx=optimizer) return state ``` ```{code-cell} ipython3 # Create an abstract closure to wrap the function before feeding it in # because `jax.eval_shape` only takes pytrees as arguments. abstract_variables = jax.eval_shape( functools.partial(init_fn, model=model, optimizer=optimizer), k, x) # This `state_sharding` has the same pytree structure as `state`, the output # of the `init_fn`. state_sharding = nn.get_sharding(abstract_variables, mesh) state_sharding ``` ## Compile the code Now you can apply [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) to your `init_fn`, but with two extra arguments: `in_shardings` and `out_shardings`. Run it to get the `initialized_state`, in which parameters are sharded exactly as instructed: ```{code-cell} ipython3 jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(PartitionSpec()), x_sharding), # PRNG key and x out_shardings=state_sharding) initialized_state = jit_init_fn(k, x, model, optimizer) # for weight, partitioned in initialized_state.params['DotReluDot_0'].items(): # print(f'Sharding of {weight}: {partitioned.names}') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` ## Inspect the Module output Note that in the output of `initialized_state`, the `params` `W1` and `W2` are of type [`flax.linen.Partitioned`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.Partitioned). This is a wrapper around the actual `jax.Array` that allows Flax to record the axis names associated with it. You can access the raw `jax.Array`s by calling `flax.linen.meta.unbox()` upon the dictionary, or call `.value` upon individual variable. You can also use `flax.linen.meta.replace_boxed()` to change the underlying `jax.Array` without modifying the sharding annotations. ```{code-cell} ipython3 print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'])) print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names) print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape) ``` ```{code-cell} ipython3 # Say for some unknown reason you want to make the whole param tree all-zero unboxed_params = nn.meta.unbox(initialized_state.params) all_zero = jax.tree.map(jnp.zeros_like, unboxed_params) all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero) assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0 ``` You can also check the underlying [`jax.sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html) of each parameter, which is now more internal than `NamedSharding`. Note that numbers like `initialized_state.step` are replicated across all devices. ```{code-cell} ipython3 initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding ``` ```{code-cell} ipython3 print(initialized_state.step) initialized_state.step.sharding ``` You can use [`jax.tree_util.tree_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html) to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays. ```{code-cell} ipython3 diff = jax.tree_util.tree_map( lambda a, b: a - b, initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0']) print(jax.tree_util.tree_map(jnp.shape, diff)) diff_array = diff['Dense_0']['kernel'].value print(type(diff_array)) print(diff_array.shape) ``` ## Compile the train step and inference Create a `jit`ted training step as follows: ```{code-cell} ipython3 @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=state_sharding) def train_step(state, x): # A fake loss function. def loss_unrolled(params): y = model.apply({'params': params}, x) return y.sum() grad_fn = jax.grad(loss_unrolled) grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state with jax.set_mesh(mesh): new_state = train_step(initialized_state, x) ``` ```{code-cell} ipython3 print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value) ``` Then, create a compiled inference step. Note that the output is also sharded along `(data, None)`. ```{code-cell} ipython3 @functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding), out_shardings=x_sharding) def apply_fn(state, x): return state.apply_fn({'params': state.params}, x) with jax.set_mesh(mesh): y = apply_fn(new_state, x) print(type(y)) print(y.dtype) print(y.shape) jax.debug.visualize_array_sharding(y) ``` ## Profiling If you are running on a TPU pod or a pod slice, you can use a custom `block_all` utility function, as defined below, to measure the performance: ```{code-cell} ipython3 %%timeit def block_all(xs): jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs) return xs with jax.set_mesh(mesh): new_state = block_all(train_step(initialized_state, x)) ``` ## Logical axis annotation JAX's automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`). The `LogicalDotReluDot` and `LogicalMLP` Module definition below are similar to the Modules you created earlier, except for the following: 1. All axes are annotated with more concrete, meaningful names, such as `'embed'`, `'hidden'`, `'batch'` and `'layer'`. These names are referred to as _logical axis names_ in Flax. They make the dimensional changes inside model definitions more readable. 2. [`flax.linen.with_logical_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_partitioning) replaces `flax.linen.with_partitioning`; and [`flax.linen.with_logical_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_constraint) replaces `jax.lax.with_sharding_constraint`, to recognize the logical axis names. ```{code-cell} ipython3 class LogicalDotReluDot(nn.Module): depth: int dense_init: Callable = nn.initializers.xavier_normal() @nn.compact def __call__(self, x): y = nn.Dense(self.depth, kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')), use_bias=False, # or overwrite with `bias_init` )(x) y = jax.nn.relu(y) # Force a local sharding annotation. y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model'))) W2 = self.param( 'W2', nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')), (self.depth, x.shape[-1])) z = jnp.dot(y, W2) # Force a local sharding annotation. z = nn.with_logical_constraint(z, ('batch', 'embed')) return z, None class LogicalMLP(nn.Module): num_layers: int depth: int use_scan: bool @nn.compact def __call__(self, x): if self.use_scan: x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True}, metadata_params={nn.PARTITION_NAME: 'layer'} )(self.depth)(x) else: for i in range(self.num_layers): x, _ = LogicalDotReluDot(self.depth)(x) return x ``` Now, initiate a model and try to figure out what sharding its `state` should have. To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis `'data'` or `'model'`. This rule is a list of (`logical_axis_name`, `device_axis_name`) tuples, and [`flax.linen.logical_to_mesh_sharding`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.logical_to_mesh_sharding) will convert them to the kind of sharding that the device mesh can understand. This allows you to change the rules and try out new partition layouts without modifying the model definition. ```{code-cell} ipython3 # Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`. rules = (('batch', 'data'), ('hidden', 'model')) logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN) logical_abstract_variables = jax.eval_shape( functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x) logical_state_spec = nn.get_partition_spec(logical_abstract_variables) print('annotations are logical, not mesh-specific: ', logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel']) logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules) print('sharding annotations are mesh-specific: ', logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec) ``` You can verify that the `logical_state_spec` here has the same content as `state_spec` in the previous ("non-logical") example. This allows you to `jax.jit` your Module's [`flax.linen.Module.init`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init) and [`flax.linen.Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) the same way in the above above. ```{code-cell} ipython3 state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0'] ``` ```{code-cell} ipython3 logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3), in_shardings=(mesh_sharding(PartitionSpec()), x_sharding), # PRNG key and x out_shardings=logical_state_sharding) logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer) ``` ```{code-cell} ipython3 print(f'Sharding of Weight 1:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value) print(f'Sharding of Weight 2:') jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value) ``` ## When to use device axis / logical axis Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model: * **Device mesh axis**: If you want a very simple model, or you are very confident of your way of partitioning, defining it with __device mesh axis__ can potentially save you a few extra lines of code of converting the logical naming back to the device naming. * **Logical naming**: On the other hand, the __logical naming__ helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model. * **Device axis names**: In really advanced use cases, you may have more complicated sharding patterns that require annotating *activation* dimension names differently from *parameter* dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using __device axis names__ could be more helpful. +++ ## Save the data To save the cross-device array, you can use Orbax as shown in the [Save and load checkpoints guide - Multi-host/multi-process checkpointing](https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing). This is especially required if you are running on a multi-host environment (for example, a TPU pod). In practice, you might want to save the raw `jax.Array` pytree as checkpoint, instead of the wrapped `Partitioned` values, to reduce complexity. You can restore it as-is and put it back into an annotated pytree with `flax.linen.meta.replace_boxed()`. Keep in mind that to restore the arrays to the desired partition, you need to provide a sample `target` pytree that has the same structure and has the desired [`jax.sharding.Sharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Sharding) in place for each JAX array. The sharding you use to restore the array doesn't necessarily need to be the same as the ones you used to store the array. ================================================ FILE: docs/guides/parallel_training/index.rst ================================================ Parallel training ================= .. toctree:: :maxdepth: 1 ensembling flax_on_pjit ================================================ FILE: docs/guides/quantization/fp8_basics.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "ca360491", "metadata": {}, "source": [ "# User Guide on Using FP8\n", "\n", "JAX supports various FP8 formats, including E4M3 (jnp.float8_e4m3fn) and E5M2\n", "(jnp.float8_e5m2). Due to the limited range of FP8 data types, higher-precision\n", "data must be scaled to fit within the FP8 representable range, a process known\n", "as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back\n", "to its original type.\n", "\n", "While jnp.dot supports FP8 inputs directly, proper quantization and\n", "dequantization is needed for optimal performance. Flax provides\n", "nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle\n", "this automatically and can be used with existing layers like nn.Dense.\n", "\n", "This tutorial will walk you through the basics of how to use it.\n", "\n", "## Setting up our environment\n", "\n", "Here, we provide the code necessary to set up the environment for our notebook.\n", "Additionally, we define a function to check if the XLA-optimized HLO will indeed\n", "call an FP8 dot operation under the hood.\n", "\n", "*Note: This tutorial relies on the XLA-FP8 feature, which is only supported on\n", "NVIDIA Hopper GPUs or later.*" ] }, { "cell_type": "code", "execution_count": null, "id": "177b91c4", "metadata": {}, "outputs": [], "source": [ "import flax\n", "import jax\n", "import re\n", "import pprint\n", "from jax import random\n", "from jax import numpy as jnp\n", "from jax._src import test_util as jtu\n", "from flax import linen as nn\n", "from flax.linen import fp8_ops\n", "\n", "e4m3 = jnp.float8_e4m3fn\n", "f32 = jnp.float32\n", "E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)\n", "\n", "assert jtu.is_cuda_compute_capability_at_least(\"9.0\")\n", "\n", "def check_fp8_call(lowered):\n", " hlo = lowered.compile()\n", " if re.search(r\"custom-call\\(f8e4m3fn.*, f8e4m3fn.*\", hlo.as_text()):\n", " print(\"Fp8 call detected!\")\n", " else:\n", " print(\"No Fp8 call!\")" ] }, { "cell_type": "markdown", "id": "4adc021f", "metadata": {}, "source": [ "## FLAX Low Level API\n", "\n", "The JAX dot operations (e.g. `jnp.dot`) support the FP8 dtype inputs. So it is\n", "legal to do the following call:" ] }, { "cell_type": "code", "execution_count": null, "id": "c54c374e", "metadata": {}, "outputs": [], "source": [ "k0, k1 = random.split(random.key(0), 2)\n", "a = random.uniform(k0, (16, 32))\n", "b = random.uniform(k1, (32, 64))\n", "@jax.jit\n", "def dot_fp8(a, b):\n", " return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)\n", "check_fp8_call(dot_fp8.lower(a, b))" ] }, { "cell_type": "markdown", "id": "adb22878", "metadata": {}, "source": [ "However, this approach has two key limitations:\n", "\n", "1. `jnp.dot` does not support custom scaling factors for operands, defaulting to\n", " a scale of 1.0\n", "2. The autodiff does not automatically use E5M2 for gradients and E4M3 for\n", " activations/weights during training, which is the recommended practice\n", "\n", "To overcome these limitations and implement proper FP8 matrix multiplication, we\n", "recommend using the Flax FP8 APIs. Let's start with a basic scaling approach.\n", "\n", "\n", "### Current Scaling\n", "\n", "Scaling factors are usually defined as `scale = amax(x) / MAX`, where `amax` is\n", "an operation to find the absolute maximum value of the tensor, and `MAX` is the\n", "maximum value of the representable range of the target dtype. This scaling\n", "approach allows us to derive the scaling factors directly from the current\n", "operand tensors of the dot product." ] }, { "cell_type": "code", "execution_count": null, "id": "f0e746e3", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def dot_fp8(a, b):\n", " a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n", " b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n", " a = fp8_ops.quantize(a, e4m3, a_scale, f32)\n", " b = fp8_ops.quantize(b, e4m3, b_scale, f32)\n", "\n", " c = jnp.dot(a, b, preferred_element_type=f32)\n", " c = fp8_ops.dequantize(c, f32, a_scale * b_scale)\n", " return c\n", "\n", "c = dot_fp8(a, b)\n", "check_fp8_call(dot_fp8.lower(a, b))" ] }, { "cell_type": "markdown", "id": "59aca6fe", "metadata": {}, "source": [ "As shown in the code, we perform quantization (`fp8_ops.quantize`) on the\n", "tensors to get the lower precision operands. The `jnp.dot` processes them and\n", "accumulates the output in high precision (i.e., the `preferred_element_type`).\n", "After that, we multiply the result by the scaling factors to dequantize back to\n", "the original range (`fp8_ops.dequantize`). Note that while this example uses\n", "E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and\n", "E5M2 for the inputs. The quantization method and the scaling factors can also be\n", "customized based on application needs.\n", "\n", "One major issue with the current scaling method is the performance overhead\n", "introduced by computing `a_scale` and `b_scale`, which requires additional\n", "loading of the operand tensors. To overcome this issue, we recommend the delayed\n", "scaling.\n", "\n", "### Delayed Scaling\n", "\n", "In delayed scaling, we use a scaling factor associated with an amax history. The\n", "scaling factor remains a scalar, but the amax history is a list that stores amax\n", "values from recent steps (e.g., 1024 steps). Both tensors are computed from\n", "previous steps and maintained in the model parameters.\n", "\n", "The quantization and dequantization operations for delayed scaling are provided\n", "by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles\n", "input quantization and update the amax history and scaling factor, while\n", "`fp8_ops.out_dq` performs output dequantization." ] }, { "cell_type": "code", "execution_count": null, "id": "cf466308", "metadata": {}, "outputs": [], "source": [ "a_scale = jnp.array(1.0)\n", "b_scale = jnp.array(1.0)\n", "a_amax_hist = jnp.zeros((1024,))\n", "b_amax_hist = jnp.zeros((1024,))\n", "\n", "@jax.jit\n", "def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):\n", " a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)\n", " b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)\n", " \n", " c = jnp.dot(a, b, preferred_element_type=f32)\n", " c = fp8_ops.out_dq(f32, a_scale, b_scale, c)\n", " return c\n", "\n", "c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)\n", "check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))" ] }, { "cell_type": "markdown", "id": "b3bdc038", "metadata": {}, "source": [ "In this example, we first prepare three pairs of scaling factors and amax\n", "histories, treating them as results computed from previous steps. Then, we apply\n", "`fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq`\n", "to the output of `jnp.dot`.\n", "\n", "\n", "## FLAX High Level API\n", "Flax provides high-level operations to seamlessly integrate FP8 quantization\n", "into existing layers. Instead of manually handling quantization of the delayed\n", "scaling (e.g., the maintanence of the amax history and scaling factors), users\n", "can simply use these drop-in replacements:\n", "\n", "* `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations\n", "* `fp8_ops.Fp8Einsum` for `jnp.einsum` operations \n", "\n", "These operations automatically handle all FP8-related functionality, including\n", "quantization/dequantization, scale factor updates, and FP8 dtype selection for\n", "both forward and backward passes.\n", "\n", "Consider the following example:" ] }, { "cell_type": "code", "execution_count": null, "id": "bd8d9dba", "metadata": {}, "outputs": [], "source": [ "model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)\n", "params = model.init(k0, A)\n", "\n", "@jax.jit\n", "def train_step(var, a): \n", " c = model.apply(var, a)\n", " return jnp.sum(c)\n", "\n", "check_fp8_call(train_step.lower(params, A))" ] }, { "cell_type": "markdown", "id": "ba280e79", "metadata": {}, "source": [ "By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the\n", "default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version.\n", "The model usage remains similar, but now includes additional parameters for FP8\n", "quantization: scaling factors and amax history values. The next section explains\n", "how to update these FP8-specific parameters.\n", "\n", "For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE)\n", "layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8\n", "quantization. Here's an example:" ] }, { "cell_type": "code", "execution_count": null, "id": "961b4549", "metadata": {}, "outputs": [], "source": [ "from typing import Any\n", "class FooModule(nn.Module):\n", " einsum: Any = None\n", " @nn.compact\n", " def __call__(self, a, b):\n", " if self.einsum is not None:\n", " einsum_fn = self.einsum()\n", " elif self.einsum is None:\n", " einsum_fn = jnp.einsum\n", " c = einsum_fn(\"mk,kn->mn\", a, b)\n", " return c\n", "\n", "model = FooModule(einsum=fp8_ops.Fp8Einsum)\n", "params = model.init(k0, a, b)\n", "\n", "@jax.jit\n", "def train_step(var, a, b):\n", " c = model.apply(var, a, b)\n", " return jnp.sum(c)\n", "\n", "check_fp8_call(train_step.lower(params, a, b))" ] }, { "cell_type": "markdown", "id": "a83b0851", "metadata": {}, "source": [ "## Manipulate FP8 params\n", "\n", "The following sections explain the internal FP8 parameters managed by\n", "`fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters\n", "include scaling factors and amax history values that control the FP8\n", "quantization process. While most users don't need to interact with these\n", "directly, understanding them can be valuable for advanced optimization and\n", "debugging.\n", "\n", "Let's first examine the data structure of `params`. In the code below, we redact\n", "the parameter values and then display the PyTree structure." ] }, { "cell_type": "code", "execution_count": null, "id": "873799fe", "metadata": {}, "outputs": [], "source": [ "params_structure = flax.core.unfreeze(params).copy()\n", "params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')\n", "for key, value in params_structure.items():\n", " params_structure[key] = '*'\n", "params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')\n", "pprint.pprint(params_structure)" ] }, { "cell_type": "markdown", "id": "031894dc", "metadata": {}, "source": [ "The output is as follows:\n", "\n", "```plaintext\n", "{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',\n", " 'input_scale': '*',\n", " 'kernel_amax_history': '*',\n", " 'kernel_scale': '*',\n", " 'output_grad_amax_history': '*',\n", " 'output_grad_scale': '*'}}}\n", "```\n", "\n", "In addition to the expected `params`, there is an additional category called\n", "`_overwrite_with_gradient`. This category includes three pairs of `amax_history`\n", "and `scale` for the activation, kernel, and dot gradient, respectively.\n", "\n", "### Update gradient of FP8 params\n", "Now, we perform one training step to obtain the gradients and see how to use\n", "them to update the parameters." ] }, { "cell_type": "code", "execution_count": null, "id": "593fc35f", "metadata": {}, "outputs": [], "source": [ "step_fn = jax.jit(jax.grad(train_step, (0, 1)))\n", "\n", "grads = step_fn(params, A)\n", "\n", "params = flax.core.unfreeze(params)\n", "params = flax.traverse_util.flatten_dict(params, sep='/')\n", "grads = flax.traverse_util.flatten_dict(grads[0], sep='/')\n", "\n", "for key, value in params.items():\n", " if key.startswith('params'):\n", " params[key] = value + 0.01 * grads[key]\n", " if key.startswith('_overwrite_with_gradient'):\n", " params[key] = grads[key]\n", "\n", "params = flax.traverse_util.unflatten_dict(params, sep='/')\n", "params = flax.core.freeze(params)" ] }, { "cell_type": "markdown", "id": "1a8e2153", "metadata": {}, "source": [ "The above code demonstrates how to update both `params` and\n", "`_overwrite_with_gradient`. For `params`, we use the formula `new_param =\n", "old_param + 0.01 * grads`, where `0.01` is the learning rate (or users can use\n", "whatever optimizers from `optax`). For `_overwrite_with_gradient`, we simply use\n", "the gradient to overwrite the old values.\n", "\n", "Note that `flax.training.train_state.TrainState` conveniently supports the\n", "category of `_overwrite_with_gradient`, so users do not need to modify their\n", "scripts if they don't use custom `TrainState`.\n", "\n", "## Accumulate gradient of FP8 params\n", "When the same parameter is used in a branched manner, the autograd mechanism\n", "will add their gradients from these branches. This is common in scenarios like\n", "pipeline parallelism, where each microbatch shares the same set of parameters\n", "for the minibatch. However, for the `_overwrite_with_gradient` parameters, this\n", "accumulation by addition is not meaningful. Instead, we prefer custom\n", "accumulation by taking the maximum value.\n", "\n", "To address this, we introduce a custom dtype `fp8_ops.fp32_max_grad`. The basic\n", "usage is demonstrated below:" ] }, { "cell_type": "code", "execution_count": null, "id": "2d3a86e9", "metadata": {}, "outputs": [], "source": [ "fmax32 = fp8_ops.fp32_max_grad\n", "\n", "def reuse_fp8_param(x, y, scale, amax_history):\n", " scale = scale.astype(fmax32)\n", " amax_history = amax_history.astype(fmax32)\n", "\n", " x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)\n", " y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)\n", " return x + y\n", "\n", "reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))\n", "reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)\n", "\n", "_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)\n", "print(new_ah, new_sf)" ] }, { "cell_type": "markdown", "id": "2321a9bb", "metadata": {}, "source": [ "In this example, we first cast the `scale` and `amax_history` to\n", "`fp8_ops.fp32_max_grad` and then call `fp8_ops.in_qdq` twice using the same pair\n", "of `scale` and `amax_history`. During autograd, their gradients from each branch\n", "will be taken as the maximum, giving us the correct results of:\n", "\n", "```plaintext\n", "1.0 [3. 0. 0. ... 0. 0. 0.]\n", "```\n", "\n", "If we do not perform the type casting, we get the following result, meaning the\n", "gradients of the two branches are added:\n", "\n", "```plaintext\n", "2.0 [5. 0. 0. ... 0. 0. 0.]\n", "```\n", "\n", "This casting is already included if users choose to use the high-level APIs.\n", "\n", "## Deprecated APIs\n", "Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current\n", "scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with\n", "high precision dot operations, leveraging an XLA-FP8 feature that\n", "pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding\n", "high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern\n", "matching-based solution proved brittle, as the patterns could be easily broken\n", "by other XLA optimizations. We recommend users migrate from these deprecated\n", "APIs to the newer ones described above.\n", "\n", "For migration, users should replace:\n", "* `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot ->\n", " fp8_ops.dequantize`\n", "* `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot\n", " -> fp8_ops.out_dq`\n", "* `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral`\n", "\n", "Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/guides/quantization/fp8_basics.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # User Guide on Using FP8 JAX supports various FP8 formats, including E4M3 (jnp.float8_e4m3fn) and E5M2 (jnp.float8_e5m2). Due to the limited range of FP8 data types, higher-precision data must be scaled to fit within the FP8 representable range, a process known as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back to its original type. While jnp.dot supports FP8 inputs directly, proper quantization and dequantization is needed for optimal performance. Flax provides nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle this automatically and can be used with existing layers like nn.Dense. This tutorial will walk you through the basics of how to use it. ## Setting up our environment Here, we provide the code necessary to set up the environment for our notebook. Additionally, we define a function to check if the XLA-optimized HLO will indeed call an FP8 dot operation under the hood. *Note: This tutorial relies on the XLA-FP8 feature, which is only supported on NVIDIA Hopper GPUs or later.* ```{code-cell} import flax import jax import re import pprint from jax import random from jax import numpy as jnp from jax._src import test_util as jtu from flax import linen as nn from flax.linen import fp8_ops e4m3 = jnp.float8_e4m3fn f32 = jnp.float32 E4M3_MAX = jnp.finfo(e4m3).max.astype(f32) assert jtu.is_cuda_compute_capability_at_least("9.0") def check_fp8_call(lowered): hlo = lowered.compile() if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()): print("Fp8 call detected!") else: print("No Fp8 call!") ``` ## FLAX Low Level API The JAX dot operations (e.g. `jnp.dot`) support the FP8 dtype inputs. So it is legal to do the following call: ```{code-cell} k0, k1 = random.split(random.key(0), 2) a = random.uniform(k0, (16, 32)) b = random.uniform(k1, (32, 64)) @jax.jit def dot_fp8(a, b): return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32) check_fp8_call(dot_fp8.lower(a, b)) ``` However, this approach has two key limitations: 1. `jnp.dot` does not support custom scaling factors for operands, defaulting to a scale of 1.0 2. The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice To overcome these limitations and implement proper FP8 matrix multiplication, we recommend using the Flax FP8 APIs. Let's start with a basic scaling approach. ### Current Scaling Scaling factors are usually defined as `scale = amax(x) / MAX`, where `amax` is an operation to find the absolute maximum value of the tensor, and `MAX` is the maximum value of the representable range of the target dtype. This scaling approach allows us to derive the scaling factors directly from the current operand tensors of the dot product. ```{code-cell} @jax.jit def dot_fp8(a, b): a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX a = fp8_ops.quantize(a, e4m3, a_scale, f32) b = fp8_ops.quantize(b, e4m3, b_scale, f32) c = jnp.dot(a, b, preferred_element_type=f32) c = fp8_ops.dequantize(c, f32, a_scale * b_scale) return c c = dot_fp8(a, b) check_fp8_call(dot_fp8.lower(a, b)) ``` As shown in the code, we perform quantization (`fp8_ops.quantize`) on the tensors to get the lower precision operands. The `jnp.dot` processes them and accumulates the output in high precision (i.e., the `preferred_element_type`). After that, we multiply the result by the scaling factors to dequantize back to the original range (`fp8_ops.dequantize`). Note that while this example uses E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and E5M2 for the inputs. The quantization method and the scaling factors can also be customized based on application needs. One major issue with the current scaling method is the performance overhead introduced by computing `a_scale` and `b_scale`, which requires additional loading of the operand tensors. To overcome this issue, we recommend the delayed scaling. ### Delayed Scaling In delayed scaling, we use a scaling factor associated with an amax history. The scaling factor remains a scalar, but the amax history is a list that stores amax values from recent steps (e.g., 1024 steps). Both tensors are computed from previous steps and maintained in the model parameters. The quantization and dequantization operations for delayed scaling are provided by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles input quantization and update the amax history and scaling factor, while `fp8_ops.out_dq` performs output dequantization. ```{code-cell} a_scale = jnp.array(1.0) b_scale = jnp.array(1.0) a_amax_hist = jnp.zeros((1024,)) b_amax_hist = jnp.zeros((1024,)) @jax.jit def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist): a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist) b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist) c = jnp.dot(a, b, preferred_element_type=f32) c = fp8_ops.out_dq(f32, a_scale, b_scale, c) return c c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist) check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)) ``` In this example, we first prepare three pairs of scaling factors and amax histories, treating them as results computed from previous steps. Then, we apply `fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq` to the output of `jnp.dot`. ## FLAX High Level API Flax provides high-level operations to seamlessly integrate FP8 quantization into existing layers. Instead of manually handling quantization of the delayed scaling (e.g., the maintanence of the amax history and scaling factors), users can simply use these drop-in replacements: * `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations * `fp8_ops.Fp8Einsum` for `jnp.einsum` operations These operations automatically handle all FP8-related functionality, including quantization/dequantization, scale factor updates, and FP8 dtype selection for both forward and backward passes. Consider the following example: ```{code-cell} model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral) params = model.init(k0, A) @jax.jit def train_step(var, a): c = model.apply(var, a) return jnp.sum(c) check_fp8_call(train_step.lower(params, A)) ``` By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version. The model usage remains similar, but now includes additional parameters for FP8 quantization: scaling factors and amax history values. The next section explains how to update these FP8-specific parameters. For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE) layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8 quantization. Here's an example: ```{code-cell} from typing import Any class FooModule(nn.Module): einsum: Any = None @nn.compact def __call__(self, a, b): if self.einsum is not None: einsum_fn = self.einsum() elif self.einsum is None: einsum_fn = jnp.einsum c = einsum_fn("mk,kn->mn", a, b) return c model = FooModule(einsum=fp8_ops.Fp8Einsum) params = model.init(k0, a, b) @jax.jit def train_step(var, a, b): c = model.apply(var, a, b) return jnp.sum(c) check_fp8_call(train_step.lower(params, a, b)) ``` ## Manipulate FP8 params The following sections explain the internal FP8 parameters managed by `fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters include scaling factors and amax history values that control the FP8 quantization process. While most users don't need to interact with these directly, understanding them can be valuable for advanced optimization and debugging. Let's first examine the data structure of `params`. In the code below, we redact the parameter values and then display the PyTree structure. ```{code-cell} params_structure = flax.core.unfreeze(params).copy() params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/') for key, value in params_structure.items(): params_structure[key] = '*' params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/') pprint.pprint(params_structure) ``` The output is as follows: ```plaintext {'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*', 'input_scale': '*', 'kernel_amax_history': '*', 'kernel_scale': '*', 'output_grad_amax_history': '*', 'output_grad_scale': '*'}}} ``` In addition to the expected `params`, there is an additional category called `_overwrite_with_gradient`. This category includes three pairs of `amax_history` and `scale` for the activation, kernel, and dot gradient, respectively. ### Update gradient of FP8 params Now, we perform one training step to obtain the gradients and see how to use them to update the parameters. ```{code-cell} step_fn = jax.jit(jax.grad(train_step, (0, 1))) grads = step_fn(params, A) params = flax.core.unfreeze(params) params = flax.traverse_util.flatten_dict(params, sep='/') grads = flax.traverse_util.flatten_dict(grads[0], sep='/') for key, value in params.items(): if key.startswith('params'): params[key] = value + 0.01 * grads[key] if key.startswith('_overwrite_with_gradient'): params[key] = grads[key] params = flax.traverse_util.unflatten_dict(params, sep='/') params = flax.core.freeze(params) ``` The above code demonstrates how to update both `params` and `_overwrite_with_gradient`. For `params`, we use the formula `new_param = old_param + 0.01 * grads`, where `0.01` is the learning rate (or users can use whatever optimizers from `optax`). For `_overwrite_with_gradient`, we simply use the gradient to overwrite the old values. Note that `flax.training.train_state.TrainState` conveniently supports the category of `_overwrite_with_gradient`, so users do not need to modify their scripts if they don't use custom `TrainState`. ## Accumulate gradient of FP8 params When the same parameter is used in a branched manner, the autograd mechanism will add their gradients from these branches. This is common in scenarios like pipeline parallelism, where each microbatch shares the same set of parameters for the minibatch. However, for the `_overwrite_with_gradient` parameters, this accumulation by addition is not meaningful. Instead, we prefer custom accumulation by taking the maximum value. To address this, we introduce a custom dtype `fp8_ops.fp32_max_grad`. The basic usage is demonstrated below: ```{code-cell} fmax32 = fp8_ops.fp32_max_grad def reuse_fp8_param(x, y, scale, amax_history): scale = scale.astype(fmax32) amax_history = amax_history.astype(fmax32) x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history) y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history) return x + y reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3)) reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn) _, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist) print(new_ah, new_sf) ``` In this example, we first cast the `scale` and `amax_history` to `fp8_ops.fp32_max_grad` and then call `fp8_ops.in_qdq` twice using the same pair of `scale` and `amax_history`. During autograd, their gradients from each branch will be taken as the maximum, giving us the correct results of: ```plaintext 1.0 [3. 0. 0. ... 0. 0. 0.] ``` If we do not perform the type casting, we get the following result, meaning the gradients of the two branches are added: ```plaintext 2.0 [5. 0. 0. ... 0. 0. 0.] ``` This casting is already included if users choose to use the high-level APIs. ## Deprecated APIs Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with high precision dot operations, leveraging an XLA-FP8 feature that pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern matching-based solution proved brittle, as the patterns could be easily broken by other XLA optimizations. We recommend users migrate from these deprecated APIs to the newer ones described above. For migration, users should replace: * `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot -> fp8_ops.dequantize` * `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot -> fp8_ops.out_dq` * `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral` Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`. ================================================ FILE: docs/guides/quantization/index.rst ================================================ Quantization ============ .. toctree:: :maxdepth: 1 fp8_basics ================================================ FILE: docs/guides/training_techniques/batch_norm.rst ================================================ Batch normalization =================== In this guide, you will learn how to apply `batch normalization `__ using :meth:`flax.linen.BatchNorm `. Batch normalization is a regularization technique used to speed up training and improve convergence. During training, it computes running averages over feature dimensions. This adds a new form of non-differentiable state that must be handled appropriately. Throughout the guide, you will be able to compare code examples with and without Flax ``BatchNorm``. .. testsetup:: No BatchNorm, With BatchNorm import flax.linen as nn import jax.numpy as jnp import jax import optax from typing import Any from flax.core import FrozenDict Defining the model with ``BatchNorm`` ************************************* In Flax, ``BatchNorm`` is a :meth:`flax.linen.Module ` that exhibits different runtime behavior between training and inference. You explicitly specify it via the ``use_running_average`` argument, as demonstrated below. A common pattern is to accept a ``train`` (``training``) argument in the parent Flax ``Module``, and use it to define ``BatchNorm``'s ``use_running_average`` argument. Note: In other machine learning frameworks, like PyTorch or TensorFlow (Keras), this is specified via a mutable state or a call flag (for example, in `torch.nn.Module.eval `__ or ``tf.keras.Model`` by setting the `training `__ flag). .. codediff:: :title: No BatchNorm, With BatchNorm :sync: class MLP(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(features=4)(x) x = nn.relu(x) x = nn.Dense(features=1)(x) return x --- class MLP(nn.Module): @nn.compact def __call__(self, x, train: bool): #! x = nn.Dense(features=4)(x) x = nn.BatchNorm(use_running_average=not train)(x) #! x = nn.relu(x) x = nn.Dense(features=1)(x) return x Once you create your model, initialize it by calling :meth:`flax.linen.init() ` to get the ``variables`` structure. Here, the main difference between the code without ``BatchNorm`` and with ``BatchNorm`` is that the ``train`` argument must be provided. The ``batch_stats`` collection ****************************** In addition to the ``params`` collection, ``BatchNorm`` also adds a ``batch_stats`` collection that contains the running average of the batch statistics. Note: You can learn more in the ``flax.linen`` `variables `__ API documentation. The ``batch_stats`` collection must be extracted from the ``variables`` for later use. .. codediff:: :title: No BatchNorm, With BatchNorm :sync: mlp = MLP() x = jnp.ones((1, 3)) variables = mlp.init(jax.random.key(0), x) params = variables['params'] jax.tree_util.tree_map(jnp.shape, variables) --- mlp = MLP() x = jnp.ones((1, 3)) variables = mlp.init(jax.random.key(0), x, train=False) #! params = variables['params'] batch_stats = variables['batch_stats'] #! jax.tree_util.tree_map(jnp.shape, variables) Flax ``BatchNorm`` adds a total of 4 variables: ``mean`` and ``var`` that live in the ``batch_stats`` collection, and ``scale`` and ``bias`` that live in the ``params`` collection. .. codediff:: :title: No BatchNorm, With BatchNorm :sync: FrozenDict({ 'params': { 'Dense_0': { 'bias': (4,), 'kernel': (3, 4), }, 'Dense_1': { 'bias': (1,), 'kernel': (4, 1), }, }, }) --- FrozenDict({ 'batch_stats': { #! 'BatchNorm_0': { #! 'mean': (4,), #! 'var': (4,), #! }, #! }, #! 'params': { 'BatchNorm_0': { #! 'bias': (4,), #! 'scale': (4,), #! }, #! 'Dense_0': { 'bias': (4,), 'kernel': (3, 4), }, 'Dense_1': { 'bias': (1,), 'kernel': (4, 1), }, }, }) Modifying ``flax.linen.apply`` ****************************** When using :meth:`flax.linen.apply ` to run your model with the ``train=True`` argument (that is, you have ``use_running_average=False`` in the call to ``BatchNorm``), you need to consider the following: * ``batch_stats`` must be passed as an input variable. * The ``batch_stats`` collection needs to be marked as mutable by setting ``mutable=['batch_stats']``. * The mutated variables are returned as a second output. The updated ``batch_stats`` must be extracted from here. .. codediff:: :title: No BatchNorm, With BatchNorm :sync: y = mlp.apply( {'params': params}, x, ) ... --- y, updates = mlp.apply( #! {'params': params, 'batch_stats': batch_stats}, #! x, train=True, mutable=['batch_stats'] #! ) batch_stats = updates['batch_stats'] #! Training and evaluation *********************** When integrating models that use ``BatchNorm`` into a training loop, the main challenge is handling the additional ``batch_stats`` state. To do this, you need to: * Add a ``batch_stats`` field to a custom :meth:`flax.training.train_state.TrainState ` class. * Pass the ``batch_stats`` values to the :meth:`train_state.TrainState.create ` method. .. codediff:: :title: No BatchNorm, With BatchNorm :sync: from flax.training import train_state state = train_state.TrainState.create( apply_fn=mlp.apply, params=params, tx=optax.adam(1e-3), ) --- from flax.training import train_state class TrainState(train_state.TrainState): #! batch_stats: Any #! state = TrainState.create( #! apply_fn=mlp.apply, params=params, batch_stats=batch_stats, #! tx=optax.adam(1e-3), ) In addition, update your ``train_step`` function to reflect these changes: * Pass all new parameters to ``flax.linen.apply`` (as previously discussed). * The ``updates`` to the ``batch_stats`` must be propagated out of the ``loss_fn``. * The ``batch_stats`` from the ``TrainState`` must be updated. .. codediff:: :title: No BatchNorm, With BatchNorm :sync: @jax.jit def train_step(state: train_state.TrainState, batch): """Train for a single step.""" def loss_fn(params): logits = state.apply_fn( {'params': params}, x=batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = { 'loss': loss, 'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), } return state, metrics --- @jax.jit def train_step(state: TrainState, batch): """Train for a single step.""" def loss_fn(params): logits, updates = state.apply_fn( #! {'params': params, 'batch_stats': state.batch_stats}, #! x=batch['image'], train=True, mutable=['batch_stats']) #! loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() return loss, (logits, updates) #! grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, (logits, updates)), grads = grad_fn(state.params) #! state = state.apply_gradients(grads=grads) state = state.replace(batch_stats=updates['batch_stats']) #! metrics = { 'loss': loss, 'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), } return state, metrics The ``eval_step`` is much simpler. Because ``batch_stats`` is not mutable, no updates need to be propagated. Make sure you pass the ``batch_stats`` to ``flax.linen.apply``, and the ``train`` argument is set to ``False``: .. codediff:: :title: No BatchNorm, With BatchNorm :sync: @jax.jit def eval_step(state: train_state.TrainState, batch): """Train for a single step.""" logits = state.apply_fn( {'params': params}, x=batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() metrics = { 'loss': loss, 'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), } return state, metrics --- @jax.jit def eval_step(state: TrainState, batch): """Evaluate for a single step.""" logits = state.apply_fn( {'params': state.params, 'batch_stats': state.batch_stats}, #! x=batch['image'], train=False) #! loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() metrics = { 'loss': loss, 'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']), } return state, metrics ================================================ FILE: docs/guides/training_techniques/dropout.rst ================================================ Dropout ======= This guide provides an overview of how to apply `dropout `__ using :meth:`flax.linen.Dropout`. Dropout is a stochastic regularization technique that randomly removes hidden and visible units in a network. Throughout the guide, you will be able to compare code examples with and without Flax ``Dropout``. .. testsetup:: No Dropout, With Dropout import flax.linen as nn import jax.numpy as jnp import jax import optax Split the PRNG key ****************** Since dropout is a random operation, it requires a pseudorandom number generator (PRNG) state. Flax uses JAX's (splittable) PRNG keys, which have a number of desirable properties for neural networks. To learn more, refer to the `Pseudorandom numbers in JAX tutorial `__. **Note:** Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as ``key = jax.random.key(seed=0)``) into multiple new PRNG keys with ``key, subkey = jax.random.split(key)``. You can refresh your memory in `🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys `__. Begin by splitting the PRNG key using `jax.random.split() `__ into three keys, including one for Flax Linen ``Dropout``. .. codediff:: :title: No Dropout, With Dropout :sync: root_key = jax.random.key(seed=0) main_key, params_key = jax.random.split(key=root_key) --- root_key = jax.random.key(seed=0) main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) #! **Note:** In Flax, you provide *PRNG streams* with *names*, so that you can use them later in your :meth:`flax.linen.Module`. For example, you pass the stream ``'params'`` for initializing parameters, and ``'dropout'`` for applying :meth:`flax.linen.Dropout`. Define your model with ``Dropout`` ********************************** To create a model with dropout: * Subclass :meth:`flax.linen.Module`, and then use :meth:`flax.linen.Dropout` to add a dropout layer. Recall that :meth:`flax.linen.Module` is the `base class for all neural network Modules `__, and all layers and models are subclassed from it. * In :meth:`flax.linen.Dropout`, the ``deterministic`` argument is required to be passed as a keyword argument, either: * When constructing the :meth:`flax.linen.Module`; or * When calling :meth:`flax.linen.init()` or :meth:`flax.linen.apply()` on a constructed ``Module``. (Refer to :meth:`flax.linen.module.merge_param` for more details.) * Because ``deterministic`` is a boolean: * If it's set to ``False``, the inputs are masked (that is, set to zero) with a probability set by ``rate``. And the remaining inputs are scaled by ``1 / (1 - rate)``, which ensures that the means of the inputs are preserved. * If it's set to ``True``, no mask is applied (the dropout is turned off), and the inputs are returned as-is. A common pattern is to accept a ``training`` (or ``train``) argument (a boolean) in the parent Flax ``Module``, and use it to enable or disable dropout (as demonstrated in later sections of this guide). In other machine learning frameworks, like PyTorch or TensorFlow (Keras), this is specified via a mutable state or a call flag (for example, in `torch.nn.Module.eval `__ or ``tf.keras.Model`` by setting the `training `__ flag). **Note:** Flax provides an implicit way of handling PRNG key streams via Flax :meth:`flax.linen.Module`'s :meth:`flax.linen.Module.make_rng` method. This allows you to split off a fresh PRNG key inside Flax Modules (or their sub-Modules) from the PRNG stream. The ``make_rng`` method guarantees to provide a unique key each time you call it. Internally, :meth:`flax.linen.Dropout` makes use of :meth:`flax.linen.Module.make_rng` to create a key for dropout. You can check out the `source code `__. In short, :meth:`flax.linen.Module.make_rng` *guarantees full reproducibility*. .. codediff:: :title: No Dropout, With Dropout :sync: class MyModel(nn.Module): num_neurons: int @nn.compact def __call__(self, x): x = nn.Dense(self.num_neurons)(x) return x --- class MyModel(nn.Module): num_neurons: int @nn.compact def __call__(self, x, training: bool): #! x = nn.Dense(self.num_neurons)(x) # Set the dropout layer with a `rate` of 50%. #! # When the `deterministic` flag is `True`, dropout is turned off. #! x = nn.Dropout(rate=0.5, deterministic=not training)(x) #! return x Initialize the model ******************** After creating your model: * Instantiate the model. * Then, in the :meth:`flax.linen.init()` call, set ``training=False``. * Finally, extract the ``params`` from the `variable dictionary `__. Here, the main difference between the code without Flax ``Dropout`` and with ``Dropout`` is that the ``training`` (or ``train``) argument must be provided if you need dropout enabled. .. codediff:: :title: No Dropout, With Dropout :sync: my_model = MyModel(num_neurons=3) x = jnp.empty((3, 4, 4)) variables = my_model.init(params_key, x) params = variables['params'] --- my_model = MyModel(num_neurons=3) x = jnp.empty((3, 4, 4)) # Dropout is disabled with `training=False` (that is, `deterministic=True`). #! variables = my_model.init(params_key, x, training=False) #! params = variables['params'] Perform the forward pass during training **************************************** When using :meth:`flax.linen.apply()` to run your model: * Pass ``training=True`` to :meth:`flax.linen.apply()`. * Then, to draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed the ``'dropout'`` stream when you call :meth:`flax.linen.apply()`. .. codediff:: :title: No Dropout, With Dropout :sync: # No need to pass the `training` and `rngs` flags. y = my_model.apply({'params': params}, x) --- # Dropout is enabled with `training=True` (that is, `deterministic=False`). #! y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key}) #! Here, the main difference between the code without Flax ``Dropout`` and with ``Dropout`` is that the ``training`` (or ``train``) and ``rngs`` arguments must be provided if you need dropout enabled. During evaluation, use the above code with no dropout enabled (this means you do not have to pass a RNG either). ``TrainState`` and the training step ************************************ This section explains how to amend your code inside the training step function if you have dropout enabled. **Note:** Recall that Flax has a common pattern where you create a dataclass that represents the whole training state, including parameters and the optimizer state. Then, you can pass a single parameter, ``state: TrainState``, to the training step function. Refer to the :meth:`flax.training.train_state.TrainState` API docs to learn more. * First, add a ``key`` field to a custom :meth:`flax.training.train_state.TrainState` class. * Then, pass the ``key`` value—in this case, the ``dropout_key``—to the :meth:`train_state.TrainState.create` method. .. codediff:: :title: No Dropout, With Dropout :sync: from flax.training import train_state state = train_state.TrainState.create( apply_fn=my_model.apply, params=params, tx=optax.adam(1e-3) ) --- from flax.training import train_state class TrainState(train_state.TrainState): #! key: jax.Array #! state = TrainState.create( #! apply_fn=my_model.apply, params=params, key=dropout_key, #! tx=optax.adam(1e-3) ) * Next, in the Flax training step function, ``train_step``, generate a new PRNG key from the ``dropout_key`` to apply dropout at each step. This can be done with one of the following: * `jax.random.split() `__; or * `jax.random.fold_in() `__ Using ``jax.random.fold_in()`` is generally faster. When you use ``jax.random.split()`` you split off a PRNG key that can be reused afterwards. However, using ``jax.random.fold_in()`` makes sure to 1) fold in unique data; and 2) can result in longer sequences of PRNG streams. * Finally, when performing the forward pass, pass the new PRNG key to ``state.apply_fn()`` as an extra parameter. .. codediff:: :title: No Dropout, With Dropout :sync: @jax.jit def train_step(state: train_state.TrainState, batch): def loss_fn(params): logits = state.apply_fn( {'params': params}, x=batch['image'], ) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state --- @jax.jit def train_step(state: TrainState, batch, dropout_key): #! dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step) #! def loss_fn(params): logits = state.apply_fn( {'params': params}, x=batch['image'], training=True, #! rngs={'dropout': dropout_train_key} #! ) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state Flax examples with dropout ************************** * A `Transformer-based model `__ trained on the WMT Machine Translation dataset. This example uses dropout and attention dropout. * Applying word dropout to a batch of input IDs in a `text classification `__ context. This example uses a custom :meth:`flax.linen.Dropout` layer. More Flax examples that use Module ``make_rng()`` ************************************************* * Defining a prediction token in a decoder of a `sequence-to-sequence model `__. ================================================ FILE: docs/guides/training_techniques/index.rst ================================================ Training techniques =================== .. toctree:: :maxdepth: 1 batch_norm dropout lr_schedule transfer_learning use_checkpointing ================================================ FILE: docs/guides/training_techniques/lr_schedule.rst ================================================ Learning rate scheduling ============================= The learning rate is considered one of the most important hyperparameters for training deep neural networks, but choosing it can be quite hard. Rather than simply using a fixed learning rate, it is common to use a learning rate scheduler. In this example, we will use the *cosine scheduler*. Before the cosine scheduler comes into play, we start with a so-called *warmup* period in which the learning rate increases linearly for ``warmup_epochs`` epochs. For more information about the cosine scheduler, check out the paper `"SGDR: Stochastic Gradient Descent with Warm Restarts" `_. We will show you how to... * define a learning rate schedule * train a simple model using that schedule .. testsetup:: Default learning rate, Learning rate schedule import jax import jax.numpy as jnp import flax.linen as nn from flax.training import train_state import optax import numpy as np import tensorflow_datasets as tfds import functools import ml_collections from absl import logging class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x def get_dummy_data(ds_size): image = np.random.rand(ds_size, 28, 28, 1) label = np.random.randint(low=0, high=10, size=(ds_size,)) return {'image': image, 'label': label} def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.learning_rate = 0.001 config.momentum = 0.9 config.batch_size = 128 config.num_epochs = 10 config.warmup_epochs = 2 config.train_ds_size = 128 return config def compute_metrics(logits, labels): one_hot = jax.nn.one_hot(labels, 10) loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot)) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, 'accuracy': accuracy, } return metrics .. testcode:: Default learning rate, Learning rate schedule def create_learning_rate_fn(config, base_learning_rate, steps_per_epoch): """Creates learning rate schedule.""" warmup_fn = optax.linear_schedule( init_value=0., end_value=base_learning_rate, transition_steps=config.warmup_epochs * steps_per_epoch) cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[config.warmup_epochs * steps_per_epoch]) return schedule_fn To use the schedule, we must create a learning rate function by passing the hyperparameters to the ``create_learning_rate_fn`` function and then pass the function to your |Optax|_ optimizer. For example using this schedule on MNIST would require changing the ``train_step`` function: .. |Optax| replace:: ``Optax`` .. _Optax: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules .. codediff:: :title: Default learning rate, Learning rate schedule :sync: @jax.jit def train_step(state, batch): def loss_fn(params): logits = CNN().apply({'params': params}, batch['image']) one_hot = jax.nn.one_hot(batch['label'], 10) loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, batch['label']) return new_state, metrics --- @functools.partial(jax.jit, static_argnums=2) #! def train_step(state, batch, learning_rate_fn): #! def loss_fn(params): logits = CNN().apply({'params': params}, batch['image']) one_hot = jax.nn.one_hot(batch['label'], 10) loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, batch['label']) lr = learning_rate_fn(state.step) #! metrics['learning_rate'] = lr #! return new_state, metrics And the ``train_epoch`` function: .. codediff:: :title: Default learning rate, Learning rate schedule :sync: def train_epoch(state, train_ds, batch_size, epoch, rng): """Trains for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm, ...] for k, v in train_ds.items()} state, metrics = train_step(state, batch) batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics = jax.device_get(batch_metrics) epoch_metrics = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0]} logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, epoch_metrics['loss'], epoch_metrics['accuracy'] * 100) return state, epoch_metrics --- def train_epoch(state, train_ds, batch_size, epoch, learning_rate_fn, rng): #! """Trains for a single epoch.""" train_ds_size = len(train_ds['image']) steps_per_epoch = train_ds_size // batch_size perms = jax.random.permutation(rng, len(train_ds['image'])) perms = perms[:steps_per_epoch * batch_size] perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: batch = {k: v[perm, ...] for k, v in train_ds.items()} state, metrics = train_step(state, batch, learning_rate_fn) #! batch_metrics.append(metrics) # compute mean of metrics across each batch in epoch. batch_metrics = jax.device_get(batch_metrics) epoch_metrics = { k: np.mean([metrics[k] for metrics in batch_metrics]) for k in batch_metrics[0]} logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, epoch_metrics['loss'], epoch_metrics['accuracy'] * 100) return state, epoch_metrics And the ``create_train_state`` function: .. codediff:: :title: Default learning rate, Learning rate schedule :sync: def create_train_state(rng, config): """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(config.learning_rate, config.momentum) return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) --- def create_train_state(rng, config, learning_rate_fn): #! """Creates initial `TrainState`.""" cnn = CNN() params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] tx = optax.sgd(learning_rate_fn, config.momentum) #! return train_state.TrainState.create( apply_fn=cnn.apply, params=params, tx=tx) .. testcleanup:: Learning rate schedule config = get_config() train_ds_size = config.train_ds_size steps_per_epoch = train_ds_size // config.batch_size learning_rate_fn = create_learning_rate_fn(config, config.learning_rate, steps_per_epoch) rng = jax.random.key(0) state = create_train_state(rng, config, learning_rate_fn) train_ds = get_dummy_data(config.train_ds_size) rng, _ = jax.random.split(rng) state, epoch_metrics = train_epoch(state, train_ds, config.batch_size, 0, learning_rate_fn, rng) assert 'accuracy' in epoch_metrics and 'learning_rate' in epoch_metrics ================================================ FILE: docs/guides/training_techniques/transfer_learning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This guide demonstrates various parts of the transfer learning workflow with Flax. Depending on the task, a pretrained model can be used just as a feature extractor or it can be fine-tuned as part of a larger model.\n", "\n", "This guide demonstrates how to:\n", "\n", "* Load a pretrained model from HuggingFace [Transformers](https://huggingface.co/docs/transformers/index) and extract a specific sub-module from that pretrained model.\n", "* Create a classifier model.\n", "* Transfer the pretrained parameters to the new model structure.\n", "* Create an optimizer for training different parts of the model separately with [Optax](https://optax.readthedocs.io/).\n", "* Set up the model for training.\n", "\n", "
Performance Note\n", "\n", "Depending on your task, some of the content in this guide may be suboptimal. For example, if you are only going to train a linear classifier on top of a pretrained model, it may be better to just extract the feature embeddings once, which can result in much faster training, and you can use specialized algorithms for linear regression or logistic classification. This guide shows how to do transfer learning with all the model parameters.\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "# Note that the Transformers library doesn't use the latest Flax version.\n", "! pip install -q \"transformers[flax]\"\n", "# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,\n", "# visit https://github.com/jax-ml/jax#installation.\n", "! pip install -U -q flax jax jaxlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a function for model loading\n", "\n", "To load a pre-trained classifier, for convenience first create a function that returns a [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics) and its pretrained variables.\n", "\n", "In the code below, the `load_model` function uses HuggingFace's `FlaxCLIPVisionModel` model from the [Transformers](https://huggingface.co/docs/transformers/index) library and extracts a `FlaxCLIPModule` module." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "from IPython.display import clear_output\n", "from transformers import FlaxCLIPModel\n", "\n", "# Note: FlaxCLIPModel is not a Flax Module\n", "def load_model():\n", " clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')\n", " clear_output(wait=False) # Clear the loading messages\n", " module = clip.module # Extract the Flax Module\n", " variables = {'params': clip.params} # Extract the parameters\n", " return module, variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `FlaxCLIPVisionModel` itself is not a Flax `Module` which is why we need to do this extra step.\n", "\n", "### Extracting a submodule\n", "\n", "Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of `text_model` and `vision_model` submodules.\n", "\n", "An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.unbind) on the `vision_model` sub-Module." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import flax.linen as nn\n", "\n", "clip, clip_variables = load_model()\n", "vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating a classifier\n", "\n", "To create a classifier define a new Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics) consisting of a `backbone` (the pretrained vision model) and a `head` (the classifier) submodules." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from typing import Callable\n", "import jax.numpy as jnp\n", "import jax\n", "\n", "class Classifier(nn.Module):\n", " num_classes: int\n", " backbone: nn.Module\n", " \n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = self.backbone(x).pooler_output\n", " x = nn.Dense(\n", " self.num_classes, name='head', kernel_init=nn.zeros)(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To construct a classifier `model`, the `vision_model` Module is passed as the `backbone` to `Classifier`. Then the model's `params` can be randomly initialized by passing fake data that is used to infer the parameter shapes." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "num_classes = 3\n", "model = Classifier(num_classes=num_classes, backbone=vision_model)\n", "\n", "x = jnp.empty((1, 224, 224, 3))\n", "variables = model.init(jax.random.key(1), x)\n", "params = variables['params']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transfering the parameters\n", "Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "params['backbone'] = vision_model_vars['params']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well.\n", "\n", "## Optimization\n", "\n", "If you need to to train different parts of the model separately, you have three options:\n", "\n", "1. Use `stop_gradient`.\n", "2. Filter the parameters for `jax.grad`.\n", "3. Use multiple optimizers for different parameters.\n", "\n", "For most situations we recommend using multiple optimizers via [Optax](https://optax.readthedocs.io/)'s [`multi_transform`](https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform) as its both efficient and can be easily extended to implement many fine-tunning strategies. \n", "\n", "### **optax.multi_transform**\n", "\n", "To use `optax.multi_transform` following must be defined:\n", "\n", "1. The parameter partitions.\n", "2. A mapping between partitions and their optimizer.\n", "3. A pytree with the same shape as the parameters but its leaves containing the corresponding partition label.\n", "\n", "To freeze layers with `optax.multi_transform` for the model above, the following setup can be used:\n", "\n", "* Define the `trainable` and `frozen` parameter partitions.\n", "* For the `trainable` parameters select the Adam (`optax.adam`) optimizer.\n", "- For the `frozen` parameters select the `optax.set_to_zero` optimizer. This dummy optimizer zeros-out the gradients so no training is done.\n", "- Map parameters to partitions using [`flax.traverse_util.path_aware_map`](https://flax.readthedocs.io/en/latest/api_reference/flax.traverse_util.html#flax.traverse_util.path_aware_map), mark the leaves from the `backbone` as `frozen`, and the rest as `trainable`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FrozenDict({\n", " backbone: {\n", " embeddings: {\n", " class_embedding: 'frozen',\n", " patch_embedding: {\n", " kernel: 'frozen',\n", " },\n", " },\n", " },\n", " head: {\n", " bias: 'trainable',\n", " kernel: 'trainable',\n", " },\n", "})" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from flax import traverse_util\n", "import optax\n", "\n", "partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n", "param_partitions = traverse_util.path_aware_map(\n", " lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)\n", "tx = optax.multi_transform(partition_optimizers, param_partitions)\n", "\n", "# visualize a subset of the param_partitions structure\n", "flat = list(traverse_util.flatten_dict(param_partitions).items())\n", "traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.\n", "\n", "## Creating the `TrainState`\n", "\n", "Once the module, params, and optimizer are defined, the `TrainState` can be constructed as usual:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from flax.training.train_state import TrainState\n", "\n", "state = TrainState.create(\n", " apply_fn=model.apply,\n", " params=params,\n", " tx=tx)" ] }, { "cell_type": "markdown", "id": "083d8854", "metadata": {}, "source": [ "Since the optimizer takes care of the freezing or fine-tunning strategy, the `train_step` requires no additional changes, training can proceed normally." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.14" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs/guides/training_techniques/transfer_learning.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Transfer learning +++ This guide demonstrates various parts of the transfer learning workflow with Flax. Depending on the task, a pretrained model can be used just as a feature extractor or it can be fine-tuned as part of a larger model. This guide demonstrates how to: * Load a pretrained model from HuggingFace [Transformers](https://huggingface.co/docs/transformers/index) and extract a specific sub-module from that pretrained model. * Create a classifier model. * Transfer the pretrained parameters to the new model structure. * Create an optimizer for training different parts of the model separately with [Optax](https://optax.readthedocs.io/). * Set up the model for training.
Performance Note Depending on your task, some of the content in this guide may be suboptimal. For example, if you are only going to train a linear classifier on top of a pretrained model, it may be better to just extract the feature embeddings once, which can result in much faster training, and you can use specialized algorithms for linear regression or logistic classification. This guide shows how to do transfer learning with all the model parameters.

+++ ## Setup ```{code-cell} ipython3 :tags: [skip-execution] # Note that the Transformers library doesn't use the latest Flax version. ! pip install -q "transformers[flax]" # Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support, # visit https://github.com/jax-ml/jax#installation. ! pip install -U -q flax jax jaxlib ``` ## Create a function for model loading To load a pre-trained classifier, for convenience first create a function that returns a [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics) and its pretrained variables. In the code below, the `load_model` function uses HuggingFace's `FlaxCLIPVisionModel` model from the [Transformers](https://huggingface.co/docs/transformers/index) library and extracts a `FlaxCLIPModule` module. ```{code-cell} ipython3 %%capture from IPython.display import clear_output from transformers import FlaxCLIPModel # Note: FlaxCLIPModel is not a Flax Module def load_model(): clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32') clear_output(wait=False) # Clear the loading messages module = clip.module # Extract the Flax Module variables = {'params': clip.params} # Extract the parameters return module, variables ``` Note that `FlaxCLIPVisionModel` itself is not a Flax `Module` which is why we need to do this extra step. ### Extracting a submodule Calling `load_model` from the snippet above returns the `FlaxCLIPModule`, which is composed of `text_model` and `vision_model` submodules. An easy way to extract the `vision_model` sub-Module defined inside `.setup()` and its variables is to use [`flax.linen.Module.bind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.bind) on the `clip` Module immediately followed by [`flax.linen.Module.unbind`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.unbind) on the `vision_model` sub-Module. ```{code-cell} ipython3 import flax.linen as nn clip, clip_variables = load_model() vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind() ``` ### Creating a classifier To create a classifier define a new Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#module-basics) consisting of a `backbone` (the pretrained vision model) and a `head` (the classifier) submodules. ```{code-cell} ipython3 from typing import Callable import jax.numpy as jnp import jax class Classifier(nn.Module): num_classes: int backbone: nn.Module @nn.compact def __call__(self, x): x = self.backbone(x).pooler_output x = nn.Dense( self.num_classes, name='head', kernel_init=nn.zeros)(x) return x ``` To construct a classifier `model`, the `vision_model` Module is passed as the `backbone` to `Classifier`. Then the model's `params` can be randomly initialized by passing fake data that is used to infer the parameter shapes. ```{code-cell} ipython3 num_classes = 3 model = Classifier(num_classes=num_classes, backbone=vision_model) x = jnp.empty((1, 224, 224, 3)) variables = model.init(jax.random.key(1), x) params = variables['params'] ``` ## Transfering the parameters Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location (i.e. the `backbone`): ```{code-cell} ipython3 params['backbone'] = vision_model_vars['params'] ``` **Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well. ## Optimization If you need to to train different parts of the model separately, you have three options: 1. Use `stop_gradient`. 2. Filter the parameters for `jax.grad`. 3. Use multiple optimizers for different parameters. For most situations we recommend using multiple optimizers via [Optax](https://optax.readthedocs.io/)'s [`multi_transform`](https://optax.readthedocs.io/en/latest/api.html#optax.multi_transform) as its both efficient and can be easily extended to implement many fine-tunning strategies. ### **optax.multi_transform** To use `optax.multi_transform` following must be defined: 1. The parameter partitions. 2. A mapping between partitions and their optimizer. 3. A pytree with the same shape as the parameters but its leaves containing the corresponding partition label. To freeze layers with `optax.multi_transform` for the model above, the following setup can be used: * Define the `trainable` and `frozen` parameter partitions. * For the `trainable` parameters select the Adam (`optax.adam`) optimizer. - For the `frozen` parameters select the `optax.set_to_zero` optimizer. This dummy optimizer zeros-out the gradients so no training is done. - Map parameters to partitions using [`flax.traverse_util.path_aware_map`](https://flax.readthedocs.io/en/latest/api_reference/flax.traverse_util.html#flax.traverse_util.path_aware_map), mark the leaves from the `backbone` as `frozen`, and the rest as `trainable`. ```{code-cell} ipython3 from flax import traverse_util import optax partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()} param_partitions = traverse_util.path_aware_map( lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params) tx = optax.multi_transform(partition_optimizers, param_partitions) # visualize a subset of the param_partitions structure flat = list(traverse_util.flatten_dict(param_partitions).items()) traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])) ``` To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation. ## Creating the `TrainState` Once the module, params, and optimizer are defined, the `TrainState` can be constructed as usual: ```{code-cell} ipython3 from flax.training.train_state import TrainState state = TrainState.create( apply_fn=model.apply, params=params, tx=tx) ``` Since the optimizer takes care of the freezing or fine-tunning strategy, the `train_step` requires no additional changes, training can proceed normally. ================================================ FILE: docs/guides/training_techniques/use_checkpointing.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "6e9134fa", "metadata": {}, "source": [ "# Save and load checkpoints\n", "\n", "This guide demonstrates how to save and load Flax checkpoints with [Orbax](https://github.com/google/orbax).\n", "\n", "Orbax provides a variety of features for saving and loading model data, which you will learn about in this doc:\n", "\n", "* Support for various array types and storage formats\n", "* Asynchronous saving to reduce training wait time\n", "* Versioning and automatic bookkeeping of past checkpoints\n", "* Flexible [`transformations`](https://orbax.readthedocs.io/en/latest/transformations.html) to tweak and load old checkpoints\n", "* [`jax.sharding`](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)-based API to save and load in multi-host scenarios\n", "\n", "---\n", "**_Ongoing migration to Orbax:_**\n", "\n", "After July 30 2023, Flax's legacy `flax.training.checkpoints` API will be deprecated in favor of [Orbax](https://github.com/google/orbax).\n", "\n", "* **If you are a new Flax user**: Use the new `orbax.checkpoint` API, as demonstrated in this guide.\n", "\n", "* **If you have legacy `flax.training.checkpoints` code in your project**: Consider the following options:\n", "\n", " * **Migrating your code to Orbax (Recommended)**: Migrate your API calls to `orbax.checkpoint` API by following this [migration guide](https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html).\n", "\n", " * **Automatically use the Orbax backend**: Add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project, which will let your `flax.training.checkpoints` calls automatically use the Orbax backend to save your checkpoints.\n", "\n", " * **Scheduled flip**: This will become the default mode after **May 2023** (tentative date).\n", "\n", " * Visit [Orbax-as-backend troubleshooting section](https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#orbax-as-backend-troubleshooting) if you meet any issue in the automatic migration.\n", "---\n", "\n", "For backward-compatibility, this guide shows the Orbax-equivalent calls in the Flax legacy `flax.training.checkpoints` API.\n", "\n", "If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](https://orbax.readthedocs.io/en/latest/).\n" ] }, { "cell_type": "markdown", "id": "5a2f6aae", "metadata": {}, "source": [ "## Setup\n", "\n", "Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "-icO30rwmKYj", "metadata": {}, "source": [ "Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell." ] }, { "cell_type": "code", "execution_count": 1, "id": "ArKLnsyGRxGv", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" ] }, { "cell_type": "code", "execution_count": 2, "id": "SJT9DTxTytjn", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n" ] } ], "source": [ "from typing import Optional, Any\n", "import shutil\n", "\n", "import numpy as np\n", "import jax\n", "from jax import random, numpy as jnp\n", "\n", "import flax\n", "from flax import linen as nn\n", "from flax.training import checkpoints, train_state\n", "from flax import struct, serialization\n", "import orbax.checkpoint\n", "\n", "import optax" ] }, { "cell_type": "code", "execution_count": 3, "id": "afd6db30", "metadata": {}, "outputs": [], "source": [ "ckpt_dir = '/tmp/flax_ckpt'\n", "\n", "if os.path.exists(ckpt_dir):\n", " shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run." ] }, { "cell_type": "markdown", "id": "40d434cd", "metadata": {}, "source": [ "## Save checkpoints\n", "\n", "In Orbax and Flax, you can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html). This includes not only typical Python and NumPy containers, but also customized classes extended from [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass). That means you can store almost any data generated — not only your model parameters, but any arrays/dictionaries, metadata/configs, and so on.\n", "\n", "First, create a pytree with many data structures and containers, and play with it:" ] }, { "cell_type": "code", "execution_count": 4, "id": "56dec3f6", "metadata": { "outputId": "f1856d96-1961-48ed-bb7c-cb63fbaa7567" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1695322343.254588 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" ] }, { "data": { "text/plain": [ "{'model': TrainState(step=1, apply_fn=, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),\n", " 'config': {'dimensions': array([5, 3])},\n", " 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# A simple model with one linear layer.\n", "key1, key2 = random.split(random.key(0))\n", "x1 = random.normal(key1, (5,)) # A simple JAX array.\n", "model = nn.Dense(features=3)\n", "variables = model.init(key2, x1)\n", "\n", "# Flax's TrainState is a pytree dataclass and is supported in checkpointing.\n", "# Define your class with `@flax.struct.dataclass` decorator to make it compatible.\n", "tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.\n", "state = train_state.TrainState.create(\n", " apply_fn=model.apply,\n", " params=variables['params'],\n", " tx=tx)\n", "# Perform a simple gradient update similar to the one during a normal training workflow.\n", "state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))\n", "\n", "# Some arbitrary nested pytree with a dictionary and a NumPy array.\n", "config = {'dimensions': np.array([5, 3])}\n", "\n", "# Bundle everything together.\n", "ckpt = {'model': state, 'config': config, 'data': [x1]}\n", "ckpt" ] }, { "cell_type": "markdown", "id": "8c715b95", "metadata": {}, "source": [ "### With Orbax" ] }, { "cell_type": "markdown", "id": "6fc59dfa", "metadata": {}, "source": [ "Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory.\n", "\n", "Note: An optional `save_args` is provided. This is recommended for performance speedups, as it bundles smaller arrays in your pytree to a single large file instead of multiple smaller files." ] }, { "cell_type": "code", "execution_count": 5, "id": "61b12da2", "metadata": {}, "outputs": [], "source": [ "from flax.training import orbax_utils\n", "\n", "orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()\n", "save_args = orbax_utils.save_args_from_target(ckpt)\n", "orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)" ] }, { "cell_type": "markdown", "id": "07d4de1a", "metadata": {}, "source": [ "Next, to use versioning and automatic bookkeeping features, you need to wrap `orbax.checkpoint.CheckpointManager` over `orbax.checkpoint.PyTreeCheckpointer`.\n", "\n", "In addition, provide `orbax.checkpoint.CheckpointManagerOptions` that customizes your needs, such as how often and on what criteria you prefer old checkpoints be deleted. See [documentation](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) for a full list of options offered.\n", "\n", "`orbax.checkpoint.CheckpointManager` should be placed at the top-level outside your training steps to manage your saves." ] }, { "cell_type": "code", "execution_count": 6, "id": "d3686ea5", "metadata": { "outputId": "b7132933-566d-440d-c34e-c5468d87cbdc" }, "outputs": [ { "data": { "text/plain": [ "['4', '3']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)\n", "checkpoint_manager = orbax.checkpoint.CheckpointManager(\n", " '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)\n", "\n", "# Inside a training loop\n", "for step in range(5):\n", " # ... do your training\n", " checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})\n", "\n", "os.listdir('/tmp/flax_ckpt/orbax/managed') # Because max_to_keep=2, only step 3 and 4 are retained" ] }, { "cell_type": "markdown", "id": "8ecbc4cc", "metadata": {}, "source": [ "### With the legacy API\n", "\n", "And here's how to save with the legacy Flax checkpointing utilities (note that this provides less management features compared with `orbax.checkpoint.CheckpointManagerOptions`):" ] }, { "cell_type": "code", "execution_count": 7, "id": "4cdb35ef", "metadata": { "outputId": "6d849273-15ce-4480-8864-726d1838ac1f" }, "outputs": [ { "data": { "text/plain": [ "'/tmp/flax_ckpt/flax-checkpointing/checkpoint_0'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Import Flax Checkpoints.\n", "from flax.training import checkpoints\n", "\n", "checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',\n", " target=ckpt,\n", " step=0,\n", " overwrite=True,\n", " keep=2)" ] }, { "cell_type": "markdown", "id": "6b658bd1", "metadata": {}, "source": [ "## Restore checkpoints\n", "\n", "### With Orbax\n", "\n", "In Orbax, call `.restore()` for either `orbax.checkpoint.PyTreeCheckpointer` or `orbax.checkpoint.CheckpointManager` to restore your checkpoint in the raw pytree format." ] }, { "cell_type": "code", "execution_count": 8, "id": "a807a9c1", "metadata": { "outputId": "b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" }, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': {'opt_state': [None, None],\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", " 'step': 1}}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save')\n", "raw_restored" ] }, { "cell_type": "markdown", "id": "8c015a22", "metadata": {}, "source": [ "Note that the `step` number is required for `CheckpointManger`. You can also use `.latest_step()` to find the latest step available." ] }, { "cell_type": "code", "execution_count": 9, "id": "251d7085", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': {'opt_state': [None, None],\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", " 'step': 1}}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "step = checkpoint_manager.latest_step() # step = 4\n", "checkpoint_manager.restore(step)" ] }, { "cell_type": "markdown", "id": "c7fe3bc8", "metadata": {}, "source": [ "### With the legacy API\n", "\n", "Note that with the migration to Orbax in progress, `flax.training.checkpointing.restore_checkpoint` can automatically identify whether a checkpoint is saved in the legacy Flax format or with an Orbax backend, and restore the pytree correctly. Therefore, adding `flax.config.update('flax_use_orbax_checkpointing', True)` won't hurt your ability to restore old checkpoints.\n", "\n", "Here's how to restore checkpoints using the legacy API:" ] }, { "cell_type": "code", "execution_count": 10, "id": "150b20a0", "metadata": { "outputId": "85ffceca-f38d-46b8-e567-d9d38b7885f9" }, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)},\n", " 'model': {'opt_state': {'0': None, '1': None},\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", " 'step': 1}}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None)\n", "raw_restored" ] }, { "cell_type": "markdown", "id": "987b981f", "metadata": {}, "source": [ "## Restore with custom dataclasses\n", "\n", "### With Orbax\n", "\n", "* The pytrees restored in the previous examples are in the form of raw dictionaries. Original pytrees contain custom dataclasses like [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state) and `optax` states.\n", "* This is because when restoring a pytree, the program does not yet know which structure it once belonged to.\n", "* To resolve this, you should first provide an example pytree to let Orbax or Flax know exactly which structure to restore to.\n", "\n", "This section demonstrates how to set up any custom Flax dataclass explicitly, and have the same structure as a saved checkpoint.\n", "\n", "Note: Data that was a JAX NumPy array (`jnp.array`) format will be restored as a NumPy array (`numpy.array`). This would not affect your work because JAX will [automatically convert](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) NumPy arrays to JAX arrays once the computation starts." ] }, { "cell_type": "code", "execution_count": 11, "id": "58f42513", "metadata": { "outputId": "110c6b6e-fe42-4179-e5d8-6b92d355e11b" }, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': TrainState(step=1, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "empty_state = train_state.TrainState.create(\n", " apply_fn=model.apply,\n", " params=jax.tree_util.tree_map(np.zeros_like, variables['params']), # values of the tree leaf doesn't matter\n", " tx=tx,\n", ")\n", "empty_config = {'dimensions': np.array([0, 0])}\n", "target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}\n", "state_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save', item=target)\n", "state_restored" ] }, { "cell_type": "markdown", "id": "f1c18bc6", "metadata": {}, "source": [ "### With the legacy API\n", "\n", "Alternatively, you can restore from Orbax `CheckpointManager` and from the legacy Flax code as follows:" ] }, { "cell_type": "code", "execution_count": 12, "id": "a61e9a66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': TrainState(step=1, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_manager.restore(4, items=target)" ] }, { "cell_type": "code", "execution_count": 13, "id": "412af50e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.\n" ] }, { "data": { "text/plain": [ "{'model': TrainState(step=1, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),\n", " 'config': {'dimensions': array([5, 3])},\n", " 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=target)" ] }, { "cell_type": "markdown", "id": "27461ac8", "metadata": {}, "source": [ "It's often recommended to refactor out the process of initializing a checkpoint's structure (for example, a [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state)), so that saving/loading is easier and less error-prone. This is because functions and complex objects like `apply_fn` and `tx` (optimizer) cannot be serialized into the checkpoint file and must be initialized by code." ] }, { "attachments": {}, "cell_type": "markdown", "id": "136a300a", "metadata": {}, "source": [ "## Restore when checkpoint structures differ\n", "\n", "During your development, your checkpoint structure will change when changing the model, adding/removing fields during tweaking, and so on.\n", "\n", "This section explains how to load old data to your new code.\n", "\n", "Below is a simple example — a `CustomTrainState` extended from `flax.training.train_state.TrainState` that contains an extra field called `batch_stats`. When working on a real-world model, you may need this when applying [batch normalization](https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html).\n", "\n", "Here, you store the new `CustomTrainState` as step 5, while step 4 contains the old/previous `TrainState`." ] }, { "cell_type": "code", "execution_count": 14, "id": "be65d4af", "metadata": { "outputId": "4fe776f0-65f8-4fc4-d64a-990520b36dce" }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class CustomTrainState(train_state.TrainState):\n", " batch_stats: Any = None\n", "\n", "custom_state = CustomTrainState.create(\n", " apply_fn=state.apply_fn,\n", " params=state.params,\n", " tx=state.tx,\n", " batch_stats=np.arange(10),\n", ")\n", "\n", "custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]}\n", "# Use a custom state to read the old `TrainState` checkpoint.\n", "custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]}\n", "\n", "# Save it in Orbax.\n", "custom_save_args = orbax_utils.save_args_from_target(custom_ckpt)\n", "checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args})" ] }, { "cell_type": "markdown", "id": "379c2255", "metadata": {}, "source": [ "It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given.\n", "\n", "Below are examples of a few common scenarios." ] }, { "attachments": {}, "cell_type": "markdown", "id": "d5fa9652", "metadata": {}, "source": [ "### Scenario 1: When a reference object is partial\n", "\n", "If your reference object is a subtree of your checkpoint, the restoration will ignore the additional field(s) and restore a checkpoint with the same structure as the reference.\n", "\n", "Like in the example below, the `batch_stats` field in `CustomTrainState` was ignored, and the checkpoint was restored as a `TrainState`.\n", "\n", "This can also be useful for reading only part of your checkpoint." ] }, { "cell_type": "code", "execution_count": 15, "id": "68828029", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': TrainState(step=0, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "restored = checkpoint_manager.restore(5, items=target)\n", "assert not hasattr(restored, 'batch_stats')\n", "assert type(restored['model']) == train_state.TrainState\n", "restored" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5c6822c6", "metadata": {}, "source": [ "### Scenario 2: When a checkpoint is partial\n", "\n", "On the other hand, if the reference object contains a value that is not available in the checkpoint, the checkpointing code will by default warn that some data is not compatible.\n", "\n", "To bypass the error, you need to pass an Orbax [`transform`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html) that teaches Orbax how to conform this checkpoint into the structure of the `custom_target`.\n", "\n", "In this case, pass a default `{}` that lets Orbax use values in the `custom_target` to fill in the blank. This allows you to restore an old checkpoint into a new data structure, the `CustomTrainState`." ] }, { "cell_type": "code", "execution_count": 16, "id": "a5d14c9f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "KeyError when target state has an unmentioned field: 'batch_stats'\n", "\n" ] }, { "data": { "text/plain": [ "{'config': None,\n", " 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)],\n", " 'model': CustomTrainState(step=1, apply_fn=, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "try:\n", " checkpoint_manager.restore(4, items=custom_target)\n", "except KeyError as e:\n", " print(f'KeyError when target state has an unmentioned field: {e}')\n", " print('')\n", "\n", "# Step 4 is an original `TrainState`, without the `batch_stats`\n", "custom_restore_args = orbax_utils.restore_args_from_target(custom_target)\n", "restored = checkpoint_manager.restore(4, items=custom_target,\n", " restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})\n", "assert type(restored['model']) == CustomTrainState\n", "np.testing.assert_equal(restored['model'].batch_stats,\n", " custom_target['model'].batch_stats)\n", "restored" ] }, { "cell_type": "markdown", "id": "74a4b0fd", "metadata": {}, "source": [ "##### With Orbax\n", "\n", "If you have already saved your checkpoints with the Orbax backend, you can use `orbax_transforms` to access this `transforms` argument in the Flax API." ] }, { "cell_type": "code", "execution_count": 17, "id": "29fd1e33", "metadata": { "outputId": "cdbb9247-d1eb-4458-aa83-8db0332af7cb" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.\n" ] }, { "data": { "text/plain": [ "{'model': CustomTrainState(step=1, apply_fn=, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n", " 'config': None,\n", " 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Save in the \"Flax-with-Orbax\" backend.\n", "flax.config.update('flax_use_orbax_checkpointing', True)\n", "checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',\n", " target=ckpt,\n", " step=4,\n", " overwrite=True,\n", " keep=2)\n", "\n", "checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=custom_target, step=4,\n", " orbax_transforms={})" ] }, { "cell_type": "markdown", "id": "830ef07c", "metadata": {}, "source": [ "##### With the legacy API\n", "\n", "Using the legacy `flax.training.checkpoints` API, similar things are doable too, but they are not as flexible as the [Orbax Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html).\n", "\n", "You need to restore the checkpoint to a raw dict with `target=None`, modify the structure accordingly, and then deserialize it back to the original target." ] }, { "cell_type": "code", "execution_count": 18, "id": "051e7a16", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': CustomTrainState(step=1, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),\n", " 'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)]}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Save using the legacy Flax `checkpoints` API.\n", "flax.config.update('flax_use_orbax_checkpointing', False)\n", "checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',\n", " target=ckpt,\n", " step=5,\n", " overwrite=True,\n", " keep=2)\n", "\n", "# Pass no target to get a raw state dictionary first.\n", "raw_state_dict = checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=None, step=5)\n", "# Add/remove fields as needed.\n", "raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10))\n", "# Restore the classes with correct target now\n", "flax.serialization.from_state_dict(custom_target, raw_state_dict)" ] }, { "cell_type": "markdown", "id": "a6b39501", "metadata": {}, "source": [ "## Asynchronized checkpointing\n", "\n", "Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training.\n", "\n", "You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`.\n", "\n", "Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps.\n", "\n", "Whenever you want to explicitly wait until an async save is done, you can call `async_checkpointer.wait_until_finished()`." ] }, { "cell_type": "code", "execution_count": 19, "id": "85be68a6", "metadata": { "outputId": "aefce94c-8bae-4355-c142-05f2b61c39e2" }, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3])},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': TrainState(step=1, apply_fn=, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=.init_fn at 0x13d5d83a0>, update=.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was\n", "# originally designed for multi-process large model checkpointing.\n", "# For Python notebooks or other single-process settings, just set up with `num_processes=1`.\n", "# Refer to https://jax.readthedocs.io/en/latest/multi_process.html#initializing-the-cluster\n", "# for how to set it up in multi-process scenarios.\n", "jax.distributed.initialize(\"localhost:8889\", num_processes=1, process_id=0)\n", "\n", "async_checkpointer = orbax.checkpoint.AsyncCheckpointer(\n", " orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)\n", "\n", "# Save your job:\n", "async_checkpointer.save('/tmp/flax_ckpt/orbax/single_save_async', ckpt, save_args=save_args)\n", "# ... Continue with your work...\n", "\n", "# ... Until a time when you want to wait until the save completes:\n", "async_checkpointer.wait_until_finished() # Blocks until the checkpoint saving is completed.\n", "async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target)" ] }, { "cell_type": "markdown", "id": "13e93db6", "metadata": {}, "source": [ "If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead." ] }, { "cell_type": "code", "execution_count": 20, "id": "af33b138", "metadata": {}, "outputs": [], "source": [ "async_checkpoint_manager = orbax.checkpoint.CheckpointManager(\n", " '/tmp/flax_ckpt/orbax/managed_async', async_checkpointer, options)\n", "async_checkpoint_manager.wait_until_finished()" ] }, { "cell_type": "markdown", "id": "bb0e03cd", "metadata": {}, "source": [ "## Multi-host/multi-process checkpointing\n", "\n", "JAX provides a few ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). To get started on JAX in multi-process settings, check out [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and the [distributed array guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "\n", "In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket.\n", "\n", "Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context." ] }, { "cell_type": "code", "execution_count": 21, "id": "ubdUvyMrhD-1", "metadata": {}, "outputs": [], "source": [ "from jax.sharding import PartitionSpec, NamedSharding\n", "\n", "# Create an array sharded across multiple devices.\n", "mesh_shape = (4, 2)\n", "devices = np.asarray(jax.devices()).reshape(*mesh_shape)\n", "mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n", "\n", "mp_array = jax.device_put(np.arange(8 * 2).reshape(8, 2),\n", " NamedSharding(mesh, PartitionSpec('x', 'y')))\n", "\n", "# Make it a pytree.\n", "mp_ckpt = {'model': mp_array}" ] }, { "cell_type": "code", "execution_count": 22, "id": "a669bc05", "metadata": {}, "outputs": [], "source": [ "async_checkpoint_manager.save(0, mp_ckpt)\n", "async_checkpoint_manager.wait_until_finished()" ] }, { "cell_type": "markdown", "id": "4deee32e", "metadata": {}, "source": [ "When restoring a checkpoint with multi-process arrays, you need to specify what `sharding` each array should be restored back to. Otherwise, they will be restored as large `np.array`s on process 0, costing time and memory.\n", "\n", "(In this notebook, since we are on single-process, it will be restored as `np.array` even if we provide shardings.)\n", "\n", "### With Orbax\n", "\n", "Orbax allows you to specify this by passing a pytree of `sharding`s in `restore_args`. If you already have a reference pytree that has all the arrays with the right sharding, you can use `orbax_utils.restore_args_from_target` to transform it into the `restore_args` that Orbax needs." ] }, { "cell_type": "code", "execution_count": 23, "id": "b8e7daaa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': Array([[ 0, 1],\n", " [ 2, 3],\n", " [ 4, 5],\n", " [ 6, 7],\n", " [ 8, 9],\n", " [10, 11],\n", " [12, 13],\n", " [14, 15]], dtype=int32)}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The reference doesn't need to be as large as your checkpoint!\n", "# Just make sure it has the `.sharding` you want.\n", "mp_smaller = jax.device_put(np.arange(8).reshape(4, 2),\n", " NamedSharding(mesh, PartitionSpec('x', 'y')))\n", "ref_ckpt = {'model': mp_smaller}\n", "\n", "restore_args = orbax_utils.restore_args_from_target(ref_ckpt)\n", "async_checkpoint_manager.restore(\n", " 0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})" ] }, { "cell_type": "markdown", "id": "edc355ce", "metadata": {}, "source": [ "### With the legacy Flax: use `save_checkpoint_multiprocess`\n", "\n", "In legacy Flax, to save multi-process arrays, use [`flax.training.checkpoints.save_checkpoint_multiprocess()`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) in place of `save_checkpoint()` and with the same arguments.\n", "\n", "If your checkpoint is too large, you can specify `timeout_secs` in the manager and give it more time to finish writing." ] }, { "cell_type": "code", "execution_count": 24, "id": "5d10039b", "metadata": { "outputId": "901bb097-0899-479d-b9ae-61dae79e7057" }, "outputs": [ { "data": { "text/plain": [ "'/tmp/flax_ckpt/checkpoint_3'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)\n", "checkpoints.save_checkpoint_multiprocess(ckpt_dir,\n", " mp_ckpt,\n", " step=3,\n", " overwrite=True,\n", " keep=4,\n", " orbax_checkpointer=async_checkpointer)" ] }, { "cell_type": "code", "execution_count": 25, "id": "a9f9724c", "metadata": { "outputId": "393c4a0e-8a8c-4ca6-c609-93c8bab38e75" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.\n" ] }, { "data": { "text/plain": [ "{'model': Array([[ 0, 1],\n", " [ 2, 3],\n", " [ 4, 5],\n", " [ 6, 7],\n", " [ 8, 9],\n", " [10, 11],\n", " [12, 13],\n", " [14, 15]], dtype=int32)}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mp_restored = checkpoints.restore_checkpoint(ckpt_dir,\n", " target=ref_ckpt,\n", " step=3,\n", " orbax_checkpointer=async_checkpointer)\n", "mp_restored" ] }, { "cell_type": "markdown", "id": "65cfdd59", "metadata": {}, "source": [ "## Orbax-as-backend troubleshooting\n", "\n", "As an intermediate stage of the migration (to Orbax from the legacy Flax `checkpoints` API), `flax.training.checkpoints` APIs will start to use Orbax as their backend when saving checkpoints starting from May 15, 2023.\n", "\n", "Checkpoints saved with the Orbax backend can be readable by either `flax.training.checkpoints.restore_checkpoint` or `orbax.checkpoint.PyTreeCheckpointer`.\n", "\n", "Code-wise, this is equivalent to setting the config flag [`flax.config.flax_use_orbax_checkpointing`](https://github.com/google/flax/blob/main/flax/configurations.py#L103) default to `True`. You can overwrite this value in your project with `flax.config.update('flax_use_orbax_checkpointing', )` at any time.\n", "\n", "In general, this automatic migration will not affect most users. However, you may encounter issues if your API usage follows some specific pattern. Check out the sections below for troubleshooting." ] }, { "cell_type": "markdown", "id": "415bceb1", "metadata": {}, "source": [ "### If your devices hang when writing checkpoints\n", "\n", "If you are running in a multi-host environment (usually anything larger than 8 TPU devices) and your devices hang when writing checkpoints, check if your code is in the following pattern (that is, the `save_checkpoint` only ran on host `0`):\n", "\n", "```\n", "if jax.process_index() == 0:\n", " flax.training.checkpoints.save_checkpoint(...)\n", "```\n", "\n", "Unfortunately this is a legacy pattern that will be deprecated and won't be supported, because in a multi-process environment, the checkpointing code should coordinate among hosts instead of being triggered only on the host `0`. Replacing the code above with the following should resolve the hang issue:\n", "\n", "```\n", "flax.training.checkpoints.save_checkpoint_multiprocess(...)\n", "```" ] }, { "cell_type": "markdown", "id": "70e0ebb3", "metadata": {}, "source": [ "### If you don't save pytrees\n", "\n", "Orbax uses `orbax.checkpoint.PyTreeCheckpointHandler` to save checkpoints, which means they only save pytrees.\n", "\n", "If you want to save singular arrays or numbers, you have two options:\n", "\n", "1. Use `orbax.ArrayCheckpointHandler` to save them following [this migration section](https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html#saving-loading-a-single-jax-or-numpy-array).\n", "\n", "1. Wrap it inside a pytree and save as usual." ] } ], "metadata": { "gpuClass": "standard", "jupytext": { "formats": "ipynb,md", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/guides/training_techniques/use_checkpointing.md ================================================ --- jupyter: jupytext: formats: ipynb,md main_language: python text_representation: extension: .md format_name: markdown format_version: '1.3' jupytext_version: 1.13.8 --- # Save and load checkpoints This guide demonstrates how to save and load Flax checkpoints with [Orbax](https://github.com/google/orbax). Orbax provides a variety of features for saving and loading model data, which you will learn about in this doc: * Support for various array types and storage formats * Asynchronous saving to reduce training wait time * Versioning and automatic bookkeeping of past checkpoints * Flexible [`transformations`](https://orbax.readthedocs.io/en/latest/transformations.html) to tweak and load old checkpoints * [`jax.sharding`](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)-based API to save and load in multi-host scenarios --- **_Ongoing migration to Orbax:_** After July 30 2023, Flax's legacy `flax.training.checkpoints` API will be deprecated in favor of [Orbax](https://github.com/google/orbax). * **If you are a new Flax user**: Use the new `orbax.checkpoint` API, as demonstrated in this guide. * **If you have legacy `flax.training.checkpoints` code in your project**: Consider the following options: * **Migrating your code to Orbax (Recommended)**: Migrate your API calls to `orbax.checkpoint` API by following this [migration guide](https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html). * **Automatically use the Orbax backend**: Add `flax.config.update('flax_use_orbax_checkpointing', True)` to your project, which will let your `flax.training.checkpoints` calls automatically use the Orbax backend to save your checkpoints. * **Scheduled flip**: This will become the default mode after **May 2023** (tentative date). * Visit [Orbax-as-backend troubleshooting section](https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#orbax-as-backend-troubleshooting) if you meet any issue in the automatic migration. --- For backward-compatibility, this guide shows the Orbax-equivalent calls in the Flax legacy `flax.training.checkpoints` API. If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](https://orbax.readthedocs.io/en/latest/). ## Setup Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/jax-ml/jax#installation). Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell. ```python import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' ``` ```python from typing import Optional, Any import shutil import numpy as np import jax from jax import random, numpy as jnp import flax from flax import linen as nn from flax.training import checkpoints, train_state from flax import struct, serialization import orbax.checkpoint import optax ``` ```python ckpt_dir = '/tmp/flax_ckpt' if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run. ``` ## Save checkpoints In Orbax and Flax, you can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html). This includes not only typical Python and NumPy containers, but also customized classes extended from [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass). That means you can store almost any data generated — not only your model parameters, but any arrays/dictionaries, metadata/configs, and so on. First, create a pytree with many data structures and containers, and play with it: ```python outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" # A simple model with one linear layer. key1, key2 = random.split(random.key(0)) x1 = random.normal(key1, (5,)) # A simple JAX array. model = nn.Dense(features=3) variables = model.init(key2, x1) # Flax's TrainState is a pytree dataclass and is supported in checkpointing. # Define your class with `@flax.struct.dataclass` decorator to make it compatible. tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer. state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx) # Perform a simple gradient update similar to the one during a normal training workflow. state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params)) # Some arbitrary nested pytree with a dictionary and a NumPy array. config = {'dimensions': np.array([5, 3])} # Bundle everything together. ckpt = {'model': state, 'config': config, 'data': [x1]} ckpt ``` ### With Orbax Save the checkpoint with `orbax.checkpoint.PyTreeCheckpointer`, directly to the `tmp/orbax/single_save` directory. Note: An optional `save_args` is provided. This is recommended for performance speedups, as it bundles smaller arrays in your pytree to a single large file instead of multiple smaller files. ```python from flax.training import orbax_utils orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() save_args = orbax_utils.save_args_from_target(ckpt) orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args) ``` Next, to use versioning and automatic bookkeeping features, you need to wrap `orbax.checkpoint.CheckpointManager` over `orbax.checkpoint.PyTreeCheckpointer`. In addition, provide `orbax.checkpoint.CheckpointManagerOptions` that customizes your needs, such as how often and on what criteria you prefer old checkpoints be deleted. See [documentation](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) for a full list of options offered. `orbax.checkpoint.CheckpointManager` should be placed at the top-level outside your training steps to manage your saves. ```python outputId="b7132933-566d-440d-c34e-c5468d87cbdc" options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True) checkpoint_manager = orbax.checkpoint.CheckpointManager( '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options) # Inside a training loop for step in range(5): # ... do your training checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args}) os.listdir('/tmp/flax_ckpt/orbax/managed') # Because max_to_keep=2, only step 3 and 4 are retained ``` ### With the legacy API And here's how to save with the legacy Flax checkpointing utilities (note that this provides less management features compared with `orbax.checkpoint.CheckpointManagerOptions`): ```python outputId="6d849273-15ce-4480-8864-726d1838ac1f" # Import Flax Checkpoints. from flax.training import checkpoints checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=ckpt, step=0, overwrite=True, keep=2) ``` ## Restore checkpoints ### With Orbax In Orbax, call `.restore()` for either `orbax.checkpoint.PyTreeCheckpointer` or `orbax.checkpoint.CheckpointManager` to restore your checkpoint in the raw pytree format. ```python outputId="b4af1ef4-f22f-459b-bdca-2e6bfa16c08b" raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save') raw_restored ``` Note that the `step` number is required for `CheckpointManger`. You can also use `.latest_step()` to find the latest step available. ```python step = checkpoint_manager.latest_step() # step = 4 checkpoint_manager.restore(step) ``` ### With the legacy API Note that with the migration to Orbax in progress, `flax.training.checkpointing.restore_checkpoint` can automatically identify whether a checkpoint is saved in the legacy Flax format or with an Orbax backend, and restore the pytree correctly. Therefore, adding `flax.config.update('flax_use_orbax_checkpointing', True)` won't hurt your ability to restore old checkpoints. Here's how to restore checkpoints using the legacy API: ```python outputId="85ffceca-f38d-46b8-e567-d9d38b7885f9" raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None) raw_restored ``` ## Restore with custom dataclasses ### With Orbax * The pytrees restored in the previous examples are in the form of raw dictionaries. Original pytrees contain custom dataclasses like [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state) and `optax` states. * This is because when restoring a pytree, the program does not yet know which structure it once belonged to. * To resolve this, you should first provide an example pytree to let Orbax or Flax know exactly which structure to restore to. This section demonstrates how to set up any custom Flax dataclass explicitly, and have the same structure as a saved checkpoint. Note: Data that was a JAX NumPy array (`jnp.array`) format will be restored as a NumPy array (`numpy.array`). This would not affect your work because JAX will [automatically convert](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) NumPy arrays to JAX arrays once the computation starts. ```python outputId="110c6b6e-fe42-4179-e5d8-6b92d355e11b" empty_state = train_state.TrainState.create( apply_fn=model.apply, params=jax.tree_util.tree_map(np.zeros_like, variables['params']), # values of the tree leaf doesn't matter tx=tx, ) empty_config = {'dimensions': np.array([0, 0])} target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]} state_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save', item=target) state_restored ``` ### With the legacy API Alternatively, you can restore from Orbax `CheckpointManager` and from the legacy Flax code as follows: ```python checkpoint_manager.restore(4, items=target) ``` ```python checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=target) ``` It's often recommended to refactor out the process of initializing a checkpoint's structure (for example, a [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state)), so that saving/loading is easier and less error-prone. This is because functions and complex objects like `apply_fn` and `tx` (optimizer) cannot be serialized into the checkpoint file and must be initialized by code. ## Restore when checkpoint structures differ During your development, your checkpoint structure will change when changing the model, adding/removing fields during tweaking, and so on. This section explains how to load old data to your new code. Below is a simple example — a `CustomTrainState` extended from `flax.training.train_state.TrainState` that contains an extra field called `batch_stats`. When working on a real-world model, you may need this when applying [batch normalization](https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html). Here, you store the new `CustomTrainState` as step 5, while step 4 contains the old/previous `TrainState`. ```python outputId="4fe776f0-65f8-4fc4-d64a-990520b36dce" class CustomTrainState(train_state.TrainState): batch_stats: Any = None custom_state = CustomTrainState.create( apply_fn=state.apply_fn, params=state.params, tx=state.tx, batch_stats=np.arange(10), ) custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]} # Use a custom state to read the old `TrainState` checkpoint. custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]} # Save it in Orbax. custom_save_args = orbax_utils.save_args_from_target(custom_ckpt) checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args}) ``` It is recommended to keep your checkpoints up-to-date with your pytree dataclass definitions. However, you might be forced to restore the checkpoints with incompatible reference objects at runtime. When this happens, the checkpoint restoration will try to respect the structure of the reference when given. Below are examples of a few common scenarios. ### Scenario 1: When a reference object is partial If your reference object is a subtree of your checkpoint, the restoration will ignore the additional field(s) and restore a checkpoint with the same structure as the reference. Like in the example below, the `batch_stats` field in `CustomTrainState` was ignored, and the checkpoint was restored as a `TrainState`. This can also be useful for reading only part of your checkpoint. ```python restored = checkpoint_manager.restore(5, items=target) assert not hasattr(restored, 'batch_stats') assert type(restored['model']) == train_state.TrainState restored ``` ### Scenario 2: When a checkpoint is partial On the other hand, if the reference object contains a value that is not available in the checkpoint, the checkpointing code will by default warn that some data is not compatible. To bypass the error, you need to pass an Orbax [`transform`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html) that teaches Orbax how to conform this checkpoint into the structure of the `custom_target`. In this case, pass a default `{}` that lets Orbax use values in the `custom_target` to fill in the blank. This allows you to restore an old checkpoint into a new data structure, the `CustomTrainState`. ```python try: checkpoint_manager.restore(4, items=custom_target) except KeyError as e: print(f'KeyError when target state has an unmentioned field: {e}') print('') # Step 4 is an original `TrainState`, without the `batch_stats` custom_restore_args = orbax_utils.restore_args_from_target(custom_target) restored = checkpoint_manager.restore(4, items=custom_target, restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args}) assert type(restored['model']) == CustomTrainState np.testing.assert_equal(restored['model'].batch_stats, custom_target['model'].batch_stats) restored ``` ##### With Orbax If you have already saved your checkpoints with the Orbax backend, you can use `orbax_transforms` to access this `transforms` argument in the Flax API. ```python outputId="cdbb9247-d1eb-4458-aa83-8db0332af7cb" # Save in the "Flax-with-Orbax" backend. flax.config.update('flax_use_orbax_checkpointing', True) checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=ckpt, step=4, overwrite=True, keep=2) checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=custom_target, step=4, orbax_transforms={}) ``` ##### With the legacy API Using the legacy `flax.training.checkpoints` API, similar things are doable too, but they are not as flexible as the [Orbax Transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html). You need to restore the checkpoint to a raw dict with `target=None`, modify the structure accordingly, and then deserialize it back to the original target. ```python # Save using the legacy Flax `checkpoints` API. flax.config.update('flax_use_orbax_checkpointing', False) checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=ckpt, step=5, overwrite=True, keep=2) # Pass no target to get a raw state dictionary first. raw_state_dict = checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=None, step=5) # Add/remove fields as needed. raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10)) # Restore the classes with correct target now flax.serialization.from_state_dict(custom_target, raw_state_dict) ``` ## Asynchronized checkpointing Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training. You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`. Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps. Whenever you want to explicitly wait until an async save is done, you can call `async_checkpointer.wait_until_finished()`. ```python outputId="aefce94c-8bae-4355-c142-05f2b61c39e2" # `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was # originally designed for multi-process large model checkpointing. # For Python notebooks or other single-process settings, just set up with `num_processes=1`. # Refer to https://jax.readthedocs.io/en/latest/multi_process.html#initializing-the-cluster # for how to set it up in multi-process scenarios. jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0) async_checkpointer = orbax.checkpoint.AsyncCheckpointer( orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50) # Save your job: async_checkpointer.save('/tmp/flax_ckpt/orbax/single_save_async', ckpt, save_args=save_args) # ... Continue with your work... # ... Until a time when you want to wait until the save completes: async_checkpointer.wait_until_finished() # Blocks until the checkpoint saving is completed. async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target) ``` If you are using Orbax `CheckpointManager`, just pass in the async_checkpointer when initializing it. Then, in practice, call `async_checkpoint_manager.wait_until_finished()` instead. ```python async_checkpoint_manager = orbax.checkpoint.CheckpointManager( '/tmp/flax_ckpt/orbax/managed_async', async_checkpointer, options) async_checkpoint_manager.wait_until_finished() ``` ## Multi-host/multi-process checkpointing JAX provides a few ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). To get started on JAX in multi-process settings, check out [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and the [distributed array guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket. Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context. ```python from jax.sharding import PartitionSpec, NamedSharding # Create an array sharded across multiple devices. mesh_shape = (4, 2) devices = np.asarray(jax.devices()).reshape(*mesh_shape) mesh = jax.sharding.Mesh(devices, ('x', 'y')) mp_array = jax.device_put(np.arange(8 * 2).reshape(8, 2), NamedSharding(mesh, PartitionSpec('x', 'y'))) # Make it a pytree. mp_ckpt = {'model': mp_array} ``` ```python async_checkpoint_manager.save(0, mp_ckpt) async_checkpoint_manager.wait_until_finished() ``` When restoring a checkpoint with multi-process arrays, you need to specify what `sharding` each array should be restored back to. Otherwise, they will be restored as large `np.array`s on process 0, costing time and memory. (In this notebook, since we are on single-process, it will be restored as `np.array` even if we provide shardings.) ### With Orbax Orbax allows you to specify this by passing a pytree of `sharding`s in `restore_args`. If you already have a reference pytree that has all the arrays with the right sharding, you can use `orbax_utils.restore_args_from_target` to transform it into the `restore_args` that Orbax needs. ```python # The reference doesn't need to be as large as your checkpoint! # Just make sure it has the `.sharding` you want. mp_smaller = jax.device_put(np.arange(8).reshape(4, 2), NamedSharding(mesh, PartitionSpec('x', 'y'))) ref_ckpt = {'model': mp_smaller} restore_args = orbax_utils.restore_args_from_target(ref_ckpt) async_checkpoint_manager.restore( 0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args}) ``` ### With the legacy Flax: use `save_checkpoint_multiprocess` In legacy Flax, to save multi-process arrays, use [`flax.training.checkpoints.save_checkpoint_multiprocess()`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) in place of `save_checkpoint()` and with the same arguments. If your checkpoint is too large, you can specify `timeout_secs` in the manager and give it more time to finish writing. ```python outputId="901bb097-0899-479d-b9ae-61dae79e7057" async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50) checkpoints.save_checkpoint_multiprocess(ckpt_dir, mp_ckpt, step=3, overwrite=True, keep=4, orbax_checkpointer=async_checkpointer) ``` ```python outputId="393c4a0e-8a8c-4ca6-c609-93c8bab38e75" mp_restored = checkpoints.restore_checkpoint(ckpt_dir, target=ref_ckpt, step=3, orbax_checkpointer=async_checkpointer) mp_restored ``` ## Orbax-as-backend troubleshooting As an intermediate stage of the migration (to Orbax from the legacy Flax `checkpoints` API), `flax.training.checkpoints` APIs will start to use Orbax as their backend when saving checkpoints starting from May 15, 2023. Checkpoints saved with the Orbax backend can be readable by either `flax.training.checkpoints.restore_checkpoint` or `orbax.checkpoint.PyTreeCheckpointer`. Code-wise, this is equivalent to setting the config flag [`flax.config.flax_use_orbax_checkpointing`](https://github.com/google/flax/blob/main/flax/configurations.py#L103) default to `True`. You can overwrite this value in your project with `flax.config.update('flax_use_orbax_checkpointing', )` at any time. In general, this automatic migration will not affect most users. However, you may encounter issues if your API usage follows some specific pattern. Check out the sections below for troubleshooting. ### If your devices hang when writing checkpoints If you are running in a multi-host environment (usually anything larger than 8 TPU devices) and your devices hang when writing checkpoints, check if your code is in the following pattern (that is, the `save_checkpoint` only ran on host `0`): ``` if jax.process_index() == 0: flax.training.checkpoints.save_checkpoint(...) ``` Unfortunately this is a legacy pattern that will be deprecated and won't be supported, because in a multi-process environment, the checkpointing code should coordinate among hosts instead of being triggered only on the host `0`. Replacing the code above with the following should resolve the hang issue: ``` flax.training.checkpoints.save_checkpoint_multiprocess(...) ``` ### If you don't save pytrees Orbax uses `orbax.checkpoint.PyTreeCheckpointHandler` to save checkpoints, which means they only save pytrees. If you want to save singular arrays or numbers, you have two options: 1. Use `orbax.ArrayCheckpointHandler` to save them following [this migration section](https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html#saving-loading-a-single-jax-or-numpy-array). 1. Wrap it inside a pytree and save as usual. ================================================ FILE: docs/index.rst ================================================ .. Flax documentation main file, created by sphinx-quickstart on Mon Feb 17 11:41:38 2020. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. ****************************** Flax Linen ****************************** .. div:: sd-text-left sd-font-italic Neural networks with JAX ---- Flax Linen delivers an **end-to-end and flexible user experience for researchers who use JAX with neural networks**. Flax exposes the full power of `JAX `__. It is made up of loosely coupled libraries, which are showcased with end-to-end integrated `guides `__ and `examples `__. Flax Linen is used by `hundreds of projects (and growing) `__, both in the open source community (like `Hugging Face `__) and at Google (like `Gemini `__, `Imagen `__, `Scenic `__, and `Big Vision `__). Features ^^^^^^^^^ .. grid:: .. grid-item:: :columns: 12 12 12 6 .. card:: Safety :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax is designed for correctness and safety. Thanks to its immutable Modules and Functional API, Flax helps mitigate bugs that arise when handling state in JAX. .. grid-item:: :columns: 12 12 12 6 .. card:: Control :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax grants more fine-grained control and expressivity than most Neural Network frameworks via its Variable Collections, RNG Collections and Mutability conditions. .. grid-item:: :columns: 12 12 12 6 .. card:: Functional API :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax's functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex. .. grid-item:: :columns: 12 12 12 6 .. card:: Terse code :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax's :meth:`compact ` Modules enables submodules to be defined directly at their callsite, leading to code that is easier to read and avoids repetition. ---- Installation ^^^^^^^^^^^^ .. code-block:: bash pip install flax # or to install the latest version of Flax: pip install --upgrade git+https://github.com/google/flax.git Flax installs the vanilla CPU version of JAX, if you need a custom version please check out `JAX's installation page `__. Basic usage ^^^^^^^^^^^^ .. testsetup:: import jax from jax import random import flax.linen as nn import jax.numpy as jnp .. testcode:: class MLP(nn.Module): # create a Flax Module dataclass out_dims: int @nn.compact def __call__(self, x): x = x.reshape((x.shape[0], -1)) x = nn.Dense(128)(x) # create inline Flax Module submodules x = nn.relu(x) x = nn.Dense(self.out_dims)(x) # shape inference return x model = MLP(out_dims=10) # instantiate the MLP model x = jnp.empty((4, 28, 28, 1)) # generate random data variables = model.init(random.key(42), x)# initialize the weights y = model.apply(variables, x) # make forward pass ---- Learn more ^^^^^^^^^^ .. grid:: .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`rocket_launch;2em` Quickstart :class-card: sd-text-black sd-bg-light :link: quick_start.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`library_books;2em` Guides :class-card: sd-text-black sd-bg-light :link: guides/index.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`science;2em` Examples :class-card: sd-text-black sd-bg-light :link: examples.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`import_contacts;2em` Glossary :class-card: sd-text-black sd-bg-light :link: glossary.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`settings;2em` Developer notes :class-card: sd-text-black sd-bg-light :link: developer_notes/index.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`history_edu;2em` The Flax philosophy :class-card: sd-text-black sd-bg-light :link: https://flax.readthedocs.io/en/latest/philosophy.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light :link: api_reference/index.html ---- Ecosystem ^^^^^^^^^ Notable examples in Flax include: .. grid:: .. grid-item:: :columns: 6 6 6 4 .. card:: `🤗 Hugging Face `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic NLP and computer vision models .. grid-item:: :columns: 6 6 6 4 .. card:: `🥑 DALLE Mini `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Model for text-to-image generation .. grid-item:: :columns: 6 6 6 4 .. card:: `PaLM `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic 540-billion parameter model for text generation .. grid-item:: :columns: 6 6 6 4 .. card:: `Imagen `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Text-to-image diffusion models .. grid-item:: :columns: 6 6 6 4 .. card:: `Scenic `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Libraries for large-scale computer vision .. grid-item:: :columns: 6 6 6 4 .. card:: `Big Vision `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Large-scale computer vision models .. grid-item:: :columns: 6 6 6 4 .. card:: `MaxText `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Open source high performance LLM .. grid-item:: :columns: 6 6 6 4 .. card:: `T5x `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic Large language models .. grid-item:: :columns: 6 6 6 4 .. card:: `Brax `__ :class-card: sd-border-0 :shadow: none :class-title: sd-text-center sd-fs-5 .. div:: sd-text-center sd-font-italic On-device differentiable reinforcement learning environments .. role:: bold :class: bold .. toctree:: :hidden: :maxdepth: 2 Quick start guides/flax_fundamentals/flax_basics guides/index examples/index glossary faq developer_notes/index The Flax philosophy How to contribute api_reference/index Flax NNX ================================================ FILE: docs/linen_intro.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", "\n", "# Preface\n", "\n", "
\n", "
CAVEAT PROGRAMMER
\n", "\n", "The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Useful links\n", "\n", "⟶ [Slides](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit?usp=sharing) for the core ideas of the new Functional Core and Linen\n", "\n", "⟶ \"Design tests\" guided our design process. Many are available for [functional core](https://github.com/google/flax/tree/main/examples/core_design_test) and some for the [proposed Module abstraction](https://github.com/google/flax/tree/main/examples/linen_design_test/)\n", "\n", "⟶ Ported examples: [ImageNet](https://github.com/google/flax/tree/main/examples/imagenet) and [WMT](https://github.com/google/flax/tree/main/examples/wmt) (to the proposed Module abstraction). TODO: Port to functional core.\n", "\n", "⟶ Our new [discussion forums](https://github.com/google/flax/discussions/)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Install and Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "# Install the newest JAXlib version.\n", "!pip install --upgrade -q pip jax jaxlib\n", "# Install Flax at head:\n", "!pip install --upgrade -q git+https://github.com/google/flax.git" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import functools\n", "from typing import Any, Callable, Sequence, Optional\n", "import jax\n", "from jax import lax, random, numpy as jnp\n", "import flax\n", "from flax import linen as nn" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Invoking Modules" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's instantiate a `Dense` layer.\n", " - Modules are actually objects in this API, so we provide _constructor arguments_ when initializing the Module. In this case, we only have to provide the output `features` dimension." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = nn.Dense(features=3)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables.\n", "\n", "We call the `init` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `init` with `(rngs, *args, **kwargs)` so in this case, just `(rng, input)`:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "3adfaeaf-977e-4e82-8adf-d254fae6eb91" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Make RNG Keys and a fake input.\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "# provide key and fake input to get initialized variables\n", "init_variables = model.init(key2, x)\n", "\n", "init_variables" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We call the `apply` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `apply` with `(variables, *args, rngs=, mutable=, **kwargs)` where\n", " - `` are the optional _call time_ RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple __kinds__ of data, it's a dictionary of rng-keys per-kind, e.g. `{'params': key0, 'dropout': key1}` for a Module with dropout layers.\n", " - `` is an optional list of names of __kinds__ that are expected to be mutated during the call. e.g. `['batch_stats']` for a layer updating batchnorm statistics.\n", "\n", "So in this case, just `(variables, input)`:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "outputId": "e8c389a6-29f3-4f93-97ea-703e85a8b811" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 0.5035518 , 1.8548559 , -0.4270196 ],\n", " [ 0.0279097 , 0.5589246 , -0.43061775],\n", " [ 0.35471284, 1.5741 , -0.3286552 ],\n", " [ 0.5264864 , 1.2928858 , 0.10089308]], dtype=float32)" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "y = model.apply(init_variables, x)\n", "y" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Additional points:\n", " - If you want to `init` or `apply` a Module using a method other than call, you need to provide the `method=` kwarg to `init` and `apply` to use it instead of the default `__call__`, e.g. `method='encode'`, `method='decode'` to apply the encode/decode methods of an autoencoder." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Defining Basic Modules" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Composing submodules" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We support declaring modules in `setup()` that can still benefit from shape inference by using __Lazy Initialization__ that sets up variables the first time the Module is called." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "outputId": "1a6c6a17-0b95-42c2-b5bf-b9ad80fd7758", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", " -4.5417294e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", " -1.0110775e-03]]\n" ] } ], "source": [ "class ExplicitMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " def setup(self):\n", " # we automatically know what to do with lists, dicts of submodules\n", " self.layers = [nn.Dense(feat) for feat in self.features]\n", " # for single submodules, we would just write:\n", " # self.layer1 = nn.Dense(feat1)\n", "\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, lyr in enumerate(self.layers):\n", " x = lyr(x)\n", " if i != len(self.layers) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Here we show the equivalent compact form of the MLP that declares the submodules inline using the `@compact` decorator." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "outputId": "b3709789-e66e-4e20-f6b2-04022f8a62bb", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", " -4.5417294e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", " -1.0110775e-03]]\n" ] } ], "source": [ "class SimpleMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " # providing a name is optional though!\n", " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", " # x = nn.Dense(feat)(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Declaring and using variables" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls.\n", "\n", "For declaring parameters that aren't mutated inside the model, but rather by gradient descent, we use the syntax:\n", "\n", " `self.param(parameter_name, parameter_init_fn, *init_args, **init_kwargs)`\n", "\n", "with arguments:\n", " - `parameter_name` just the name, a string\n", " - `parameter_init_fn` a function taking an RNG key and a variable number of other arguments, i.e. `fn(rng, *args)`. typically those in `nn.initializers` take an `rng` and a `shape` argument.\n", " - the remaining arguments to feed to the init function when initializing.\n", "\n", "Again, we'll demonstrate declaring things inline as we typically do using the `@compact` decorator." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "outputId": "bc5cb1f2-c5e9-4159-d131-73247009e32f", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameters:\n", " FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "output:\n", " [[ 0.5035518 1.8548559 -0.4270196 ]\n", " [ 0.0279097 0.5589246 -0.43061775]\n", " [ 0.35471284 1.5741 -0.3286552 ]\n", " [ 0.5264864 1.2928858 0.10089308]]\n" ] } ], "source": [ "class SimpleDense(nn.Module):\n", " features: int\n", " kernel_init: Callable = nn.initializers.lecun_normal()\n", " bias_init: Callable = nn.initializers.zeros_init()\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " kernel = self.param('kernel',\n", " self.kernel_init, # RNG passed implicitly.\n", " (inputs.shape[-1], self.features)) # shape info.\n", " y = lax.dot_general(inputs, kernel,\n", " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", " bias = self.param('bias', self.bias_init, (self.features,))\n", " y = y + bias\n", " return y\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleDense(features=3)\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameters:\\n', init_variables)\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can also declare variables in setup, though in doing so you can't take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "outputId": "1e822bd8-7a08-4e80-e0e6-a86637c46772", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameters:\n", " FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "output:\n", " [[ 0.5035518 1.8548559 -0.4270196 ]\n", " [ 0.0279097 0.5589246 -0.43061775]\n", " [ 0.35471284 1.5741 -0.3286552 ]\n", " [ 0.5264864 1.2928858 0.10089308]]\n" ] } ], "source": [ "class ExplicitDense(nn.Module):\n", " features_in: int # <-- explicit input shape\n", " features: int\n", " kernel_init: Callable = nn.initializers.lecun_normal()\n", " bias_init: Callable = nn.initializers.zeros_init()\n", "\n", " def setup(self):\n", " self.kernel = self.param('kernel',\n", " self.kernel_init,\n", " (self.features_in, self.features))\n", " self.bias = self.param('bias', self.bias_init, (self.features,))\n", "\n", " def __call__(self, inputs):\n", " y = lax.dot_general(inputs, self.kernel,\n", " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", " y = y + self.bias\n", " return y\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitDense(features_in=4, features=3)\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameters:\\n', init_variables)\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## General Variables" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "For declaring generally mutable _variables_ that may be mutated inside the model we use the call:\n", "\n", " `self.variable(variable_kind, variable_name, variable_init_fn, *init_args, **init_kwargs)`\n", "\n", "with arguments:\n", " - `variable_kind` the \"kind\" of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. `batch_stats` for the moving statistics for a batch norm layer or `cache` for autoregressive cache data. Note that parameters also have a kind, but they're set to the default `param` kind.\n", " - `variable_name` just the name, a string\n", " - `variable_init_fn` a function taking a variable number of other arguments, i.e. `fn(*args)`. Note that we __don't__ assume the need for an RNG, if you _do_ want an RNG, provide it via a `self.make_rng(variable_kind)` call in the provided arguments.\n", " - the remaining arguments to feed to the init function when initializing.\n", "\n", "⚠️ Unlike parameters, we expect these to be mutated, so `self.variable` returns not a constant, but a _reference_ to the variable. To __get__ the raw value, you'd write `myvariable.value` and to __set__ it `myvariable.value = new_value`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "outputId": "2a8f5453-81b1-44dc-a431-d14b372c5710", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized variables:\n", " FrozenDict({\n", " counter: {\n", " count: DeviceArray(0, dtype=int32),\n", " },\n", "})\n", "mutated variables:\n", " FrozenDict({\n", " counter: {\n", " count: DeviceArray(1, dtype=int32),\n", " },\n", "})\n", "output:\n", " 1\n" ] } ], "source": [ "class Counter(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " # easy pattern to detect if we're initializing\n", " is_initialized = self.has_variable('counter', 'count')\n", " counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))\n", " if is_initialized:\n", " counter.value += 1\n", " return counter.value\n", "\n", "\n", "key1 = random.key(0)\n", "\n", "model = Counter()\n", "init_variables = model.init(key1)\n", "print('initialized variables:\\n', init_variables)\n", "\n", "y, mutated_variables = model.apply(init_variables, mutable=['counter'])\n", "\n", "print('mutated variables:\\n', mutated_variables)\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Another Mutability and RNGs Example" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let's make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "outputId": "8f299a5c-74c8-476c-93fa-e5543901ec45", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "updated variables:\n", " FrozenDict({\n", " params: {\n", " Dense_0: {\n", " kernel: DeviceArray([[ 0.6498898 , -0.5000124 , 0.78573596],\n", " [-0.25609785, -0.7132329 , 0.2500864 ],\n", " [-0.64630085, 0.39321756, -1.0203307 ],\n", " [ 0.38721725, 0.86828285, 0.10860055]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", " BatchNorm_0: {\n", " scale: DeviceArray([1., 1., 1.], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", " },\n", " batch_stats: {\n", " BatchNorm_0: {\n", " mean: DeviceArray([ 0.00059601, -0.00103457, 0.00166948], dtype=float32),\n", " var: DeviceArray([0.9907686, 0.9923046, 0.992195 ], dtype=float32),\n", " },\n", " },\n", "})\n", "initialized variable shapes:\n", " FrozenDict({\n", " batch_stats: {\n", " BatchNorm_0: {\n", " mean: (3,),\n", " var: (3,),\n", " },\n", " },\n", " params: {\n", " BatchNorm_0: {\n", " bias: (3,),\n", " scale: (3,),\n", " },\n", " Dense_0: {\n", " bias: (3,),\n", " kernel: (4, 3),\n", " },\n", " },\n", "})\n", "output:\n", " [[[-0.21496922 0.21550177 -0.35633382]\n", " [-0.21496922 -2.0458 1.3015485 ]\n", " [-0.21496922 -0.925116 -0.35633382]\n", " [-0.6595459 0.21550177 0.3749205 ]]\n", "\n", " [[-0.21496922 1.642865 -0.35633382]\n", " [-0.21496922 1.3094063 -0.88034123]\n", " [ 2.5726683 0.21550177 0.34353197]\n", " [-0.21496922 0.21550177 1.6778195 ]]\n", "\n", " [[-1.6060593 0.21550177 -1.9460517 ]\n", " [ 1.4126908 -1.4898677 1.2790381 ]\n", " [-0.21496922 0.21550177 -0.35633382]\n", " [-0.21496922 0.21550177 -0.7251308 ]]]\n", "eval output:\n", " [[[ 3.2246590e-01 2.6108384e-02 4.4821960e-01]\n", " [ 8.5726947e-02 -5.4385906e-01 3.8821870e-01]\n", " [-2.3933809e-01 -2.7381191e-01 -1.7526165e-01]\n", " [-6.2515378e-02 -5.2414006e-01 1.7029770e-01]]\n", "\n", " [[ 1.5014435e-01 3.4498507e-01 -1.3554120e-01]\n", " [-3.6971044e-04 2.6463276e-01 -1.2491019e-01]\n", " [ 3.8763803e-01 2.9023719e-01 1.6291586e-01]\n", " [ 4.1320035e-01 4.1468274e-02 4.7670874e-01]]\n", "\n", " [[-1.9433719e-01 5.2831882e-01 -3.7554008e-01]\n", " [ 2.2608691e-01 -4.0989807e-01 3.8292480e-01]\n", " [-2.4945706e-01 1.6170470e-01 -2.5247774e-01]\n", " [-7.2220474e-02 1.2077977e-01 -8.8408351e-02]]]\n" ] } ], "source": [ "class Block(nn.Module):\n", " features: int\n", " training: bool\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = nn.Dense(self.features)(inputs)\n", " x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)\n", " x = nn.BatchNorm(use_running_average=not self.training)(x)\n", " return x\n", "\n", "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", "x = random.uniform(key1, (3,4,4))\n", "\n", "model = Block(features=3, training=True)\n", "\n", "init_variables = model.init({'params': key2, 'dropout': key3}, x)\n", "_, init_params = flax.core.pop(init_variables, 'params')\n", "\n", "# When calling `apply` with mutable kinds, returns a pair of output,\n", "# mutated_variables.\n", "y, mutated_variables = model.apply(\n", " init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])\n", "\n", "# Now we reassemble the full variables from the updates (in a real training\n", "# loop, with the updated params from an optimizer).\n", "updated_variables = flax.core.freeze(dict(params=init_params,\n", " **mutated_variables))\n", "\n", "print('updated variables:\\n', updated_variables)\n", "print('initialized variable shapes:\\n',\n", " jax.tree_util.tree_map(jnp.shape, init_variables))\n", "print('output:\\n', y)\n", "\n", "# Let's run these model variables during \"evaluation\":\n", "eval_model = Block(features=3, training=False)\n", "y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value.\n", "print('eval output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# JAX transformations inside modules" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## JIT" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "It's not immediately clear what use this has, but you can compile specific submodules if there's a reason to.\n", "\n", "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "outputId": "3f324d0f-259f-40f0-8273-103f7fc281c5", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 0.2524199 0.11621253 0.5246693 0.19144788 0.2096542 ]\n", " [ 0.08557513 -0.04126885 0.2502836 0.03910369 0.16575359]\n", " [ 0.2804383 0.27751124 0.44969672 0.26016283 0.05875347]\n", " [ 0.2440843 0.17069656 0.45499086 0.20377949 0.13428023]]\n" ] } ], "source": [ "class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " # JIT the Module (it's __call__ fn by default.)\n", " x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.key(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = MLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Remat" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "For memory-expensive computations, we can `remat` our method to recompute a Module's output during a backwards pass.\n", "\n", "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing remat'd and undecorated initializations will look different." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "7fe8e13b-7dd6-4e55-ee50-ce334e8ed178", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[-0.14814317 0.06889858 -0.19695625 0.12019286 0.02068037]\n", " [-0.04439102 -0.06698258 -0.11579747 -0.19906905 -0.04342325]\n", " [-0.08875751 -0.13392815 -0.23153095 -0.39802808 -0.0868225 ]\n", " [-0.01606487 -0.02424064 -0.04190649 -0.07204203 -0.01571464]]\n" ] } ], "source": [ "class RematMLP(nn.Module):\n", " features: Sequence[int]\n", " # For all transforms, we can annotate a method, or wrap an existing\n", " # Module class. Here we annotate the method.\n", " @nn.remat\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.key(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = RematMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Vmap" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "You can now `vmap` Modules inside. The transform has a lot of arguments, they have the usual jax vmap args:\n", " - `in_axes` - an integer or `None` for each input argument\n", " - `out_axes` - an integer or `None` for each output argument\n", " - `axis_size` - the axis size if you need to give it explicitly\n", "\n", "In addition, we provide for each __kind__ of variable it's axis rules:\n", "\n", " - `variable_in_axes` - a dict from kinds to a single integer or `None` specifying the input axes to map\n", " - `variable_out_axes` - a dict from kinds to a single integer or `None` specifying the output axes to map\n", " - `split_rngs` - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis.\n", "\n", "\n", "Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "outputId": "223d880e-c7b2-4210-ebb5-dbfcdd9aed09", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'attention': {'key': {'kernel': (2, 64, 32)}, 'out': {'bias': (2, 64), 'kernel': (2, 32, 64)}, 'query': {'kernel': (2, 64, 32)}, 'value': {'kernel': (2, 64, 32)}}}}\n", "output:\n", " (3, 13, 2)\n" ] } ], "source": [ "class RawDotProductAttention(nn.Module):\n", " attn_dropout_rate: float = 0.1\n", " train: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, query, key, value, bias=None, dtype=jnp.float32):\n", " assert key.ndim == query.ndim\n", " assert key.ndim == value.ndim\n", "\n", " n = query.ndim\n", " attn_weights = lax.dot_general(\n", " query, key,\n", " (((n-1,), (n - 1,)), ((), ())))\n", " if bias is not None:\n", " attn_weights += bias\n", " norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim))\n", " attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims)\n", " attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights,\n", " deterministic=not self.train)\n", " attn_weights = attn_weights.astype(dtype)\n", "\n", " contract_dims = (\n", " tuple(range(n - 1, attn_weights.ndim)),\n", " tuple(range(0, n - 1)))\n", " y = lax.dot_general(\n", " attn_weights, value,\n", " (contract_dims, ((), ())))\n", " return y\n", "\n", "class DotProductAttention(nn.Module):\n", " qkv_features: Optional[int] = None\n", " out_features: Optional[int] = None\n", " train: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", " out_features = self.out_features or inputs_q.shape[-1]\n", "\n", " QKVDense = functools.partial(\n", " nn.Dense, features=qkv_features, use_bias=False, dtype=dtype)\n", " query = QKVDense(name='query')(inputs_q)\n", " key = QKVDense(name='key')(inputs_kv)\n", " value = QKVDense(name='value')(inputs_kv)\n", "\n", " y = RawDotProductAttention(train=self.train)(\n", " query, key, value, bias=bias, dtype=dtype)\n", "\n", " y = nn.Dense(features=out_features, dtype=dtype, name='out')(y)\n", " return y\n", "\n", "class MultiHeadDotProductAttention(nn.Module):\n", " qkv_features: Optional[int] = None\n", " out_features: Optional[int] = None\n", " batch_axes: Sequence[int] = (0,)\n", " num_heads: int = 1\n", " broadcast_dropout: bool = False\n", " train: bool = False\n", " @nn.compact\n", " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", " out_features = self.out_features or inputs_q.shape[-1]\n", "\n", " # Make multiheaded attention from single-headed dimension.\n", " Attn = nn.vmap(DotProductAttention,\n", " in_axes=(None, None, None),\n", " out_axes=2,\n", " axis_size=self.num_heads,\n", " variable_axes={'params': 0},\n", " split_rngs={'params': True,\n", " 'dropout': not self.broadcast_dropout})\n", "\n", " # Vmap across batch dimensions.\n", " for axis in reversed(sorted(self.batch_axes)):\n", " Attn = nn.vmap(Attn,\n", " in_axes=(axis, axis, axis),\n", " out_axes=axis,\n", " variable_axes={'params': None},\n", " split_rngs={'params': False, 'dropout': False})\n", "\n", " # Run the vmap'd class on inputs.\n", " y = Attn(qkv_features=qkv_features // self.num_heads,\n", " out_features=out_features,\n", " train=self.train,\n", " name='attention')(inputs_q, inputs_kv, bias)\n", "\n", " return y.mean(axis=-2)\n", "\n", "\n", "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", "x = random.uniform(key1, (3, 13, 64))\n", "\n", "model = functools.partial(\n", " MultiHeadDotProductAttention,\n", " broadcast_dropout=False,\n", " num_heads=2,\n", " batch_axes=(0,))\n", "\n", "init_variables = model(train=False).init({'params': key2}, x, x)\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "\n", "y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n", "print('output:\\n', y.shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Scan" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Scan allows us to apply `lax.scan` to Modules, including their parameters and mutable variables. To use it we have to specify how we want each \"kind\" of variable to be transformed. For scanned variables we specify similar to vmap via in `variable_in_axes`, `variable_out_axes`:\n", " - `nn.broadcast` broadcast the variable kind across the scan steps as a constant\n", " - `` scan along `axis` for e.g. unique parameters at each step\n", "\n", "OR we specify that the variable kind is to be treated like a \"carry\" by passing to the `variable_carry` argument.\n", "\n", "Further, for `scan`'d variable kinds, we further specify whether or not to split the rng at each step." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "outputId": "7d9ebed3-64de-4ca8-9dce-4b09ba9e31a1", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'lstm_cell': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}}\n", "output:\n", " ((DeviceArray([[-0.562219 , 0.92847174]], dtype=float32), DeviceArray([[-0.31570646, 0.2885693 ]], dtype=float32)), DeviceArray([[[-0.08265854, 0.01302483],\n", " [-0.10249066, 0.21991298],\n", " [-0.26609066, 0.22519003],\n", " [-0.27982554, 0.28393182],\n", " [-0.31570646, 0.2885693 ]]], dtype=float32))\n" ] } ], "source": [ "class SimpleScan(nn.Module):\n", " features: int\n", "\n", " @nn.compact\n", " def __call__(self, xs):\n", " LSTM = nn.scan(nn.LSTMCell,\n", " in_axes=1, out_axes=1,\n", " variable_broadcast='params',\n", " split_rngs={'params': False})\n", " lstm = LSTM(self.features, name=\"lstm_cell\")\n", "\n", " dummy_rng = random.key(0)\n", " input_shape = xs[:, 0].shape\n", " init_carry = lstm.initialize_carry(dummy_rng, input_shape)\n", "\n", " return lstm(init_carry, xs)\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "model = SimpleScan(2)\n", "init_variables = model.init(key2, xs)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "\n", "y = model.apply(init_variables, xs)\n", "print('output:\\n', y)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "name": "python", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs/linen_intro.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/linen_intro.ipynb) # Preface
CAVEAT PROGRAMMER
The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points. +++ ## Useful links ⟶ [Slides](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit?usp=sharing) for the core ideas of the new Functional Core and Linen ⟶ "Design tests" guided our design process. Many are available for [functional core](https://github.com/google/flax/tree/main/examples/core_design_test) and some for the [proposed Module abstraction](https://github.com/google/flax/tree/main/examples/linen_design_test/) ⟶ Ported examples: [ImageNet](https://github.com/google/flax/tree/main/examples/imagenet) and [WMT](https://github.com/google/flax/tree/main/examples/wmt) (to the proposed Module abstraction). TODO: Port to functional core. ⟶ Our new [discussion forums](https://github.com/google/flax/discussions/) +++ # Install and Import ```{code-cell} :tags: [skip-execution] # Install the newest JAXlib version. !pip install --upgrade -q pip jax jaxlib # Install Flax at head: !pip install --upgrade -q git+https://github.com/google/flax.git ``` ```{code-cell} import functools from typing import Any, Callable, Sequence, Optional import jax from jax import lax, random, numpy as jnp import flax from flax import linen as nn ``` # Invoking Modules +++ Let's instantiate a `Dense` layer. - Modules are actually objects in this API, so we provide _constructor arguments_ when initializing the Module. In this case, we only have to provide the output `features` dimension. ```{code-cell} model = nn.Dense(features=3) ``` We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables. We call the `init` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `init` with `(rngs, *args, **kwargs)` so in this case, just `(rng, input)`: ```{code-cell} :outputId: 3adfaeaf-977e-4e82-8adf-d254fae6eb91 # Make RNG Keys and a fake input. key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) # provide key and fake input to get initialized variables init_variables = model.init(key2, x) init_variables ``` We call the `apply` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `apply` with `(variables, *args, rngs=, mutable=, **kwargs)` where - `` are the optional _call time_ RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple __kinds__ of data, it's a dictionary of rng-keys per-kind, e.g. `{'params': key0, 'dropout': key1}` for a Module with dropout layers. - `` is an optional list of names of __kinds__ that are expected to be mutated during the call. e.g. `['batch_stats']` for a layer updating batchnorm statistics. So in this case, just `(variables, input)`: ```{code-cell} :outputId: e8c389a6-29f3-4f93-97ea-703e85a8b811 y = model.apply(init_variables, x) y ``` Additional points: - If you want to `init` or `apply` a Module using a method other than call, you need to provide the `method=` kwarg to `init` and `apply` to use it instead of the default `__call__`, e.g. `method='encode'`, `method='decode'` to apply the encode/decode methods of an autoencoder. +++ # Defining Basic Modules +++ ## Composing submodules +++ We support declaring modules in `setup()` that can still benefit from shape inference by using __Lazy Initialization__ that sets up variables the first time the Module is called. ```{code-cell} :outputId: 1a6c6a17-0b95-42c2-b5bf-b9ad80fd7758 :tags: [] class ExplicitMLP(nn.Module): features: Sequence[int] def setup(self): # we automatically know what to do with lists, dicts of submodules self.layers = [nn.Dense(feat) for feat in self.features] # for single submodules, we would just write: # self.layer1 = nn.Dense(feat1) def __call__(self, inputs): x = inputs for i, lyr in enumerate(self.layers): x = lyr(x) if i != len(self.layers) - 1: x = nn.relu(x) return x key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` Here we show the equivalent compact form of the MLP that declares the submodules inline using the `@compact` decorator. ```{code-cell} :outputId: b3709789-e66e-4e20-f6b2-04022f8a62bb :tags: [] class SimpleMLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): x = nn.Dense(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) # providing a name is optional though! # the default autonames would be "Dense_0", "Dense_1", ... # x = nn.Dense(feat)(x) return x key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` ## Declaring and using variables +++ Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls. For declaring parameters that aren't mutated inside the model, but rather by gradient descent, we use the syntax: `self.param(parameter_name, parameter_init_fn, *init_args, **init_kwargs)` with arguments: - `parameter_name` just the name, a string - `parameter_init_fn` a function taking an RNG key and a variable number of other arguments, i.e. `fn(rng, *args)`. typically those in `nn.initializers` take an `rng` and a `shape` argument. - the remaining arguments to feed to the init function when initializing. Again, we'll demonstrate declaring things inline as we typically do using the `@compact` decorator. ```{code-cell} :outputId: bc5cb1f2-c5e9-4159-d131-73247009e32f :tags: [] class SimpleDense(nn.Module): features: int kernel_init: Callable = nn.initializers.lecun_normal() bias_init: Callable = nn.initializers.zeros_init() @nn.compact def __call__(self, inputs): kernel = self.param('kernel', self.kernel_init, # RNG passed implicitly. (inputs.shape[-1], self.features)) # shape info. y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())),) bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleDense(features=3) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameters:\n', init_variables) print('output:\n', y) ``` We can also declare variables in setup, though in doing so you can't take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names. ```{code-cell} :outputId: 1e822bd8-7a08-4e80-e0e6-a86637c46772 :tags: [] class ExplicitDense(nn.Module): features_in: int # <-- explicit input shape features: int kernel_init: Callable = nn.initializers.lecun_normal() bias_init: Callable = nn.initializers.zeros_init() def setup(self): self.kernel = self.param('kernel', self.kernel_init, (self.features_in, self.features)) self.bias = self.param('bias', self.bias_init, (self.features,)) def __call__(self, inputs): y = lax.dot_general(inputs, self.kernel, (((inputs.ndim - 1,), (0,)), ((), ())),) y = y + self.bias return y key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitDense(features_in=4, features=3) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameters:\n', init_variables) print('output:\n', y) ``` ## General Variables +++ For declaring generally mutable _variables_ that may be mutated inside the model we use the call: `self.variable(variable_kind, variable_name, variable_init_fn, *init_args, **init_kwargs)` with arguments: - `variable_kind` the "kind" of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. `batch_stats` for the moving statistics for a batch norm layer or `cache` for autoregressive cache data. Note that parameters also have a kind, but they're set to the default `param` kind. - `variable_name` just the name, a string - `variable_init_fn` a function taking a variable number of other arguments, i.e. `fn(*args)`. Note that we __don't__ assume the need for an RNG, if you _do_ want an RNG, provide it via a `self.make_rng(variable_kind)` call in the provided arguments. - the remaining arguments to feed to the init function when initializing. ⚠️ Unlike parameters, we expect these to be mutated, so `self.variable` returns not a constant, but a _reference_ to the variable. To __get__ the raw value, you'd write `myvariable.value` and to __set__ it `myvariable.value = new_value`. ```{code-cell} :outputId: 2a8f5453-81b1-44dc-a431-d14b372c5710 :tags: [] class Counter(nn.Module): @nn.compact def __call__(self): # easy pattern to detect if we're initializing is_initialized = self.has_variable('counter', 'count') counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32)) if is_initialized: counter.value += 1 return counter.value key1 = random.key(0) model = Counter() init_variables = model.init(key1) print('initialized variables:\n', init_variables) y, mutated_variables = model.apply(init_variables, mutable=['counter']) print('mutated variables:\n', mutated_variables) print('output:\n', y) ``` ## Another Mutability and RNGs Example +++ Let's make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables: ```{code-cell} :outputId: 8f299a5c-74c8-476c-93fa-e5543901ec45 :tags: [] class Block(nn.Module): features: int training: bool @nn.compact def __call__(self, inputs): x = nn.Dense(self.features)(inputs) x = nn.Dropout(rate=0.5)(x, deterministic=not self.training) x = nn.BatchNorm(use_running_average=not self.training)(x) return x key1, key2, key3, key4 = random.split(random.key(0), 4) x = random.uniform(key1, (3,4,4)) model = Block(features=3, training=True) init_variables = model.init({'params': key2, 'dropout': key3}, x) _, init_params = flax.core.pop(init_variables, 'params') # When calling `apply` with mutable kinds, returns a pair of output, # mutated_variables. y, mutated_variables = model.apply( init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats']) # Now we reassemble the full variables from the updates (in a real training # loop, with the updated params from an optimizer). updated_variables = flax.core.freeze(dict(params=init_params, **mutated_variables)) print('updated variables:\n', updated_variables) print('initialized variable shapes:\n', jax.tree_util.tree_map(jnp.shape, init_variables)) print('output:\n', y) # Let's run these model variables during "evaluation": eval_model = Block(features=3, training=False) y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value. print('eval output:\n', y) ``` # JAX transformations inside modules +++ ## JIT +++ It's not immediately clear what use this has, but you can compile specific submodules if there's a reason to. _Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different. ```{code-cell} :outputId: 3f324d0f-259f-40f0-8273-103f7fc281c5 :tags: [] class MLP(nn.Module): features: Sequence[int] @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): # JIT the Module (it's __call__ fn by default.) x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) return x key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4,4)) model = MLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` ## Remat +++ For memory-expensive computations, we can `remat` our method to recompute a Module's output during a backwards pass. _Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing remat'd and undecorated initializations will look different. ```{code-cell} :outputId: 7fe8e13b-7dd6-4e55-ee50-ce334e8ed178 :tags: [] class RematMLP(nn.Module): features: Sequence[int] # For all transforms, we can annotate a method, or wrap an existing # Module class. Here we annotate the method. @nn.remat @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): x = nn.Dense(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) return x key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4,4)) model = RematMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` ## Vmap +++ You can now `vmap` Modules inside. The transform has a lot of arguments, they have the usual jax vmap args: - `in_axes` - an integer or `None` for each input argument - `out_axes` - an integer or `None` for each output argument - `axis_size` - the axis size if you need to give it explicitly In addition, we provide for each __kind__ of variable it's axis rules: - `variable_in_axes` - a dict from kinds to a single integer or `None` specifying the input axes to map - `variable_out_axes` - a dict from kinds to a single integer or `None` specifying the output axes to map - `split_rngs` - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis. Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation. ```{code-cell} :outputId: 223d880e-c7b2-4210-ebb5-dbfcdd9aed09 :tags: [] class RawDotProductAttention(nn.Module): attn_dropout_rate: float = 0.1 train: bool = False @nn.compact def __call__(self, query, key, value, bias=None, dtype=jnp.float32): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim attn_weights = lax.dot_general( query, key, (((n-1,), (n - 1,)), ((), ()))) if bias is not None: attn_weights += bias norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim)) attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims) attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights, deterministic=not self.train) attn_weights = attn_weights.astype(dtype) contract_dims = ( tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1))) y = lax.dot_general( attn_weights, value, (contract_dims, ((), ()))) return y class DotProductAttention(nn.Module): qkv_features: Optional[int] = None out_features: Optional[int] = None train: bool = False @nn.compact def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): qkv_features = self.qkv_features or inputs_q.shape[-1] out_features = self.out_features or inputs_q.shape[-1] QKVDense = functools.partial( nn.Dense, features=qkv_features, use_bias=False, dtype=dtype) query = QKVDense(name='query')(inputs_q) key = QKVDense(name='key')(inputs_kv) value = QKVDense(name='value')(inputs_kv) y = RawDotProductAttention(train=self.train)( query, key, value, bias=bias, dtype=dtype) y = nn.Dense(features=out_features, dtype=dtype, name='out')(y) return y class MultiHeadDotProductAttention(nn.Module): qkv_features: Optional[int] = None out_features: Optional[int] = None batch_axes: Sequence[int] = (0,) num_heads: int = 1 broadcast_dropout: bool = False train: bool = False @nn.compact def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): qkv_features = self.qkv_features or inputs_q.shape[-1] out_features = self.out_features or inputs_q.shape[-1] # Make multiheaded attention from single-headed dimension. Attn = nn.vmap(DotProductAttention, in_axes=(None, None, None), out_axes=2, axis_size=self.num_heads, variable_axes={'params': 0}, split_rngs={'params': True, 'dropout': not self.broadcast_dropout}) # Vmap across batch dimensions. for axis in reversed(sorted(self.batch_axes)): Attn = nn.vmap(Attn, in_axes=(axis, axis, axis), out_axes=axis, variable_axes={'params': None}, split_rngs={'params': False, 'dropout': False}) # Run the vmap'd class on inputs. y = Attn(qkv_features=qkv_features // self.num_heads, out_features=out_features, train=self.train, name='attention')(inputs_q, inputs_kv, bias) return y.mean(axis=-2) key1, key2, key3, key4 = random.split(random.key(0), 4) x = random.uniform(key1, (3, 13, 64)) model = functools.partial( MultiHeadDotProductAttention, broadcast_dropout=False, num_heads=2, batch_axes=(0,)) init_variables = model(train=False).init({'params': key2}, x, x) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4}) print('output:\n', y.shape) ``` ## Scan +++ Scan allows us to apply `lax.scan` to Modules, including their parameters and mutable variables. To use it we have to specify how we want each "kind" of variable to be transformed. For scanned variables we specify similar to vmap via in `variable_in_axes`, `variable_out_axes`: - `nn.broadcast` broadcast the variable kind across the scan steps as a constant - `` scan along `axis` for e.g. unique parameters at each step OR we specify that the variable kind is to be treated like a "carry" by passing to the `variable_carry` argument. Further, for `scan`'d variable kinds, we further specify whether or not to split the rng at each step. ```{code-cell} :outputId: 7d9ebed3-64de-4ca8-9dce-4b09ba9e31a1 :tags: [] class SimpleScan(nn.Module): features: int @nn.compact def __call__(self, xs): LSTM = nn.scan(nn.LSTMCell, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False}) lstm = LSTM(self.features, name="lstm_cell") dummy_rng = random.key(0) input_shape = xs[:, 0].shape init_carry = lstm.initialize_carry(dummy_rng, input_shape) return lstm(init_carry, xs) key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (1, 5, 2)) model = SimpleScan(2) init_variables = model.init(key2, xs) print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) y = model.apply(init_variables, xs) print('output:\n', y) ``` ================================================ FILE: docs/quick_start.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "6eea21b3", "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb)\n", "\n", "# Quick start\n", "\n", "Welcome to Flax!\n", "\n", "Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural\n", "network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train\n", "the network for image classification on the MNIST dataset." ] }, { "cell_type": "markdown", "id": "nwJWKIhdwxDo", "metadata": {}, "source": [ "## 1. Install Flax" ] }, { "cell_type": "code", "execution_count": null, "id": "bb81587e", "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "!pip install -q flax>=0.7.5" ] }, { "cell_type": "markdown", "id": "b529fbef", "metadata": {}, "source": [ "## 2. Loading data\n", "\n", "Flax can use any\n", "data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the\n", "samples to floating-point numbers." ] }, { "cell_type": "code", "execution_count": 48, "id": "bRlrHqZVXZvk", "metadata": {}, "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS for MNIST\n", "import tensorflow as tf # TensorFlow operations\n", "\n", "def get_datasets(num_epochs, batch_size):\n", " \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n", " train_ds = tfds.load('mnist', split='train')\n", " test_ds = tfds.load('mnist', split='test')\n", "\n", " train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", " tf.float32) / 255.,\n", " 'label': sample['label']}) # normalize train set\n", " test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'],\n", " tf.float32) / 255.,\n", " 'label': sample['label']}) # normalize test set\n", "\n", " train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", " train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", " test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from\n", " test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency\n", "\n", " return train_ds, test_ds" ] }, { "cell_type": "markdown", "id": "7057395a", "metadata": {}, "source": [ "## 3. Define network\n", "\n", "Create a convolutional neural network with the Linen API by subclassing\n", "[Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n", "Because the architecture in this example is relatively simple—you're just\n", "stacking layers—you can define the inlined submodules directly within the\n", "`__call__` method and wrap it with the\n", "[`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact)\n", "decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide." ] }, { "cell_type": "code", "execution_count": 49, "id": "cbc079cd", "metadata": {}, "outputs": [], "source": [ "from flax import linen as nn # Linen API\n", "\n", "class CNN(nn.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=10)(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "hy7iRu7_zlx-", "metadata": {}, "source": [ "### View model layers\n", "\n", "Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input." ] }, { "cell_type": "code", "execution_count": 50, "id": "lDHfog81zLQa", "metadata": { "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[3m CNN Summary \u001b[0m\n", "┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mflops \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mvjp_flops\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩\n", "│ │ CNN │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 8708106 │ 26957556 │ │\n", "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", "│ Conv_0 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 455424 │ 1341472 │ bias: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", "│ │ │ │ │ │ │ kernel: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", "│ │ │ │ │ │ │ │\n", "│ │ │ │ │ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 \u001b[0m │\n", "│ │ │ │ │ │ │ \u001b[1;2mKB)\u001b[0m │\n", "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", "│ Conv_1 │ Conv │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 6566144 │ 19704320 │ bias: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[6… │\n", "│ │ │ │ │ │ │ kernel: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", "│ │ │ │ │ │ │ │\n", "│ │ │ │ │ │ │ \u001b[1m18,496 \u001b[0m │\n", "│ │ │ │ │ │ │ \u001b[1;2m(74.0 KB)\u001b[0m │\n", "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", "│ Dense_0 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 1605888 │ 5620224 │ bias: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", "│ │ │ │ │ │ │ kernel: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[3… │\n", "│ │ │ │ │ │ │ │\n", "│ │ │ │ │ │ │ \u001b[1m803,072 \u001b[0m │\n", "│ │ │ │ │ │ │ \u001b[1;2m(3.2 MB)\u001b[0m │\n", "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", "│ Dense_1 │ Dense │ \u001b[2mfloat32\u001b[0m[1… │ \u001b[2mfloat32\u001b[0m[… │ 5130 │ 17940 │ bias: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[1… │\n", "│ │ │ │ │ │ │ kernel: │\n", "│ │ │ │ │ │ │ \u001b[2mfloat32\u001b[0m[2… │\n", "│ │ │ │ │ │ │ │\n", "│ │ │ │ │ │ │ \u001b[1m2,570 \u001b[0m │\n", "│ │ │ │ │ │ │ \u001b[1;2m(10.3 KB)\u001b[0m │\n", "├─────────┼────────┼────────────┼───────────┼─────────┼───────────┼────────────┤\n", "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m824,458 \u001b[0m\u001b[1m \u001b[0m│\n", "│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", "└─────────┴────────┴────────────┴───────────┴─────────┴───────────┴────────────┘\n", "\u001b[1m \u001b[0m\n", "\u001b[1m Total Parameters: 824,458 \u001b[0m\u001b[1;2m(3.3 MB)\u001b[0m\u001b[1m \u001b[0m\n", "\n", "\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp # JAX NumPy\n", "\n", "cnn = CNN()\n", "print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)),\n", " compute_flops=True, compute_vjp_flops=True))" ] }, { "cell_type": "markdown", "id": "4b5ac16e", "metadata": {}, "source": [ "## 4. Create a `TrainState`\n", "\n", "A common pattern in Flax is to create a single dataclass that represents the\n", "entire training state, including step number, parameters, and optimizer state.\n", "\n", "Because this is such a common pattern, Flax provides the class\n", "[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)\n", "that serves most basic usecases." ] }, { "cell_type": "code", "execution_count": null, "id": "qXr7JDpIxGNZ", "metadata": { "outputId": "1249b7fb-6787-41eb-b34c-61d736300844" }, "outputs": [], "source": [ "!pip install -q clu" ] }, { "cell_type": "code", "execution_count": 52, "id": "CJDaJNijyOji", "metadata": {}, "outputs": [], "source": [ "from clu import metrics\n", "from flax.training import train_state # Useful dataclass to keep train state\n", "from flax import struct # Flax dataclasses\n", "import optax # Common loss functions and optimizers" ] }, { "cell_type": "markdown", "id": "8b86b5f1", "metadata": {}, "source": [ "We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ)." ] }, { "cell_type": "code", "execution_count": 53, "id": "7W0qf7FC9uG5", "metadata": {}, "outputs": [], "source": [ "@struct.dataclass\n", "class Metrics(metrics.Collection):\n", " accuracy: metrics.Accuracy\n", " loss: metrics.Average.from_output('loss')" ] }, { "cell_type": "markdown", "id": "f3ce5e4c", "metadata": {}, "source": [ "You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need\n", "to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once." ] }, { "cell_type": "code", "execution_count": 54, "id": "e0102447", "metadata": {}, "outputs": [], "source": [ "class TrainState(train_state.TrainState):\n", " metrics: Metrics\n", "\n", "def create_train_state(module, rng, learning_rate, momentum):\n", " \"\"\"Creates an initial `TrainState`.\"\"\"\n", " params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image\n", " tx = optax.sgd(learning_rate, momentum)\n", " return TrainState.create(\n", " apply_fn=module.apply, params=params, tx=tx,\n", " metrics=Metrics.empty())" ] }, { "cell_type": "markdown", "id": "a15de484", "metadata": {}, "source": [ "## 5. Training step\n", "\n", "A function that:\n", "\n", "- Evaluates the neural network given the parameters and a batch of input images\n", " with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply)\n", " method (forward pass)).\n", "- Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.\n", "- Evaluates the gradient of the loss function using\n", " [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).\n", "- Applies a\n", " [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)\n", " of gradients to the optimizer to update the model's parameters.\n", "\n", "Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)\n", "decorator to trace the entire `train_step` function and just-in-time compile\n", "it with [XLA](https://www.tensorflow.org/xla) into fused device operations\n", "that run faster and more efficiently on hardware accelerators." ] }, { "cell_type": "code", "execution_count": 55, "id": "9b0af486", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def train_step(state, batch):\n", " \"\"\"Train for a single step.\"\"\"\n", " def loss_fn(params):\n", " logits = state.apply_fn({'params': params}, batch['image'])\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=batch['label']).mean()\n", " return loss\n", " grad_fn = jax.grad(loss_fn)\n", " grads = grad_fn(state.params)\n", " state = state.apply_gradients(grads=grads)\n", " return state" ] }, { "cell_type": "markdown", "id": "0ff5145f", "metadata": {}, "source": [ "## 6. Metric computation\n", "\n", "Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`." ] }, { "cell_type": "code", "execution_count": 56, "id": "961bf70b", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def compute_metrics(*, state, batch):\n", " logits = state.apply_fn({'params': state.params}, batch['image'])\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=batch['label']).mean()\n", " metric_updates = state.metrics.single_from_model_output(\n", " logits=logits, labels=batch['label'], loss=loss)\n", " metrics = state.metrics.merge(metric_updates)\n", " state = state.replace(metrics=metrics)\n", " return state" ] }, { "cell_type": "markdown", "id": "497241c3", "metadata": {}, "source": [ "## 7. Download data" ] }, { "cell_type": "code", "execution_count": 57, "id": "bff5393e", "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", "batch_size = 32\n", "\n", "train_ds, test_ds = get_datasets(num_epochs, batch_size)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "809ae1a0", "metadata": {}, "source": [ "## 8. Seed randomness\n", "\n", "- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible.\n", "- Get one\n", " [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)\n", " and use it for parameter initialization. (Learn\n", " more about\n", " [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html)\n", " and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).)" ] }, { "cell_type": "code", "execution_count": 58, "id": "xC4MFyBsfT-U", "metadata": {}, "outputs": [], "source": [ "tf.random.set_seed(0)" ] }, { "cell_type": "code", "execution_count": 59, "id": "e4f6f4d3", "metadata": {}, "outputs": [], "source": [ "init_rng = jax.random.key(0)" ] }, { "cell_type": "markdown", "id": "80fbb60b", "metadata": {}, "source": [ "## 9. Initialize the `TrainState`\n", "\n", "Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics\n", "and puts them into the training state dataclass that is returned." ] }, { "cell_type": "code", "execution_count": 60, "id": "445fcab0", "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.01\n", "momentum = 0.9" ] }, { "cell_type": "code", "execution_count": 61, "id": "5221eafd", "metadata": {}, "outputs": [], "source": [ "state = create_train_state(cnn, init_rng, learning_rate, momentum)\n", "del init_rng # Must not be used anymore." ] }, { "cell_type": "markdown", "id": "b1c00230", "metadata": {}, "source": [ "## 10. Train and evaluate\n", "\n", "Create a \"shuffled\" dataset by:\n", "- Repeating the dataset equal to the number of training epochs\n", "- Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from\n", " - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer\n", "\n", "Define a training loop that:\n", "- Randomly samples batches from the dataset.\n", "- Runs an optimization step for each training batch.\n", "- Computes the mean training metrics across each batch in an epoch.\n", "- Computes the metrics for the test set using the updated parameters.\n", "- Records the train and test metrics for visualization.\n", "\n", "Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy." ] }, { "cell_type": "code", "execution_count": 62, "id": "74295360", "metadata": {}, "outputs": [], "source": [ "# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs\n", "num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs" ] }, { "cell_type": "code", "execution_count": 63, "id": "cRtnMZuQFlKl", "metadata": {}, "outputs": [], "source": [ "metrics_history = {'train_loss': [],\n", " 'train_accuracy': [],\n", " 'test_loss': [],\n", " 'test_accuracy': []}" ] }, { "cell_type": "code", "execution_count": 64, "id": "2c40ce90", "metadata": { "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train epoch: 1, loss: 0.20290373265743256, accuracy: 93.87000274658203\n", "test epoch: 1, loss: 0.07591685652732849, accuracy: 97.60617065429688\n", "train epoch: 2, loss: 0.05760224163532257, accuracy: 98.28500366210938\n", "test epoch: 2, loss: 0.050395529717206955, accuracy: 98.3974380493164\n", "train epoch: 3, loss: 0.03897436335682869, accuracy: 98.83000183105469\n", "test epoch: 3, loss: 0.04574578255414963, accuracy: 98.54767608642578\n", "train epoch: 4, loss: 0.028721099719405174, accuracy: 99.15166473388672\n", "test epoch: 4, loss: 0.035722777247428894, accuracy: 98.91827392578125\n", "train epoch: 5, loss: 0.021948494017124176, accuracy: 99.37999725341797\n", "test epoch: 5, loss: 0.035723842680454254, accuracy: 98.87820434570312\n", "train epoch: 6, loss: 0.01705147698521614, accuracy: 99.54833221435547\n", "test epoch: 6, loss: 0.03456473350524902, accuracy: 98.96835327148438\n", "train epoch: 7, loss: 0.014007646590471268, accuracy: 99.6116714477539\n", "test epoch: 7, loss: 0.04089202359318733, accuracy: 98.7880630493164\n", "train epoch: 8, loss: 0.011265480890870094, accuracy: 99.73333740234375\n", "test epoch: 8, loss: 0.03337760642170906, accuracy: 98.93830108642578\n", "train epoch: 9, loss: 0.00918484665453434, accuracy: 99.78334045410156\n", "test epoch: 9, loss: 0.034478139132261276, accuracy: 98.96835327148438\n", "train epoch: 10, loss: 0.007260234095156193, accuracy: 99.84166717529297\n", "test epoch: 10, loss: 0.032822880893945694, accuracy: 99.07852172851562\n" ] } ], "source": [ "for step,batch in enumerate(train_ds.as_numpy_iterator()):\n", "\n", " # Run optimization steps over training batches and compute batch metrics\n", " state = train_step(state, batch) # get updated train state (which contains the updated parameters)\n", " state = compute_metrics(state=state, batch=batch) # aggregate batch metrics\n", "\n", " if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed\n", " for metric,value in state.metrics.compute().items(): # compute metrics\n", " metrics_history[f'train_{metric}'].append(value) # record metrics\n", " state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch\n", "\n", " # Compute metrics on the test set after each training epoch\n", " test_state = state\n", " for test_batch in test_ds.as_numpy_iterator():\n", " test_state = compute_metrics(state=test_state, batch=test_batch)\n", "\n", " for metric,value in test_state.metrics.compute().items():\n", " metrics_history[f'test_{metric}'].append(value)\n", "\n", " print(f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", " f\"loss: {metrics_history['train_loss'][-1]}, \"\n", " f\"accuracy: {metrics_history['train_accuracy'][-1] * 100}\")\n", " print(f\"test epoch: {(step+1) // num_steps_per_epoch}, \"\n", " f\"loss: {metrics_history['test_loss'][-1]}, \"\n", " f\"accuracy: {metrics_history['test_accuracy'][-1] * 100}\")" ] }, { "cell_type": "markdown", "id": "gfsecJzvzgCT", "metadata": {}, "source": [ "## 11. Visualize metrics" ] }, { "cell_type": "code", "execution_count": 65, "id": "Zs5atiqIG9Kz", "metadata": { "outputId": "431a2fcd-44fa-4202-f55a-906555f060ac" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3cAAAE/CAYAAADlpzo+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABsiElEQVR4nO3dd3yddd3/8dcneyfNaJs26aJ7JAFKyxLUKlBWAQEBmQK9uRUEb/EWt7hufooDFcEyBVFUFK1QLFBEZLeFpLt00qRJ23Rk7+T7++O6kp6maXPSjJPxfj4e53HONc/3nJ7mOu/zXeacQ0RERERERAa2sFAXQERERERERLpP4U5ERERERGQQULgTEREREREZBBTuREREREREBgGFOxERERERkUFA4U5ERERERGQQULgTEREREREZBBTuRPqYmW03s0+EuhwiIiK9ycxeNbMDZhYd6rKIDBUKdyIiIiLSo8xsHPARwAEX9uHzRvTVc4n0Rwp3Iv2AmUWb2c/NrNi//bz1l04zSzez58yszMz2m9l/zCzM3/YVM9tpZpVmttHM5oX2lYiIiABwLfA28DhwXetKM8s2s7+aWamZ7TOzXwVsu9nM1vvXtHVmdoK/3pnZxID9Hjez7/uPP2pmRf71cBfwmJkN86+bpX7N4XNmlhVwfKqZPeZfbw+Y2d/89WvM7IKA/SLNbK+Z5fXSeyTS4xTuRPqHrwMnA3lALjAH+Ia/7UtAEZABjAC+BjgzmwLcCpzknEsEzga292mpRUREOnYt8JR/O9vMRphZOPAc8CEwDhgNPA1gZpcB3/GPS8Kr7dsX5HONBFKBscBCvO+3j/nLY4Ba4FcB+z8JxAEzgOHAz/z1TwBXB+x3LlDinMsPshwiIaeqa5H+4TPAbc65PQBmdjfwG+CbQCOQCYx1zm0G/uPv0wxEA9PNrNQ5tz0UBRcREQlkZqfjBas/Oef2mtkW4Cq8mrxRwJedc03+7q/79zcBP3LOLfeXN3fhKVuAbzvn6v3lWuAvAeX5AfAv/3EmMB9Ic84d8Hf5t3//O+CbZpbknKsArsELgiIDhmruRPqHUXi/ZLb60F8H8GO8i9yLZrbVzO4C8IPeHXi/dO4xs6fNbBQiIiKhdR3wonNur7/8e39dNvBhQLALlA1sOcbnK3XO1bUumFmcmf3GzD40swrgNSDFrznMBvYHBLs2zrli4A3gU2aWghcCnzrGMomEhMKdSP9QjPcrZ6sx/jqcc5XOuS855yYAFwD/09q3zjn3e+dc6y+kDvh/fVtsERGRg8wsFrgcONPMdvn94L6I1+VgNzDmCIOeFALHHeG0NXjNKFuNbLfdtVv+EjAFmOucSwLOaC2e/zypfnjryG/xmmZeBrzlnNt5hP1E+iWFO5HQiDSzmNYb8AfgG2aWYWbpwLfwmodgZueb2UQzM6ACaAaazWyKmX3cH3ilDq8ZSnNoXo6IiAgAF+Fdi6bj9SPPA6bhdSm4CCgB7jGzeP8aeJp/3MPAnWZ2onkmmlnrj575wFVmFm5m5wBndlKGRLxrYpmZpQLfbt3gnCsBXgB+7Q+8EmlmZwQc+zfgBOB2vD54IgOKwp1IaCzBu/C03mKAFcAqYDXwHvB9f99JwMtAFfAW8Gvn3Kt4/e3uAfYCu/A6hX+tz16BiIjI4a4DHnPO7XDO7Wq94Q1ociVeC5SJwA68wcI+DeCc+zPwA7wmnJV4ISvVP+ft/nFleH3U/9ZJGX4OxOJdH98G/tlu+zV4/dk3AHvwujjgl6O1v9544K/Bv2yR/sGca1+TLSIiIiIyNJnZt4DJzrmrO91ZpJ/RaJkiIiIiInhz4AE34tXuiQw4apYpIiIiIkOemd2MN+DKC86510JdHpFjoWaZIiIiIiIig4Bq7kRERERERAYBhTsREREREZFBYEANqJKenu7GjRsX6mKIiEgvW7ly5V7nXEaoyzFQ6PooIjJ0HO0aOaDC3bhx41ixYkWoiyEiIr3MzD4MdRkGEl0fRUSGjqNdI9UsU0REpIeZ2aNmtsfM1hxhu5nZL8xss5mtMrMTAradY2Yb/W139V2pRURkoFO4ExER6XmPA+ccZft8YJJ/Wwg8AGBm4cD9/vbpwJVmNr1XSyoiIoOGwp2IiEgP8+fI2n+UXRYATzjP20CKmWUCc4DNzrmtzrkG4Gl/XxERkU4NqD53IiL9RWNjI0VFRdTV1YW6KANaTEwMWVlZREZGhroofW003mTJrYr8dR2tn3ssT6DP6MAzhP8/iEgPUbgTETkGRUVFJCYmMm7cOMws1MUZkJxz7Nu3j6KiIsaPHx/q4vS1jj407ijrDz+B2UK8Jp2MGTPmsO36jA4sQ/z/g4j0EDXLFBE5BnV1daSlpelLczeYGWlpaUO1ZqkIyA5YzgKKj7L+MM65Rc652c652RkZh4+Irc/owDLE/z+ISA9RuBMROUb60tx9Q/g9XAxc64+aeTJQ7pwrAZYDk8xsvJlFAVf4+x6TIfz+Dkj69xKR7lKzTBERkR5mZn8APgqkm1kR8G0gEsA59yCwBDgX2AzUADf425rM7FZgKRAOPOqcW9vnL0BERAYk1dyJiAxAZWVl/PrXv+7yceeeey5lZWVdPu7666/nmWee6fJxQ5Vz7krnXKZzLtI5l+Wce8Q596Af7PBHyfy8c+4459ws59yKgGOXOOcm+9t+ELpX0X19/TkVERnqhlS4W1tczlPvHHFCdxGRAeNIX5qbm5uPetySJUtISUnppVKJHGqwfk47K7+IDG1NzS0cqG7gw33VrC4q5/VNe1myuoSn393Bpt2VvfrcQTXLNLNzgPvwmog87Jy7p932zwBf8RergP92zhUc7VgzSwX+CIwDtgOXO+cOdPP1HNVL63Zz37JNLMgbTUK0WqSKyMB11113sWXLFvLy8oiMjCQhIYHMzEzy8/NZt24dF110EYWFhdTV1XH77bezcOFCAMaNG8eKFSuoqqpi/vz5nH766bz55puMHj2av//978TGxnb63MuWLePOO++kqamJk046iQceeIDo6GjuuusuFi9eTEREBGeddRb33nsvf/7zn7n77rsJDw8nOTmZ1157rbffGulH+vpz+tBDD7Fo0SIaGhqYOHEiTz75JHFxcezevZtbbrmFrVu3AvDAAw9w6qmn8sQTT3DvvfdiZuTk5PDkk09y/fXXc/7553PppZcCkJCQQFVVFa+++ip33313UOX/5z//yde+9jWam5tJT0/npZdeYsqUKbz55ptkZGTQ0tLC5MmTefvtt0lPT++DfwkR6YrmFkdlXSMVtU1U1DVSUdvo3zdR3va4kYq6pkO2ta6vbjjyD0DfXTCDSSMSe63snSYcMwsH7gc+iTeK13IzW+ycWxew2zbgTOfcATObDywC5nZy7F3AMufcPWZ2l7/8FXpRbnYKzsHqonJOOS6tN59KRIaQu/+xlnXFFT16zumjkvj2BTOOuP2ee+5hzZo15Ofn8+qrr3LeeeexZs2atiHUH330UVJTU6mtreWkk07iU5/6FGlph/7d27RpE3/4wx946KGHuPzyy/nLX/7C1VdffdRy1dXVcf3117Ns2TImT57MtddeywMPPMC1117Ls88+y4YNGzCztiZ13/3ud1m6dCmjR49WM7sQCsVnFPr+c3rJJZdw8803A/CNb3yDRx55hNtuu40vfOELnHnmmTz77LM0NzdTVVXF2rVr+cEPfsAbb7xBeno6+/cfbc55z7vvvttp+VtaWrj55pt57bXXGD9+PPv37ycsLIyrr76ap556ijvuuIOXX36Z3NxcBTuRXtLc4qiq88JWeQfh69BQdvi2qvqmo54/zCAxJpKk2AiSYiJJiolkXHocybHe46TYSJJiIvx7f9nfNzU+qldfezDVV3OAzc65rQBm9jSwAGgLd865NwP2fxtv6ObOjl2A19kc4LfAq/R2uMtKAWBVUZnCnYgMKnPmzDlkbqxf/OIXPPvsswAUFhayadOmw740jx8/nry8PABOPPFEtm/f3unzbNy4kfHjxzN58mQArrvuOu6//35uvfVWYmJiuOmmmzjvvPM4//zzATjttNO4/vrrufzyy7nkkkt64JXKQNbbn9M1a9bwjW98g7KyMqqqqjj77LMBeOWVV3jiiScA2mqRn3jiCS699NK2gJWamtoj5S8tLeWMM85o26/1vJ/97GdZsGABd9xxB48++ig33HBDp88nMpQ456hvaqGqvonq+iYq/ZBV1Xpff+hyZZ23X1V9E5X1TVTVNVJd39y239GYQWJ0YPiKYExq3CHLHQY1/3F8VARhYf1zdNtgwt1ooDBguQiYe5T9bwReCOLYEf6wzzjnSsxseFAl7obU+CiyU2MpKCrr7acSkSGks9qLvhAfH9/2+NVXX+Xll1/mrbfeIi4ujo9+9KMdzp0VHR3d9jg8PJza2tpOn8e5DufTJiIignfffZdly5bx9NNP86tf/YpXXnmFBx98kHfeeYfnn3+evLw88vPzD/vyLr2vP3xGofc/p9dffz1/+9vfyM3N5fHHH+fVV1894r7OuQ6nHoiIiKClpaVtn4aGhi6V/0jnzc7OZsSIEbzyyiu88847PPXUU0csm8hA0tzi2gLV4aGskar65oDHTf5yY1tAaz2uqr6JxuaOrzGBwgwSoiNIjIkkITqChBgviGWlxLYtx0e3hrODAS65tfYsNpKEfhzOuiuYcNfRK+/wnTezj+GFu9O7euwRn9xsIbAQYMyYMV05tEO5WSm8v6Os2+cREQmlxMREKis77pRdXl7OsGHDiIuLY8OGDbz99ts99rxTp05l+/btbN68ua1P05lnnklVVRU1NTWce+65nHzyyUycOBGALVu2MHfuXObOncs//vEPCgsLFe6GkL7+nFZWVpKZmUljYyNPPfUUo0ePBmDevHk88MAD3HHHHTQ3N1NdXc28efO4+OKL+eIXv0haWhr79+8nNTWVcePGsXLlSi6//HL+/ve/09jY2KXyn3LKKXz+859n27Ztbc0yW2vvbrrpJq6++mquueYawsPDu/16RXqbc47SynoKD9SwY38NO/bVtj3eeaCWAzUN1Bylf1mg2MhwEmIiSIz2wldCdATZqXEkBgQyL7R59223mIP3idGRxESGaU7Iowgm3BUB2QHLWUBx+53MLAd4GJjvnNsXxLG7zSzTr7XLBPZ09OTOuUV4ffiYPXt2l4JhR3KzUnhuVQmllfVkJEZ3foCISD+UlpbGaaedxsyZM4mNjWXEiBFt28455xwefPBBcnJymDJlCieffHKPPW9MTAyPPfYYl112WduAKrfccgv79+9nwYIFbTUXP/vZzwD48pe/zKZNm3DOMW/ePHJzc3usLNL/9fXn9Hvf+x5z585l7NixzJo1qy1Y3nfffSxcuJBHHnmE8PBwHnjgAU455RS+/vWvc+aZZxIeHs7xxx/P448/zs0338yCBQuYM2cO8+bNO6S2LtCRyp+RkcGiRYu45JJLaGlpYfjw4bz00ksAXHjhhdxwww1qkin9SlV9E4X7vcBW2Ho7UMuO/TUUHaihrrHlkP1HJEWTPSyOOeNTSYuPagtfiYcFtEhvW1QE8dHhRIQPqUH6Q8aO1MSmbQezCOADYB6wE1gOXBU4qaqZjQFeAa4N7H93tGPN7MfAvoABVVKdc/97tLLMnj3brVix4mi7dOrdbfu5/Ddv8ch1s5k3bUTnB4iIdGD9+vVMmzYt1MUYFDp6L81spXNudoiKNOB0dH3UZ7T/WbFiBV/84hf5z3/+c8R99O8mPa2xuYWSsrq2GrdDgtyBWvZXNxyyf2uN2pjUWLKHxTEmLY7sYXFkp8aRNSyWmEjVOofa0a6RndbcOeeazOxWYCnedAaP+uHsFn/7g8C3gDTg1341aZNzbvaRjvVPfQ/wJzO7EdgBXNatVxmkmaOTCDMoKCpXuBMREZE+cc899/DAAw+or530OOcc+6ob2kJb0YFaduyraQtzJeV1NLccrMyJCDNGD4tlTGocZ49KZkxqHNmp3nL2sDhS4iLV7HEAC2qyN+fcEmBJu3UPBjy+Cbgp2GP99fvwavT6VFxUBJNHJFJQWNbXTy0i0u99/vOf54033jhk3e23365mZNKvDMTP6V133cVdd90V6mLIAFXb0EzhgcBat4PNJnfsrzms31t6QjRjUmM5ceywttCW7Ye4kUkxaiI5iA3Jmbxzs1JYum7XEUe0EhEZqu6///5QF0GkU/qcykDV3OKoqG2krNabf628tpGymgZvXY2/3LrNX95X3cDeqvpDzhMXFd4W2E45Lo0xqXF+DZzXdDIuakh+xReGaLjLyU7mjysKKdxfy5i0uFAXR0REREQGCOcc1Q3NlNU0HBLCytrCmj9xdm0jZbUNh6yrrDv6/GtxUeEkx0a23calx3H8mJS20NYa4NLio1RBIR0akuGudTLz/KIyhTsRERGRIayusZnNe6rYVV4XUGvWcFhgqwiobWtqOfKAhJHhdkhAG54Yw6ThiW3LKXGH3ifHRrVti4pQc0npniEZ7qaMTCQ6IoxVhWVcmDsq1MURERERkV7mnGN3RT3rSypYv6uCDSWVrC+pYOve6kMGHAEwg8ToCFLiotqC2OhhsaTEHjmYta6LiwpXrZqEzJAMd5HhYcwYlURBUVmoiyIiIjJolZWV8fvf/57Pfe5zXT725z//OQsXLiQuTi1spOvqGpvZtLuqLcitL6lgw65KymoOTko/OiWWaZmJnD1jJNMyk8gaFtsW0BJjIgkPU0CTgWdIhjuA3OwU/vDuDpqaWzRikIgMOL39pXncuHGsWLGC9PT07hRThriysjJ+/etfH/Pn9Oqrr+4X4a6pqYmIiCH7lalfc85RUl7Hhl0VrPdr4jbsqmRraRWtlXGxkeFMGZnI/JkjmToyiWmZSUwZ6TWTFBlshuxfqtysFB57Yzub9lQxLTMp1MUREemSwfKlWQa3u+66iy1btpCXl8cnP/lJhg8fzp/+9Cfq6+u5+OKLufvuu6murubyyy+nqKiI5uZmvvnNb7J7926Ki4v52Mc+Rnp6Ov/61786PP9///d/s3z5cmpra7n00ku5++67AVi+fDm333471dXVREdHs2zZMuLi4vjKV77C0qVLMTNuvvlmbrvttkN+yFixYgV33nknr776Kt/5zncoLi5m+/btpKen88Mf/pBrrrmG6upqAH71q19x6qmnAvCjH/2IJ598krCwMObPn8/NN9/MZZddxnvvvQfApk2buOKKK1i5cmUfvOuDV21DMx/srjwsyJXXHqyNyxoWy7TMJM6dOZKpmV6QG5Map1o4GTKGbrjLTgGgoLBM4U5EuueFu2DX6p4958hZMP+eI27u7S/NgX7605/y6KOPAnDTTTdxxx13dHjuT3/609x1110sXryYiIgIzjrrLO69994ee0ukG0LwGQVv4u41a9aQn5/Piy++yDPPPMO7776Lc44LL7yQ1157jdLSUkaNGsXzzz8PQHl5OcnJyfz0pz/lX//611Frj3/wgx+QmppKc3Mz8+bNY9WqVUydOpVPf/rT/PGPf+Skk06ioqKC2NhYFi1axLZt23j//feJiIhg//79nb7ElStX8vrrrxMbG0tNTQ0vvfQSMTExbNq0iSuvvJIVK1bwwgsv8Le//Y133nmHuLg49u/fT2pqKsnJyeTn55OXl8djjz3G9ddf36W3dyhzzrGzrJYNJQFBblcF2/dWt9XGxUV5tXHn5WQybWQiU/3auKQY1cbJ0DZkw924tDiSYiIoKCrjijljQl0cEZEu6e0vza1WrlzJY489xjvvvINzjrlz53LmmWeydevWw869f/9+nn32WTZs2ICZUVZW1ptvgQwwL774Ii+++CLHH388AFVVVWzatImPfOQj3HnnnXzlK1/h/PPP5yMf+UjQ5/zTn/7EokWLaGpqoqSkhHXr1mFmZGZmctJJJwGQlOT9gPvyyy9zyy23tDWvTE1N7fT8F154IbGxsQA0NjZy6623kp+fT3h4OB988EHbeW+44Ya2mvDW895000089thj/PSnP+WPf/wj7777btCvayipaWhi465KNuyqZEPJwSAXOGXAmNQ4po5M5IKcUUzLTGTqSK82Lky1cSKHGbLhzszIzU6hoLA81EURkYGuk9qL3tYbX5pbvf7661x88cXEx8cDcMkll/Cf//yHc84557BzNzU1ERMTw0033cR5553H+eef36OvU7ohxJ9R8GpjvvrVr/Jf//Vfh21buXIlS5Ys4atf/SpnnXUW3/rWtzo937Zt27j33ntZvnw5w4YN4/rrr6eurg7nXIcjFR5pfUREBC0tLQDU1dUdsq31cw/ws5/9jBEjRlBQUEBLSwsxMTFHPe+nPvUp7r77bj7+8Y9z4oknkpaW1ulrGuz2Vzew8sMDfnNKL8ht31eN82vj4qPCmZqZxIW5o5iWmcS0zESmjEwiIXrIfl0V6bIh/b8lNyuFB/69hdqGZmKjwkNdHBGRY9LTX5rbn7sjkydP7vDc7777LsuWLePpp5/mV7/6Fa+88soxvSYZHBITE6msrATg7LPP5pvf/Caf+cxnSEhIYOfOnURGRtLU1ERqaipXX301CQkJPP7444cce6Qa5oqKCuLj40lOTmb37t288MILfPSjH2Xq1KkUFxezfPlyTjrpJCorK4mNjeWss87iwQcf5KMf/Whbs8zU1FTGjRvHypUrmT9/Pn/5y1+O+FrKy8vJysoiLCyM3/72tzQ3NwNw1lln8d3vfperrrrqkGaZMTExnH322fz3f/83jzzySM++sQNEaWU972zbxztb9/Putv1s3F3Ztm1cWhxTRyaxIM8PciO90SpVGyfSPUM63OVkJdPc4lhXUs6JYztvniEi0l/05pfmQGeccQbXX389d911F845nn32WZ588kmKi4sPO3dVVRU1NTWce+65nHzyyUycOLE33wIZANLS0jjttNOYOXMm8+fP56qrruKUU04BICEhgd/97nds3ryZL3/5y4SFhREZGckDDzwAwMKFC5k/fz6ZmZkd9g3Nzc3l+OOPZ8aMGUyYMIHTTjsNgKioKP74xz9y2223UVtbS2xsLC+//DI33XQTH3zwATk5OURGRnLzzTdz66238u1vf5sbb7yRH/7wh8ydO/eIr+Vzn/scn/rUp/jzn//Mxz72sbZavXPOOYf8/Hxmz55NVFQU5557Lj/84Q8B+MxnPsNf//pXzjrrrB59X/urXeV1vLNtH29v3c872/axtdQbfCYuKpwTxw7jwrxRzBmfyvTMJOJVGyfSK+xIv8r2R7Nnz3YrVqzosfPtqahjzg+X8c3zp3Pj6eN77LwiMvitX7+eadOmhbQMV111FatWrWL+/PlkZWXx8MMPA0f/0jx79mx++ctfcv/99x/xSzMcOhVCRwOqLF269LBzjx49mgULFrQ1jbvzzju57rrrOn0dHb2XZrbSOTe7m2/RkNHR9bE/fEaHunvvvZfy8nK+973vBX3MQPp3KzpQwzt+kHtn234+3FcDeJN/zx43jLkT0pg7PpWZo5OJ1LRTIj3maNfIIR3uAE75v2WcNC6VX1x5fI+eV0QGt4H0Bay/U7jrPoW7/ufiiy9my5YtvPLKK12aL7K//rs55/hwXw3vbtvP235Ty51ltQAkx0Zy0rhUTp6QytzxaUwflaSpB0R60dGukUO+TjwnK5lVRWWhLoaIiIgcwdy5c6mvrz9k3ZNPPsmsWbNCVKLOPfvss6EuQrc459hSWn1In7ldFd6AM2nxUcwZn8rNHxnP3AlpTBmRqL5yIv3EkA93udkpLF27m7KaBlLiokJdHBGRPjUQvzTL0PPOO++EugiDXkuLY9OeqrYw9862/eyt8v42ZCRGM3d8KnMnpHHy+FQmDk/ocIRQEQk9hbusFABWFZVzxuSM0BZGRKSP6UuzyNDU3OLYsKuirc/cu9v2c6CmEYDM5BhOn5jW1mdufHq8wpzIADHkw92srGQACgrLFO5EpEuONL+VBG8g9fseiPQZHVh68/9DU3MLa4sr2mrmlm/fT4U/UXh2aizzpo1gzvhUTh6fRnZqrD43IgPUkA93STGRHJcRT4H63YlIF8TExLBv3z7S0tL0JegYOefYt29f22TQ0rP0GR1Yevr/Q2NzC6uKytvC3MoPD1BV74W58enxnDsrk7n+ACijUmJ75DlFJPSGfLgDr2nma5v26hdOEQlaVlYWRUVFlJaWhrooA1pMTAxZWVmhLsagpM/owNPd/w81DU28tG43/ygo4Y3Ne6lt9CZanzg8gQV5o9qaWY5I0g8qIoOVwh3eoCp/fX8nJeV1+vVKRIISGRnJ+PGaH1P6L31Gh4b6pmb+vbGUxQXFLFu/h9rGZkYmxXDZ7CxOnpDGnPGppCdEh7qYItJHFO7wpkMAWFVUpnAnIiIi/VpTcwtvbd3H4vxi/rl2F5V1TQyLi+SSE0ZzYe4oThqXqqkJRIYohTtgWmYSkeFGfmE558zMDHVxRERkEDCzc4D7gHDgYefcPe22DwMeBY4D6oDPOufW+NtuB24GDHjIOffzPiy69EMtLY73dhxgcUExS1aXsLeqgYToCM6aMYILc0dx2sR0IsPDQl1MEWlpgeo9UF4EZTu8+/IiKC/0bmd8GaYv6LWnDyrcBXGBmgo8BpwAfN05d6+/fgrwx4BdJwDfcs793My+g3fhau0M8DXn3JJuvJZjFhMZzrTMJAoKy0Lx9CIiMsiYWThwP/BJoAhYbmaLnXPrAnb7GpDvnLvYv47eD8wzs5l418c5QAPwTzN73jm3qW9fhYSac461xRX8o6CY51aVsLOsluiIMOZNG86FuaP46JThxESGh7qYIkNLY92hYa01vLUGuYqd0Nxw6DHRyZCc5d0i43u1eJ2GuyAvUPuBLwAXBR7rnNsI5AWcZyfwbMAuP2sNgqGWk5XM394vpqXFqSmDiIh01xxgs3NuK4CZPQ0sAAKvndOB/wNwzm0ws3FmNgKYBrztnKvxj/03cDHwoz4sv4TQltIqFucX849VxWwtrSYizPjIpHTuPHsyn5g2gsSYyFAXUWRwcg5q9kO5H9TKCg8PctXtBqmyMEjM9ILb6BO8WrnkLEjOhpRs73FMcp+9hGBq7jq9QDnn9gB7zOy8o5xnHrDFOfdhN8rba3KzUvjd2zvYureKicMTQ10cEREZ2EYDhQHLRcDcdvsUAJcAr5vZHGAskAWsAX5gZmlALXAusKLXSywhtbOsln8UFPOPgmLWFldgBnPHp3LT6RM4Z+ZIUuOjQl1EkYGvqQEqiwNCW9HhQa6p9tBjIuMOhrWROX5gyz64LmkUhPefH1yCCXfBXKCCcQXwh3brbjWza/EuWl9yzh04hvP2iLzsFAAKCssV7kREpLs6agLSfobqe4D7zCwfWA28DzQ559ab2f8DXgKq8EJg02FPYLYQWAgwZsyYniu59JnSynqWrC5hcUExKz/0vgLlZqfwzfOnc96sTEYma8qCXtdYB1W7oGoPRCdC2iQI15AU/V5LMzTVef9+Te1v9VBXfrDGLbD2rXIXh/0pjh/uBbUR02Hy2QHBLQtSxkDsMBhAU6UF8+kN5gJ19BOYRQEXAl8NWP0A8D3/XN8DfgJ8toNj++TiNSEjgfiocAqKyvjUiZpzSUREuqUIyA5YzgKKA3dwzlUANwCYN8nqNv+Gc+4R4BF/2w/989Hu+EXAIoDZs2d36bosoVNe28jSNbtYXFDMm1v20uJgyohEvnz2FC7IGcWYtLhQF3Hgc877cl+1xwtulbv9+11Qtdu/97fVlR96bEQsjJzp1dBk5nq34dMgQtNJHKalxavlaqr3g1bA46b6Q7c11bfbHhDEDglpQR7XctjvXR0LjzoY1I6b5we2wFq30RA5uH5ECSbcdXqBCsJ84D3n3O7WFYGPzewh4LmODuyri1d4mDErK1mDqoiISE9YDkwys/F4/c2vAK4K3MHMUoAa51wDcBPwmh/4MLPhzrk9ZjYGr+nmKX1ZeOlZNQ1NvLx+D/8oKObfG0tpaG5hTGocn/voRC7IHcWUkT3cYqiuwqupiIyFqHivWVlkHIQN8NE0W1qgZm+7kLb70Metwa190zqAiBhIGAGJIyFjCkw4ExKGQ8JIb33tASgp8G6r/wwrHvGOC4uE4VO9oDfSD3wjZ3rv7VDQVA/7tkDpBtj7gXdf+gHs23T4wCFdER7l/ZtERHuhOiL64HJkrFdjFhETcPO3R7ZbDtwe6Z8nKtELcPEZA/9z30XBhLtOL1BBuJJ2TTLNLNM5V+IvXozXxyCkcrNSeOyN7dQ3NRMdodGnRETk2DjnmszsVmAp3kjTjzrn1prZLf72B/EGTnnCzJrx+rHfGHCKv/h97hqBz4ey24Icm/qmZl77YC//KCjmpXW7qW1sZkRSNNecMpYLckeRm5WM9WRTL+egaAWsfBzW/KXjcNMa8qLivBH7olqX4ztY39n2OIhKOPg4rBvfm5rq/Zq01nAWWNu2+2CAq9oDrvnw42OS/YA2HLLnHAxwCSMhccTBbTHJnTevy/20d9/SAmXbD4a9klWw8QV4/3f+jgbpkw7W7mXmwshZXiAZqOqrvPAWGOBKN8CB7QHvu8GwsZA+BSZ+3AtPnQWtIwW47nxm5Ig6DXfBXKDMbCRev7kkoMXM7gCmO+cqzCwOb6TN/2p36h+ZWR5es8ztHWzvc7nZKTQ0t7ChpJJcvw+eiIjIsfCn91nSbt2DAY/fAiYd4diP9G7ppDc0tzje2rKPxQU7+eeaXVT4k4tffMJoLsgZxZzxqYT39IjctWVeLdPKx2H3Gi9w5X4axp/hhaaGamisgYYaaKz2ltse13jbKorb7VMDLY1dK0dETCdB0F8fEQ3Vew8NcLUd/XZhXnBoDWcjZ/phbeTB2rbEEV6Qi4ztgTeynbAwSJ3g3WZc7K1zznuvSgpg1yrv/sM3vfe/VcpYP+zlQGae9zhheM+Xrztq9kPpRti78WCA2/uB1yetVVgEpB4HI2bAzEsgYyqkT/YCbW+839JjguoxGsQFahdec82Ojq0B0jpYf02XStoHWgPdqqIyhTsRERHplHP+5OL5xTy/ehd7q+qJjwrn7BkjuSB3FKdP6oXJxZ2DouV+Ld1fvVq6UcfDBffBzE95A4N0V3NjB8GwXShsqAp43NG+NV5tW+D6pgaIS/OCWdpxMPZUP7CNOPQ+Lr3/DWxiBsmjvdvUcw+ur94bUMPnB7/1iw9uTxh5aA1fZo7X36s3B+lwzqsF3bvRC3KlGw/WyAUO5R8R6wW2MSdDxnVejVzGVEgd369GgJTg9bP/NaE1KjmG9IQo8gvLuUa9G0REROQIymsaWfSfLfzt/WJ2ltUSFRHGvKnDuSB3FB+f2kuTi9cegFV/8kLdnnVev6K8K+GE62BUXs8+V3gkxKZ4Nzm6+HSYOM+7taorh12rveacraFv80vgWrztscP8ppytA7fkebWEXe0f1tLiDeXfFuBaw9wHUB8wWEx0MmRM9kaDbA1wGZMhecyQ65M22CncBTAzcrNSKCgqC3VRREREpB9qbnH8cXkhP166gfLaRs6YnMH/fHIyZ83opcnFnYPCd7xAt/ZZb6TAUSfABb/wa+kSev45pftikmHc6d6tVUONF8pL8g/243vnwYODkkQleP32AkNfxhQvaDc3wv6th4e4vZsO7V8ZP9w7ZtalBwNcxlSvRnQADecvx07hrp2crBRe2biHyrrG3vkjLSIiIgPSiu37+fbitawtrmDO+FS+c8EMpo9K6p0nqz0ABX/0Ql3per+W7jNw4nXel34ZeKLiIGu2d2vV1OA1lQzsx/feE15TVoDwaG+S7PLCQ4f/T872+sCN+8jBAJc+GeJS+/Y1Sb+jcNdObnYyzsHqneWcelx6qIsjIiIiIba7oo57XtjAs+/vZGRSDL+48nguyMns2dEuwaul2/G2F+jW/c2rpRt9Ilz4S5hxiWrpBqOIKH/wlZyD61qavakHSgq8Wr7yIphxUcCgJpP1WZAjUrhrJzcrBYCCQoU7ERGRoay+qZlHX9/OL1/ZRFOz49aPTeRzHzuOuKge/vpUsx8KnvZC3d6NEJ0Ex1/t9aUL/NIvQ0NYuF8bNxlyLgt1aWSAUbhrZ1h8FGNS41ilfnciIiJD1isbdvPdf6xj+74aPjl9BN84bxpj03pw0mrnYMdbsOIxWPd3aK6H0bNhwf3e0PtDZYJsEelRCncdyM1OYeX2/aEuhoiIiPSxbXur+d5z63hlwx4mZMTz+A0n8dEpPThPWc1+KPiDX0v3gVdLd8K1Xl+6kbN67nlEZEhSuOtAblYy/ygoZk9lHcMTY0JdHBEREell1fVN/PKVzTzy+laiI8L5+rnTuO7UcURF9MAw8c7Bh2/4fen+7o2OmHUSLPi115dKtXQi0kMU7jrQNpl5YTmfmK5wJyIiMlg55/h7fjH/98J6dlfUc+mJWfzvOVN65sfd6n1Q8HtY+VvYt8mba+zE672+dCNndv/8IiLtKNx1YMaoJMLDjIKiMj4xfUSoiyMiIiK9YM3Ocr69eC0rPzxATlYyD1x9IieMGda9kzoH21/3aunWL/Zq6bLnwkcegOkXecPhi4j0EoW7DsRFRTBpeAIFReWhLoqIiIj0sP3VDfx46UaeXr6D1LgofvSpHC49MYuwsG5MbVC9F/J/D+/9FvZt9iaxnv1Zr5ZuxPSeK7yIyFEo3B1BXnYK/1y7C+dcz89jIyIiIn2uqbmFp97ZwU9e3Eh1QzOfPW08X5g3ieTYyGM7oXOw7TW/lu4f0NII2SfDR+6E6QtUSycifU7h7ghyslJ4enkhO/bX9OzQxyIiItLn3tyyl7sXr2Pj7kpOm5jGdy6YwaQRicGfoKEa9m6C0o3eXHSlG6FkFZTv8GrpTrrJG/Fy+LTeexEiIp1QuDuC3OxkAPILyxTuREREBqidZbX88Pn1PL+6hKxhsTx49YmcPWPEkVvl1B6A0g8OBrjWW/mOg/tYOKQdB6Ny4eNf92rpImP75gWJiByFwt0RTB6RSExkGAWF5SzIGx3q4oiIiEgX1DU2s+i1rfz61c0A/M8nJ7PwjAnERIZ7zSkrd0PpBm+uudKNBx9X7T54kogYSJsE2XPghGsgfTJkTIXUCRARFaJXJiJyZAp3RxAZHsaMUcmsKioLdVFEREQkSM45lq7dzfefX8fOA9VcPTWc23ObSa9dCktaw9wGqAsYNC0qETKmwMRPePfpUyBjMqSMhbDw0L0YEZEuUrg7itysFH7/7oc0NbcQEd4Dk5iKiIhIz2tuggPbKNn8Pv95800iDmziscgSJsQXE769Frb7+8Wle+Ft5qcOBriMqZCYCRo8TUQGAYW7o8jNTubRN1r4YHcV00clhbo4IiIiQ1tjnTcZeGs/uL0bofQD3L7NWEsjmcDlQHXcCGJHzyAs4ywvzLXWxsWnhfoViIj0KoW7o8jNSgGgoKhM4U5ERCQUtvwL3vmN15Sy7ENwLd56C8MNG0dxxBheclNY3TiSCdNO4Mr580hNSw9tmUVEQkTh7ijGpsWRHBtJQWEZV84ZE+riiIiIDC3vPgQv/C8kjoKs2ZBzeVstXH5NGt9esoWCHeWcOHYYd184g5mjk0NdYhGRkFK4OwozIycrmYKi8s53FhERkZ7R0gxLvw7vPACTzoZLH4Fob066PZV1/OifG3lm5XsMT4zm55/OY0HeqCNPbSAiMoQo3HUiLzuFX7+6hdqGZmKjNGKWiIhIr6qvgr/cCB/8E07+HJz1fQgLp6Gphd++uZ37lm2ivqmZW848jls/PpGEaH2VERFpFdRfRDM7B7gPCAceds7d0277VOAx4ATg6865ewO2bQcqgWagyTk321+fCvwRGIc3jtXlzrkD3Xs5PS83K4XmFsfa4nJmj0sNdXFEREQGr/Kd8IdPw+61cO69MOdmAF77oJS7/7GWLaXVfHzqcL55/nTGp8eHuLAiIv1Pp+P7m1k4cD8wH5gOXGlm09vtth/4AnAvHfuYcy6vNdj57gKWOecmAcv85X4nJ9trv59fWBbagoiIiAxmxe/DQx+H/dvhqj+3Bbs/vLuDax99l+YWx6PXz+bR609SsBMROYJgJm+bA2x2zm11zjUATwMLAndwzu1xzi0HGrvw3AuA3/qPfwtc1IVj+8zwxBhGJceo352IiEhvWf8cPHYuhEfCjUth0ifaNr24dhcT0uNZ+sUz+PjUESEspIhI/xdMuBsNFAYsF/nrguWAF81spZktDFg/wjlXAuDfD+/COftUTlYKq4rKQl0MERGRwcU5ePOX8MerYfg0uGkZjJgRsNmxqsgbDTM6Qv3eRUQ6E0y462j4KdeF5zjNOXcCXrPOz5vZGV04FjNbaGYrzGxFaWlpVw7tMbnZKXy4r4YD1Q0heX4REZFBp7kRnrsDXvwGTL8QrnsOEg+tmSs6UMu+6gZys1NCUkQRkYEmmHBXBGQHLGcBxcE+gXOu2L/fAzyL18wTYLeZZQL493uOcPwi59xs59zsjIyMYJ+2R+X6/e5W7VTTTBERkW6rLYOnLoWVj8Pp/wOXPg5RcYftVuC3mslTuBMRCUow4W45MMnMxptZFHAFsDiYk5tZvJkltj4GzgLW+JsXA9f5j68D/t6VgvelWaOTMYMCDaoiIiLSPQe2wyNnwfbXYcH98IlvQ1jHX0cKCsuIighjysjEvi2jiMgA1elUCM65JjO7FViKNxXCo865tWZ2i7/9QTMbCawAkoAWM7sDb2TNdOBZf2LRCOD3zrl/+qe+B/iTmd0I7AAu69FX1oMSYyI5LiNB4U5ERKQ7Ct+FP1wJLU1wzbMw/ug9NQoKy5k5KonI8GB+ixYRkaDmuXPOLQGWtFv3YMDjXXjNNdurAHKPcM59wLygSxpiOVnJvPbBXpxz+GFVREREgrX6Gfjb5yBpFHzmz5A+6ai7NzW3sHpnOZ8+Kfuo+4mIyEH6KSxIedkp7K2qp7i8LtRFERERGTicg3//GP5yI4w+0RsRs5NgB7C5tIraxmb1txMR6QKFuyDlZqUA6ncnIiLBMbNzzGyjmW02s7s62D7MzJ41s1Vm9q6ZzQzY9kUzW2tma8zsD2YW07el7yFN9fDsLfCv70POp+Hav0F8WlCHtl5vNVKmiEjwFO6CNDUzkchwaxu5S0RE5EjMLBy4H28aoOnAlWY2vd1uXwPynXM5wLXAff6xo4EvALOdczPx+rtf0Vdl7zE1++GJi2DV0/Cxr8PFv4GI6KAPzy8sJykmgnFph4+iKSIiHVO4C1J0RDjTM5NUcyciIsGYA2x2zm11zjUATwML2u0zHVgG4JzbAIwzs9aJ3iKAWDOLAOLowhRE/cLeTfDwPNi5Ej71CJz5v9DF/uoFhWXkZqeon7uISBco3HVBTlYKa3ZW0NzSlTncRURkCBoNFAYsF/nrAhUAlwCY2RxgLJDlnNsJ3Is3knQJUO6ce7HXS9xTtv0HHv4E1FXAdf+AWZd2+RS1Dc1s3F2p/nYiIl2kcNcFudkpVNU3sbW0KtRFERGR/q2j6qb2vwzeAwwzs3zgNuB9oMnMhuHV8o0HRgHxZnb1YU9gttDMVpjZitLS0h4t/DF7/3fw5MWQMAJuehnGzD2m06wtLqe5xbX1dxcRkeAo3HVBXnYyAPlqmikiIkdXBASO4Z9Fu6aVzrkK59wNzrk8vD53GcA24BPANudcqXOuEfgrcGr7J3DOLXLOzXbOzc7IyOillxGklhZ4+W74++dh3Glw44uQOv6YT9d6nc3xr7siIhIchbsumJCeQEJ0BKuKykNdFBER6d+WA5PMbLyZReENiLI4cAczS/G3AdwEvOacq8BrjnmymcWZ1+FsHrC+D8veNY218MwN8PpP4YTr4DPPQGxKt065qqicUckxDE8cmIOEioiESlCTmIsnLMyYNTpZI2aKiMhROeeazOxWYCneaJePOufWmtkt/vYHgWnAE2bWDKwDbvS3vWNmzwDvAU14zTUXheBldK5qD/zhSm/glLO+D6fc2uWBUzpSUFSmKRBERI6Bwl0X5WQn8+jr26hvaiY6IjzUxRERkX7KObcEWNJu3YMBj98COpzN2zn3beDbvVrA7tq9Dn7/aajZC5/+HUw7v0dOe6C6gQ/31XDlnDE9cj4RkaFEzTK7KC8rhcZmx/qSylAXRUREJDQ2vwyPnAXNDXDDkh4LdkBb6xgNpiIi0nUKd13U2kxE892JiMiQtPxheOpyGDYObn4FRh3fo6cvKCzHDGZlaTAVEZGuUrPMLspMjiE9IVr97kREZGhpaYYXvwFv/xomnQ2XPgLRiT3+NAVFZUwa7g1gJiIiXaO/nF1kZuRlJ6vmTkREho76KvjLTfDBCzD3v+HsH0BYz/c7d85RUFjGx6YO7/Fzi4gMBWqWeQxys1LYureairrGUBdFRESkd5XvhMfOgU1L4dx7Yf49vRLsAHaW1bKvukEjZYqIHCOFu2OQk52Cc7BG892JiMhgVpwPD8+D/dvhqj/DnJt79ekKCr3rap4GUxEROSYKd8cg1+/kna9+dyIiMlhteB4emw9hEXDjUpj0iV5/yoKiMqIiwpgysuf78omIDAUKd8cgJS6KsWlxrCpUzZ2IiAwyzsGbv4KnPwMZU+GmZTBiRp88dX5hGTNGJREVoa8nIiLHQn89j1FuVopGzBQRkcGluRGe+yK8+HWYfiFc/zwkjuiTp25qbmF1UbnmtxMR6QaFu2OUm51CSXkdeyrqQl0UERGR7qsrh6cug5WPwelfhEsfh6i4Pnv6zaVV1DY2k5ut+e1ERI6Vwt0xau13V6BBVUREZKCr2Q+PnAXb/wML7odPfAfC+vYrQusUQ6q5ExE5dgp3x2jGqGTCw0zz3YmIyMAXOwzGngbXPAvHXx2SIhQUlZMUE8G4tPiQPL+IyGCgScyPUWxUOJNHJKrfnYiIDHxmcP5PQ1qEgsIycrNTCAuzkJZDRGQgC6rmzszOMbONZrbZzO7qYPtUM3vLzOrN7M6A9dlm9i8zW29ma83s9oBt3zGznWaW79/O7ZmX1HfyspMpKCzDORfqooiIiAxYdY3NbNhVqSaZIiLd1Gm4M7Nw4H5gPjAduNLMprfbbT/wBeDeduubgC8556YBJwOfb3fsz5xzef5tybG+iFDJzUqhoq6J7ftqQl0UERGRAWttcTnNLY7c7JRQF0VEZEALpuZuDrDZObfVOdcAPA0sCNzBObfHObccaGy3vsQ5957/uBJYD4zukZL3Azn+L4yr1DRTRETkmOX788a2DlYmIiLHJphwNxooDFgu4hgCmpmNA44H3glYfauZrTKzR81sWFfPGWqTRyQQExlGvgZVEREROWYFhWVkJscwPCkm1EURERnQggl3HfVs7lInMzNLAP4C3OGcq/BXPwAcB+QBJcBPjnDsQjNbYWYrSktLu/K0vS4iPIyZo5JZpekQREREjtmqojL1txMR6QHBhLsiIDtgOQsoDvYJzCwSL9g95Zz7a+t659xu51yzc64FeAiv+edhnHOLnHOznXOzMzIygn3aPpObncKaneU0NreEuigiIiIDTllNA9v31ai/nYhIDwgm3C0HJpnZeDOLAq4AFgdzcjMz4BFgvXPup+22ZQYsXgysCa7I/Utudgr1TS1s3FUZ6qKIiIgMOAV+65fcbPW3ExHprk7nuXPONZnZrcBSIBx41Dm31sxu8bc/aGYjgRVAEtBiZnfgjayZA1wDrDazfP+UX/NHxvyRmeXhNfHcDvxXD76uPtPa+XtVUTkzR+vCJCIi0hUFhWWYwSxdQ0VEui2oScz9MLak3boHAx7vwmuu2d7rdNxnD+fcNcEXs/8akxpHSlwkBYVlXDV3TKiLIyIiMqAUFJYxMSOBxJjIUBdFRGTAC2oSczkyMyMnK4UCTYcgIiLSJc45CorK1N9ORKSHKNz1gLysZD7YXUlNQ1OoiyIiIjJg7CyrZW9Vg+a3ExHpIQp3PSA3O4UWB2t2VnS+s4iIiAC0TSWkmjsRkZ6hcNcDcvy5eVapaaaIiEjQCgrLiAoPY+rIpFAXRURkUFC46wEZidGMToklv7As1EUREREZMPILy5g+KomoCH0dERHpCfpr2kNyspLbmpeIiIjI0TW3OFbvLCdPTTJFRHqMwl0Pyc1OYcf+GvZXN4S6KCIiIv3e5j1V1DQ0a/JyEZEepHDXQ3L9fneaEkFERKRzBX5Xhtbrp4iIdJ/CXQ+ZlZWMGawqVNNMERGRzuQXlZEYE8G4tPhQF0VEZNBQuOshCdERTMxIUM2diIhIEFYVlZGblUJYmIW6KCIig4bCXQ/KzU5hVVEZzrlQF0VERELMzM4xs41mttnM7upg+zAze9bMVpnZu2Y2018/xczyA24VZnZHn7+AXlTX2MyGkkr1txMR6WEKdz0oNyuZvVUN7CyrDXVRREQkhMwsHLgfmA9MB640s+ntdvsakO+cywGuBe4DcM5tdM7lOefygBOBGuDZvip7X1hbXEFTi1N/OxGRHqZw14Ny/eGcC9TvTkRkqJsDbHbObXXONQBPAwva7TMdWAbgnNsAjDOzEe32mQdscc592NsF7kutg6loGgQRkZ6lcNeDpo5MIio8jFXqdyciMtSNBgoDlov8dYEKgEsAzGwOMBbIarfPFcAfeqmMIVNQVEZmcgzDk2JCXRQRkUFF4a4HRUWEMW1UEvn+L5IiIjJkdTRKSPsO2fcAw8wsH7gNeB9oajuBWRRwIfDnDp/AbKGZrTCzFaWlpT1S6L5SUFimJpkiIr1A4a6H5WUls2ZnOc0tGlRFRGQIKwKyA5azgOLAHZxzFc65G/y+ddcCGcC2gF3mA+8553Z39ATOuUXOudnOudkZGRk9WvjeVFbTwPZ9NeRoMBURkR6ncNfDcrJSqG5oZktpVaiLIiIiobMcmGRm4/0auCuAxYE7mFmKvw3gJuA151xFwC5XMgibZK4q8vql56nmTkSkxync9bDWQVXUNFNEZOhyzjUBtwJLgfXAn5xza83sFjO7xd9tGrDWzDbg1dLd3nq8mcUBnwT+2rcl730FhWWYwcws1dyJiPS0iFAXYLCZkB5PYnQEq4rKuHx2ducHiIjIoOScWwIsabfuwYDHbwGTjnBsDZDWqwUMkYKiMo7LSCApJjLURRERGXRUc9fDwsKMWVnJmg5BRESkHecc+YXlGkxFRKSXKNz1gtzsFNaXVFDX2BzqooiIiPQbxeV17K2qJ0+DqYiI9AqFu16Qm5VMU4tjfUlF5zuLiIgMEa2Tl+dq8nIRkV6hcNcLWi9aBRpURUREpE1BURlR4WFMHZkU6qKIiAxKQYU7MzvHzDaa2WYzu6uD7VPN7C0zqzezO4M51sxSzewlM9vk3w/r/svpH0YmxZCRGN023LOIiIh4P3pOG5VEVIR+WxYR6Q2d/nU1s3DgfrxhmqcDV5rZ9Ha77Qe+ANzbhWPvApY55yYBy/zlQcHMyM1KIb+oLNRFERER6ReaWxyri8rJ0xQIIiK9JpifzuYAm51zW51zDcDTwILAHZxze5xzy4HGLhy7APit//i3wEXH9hL6p7zsZLaWVlNe2/4tERERGXq2lFZR3dCs/nYiIr0omHA3GigMWC7y1wXjaMeOcM6VAPj3wzs6gZktNLMVZraitLQ0yKcNvRx/mOc1O9U0U0REJF+DqYiI9Lpgwp11sM4Fef7uHOvt7Nwi59xs59zsjIyMrhwaUjl+s5N8DaoiIiJCQWEZiTERjE+LD3VRREQGrWDCXRGQHbCcBRQHef6jHbvbzDIB/Ps9QZ5zQEiJi2JcWhyr1O9ORESEgqIycrKSCQvr6HdfERHpCcGEu+XAJDMbb2ZRwBXA4iDPf7RjFwPX+Y+vA/4efLEHhtzsFAoK1SxTRESGtrrGZjaUVJLrd1kQEZHe0Wm4c841AbcCS4H1wJ+cc2vN7BYzuwXAzEaaWRHwP8A3zKzIzJKOdKx/6nuAT5rZJuCT/vKgkpuVwq6KOnZX1IW6KCIiIiGzrqSCphan/nYiIr0sIpidnHNLgCXt1j0Y8HgXXpPLoI711+8D5nWlsANNbrbX766gsIyzZowMcWlERERCo8Dvf56ncCci0qs0i2gvmjEqmfAwo0D97kREZAgrKCxjZFIMI5JiQl0UEZFBTeGuF8VEhjN1ZCKritTvTkREhq6CovK21iwiItJ7FO56WU5WCgWFZbS0dGkGCBERkUGhvKaRbXur1d9ORKQPKNz1srzsZCrqmti+rzrURREREelzq3aWAWikTBGRPqBw18ty/IuZmmaKiMhQ1DqYyqwsNcsUEeltCne9bNLwBGIjw8n3L24iIiJDSX5hOcdlxJMUExnqooiIDHoKd70sIjyMWaOTWaURM0VEZIhxzpFfWKb+diIifUThrg/kZCWzpriCxuaWUBdFRESkz5SU17G3ql7z24mI9BGFuz6Qm51CQ1MLG3dVhrooIiIifaa1v50GUxER6RsKd32g9aKmycxFRGQoyS8qIyo8jKmZiaEuiojIkKBw1weyU2MZFhfZ9gumiIjIULCqsJxpmYlER4SHuigiIkOCwl0fMDNys1M0HYKIiAwZzS2O1TvLNZiKiEgfUrjrIzlZKXywu5Lq+qZQF0VERKTXbS2toqq+Sf3tRET6kMJdH8nLTqbFwZqdqr0TEZHBr3V+V9XciYj0HYW7PpLj/3KpppkiIjIUFBSVkRgdwYT0+FAXRURkyFC46yPpCdGMToklXyNmiojIEFBQWE5OdjJhYRbqooiIDBkKd30oLztFI2aKiMigV9fYzPqSCvW3ExHpYwp3fSgnK5miA7Xsq6oPdVFERER6zfqSCppaXFuXBBER6RsKd32otVO5+t2JiMhg1tpKJU+DqYiI9CmFuz40c3QyZl4ncxERkcGqoKicEUnRjEyOCXVRRESGFIW7PpQQHcGk4QnqdyciMgSY2TlmttHMNpvZXR1sH2Zmz5rZKjN718xmBmxLMbNnzGyDma03s1P6tvTdU1BYpv52IiIhoHDXx3KzUigoKsc5F+qiiIhILzGzcOB+YD4wHbjSzKa32+1rQL5zLge4FrgvYNt9wD+dc1OBXGB975e6Z5TXNLJ1b7XmtxMRCYGgwl0Qvz6amf3C377KzE7w108xs/yAW4WZ3eFv+46Z7QzYdm6PvrJ+Kic7hf3VDRQdqA11UUREpPfMATY757Y65xqAp4EF7faZDiwDcM5tAMaZ2QgzSwLOAB7xtzU458r6rOTdtGpnGaD+diIiodBpuAvy18f5wCT/thB4AMA5t9E5l+ecywNOBGqAZwOO+1nrdufcku6+mE61NEPlrl5/mqPJ85upqN+diMigNhooDFgu8tcFKgAuATCzOcBYIAuYAJQCj5nZ+2b2sJkdNhO4mS00sxVmtqK0tLQ3XsMxae16MCsrObQFEREZgoKpuQvm18cFwBPO8zaQYmaZ7faZB2xxzn3Y7VIfq1fvgQdPh8J3Q1aEKSMTiQoP04iZIiKDW0czd7dvj38PMMzM8oHbgPeBJiACOAF4wDl3PFANHNZqxjm3yDk32zk3OyMjoyfL3i0FReVMyIgnKSYy1EURERlyggl3wfz6GMw+VwB/aLfuVr8Z56NmNiyIsnTPrMsgKgEePx9W/anXn64jURFhTB+VRL4GVRERGcyKgOyA5SygOHAH51yFc+4Gv3XLtUAGsM0/tsg5946/6zN4Ya/fc86RX1jW1kpFRET6VjDhLphfH4+6j5lFARcCfw7Y/gBwHJAHlAA/6fDJe7LZScZkuPkVyDoJ/nozLPsetLR075zHIC87hTU7y2lu0aAqIiKD1HJgkpmN96+BVwCLA3fwR8SM8hdvAl7zA98uoNDMpvjb5gHr+qrg3bGroo7SynoNpiIiEiLBhLtOf30MYp/5wHvOud2tK5xzu51zzc65FuAhvOafh+nxZidxqXDNs3D8NfCfe+GZ66Ghpvvn7YKcrGRqGprZvKeqT59XRET6hnOuCbgVWIo30uWfnHNrzewWM7vF320asNbMNuBdJ28POMVtwFNmtgrvR9Af9lnhu6G1v53CnYhIaEQEsU/br4/ATrxfH69qt89ivCaWTwNzgXLnXEnA9itp1yTTzDID9rkYWHMM5T82EVFw4S8hYyq8+A048CFc+QdIGtUnT9960SsoLGPKyMQ+eU4REelb/kBhS9qtezDg8Vt4A5F1dGw+MLs3y9cb8gvLiQw3pmXq2iYiEgqd1twF+evjEmArsBmvFu5zrcebWRzwSeCv7U79IzNb7f8q+THgi919MV1iBqfeClc+Dfs2w0Mfh+L3++Spx6fFkxgToREzRURkUCkoLGN6ZhLREeGhLoqIyJAUTM1dML8+OuDzRzi2BkjrYP01XSppb5lyDnx2KfzhCnh0Plz8IMy4qFefMizMyMlKVrgTEZFBo6XFsXpnORcf3348NRER6StBTWI+6I2c6Q20MnIW/Pk6eO3H4Hp3sJPcrBQ2lFRS19jcq88jIiLSF7buraKqvkn97UREQkjhrlXCcLjuHzDrcnjl+/DXhdBY12tPl5OVQlOLY11JRa89h4iISF/JL/Tmb83L1uTlIiKhonAXKDIGLlkEH/8mrP4T/PYCqNrTK0+VFzCoioiIyEBXUFhGQnQEE9ITQl0UEZEhS+GuPTM44064/AnYtdobaGVXzw/kOTI5hhFJ0awqKu/xc4uIiPS1gqIycrKSCQvraOpbERHpCwp3RzJ9AXz2BWhpgkfPho0v9PhT5GSlqOZOREQGvLrGZtaXVKi/nYhIiCncHc2o472BVtImwh+uhDd/2aMDreRlp7B1bzXltY09dk4REZG+tr6kgsZmR25WSqiLIiIypCncdSZpFNzwAky/0JvwfPFt0NTQI6fOyfI6na9W00wRERnAWrsY5GowFRGRkFK4C0ZUHFz6OJzxZXj/SXjyYqjZ3+3T5oxOAdB8dyIiMqAVFJYxPDGakUkxoS6KiMiQpnAXrLAw+Pg34JKHoGi5N9BK6QfdOmVyXCQT0uPV705ERAa0/KIycrNTMNNgKiIioaRw11U5l8P1z0FDFTz8Cdi8rHuny0pWzZ2IiAxY5bWNbC2tbpviR0REQkfh7lhkz/EGWknOgqcug3cfOuZT5WansLuinl3lvTdhuoiISG9p7TeuwVREREJP4e5YpYyBG5fCpE/Ckjvh+TuhuanLp8nxL4aqvRMRkYGo9fo1K0uDqYiIhJrCXXdEJ8IVv4dTb4PlD8HvL4Pasi6dYsaoJCLCTP3uRERkQMovLGNCRjzJsZGhLoqIyJCncNddYeFw1vfhwl/CttfgkU/Cvi1BHx4TGc7UzETV3ImIyIC0qqhMTTJFRPoJhbuecsK1cO3foboUHp4H218P+tCcrBRWFZXT0tJzE6SLiIj0tl3ldeyuqCdXTTJFRPoFhbueNO50uGkZxGfAExfBe08GdVheVgqVdU1s21fdu+UTERHpQfl+l4JcjZQpItIvKNz1tLTj4MaXYPxHYPGt8OI3oKX5qIfkZHu/eD76+jbqGo++r4iISH9RUFRGZLgxLTMp1EUREREU7npHbApc9Wc46WZ485fw9GegvvKIu08ensgVJ2Xz1Ds7OOtnr/GvjXv6rqwiIiLHqKCwjGmZScREhoe6KCIigsJd7wmPgPPuhXPvhU0vwqPnQNmODncNCzPu+VQOv795LpHhxg2PLee/f7eSkvLaPi60iIhIcFpaHKuLyjWYiohIP6Jw19vm3Ayf+TOUFcJDH4fCd4+466nHpfPC7Wfw5bOn8MqGPcz7yb956LWtNDa39GGBRUREOrd1bzWV9U3qbyci0o8o3PWFifPgppchKgEePx9W/emIu0ZFhPH5j03k5f85k5MnpPGDJeu54Jevs2L7/j4ssIiIyNG1zs+qkTJFRPoPhbu+kjEZbn4Fsk6Cv94Mr3wfWo5cI5edGscj183mN9ecSEVtI5c++Bb/+0wB+6sb+rDQIiIiHSsoKiMhOoIJGQmhLoqIiPgU7vpSXCpc8ywcfw289mN45npoqDni7mbG2TNG8vKXzuS/zpzAX9/bycd/8ip/XL5Dc+KJiEhIFRSWMWt0MuFhFuqiiIiIL6hwZ2bnmNlGM9tsZnd1sN3M7Bf+9lVmdkLAtu1mttrM8s1sRcD6VDN7ycw2+ffDeuYl9XMRUXDhL+Gs78O6xfDYfKgoPuohcVERfHX+NJ7/wkeYPDyRr/xlNZf95i3Wl1T0UaFFREQOqm9qZl1JhfrbiYj0M52GOzMLB+4H5gPTgSvNbHq73eYDk/zbQuCBdts/5pzLc87NDlh3F7DMOTcJWOYvDw1mcOptcOXTsG+zN9BK8fudHjZlZCJ//K+TufeyXLbtreb8X77O959bR1V9Ux8UWkRExLO+pJLGZkdetvrbiYj0J8HU3M0BNjvntjrnGoCngQXt9lkAPOE8bwMpZpbZyXkXAL/1H/8WuCj4Yg8SU86Bzy6FsAh4dD786VpY9j0oeBqKVkJd+WGHmBmXnpjFK186k8tnZ/Pw69v4xE/+zZLVJTinppoiItL72gZTUc2diEi/EhHEPqOBwoDlImBuEPuMBkoAB7xoZg74jXNukb/PCOdcCYBzrsTMhnf05Ga2EK82kDFjxgRR3AFm5ExvoJWlX4Od78H658A1H9yeMALSJkH6RP9+EqRNJCVlLP93ySwum53F159dw+eeeo8zJ2fw3QUzGJsWH7rXIyIig15BURnDE6MZmRQT6qKIiEiAYMJdRz2l21cRHW2f05xzxX54e8nMNjjnXgu2gH4YXAQwe/bswVk1lTAcPvWw97ipAQ5sh32bYK9/27fJ659XGzAdQlgkpE7ghPRJPD9tIq9nDuM3a7dy6c8+5OqP5nHLRycQHREekpcjIiKDW0FhGTlZKZhpMBURkf4kmHBXBGQHLGcB7UcAOeI+zrnW+z1m9ixeM8/XgN1mlunX2mUCe47tJQwyEVHetAkZkw/fVrP/YNjbu8nrr7d3E2EfLOWMlkbOMCAC9v0nkQ/ezCJj/ExGjp/p1falT4Zh4yA8sq9fkYiIDCIVdY1sKa3m4uNHh7ooIiLSTjDhbjkwyczGAzuBK4Cr2u2zGLjVzJ7Ga7JZ7oe2eCDMOVfpPz4L+G7AMdcB9/j3f+/2qxns4lJhzFzvFqi5Cco+bAt+DVtW07xtDeGbX4Qtfz64X1iEF/AOa+Y5CeLTvYFeREREjmJ1kdcfXP3tRET6n07DnXOuycxuBZYC4cCjzrm1ZnaLv/1BYAlwLrAZqAFu8A8fATzrN9uIAH7vnPunv+0e4E9mdiOwA7isx17VUBMeAWnHeTfOIfNUGNbYzIP/3sKTr65iUvhuPjezmdNTDhC2f7NX47flFWiuP3iOmJSDQS8w+KVOgIjoUL0yEZEBy8zOAe7Du3Y+7Jy7p932YcCjwHFAHfBZ59waf9t2oBJoBprajTYdUvn+YCo5o1NCWg4RETlcMDV3OOeW4AW4wHUPBjx2wOc7OG4rkHuEc+4D5nWlsBK8mMhw7vjEZC7KG823Fq/l2pWlzBiVxPcv+gLHjxkGLc1QXgh7N8PeDw429dz6Lyj4/cETWRgMnwHTLoAZF0HGlJC9JhGRgSJgGqFP4nVdWG5mi51z6wJ2+xqQ75y72Mym+vsHXhc/5pzb22eFDlJBYRkT0uNJjlMzfxGR/iaocCcD17j0eH57w0ksWb2L7z63lkseeJMr54zhK2dPJXnYOK+Z5qRPHHpQfaXfn88Pfttfh1f/D179IWRMhekLYPpFMHyamnKKiHSsbRohAL/bwgIgMNxNB/4PwDm3wczGmdkI59zuPi9tFxQUlXHqcemhLoaIiHRA4W4IMDPOy8nkjMnp/OylTTz+5jaWrtnFV8+dxqdOGH34aGfRiTDqeO/WqqIENjwH6/4Or/0Y/v3/vKabMy7ywt6ImQp6IiIHBTONUAFwCfC6mc0BxuINSLabI08jFFK7yuvYXVFPbpYmLxcR6Y+CmcRcBonEmEi+dcF0/nHb6YxNi+POPxfw6UVv88Huys4PTsqEOTfD9c/BlzbCeT+FpFHwn5/Ag6fDL0+Al++G4nzQZOoiIsFMI3QPMMzM8oHbgPeBJn/bac65E4D5wOfN7IzDnsBsoZmtMLMVpaWlPVfyoygoKgMgR4OpiIj0Swp3Q9CMUck8c8up3HPJLD7YXcm59/2He17YQE1DU+cHgzcv30k3wnWL4c5NcMF9XvPON+6DRWfCL/LgpW/BzpUKeiIyVHU6jZBzrsI5d4NzLg+4FsgAtvnb2qYRAlqnEaLd8Yucc7Odc7MzMjJ65UW0V1BYRkSYMT0zqU+eT0REukbhbogKCzOumDOGZf9zJhcfP5oH/72FT/70NV5cu6trJ4pPhxOvh2uehS9vhgt/5TXXfOt+eOjj8PMcWPp1KFwOLS298lpERPqhtmmEzCwKbxqhxYE7mFmKvw3gJuA151yFmcWbWaK/T+s0Qmv6sOxHVFBUxrTMJGIiw0NdFBER6YD63A1xaQnR/PiyXC4/KZtvPLuGhU+u5BPThvPtC2aQnRrXtZPFpcIJ13i32gOw8QWvj967i+CtX0HSaH8wlgWQNQfC9NvCkOUcVBR7tbvF70HJKu+HgrGnwbjTvSk41IdTBrAgpxGaBjxhZs14A63c6B9+tGmEQqalxbGqsJwFx48KdVFEROQIzA2gZnOzZ892K1asCHUxBq3G5hYee2MbP395Ey3OcdvHJ3HzRyYQFdHNEFZXDhv/6QW9zS978+sljITpF3qjbo45GcL0K/CgVrPfC3E7/Vvxe1DlDwgYFgEZ06BqF1T7/YYSM/2gdxqMPd2bc1Fhb0gxs5X9aW63/q4vro+b91TxiZ/+mx9fmsNls7M7P0BERHrF0a6RqrmTNpHhYSw84zjOzxnF3f9Yy4+XbuSv7xXx9fOmccakDCLCjzHkxSRD7qe9W10FbHoR1v0N3nvCq9WLH35wHr0xp3qTssvA1VANJQV+kPNr5g5sP7g9fTJM+BiMPgFGnQAjZ0FkjFebt3cTfPg6bH/Dm4JjzTPeMfHD/aDn1+xlTFXYE+ljBf7k5XkaTEVEpN9SzZ0c0SsbdvOtv6+l6EAtqfFRnDNzJOfPymTuhDTCw3rgi3V9lR/0/u7dN9ZAXDpMO9+r0Rv3EQW9/q6pAfasPVgbt/M9KN0Azu9fmZTlhbjWIDcqzwv7wXAO9m/1Qt6Hftir2Olti0s7GPTGngbDp6uZ7yCjmruu6Yvr47f/voZnVhax6jtn98w1QEREjsnRrpEKd3JUdY3NvLqxlOdXl7Bs/W5qGppJT4hi/sxMzsvJ5KRxqT1zkW+o9ppsrvu714SzsRpiU2HqeV6N3vgzITyy+88jx66lBfZtCghyK2HXGq+ZLXj/XqNPPBjkRp/gjazaU5zzagA/fMOr2fvwdSjb4T/3MK/Wd9zpXg3fiJlq6jvAKdx1TV9cHxfc/waxkWE8vfCUXn0eERE5OoU76RG1Dc28unEPz60qYdmG3dQ1tjA8MZpzZ3lB78QxwwjriaDXWAubl/lB7wVoqISYFC/oTb8IJnwUIqI6OYl0i3NQXnhojVxxvvdvARAZ79XCBQa5lLF931SybMfBoLf9DTiwzVsfnQxjTzlYuzcyR7XAA4zCXdf09vWxvqmZWd9+kRtOH8dX50/rtecREZHOqc+d9IjYqHDmz8pk/qxMahqaeGXDHp5fVcIf3t3B429uZ0SSF/TOz8nk+OxuBL3IWK9p5rTzobEOtv7LC3rrn4P8p7wv7lPmw9RzvRqb/igswnsdkXHefUSsvxzbP2uUqve2C3LvHRzcJCwSRs6EnMsP1sylT+4fryNlDOSNgbwrveXynQebcH74BnzgDzAYlegN3DPuNK+5b2auaoJFumBDSSUNzS3kZaWEuigiInIUCndyTOKiIjg/ZxTn54yiqr6JZet38/yqEp56ZwePvbGdUckxbTV6edkp2LHW6ETGeEFuynxoqoet//aC3obnYNXTPfui+kp49MGgd0gAjDn4ODLOe+1ty4EBMa7d8f66iHb7Hym81Fd6tXCBQa61eSMGGVNg4icP9pUbMRMiovvq3eme5NFeCM253Fuu3BXQZ+8NePklb31kPIyZe7Bmb9QJqg0WOYqCojIAcjWYiohIv6ZmmdKjKusaedkPev/+oJTGZsfolFjOz/GC3qzRycce9AI1N3oBpbW/V3/T3Og1L22sgaa6g48bawNu/nJTXcC2Gq+28pD9a4Bj+H8aFnF4OGxpgn2bD54vZczBZpWtA55EJ/bgG9HPVO0J6LP3BuxZ562PiIXsOQcHaMma3TuB1jlobmj3714HTbUH/91bHzcFfFbaPkOB2+u8QB+dGHBLarfcbl1k7IAZZVTNMrumt6+P//OnfP6zaS/vfm1ez/wNFxGRY6ZmmdJnEmMiufj4LC4+Povy2kZeWreb51cV88jr2/jNa1sZkxrHeTmZnDcrkxmjko79S0J4JGSf1LOF76/aAkH7cBgYENstN7YLjK1BwjmYdenBQBefHupX17cShsOMi70bQPU+2PGmV7u3/Q341w8B59WuZs/xgt7Imf77Hxi4AgNZQIAPJrC1jiTaVeFRATW4MV6wa6rzamLrKqClsfNzWDhEJxwhBB4pHLZbH5Xg3feHZrnSZ1YVlZOb1Y1WGCIi0icU7qTXJMdGcumJWVx6YhZlNQ28uG43z60qYdFrW3ng1S2MS/OC3vk5o5g6MlFfGo7EzKtFiojuv30MB6r4NG+OxWkXeMu1B+DDt/zavf/Aaz86chiz8HbNaf3AFRkLUXHedA2tTWtb17feH/Y49gj7Bpy3szDVVO8FvfoK/77Sm27ksHWVh66r2Q8HPjy4vrE6uPcuMv7IQTB5NHz8G8H/O0i/VlHXyJbSKhbkjgp1UUREpBMKd9InUuKiuHx2NpfPzmZ/dQMvrt3Fc6tKeODVLdz/ry1MyIjn/FmZnJcziikjB3GzQOnfYod5A/VMPddbriv35tprC1wBIay/DcjS+gNAd2tjW5qhoarjINhhQAzYt3qvd5+QoXA3iKwpKsc59bcTERkIFO6kz6XGR3HFnDFcMWcM+6rq+efaXTy/qoRf/Wszv3hlM5OGJ/g1eplMHK6gJyEUkwyjjg91KfpWWLj3uoOdbF4GvXx/MJWcLH0mRET6O4U7Cam0hGg+M3csn5k7ltLKev65poTnVpVw37JN/PzlTUwdmch5/qibEzISQl1cEZEhp6CwjPHp8aTEaURZEZH+TuFO+o2MxGiuOWUc15wyjj0VdSxZXcLzq0v4yUsf8JOXPmBaZpI36uasTMalx4e6uCIiQ0JBYTknT0gNdTFERCQICnfSLw1PiuH608Zz/WnjKSmv5YXVu3huVTE/XrqRHy/dyMzRSZw7K5NTJqQxY1QyURFhoS6yiMigs7uijl0VdepvJyIyQCjcSb+XmRzLZ08fz2dPH8/OslpeWO013fzRPzcCEBURxqzRyZwwJoUTxgzjhLHDGJEUE+JSi4gMfAWFZYAGUxERGSiCCndmdg5wHxAOPOycu6fddvO3nwvUANc7594zs2zgCWAk0AIscs7d5x/zHeBmoNQ/zdecc0u6/YpkUBudEstNH5nATR+ZwO6KOt778ADv7TjAezvK+O1bH/LQf7a17Xf8mBSOHzOME8akqHZPROQYFBSVERFmTM9MCnVRREQkCJ2GOzMLB+4HPgkUAcvNbLFzbl3AbvOBSf5tLvCAf98EfMkPeonASjN7KeDYnznn7u25lyNDyYikGObPymT+rEwA6puaWVdcwXs7yrzA9+EBnltVAqh2T0TkWBQUljM1M5GYSE1aLyIyEARTczcH2Oyc2wpgZk8DC4DAcLcAeMI554C3zSzFzDKdcyVACYBzrtLM1gOj2x0r0iOiI8I5fswwjh8zjBsZD8Cu8rq2oPfejgP89s3Da/daw970zCTV7omI+FpaHAVFZVyoyctFRAaMYMLdaKAwYLkIr1aus31G4wc7ADMbBxwPvBOw361mdi2wAq+G70DQJRcJwsjkGM6dlcm5AbV7a4sreO/DA7y/o4yVAbV70a21e2OHtdXwDVftnogMUdv2VVNZ16T+diIiA0gw4c46WOe6so+ZJQB/Ae5wzlX4qx8Avufv9z3gJ8BnD3tys4XAQoAxY8YEUVyRI4uOCPdq6sYMa1tXUl7Lex/6TTl3HODxN7az6LUWwKvdCwx700clERmu2j0RGfxaB1PJU7gTERkwggl3RUB2wHIWUBzsPmYWiRfsnnLO/bV1B+fc7tbHZvYQ8FxHT+6cWwQsApg9e3b7UCnSbZnJsZyXE8t5OQdr99bsrOB9P+wt37affxR4H/noiDByspI5wW/+ecLYFIYnqnZPRAafgsIy4qPCOS4jIdRFERGRIAUT7pYDk8xsPLATuAK4qt0+i/GaWD6N12Sz3DlX4o+i+Qiw3jn308ADAvrkAVwMrOnG6xDpMdER4Zw4dhgnjj1Yu1dcVuv33fNq+B59YxuNr20FIGtYrF8bmMIJY4cxLVO1eyIy8BUUlTMrK5nwsI4a54iISH/UabhzzjWZ2a3AUrypEB51zq01s1v87Q8CS/CmQdiMNxXCDf7hpwHXAKvNLN9f1zrlwY/MLA+vWeZ24L966DWJ9LhRKbGMSonl/BxvYIG6xmbWFpe3hb13tu1jsV+7FxMZxpSRSUwansDE4QlMzPDus1Pj9CVJRAaEhqYW1hVXcMNp40JdFBER6YKg5rnzw9iSduseDHjsgM93cNzrdNwfD+fcNV0qqUg/EhMZzoljUzlxbCoAzjmKyw/Ou7ehpJJ/f1DKMyuL2o6JighjQnq8F/j826ThiYxLjyM6QsOMi0j/sWFXBQ3NLRpMRURkgAkq3InI0ZkZo1NiGZ0SywUBw4aX1zSyubSSzXuq2m4FRWU8v7oE5/cgDQ8zxqTGcVxGApNGHKzpO254AgnR+i8qIn2vdTAVhTsRkYFF3xxFelFyXOQhNXytahua2VJaxZbSg6Fv054qXt24h6aWg+MGZSbHHFLT1xr80hKi+/qliMgQkl9YTnpCNKOSNWCUiMhAonAnEgKxUeHMHJ3MzNHJh6xvbG7hw301bN7jBb9NuyvZXFrF0+8WUtvY3LZfanwUEzO82j2vead3n5kcgzeOkYjIsSsoKiMvO1l/T0REBhiFO5F+JDI8rK2WLlBLi6O4vPaQ5p2b91TxwpoSymoa2/aLjwr3Al9GAhMDmniOSY0jQiN4ikgQKusa2VJaxYKAJuYiIjIwKNyJDABhYUbWsDiyhsXx0SnD29Y759hX3dDWrHOLH/re3LKPv76/s22/qPAwxqV7/fpGJMUwIimG4YnR/uNohifGkBQboV/pRYTVO8txTv3tREQGIoU7kQHMzEhPiCY9IZqTJ6Qdsq2irrEt7G0u9YLfxt2V/GfTXqrqmw47V3REGMOTohmR6IW/DD/8KQSKDC0FheUA5GQld7KniIj0Nwp3IoNUUkwkx48ZxvFjhh22rbq+iT2V9eypqGO3f7+nsp7dFXXsqahn/a4K/v1BvUKgyBBUUFjGuLQ4UuKiQl0UERHpIoU7kSEoPjqC8dERjE+PP+p+rSFwtx/+AkPg7oo6hUCRQaigqIw541M731FERPodhTsROaLuhMDdAUEwmBCY4TcvTU/07jMSorz7xIPr46PCFQRlwDCzc4D7gHDgYefcPe22DwMeBY4D6oDPOufWBGwPB1YAO51z5/dFmXdX1FFSXkduVkpfPJ2IiPQwhTsR6bbuhsDdFfXsrapn+75qVnx4gP3VDR0eHxMZ1tbH0At+Ue2Wo0lPiCI9MZrEaNUISuj4wex+4JNAEbDczBY759YF7PY1IN85d7GZTfX3nxew/XZgPZDUR8XW5OUiIgOcwp2I9JlgQ2BTcwv7qxvYU+mFvr1VDd59wHLRgRryC70gGDDve5uoiDC/NjAgAAaEwYy2GsJoNQ2V3jAH2Oyc2wpgZk8DC4DAcDcd+D8A59wGMxtnZiOcc7vNLAs4D/gB8D99VehVReVEhBkzRvVZnhQRkR6kcCci/U5EeBjDk2IYnhTT6b7NLY791X74a71VesulfhAsKa9j9c5y9lU30NxBEowKDyOtLQRGHdI8ND0hiuTYSFLi/PvYSJJiIwkPUxiUoxoNFAYsFwFz2+1TAFwCvG5mc4CxQBawG/g58L9AYq+XNLBARWVMGZlITGR4Xz6tiIj0EIU7ERnQwsOMjESvJq4zLS2OAzUNB2sCq+opraw/dLmqnvUlleyrrqexuYMqQV9iTATJsZF+8Iv0H0cdspzib09uXY6LUr/BoaOjf+T2H6h7gPvMLB9YDbwPNJnZ+cAe59xKM/voEZ/AbCGwEGDMmDHdLnBLi6OgsIzzNXm5iMiApXAnIkNGWJiRlhBNWkI0UzqpEHHOUV7byN6qBsprG6mobaSstoHymkbKahspr22kvMa7L6ttZHdFFWU1jZTXNhw1FEaEWVsoTD4sBEYdstwWGv376AjVpgwgRUB2wHIWUBy4g3OuArgBwLzEv82/XQFcaGbnAjFAkpn9zjl3dbvjFwGLAGbPnn3kD12Qtu+rpqKuiTwNpiIiMmAp3ImIdMDMSImL6vJcX845ahubvdDXGv5qAsJhwPry2kb2VzewtbTaC5B1jbijfEWPjQxvC31JfgBMiokkKTaCpBh/OTaSpJiIg9v95QQNMNPXlgOTzGw8sBMvsF0VuIOZpQA1zrkG4CbgNT/wfdW/4dfc3dk+2PWGgqIyQIOpiIgMZAp3IiI9yMyIi4ogLiqCzOTYLh3b3OKoqms6LASWtdYc1hxcX1bbSOH+GirrmqiobaSyg2kmAoUZJLYFQC8MtgbDgyGxo6DorYuNVHPSrnDONZnZrcBSvKkQHnXOrTWzW/ztDwLTgCfMrBlvoJUbQ1ZgoKCwnLiocCYOTwhlMUREpBsU7kRE+onwMPOaYMZFdvnY5hZHZV0jFbVNVNR5YbCirtFvUhq4rqmtmenWvVVt22oamo96/shwOxgA/ZrBpHY1h621henxUZw6Mf1Y34ZBwzm3BFjSbt2DAY/fAiZ1co5XgVd7oXiHyS8sY9boZA0WJCIygCnciYgMAuFhx9aMtFVDU4sXDv2awI6DYSPltQe3F5fVti03NLe0nWtcWhyvfvljPfXSpA80NLWwrqSCG04dF+qiiIhINyjciYgIURFhbYPNHIu6xmY/BDbRGBD0ZGCIDDf+eftHiIoIC3VRRESkGxTuRESk22Iiw4mJDGd4n87KJj3FzJiQob52IiIDnX6iExERERERGQQU7kRERERERAYBhTsREREREZFBIKhwZ2bnmNlGM9tsZnd1sN3M7Bf+9lVmdkJnx5pZqpm9ZGab/PthPfOSREREREREhp5Ow52ZhQP3A/OB6cCVZja93W7z8ebqmQQsBB4I4ti7gGXOuUnAMn9ZREREREREjkEwNXdzgM3Oua3OuQbgaWBBu30WAE84z9tAiplldnLsAuC3/uPfAhd176WIiIiIiIgMXcGEu9FAYcBykb8umH2OduwI51wJgH8/PPhii4iIiIiISKBgwp11sM4FuU8wxx79yc0WmtkKM1tRWlralUNFRERERESGjGDCXRGQHbCcBRQHuc/Rjt3tN93Ev9/T0ZM75xY552Y752ZnZGQEUVwREREREZGhJ5hwtxyYZGbjzSwKuAJY3G6fxcC1/qiZJwPlflPLox27GLjOf3wd8PduvhYREREREZEhy5zrvJWkmZ0L/BwIBx51zv3AzG4BcM49aGYG/Ao4B6gBbnDOrTjSsf76NOBPwBhgB3CZc25/J+UoBT7s+ss8RDqwt5vnGGr0nnWd3rOu03vWdYP5PRvrnFNzjSD10PURBvdnqrfoPes6vWddo/er6wb7e3bEa2RQ4W4wMbMVzrnZoS7HQKL3rOv0nnWd3rOu03smPU2fqa7Te9Z1es+6Ru9X1w3l9yyoScxFRERERESkf1O4ExERERERGQSGYrhbFOoCDEB6z7pO71nX6T3rOr1n0tP0meo6vWddp/esa/R+dd2Qfc+GXJ87ERERERGRwWgo1tyJiIiIiIgMOkMq3JnZOWa20cw2m9ldoS5Pf2dm2Wb2LzNbb2Zrzez2UJdpIDCzcDN738yeC3VZBgIzSzGzZ8xsg/9ZOyXUZervzOyL/v/JNWb2BzOLCXWZZGDT9bFrdH08drpGdo2ukV031K+RQybcmVk4cD8wH5gOXGlm00Nbqn6vCfiSc24acDLweb1nQbkdWB/qQgwg9wH/dM5NBXLRe3dUZjYa+AIw2zk3E28O0StCWyoZyHR9PCa6Ph47XSO7RtfILtA1cgiFO2AOsNk5t9U51wA8DSwIcZn6NedciXPuPf9xJd4flNGhLVX/ZmZZwHnAw6Euy0BgZknAGcAjAM65BudcWUgLNTBEALFmFgHEAcUhLo8MbLo+dpGuj8dG18iu0TXymA3pa+RQCnejgcKA5SL0hzhoZjYOOB54J8RF6e9+Dvwv0BLicgwUE4BS4DG/mc7DZhYf6kL1Z865ncC9wA6gBCh3zr0Y2lLJAKfrYzfo+tglP0fXyK7QNbKLdI0cWuHOOlinoUKDYGYJwF+AO5xzFaEuT39lZucDe5xzK0NdlgEkAjgBeMA5dzxQDai/z1GY2TC8WpXxwCgg3syuDm2pZIDT9fEY6foYPF0jj4mukV2ka+TQCndFQHbAchZDrJr2WJhZJN6F6ynn3F9DXZ5+7jTgQjPbjtes6eNm9rvQFqnfKwKKnHOtv3g/g3chkyP7BLDNOVfqnGsE/gqcGuIyycCm6+Mx0PWxy3SN7DpdI7tuyF8jh1K4Ww5MMrPxZhaF17lycYjL1K+ZmeG1817vnPtpqMvT3znnvuqcy3LOjcP7fL3inBtSvxZ1lXNuF1BoZlP8VfOAdSEs0kCwAzjZzOL8/6PzUAd76R5dH7tI18eu0zWy63SNPCZD/hoZEeoC9BXnXJOZ3QosxRs551Hn3NoQF6u/Ow24BlhtZvn+uq8555aErkgyCN0GPOV/qdwK3BDi8vRrzrl3zOwZ4D28EfveBxaFtlQykOn6eEx0fZS+omtkF+gaCeacmtWLiIiIiIgMdEOpWaaIiIiIiMigpXAnIiIiIiIyCCjciYiIiIiIDAIKdyIiIiIiIoOAwp2IiIiIiMggoHAnIiIiIiIyCCjciYiIiIiIDAIKdyIiIiIiIoPA/wdADUHMxal/GAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt # Visualization\n", "\n", "# Plot loss and accuracy in subplots\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", "ax1.set_title('Loss')\n", "ax2.set_title('Accuracy')\n", "for dataset in ('train','test'):\n", " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", "ax1.legend()\n", "ax2.legend()\n", "plt.show()\n", "plt.clf()" ] }, { "cell_type": "markdown", "id": "qQbKS0tV3sZ1", "metadata": {}, "source": [ "## 12. Perform inference on test set\n", "\n", "Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels." ] }, { "cell_type": "code", "execution_count": 66, "id": "DFwxgBQf44ks", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def pred_step(state, batch):\n", " logits = state.apply_fn({'params': state.params}, test_batch['image'])\n", " return logits.argmax(axis=1)\n", "\n", "test_batch = test_ds.as_numpy_iterator().next()\n", "pred = pred_step(state, test_batch)" ] }, { "cell_type": "code", "execution_count": 67, "id": "5d5nF3u44JFI", "metadata": { "outputId": "1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqkAAAKqCAYAAAAZssdpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABhcUlEQVR4nO3debxV8/7H8c+neZ7k0qDipktRcRMqDcqQuBVFbshM5Ip0yVS5dCV0dQ0ZKq6hIopKoSRjUt1QJNU9NKGRSnPr98c5Hr/z+e5jD2dP33XO6/l47MfjvPdee63vOefb2p+z+uzv1iAIBAAAAPBJiWwPAAAAAHBRpAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7oShSVTVHVTvFsV2gqg0LeYxCPxf+YK4gHswTxIu5gngwT9IjFEWqr1S1rKqOVtUfVXWzqk5V1TrZHhf8o6odVHWOqv6sqjnZHg/8pKr9VXWVqv6iqutUdaSqlsr2uOAfzimIh6oOUdW9qro93+2IbI8rXhSpyblRRE4WkaYiUltEtorIv7M5IHhrh4iMFZGB2R4IvDZVRI4PgqCKiBwjIs1E5G/ZHRI8xTkF8ZoYBEGlfLdV2R5QvEJVpKpqS1X9RFW3qup6VX1UVcs4m52VdyVio6qOUNUS+Z5/uap+rapbVPUtVa2f5JAOF5G3giD4MQiCXSIyQUSaJLlPpIBvcyUIgvlBEDwvIqE5ORQHHs6TlUEQbP1t9yJyQESK1X/v+crDucI5xUO+zZOwC1WRKiL7ReQmEakpuVcwO4rIdc423UWkhYgcLyJdReRyERFV7SYit4vIuSJysIh8ICLjCzqIqt6WN8EKvOXbdIyItFbV2qpaQUR6i8iMlHynSJZvcwV+8m6eqOpfVfUXEdkouVdSn0zFN4qkeTdX4CUf58k5mtuSuFRV+6bim8yYIAi8v4lIjoh0KuD+/iIyOV8OROTMfPk6EZmd9/UMEbki32MlRORXEamf77kNExxXFcmdQIGI7BOR/4pIjWz/vIrzzde5km9fnUQkJ9s/p+J+832e5D3/SBH5h4gcmu2fV3G++T5XOKf4cfN1nohIY8ltRywpIq1EZL2IXJjtn1e8t1BdSVXVRqo6TVV/yLvSMExy/1rJb3W+r7+T3F+OiEh9EXkk318ZmyX3v9OSeaPTEyJSTkQOEpGKIvKacCXVCx7OFXjI53kSBMG3IrJURB5Pxf6QHJ/nCvzh2zwJguCrIAjWBUGwPwiCj0XkERHpUdj9ZVqoilTJLQqXiciRQe4bC26X3F9gfofl+7qeiKzL+3q1iFwTBEG1fLfyeb80Q1VvV/tOOHPLt2kzEXk2CILNQRDsltw3TbVUVXdCIvN8myvwk+/zpJSI/LHQ3x1Syfe5Aj/4Pk+CAsbjrbAVqZVF5BcR2a6qR4lIQb0VA1W1uqoeJrnvvp+Yd/9oERmkqk1ERFS1qqr2LOggQRAMC+w74cwt36aficglefsqLbmX7dcFQbAxNd8ukuDVXFHVEqpaTkRK50Ytp5HN9Mg83+bJlar6h7yvG4vIIBGZnapvFknxba5wTvGTb/Oka96xVFVbSu5qIa+n7ttNr7AVqbeIyF9FZJuIPC3//4vN73URWSgii0VkuuS+uUmCIJgsIsNFZELeJfglItI5BePZJSLfisgGETlLchuikX2+zZW2IrJTRN6U3L+cd4rI20nuE8nzbZ60FpEvVXWH5M6VNyX3Sgyyz7e5wjnFT77Nk14isiJvPP8RkeFBEDyX5D4zRvMaawEAAABvhO1KKgAAAIoBilQAAAB4hyIVAAAA3qFIBQAAgHdKRXtQVXlXVcgFQZCR9dCYK+GXibnCPAk/zimIF+cUxCPaPOFKKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAO6WyPQCgqLj00ktNHjdunMmzZs0y+bTTTkv3kIq92rVrm1yrVi2TDzrooIT2d+qpp0bdfxAEEc+ZPn26ybNnzzZ506ZNCY0B4fTBBx+YXNDvvXfv3ibv2LEjrWMCfMeVVAAAAHiHIhUAAADeoUgFAACAd+hJdbRp08bkbt26mVyjRg2Tt27davJ9991n8vjx4012+xDfeOMNk7t27RrvUOGZ008/3eQDBw6Y3LZtW5M7dOhg8pw5c9IzsCLs+eefN7ljx44mly9f3uRy5cqZXLZsWZML6imNRlVjPv+iiy4yecuWLSa/9957Ji9cuNDkf/3rXybv3LkzoTHCD7t37zb5zDPPjNjmyCOPNHnx4sXpHBI81LlzZ5Pr1q0bsc2DDz5ocpUqVUyeNm2ayY8//rjJM2bMSGaIGcWVVAAAAHiHIhUAAADeoUgFAACAdzRaD5aqJtag5blOnTqZfOedd0Zs4/akliiRWB2/fv16k911GV179+412e2RS1YQBBp7q+QVtbkSj+rVq5u8YMECkxs0aGDyrl27TG7atKnJK1euTN3gCiETcyXV88Tt+3XPZ+7jGzZsSOXhI5QqFdnmH2st1lh9rVOmTDF5wIABJufk5MQ/wBTgnFI4vXr1Mrl///4R29x8880mf/zxx+kcUtqF8ZySbu7ayi+99JLJzZo1M9ntNy2MX375xWS319mdi8uWLTPZ7adOtWjzhCupAAAA8A5FKgAAALxDkQoAAADvFOl1Um+44QaT3TVMK1WqFHMfbm+G23d43HHHmdykSZNEhpj1PkQU3iWXXGKy24Pq+vbbb03md5+8M844I+rje/bsMXnu3LnpHE5EP5mIyNNPP23yHXfcYbLbt/7EE0+Y7K7VPHr0aJMz3ZOKwrn++utNPvHEEyO26d27t8lh70mFyCGHHGLy1KlTTW7evHnax+D2tbprdi9atMjke+65x+ShQ4emZ2Bx4EoqAAAAvEORCgAAAO9QpAIAAMA7oe5JLVOmjMlPPvmkyW7PoLse4aZNmyL26X6esrue2P79+01210CcPn26yS1btow4Rn7Dhw+P+jj81aNHj4S2z2ZfT1H1zjvvZHsIxueffx5xX6xzwOmnnx71cfe8hXCaNWuWya1bt87SSJBJ1apVMzkTPajJcteQr1Gjhsk33nhjxsbClVQAAAB4hyIVAAAA3qFIBQAAgHdC3ZPavn17k/v06RN1e7cH9eyzz47YZuHChQmNoUQJW+dXrFgx6vZr1qwxec6cOQkdD9nhrmX5e/dFM23atFQNByFSr149k/v162eyu35muXLlTH7xxRdNTvdar0iPr776KttDQAZUrVrV5Lvvvjvlx9i+fbvJa9euNdl9r0zNmjVNHjx4sMnu+3t27txp8gUXXFCocaYCV1IBAADgHYpUAAAAeIciFQAAAN7RIAh+/0HV33/QA7Nnzza5Q4cOUbd310B9++23Ez5mnTp1THbXE7vmmmuiPv+EE04wOdEe2EQFQZCRRRZ9nyvJateuXcR97777btTnfPrppyafcsopJrtr7mZbJuZK2OdJ+fLlTXb7z6666qqI51x99dUm165d2+Q9e/aY7K6d7Ga3XyzTOKekhvt7F4l8TevcuXOmhpMWxfGc8vrrr5tc0HtfEjFz5syI+55++mmTp0yZYvJJJ51ksrsm79SpU01evnx5EiNMXrR5wpVUAAAAeIciFQAAAN6hSAUAAIB3QrVOauXKlU3+4x//GHV7t5fD/ezkeBx22GEmP/744yZ36dLF5AMHDph88803m7xo0aKEx4DMO/TQQ00eM2ZMwvu47777TPatBxWRevbsafJ5551n8tFHH23ysccea3K0Hv/fc/HFF5s8adKkhPeB8HF71kUi3/OA8Dn99NOTev7kyZNN7t27d8Q2u3fvjrqPefPmRc1hwpVUAAAAeIciFQAAAN6hSAUAAIB3QtWT6vZ/uZ+J7XLXAnP7RQvi7nP69OkmN2nSxOR9+/aZfMcdd5g8atSomMeEf2rUqGHy4YcfHvM5H374ocmx1lFF6rl96/fee6/Jxx9/vMnu+oHJUk18WchHHnnE5BtvvNHkVatWRX3+iy++aPLcuXNNjtW/huz48ssvI+5z19lt2LChyStWrEjrmJC8jRs3muyui+x6//33TXbXWi/u/365kgoAAADvUKQCAADAOxSpAAAA8E6oelITFatntVWrVhH3jR071uRGjRpF3ceTTz5p8ogRI+IcHXxWt27dhJ/j9hZl+zPWi6MjjjjC5H79+iX0/LVr15q8YcMGk//3v/+Z/PHHH8fcZ7ly5UyuVq2ayZ06dTK5YsWKJp977rkmV6hQwWR3nVV3fWh3vd6PPvoo+oCRNSVLljTZ7ZmmJ9V/t99+u8nPPvts1O1r1qxpcvXq1U3etGlTSsYVVlxJBQAAgHcoUgEAAOAdilQAAAB4hyIVAAAA3gnVG6fmz58fNbds2dLks88+2+SlS5eaPHTo0IhjuIu2u2+k+Nvf/mbylClTfn/ACA33zS1///vfYz7nxx9/NPmpp55K6ZiQOPeNTg8++GDU7d2F8NevXx91f9ngvnmzS5cuJt95550mn3HGGSZ37NjR5AceeMDku+66K9khAiikxo0bm9ytWzeTY53DijqupAIAAMA7FKkAAADwDkUqAAAAvKNBEPz+g6q//6AHBgwYYHIqFtJ/5513oh5jyZIlSR8jk4Ig0Ewcx/e5Eovb11dQv7Jr6tSpJru9RGGTibkS9nnio4MOOsjkxx9/3OQePXqYvG7dOpMPO+ywhI7HOSU13N+TiEjfvn1NvvTSS01+7rnn0jmklCuO5xT3g2BmzJhhstuD6lqzZo3JTZo0idhm+/bthRydn6LNE66kAgAAwDsUqQAAAPAORSoAAAC8E6p1Ul0vvfSSyYn2pL7yyisR91100UUm7927N/GBIXRq1KiR8HMee+yxNIwESMymTZtMHjZsmMk9e/Y0uU6dOmkfEwon2ntEEA5uT+m//vUvky+//HKTTzrpJJPdntbXXnst4hjPPPOMyS+//HKiwwwNrqQCAADAOxSpAAAA8A5FKgAAALwT6p7U0047LaHtN2/ebPLFF18csQ09qMVDpUqVTL7hhhuibn/gwIGI+7Zt25bSMQGpcOWVV5rs9jkuXLgwk8NBAvbv32/y3LlzszQSpMqYMWNMdntMx44da3Lr1q1N7tixY8Q+K1asaPKcOXNM3rBhQ8Lj9BVXUgEAAOAdilQAAAB4hyIVAAAA3glVT+rxxx9v8ujRoxN6fpUqVUxu2bJlxDYffvhh4gND6Nx9990mlygR/e+1mTNnRtw3b968lI4JiatatarJ+/btM3nHjh2ZHE5G/PnPfzb5jjvuMLlLly4mu/3U48ePT8/AkDT3d5WTk5OdgSBttmzZYnL37t2j5kmTJkXsw11b1e1rdddG3rVrV8Lj9AVXUgEAAOAdilQAAAB4hyIVAAAA3glVT+o//vEPk1XV5MWLF5vcvHlzk0uVst9u9erVUzY2hEvfvn2jPr57926TR4wYkc7hoJC++uorkx9++GGTH3rooUwOJyXcfrLjjjvOZHcd1Jo1a5rsrovqnjdHjhyZ7BABpMl7771nckHvfXB7Us866yyTBw0aZPLgwYNTM7gs4EoqAAAAvEORCgAAAO9QpAIAAMA7Xvekuj2lZ5xxhslvvvmmyS+++KLJrAeI39StW9fkWOuirly50uT3338/5WNC8mrXrm3y7bffbnLp0qVNXrRokclvv/22ye3btze5TJkySY4wct1Sd71n97O6E/XOO++YfNNNN5ns9u3CDxdffHG2hwAP7dy50+QVK1ZEbOP2pBZlXEkFAACAdyhSAQAA4B2KVAAAAHjH657UY445xmS3j7BevXqZHA5CrGvXriaXK1cu6vYvvPBCOoeDFOnXr5/J7jqp9913X9Tn//TTTya7a47G6l1212p21yiNx+eff26y+9ner776qsnuZ3n/8ssvJrs9bfBT2bJlI+6bNm1aFkaCVCpZsqTJbl98LHfccYfJF110UdJjCjOupAIAAMA7FKkAAADwDkUqAAAAvON1T2os7tqX3bp1y85A4L0TTzwx6uO//vqryXPnzk3ncJAijz32mMmLFy82+amnnjK5Vq1aJlerVs3kDRs2RD2eu37uJ598YnJBPam7du0y2e0x/eKLL6IeE8XH+vXrsz0EJOnPf/6zye6/d3dt51RwX7/mz5+f8mNkC1dSAQAA4B2KVAAAAHiHIhUAAADe8bon1e0vW7BggcktWrQw+YILLoi6v++//97kWbNmFX5wCJXHH3/c5PPPP99kdz3NefPmpX1MSL2PPvrI5CZNmpjs9qS6edGiRekZGIBiwe0HXbVqlcmp6El97733TL7++utNXrZsWdLH8AVXUgEAAOAdilQAAAB4hyIVAAAA3tFonzWtqol/EHUatWnTxuSpU6eaXLVqVZPdNQ87d+5scnHoPwuCQGNvlTzf5goSl4m5wjwJP84piBfnFJGzzz7bZLd/9PTTTzd506ZNJg8aNChin+77a955551khph10eYJV1IBAADgHYpUAAAAeIciFQAAAN4JVU8qEkf/GOJF/xjiwTkF8eKcgnjQkwoAAIBQoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4J2o66QCAAAA2cCVVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeCUWRqqo5qtopju0CVW1YyGMU+rnwB3MF8WCeIF7MFcSDeZIeoShSfaeqZVR1maquyfZY4CdV7aCqc1T1Z1XNyfZ44CdVHaKqe1V1e77bEdkeF/zDOQXxUtXjVfX9vPPJj6p6Y7bHFC+K1NQYKCI/ZXsQ8NoOERkruXMFiGZiEASV8t1WZXtA8BLnFMSkqjVFZKaIPCkiB4lIQxF5O6uDSkCoilRVbamqn6jqVlVdr6qPqmoZZ7OzVHWVqm5U1RGqWiLf8y9X1a9VdYuqvqWq9VMwpsNF5CIR+Wey+0Lq+DZXgiCYHwTB8yJCweER3+YJ/OXbXOGc4iff5omI3CwibwVB8GIQBLuDINgWBMHXSe4zY0JVpIrIfhG5SURqisjJItJRRK5ztukuIi1E5HgR6Soil4uIqGo3EbldRM4VkYNF5AMRGV/QQVT1trwJVuDN2fzfefvdmfy3hxTyca7APz7Ok3NUdbOqLlXVvqn4JpESPs4V+Me3eXKSiGxW1Y9V9SdVnaqq9VL0vaZfEATe30QkR0Q6FXB/fxGZnC8HInJmvnydiMzO+3qGiFyR77ESIvKriNTP99yGCY6ru4jMzPu6vYisyfbPqrjffJ0r+fbVSURysv1zKu43X+eJiDQWkdoiUlJEWonIehG5MNs/r+J883Wu5NsX5xQPbr7OExFZLiJbReQEESknIqNE5KNs/7zivYXqSqqqNlLVaar6g6r+IiLDJPevlfxW5/v6O8k94YuI1BeRR/L9lbFZRFRE6hRyLBVF5AERuaEwz0d6+TRX4C/f5kkQBF8FQbAuCIL9QRB8LCKPiEiPwu4PqePbXIGfPJwnOyW3SP4sCIJdIjJURFqpatUk9pkxoSpSReQJEVkmIkcGQVBFci+Lq7PNYfm+rici6/K+Xi0i1wRBUC3frXzeC4GhqrerfXetueVtdqSINBCRD1T1BxF5TURq5U3MBqn6hlFoPs0V+Mv3eRIUMB5kh+9zBX7wbZ58Ibnnkd/89nUozithK1Iri8gvIrJdVY8SkYL6tQaqanVVPUxEbhSRiXn3jxaRQaraREREVauqas+CDhIEwbDAvrvW3PI2WyK5E6153u1KEfkx7+vVBewWmeXTXBFVLaGq5USkdG7UchrZTI/M822edM07lqpqSxH5m4i8nrpvF0nwba5wTvGTV/NERMaJSHdVba6qpUXkLhH5MAiCrSn5btMsbEXqLSLyVxHZJiJPy///YvN7XUQWishiEZkuImNERIIgmCwiw0VkQt4l+CUi0rmwAwmCYF8QBD/8dpPcy/IH8vL+wu4XKePNXMnTVnL/2+VNyf3LeaeEaBmQIsy3edJLRFbkjec/IjI8CILnktwnUsO3ucI5xU9ezZMgCN6V3Ku50yV3qcyGeeMLBQ2CIPZWAAAAQAaF7UoqAAAAigGKVAAAAHiHIhUAAADeoUgFAACAd0pFe1BVeVdVyAVBkJG10Jgr4ZeJucI8CT/OKYgX5xTEI9o84UoqAAAAvEORCgAAAO9QpAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7UddJBQAAQHa0b98+4r45c+aYPHToUJOHDBmSxhFlFldSAQAA4B2KVAAAAHiHIhUAAADeoScVAADAA24Pqtt/WtxwJRUAAADeoUgFAACAdyhSAQAA4B16UgEASNJDDz1kcv/+/U1+9dVXTT7//PPTPSSEUEHrosby3nvvpXwcvuBKKgAAALxDkQoAAADvUKQCAADAOxoEwe8/qPr7DxZTkyZNMrl69eomd+zYMZPDiSkIAs3EccI2VypXrmzywoULTd65c6fJN9xwQ8Q+3n///dQPLIsyMVfCNk8QiXNKwfbv32/ygQMHTF63bp3JF1xwQcQ+5s2bl/qBZRHnlNiGDBli8uDBg2M+x+1B7dChQwpHlHnR5glXUgEAAOAdilQAAAB4hyIVAAAA3qEnNYbWrVub7PaCzJ071+ROnTqle0gJoX+sYGXKlDF5xowZJrdr187k2bNnR+zjjDPOSP3Asoj+McSDc0rB3HVPx48fb3KJEvaakNuzKiJSsmTJ1A8sizinRHLXQZ0zZ07C+1DNyD/BjKEnFQAAAKFCkQoAAADvUKQCAADAO6WyPYBUqlGjhsmbN29Oep8NGzY0uVSpIvUjK7b27Nlj8saNG6NuX69evYj73L5Wd58Aig/3/R0F9Zwm8jiKhmR7UMO+BmqyuJIKAAAA71CkAgAAwDsUqQAAAPBOqBssmzRpYvKzzz5rctu2bU12P489Hsccc0zUxydMmJDwPhE+jRo1irjv5JNPNtldMxfFT7Vq1SLuGz58uMlffvmlyY8++mg6h4QMcdeudNdFdXNBXn75ZZPdtVcRPm5PaizuWuxuLm64kgoAAADvUKQCAADAOxSpAAAA8E6oe1Jvvvlmk1u0aGFy+fLlTY6nJ7Vq1aomX3311Sbv3r3b5BdeeCHmPgEUTe46yvPnz4/Yxu1TddfTHDFihMnbt283+ZVXXjF56tSpJn/66acmp2J9aCQuFeuksnZq0dOuXbuEtue9DRZXUgEAAOAdilQAAAB4hyIVAAAA3glVT2qpUna4xx9/fNTt3c9bj6dX68gjjzS5SpUqJk+ZMsXkXbt2xdwn/Pfhhx+a3KNHD5PdNRBFRPr27WsyvURFX5s2bUx2zwcFrZPq+uKLL0xu1qyZyeXKlTP52muvjZrdXvv9+/ebvGrVKpMXLFhgstsH6a7b6o4XBUvFOqnua1bdunVNXrNmTSFHh2xJdJ3UIUOGpGUcYcWVVAAAAHiHIhUAAADeoUgFAACAd0LVk3rZZZeZ3Lx586jbf//99wkfg89KLp7cz1N31zxE8VCjRg2TH3/8cZO7detmcpkyZUwu6Jxz6623mjx58mST//nPf5p80003xTXW37jrQbuaNm0aNbu9lH369DG5bNmyCY2nuPr444+j5latWplc0JqoJ554YtRMT6r/Eu0p7dChQ3oGUkRwJRUAAADeoUgFAACAdyhSAQAA4B2ve1LdXqjrr78+6vbPPvusyVu2bIm6fUG9XGeffXZ8g0ORsnfvXpPdtSbdNXpFRBo3bmxyxYoVTd6xY0eKRod0cXtQZ86caXKLFi2iPn/btm0mF9SPNnHixKj7cHtW77//fpPdeZXsOapkyZIm/+EPfzB5/vz5Se2/uHL7RV977TWTW7dubXJB66a6/cEvv/yyye7vDv4ZPHhwQtu/99576RlIPnPmzIn6uLvGt09rtXIlFQAAAN6hSAUAAIB3KFIBAADgHY22HqSqZnWxyEsvvdTkcePGRd1+9uzZJn/zzTdRtz/mmGMi7mvbtm3U57i9iytXrjR57NixJo8YMSLq/tItCILID51Pg2zPlVRz1011+09FItdSrVWrlskbNmxI/cDSKBNzJdvzJNkeVHf7e++912R3bcyiiHNK4bh97gWtk+r2qbrblC5dOvUDS6PicE5xJbrGttuHXBjt27ePmhPtk03FmBIRbZ5wJRUAAADeoUgFAACAdyhSAQAA4B2v1kl110UdMGBAQs/v2LFj1JwKbk/QUUcdZXLPnj1NznZPKlCcVa9e3eREe1BfeOEFky+77DKT3T5D4Pd8+umnJp944okR27i9gAWtpQq/JLqm6NChQ1M+Brfn1O1JDTP+BQAAAMA7FKkAAADwDkUqAAAAvEORCgAAAO949cYpdzH0Ro0aZWkk/++HH34wOdYHCvzxj39M53CQJQW9gaGgxbjhlxtvvNHkWG+Uuu+++0x23xTBG6VQWCNHjjT5pZdeitgm1mL+N910U9R9wn+JvtGqIHPmzDG5KL1RysWVVAAAAHiHIhUAAADeoUgFAACAd7zqSc3JyTH5zjvvNLlevXoJ7e/VV1812V3c392/iMjGjRtNPvroo03eunVrQmNA0VBQ/2kQBFkYCRJx8803R338/fffN9ldFJu+Y6RLQX3usRbzP+mkk9I6JqSf2z/63nvvJb2PoowrqQAAAPAORSoAAAC8Q5EKAAAA73jVk+oaMWJESvd3xhlnxNxmwoQJJtODCoTXzJkzTe7Ro4fJ7trMBfWp57dhwwaT33zzTZO/++67RIeIYqqgfudY66TSB+8ft6fU7Wt3uY/H6klNxbqqsRSmLzZTuJIKAAAA71CkAgAAwDsUqQAAAPCO1z2pqdamTZtsDwEhsWTJEpMbN26cpZEgGVdffbXJNWrUMLlDhw4mJ9r/tWPHDpMfe+yxiG1uu+22hPaJomn16tUmr1u3LmKbww47zGS3R9VdRxXZl2g/p7vGqdtnPHToUJNj9bimgntMn3AlFQAAAN6hSAUAAIB3KFIBAADgnWLVk1qvXr2Y20ycODEDI4HvjjnmmGwPASngrnPcqVMnk48//niT3XVTjzjiCJPdz06/8MILTT711FMLM0wUA/PmzTP5k08+idimbt26JrvrpLrzz83uMZB5bt/wnDlzTHZ7Ul2Z6EF1e/FZJxUAAABIAEUqAAAAvEORCgAAAO8U6Z7UBg0amFy9enWTt2zZEvGcVatWpXNICCl3vUKRgj97G+GyaNGihLY/6KCD0jQSFDcFrXnq3ueed9x1VN0eVvgn1hqksXpUUyHM6+tyJRUAAADeoUgFAACAdyhSAQAA4J0i3ZNap04dkytXrmzyd999F/Gcgj5PGcXPlClTTG7cuHHENu5nLqPocdfLveKKK7I0EhQ1I0eOjLivR48eJrt9726P6o033mjypEmTUjQ6pIq7BmmsNUmHDBkSc5/uWqruPt11UMOMK6kAAADwDkUqAAAAvEORCgAAAO8U6Z7Uo48+OurjM2bMyNBIEDb0JkNEpGLFiib/4Q9/iLr9mDFj0jkcFCHz5s2LuC/WOqlhXu8S8YmnJzWebYoKrqQCAADAOxSpAAAA8A5FKgAAALxTpHtSjzvuuKiPb9myJUMjARBGd999t8mlS5c22V3Hcvr06WkfE4quhx56yOT+/fub7PaxXnjhhekeEpBVXEkFAACAdyhSAQAA4B2KVAAAAHinSPekup9jfOmll5r83XffZXA0CJOJEyeafO2110Zss3btWpPpcS565s6da3KnTp1Mds8xa9asSfuYUHQNHDgwagaKG66kAgAAwDsUqQAAAPAORSoAAAC8o0EQ/P6Dqr//IEIhCIKMfNgzcyX8MjFXwj5P6tSpE/Vxt0+5KOKcgnhxTkE8os0TrqQCAADAOxSpAAAA8A5FKgAAALxDT2oRR/8Y4kX/GOLBOQXx4pyCeNCTCgAAgFChSAUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgnajrpAIAAADZwJVUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN4JRZGqqjmq2imO7QJVbVjIYxT6ufAHcwXxYJ4gXswVxIN5kh6hKFJ9paoDVXWJqm5T1f+p6sBsjwl+UtUZqro9322Pqn6Z7XHBL5pruKpuyrs9oKqa7XHBP7z+IB5hnyelsj2AkFMRuUREvhCRP4rI26q6OgiCCdkdFnwTBEHn/FlV3xORd7MzGnjsahHpJiLNRCQQkXdEZJWIjM7imOAnXn8Qj1DPk1BdSVXVlqr6iapuVdX1qvqoqpZxNjtLVVep6kZVHaGqJfI9/3JV/VpVt6jqW6paP5nxBEHwQBAEi4Ig2BcEwTci8rqItE5mn0gN3+aKM7YGInKKiDyfqn2icDycJ31E5KEgCNYEQbBWRB4SkUuT3CdSwLe5wuuPn5gnqRWqIlVE9ovITSJSU0ROFpGOInKds013EWkhIseLSFcRuVxERFW7icjtInKuiBwsIh+IyPiCDqKqt+VNsAJvv/McldzCY2lS3yFSxdu5Irl/1X4QBMH/kvj+kBq+zZMmIvJ5vvx53n3IPt/mSv7n8PrjD+ZJKgVB4P1NRHJEpFMB9/cXkcn5ciAiZ+bL14nI7LyvZ4jIFfkeKyEiv4pI/XzPbZjEGIdK7gtK2Wz/vIrzLSRzZYWIXJrtn1Vxvvk6TyT3Be6ofPnIvP1otn9mxfXm61xxxsLrD/OkSM6TUF1JVdVGqjpNVX9Q1V9EZJjk/rWS3+p8X38nIrXzvq4vIo/k+ytjs+T2atRJwbj6Se7VsS5BEOxOdn9InsdzpY2IHCoik5LdF5Ln4TzZLiJV8uUqIrI9yHuFQfZ4OFd+GxevPx5hnqRWqIpUEXlCRJaJyJFBEFSR3Mvi7jtfD8v3dT0RWZf39WoRuSYIgmr5buWDIPjYPYiq3q72ndjm5mx7uYjcJiIdgyBYk6LvE8nzbq7k6SMirwVBUNBjyDzf5slSyX3T1G+aSZj+a65o822u8PrjJ+ZJCoWtSK0sIr+IyHZVPUpE+hawzUBVra6qh4nIjSIyMe/+0SIySFWbiIioalVV7VnQQYIgGBYEQaXfu/22nar2lty/kk4LgmBV6r5NpIBXcyVvP+VFpKeIPJuS7xCp4Ns8+Y+I3KyqdVS1togMEOaLL7yaK7z+eIt5kkJhK1JvEZG/isg2EXla/v8Xm9/rIrJQRBaLyHQRGSMiEgTBZBEZLiIT8i7BLxGRzgU8PxH3ishBIvJZvr9gWCrGD77NFZHcpYV+FpE5KdgXUsO3efKkiEwVkS/z9jc97z5kn29zhdcfPzFPUkhpdQIAAIBvwnYlFQAAAMUARSoAAAC8Q5EKAAAA71CkAgAAwDuloj2oqryrKuSCIHDXZ0sL5kr4ZWKuME/Cj3MK4sU5BfGINk+4kgoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8E6pbA8AAAAAIg0bNjT5pptuitimb9++Ufcxbdo0k6+66iqTf/zxx0KOLvO4kgoAAADvUKQCAADAOxSpAAAA8I4GQfD7D6r+/oMZcNhhh5ncqlUrk9u0aWNyt27dTK5Ro4bJa9euNfmzzz6LOOYNN9xg8ubNm+Maq6+CINBMHCfbcwXJy8RcKerzxO0nExE5+uijk9rn22+/bfLu3buT2l+yiuo5pUKFCibfeeedJh977LEmd+nSJanj/fTTTxH3ub2ErhdffNHkhQsXmvzLL78kNaZU45wSqVQp+1agwYMHm9yvXz+Tq1SpkvQxv/jiC5Pdubtu3bqkj5GMaPOEK6kAAADwDkUqAAAAvEORCgAAAO9ktSf1yiuvNPmvf/2ryY0bNza5Zs2aJqvaNoZo30u8HnnkEZMHDBiQ9D6zqaj2jyH16B+LrWrVqiY//PDDJrvnMBGRsmXLJnXM7777zuR77rnH5HHjxiW1/0QV1XNKz549TZ4wYYI7HpOTfb1x91eYfc6aNcvkgQMHmuz2ImYa55RI119/vcmjRo0yOZ559umnn5p83HHHmVymTJmo+xw+fLjJgwYNijLi9KMnFQAAAKFCkQoAAADvUKQCAADAO6Vib5I6F198scn//ve/TS5durTJ7ppvGzZsMNnts/j2229Nfuedd0yuXbu2yZdddlnEGHv37h11jDk5ORHPAVA0HXXUUSa755Q6deqkfQz169c3efTo0SZXrlzZZLfHDfG59dZbU7q/bdu2mbxy5UqT3T7CwujUqZPJbq+h22e7ffv2pI+J5Ljrv8fyn//8J+K+a665xmR3jfgxY8aYXLFixYSO6ROupAIAAMA7FKkAAADwDkUqAAAAvJPRntTu3bubvHXrVpOfe+45kx999FGT16xZk9LxtG7dOuI+d23Wq666yuQ77rgjpWNAOBTU0+N+lrerR48eJp977rkmH3744QmNwe1pc+fqnj17EtofIntK3X/vbt+6u/3y5ctNdnvDCuOGG24w+cILLzS5WrVqJo8YMcLkL7/80uQ5c+YkPabioF69elEf37Fjh8kjR440eenSpSa/9dZbJu/atcvkeM4pJ5xwgsnu57j379/f5DPOOMPkSZMmmeyek+hRzbwuXboktL37718kcu3k+++/3+TFixebXFCtExZcSQUAAIB3KFIBAADgHYpUAAAAeEejfVZwqj8T99prrzV50aJFJs+fPz+Vh4vJ7d0Siezze+WVV0zu1atXWseUakX1c7YbNGhgsruWpPvZxiVK2L/HOnfubPJ5551n8jHHHGNy+fLlI8bwxz/+Ma6xpovb07Zz586k9lccPmd7yJAhJrs9fW7PXyxun2JB54fp06cntE+X2yt58803m9y3b1+Tf/rpJ5PbtWtn8qpVq5IaT1E9p3zwwQcmt2rVyuRmzZqZvGTJkrSPKRb3POX2oB555JEm//3vfzf5oYceSs/A8hSHc0qiBgwYYLLbU+6u/15QjbZp0yaTGzZsaPK0adNMbtOmjcnuerq33XZblBGnX7R5wpVUAAAAeIciFQAAAN6hSAUAAIB3MrpOqvuZ0z5yexe3bNmSpZEgv7Jly5rsrv3o9qSuXbvW5IMOOsjkcuXKJT2mBQsWmOx+Vrfb8+z2ybrrXT722GNRj/fjjz+afODAgXiGWawccsghJrufx3755Zeb7Pag7t+/3+ScnByT3Z6/WbNmRd0+Fb7//nuT3T5at3eyRYsWJrtrbybbk1pUnXPOOSa7c8eHHlSXO6bXXnvNZLfX0O1NTHdPKiL9+9//Ntk9j7vrae/evTtiH3feeafJpUrZUs5dz9nta432XiTfcCUVAAAA3qFIBQAAgHcoUgEAAOCdjPakZlvz5s1NdvsYRUS2bt1qsts/guxwe2h++OEHk2vWrGmy22vofka1m//zn/+Y7PZ2uccTiex7Lah3KJpYa9O5c9Fd2zXR4xUHr7/+usktW7aMur3bg+quI+l+PruP3L7C8ePHm+x+Xrv7M0Iu99/boEGDsjOQJPzjH/8w+fzzzzfZ7V1034NBn3v67dmzx2T3HFOYc87pp59uckG1TVhxJRUAAADeoUgFAACAdyhSAQAA4J0i3ZPq9tu4695VqFAh4jnuNl999VXqB4aEuX08J598ssmNGjUy2V2zdP369ekZWALcfse77ror6vYTJ040efHixakeUugtWrTIZPfz1V3uOqbu52hPnjw5JePKpMMOOyzq4xdeeKHJF198cTqHgyzauXOnye5586yzzjK5adOmJnOOCacxY8ZEfXzHjh0mf/LJJ+kcTkpxJRUAAADeoUgFAACAdyhSAQAA4J0i3ZPq9tv07Nkz5nNWrlyZruEgjZYvX57tIURwPwt86NChJpctW9Zk93PhBw4cmJ6BhdhFF11kstuDqqomL1261GS3J2/16tUpHF12tGjRIurjbm8+8JvTTjvNZHpSw8l9LXG9/PLLJr/xxhvpHE5KcfYCAACAdyhSAQAA4B2KVAAAAHinSPekup9h7SponcopU6akaTQo6g4++GCTR48ebbLbN/Tiiy+afPPNN5u8ffv2FI6uaOjcubPJsXpQ+/bta3JR6EEtV66cyW7vveu5555L53DgEXcuxFpDF+F09dVXm1yzZs2o24dx/effcCUVAAAA3qFIBQAAgHcoUgEAAOCdUPekur1Zjz32mMnu57m7Xn/99ZSPCcVHyZIlTX777bdNrlWrlsnu5yc///zzJm/YsCGFoyua3HVOXW7/5YcffpjO4WTF3//+d5OPOuqoqNu/+uqr6RwOPOL++6hYsWKWRoJUcV9HRERGjBhhchAEJn/22WcmT5s2LfUDyxCupAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO+E+o1TAwYMMLlPnz4mu83E7hsIvvvuu/QMDMXCwIEDTW7WrJnJe/fuNfnCCy802X2jFWKrVq2aye6/8aKoYcOGJl9//fVRt9+zZ4/JOTk5qR5SKFWvXt3kq666ymT3zSZffvmlye6Ha+zatSuFo4tPqVL2Jbtly5Ym33rrrSbH+veRje8B0R100EEmDxs2LGKbSpUqRd2H+8EwYcaVVAAAAHiHIhUAAADeoUgFAACAd0LVk9q2bVuTb7/99qjbr1mzxuRLLrnE5N27d6dmYCjyhg4dGnHfXXfdFfU5vXr1MjnMCyr7Ys6cOSa3b9/e5Pfeey9zg0mTOnXqmLxgwQKTq1SpEvX5TzzxhMlLlixJzcBCzu1B/ec//5nQ8z///HOTly9fbvL48eNNXrp0qckrVqxI6HgFue+++0y+5ZZbTFZVk92e1HXr1pn87LPPJj0mROf20bu2bt1qcr9+/Ux265aCfPvttyYXpffbcCUVAAAA3qFIBQAAgHcoUgEAAOAdr3tS3V6OcePGmVyuXDmT3f4btyeQHlTEq2vXribH6j8VEXnxxRdNnjFjRkrHBJH169dHfdztUXX7OX1Uu3Ztk92+21g9qK7JkycnPaaiqEWLFkk9v3nz5ia76yL37NnT5F9//dXkxYsXmzxlypSIY8yaNctkd91T9xixbNq0yeSrr77a5G3btiW0P4g0atTI5DFjxkTd/g9/+EPUx3/66SeTW7dubXI8a0F369bN5LVr18Z8TlhwJRUAAADeoUgFAACAdyhSAQAA4B2ve1Ivuugik+vXrx91+5UrV5p87LHHmjxv3rzUDAxFzoUXXmjyv/71r5jPcdfhvfjii1M5JBQRpUuXNvmyyy4zedSoUSaXKVMmof3fdtttJn/44YcJPb+oOuSQQ0x2e/1c7nqVr732WtT9denSJer+KlSoYHKrVq1ijiee/sNEDBo0yGT65BN3xx13mPyPf/zD5Fhr08Zy5JFHRt1fPNwxunP3448/NvnHH39M+BjZwpVUAAAAeIciFQAAAN6hSAUAAIB3NFr/hKqmtkEmQaNHjzbZ/ezlEiVsjX3gwIGo+3M/w3rChAkR27zxxhtR9/HNN9+YvG/fvqjbJ+qwww4z2V1DLdG1XoMgSLzBpRCyPVcSdcwxx5j8zjvvmOz2n7n9pyIinTt3Ntn9rO6wycRcSXaeuGuGur+XLVu2mOz2ar3wwgvJHD6C+++1U6dOEdu4a1ueeeaZCR1j+/btJj/11FMmu32He/fuTWj/iQrLOcXt+fzggw+ibu/2Cj/33HNRt3fX6XbXpz3jjDOiPr+g3sNke1Ldfe7cudPk4cOHm/zMM8+YnOr1NcNwTonFfQ0+6KCD3OOb/PTTT5vsrt3csGHDqMdLtse1IG4P6sMPP2zygw8+mPQxkhFtnnAlFQAAAN6hSAUAAIB3KFIBAADgHa97UqtVq2bysGHDTHb7v4444oikjxmrH+STTz4xOScnx2S3pzXW/sqXL2/yAw88EHX/J510UuSgowhL/1i6uT9ndy3J4447zuS5c+eafOWVV0bs012XN+zC2D/m9lLdcMMNUbdfvny5ye56grH06NHD5EMPPdTk6tWrJ7S/grh977fccovJ06dPT/oYyQjLOcX9N//ll1+afPjhh5vsrnvqrqtdqpRdVvy8884z2V3X210X1ZWJntRY+3v//fdN7tChQ1LHd4XxnOJasGCBye5rhfszf+KJJ0zu1q2bye45w/X222+bPHbs2Iht3HW93fdHuGstu2NcuHChySeccELUMaUbPakAAAAIFYpUAAAAeIciFQAAAN7xuic1lkqVKpncq1cvk90e1WuuucbkqlWrRuwz1WuUpXp/bl9ULGHpH0s3d03c888/3+T9+/eb7PY2umv2FkVFoX/soYceMrlfv34mly5dOp2HL5C7frO7FqU75kcffTTq87MtrOeUmTNnmnzaaaeZ7K5Hu2fPHpPd9TGTPZdv3rw54r7HH3/cZLd30P2c92+//dZkd21Yd/vZs2eb/NZbb5m8YsWKKCNOXFE4p7h96BMnTnSPb3Ki82LWrFkmd+3a1eRdu3bF3EeDBg1MHjdunMktWrQwedKkSSa7awRnGj2pAAAACBWKVAAAAHiHIhUAAADeCXVPaqIOPvhgk9u1axexzSmnnGKy+xnvbu9H/fr1ox4z2X4Vt7fkqquuSuj5Ye0fS5a7nuybb75psrsG76hRo0zu379/OobltaLQP+Zy130cM2aMyXXr1jU50Z5vt49x2rRpEdu4/dDuWsphE9ZzStu2bU2eOnWqye57HAoYj8mxzuVuT6vb/9mnT5+I5/z8889R9xk2ReGc4q63+9xzz5ns9qy68+K7774zefjw4SaPHz/e5F9++aVQ44zGrWOWLFmS8mMkg55UAAAAhApFKgAAALxDkQoAAADvFKue1FRw+1rd7Orbt29C+1+3bp3JDz/8sMm7d+9OaH9h7R9LlNsr7H7+ubs+5kcffWTy2WefbXJR6w2LR1HoH0uU27tcs2bNhJ7vrju5c+fOpMfku6JyTjnjjDNMvuKKK0w+77zz3PGY7K41+dprr5m8bNkykxcvXlyYYYZacTynIHH0pAIAACBUKFIBAADgHYpUAAAAeIee1CKuqPSPuQ499FCTZ8yYYXKzZs1M3rFjh8knn3yyyb6tG5cN9I8hHkX1nILU45yCeNCTCgAAgFChSAUAAIB3KFIBAADgncQ+qBrwhLumoduD6jr22GNNzsnJSfWQAABACnElFQAAAN6hSAUAAIB3KFIBAADgHXpSEQrNmzc3+aabboq6/dixY01et25dqocEAADSiCupAAAA8A5FKgAAALxDkQoAAADvaBD8/sfe8pm44cfnbCNefM424sE5BfHinIJ4RJsnXEkFAACAdyhSAQAA4B2KVAAAAHgnak8qAAAAkA1cSQUAAIB3KFIBAADgHYpUAAAAeIciFQAAAN6hSAUAAIB3KFIBAADgnVAUqaqao6qd4tguUNWGhTxGoZ8LfzBXEA/mCeLFXEE8mCfpEYoi1VeqOkRV96rq9ny3I7I9LvhLVcuo6jJVXZPtscA/qlpNVZ9T1Z/ybkOyPSb4SVXLqupoVf1RVTer6lRVrZPtccEvqtpBVeeo6s+qmpPt8SSKIjV5E4MgqJTvtirbA4LXBorIT9keBLw1UkQqiEgDEWkpIher6mVZHRF8daOInCwiTUWktohsFZF/Z3NA8NIOERkrua89oROqIlVVW6rqJ6q6VVXXq+qjqlrG2ewsVV2lqhtVdYSqlsj3/MtV9WtV3aKqb6lq/Qx/C8gQH+eKqh4uIheJyD+T3RdSw8N5co6IPBAEwa9BEOSIyBgRuTzJfSIFPJwrh4vIW0EQ/BgEwS4RmSAiTZLcJ5Lk2zwJgmB+EATPi0goL6CFqkgVkf0icpOI1JTcvyA7ish1zjbdRaSFiBwvIl0l7wSvqt1E5HYROVdEDhaRD0RkfEEHUdXb8iZYgTdn83Py/qtlqar2TcU3iZTwca78O2+/O5P/9pAiPs4Tdb4+pvDfHlLIt7kyRkRaq2ptVa0gIr1FZEZKvlMkw7d5Em5BEHh/E5EcEelUwP39RWRyvhyIyJn58nUiMjvv6xkickW+x0qIyK8iUj/fcxsmOK7GkvvfLCVFpJWIrBeRC7P98yrON4/nSncRmZn3dXsRWZPtn1Vxvnk8T14QkddEpLKINBSRlSKyO9s/r+J883iuVJHcAiYQkX0i8l8RqZHtn1dxvfk6T/Ltq5OI5GT755ToLVRXUlW1kapOU9UfVPUXERkmuX+t5Lc639ffSW4RKSJSX0QeyfdXxmbJvUpR6EbzIAi+CoJgXRAE+4Mg+FhEHhGRHoXdH1LHp7miqhVF5AERuaEwz0f6+DRP8vxNcq+0fysir0tuEcKb7Dzg4Vx5QkTKichBIlJRcv+44Upqlnk4T0ItVEWq5P6jXCYiRwZBUEVyL4urs81h+b6uJyLr8r5eLSLXBEFQLd+tfF5xaajq7WrfsW9uUcYXFDAeZIdPc+VIyX0jzAeq+oPkvpjUyjuJNUjVN4xC8WmeSBAEm4Mg6B0EwaFBEDSR3HP0/BR+vyg8r+aKiDQTkWfz5sxuyW0naqmqbkGEzPJtnoRa2IrUyiLyi4hsV9WjRKSgHtCBqlpdVQ+T3Hc/Tsy7f7SIDFLVJiIiqlpVVXsWdJAgCIYF9h375vbbdqraNe9YqqotJfcqyOup+3aRBJ/myhLJPSk1z7tdKSI/5n29uoDdInN8mieiqn9U1YNUtaSqdhaRq0Xk3tR9u0iCV3NFRD4TkUvy9lVacv/beF0QBBtT8+2ikLyaJ6paQlXLiUjp3KjlNPKNXN4KW5F6i4j8VUS2icjT8v+/2PxeF5GFIrJYRKZLbnO5BEEwWUSGi8iEvEvwS0Skc5Lj6SUiK/LG8x8RGR4EwXNJ7hOp4c1cCYJgXxAEP/x2k9z/wjmQl/cXdr9ICW/mSZ4/i8iXeeP5p4j0DoJgaZL7RGr4NlduEZFdktsaskFEzpLc3ndkl2/zpK3kthC9KblXbXeKyNtJ7jNjNMhtqAUAAAC8EbYrqQAAACgGKFIBAADgHYpUAAAAeIciFQAAAN4pFe1BVeVdVSEXBEFG1m1lroRfJuYK8yT8OKcgXpxTEI9o84QrqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA75TK9gCiGTdunMknn3yyyR9++KHJS5YsMfmzzz4zOScnx+R9+/ZFHLNVq1Ymn3rqqSYPHjzY5M2bN0fsAwBQtFWqVMnkOnXqmDxp0iSTmzRpYvLixYsj9vn+++9H3WbevHkmL1u2LJ6hophr0aKFyW5tdODAAZP/+te/mjxx4sT0DCwOXEkFAACAdyhSAQAA4B2KVAAAAHhHgyD4/QdVf//BNGjTpo3Js2bNMrl06dImq6rJ0b4XEZENGzaYvH///ohtatWqFXWf3bt3N/mNN96IesxsC4JAY2+VvEzPlTBasWKFyQ888IDJTz31VCaHEyETc6WozZNy5cqZ3Lp164ht3PPa4YcfbvJZZ51l8tKlS012+xLvvvtuk7dt2xbXWFOFc0qul156yeQLLrgg6X3Gek37/vvvTR41apTJI0eOTHoMqcQ5JTvatWtn8tixY01u0KCByW5PqvtadfTRR6ducAWINk+4kgoAAADvUKQCAADAOxSpAAAA8I5XPamu+++/3+SBAwea7K5R+sEHH5js9pcecsghJrv9PSIiW7ZsMfmLL74w+YknnjD5hx9+iNiHT+gfy3XppZeavHbtWpPfeeedlB/TXZvu008/NfmFF14wuU+fPikfQyLoH4vkrn15xx13mNy1a1eT3XNOOsycOdPk888/3+Tt27en9fjF5ZxSvnx5k59//nmTzzjjDJMrVKiQ9DETfZ+F+76K9957z+SLLrrI5J9++qnwgysEzinpcfDBB5vcuHFjkydMmGByzZo1Td65c6fJ7uuh+3rpvnalGj2pAAAACBWKVAAAAHiHIhUAAADeKZXtAUQzefJkk92eVPdzi88991yTy5QpEzXv3r074ph79+5NeJzwT8+ePU121yA99dRT0z6GE044wWS336ygnmhkVokS9u/0k08+2WR3HeTq1atH3V9B5xT389hXrlyZyBAj1t90e50rVqxocrp7UouLKlWqmOyukZ2sTZs2Rdznzh93/rmvcX/4wx9M7tixo8m9e/c22bd1VFE4bdu2NfnJJ580uWrVqlGf/91335ns9i5//vnnSYwutbiSCgAAAO9QpAIAAMA7FKkAAADwjtc9qccee2xSz9+zZ0/UjKLDXaPQ/XzzHTt2mJyTk5PuIUmPHj1MXr9+vclunywyz12L+ZZbbkno+dOmTTP5yiuvjNgm2bUpBw0aZLLbqz9//nyTO3ToYPKqVauSOn5xdeaZZ0Z9/OOPPzb5scceM7l58+YmL1682GR37ohErsv7zTffmHz99deb/O2335p8xBFHmOyu5froo4+azHswwsntRS5dunRCz3fXVW3Tpo3J9KQCAAAAUVCkAgAAwDsUqQAAAPCO1z2py5cvN9ldZ9JVt25dk+vXr2+yu25l2bJlI/YxY8YMk7/44ouY40T23XjjjSY3adLEZHcduTVr1qR8DC1btjT5lFNOMdld13f16tUpHwOsQw45xORRo0aZfN5550V9/tatW03+y1/+YvInn3xisvtZ6oXhrnF43XXXmdy+ffuoz3/mmWdMzsSawEVRo0aNoj7urnPsfl66m+Ph9qDG8vLLL5t82223JXxM+O2uu+6KuG/IkCFJ7fO+++4z2e2n9glXUgEAAOAdilQAAAB4hyIVAAAA3vG6J/Wss84yOQgCk9015WbPnm1yw4YNEz6mu77mww8/bLK7RqG79t2BAwcSPiaS5/YebtmyxeR///vfaR9DuXLlTC5Vyut/XsVCr169TO7Zs2fU7X/88UeTTzrpJJPdz7xOhWbNmpnsfr56rB5U9/Pe//Wvf6ViWIjB/fddsWJFk921meNRooS9buT2xd5zzz0mx+qt//nnn03et29fwmNCZvXt29fkgvpPY9UZc+fONdmtW3zuQXVxJRUAAADeoUgFAACAdyhSAQAA4B2vm+Y6deoU9fEGDRqY7Pasrlu3zuRZs2aZvGTJkoh9uusg3n777VGzuy7diBEjfn/ASJkyZcqY3LVrV5Nfeuklk7/66qu0j+m0006L+rjbw+b20br9kEjeBRdcEPVxdx3UHj16mJyOHtRLLrnE5AEDBph87LHHJrS//v37m/zGG28Ualyw3NcL91zvzpVu3bqZ3KdPH5PdPsGCXt/ctb7vvffeuMb6m19//dVkd21xt29+586dCe0fqVetWjWTY/XNx8OtjebNm5f0PrOFK6kAAADwDkUqAAAAvEORCgAAAO943ZNar169hLZ3+3cefPBBk7dt2xZzH+5ne3fs2NHkiRMnmux+Bq67buo777wT85hIXNOmTU2uX7++yW6vYTqULl3aZHdNTddRRx1l8owZM0w+/vjjUzOwYqx8+fImV6pUKer2OTk5Jn/00UcJHa958+Ymu73RIpG9i+7al+48imXTpk0mP/PMMwk9H/H573//a7L7HofatWub7Pacv/jiiya7a5q6/acikT2k7vssYpkyZYrJF198cULPR/q5Pelnn322yaecckrC+3Rf7wYOHGjywoULE96nL7iSCgAAAO9QpAIAAMA7FKkAAADwjtc9qdOnTzfZ7eW4/PLLTXb7RQuzBtzevXtNnjlzpsnnn3++ya+99prJzz33nMnu524vX7484TEhktur5X6W8TnnnGPyo48+anKsNUnLli1rcps2bSK2ueaaa0x2+5fdMbm/+7Zt20YdAxK3f/9+k2N9VvnRRx9t8jfffJPQ8dxeaHf93nR44YUXTHa/Z6SG2+fXpUsXk93XBnfdY1dBPajJmj17tsnu577DP7t37za5e/fuSe9z48aNJk+ePDnpffqCK6kAAADwDkUqAAAAvEORCgAAAO9QpAIAAMA7Gm2xYFVNbCXhFKtWrZrJboPxuHHjMjiago0ePdrkq6++2mT3jVXu4sqFeXNXIoIg0NhbJS/bc8V9A0GHDh1MXrFihcmvv/66ye5ixzfffLPJLVq0SHhMkyZNMtl9051vMjFXMj1PhgwZYvLdd9+dycOLSPILtLtz1/3Qh+3btxduYIVUXM4psfzlL38x2T3XlyiR+DUg9wNnKleuHHV79/XDXRh+zpw5CY8hlYriOSVRjRs3Nvmtt94y2f1QCFdB82jJkiUmn3766SavX78+kSFmXbR5wpVUAAAAeIciFQAAAN6hSAUAAIB3vO5JDQN3AXe3T9ZdwLl58+Ymf/HFF2kZ12+KS/9Y06ZNTX7yySdNPvHEE6M+3+0b/PTTT00uaHHkM8880+R27dqZPGDAAJNHjhwZdQzZVhT7x0qVsp9Xcuutt5rs9i67/Z6LFi0yecOGDSa7C7oXZM2aNSZPmTLF5IoVK5q8adMmk0877TSTFy9eHPOY6VRczimxuB/G4fYaxvpgh8cffzziPve8dfjhh5s8atQok+vVq2eye95q1apV1DGkW1E8p8Ti9qC6H75x7LHHJrS/ZcuWRdzXu3dvk9NdR6QbPakAAAAIFYpUAAAAeIciFQAAAN4pFXsTRPPtt9+a/MEHH5h84YUXRs1h7yXxhftzvPbaa02+4YYbTF69erXJq1atMvmll14yef/+/RHHdHtQXfPnz4/6ONJv3759Jt93331R88EHH2yy24NaGO76mW4Pqsvtc812DyoKVqdOHZNj9aA++uijJg8cODBimz179pjsrof59ddfm/z555+b7PY7uq8348ePjzpGJK9+/fomJ9qDum7dOpPdNVBFwrcOajK4kgoAAADvUKQCAADAOxSpAAAA8E6oe1Jr1aplsru+oNvfkw7ff/991IzscHu1rrzyyqT216xZs4j7OnXqZPJHH31k8ieffJLUMZF5yfag3nTTTRH3devWLepz3H7q/v37JzUGZMZ1110X9fFnnnnG5JtvvtnkgvrcY1mxYoXJbp/rLbfcYvLZZ59tMj2p6XfOOecktL37WnXRRReZXJz6TwvClVQAAAB4hyIVAAAA3qFIBQAAgHdC1ZPaqFEjk+fOnWvy2rVrTXZ7uz788MO0jCs/d4woGtzP0BaJ/Fz4BQsWmHzgwIG0jgnZV758eZP79OmT8D7++c9/muz21sNPGzdujPq4u05yYXpQY3HXAXa5r0cVKlQw+ddff035mIq7a665xuRYrwPvv/++ycuWLUv5mMKMK6kAAADwDkUqAAAAvEORCgAAAO+Eqif1sssuM/nQQw81ee/evZkcToGWL19usqqa/Kc//SmTw0EGub97FH133nmnyU2bNo35nAkTJpj86quvpnRMyIyXX37Z5L/85S8mZ+L9Ce4x3H5md91eelCT97e//c3kkSNHmlyiRPRrfy+99JLJ7vq5sLiSCgAAAO9QpAIAAMA7FKkAAADwTqh6UpcsWWJyEAQmu2sWnnzyySYvXLjQ5J07dyY9ptq1a5t85plnmuyOkf6zcGrdunXMbb766qsMjATZ1LJlS5Pj6Sdz17K87777oj6OcHB/b+65/uKLLzZ53LhxJq9ZsyZin1WqVDHZXXfX/Vz3o446yuRPP/3U5EysDV7cuL/nWOuguo8PGTIk1UMq0riSCgAAAO9QpAIAAMA7FKkAAADwTqh6Ul988UWT27dvb/Lll19u8v3332/yJZdcYvKoUaMijvH0009HHUOtWrWi7sNdJ/Hrr782+fXXX4+6f/jpo48+irhvwIABJrvr9qLoef75500uW7Zsws9ZunRpSseE7HjllVdM7tatm8m9evUyOZ7fu7vGZqx+R6RfnTp1TL766qujbr9gwQKT+/bta/L333+fmoEVE1xJBQAAgHcoUgEAAOAdilQAAAB4J1Q9qa7rrrvOZHcN0qeeesrkxo0bmzx69OiIfQ4bNsxkd020MmXKmFy5cmWTt27darK7zt327dsjjomiwV1Dc+LEiVkaCVLl2muvNblhw4YJ7+Oaa65J1XDgscGDB5vcqlUrk+vVqxdzH+7rTaJWrlyZ1PMRae3atSa7dcXDDz9scs2aNU2uUKGCyXv37k3h6Io+rqQCAADAOxSpAAAA8A5FKgAAALyj0XpgVDW5Bpksq1GjhslDhw41+bzzzot4jrvWZaweoc8++8xk97O8P/7445jjTKcgCDQTxwn7XInl+OOPj7jP/d3OmzfPZHcdX99lYq74Pk+qV69usrumYcWKFaM+/4UXXoi4z12fOew4p8TnqKOOMvmtt94yuW7duhHPUbU/Wnf+TZ482WT3HDRjxgyTs/0eiKJwTnnmmWdMPvHEE01+6aWXTH777bdNXrhwYXoGVoREmydcSQUAAIB3KFIBAADgHYpUAAAAeKdI96SC/rF0GjJkiMnr1q0z2V1Pz3dFoX8sWaVLlzZ57NixJvfu3dvke+65x+RHHnkkYp9btmxJ0ej8wDkF8eKcgnjQkwoAAIBQoUgFAACAdyhSAQAA4B16Uos4+scQL/rHEA/OKYgX5xTEg55UAAAAhApFKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvRF0nFQAAAMgGrqQCAADAOxSpAAAA8A5FKgAAALxDkQoAAADvUKQCAADAOxSpAAAA8E4oilRVzVHVTnFsF6hqw0Ieo9DPhT+YK4gH8wTxYq4gHsyT9AhFkeorzTVcVTfl3R5QVc32uOAfVR2oqktUdZuq/k9VB2Z7TPCPqnZQ1Tmq+rOq5mR7PPCXqg5R1b2quj3f7Yhsjwt+Cfs8oUhNztUi0k1EmolIUxE5W0SuyeaA4C0VkUtEpLqInCki/VS1V3aHBA/tEJGxIsIfMYjHxCAIKuW7rcr2gOCl0M6TUBWpqtpSVT9R1a2qul5VH1XVMs5mZ6nqKlXdqKojVLVEvudfrqpfq+oWVX1LVesnOaQ+IvJQEARrgiBYKyIPicilSe4TKeDbXAmC4IEgCBYFQbAvCIJvROR1EWmdzD6RPA/nyfwgCJ4XkdC8iBQXvs0V+Il5klqhKlJFZL+I3CQiNUXkZBHpKCLXOdt0F5EWInK8iHQVkctFRFS1m4jcLiLnisjBIvKBiIwv6CCqelveBCvwlm/TJiLyeb78ed59yD7f5kr+56iInCIiS5P6DpEK3s4TeMfHuXKOqm5W1aWq2jcV3ySSxjxJpSAIvL+JSI6IdCrg/v4iMjlfDkTkzHz5OhGZnff1DBG5It9jJUTkVxGpn++5DRMc134ROSpfPjJvP5rtn1lxvfk6V5yxDJXcP2jKZvvnVVxvvs8TEekkIjnZ/jlx83euiEhjEaktIiVFpJWIrBeRC7P98yquN+ZJem6hupKqqo1UdZqq/qCqv4jIMMn9ayW/1fm+/k5yfzkiIvVF5JF8f2Vsltw+wTpJDGm7iFTJl6uIyPYgb2YgezycK7+Nq5/k9qZ2CYJgd7L7Q3J8nSfwj29zJQiCr4IgWBcEwf4gCD4WkUdEpEdh94fUYJ6kVqiKVBF5QkSWiciRQRBUkdzL4u676Q/L93U9EVmX9/VqEbkmCIJq+W7l835phqrervadcOaWb9Olkvumqd80E/4L1xe+zRVR1ctF5DYR6RgEwZoUfZ9IjnfzBN7yfa4EBYwHmcc8SaGwFamVReQXEdmuqkeJSEG9FQNVtbqqHiYiN4rIxLz7R4vIIFVtIiKiqlVVtWdBBwmCYFhg3wlnbvk2/Y+I3KyqdVS1togMEJFnU/KdIllezRVV7S25f1GfFoTonZXFgG/zpISqlhOR0rlRy2nkmy6QHb7Nla55x1JVbSkif5PcN2Qiu5gnKRS2IvUWEfmriGwTkafl/3+x+b0uIgtFZLGITBeRMSIiQRBMFpHhIjIh7xL8EhHpnOR4nhSRqSLyZd7+pufdh+zzba7cKyIHichn+f7aHZ3kPpE83+ZJWxHZKSJvSu4Vlp0i8naS+0Rq+DZXeonIirzx/EdEhgdB8FyS+0TymCcppLRPAgAAwDdhu5IKAACAYoAiFQAAAN6hSAUAAIB3KFIBAADgnVLRHlRV3lUVckEQZGQ9NOZK+GVirjBPwo9zCuLFOQXxiDZPuJIKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDsUqQAAAPBOqWwPAAAAoCgqUcJeCyxdurTJBw4cMLlcuXJRny8iUrJkyaTGtG3bNpP37t2b1P7SiSupAAAA8A5FKgAAALxDkQoAAADv0JOKYqlChQomn3baaSa3bds24X0uWrTI5Dlz5pi8bt26hPcJuHP1zjvvNPlPf/qTyeedd17axwQgPpdccknU7PaHNm3a1OQqVapE7LNatWpJjWnmzJkmu69V//vf/6Juv2PHjqSOnwiupAIAAMA7FKkAAADwDkUqAAAAvKNBEPz+g6q//2Ax1b59e5MHDx4c9fEOHTqY/N5776VhVL8vCALNxHF8nyuVK1c2edy4cSZ369bNZFX7Y4v27+T3bNiwweRjjjnG5E2bNiW8z3TKxFzxfZ74wJ0nEydONHnChAkmjxo1yuSff/45PQPLwzmlcNy1LUuVinxLSKtWrUyuX79+1H3WqlXL5PXr10fd3l2js1+/fibXrFnT5I0bN5p8wQUXmLxs2bKoxyuO5xR3XdPnn3/e5F69emVyOCkxbdo0k7t27ZrS/UebJ1xJBQAAgHcoUgEAAOAdilQAAAB4h3VSHYn2nMbirj+W7R7V4qpJkyYmuz2orldeecXkgtY4ffnll00+7rjjTH700Uej7rNnz54m+9ajisxwe1Dfeustk59++mmTH3vsMZPT3YOK+Li9iO65/rbbbjO5Xr16Efto0KCByW4PaWF64/NLtNf+0EMPNfnYY481OVZPanHUqFEjk8PYg+ravn171o7NlVQAAAB4hyIVAAAA3qFIBQAAgHeKVU/qkCFDTHb7TTPBPSY9qZnxww8/mPz++++b7Pb9jR8/PuFjzJs3z+QDBw6Y7PYSnn/++SY/8cQTCR8T4eP2HU6fPt3kkSNHmvzwww+b7M4rZIfbU/rnP//Z5NGjR5vsrkGaCV9++aXJ7txxe1Ld91B89NFHJr/xxhspHF3R5L7XIBPctZT3799vsnuOqVKlStT9ffrppyb/9NNPSYwuOVxJBQAAgHcoUgEAAOAdilQAAAB4p0j3pLr9NYmucZoJ7pjoUU2PnJwck0899dS0H/PNN9802V2jsHPnzia7fbH79u1Lz8CQVSNGjDD5+++/N/nBBx/M5HAQp4YNG5o8aNAgky+77DKTY61B6v57FxH5+uuvTXZ7A9evXx9znPmtXr3aZPqZ/eP2i7777rsmv/rqqzH3sWbNGpOTXU/XJ1xJBQAAgHcoUgEAAOAdilQAAAB4p0j1pLr9nZnoQR06dGjUx9u1a2dyrDG6+3PXdkV4uX1C1atXN9n9nG56UouGK664wuRmzZqZfMIJJ2RyOCik8847z+RLL7006vYrV640edKkSSbPnDkz4jnu+s0In5tuuinq42PHjjW5b9++JnPet7iSCgAAAO9QpAIAAMA7FKkAAADwTqh7Ut1+zcGDByf0/IL6Sd19utldxzTRdU2L0vpliK5Xr15RH1+7dq3JO3fuTOdwkAFdu3aNuO+ee+4x+d577zX5559/TuuYUDglS5Y0uXv37ia76x5PmTLF5HPPPTct44Lf3PcWuOrWrWvySSedZPLSpUtN3rJlS2oGFlJcSQUAAIB3KFIBAADgHYpUAAAAeEej9UiqqlcNlMn2oHbo0MHkRPtJUyHR78Hte0pUEATJ7SBOvs0VH7i9hhUrVjTZXTcxVg9rumVirhT1ebJgwYKI+2rUqGHyEUcckanhpEVxOadUqlTJ5Llz55rcvHlzk/v162fyE088kZZxhUlxPKdMmDDB5J49eyb0/DVr1pg8Z86cmM/58MMPTXZrmxUrViQ0hkyLNk+4kgoAAADvUKQCAADAOxSpAAAA8E6o1klt165dQtsnu6ZpOrg9qe731L59e5PdfhS3rxbZU6FCBZMnTpxocuXKlU12+7/dHjeEz913321ys2bNIrZxexURDtu3bzd58eLFJrs9qaNGjTL51FNPNfnVV181uaD+Zd97BxHbzJkzTW7durXJtWvXjvp8dx3Viy++OOYx3W1+/fVXk999912Tn3nmGZOnTp0a8xjZwpVUAAAAeIciFQAAAN6hSAUAAIB3vF4nNdHPufdhHdRYYvWcuoYOHWqy29MaS3FZ09Dt86lXr57JXbp0ifr8H374weSFCxfGPOYtt9xicrdu3Ux217h15+P5559v8qZNm2IeM52K45qGiSpXrpzJH330kcnVq1ePeI7bu/jLL7+kfFyZVFzOKa7OnTubPG3aNJPdf++xXr927doVcZ/b9zp//nyTH3vsMZN972HlnCLSoEEDk6+44gqTjz32WJPbtGljckHnlGS5a3i/9tprJt94440m79ixI+VjyI91UgEAABAqFKkAAADwDkUqAAAAvOPVOqmJ9lu6fOxBTVZR/J7i4fbxtW3b1uQLLrjA5Fq1aplcv359kxPtb060vywes2bNMjnbPahI3L333mvycccdZ/Ktt94a8Zyw96Ai1zvvvGNykyZNTJ49e7bJhx56aNT9uf3NIiInnXSSySeffLLJXbt2Ndmdf26vIbIvJyfH5Lvuuivq9m4P6wknnBCxjdsf7c6Dpk2bRj1G1apVTb7sssuiPn7RRReZvHv37qj7TyWupAIAAMA7FKkAAADwDkUqAAAAvEORCgAAAO9kdTH/RBe2d4Vh8X6X+z26PwNXst9jWBfedt8AULFixUTHY/K6detMnjBhgslXXXWVyZUrVza5MG+ccsfgNpu7C8Gfe+65Jm/bti3hYyaDhbcjlS5d2uQlS5aY7H5oxNFHHx2xD/eNE2EX1nNKprmLsru5UaNGEc85/vjjTW7WrJnJ7nno1VdfNblPnz4m//rrr/ENNk04p2RGpUqVTHbftHf33Xeb3Lt374T2X7NmTZO3bNmS0PNjYTF/AAAAhApFKgAAALxDkQoAAADvZHQx/+LQg5rs9+j2MRZXVapUMfnAgQMmu71W7kL57qLrbj/osGHDTHZ7UN3fw549eyLG+OGHH5r85ptvRmyTX5cuXUx2P7Bg69atJrs9qu5i4tnuNysOnnjiCZMbNmxo8t/+9jeTi1r/KQrPPT+4OR6nn366yc8//7zJ7utNhQoVTOYcUTy4r0+rV682+frrrze5du3aJru1lU+4kgoAAADvUKQCAADAOxSpAAAA8E5We1JjGTp0qMlFsQfV516QbHJ7UN31AadPn27yiBEjTL711ltNPvHEE02uU6dO1P2762H2798/YoyJ/q5Hjhxpsvu7f+aZZ0x+7bXXoubLL7/c5Eyvq1oUVatWzWT3Z7xy5UqTn3766XQPCcWY+5q3adMmkw8++OAMjqZocN/vsHPnTpNr1aqV0uO5fcJuf2gquOvruuumHnPMMSk/ZqZwJRUAAADeoUgFAACAdyhSAQAA4J2M9qS2a9cuoe196EF1e04HDx4c9XGX21c7ZMiQFIyq6Pv2229NdtenPPXUU03+y1/+YnLZsmVNdntOXW6P67XXXmvy+vXroz6/MNyeVvezvdesWWNy9+7do+7Pt8/tDqNYP+P777/f5ILWzwVSxf03/ac//cnkzZs3Z3I4RcIDDzxgsrtmqLueNbKLK6kAAADwDkUqAAAAvEORCgAAAO94vU5qrO1j9ay6z4/VX1oY7hhY9zQ1OnXqZPLw4cNNvuCCCxLan9tz+o9//MPk//73vybv378/of2ngtv32qRJE5MnT55ssts/6fbhun26iOSuiThs2DCTVdXkwnz+OvzUrVs3k93f7caNGzM4mlxXXnmlyU8++aTJbm/96NGjTc7GmMPmmmuuMdldkxsi69atM3nfvn1ZGglXUgEAAOAhilQAAAB4hyIVAAAA3tFo60eqavTFJZMUa+1KH7k9pz6s5RpNEAQae6vkpXuuuBo1amSyu7bdyJEjMzmcjOjXr5/JjzzySNTtS5YsmdD+MzFXMj1PYjnqqKNMXrp0qckff/yxyR07djS5OK6TWlTOKU2bNjXZXT/ztttuM3nx4sVJHa9y5com9+rVK2IbtwfV7Yl210U97rjjTP7++++TGWLK+XhOGTRokMn33ntvSscTBm7/9fvvv2/ymDFjTM7JyUnreKLNE66kAgAAwDsUqQAAAPAORSoAAAC8k9F1Ul1uP2ei66gme7y5c+fG3Mb3ntPiavny5VFzUeT2cIexpztsxo4da3Jx7EEtLk477TSTjzzySJOnTZtmstuv/Omnn5p8+umnmzxixAiTK1WqFDEG99/09u3bTb7kkktM9q0HNQzc3uOXX37ZZHf9XPf9D7FceumlJpcqlf4ya8GCBSa7/dNTp041+d133zX5119/Tcu4UoErqQAAAPAORSoAAAC8Q5EKAAAA72R1ndRYhgwZktbti4OisqYhItdZdPslu3fvbnKivVA+rmmYbrHWSX3rrbdM7tOnj8kbNmxIz8A8VlTOKTVq1DB55syZJv/5z39OaH/umqaF6Rl3e0zdtVTdvlffFcdzSrVq1Ux250U67Nq1y+SdO3em/ZipxDqpAAAACBWKVAAAAHiHIhUAAADe8bonFckrKv1jiFShQgWT3R67NWvWJLS/4tg/duihh5o8ceJEk8eNGxf18bD1fqVCUT2nuOukTpo0yeSC1jXNL9GeVHfdVZHINTa3bNkSdR++K47nFCSOnlQAAACECkUqAAAAvEORCgAAAO/Qk1rEFdX+MaQe/WOIR3E5p9StW9fkCy64wOR+/fqZ7PakLly40OQ33njD5Oeffz7imAcOHEh4nD7jnIJ40JMKAACAUKFIBQAAgHcoUgEAAOAdelKLuOLSP4bk0T+GeHBOQbw4pyAe9KQCAAAgVChSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHiHIhUAAADeoUgFAACAdyhSAQAA4B2KVAAAAHhHg4CPvQUAAIBfuJIKAAAA71CkAgAAwDsUqQAAAPAORSoAAAC8Q5EKAAAA71CkAgAAwDv/Bx6z9iwB7yj7AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", "for i, ax in enumerate(axs.flatten()):\n", " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", " ax.set_title(f\"label={pred[i]}\")\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "id": "edb528b6", "metadata": {}, "source": [ "Congratulations! You made it to the end of the annotated MNIST example. You can revisit\n", "the same example, but structured differently as a couple of Python modules, test\n", "modules, config files, another Colab, and documentation in Flax's Git repo:\n", "\n", "[https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" }, "language_info": { "name": "python", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs/quick_start.md ================================================ --- jupytext: formats: ipynb,md:myst main_language: python text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/quick_start.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/quick_start.ipynb) # Quick start Welcome to Flax! Flax is an open source Python neural network library built on top of [JAX](https://github.com/jax-ml/jax). This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train the network for image classification on the MNIST dataset. +++ ## 1. Install Flax ```{code-cell} :tags: [skip-execution] !pip install -q flax>=0.7.5 ``` ## 2. Loading data Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers. ```{code-cell} import tensorflow_datasets as tfds # TFDS for MNIST import tensorflow as tf # TensorFlow operations def get_datasets(num_epochs, batch_size): """Load MNIST train and test datasets into memory.""" train_ds = tfds.load('mnist', split='train') test_ds = tfds.load('mnist', split='test') train_ds = train_ds.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']}) # normalize train set test_ds = test_ds.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']}) # normalize test set train_ds = train_ds.repeat(num_epochs).shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency test_ds = test_ds.shuffle(1024) # create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) # group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency return train_ds, test_ds ``` ## 3. Define network Create a convolutional neural network with the Linen API by subclassing [Flax Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html). Because the architecture in this example is relatively simple—you're just stacking layers—you can define the inlined submodules directly within the `__call__` method and wrap it with the [`@compact`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.compact) decorator. To learn more about the Flax Linen `@compact` decorator, refer to the [`setup` vs `compact`](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide. ```{code-cell} from flax import linen as nn # Linen API class CNN(nn.Module): """A simple CNN model.""" @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) return x ``` ### View model layers Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. ```{code-cell} :outputId: 2c580f41-bf5d-40ec-f1cf-ab7f319a84da import jax import jax.numpy as jnp # JAX NumPy cnn = CNN() print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)), compute_flops=True, compute_vjp_flops=True)) ``` ## 4. Create a `TrainState` A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state. Because this is such a common pattern, Flax provides the class [`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) that serves most basic usecases. ```{code-cell} :outputId: 1249b7fb-6787-41eb-b34c-61d736300844 !pip install -q clu ``` ```{code-cell} from clu import metrics from flax.training import train_state # Useful dataclass to keep train state from flax import struct # Flax dataclasses import optax # Common loss functions and optimizers ``` We will be using the `clu` library for computing metrics. For more information on `clu`, refer to the [repo](https://github.com/google/CommonLoopUtils) and [notebook](https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb#scrollTo=ueom-uBWLbeQ). ```{code-cell} @struct.dataclass class Metrics(metrics.Collection): accuracy: metrics.Accuracy loss: metrics.Average.from_output('loss') ``` You can then subclass `train_state.TrainState` so that it also contains metrics. This has the advantage that we only need to pass around a single argument to functions like `train_step()` (see below) to calculate the loss, update the parameters and compute the metrics all at once. ```{code-cell} class TrainState(train_state.TrainState): metrics: Metrics def create_train_state(module, rng, learning_rate, momentum): """Creates an initial `TrainState`.""" params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image tx = optax.sgd(learning_rate, momentum) return TrainState.create( apply_fn=module.apply, params=params, tx=tx, metrics=Metrics.empty()) ``` ## 5. Training step A function that: - Evaluates the neural network given the parameters and a batch of input images with [`TrainState.apply_fn`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState) (which contains the [`Module.apply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply) method (forward pass)). - Computes the cross entropy loss, using the predefined [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels). Note that this function expects integer labels, so there is no need to convert labels to onehot encoding. - Evaluates the gradient of the loss function using [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad). - Applies a [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) of gradients to the optimizer to update the model's parameters. Use JAX's [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) decorator to trace the entire `train_step` function and just-in-time compile it with [XLA](https://www.tensorflow.org/xla) into fused device operations that run faster and more efficiently on hardware accelerators. ```{code-cell} @jax.jit def train_step(state, batch): """Train for a single step.""" def loss_fn(params): logits = state.apply_fn({'params': params}, batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() return loss grad_fn = jax.grad(loss_fn) grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state ``` ## 6. Metric computation Create a separate function for loss and accuracy metrics. Loss is calculated using the `optax.softmax_cross_entropy_with_integer_labels` function, while accuracy is calculated using `clu.metrics`. ```{code-cell} @jax.jit def compute_metrics(*, state, batch): logits = state.apply_fn({'params': state.params}, batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label']).mean() metric_updates = state.metrics.single_from_model_output( logits=logits, labels=batch['label'], loss=loss) metrics = state.metrics.merge(metric_updates) state = state.replace(metrics=metrics) return state ``` ## 7. Download data ```{code-cell} num_epochs = 10 batch_size = 32 train_ds, test_ds = get_datasets(num_epochs, batch_size) ``` ## 8. Seed randomness - Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. - Get one [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey) and use it for parameter initialization. (Learn more about [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) and [PRNG chains](https://flax.readthedocs.io/en/latest/philosophy.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables).) ```{code-cell} tf.random.set_seed(0) ``` ```{code-cell} init_rng = jax.random.key(0) ``` ## 9. Initialize the `TrainState` Remember that the function `create_train_state` initializes the model parameters, optimizer and metrics and puts them into the training state dataclass that is returned. ```{code-cell} learning_rate = 0.01 momentum = 0.9 ``` ```{code-cell} state = create_train_state(cnn, init_rng, learning_rate, momentum) del init_rng # Must not be used anymore. ``` ## 10. Train and evaluate Create a "shuffled" dataset by: - Repeating the dataset equal to the number of training epochs - Allocating a buffer of size 1024 (containing the first 1024 samples in the dataset) of which to randomly sample batches from - Everytime a sample is randomly drawn from the buffer, the next sample in the dataset is loaded into the buffer Define a training loop that: - Randomly samples batches from the dataset. - Runs an optimization step for each training batch. - Computes the mean training metrics across each batch in an epoch. - Computes the metrics for the test set using the updated parameters. - Records the train and test metrics for visualization. Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy. ```{code-cell} # since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs ``` ```{code-cell} metrics_history = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': []} ``` ```{code-cell} :outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 for step,batch in enumerate(train_ds.as_numpy_iterator()): # Run optimization steps over training batches and compute batch metrics state = train_step(state, batch) # get updated train state (which contains the updated parameters) state = compute_metrics(state=state, batch=batch) # aggregate batch metrics if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed for metric,value in state.metrics.compute().items(): # compute metrics metrics_history[f'train_{metric}'].append(value) # record metrics state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch # Compute metrics on the test set after each training epoch test_state = state for test_batch in test_ds.as_numpy_iterator(): test_state = compute_metrics(state=test_state, batch=test_batch) for metric,value in test_state.metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) print(f"train epoch: {(step+1) // num_steps_per_epoch}, " f"loss: {metrics_history['train_loss'][-1]}, " f"accuracy: {metrics_history['train_accuracy'][-1] * 100}") print(f"test epoch: {(step+1) // num_steps_per_epoch}, " f"loss: {metrics_history['test_loss'][-1]}, " f"accuracy: {metrics_history['test_accuracy'][-1] * 100}") ``` ## 11. Visualize metrics ```{code-cell} :outputId: 431a2fcd-44fa-4202-f55a-906555f060ac import matplotlib.pyplot as plt # Visualization # Plot loss and accuracy in subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) ax1.set_title('Loss') ax2.set_title('Accuracy') for dataset in ('train','test'): ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') ax1.legend() ax2.legend() plt.show() plt.clf() ``` ## 12. Perform inference on test set Define a jitted inference function `pred_step`. Use the learned parameters to do model inference on the test set and visualize the images and their corresponding predicted labels. ```{code-cell} @jax.jit def pred_step(state, batch): logits = state.apply_fn({'params': state.params}, test_batch['image']) return logits.argmax(axis=1) test_batch = test_ds.as_numpy_iterator().next() pred = pred_step(state, test_batch) ``` ```{code-cell} :outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e fig, axs = plt.subplots(5, 5, figsize=(12, 12)) for i, ax in enumerate(axs.flatten()): ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') ax.set_title(f"label={pred[i]}") ax.axis('off') ``` Congratulations! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax's Git repo: [https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist) ================================================ FILE: docs/robots.txt ================================================ User-agent: * Disallow: /api_reference/flax.linen/_autosummary/ # for SEO, since Google still indexes this deprecated link Sitemap: https://flax.readthedocs.io/sitemap.xml ================================================ FILE: docs_nnx/.gitignore ================================================ _formatted_howtos ================================================ FILE: docs_nnx/.readthedocs.yaml ================================================ # .readthedocs.yml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 build: os: ubuntu-22.04 tools: python: "3.12" jobs: pre_build: - pip install ".[all, testing, docs]" - pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs_nnx/conf.py # Optionally build your docs in additional formats such as PDF and ePub formats: - htmlzip - epub # - pdf ================================================ FILE: docs_nnx/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs_nnx/README.md ================================================ # Where to find the docs The FLAX documentation can be found here: https://flax.readthedocs.io/en/latest/ # How to build the docs 1. Clone the `flax` repository with `git clone https://github.com/google/flax.git`. 1. In the main `flax` folder, install the required dependencies using `uv pip install -e .[docs]`. 1. [Optional] If you need to make any local changes to the docs, create and switch to a branch. Make your changes to the docs in that branch. 1. To build the docs, in the `flax/docs_nnx` folder run the make script: `make html`. Alternatively, install [`entr`](https://github.com/eradman/entr/), which helps run arbitrary commands when files change. Then run `find ../ ! -regex '.*/[\.|\_].*' | entr -s 'make html'`. 1. If the build is successful, you should get the `The HTML pages are in _build/html.` message. You can preview the docs in `flax/docs/_build/html`. # How to run embedded code tests We use `doctest` blocks for embedded code in documents, that are also tested. Learn more at https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html To run tests locally, run `make doctest` # How to write code documentation Our documentation is written in reStructuredText for Sphinx. It is a meta-language that is compiled into online documentation. For more details, check out [Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). As a result, our docstrings adhere to a specific syntax that has to be kept in mind. Below we provide some guidelines. To learn how to contribute to Jupyter Notebooks or other formats in Flax docs, refer to the dedicated [Contributing](https://flax.readthedocs.io/en/latest/contributing.html) page. ## How much information to put in a docstring Docstring should be informative. We prefer to err on the side of too much documentation than too little. For instance, providing a one-line explanation to a new `Module` which implements new functionality is not sufficient. Furthermore, we highly encourage adding examples to your docstrings, so users can directly see how code can be used. ## How to write inline tested code We use [doctest](https://docs.python.org/3/library/doctest.html) syntax for writing examples in documentation. These examples are ran as tests as part of our CI process. In order to write `doctest` code in your documentation, please use the following notation: ```bash # Example code:: # # def sum(a, b): # return a + b # # sum(0, 1) ``` The `Example code` string at the beginning can be replaced by anything as long as there are two semicolons and a newline following it, and the code is indented. ## How to use "code font" When writing code font in a docstring, please use double backticks. Example: ```bash # This returns a ``str`` object. ``` Note that argument names and objects like True, None or any strings should usually be put in `code`. ## How to create cross-references/links It is possible to create cross-references to other classes, functions, and methods. In the following, `obj_typ` is either `class`, `func`, or `meth`. ```bash # First method: # :`path_to_obj` # Second method: # ::`description ` ``` You can use the second method if the `path_to_obj` is very long. Some examples: ```bash # Create: a reference to class flax.linen.Module. # :class:`flax.linen.Module` # Create a reference to local function my_func. # :func:`my_func` # Create a reference "Module.apply()" to method flax.linen.Module.apply. # :meth:`Module.apply() ` # ``` To creata a hyperlink, use the following syntax: ```bash # Note the double underscore at the end: # `Link to Google `__ ``` ### How to specify arguments for classes and methods * Class attributes should be specified using the `Attributes:` tag. * Method argument should be specified using the `Args:` tags. * All attributes and arguments should have types. Here is an example from our library: ```python class DenseGeneral(Module): """A linear transformation with flexible axes. Attributes: features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` for details. """ features: Union[int, Iterable[int]] axis: Union[int, Iterable[int]] = -1 batch_dims: Iterable[int] = () use_bias: bool = True dtype: Dtype = jnp.float32 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros precision: Any = None @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ ... ``` ================================================ FILE: docs_nnx/_ext/codediff.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sphinx directive for creating code diff tables. Use directive as follows: .. codediff:: :title: , --- In order to highlight a line of code, append "#!" to it. """ import sphinx from docutils import nodes from docutils.parsers.rst import directives from docutils.statemachine import ViewList from sphinx.util.docutils import SphinxDirective MISSING = object() class CodeDiffParser: def parse( self, lines: list[str], title: str, groups: list[str] | None = None, skip_test: str | None = None, code_sep: str = '---', sync: object = MISSING, ): """Parse the code diff block and format it so that it renders in different tabs and is tested by doctest. For example: .. testcode:: tab0, tab2, tab3 .. codediff:: :title: Tab 0, Tab 1, Tab 2, Tab 3 :groups: tab0, tab1, tab2, tab3 :skip_test: tab1, tab3 --- --- --- For group tab0: and are executed. For group tab1: Nothing is executed. For group tab2: and are executed. For group tab3: is executed. Arguments: lines: a string list, where each element is a single string code line title: a single string that contains the titles of each tab (they should be separated by commas) groups: a single string that contains the group of each tab (they should be separated by commas). Code snippets that are part of the same group will be executed together. If groups=None, then the group names will default to the tab title names. skip_test: a single string denoting which group(s) to skip testing (they should be separated by commas). This is useful for legacy code snippets that no longer run correctly anymore. If skip_test=None, then no tests are skipped. code_sep: the separator character(s) used to denote a separate code block for a new tab. The default code separator is '---'. sync: an option for Sphinx directives, that will sync all tabs together. This means that if the user clicks to switch to another tab, all tabs will switch to the new tab. """ titles = [t.strip() for t in title.split(',')] num_tabs = len(titles) sync = sync is not MISSING # skip legacy code snippets in upgrade guides if skip_test is not None: skip_tests = {index.strip() for index in skip_test.split(',')} else: skip_tests = set() code_blocks = '\n'.join(lines) if code_blocks.count(code_sep) != num_tabs - 1: raise ValueError( f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.' ) code_blocks = [ code_block.split('\n') for code_block in code_blocks.split(code_sep + '\n') ] # list[code_tab_list1[string_line1, ...], ...] # by default, put each code snippet in a different group denoted by an index number, to be executed separately if groups is not None: groups = [group_name.strip() for group_name in groups.split(',')] else: groups = titles if len(groups) != num_tabs: raise ValueError( f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.' ) tabs = [] test_codes = [] for i, code_block in enumerate(code_blocks): if groups[i] not in skip_tests: test_codes.append((code_block, groups[i])) tabs.append((titles[i], self._code_block(code_block))) output = self._tabs(*tabs, sync=sync) return output, test_codes def _code_block(self, lines): """Creates a codeblock.""" # Remove right trailing whitespace so we can detect the comments. lines = [x.rstrip() for x in lines] highlight = lambda x: x.endswith('#!') code = map(lambda x: x[:-2].rstrip() if highlight(x) else x, lines) highlights = [i + 1 for i in range(len(lines)) if highlight(lines[i])] highlights = ','.join(str(i) for i in highlights) directive = ['.. code-block:: python'] if highlights: directive += [f' :emphasize-lines: {highlights}'] # Indent code and add empty line so the code is picked up by the directive. return directive + [''] + list(map(lambda x: ' ' + x, code)) def _tabs(self, *contents: tuple[str, list[str]], sync): output = ['.. tab-set::'] + [' '] for title, content in contents: output += [f' .. tab-item:: {title}'] if sync: key = title.strip() output += [f' :sync: {key}'] output += [' '] output += [' ' + line for line in content] return output class CodeDiffDirective(SphinxDirective): has_content = True option_spec = { 'title': directives.unchanged, 'groups': directives.unchanged, 'skip_test': directives.unchanged, 'code_sep': directives.unchanged, 'sync': directives.flag, } def run(self): table_code, test_codes = CodeDiffParser().parse( list(self.content), **self.options ) # Create a test node as a comment node so it won't show up in the docs. # We add attribute "testnodetype" so it is be picked up by the doctest # builder. This functionality is not officially documented but can be found # in the source code: # https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py # (search for 'testnodetype'). test_nodes = [] for test_code, group in test_codes: test_node = nodes.comment( '\n'.join(test_code), '\n'.join(test_code), testnodetype='testcode', groups=[group], ) self.set_source_info(test_node) test_node['options'] = {} test_node['language'] = 'python3' test_nodes.append(test_node) # The table node is the side-by-side diff view that will be shown on RTD. table_node = nodes.paragraph() self.content = ViewList(table_code, self.content.parent) self.state.nested_parse(self.content, self.content_offset, table_node) return [table_node] + test_nodes def setup(app): app.add_directive('codediff', CodeDiffDirective) return { 'version': sphinx.__display_version__, 'parallel_read_safe': True, 'parallel_write_safe': True, } ================================================ FILE: docs_nnx/_ext/codediff_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for codediff Sphinx extension.""" from absl.testing import parameterized from codediff import CodeDiffParser class CodeDiffTest(parameterized.TestCase): def test_parse(self): input_text = r"""@jax.jit #! def get_initial_params(key): #! init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] extra_line return initial_params --- @jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] return initial_params""" expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params""" expected_testcodes = [ r"""@jax.jit #! def get_initial_params(key): #! init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] extra_line return initial_params """, r"""@jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] return initial_params""", ] title_left = 'Single device' title_right = 'Ensembling on multiple devices' actual_table, actual_testcodes = CodeDiffParser().parse( lines=input_text.split('\n'), title=f'{title_left}, {title_right}', ) actual_table = '\n'.join(actual_table) actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes] self.assertEqual(expected_table, actual_table) self.assertEqual(expected_testcodes[0], actual_testcodes[0]) self.assertEqual(expected_testcodes[1], actual_testcodes[1]) @parameterized.parameters( { 'input_text': r"""x = 1 --- x = 2 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': None, 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 --- x = 4 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': None, 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': 'tab0, tab2', 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.', }, { 'input_text': r"""x = 1 --- x = 2 --- x = 3 """, 'title': 'Tab 0, Tab1, Tab2', 'groups': 'tab0, tab1, tab2, tab3', 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.', }, ) def test_parse_errors(self, input_text, title, groups, error_msg): with self.assertRaisesRegex(ValueError, error_msg): _, _ = CodeDiffParser().parse( lines=input_text.split('\n'), title=title, groups=groups, ) ================================================ FILE: docs_nnx/_ext/flax_module.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sphinx directive for visualizing Flax modules. Use directive as follows: .. flax_module:: :module: flax.linen :class: Dense """ import importlib import sphinx import sphinx.ext.autosummary.generate as ag from docutils import nodes from docutils.parsers.rst import directives from docutils.statemachine import ViewList from sphinx.util.docutils import SphinxDirective from docs.conf_sphinx_patch import generate_autosummary_content def render_module(modname: str, qualname: str, app): parent = importlib.import_module(modname) obj = getattr(parent, qualname) template = ag.AutosummaryRenderer(app) template_name = 'flax_module' imported_members = False recursive = False context = {} return generate_autosummary_content( qualname, obj, parent, template, template_name, imported_members, app, recursive, context, modname, qualname, ) class FlaxModuleDirective(SphinxDirective): has_content = True option_spec = { 'module': directives.unchanged, 'class': directives.unchanged, } def run(self): module_template = render_module( self.options['module'], self.options['class'], self.env.app ) module_template = module_template.splitlines() # Create a container for the rendered nodes container_node = nodes.container() self.content = ViewList(module_template, self.content.parent) self.state.nested_parse(self.content, self.content_offset, container_node) return [container_node] def setup(app): app.add_directive('flax_module', FlaxModuleDirective) return { 'version': sphinx.__display_version__, 'parallel_read_safe': True, 'parallel_write_safe': True, } ================================================ FILE: docs_nnx/_static/css/flax_theme.css ================================================ @import url("theme.css"); .wy-nav-content { max-width: 1290px; } .rst-content table.docutils { width: 100%; } .rst-content table.docutils td { vertical-align: top; padding: 0; } .rst-content table.docutils td p { padding: 8px; } .rst-content div[class^=highlight] { border: 0; margin: 0; } ================================================ FILE: docs_nnx/_templates/autosummary/flax_module.rst ================================================ {{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} :exclude-members: .. automethod:: __call__ {% block methods %} {% for item in methods %} {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} .. automethod:: {{ item }} {%- endif %} {%- endfor %} {% if methods %} .. rubric:: Methods .. autosummary:: {% for item in methods %} {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: docs_nnx/api_reference/flax.config.rst ================================================ flax.config package ==================== .. automodule:: flax.configurations :members: :undoc-members: :exclude-members: FlagHolder, bool_flag, static_bool_env ================================================ FILE: docs_nnx/api_reference/flax.core.frozen_dict.rst ================================================ flax.core.frozen_dict package ============================= .. currentmodule:: flax.core.frozen_dict .. autoclass:: FrozenDict :members: pretty_repr, copy, pop, unfreeze, tree_flatten .. autofunction:: freeze .. autofunction:: unfreeze .. autofunction:: copy .. autofunction:: pop .. autofunction:: pretty_repr ================================================ FILE: docs_nnx/api_reference/flax.nnx/bridge.rst ================================================ bridge ------------------------ .. automodule:: flax.nnx.bridge .. currentmodule:: flax.nnx.bridge .. flax_module:: :module: flax.nnx.bridge :class: ToNNX .. flax_module:: :module: flax.nnx.bridge :class: ToLinen .. autofunction:: to_linen .. flax_module:: :module: flax.nnx.bridge :class: NNXMeta ================================================ FILE: docs_nnx/api_reference/flax.nnx/filterlib.rst ================================================ filterlib ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: flax.nnx.filterlib.to_predicate .. autoclass:: WithTag .. autoclass:: PathContains .. autoclass:: OfType .. autoclass:: Any .. autoclass:: All .. autoclass:: Not .. autoclass:: Everything .. autoclass:: Nothing ================================================ FILE: docs_nnx/api_reference/flax.nnx/graph.rst ================================================ graph ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: split .. autofunction:: merge .. autofunction:: update .. autofunction:: pop .. autofunction:: state .. autofunction:: variables .. autofunction:: graph .. autofunction:: graphdef .. autofunction:: iter_graph .. autofunction:: recursive_map .. autofunction:: clone .. autofunction:: call .. autofunction:: set_metadata .. autofunction:: cached_partial .. autoclass:: GraphDef :members: .. autoclass:: UpdateContext :members: .. autofunction:: update_context .. autofunction:: current_update_context .. autofunction:: find_duplicates .. autofunction:: pure .. autofunction:: flatten .. autofunction:: unflatten ================================================ FILE: docs_nnx/api_reference/flax.nnx/helpers.rst ================================================ helpers ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: Sequential :members: .. autoclass:: List :members: .. autoclass:: Dict :members: .. autoclass:: TrainState :members: ================================================ FILE: docs_nnx/api_reference/flax.nnx/index.rst ================================================ flax.nnx ------------------------ Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 graph object module nn/index rnglib spmd state training/index transforms variables helpers visualization filterlib bridge ================================================ FILE: docs_nnx/api_reference/flax.nnx/module.rst ================================================ module ------------------------ .. automodule:: flax.nnx :members: iter_children, iter_modules .. currentmodule:: flax.nnx .. autoclass:: Module :members: .. autofunction:: view .. autofunction:: view_info .. autofunction:: with_attributes ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/activations.rst ================================================ Activation functions ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: celu .. autofunction:: elu .. autofunction:: gelu .. autofunction:: glu .. autofunction:: hard_sigmoid .. autofunction:: hard_silu .. autofunction:: hard_swish .. autofunction:: hard_tanh .. autofunction:: leaky_relu .. autofunction:: log_sigmoid .. autofunction:: log_softmax .. autofunction:: logsumexp .. autofunction:: one_hot .. autofunction:: relu .. autofunction:: relu6 as relu6, .. autofunction:: selu .. autofunction:: sigmoid .. autofunction:: identity .. autofunction:: silu .. autofunction:: soft_sign .. autofunction:: softmax .. autofunction:: softplus .. autofunction:: standardize .. autofunction:: swish .. autofunction:: tanh ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/attention.rst ================================================ Attention ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. flax_module:: :module: flax.nnx :class: MultiHeadAttention .. autofunction:: combine_masks .. autofunction:: dot_product_attention .. autofunction:: make_attention_mask .. autofunction:: make_causal_mask ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/dtypes.rst ================================================ Dtypes ------------------------ .. automodule:: flax.nnx.nn.dtypes .. currentmodule:: flax.nnx.nn.dtypes .. autofunction:: canonicalize_dtype .. autofunction:: promote_dtype ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/index.rst ================================================ nn ---------------------------- Neural network layers and activation functions used in NNX :class:`Module`'s. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 activations attention dtypes initializers linear lora normalization recurrent stochastic ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/initializers.rst ================================================ Initializers ------------------------ .. automodule:: flax.nnx.initializers .. currentmodule:: flax.nnx.initializers .. autofunction:: constant .. autofunction:: delta_orthogonal .. autofunction:: glorot_normal .. autofunction:: glorot_uniform .. autofunction:: he_normal .. autofunction:: he_uniform .. autofunction:: kaiming_normal .. autofunction:: kaiming_uniform .. autofunction:: lecun_normal .. autofunction:: lecun_uniform .. autofunction:: normal .. autofunction:: truncated_normal .. autofunction:: ones .. autofunction:: ones_init .. autofunction:: orthogonal .. autofunction:: uniform .. autofunction:: variance_scaling .. autofunction:: xavier_normal .. autofunction:: xavier_uniform .. autofunction:: zeros .. autofunction:: zeros_init ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/linear.rst ================================================ Linear ------------------------ NNX linear layer classes. .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. flax_module:: :module: flax.nnx :class: Conv .. flax_module:: :module: flax.nnx :class: ConvTranspose .. flax_module:: :module: flax.nnx :class: Embed .. flax_module:: :module: flax.nnx :class: Linear .. flax_module:: :module: flax.nnx :class: LinearGeneral .. flax_module:: :module: flax.nnx :class: Einsum ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/lora.rst ================================================ LoRA ------------------------ NNX LoRA classes. .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. flax_module:: :module: flax.nnx :class: LoRA .. flax_module:: :module: flax.nnx :class: LoRALinear ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/normalization.rst ================================================ Normalization ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. flax_module:: :module: flax.nnx :class: BatchNorm .. flax_module:: :module: flax.nnx :class: LayerNorm .. flax_module:: :module: flax.nnx :class: RMSNorm .. flax_module:: :module: flax.nnx :class: GroupNorm .. flax_module:: :module: flax.nnx :class: InstanceNorm .. flax_module:: :module: flax.nnx :class: SpectralNorm .. flax_module:: :module: flax.nnx :class: WeightNorm ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/recurrent.rst ================================================ Recurrent ------------------------ .. automodule:: flax.nnx.nn.recurrent .. currentmodule:: flax.nnx.nn.recurrent .. flax_module:: :module: flax.nnx.nn.recurrent :class: LSTMCell .. flax_module:: :module: flax.nnx.nn.recurrent :class: OptimizedLSTMCell .. flax_module:: :module: flax.nnx.nn.recurrent :class: SimpleCell .. flax_module:: :module: flax.nnx.nn.recurrent :class: GRUCell .. flax_module:: :module: flax.nnx.nn.recurrent :class: RNN .. flax_module:: :module: flax.nnx.nn.recurrent :class: Bidirectional .. autofunction:: flip_sequences ================================================ FILE: docs_nnx/api_reference/flax.nnx/nn/stochastic.rst ================================================ Stochastic ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: Dropout :members: ================================================ FILE: docs_nnx/api_reference/flax.nnx/object.rst ================================================ object ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: Pytree :members: .. autoclass:: Object :members: .. autofunction:: data .. autodata:: Data :annotation: .. autofunction:: static .. autodata:: Static :annotation: .. autofunction:: is_data .. autofunction:: register_data_type .. autofunction:: check_pytree ================================================ FILE: docs_nnx/api_reference/flax.nnx/rnglib.rst ================================================ rnglib ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: Rngs :members: __init__ .. autoclass:: RngStream :members: .. autofunction:: split_rngs .. autofunction:: fork_rngs .. autofunction:: reseed ================================================ FILE: docs_nnx/api_reference/flax.nnx/spmd.rst ================================================ spmd ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: get_partition_spec .. autofunction:: get_named_sharding .. autofunction:: with_partitioning ================================================ FILE: docs_nnx/api_reference/flax.nnx/state.rst ================================================ state ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: State :members: .. autoclass:: FlatState :members: .. autofunction:: filter_state .. autofunction:: from_flat_state .. autofunction:: map_state .. autofunction:: merge_state .. autofunction:: replace_by_pure_dict .. autofunction:: restore_int_paths .. autofunction:: to_flat_state .. autofunction:: to_pure_dict .. autofunction:: split_state ================================================ FILE: docs_nnx/api_reference/flax.nnx/summary.rst ================================================ summary ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: tabulate ================================================ FILE: docs_nnx/api_reference/flax.nnx/training/index.rst ================================================ training ---------------------------- Experimental API. See the `NNX page `__ for more details. .. toctree:: :maxdepth: 3 metrics optimizer ================================================ FILE: docs_nnx/api_reference/flax.nnx/training/metrics.rst ================================================ Metrics ------------------------ .. automodule:: flax.nnx.metrics .. currentmodule:: flax.nnx.metrics .. autoclass:: Metric :members: __init__, reset, update, compute .. autoclass:: Average :members: __init__, reset, update, compute .. autoclass:: Accuracy :members: update .. autoclass:: Welford :members: __init__, reset, update, compute .. autoclass:: MultiMetric :members: __init__, reset, update, compute ================================================ FILE: docs_nnx/api_reference/flax.nnx/training/optimizer.rst ================================================ Optimizer ------------------------ .. automodule:: flax.nnx.optimizer .. currentmodule:: flax.nnx.optimizer .. autoclass:: Optimizer :members: __init__, update ================================================ FILE: docs_nnx/api_reference/flax.nnx/transforms.rst ================================================ transforms ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: grad .. autofunction:: jit .. autofunction:: shard_map .. autofunction:: remat .. autofunction:: scan .. autoclass:: Carry .. autofunction:: value_and_grad .. autofunction:: vmap .. autofunction:: eval_shape .. autofunction:: custom_vjp .. autofunction:: vjp .. autofunction:: jvp .. autofunction:: cond .. autofunction:: switch .. autofunction:: while_loop .. autofunction:: fori_loop ================================================ FILE: docs_nnx/api_reference/flax.nnx/variables.rst ================================================ variables ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autoclass:: BatchStat :members: .. autoclass:: Cache :members: .. autoclass:: Intermediate :members: .. autoclass:: Param :members: .. autoclass:: Variable :members: .. autoclass:: VariableMetadata :members: .. autofunction:: with_metadata .. autofunction:: variable_name_from_type .. autofunction:: variable_type_from_name .. autofunction:: register_variable_name ================================================ FILE: docs_nnx/api_reference/flax.nnx/visualization.rst ================================================ visualization ------------------------ .. automodule:: flax.nnx .. currentmodule:: flax.nnx .. autofunction:: display ================================================ FILE: docs_nnx/api_reference/flax.struct.rst ================================================ flax.struct package ===================== .. currentmodule:: flax.struct .. automodule:: flax.struct .. autofunction:: dataclass .. autoclass:: PyTreeNode ================================================ FILE: docs_nnx/api_reference/flax.training.rst ================================================ flax.training package ===================== Train state ------------------------ .. currentmodule:: flax.training.train_state .. autoclass:: TrainState :members: apply_gradients, create ================================================ FILE: docs_nnx/api_reference/flax.traverse_util.rst ================================================ flax.traverse_util package ============================ .. currentmodule:: flax.traverse_util .. automodule:: flax.traverse_util Dict utils ------------ .. autofunction:: flatten_dict .. autofunction:: unflatten_dict .. autofunction:: path_aware_map ================================================ FILE: docs_nnx/api_reference/index.rst ================================================ API Reference ============= .. toctree:: :maxdepth: 4 flax.nnx/index flax.core.frozen_dict flax.struct flax.training ================================================ FILE: docs_nnx/conf.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Configuration file for the Sphinx documentation builder.""" # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) import os import sys import doctest sys.path.insert(0, os.path.abspath('..')) # Include local extension. sys.path.append(os.path.abspath('./_ext')) # Set environment variable to indicate that we are building the docs. os.environ['FLAX_DOC_BUILD'] = 'true' # patch sphinx # -- Project information ----------------------------------------------------- project = 'Flax' copyright = '2023, The Flax authors' # pylint: disable=redefined-builtin author = 'The Flax authors' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'myst_nb', 'codediff', 'flax_module', 'sphinx_design', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The suffix(es) of source filenames. # Note: important to list ipynb before md here: we have both md and ipynb # copies of each notebook, and myst will choose which to convert based on # the order in the source_suffix list. Notebooks which are not executed have # outputs stored in ipynb but not in md, so we must convert the ipynb. source_suffix = ['.rst', '.ipynb', '.md'] autosummary_generate = True master_doc = 'index' autodoc_typehints = 'none' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # # html_theme = 'pydata_sphinx_theme' html_theme = 'sphinx_book_theme' html_css_files = ['css/flax_theme.css'] # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = './flax.png' html_favicon = './flax.png' # title of the website html_title = '' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named 'default.css' will overwrite the builtin 'default.css'. html_static_path = ['_static'] html_extra_path = ['robots.txt'] # href with no underline and white bold text color announcement = """ This site covers the new Flax NNX API. [Click here for the old Flax Linen API] """ html_theme_options = { 'repository_url': 'https://github.com/google/flax', 'use_repository_button': True, # add a 'link to repository' button 'use_issues_button': False, # add an 'Open an Issue' button 'path_to_docs': ( 'docs_nnx' ), # used to compute the path to launch notebooks in colab 'launch_buttons': { 'colab_url': 'https://colab.research.google.com/', }, 'prev_next_buttons_location': None, 'show_navbar_depth': 1, 'announcement': announcement, } # -- Options for myst ---------------------------------------------- # uncomment line below to avoid running notebooks during development nb_execution_mode = os.environ.get("NB_EXECUTION_MODE", 'off') # Notebook cell execution timeout; defaults to 30. nb_execution_timeout = 100 # List of patterns, relative to source directory, that match notebook # files that will not be executed. myst_enable_extensions = ['dollarmath'] nb_execution_excludepatterns = [ 'mnist_tutorial.ipynb', # <-- times out 'transfer_learning.ipynb', # <-- transformers requires flax<=0.7.0 'flax/nnx', # exclude nnx 'guides/demo.ipynb', # TODO(cgarciae): broken, remove or update 'examples/gemma.ipynb', 'guides/bridge_guide.ipynb', # TODO(cgarciae): broken, bridge doesn't support Linen sow yet ] # raise exceptions on execution so CI can catch errors nb_execution_allow_errors = False nb_execution_raise_on_error = True # -- Extension configuration ------------------------------------------------- # Tell sphinx-autodoc-typehints to generate stub parameter annotations including # types, even if the parameters aren't explicitly documented. always_document_param_types = True # -- doctest configuration ------------------------------------------------- doctest_default_flags = doctest.NORMALIZE_WHITESPACE doctest_global_setup = """ import jax import jax.numpy as jnp from flax import nnx import logging as slog from absl import logging as alog # Avoid certain absl logging messages to break doctest filtered_message = [ 'SaveArgs.aggregate is deprecated', '', ] class _CustomLogFilter(slog.Formatter): def format(self, record): message = super(_CustomLogFilter, self).format(record) for m in filtered_message: if m in message: return '' return message alog.use_absl_handler() alog.get_absl_handler().setFormatter(_CustomLogFilter()) """ ================================================ FILE: docs_nnx/conf_sphinx_patch.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Patch Sphinx to improve documentation aesthetics.""" # TODO(cgarciae): Send a PR to sphinx to upstream this fix. Issue: https://github.com/google/flax/issues/2196 # This patch is needed to make autosummary provide the "annotations" # variable so we can exclude function attributes from the methods list # in flax_module.rst. The patch as such only adds this single line: # # ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())' # # We should consider sending a PR to sphinx so we can get rid of this. # Original source: https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351 from typing import Any import sphinx.ext.autodoc import sphinx.ext.autosummary.generate as ag def generate_autosummary_content( name: str, obj: Any, parent: Any, template: ag.AutosummaryRenderer, template_name: str, imported_members: bool, app: Any, recursive: bool, context: dict, modname: str = None, qualname: str = None, ) -> str: doc = ag.get_documenter(app, obj, parent) def skip_member(obj: Any, name: str, objtype: str) -> bool: try: return app.emit_firstresult( 'autodoc-skip-member', objtype, name, obj, False, {} ) except Exception as exc: ag.logger.warning( __( 'autosummary: failed to determine %r to be documented, ' 'the following exception was raised:\n%s' ), name, exc, type='autosummary', ) return False def get_class_members(obj: Any) -> dict[str, Any]: members = sphinx.ext.autodoc.get_class_members( obj, [qualname], ag.safe_getattr ) return {name: member.object for name, member in members.items()} def get_module_members(obj: Any) -> dict[str, Any]: members = {} for name in ag.members_of(obj, app.config): try: members[name] = ag.safe_getattr(obj, name) except AttributeError: continue return members def get_all_members(obj: Any) -> dict[str, Any]: if doc.objtype == 'module': return get_module_members(obj) elif doc.objtype == 'class': return get_class_members(obj) return {} def get_members( obj: Any, types: set[str], include_public: list[str] = [], imported: bool = True, ) -> tuple[list[str], list[str]]: items: list[str] = [] public: list[str] = [] all_members = get_all_members(obj) for name, value in all_members.items(): documenter = ag.get_documenter(app, value, obj) if documenter.objtype in types: # skip imported members if expected if imported or getattr(value, '__module__', None) == obj.__name__: skipped = skip_member(value, name, documenter.objtype) if skipped is True: pass elif skipped is False: # show the member forcedly items.append(name) public.append(name) else: items.append(name) if name in include_public or not name.startswith('_'): # considers member as public public.append(name) return public, items def get_module_attrs(members: Any) -> tuple[list[str], list[str]]: """Find module attributes with docstrings.""" attrs, public = [], [] try: analyzer = ag.ModuleAnalyzer.for_module(name) attr_docs = analyzer.find_attr_docs() for namespace, attr_name in attr_docs: if namespace == '' and attr_name in members: attrs.append(attr_name) if not attr_name.startswith('_'): public.append(attr_name) except ag.PycodeError: pass # give up if ModuleAnalyzer fails to parse code return public, attrs def get_modules(obj: Any) -> tuple[list[str], list[str]]: items: list[str] = [] for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__): fullname = name + '.' + modname try: module = ag.import_module(fullname) if module and hasattr(module, '__sphinx_mock__'): continue except ImportError: pass items.append(fullname) public = [x for x in items if not x.split('.')[-1].startswith('_')] return public, items ns: dict[str, Any] = {} ns.update(context) if doc.objtype == 'module': scanner = ag.ModuleScanner(app, obj) ns['members'] = scanner.scan(imported_members) ns['functions'], ns['all_functions'] = get_members( obj, {'function'}, imported=imported_members ) ns['classes'], ns['all_classes'] = get_members( obj, {'class'}, imported=imported_members ) ns['exceptions'], ns['all_exceptions'] = get_members( obj, {'exception'}, imported=imported_members ) ns['attributes'], ns['all_attributes'] = get_module_attrs(ns['members']) ispackage = hasattr(obj, '__path__') if ispackage and recursive: ns['modules'], ns['all_modules'] = get_modules(obj) elif doc.objtype == 'class': ns['members'] = dir(obj) ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys()) ns['methods'], ns['all_methods'] = get_members( obj, {'method'}, ['__init__'] ) ns['attributes'], ns['all_attributes'] = get_members( obj, {'attribute', 'property'} ) ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys()) if modname is None or qualname is None: modname, qualname = ag.split_full_qualified_name(name) if doc.objtype in ('method', 'attribute', 'property'): ns['class'] = qualname.rsplit('.', 1)[0] if doc.objtype in ('class',): shortname = qualname else: shortname = qualname.rsplit('.', 1)[-1] ns['fullname'] = name ns['module'] = modname ns['objname'] = qualname ns['name'] = shortname ns['objtype'] = doc.objtype ns['underline'] = len(name) * '=' if template_name: return template.render(template_name, ns) else: return template.render(doc.objtype, ns) ag.generate_autosummary_content = generate_autosummary_content ================================================ FILE: docs_nnx/contributing.md ================================================ # How to contribute Everyone can contribute to Flax, and the Flax development team values everyone's contributions! You can contribute in many more ways than just writing code. Answering questions on the [Flax GitHub Discussions page](https://github.com/google/flax/discussions), helping each other, and improving Flax documentation are extremely valuable to the Flax ecosystem. We also appreciate if you spread the word, for instance by starring the [Flax GitHub repository](https://github.com/google/flax), or referencing Flax in blog posts of projects that used it. This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Ways to contribute We welcome pull requests (PRs), in particular for those issues [marked as PR-ready](https://github.com/google/flax/issues?q=is%3Aopen+is%3Aissue+label%3A%22Status%3A+pull+requests+welcome%22). For other proposals, you should first open a GitHub Issue or a GitHub Discussion to start a conversation about your planned contribution. ## Contributing code using pull requests The Flax development team performs all development using [Git](https://git-scm.com/). To contribute, you should have basic knowledge of [Git](https://git-scm.com/) and [GitHub](https://docs.github.com). (You can learn how to set up Git by following Git's official [Getting Started - First-Time Git Setup](https://git-scm.com/book/en/v2/Getting-Started-First-Time-Git-Setup) and GitHub's [Set Up Git](https://docs.github.com/en/get-started/quickstart/set-up-git) guides.) To contribute code to Flax on GitHub, follow these steps: ### To create a pull request from a fork 1. Using GitHub's web UI, fork the Flax repository by clicking the 'Fork' button on the [`github.com/google/flax` repository page](http://www.github.com/google/flax). This creates a fork (a copy) of the Flax repository in your own GitHub. Reference: [Creating a pull request from a fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). 2. Install [Python >=3.7](https://www.python.org/downloads/). 3. (Optional) Create a virtual environment or a Docker container. See [`dev/README.md`](https://github.com/google/flax/blob/main/dev/README.md) for details on how to set up a Docker Container. To set up a virtual environment, run the following: ```bash python3 -m virtualenv env . env/bin/activate ``` This ensures all your dependencies are installed in this environment. 4. Clone your local forked Flax repo with `git clone`. Then, install the required packages with [PyPi](https://pip.pypa.io/en/stable/cli/pip_install/). This enables you to immediately test the code after modifying it: ```bash git clone https://github.com/YOUR_USERNAME/flax cd flax pip install -e ".[all,testing,docs]" ``` You can also use [uv](https://docs.astral.sh/uv/) to setup the development environment: ```bash uv sync --all-extras ``` 5. Set up pre-commit hooks, this will run some automated checks during each `git` commit and possibly update some files that require changes. ```bash pip install pre-commit pre-commit install ``` 6. Add the Google Flax repo (not your fork) as an upstream remote, so you can use it to sync your changes. ```bash git remote add upstream http://www.github.com/google/flax ``` 7. Create a branch, such as `my_development_branch`, you will develop from: ```bash git checkout -b my_development_branch ``` 8. Implement your changes using your favorite editor (we recommend [Visual Studio Code](https://code.visualstudio.com/)). Make sure the tests pass by running the following command from the top of the repository: ```bash ./tests/run_all_tests.sh ``` 9. Once you finish making changes, don't forget to create commits ([learn how to write a commit message](https://chris.beams.io/posts/git-commit/)): ```bash git add file1.py file2.py ... # or use `git add .` to add all changed files git commit -m "Your commit message" ``` Then sync your code with the main repository: ```bash git fetch upstream git rebase upstream/main ``` 10. Finally, push your commit on your `my_development_branch`, and create a remote branch in your fork that you can use to create a pull request from: ```bash git push --set-upstream origin my_development_branch ``` After running the command, you should get a GitHub link in your (VS Code) terminal output for creating a pull request. If you don't receive a link after `git push`, use the [GitHub web UI](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request?tool=webui) to create a pull request. 11. Make sure your pull request passes the [Flax PR checklist](https://github.com/google/flax/blob/main/.github/pull_request_template.md#checklist). If so, create a pull request from the Flax repository and send it for review. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. You can learn more in GitHub's [Creating a pull request from a fork ](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork). documentation. ### Adding or updating dependencies To add or update dependencies, you must use `uv` after updating the `pyproject.toml` file to ensure that the `uv.lock` file is up-to-date. ```bash uv sync --all-extras ``` Alternatively use can use `uv add` to add or update the dependencies automatically, for example: ```bash uv add 'some-package>=1.2.3' ``` ### Updating Jupyter Notebooks We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of docs in `docs/notebooks`: one in the Jupyter Notebook (`.ipynb`) format, and one in Markdown (`.md`). The former can be opened and executed directly in [Google Colab](https://colab.research.google.com/). Markdown makes it easier to track changes/diffs within version control and, for example, GitHub web UI, since `.ipynb` files are based on JSON. #### Editing Jupyter Notebooks (`.ipynb`) For making large changes that substantially modify code and outputs, it's recommended to edit the notebooks in [Jupyter](https://jupyter.org/install) or in [Colab](https://colab.research.google.com/). If you choose to work in Colab, go to **File** and click **Upload notebook**, then pick your file. After loading it into Colab and editing it, make sure you run the cells, and that there aren't any errors. Click on **Runtime**, then select **Run all**. After you finish, click **File** > **Download** > **Download ipynb**. You may also want to test that the file executes properly by using `sphinx-build`, as explained above. After you make changes in your Jupyter Notebook, follow the steps _Syncing notebooks_ below. #### Editing Markdown files (`.md`) For making smaller changes to the text content of the notebooks, it is easiest to edit the `.md` versions using a text editor. After you make changes in your Markdown file, follow the steps _Syncing notebooks_ below. #### Syncing notebooks After editing either the `.ipynb` or `.md` versions of the docs, sync the two versions using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` on the updated notebooks. First, make sure you have jupytext installed. The jupytext version should match the one specified in [.pre-commit-config.yaml](https://github.com/google/flax/blob/main/.pre-commit-config.yaml) (currently, it is v1.13.8). ```bash pip install jupytext==1.13.8 ``` Then, after you have made your changes in the Jupyter Notebook, sync the contents with its Markdown-equivalent file by running the following command: ```bash jupytext --sync path/to/the/file.ipynb ``` Similarly, to sync your Markdown file with its Jupyter Notebook version, run: ```bash jupytext --sync path/to/the/file.md ``` Note that if you receive an error, and it is the first time you worked in a Jupyter Notebook, you may need to (re)create a synced copy of the document (which is explained in detail in _Creating new notebooks_ section below): ```bash jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb ``` Once you're finished with syncing the `.md` and `.ipynb` files, you can check that they are properly synced using the [pre-commit](https://pre-commit.com/) framework to perform the same checks used in the Flax GitHub CI: ```bash git add docs -u # pre-commit runs on files in git staging. pre-commit run jupytext ``` #### Creating new notebooks If you are adding a new Jupyter Notebook to the documentation, you can use `jupytext --set-formats`. It can set up both the Jupyter Notebook (`.ipynb`) and Markdown (`.md`) versions of the file: ```bash jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb ``` This works by adding a `"jupytext"` metadata field to the notebook file which specifies the desired formats. The `jupytext --sync` command can then recognize them when invoked. After you make changes in your file(s), follow the steps from the _Syncing notebooks_ section above to keep the contents of both Markdown and Jupyter Notebook files in sync. #### Notebooks within the Sphinx build Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the Docs](https://flax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. We exclude some notebooks from the build because, for example, they contain long computations. See `exclude_patterns` in [`conf.py`](https://github.com/google/flax/blob/main/docs/conf.py). ### Updating the pull request contents Every pull request should ideally be limited to just one commit, so if you have multiple commits please squash them. Assuming you now have only one commit in your pull request, and want to add changes requested during review: 1. Make the changes locally in your editor. 2. Run `git commit -a --amend`. This updates the commit contents and allows you to edit the commit message. 3. At this point, `git push` alone will result in an error. Instead, use `git push --force`. 4. Check that it's done: The changes to your commit should be immediately reflected in the Github web UI. ## Troubleshooting ### Too many commits in a pull request If your PR has too many commits associated with it (for example, more than five), you need to squash them. Otherwise, the Flax docs build process may fail with an error message. This is because of the following reasons: * There are more than five commits in your pull request; and * The Flax source sync process fails when the commit tree is too large. To squash your commits, you can rebase your branch to `main` and create a new commit containing all your changes, run the following command: ```bash git rebase main && git reset --soft main && git commit ``` This will apply all your changes to the main branch. Note that if you had to resolve any conflicts while working on your change (for instance, you did a `pull upstream main` which led to conflict), then you will have to resolve these conflicts again. After you have successfully rebased your branch, you should push your changes. And because you changed the commit history, you may have to use `git push --force`. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. All submissions to Google Open Source projects need to follow Google’s Contributor License Agreement (CLA), in which contributors agree that their contribution is an original work of authorship. This doesn’t prohibit the use of coding assistance tools, but what’s submitted does need to be a contributor’s original creation. ================================================ FILE: docs_nnx/examples/core_examples.rst ================================================ Core examples ============= Core examples are hosted on the GitHub Flax repository in the `examples `__ directory. Each example is designed to be **self-contained and easily forkable**, while reproducing relevant results in different areas of machine learning. Some of the examples below have a link "Interactive🕹" that lets you run them directly in Colab. Transformers ******************** - :octicon:`mark-github;0.9em` `Gemma `__ : A family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology. - :octicon:`mark-github;0.9em` `LM1B `__ : Transformer encoder trained on the One Billion Word Benchmark. Toy examples ******************** `NNX toy examples `__ directory contains a few smaller, standalone toy examples for simple training scenarios. ================================================ FILE: docs_nnx/examples/gemma.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Using pretrained Gemma for inference with Flax NNX\n", "\n", "This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.\n", "\n", "> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n", "\n", "You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Installation\n", "\n", "Install the necessary dependencies, including `kagglehub`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "! pip install --no-deps -U flax\n", "! pip install jaxtyping kagglehub treescope" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download the model\n", "\n", "To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:\n", "\n", "1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n", "2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n", "3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n", "\n", "Then run the cell below." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e7cf9f0345845f1a3edc72fa4411eb4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
(()=>{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child => { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=>{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })();
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "transformer = transformer_lib.Transformer.from_params(params)\n", "nnx.display(transformer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Perform sampling/inference\n", "\n", "Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "cellView": "form" }, "outputs": [], "source": [ "sampler = sampler_lib.Sampler(\n", " transformer=transformer,\n", " vocab=vocab,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You're ready to start sampling!\n", "\n", "**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.\n", "\n", "Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "cellView": "form" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prompt:\n", "\n", "# Python program for implementation of Bubble Sort\n", "\n", "def bubbleSort(arr):\n", "Output:\n", "\n", " for i in range(len(arr)):\n", " for j in range(len(arr) - i - 1):\n", " if arr[j] > arr[j + 1]:\n", " swap(arr, j, j + 1)\n", "\n", "\n", "def swap(arr, i, j):\n", " temp = arr[i]\n", " arr[i] = arr[j]\n", " arr[j] = temp\n", "\n", "\n", "# Driver code\n", "arr = [5, 2, 8, 3, 1, 9]\n", "print(\"Unsorted array:\")\n", "print(arr)\n", "bubbleSort(arr)\n", "print(\"Sorted array:\")\n", "print(arr)\n", "\n", "\n", "# Time complexity of Bubble sort O(n^2)\n", "# where n is the length of the array\n", "\n", "\n", "# Space complexity of Bubble sort O(1)\n", "# as it only requires constant extra space for the swap operation\n", "\n", "\n", "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", "\n", "```python\n", "# This program uses the bubble sort algorithm to sort the given array in ascending order.\n", "\n", "def bubbleSort(arr):\n", " for i in range(len(arr)):\n", " for j in range(len(arr) - i - 1):\n", " if arr[j] > arr[j + 1]:\n", " swap(arr, j, j + 1)\n", "\n", "\n", "def swap(\n", "\n", "##########\n" ] } ], "source": [ "input_batch = [\n", " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", " ]\n", "\n", "out_data = sampler(\n", " input_strings=input_batch,\n", " total_generation_steps=300, # The number of steps performed when generating a response.\n", " )\n", "\n", "for input_string, out_string in zip(input_batch, out_data.text):\n", " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n", " print()\n", " print(10*'#')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should get a Python implementation of the bubble sort algorithm." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs_nnx/examples/gemma.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Example: Using pretrained Gemma for inference with Flax NNX This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference. > Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/). You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code. +++ ## Installation Install the necessary dependencies, including `kagglehub`. ```{code-cell} ipython3 ! pip install --no-deps -U flax ! pip install jaxtyping kagglehub treescope ``` ## Download the model To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key: 1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'. 2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key. 3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys. Then run the cell below. ```{code-cell} ipython3 import kagglehub kagglehub.login() ``` If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`. **Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above. ``` import os from google.colab import userdata # `userdata` is a Colab API. os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') ``` Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files. **Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100. **Note:** To avoid 403 error when downloading the model, you need to consent to the license for Gemma models on Kaggle. To do that, open https://www.kaggle.com/models/google/gemma/flax/ in the browser and click on "Download" button choosing any version of Gemma model. In the next window you will be proposed to agree with Gemma models usage license. Once, this step is done, you will be able to download the model using the code below. ```{code-cell} ipython3 from IPython.display import clear_output VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"} weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}') ckpt_path = f'{weights_dir}/{VARIANT}' vocab_path = f'{weights_dir}/tokenizer.model' ``` ## Python imports ```{code-cell} ipython3 from flax import nnx import sentencepiece as spm ``` To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub. ```{code-cell} ipython3 import sys import tempfile with tempfile.TemporaryDirectory() as tmp: # Create a temporary directory and clone the `flax` repo. # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules. ! git clone https://github.com/google/flax.git {tmp}/flax sys.path.append(f"{tmp}/flax/examples/gemma") import params as params_lib import sampler as sampler_lib import transformer as transformer_lib sys.path.pop(); ``` ## Load and prepare the Gemma model First, load the Gemma model parameters for use with Flax. ```{code-cell} ipython3 :cellView: form params = params_lib.load_and_format_params(ckpt_path) ``` Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library. ```{code-cell} ipython3 :cellView: form vocab = spm.SentencePieceProcessor() vocab.Load(vocab_path) ``` Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint. **Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release. ```{code-cell} ipython3 transformer = transformer_lib.Transformer.from_params(params) nnx.display(transformer) ``` ## Perform sampling/inference Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes. ```{code-cell} ipython3 :cellView: form sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, ) ``` You're ready to start sampling! **Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent. Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response). ```{code-cell} ipython3 :cellView: form input_batch = [ "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):", ] out_data = sampler( input_strings=input_batch, total_generation_steps=300, # The number of steps performed when generating a response. ) for input_string, out_string in zip(input_batch, out_data.text): print(f"Prompt:\n{input_string}\nOutput:\n{out_string}") print() print(10*'#') ``` You should get a Python implementation of the bubble sort algorithm. ================================================ FILE: docs_nnx/examples/index.rst ================================================ Examples ======== .. toctree:: :maxdepth: 2 gemma core_examples ================================================ FILE: docs_nnx/faq.rst ================================================ Frequently Asked Questions (FAQ) ================================ This is a collection of answers to frequently asked questions (FAQ). You can contribute to the Flax FAQ by starting a new topic in `GitHub Discussions `__. Where to search for an answer to a Flax-related question? ********************************************************* There are a number of official Flax resources to search for information: - `Flax Documentation on ReadTheDocs `__ (this site): Use the `search bar `__ or the table of contents on the left-hand side. - `google/flax GitHub Discussions `__: Search for an existing topic or start a new one. If you can't find what you're looking for, feel free to ask the Flax team or community a question. - `google/flax GitHub Issues `__: Use the search bar to look for an existing issue or a feature request, or start a new one. How to take the derivative with respect to an intermediate value (using :code:`Module.perturb`)? ************************************************************************************************ To take the derivative(s) or gradient(s) of the output with respect to a hidden/intermediate activation inside a model layer, you can use :meth:`flax.linen.Module.perturb`. You define a zero-value :class:`flax.linen.Module` "perturbation" parameter – :code:`perturb(...)` – in the forward pass with the same shape as the intermediate activation, define the loss function with :code:`'perturbations'` as an added standalone argument, perform a JAX derivative operation with :code:`jax.grad` on the perturbation argument. For full examples and detailed documentation, go to: - The :meth:`flax.linen.Module.perturb` API docs - The `Extracting gradients of intermediate values `_ guide - `Flax GitHub Discussions #1152 `__ Is Flax Linen :code:`remat_scan()` the same as :code:`scan(remat(...))`? ************************************************************************ Flax :code:`remat_scan()` (:meth:`flax.linen.remat_scan()`) and :code:`scan(remat(...))` (:meth:`flax.linen.scan` over :meth:`flax.linen.remat`) are not the same, and :code:`remat_scan()` is limited in cases it supports. Namely, :code:`remat_scan()` treats the inputs and outputs as carries (hidden states that are carried through the training loop). You are recommended to use :code:`scan(remat(...))`, as typically you would need the extra parameters, such as ``in_axes`` (for input array axes) or ``out_axes`` (output array axes), which :meth:`flax.linen.remat_scan` does not expose. What are the recommended training loop libraries? ************************************************* Consider using CLU (Common Loop Utils) `google/CommonLoopUtils `__. To get started, go to this `CLU Synopsis Colab `__. You can find answers to common questions about CLU with Flax on `google/flax GitHub Discussions `__. Check out the official `google/flax Examples `__ for examples of using the training loop with (CLU) metrics. For example, this is `Flax ImageNet's train.py `__. For computer vision research, consider `google-research/scenic `__. Scenic is a set of shared light-weight libraries solving commonly encountered tasks when training large-scale vision models (with examples of several projects). Scenic is developed in JAX with Flax. To get started, go to the `README page on GitHub `__. ================================================ FILE: docs_nnx/flip/0000-template.md ================================================ - Start Date: (fill me in with today's date, YYYY-MM-DD) - FLIP PR: [#0000](https://github.com/google/flax/pull/0000) - FLIP Issue: [#0000](https://github.com/google/flax/issues/0000) (Below sections are just a possible structure - please adapt to your FLIP.) # Summary [summary]: #summary One paragraph explanation of the FLIP. # Motivation [motivation]: #motivation Why are we doing this? What use cases does it support? What is the expected outcome? # Implementation [implementation]: #implementation The technical part. # Discussion [discussion]: #discussion Summarize the discussion from the original issue and from the pull request. ================================================ FILE: docs_nnx/flip/1009-optimizer-api.md ================================================ - Start Date: 2021-02-08 - FLIP PR: [#1011](https://github.com/google/flax/pull/1011) - FLIP Issue: [#1009](https://github.com/google/flax/issues/1009) Table of contents: - [Summary] - [Motivation] - [Using Optax] - [Gradient Transformations] - [Optax Training Step] - [Multi Optimizer] - [Train State] - [Previous API] - [Optimizer and OptimizerDef] - [Previous Training Step] - [Update Plan] - [Appendix] - [Setup Code] # Summary [Summary]: #summary This FLIP proposes to replace our current `flax.optim` API (referred to as [previous API] in this document) with [Optax], DeepMind's optimizer library. # Motivation [motivation]: #motivation Our current API (referred to as [previous API] in this document) uses a pattern where an `Optimizer` dataclass is created from a pytree of `target` variables and from an `OptimizerDef` that defines how to update optimizer state, hyperparameters, and target variables. This pattern is relatively complex for implementing a simple optimizer, while being quite verbose in the typical Linen train step (especially when using mutable state collections). This package `flax.optim` contains some optimizers, but that list is far from exhaustive and ideally we would instead use JAX optimizers from a dedicated PyPi package. DeepMind already has a dedicated library — [Optax] — that implements a wide range of interesting optimizers and provides a framework to compose new optimizers from reusable gradient transformations. [Optax]: https://github.com/deepmind/optax # Using Optax [Using Optax]: #using-optax ## Gradient Transformations [Gradient Transformations]: #gradient-transformations While [Optax] does provide predefined optimizers (like `optax.adam`, or `optax.sgd` with momentum), it is really a library of *gradient transformations* and the idiomatic way of instantiating an optimizer is by providing a combination of these gradient transformations. To emulate the momentum optimizer from the example when using the [previous API] we would write: ```python import optax tx = optax.chain( optax.trace(decay=0.9, nesterov=False), optax.scale_by_schedule(lambda step: -get_learning_rate(step)), ) ``` Remarks: - Above gradient transformation would be equivalent with the example under [Optimizer and OptimizerDef] where we define a Momentum optimizer without Nesterov momentum (note that the `beta` parameter corresponds to the `decay` parameter of the `optax.trace()` transformation, and the learning rate is applied in a second chained transformation). - Note that hyper parameters like `decay` or `nesterov` only exist in the inner scope of the higher order functions returning the `GradientTransformation`. Such a gradient transformation is currently defined as a `NamedTuple` of the `init()` and the `update()` function. In principle this pattern could be extended to also store hyperparameters, maybe a point to discuss on the [Optax] repo. - We can use a `get_learning_rate()` that returns the learning rate depending on the step number when defining the Optax gradient update transformation. Above code illustrates how this can be a drop-in replacement for a function we also use in our [previous training step], where this update function already exists (notice how we need to invert the sign because we add the gradient update to the parameters). In addition, you can use [`inject_hyperparams()`](https://github.com/deepmind/optax/pull/48) to schedule arbitrary hyper parameters with Optax. ## Optax Training Step [Optax Training Step]: #optax-training-step ```python @functools.partial(jax.jit, static_argnums=(4, 5)) def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn): def loss_fn(params): logits, new_model_state = apply_fn( {**variables, 'params': params}, inputs, mutable=['batch_stats']) loss = xent_loss(logits, labels) return loss, new_model_state variables, params = variables.pop('params') (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)( params) updates, new_opt_state = tx_update_fn(grads, opt_state, params) new_params = optax.apply_updates(params, updates) new_variables = {**variables, **new_model_state, 'params': new_params} return new_opt_state, new_variables, loss opt_state = tx.init(variables['params']) for batch in ds.as_numpy_iterator(): opt_state, variables, loss = train_step( opt_state, variables, batch['image'], batch['label'], model.apply, tx.update) print(loss) ``` Remarks: - Since `tx.update()` only transforms the gradient, we still need to call `optax.apply_updates()` to apply these transformed gradients to the parameters. - Compared with the [previous API], we can now keep the entire `variables` including the `params` as an input and output to the `train_step()`. - Splitting `params` from `variables` is still necessary inside the train step because we only want to compute gradients with respect to `params` and not the entire `variables`. - We can still log internal optimizer state, such as the learning rate, as long as Optax transformations expose that information in their respective state. For example, `optax.scale_by_schedule()` currently only exposes `opt_state.count` but could easily be extend to also expose the `step_size`. The same is true for internal optimizer states that change over time. ## Multi Optimizer [Multi Optimizer]: #multi-optimizer The [previous API] defined `flax.optim.MultiOptimizer` for processing different parts of the parameter tree with different optimizers: ```python biases_traversal = flax.optim.ModelParamTraversal( lambda path, _: path.endswith('/bias')) not_biases_traversal = flax.optim.ModelParamTraversal( lambda path, _: not path.endswith('/bias')) optimizer_def = flax.optim.MultiOptimizer( (biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)), (not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)), ) ``` Note how we first define a traversal that selects parameters based on their path (which is the concatenation of module scopes and variable name), and then create a `MultiOptimizer` that binds a different optimizer for each of these separate traversals. Optax has recently implemented `optax.masked()` that can be used for specifying gradient transformations that only applied to a subset of the gradients: ```python def flattened_traversal(fn): def mask(data): flat = traverse_util.flatten_dict(data) return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) return mask tx = optax.chain( optax.masked(optax.sgd(learning_rate=0.1), mask=flattened_traversal(lambda path, _: path[-1] == 'bias')), optax.masked(optax.sgd(learning_rate=0.05), mask=flattened_traversal(lambda path, _: path[-1] != 'bias')), ) ``` ## Train State [Train State]: #train-state In Flax it is common to hand around a `TrainState` object that can then be used for checkpointing. This simplifies the above [Optax training step] a bit by reducing the number of arguments and getting rid of the `static_argnums`. We can define a `TrainState` dataclass that wraps the common pattern of updating the optimizer state and parameters by applying the gradients. ```python # Small helper class in flax.training class TrainState(flax.struct.PyTreeNode): step: int apply_fn: Callable = flax.struct.field(pytree_node=False) params: flax.core.FrozenDict[str, Any] tx: optax.GradientTransformation = flax.struct.field(pytree_node=False) opt_state: optax.OptState def apply_gradients(self, *, grads, **kwargs): updates, new_opt_state = self.tx.update( grads, self.opt_state, self.params) new_params = optax.apply_updates(self.params, updates) return self.replace( step=self.step + 1, params=new_params, opt_state=new_opt_state, **kwargs, ) @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): opt_state = tx.init(params) return cls( step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs, ) ``` Users can then derive from this dataclass and add more fields, for example mutable model state: ```python from flax.training import train_state class TrainState(train_state.TrainState): batch_stats: flax.core.FrozenDict[str, Any] ``` With this the [Optax Training Step] becomes: ```python @jax.jit def train_step(state, inputs, labels): def loss_fn(params): outputs, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, inputs, mutable=['batch_stats']) loss = xent_loss(outputs, labels) return loss, new_model_state (loss, new_model_state), grads = jax.value_and_grad( loss_fn, has_aux=True)(state.params) new_state = state.apply_gradients( grads=grads, batch_stats=new_model_state['batch_stats'], ) return new_state, loss state = TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, batch_stats=variables['batch_stats'], ) for batch in ds.as_numpy_iterator(): state, loss = train_step(state, batch['image'], batch['label']) ``` The train step without mutable state reduces to: ```python @jax.jit def train_step(state, inputs, labels): def loss_fn(params): outputs = state.apply_fn({'params': params}, inputs) loss = xent_loss(outputs, labels) return loss loss, grads = jax.value_and_grad(loss_fn)(state.params) new_state = state.update(grads=grads) return new_state, loss state = flax.training.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, ) for batch in ds.as_numpy_iterator(): state, loss = train_step(state, batch['image'], batch['label']) ``` Remarks: - It is a common pattern in Flax training loops to have a `TrainState` dataclass that is updated with new state after every step. - The simple solution proposed in `flax.training.train_state` an be extended with additional data, but advanced usecases (e.g. multiple different models and/or optimizers) are not supported. Users should instead fork the dataclass and re-implement it to their needs. - As opposed to the `Optimizer` abstraction in the [previous API], the `TrainState` now directly contains the `.params`, without having to to through `.optimizer` # Previous API [previous API]: #previous-api ## Optimizer and OptimizerDef [Optimizer and OptimizerDef]: #optimizer-and-optimizerdef The optimizer itself would be implemented by creating a new class derived from `OpimizerDef`: ```python # flax/optim/momentum.py @flax.struct.dataclass class _MomentumHyperParams: learning_rate: jnp.ndarray beta: jnp.ndarray @flax.struct.dataclass class _MomentumParamState: momentum: np.ndarray class Momentum(flax.optim.OptimizerDef): def __init__(self, learning_rate=None, beta=0.9): super().__init__( _MomentumHyperParams(learning_rate, beta) ) def init_param_state(self, param): return _MomentumParamState(jnp.zeros_like(param)) def apply_param_gradient(self, step, hyper_params, param, state, grad): del step assert hyper_params.learning_rate is not None new_momentum = state.momentum * hyper_params.beta + grad new_params = param - hyper_params.learning_rate * new_momentum return new_params, _MomentumParamState(new_momentum) ``` Remarks: - Note the relationship between `OptimizerDef` and `Optimizer` : When the function `Optimizer.apply_gradient()` is called from the user code, it calls into `OptimizerDef.apply_gradient()` (among other things) which in turn will call `OptimizerDef.apply_param_gradient()` (implemented by subclasses of `OptimizerDef`). - The functions `init_param_state()` and `apply_param_gradient()` are called for every leaf in the params/grads pytree. This makes it possible to write the calculations directly without `jax.tree_util.tree_map()`. - The interface was defined in pre-Linen without the distinction of `params` vs. other collections in `variables` in mind. The original API was elegant because one only needed to pass around the optimizer, which included the parameters, optimizer state, optimizer hyperparameters, and a reference to the `OptimizerDef` to perform the param/state update. ## Previous Training Step [Previous Training Step]: #previous-training-step An optimizer would first be constructed from its definition and the pytree of target params: ```python optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9) optimizer = optimizer_def.create(variables['params']) ``` Then, the target variables would optimized in the train step (assuming a single non-params collection "batch_stats"): ```python def make_train_step(apply_fn): @jax.jit def train_step(optimizer, batch_stats, inputs, labels): def loss_fn(params): variables = {'params': params, 'batch_stats': batch_stats} logits, new_model_state = apply_fn( variables, inputs, mutable=['batch_stats']) loss = xent_loss(logits, labels) return loss, new_model_state['batch_stats'] (loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)( optimizer.target) lr = get_learning_rate(step) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) return new_optimizer, new_batch_stats, loss return train_step batch_stats = variables['batch_stats'] train_step = make_train_step(model.apply) for step, batch in enumerate(ds) optimizer, batch_stats, loss = train_step( optimizer, batch_stats, batch['image'], batch['label']) ``` Remarks: - Notice how `optimizer.apply_gradient()` can take additional arguments to update hyperparameters, such as learning rate from an independent function `get_learning_rate()` in this case. # Update Plan [Update Plan]: #update-plan 1. Finalize discussions on this FLIP 2. Add [equivalence tests] to Optax that guarantee that existing `flax.optim` optimizers return identical values with corresponding `optax` optimizers. 3. Update examples to use Optax and verify that they reach the same final performance with the same computational cost. 4. Port missing optimizers to Optax (e.g. Adafactor) - and verify above points. 5. Update all documentation (including README, Flax guided tour, HOWTOs, ...) to talk exclusively about Optax optimizers. 6. Create a transition guide for updating users from `flax.optim` to using Optax. This transition guide should also point to Optax's [equivalence tests] and the pull requests updating the examples. 7. Mark optimizers in `flax.optim` as deprecated. [equivalence tests]: https://github.com/deepmind/optax/blob/master/optax/_src/equivalence_test.py Note that all current Flax examples use an optimizer that is already available in Optax: | Example | Flax | Optax | Comments | | -------- | -------------- | ----------- | ----------------------------------- | | imagenet | optim.Momentum | optax.sgd | DynamicScale can be used unchanged. | | mnist | optim.Momentum | optax.sgd | | | nlp_seq | optim.Adam | optax.adamw | | | pixelcnn | optim.Adam | optax.adam | | | ppo | optim.Adam | optax.adam | | | seq2seq | optim.Adam | optax.adam | | | vae | optim.Adam | optax.adam | | | wmt | optim.Adam | optax.adamw | | (Flax's Adam implementation has an optional parameter for weight decay, but in Optax Adam with and without weight decay are two different aliases.) # Appendix [Appendix]: #appendix ## Setup Code [Setup Code]: #setup-code The following setup code can be used for running the code snippets in this FLIP: ```python import functools from typing import Callable, Sequence import jax import jax.numpy as jnp import flax import flax.linen as nn import tensorflow as tf import tensorflow_datasets as tfds def pp(features): return { 'image': tf.cast(features['image'], tf.float32) / 255 - 0.5, 'label': features['label'], } class Model(nn.Module): @nn.compact def __call__(self, inputs): x = inputs.reshape([inputs.shape[0], -1]) x = nn.normalization.BatchNorm(True)(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x def onehot(labels, num_classes, on_value=1.0, off_value=0.0): x = (labels[..., None] == jnp.arange(num_classes)[None]) x = jax.lax.select( x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) return x.astype(jnp.float32) def xent_loss(logits, labels): return -jnp.sum( onehot(labels, num_classes=10) * logits) / labels.size def get_learning_rate(step): return 0.1 model = Model() rng = jax.random.key(0) ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16) batch = next(iter(ds)) variables = model.init(rng, jnp.array(batch['image'][:1])) jax.tree_util.tree_map(jnp.shape, variables) ``` ================================================ FILE: docs_nnx/flip/1777-default-dtype.md ================================================ # FLIP: Default dtypes - Start Date: 2022-01-11 - FLIP PR: [#1776](https://github.com/google/flax/pull/1776) - FLIP Issue: [#1777](https://github.com/google/flax/issues/1777) - Status: Implemented ## Summary This FLIP proposes to replace the default dtype which is currently fixed to float32, and instead use the JAX type promotion results to derive a default dtype from the input and parameters of a layer. ## Motivation Currently, Linen Modules always produce `module.dtype` (defaults to float32) outputs regardless of input and parameter dtypes. Half-precision types like float16 and bfloat16 are supported by explicitly passing the half-precision type to each Module. The way this is currently implemented is that each Module has a dtype argument with float32 as the default value. The layer guarantees that this dtype will be the return type of the result returned by `__call__`. The current behavior is problematic and results in silent bugs, especially for dtypes that do not fit inside float32 (complex, float64). Also, the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes. ### Dtypes in JAX JAX uses a NumPy-inspired [dtype promotion](https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py) mechanism as explained [here](https://jax.readthedocs.io/en/latest/type_promotion.html?highlight=lattice#type-promotion-semantics). The type promotion rules are summarized by the following type lattice: ![JAX type promotion lattice](https://jax.readthedocs.io/en/latest/_images/type_lattice.svg) ## Dtypes in Linen Besides input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently, the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32. By using the dtype promotion rules already available in JAX we can avoid this issue. A public API is available called `jax.numpy.result_dtype(*args)`, which returns the dtype that JAX would promote the given arguments to, in accordance with the type promotion lattice. For Linen layers the arguments would be the layer inputs together with the parameters. For example, for a linear layer this would be inputs, kernel, and bias. Note that there is also a `param_dtype` attribute in standard Linen Modules that also defaults to flaot32. This behavior is left untouched and encodes the common case of having float32 parameters. There are a few reasons why float32 is almost always the correct dtype for parameters: 1. Storing weights in half-precision often leads to underflow during optimization. 2. Double precision is rarely used because it severely slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for. 3. Complex Modules are relatively uncommon. Even within complex networks, the complex inputs can be projected with a real matrix. # Implementation A simplified example implementation: ```python def promote_arrays(*xs, dtype): if dtype is None: dtype = jnp.result_type(*jax.tree_util.tree_leaves(xs)) return jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), xs) Dtype = Any class Dense(nn.Module): features: int kernel_init: Callable bias_init: Callable dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 @nn.compact def __call__(self, x): kernel = self.param("kernel", self.kernel_init, (x.shape[-1], self.features), self.param_dtype) bias = self.param("bias", self.bias_init, (self.features,), self.param_dtype) x, kernel, bias = promote_arrays(x, kernel, bias, dtype=self.dtype) return x @ kernel + bias ``` ## Half-precision dtypes Some layers don’t work with half-precision dtypes internally. For example: The normalization layers currently compute mean and variance in float32 even when a half-precision dtype is specified to avoid numerical issues. We can replicate this behavior by calling result_dtype with a dummy argument that has the minimum precision for the sub computation to work correctly. ## Backward compatibility This proposal causes some layers to behave differently in cases where the dtype is not specified to a Linen Module. By default, parameters are in float32. Therefore, passing in half or float32 precision inputs will cause a float32 dtype and no functional differences with current behavior. When passing complex or float64 precision, the result will no longer truncate the imaginary component or the precision. The silent truncation is problematic and has caused [user complaints](https://github.com/google/flax/issues/805#issuecomment-981468837). Therefore, this change can be considered a bugfix. Thus, although this proposal strictly speaking changes behavior it is unlikely to cause problems for users. There are 2 exceptions to this which should be rare and easy to fix: 1. A user relies on the enforced float32 to downcast a double precision value. 2. A user relies on the float32 to explicitly upcast a half precision value even though the weights are in half precision. ## Corner cases In this section we describe corner cases where the implementation of the proposal is not obvious. The two main concerns are how complex numbers are handled in existing layers and how to determine the dtype of state variables. **Autoregressive decoding cache** Currently, only attention implements autoregressive caching and the stored key and value mirror the dtype of the key and value passed to the layer. Forcing the cache dtype to be the same as the output dtype could result in reduced precision during cached decoding vs uncached. This seems undesirable. Decision: keep the current behavior. **Batch statistics** BatchNorm layers are often used with a half precision output dtype. However, calculating statistics is by default always done in float32 to avoid numerical precision issues and over/underflow for float16. With float64 this would actually cause a downcast so we should now use `np.promote_types(float32, dtype)` such that the precision is at least float32. The running batch statistics will be stored with the same dtype for consistency. **Complex number support** Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex numbers correctly: 1. Normalization layers use the complex conjugate to calculate norms instead of normal squaring. 2. Attention: It’s not exactly clear how the dot product and softmax are defined in this case. Raise an error on complex inputs. 3. Recurrent layers: might require special gating / activation functions to function correctly, but these can be specified by the user. # Discussion Summarizing the main points from the discussion: ## Consider implicit complex truncation an error Q: I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do. (Maybe it's a contrived example, but I can imagine cases where layers have their dtype set by parent modules based on assumptions without complex numbers in mind) A: This is worth considering in a follow-up CL but this might as well be solved in JAX directly where the safeguard would apply more generally. In NumPy this was also considered but abandoned because it is not backwards compatible. ## Dtype attribute names Q: Are the dtype and param_dtype arguments confusing? In particular, should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit? A: This would be a large and orthogonal change wrt to this proposal so leaving it out for now. Also, this breaks with the standard dtype argument in NumPY/JAX. Although dtype indeed constrains the output dtype it is also a hint for the dtype we would like the computation to happen in. ================================================ FILE: docs_nnx/flip/2396-rnn.md ================================================ # RNN Flip - Start Date: 2022-08-18 - FLIP PR: [#2604](https://github.com/google/flax/pull/2604) - FLIP Issue: [#2396](https://github.com/google/flax/issues/2396) - Authors: Jasmijn Bastings (@bastings) and Cristian Garcia (@cgarciae) ## Summary This FLIP adds support for higher-level recurrent layers (RNN, GRU, LSTM) that can help users process input sequences using the recurrent cells already available in Flax. ## Motivation Implementing well known recurrent architectures is tricky and prone to user errors, even a simple LSTM layers involves the manual creation and handling of the carry/memory and correctly setting up `nn.scan`: ```python @nn.compact def __call__(self, x): LSTM = nn.scan( nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False} ) carry = LSTM.initialize_carry( jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size ) carry, x = LSTM()(carry, x) return x ``` Slightly more complicated cases involving padding like in the [seq2seq](https://github.com/google/flax/blob/main/examples/seq2seq/models.py) example require even more work but couple potentially be simplified to a couple of lines with the right abstractions. We propose providing users with clean, correct, and efficient abstractions to use recurrent cells. ## Requirements * **Masking**: We need to support a batch of sequences that contain padding at the end of each sequence. * We do not intend to support non-contiguous padding, i.e. padding that is not at the end of a sequence, for performance reasons, except in the case of packing (see below). * **Bidirectionality**: The ability to process a sequence in both the forward and reverse directions, respecting padding (i.e., the reverse direction should start with the actual inputs, not with padding values). * **Performance**: The proposed classes should be benchmarked to provide the best performance in terms of step time and/or memory use. * **Recurrent Dropout**: Support for recurrent dropout in cells (e.g. dropout on the state of the cell). ## Implementation ### High-level structure We propose to have these 3 levels of abstraction: * **Cells (unchanged)**: all RNNCellBase subclasses such as LSTMCell and GRUCell, these implement the stepwise logic. These already exist in Flax today. * **Layers (new)**: a class (RNN) that takes a cell and scans over a sequence respecting possible padding values and optionally also allows packed sequences. * **Bidirectional (new)**: a single class that takes a forward and a backward RNN instance and correctly processes the input sequence in both directions and merges the results. ### Example of proposed API We start with a code example of what you could do with the proposed API, and then we discuss the API in detail below. ```python cell = nn.LSTMCell() # Encodes a batch of input sequences. carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths) ``` A Bidirectional layer with a LSTM RNNs for the forward and backward directions respectively would look like this: ```python forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) # Bidirectional combinator. bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) # Encodes a batch of input sequences in both directions. carry, outputs = bi_rnn(inputs, seq_lengths) ``` Next we will discuss `RNN`, `Bidirectional`, and proposed changes to `RNNCellBase`. ### RNNBase The `RNNBase` class serves as a base class for the `RNN` class, it specifies the API that all RNN layers should implement to be compatible with the `Bidirectional`. `RNNBase` contains the `__call__` and `flip_sequences` methods: ```python class RNNBase(Protocol): def __call__( self, inputs: jax.Array, *, initial_carry: Optional[Carry] = None, init_key: Optional[random.KeyArray] = None, seq_lengths: Optional[Array] = None, return_carry: Optional[bool] = None, time_major: Optional[bool] = None, reverse: Optional[bool] = None, keep_order: Optional[bool] = None, ) -> Union[Output, Tuple[Carry, Output]]: ... ``` Where: * `inputs`: the input sequence. * `initial_carry`: the initial carry, if not provided it will be initialized using the cell's :meth:`RNNCellBase.initialize_carry` method. * `init_key`: a PRNG key used to initialize the carry, if not provided ``jax.random.key(0)`` will be used. Most cells will ignore this argument. * `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored. * `return_carry`: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. * `time_major`: if ``time_major=False`` (default) it will expect inputs with shape ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. * `reverse`: if ``reverse=False`` (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If ``seq_lengths`` is passed, padding will always remain at the end of the sequence. * `keep_order`: if ``keep_order=True``, when ``reverse=True`` the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default), the output will remain in the order specified by ``reverse``. * `Returns`: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. ### RNN The `RNN` module inherits from `RNNBase`, it main function is to apply an `RNNCellBase` instance over a batch of input sequences, it can be used with any type of cell (e.g., `GRUCell`, `LSTMCell`, etc). It accepts the following parameters: ```python class RNN(RNNBase): cell: RNNCellBase, cell_size: int | Tuple[int, ...] time_axis: int = -2, variable_axes = FrozenDict(), variable_broadcast: CollectionFilter = 'params' variable_carry: CollectionFilter = False split_rngs = FrozenDict({'params': False}) # implement RNNBase ... ``` Attributes like `variable_axes`, `variable_broadcast`, `variable_carry`, and `split_rngs` are directly passed to `nn.scan`, their default values are set such that common cells like `LSTMCell` and `GRUCell` work out of the box. ### Masking `seq_lengths` is defined as an integer array of shape `(*batch,)` indicating the length of each sequence.
Discussion There are various masking formats found in other frameworks, here are some of the most popular ones: * **Binary masking**: specifies per-sample and timestep whether that data point should be included or not in the computation, it can be non-contigous (e.g., [1, 1, 0, 1]). This is used by Keras. * **Sequence length masking**: specifies per-sample the number of non-padding examples contained in the sequence, any padding contained in the sequence should be stacked at the end. This is used by FlaxFormer. * **Segmentation Mask**: specifies row and timestep to which sample the data point belongs to, this format allows more than one sample per row which potentially reduces the total amount of padding needed (e.g. [1, 1, 1, 2, 2, 0, 0]). Pytorch uses this representation (see [pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)). While Sequence packing (see [LM1B example](https://github.com/google/flax/blob/main/examples/lm1b/input_pipeline.py#L90-L92)) is is more powerful, its implementation is more complex and it is not clear whether it is worth the effort. The simplest format is sequence length masking, which is the one we propose to use.
### Bidirectional Bidirectional processing can be achieved via a Module that accepts a `forward_rnn` Module and a `backward_rnn` Module, both of which should be `RNN` instances, in order to process the input sequence in both directions. Here we present some pseudo code of the implementation: ```python def __call__(self, inputs, seq_lengths): # Encode in the forward direction. carry_forward, outputs_forward = self.forward_rnn( inputs, seq_lengths=seq_lengths, return_carry=True, reverse=False, ) # Encode in the reverse order. carry_backward, outputs_backward = self.backward_rnn( inputs, seq_lengths=seq_lengths, return_carry=True, reverse=True, # process in reverse order keep_order=True, # but return the sequence in the original order ) # Merge both sequences. outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward) return (carry_forward, carry_backward), outputs ``` Here `merge_fn` a function that takes both outputs and fuses them (`concat` by default). As showcased in the beginning of this document, usage would look like this: ```python forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32) # Bidirectional combinator. bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) # Encodes a batch of input sequences in both directions. carry, outputs = bi_rnn(inputs, seq_lengths) ``` ### Recurrent Dropout There are two main uses of dropout in RNNs: 1. Input dropout: regular dropout applied to the inputs, different for every step. 4. Recurrent dropout: applies dropout to a recurrent input/output, same for every step. Flax's `nn.scan` can easily express both types of dropout via `split_rns`, input dropout would split rngs while recurrent dropout would not. [#2540](https://github.com/google/flax/pull/2540) was introduces such that the `rng_name` in `nn.Dropout` can now be defined by the user, this way Cells could define both types of dropout e.g: ```python self.dropout = nn.Dropout(...) # input dropout self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout') ``` Based on this, `nn.scan` / `nn.RNN` can now specify `split_rngs` accordingly e.g: ``` nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False}) ``` # Future ideas
show ### Sequence Packing Allow packing multiple sequences to make efficient use of space/memory. This might result in a trade-off where step time is higher (because at each step we need to check whether we are starting a new sequence and reset the carry/initial state), but where less padding is used increasing efficiency overall. ### RNNCell redesign #### Make initialize_state an instance method First altenative is to make `initialize_carry` a instance method. With this change hyperparameters can be passed directly to the cell, it signature would look like this: ```python def initialize_carry(self, sample_input) -> Carry: ... ``` Usage would look like this: ```python LSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'dropout': True}) lstm = LSTM(features=32) carry = lstm.initialize_carry(x[:, 0]) carry, y = lstm(carry, x) ``` #### Remove initialize_carry An alternative is to remove `initialize_carry` entirely and have the carry state be handled as a carry collection. This would simplify usage quite a bit: ```python LSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'dropout': True}) y = LSTM(features=32)(carry, x) ``` However, this would require `nn.scan` to support initialization of carry collections which is currently not possible. Also, users would have to specify that a collection is mutable e.g. `mutable=['carry']`, even if they are not interested in the output carry state.
================================================ FILE: docs_nnx/flip/2434-general-metadata.md ================================================ # FLIP: Axis Metadata - Start Date: 2022-08-08 - FLIP Issue: [#2434](https://github.com/google/flax/issues/2434) - FLIP PR: [#2435](https://github.com/google/flax/pull/2435) - Status: Proposal ## Summary This FLIP proposes to extend Flax's variable collections with a generic axis metadata API. The core of the API is an abstract base class that is recognized by lifting transformations that can add an axis (vmap, scan). Users can extend the base class to keep track of per-axis metadata in a way that works with lifted transformations. ## Motivation Generally, there is no way in Flax to track metadata for variables across lifted transformations. Axis metadata is used to keep track of semantic information about axes into other (Flax independent) APIs. For example, optimizers like AdaFactor can be configured on a per-axis level and partitioning APIs in JAX like xmap or pjit require per variable annotations to map effectiently to parallel hardware. Currently, there is an experimental [API](https://github.com/google/flax/blob/main/flax/linen/partitioning.py) supporting partitioning annotations with wrappers around lifted transforms that change axes (``nn.scan_with_axes``, ``nn.vmap_with_axes``) and a special APIs to create variables (``param_with_axes`` and ``variable_with_axes``). The experimental partitioning API stores the metadata in a separate collection named "[collection]_axes". The experimental API has a number of shortcomings that we like to solve: 1. The current API works for tracking PartitionSpecs but not for other types of metadata like optimizer annotations. 2. The implementation using an "xxx_axes" collection requires error-prone and non-composable string manipulation. 3. Special, partioning-aware variable creators and lifted transforms are required 4. The partioning API is hard to use with pre-existing Modules that aren't partioning aware. ## Proposal To generalize metadata tracking and keep the specific metadata out of core Flax we propose the following abstract base class: ```python TAxisMetadata = TypeVar("TAxisMetadata", bound="AxisMetadata") class AxisMetadata(metaclass=abc.ABCMeta): """Abstract base class for boxed Metadata. ``AxisMetadata`` enables arbitrary, per axis metadata for variables. By using ``unbox`` the metadata is stripped away to obtain the original variables. By using unboxing, most code handling variables does not need to handle ``AxisMetadata`` specifically, but can directly operate on the JAX arrays that they wrap. Additionally, ``AxisMetadata`` supports updating metadata whenever an axis is added or removed by a functional transformation (e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis`` methods. By extending ``AxisMetadata``, custom metadata can be stored. See ``Partitioned`` for a specific implementation. """ @abc.abstractmethod def unbox(self) -> Any: """Returns the content of the AxisMetadata box. Note that unlike ``meta.unbox`` the unbox call should recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata. In practise, AxisMetadata subclasses should be registred as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this note should correspond to the value returned by unbox. Returns: The unboxed value. """ pass @abc.abstractmethod def add_axis(self: TAxisMetadata, index: int, params: Dict[Any, Any]) -> TAxisMetadata: """Adds a new axis to the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``) Args: index: The position at which the new axis will be inserted params: An arbitrary dictionary of parameters passed by the transformation that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass @abc.abstractmethod def remove_axis(self: TAxisMetadata, index: int, params: Dict[Any, Any]) -> TAxisMetadata: """Removes an axis from the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``) Args: index: The position of the axis that is to be removed params: An arbitrary dictionary of parameters passed by the transformation that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass ``` We call this type of class wrapping a value and keeping track of some additional data a **box**. By defining an abstract base class for this box, the API does not need to be aware of the specifics of the metadata that is tracked. This should make the API future proof and modular. The ``add_axis`` and ``remove_axis`` method return an instance of their own type instead of mutating in-place. Typically, an implementation would be a ``flax.struct.PyTreeNode`` because the box should still be a valid JAX value and must therefore be handled by the PyTree API. Calling ``jax.tree.map`` on a boxed value will simply map over the value in the box. The lifted transforms that need to handle metadata will call ``jax.tree.map(..., is_leaf=lambda x: isinstance(x, AxisMetadata))`` to find the AxisMetadata instances within a PyTree. Advantages of the boxing approach: 1. Boxing can be used outside of Flax and metadata is automatically "inherited". For example, the optimizer state will have the same partitioning spec as the parameters, because the state is initialized using a ``jax.tree.map`` over the boxed parameters. 2. Boxes are composable. 3. Boxing avoids string manipulation and generally avoids having to handle additional auxiliary collections like "param_axes" in the current partitioning API. 4. No need to lift metadata collections separately. Disadvantages: 1. Adding the boxes changes the PyTree hierarchy and introduces dataclasses within the otherwise plain, nested dict of variables. 3. Custom Pytree nodes have a small runtime overhead. It's hard to observe this in practise because JAX calls are async. ### Init syntax Boxes can be created directly by the init function of a variable. Therefore, we propose to create metadata using higher-order initializers. The main advantage of this is that we can decouple metadata handling completely from the Module definition. Also, most Modules already overwrite attributes to override the default initialzers so users can add metadata to existing Modules without requiring any code changes. To illustrate this, let's consider a metadata class that keeps track of PartitionSpecs used by ``pjit``: ```python class Partitioned(flax.struct.PyTreeNode, AxisMetadata): value: Any names: Tuple[Optional[str], ...] = flax.struct.field(pytree_node=False) def add_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: axis_name = self._get_partition_name(params) names = list(self.names) names.insert(index, axis_name) return self.replace(names=tuple(names)) def remove_axis(self, index: int, params: Dict[Any, Any]) -> TAxisMetadata: axis_name = self._get_partition_name(params) names = list(self.names) assert names.pop(index) == axis_name return self.replace(names=tuple(names)) def with_partitioning(init_fn, names): def wrapper(*args, **kwargs): return Partitioned(init_fn(*args, **kwargs), names) return wrapper ``` Here we also defined a small utility called ``with_partitioning`` that we can use to wrap existing initialzers to add metadata: ```python # init kernel with lecun normal and split the output features over the data axis partitioned_dense = nn.Dense(features, kernel_init=with_partitioning(nn.initializers.lecun_normal, (None, "data"))) ``` Initializing a model that creates partitioned weights would result in the following variable structure: ```python variables = partitioned_dense.init(rng, jnp.ones((4,))) jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} ``` The variable tree with metadata can be used to integrate with other libraries and APIs. For example, we can turn the ``Partitioned`` metadata into ``jax.pjit`` sharding annotations: ```python def to_sharding_spec(x): if isinstance(x, Partitioned): return PartitionSpec(*x.names) else: # fully replicated return PartitionSpec() # Result: {"params": {"kernel": PartitionSpec(None, "data"), bias: PartitionSpec()}} variables_pspec = jax.tree.map(to_sharding_spec, variables, is_leaf=lambda x: isinstance(x, Partitioned)) ``` ### Unbox syntax Metadata typically doesn't need to be handled by Modules directly. Therefore, we propose to make Modules agnostic to Metadata boxes by default. The ``unbox`` method can be used to unpack a variable such that only the original JAX arrays remain. Users can manually call unbox but to make sure Module classes don't have to call it everywhere we add an unbox keyword arg to variable returning APIs (e.g.: ``.param``, ``.variable``, ``.get_variable``). The keyword arg ``unbox`` will default to ``True`` such that a Modules are metadata agnostic by default. This also means existing Modules will be backward compatible with the new API. ```python kernel = self.param("kernel", self.kernel_init, shape) # No AxisMetadata instances kernel_box = self.get_variable("param", "kernel", unbox=False) # AxisMetadata boxes are preserved ``` ### Lift syntax When calling a lifted transformation that adds an axis you will now be able to pass a dictionary with arguments. These params will be passed to ``AxisMetadata`` add_axis/remove_axis callbacks: ```python nn.scan(..., variable_axes={"params": 0}, metadata_params={nn.Partitioned.AXIS_NAME: "layers"}) ``` A dict is used such that users can add their own arguments to custom AxisMetadata classes. ================================================ FILE: docs_nnx/flip/2974-kw-only-dataclasses.md ================================================ # FLIP: kw_only dataclasses Authors: Brennan Saeta, Ivy Zheng - Start Date: Mar 23, 2023 - FLIP Issue: [TBD] - FLIP PR: #2974 - Status: Implementing ## Summary Python 3.10 adds support for `kw_only` dataclasses. Subclasses of `flax.linen.Module` are automatically converted to `dataclasses` on users' behalf, but today, Flax doesn't allow setting the `kw_only` parameter to this dataclass transform, even if users are running Python 3.10. This proposal allows users to use this new feature with `nn.Module`'s. ## Motivation In larger Flax-based codebases (e.g. [`PaxML`](https://github.com/google/paxml) / [`Praxis`](https://github.com/google/praxis)), it’s not uncommon to define an (abstract) subclass of nn.Module that contains shared functionality that is itself further subclassed for specific implementations (e.g. [`BaseLayer`](https://github.com/google/praxis/blob/main/praxis/base_layer.py), or [`StackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L1836) which is further subclassed by [`PipelineCompatibleStackedTransformerRepeat`](https://github.com/google/praxis/blob/81479b260fcc13de8549cdbfb0fdf5c3f188ac90/praxis/layers/transformers.py#L2198)). Often, these parent types define hyperparameters (constructor arguments), often with default values. Without `kw_only` on the `dataclass` transform, default values must be specified for all child layers hyperparameters. This is suboptimal, because users could forget to set them when instantiating the modules. For example, `Child` must set a default value for `num_heads` (because a non-defaulted argument can’t come after a defaulted argument if they are positional), but no reasonable default is available: ```python class BaseLayer(nn.Module): mesh: Optional[jax.experimental.mesh.Mesh] = None def with_sharding(self, some_variable, some_sharding): if self.mesh: # Do something useful here. class Child(BaseLayer): num_heads: int # Don't want to have to set a default argument! def __call__(self, x): ... ``` Note: Flax already has this problem, which is why `nn.Module` has its own fancy `kw_only_dataclasses.dataclass` transform: it moves the `name` and `parent` dataclass fields to the end, so they can have defaults. ## Implementation To allow modules to optionally opt into this `kw_only` dataclass behavior, we leverage arguments to `__init_subclass__`. This would look as follows: ```python class BaseLayer(nn.Module, kw_only=True): ... class Child(BaseLayer): ... ``` The implementation of `nn.Module`’s `__init_subclass__` will be tweaked as follows: ```python class Module(ModuleBase): def __init_subclass__(self, kw_only: Optional[bool] = None): # ... if kw_only: if is_python_310_or_above(): dataclass_transform_args = {'kw_only': True} else: raise TypeError("Can't use `kw_only` before Py3.10.") else: dataclass_transform_args = {} kw_only_dataclasses.dataclass( cls, unsafe_hash='__hash__' not in cls.__dict__, repr=False, **dataclass_transform_args) ``` ### Forward compatibility For future simplification, if `kw_only` is requested and the Python version is 3.10 or above, bypass the `kw_only_dataclasses` implementation and just use the regular `dataclasses` transform. That means we may one day remove `flax/linen/kw_only_dataclasses.py` when Flax rolls over 3.10. ## Discussion ### Aligned with Python `dataclass` We prefer to keep the behavior of `nn.Module`’s `kw_only` aligned with the Python dataclasses. Note that this means `kw_only` will not be inheritable, and this could happen: ```python class BaseLayer(nn.Module, kw_only=True): base_muliplier: Optional[int] = -1 class ChildLayer(BaseLayer): child_multiplier: int BaseLayer(2) # This will throw error ChildLayer(2) # But this will not ``` ### `flax.struct.dataclass` There’s a potentially related feature to allow `kw_only` to be specified for `flax.struct.dataclass`. This should be considered an orthogonal decision. ================================================ FILE: docs_nnx/flip/3099-rnnbase-refactor.md ================================================ # Refactor RNNCellBase in FLIP Authors: Cristian Garcia, Marcus Chiam, Jasmijn Bastings - Start Date: May 1, 2023 - FLIP Issue: [TBD] - FLIP PR: #3053 - Status: Implemented ## Summary This proposal aims to improve the usability of the `RNNCellBase` class by refactoring the `initialize_carry` method and other relevant components. ## Motivation Currently, `initialize_carry` is used to both initialize the carry and pass crucial metadata like the number of features. The API can be unintuitive as it requires users to manually calculate things that could typically be inferred by the modules themselves, such as the shape of batch dimensions and the shape of feature dimensions. ### Example: ConvLSTM The current API can be unintuitive in cases like `ConvLSTM` where a the `size` parameter contains both the input image shape and output feature dimensions: ```python x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) # image shape: vvvvvvv carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16)) # batch size: ^^ ^^ :output features lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) (carry, y), initial_params = lstm.init_with_output(key2, carry, x) ``` This FLIP will propose some changes to `initialize_carry` such that the previous example can be simplified to: ```python x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels) lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) carry = lstm.initialize_carry(key1, input_shape=x.shape) (carry, y), initial_params = lstm.init_with_output(key2, carry, x) ``` ## Implementation The proposal suggests the following changes: ### initialize_carry `initialize_carry` should be refactored as an instance method with the following signature: ```python def initialize_carry(self, key, sample_input): ``` `sample_input` should be an array of the same shape that will be processed by the cell, excluding the time axis. ### Refactor RNNCellBase subclasses `RNNCellBase` should be refactored to include the metadata required to initialize the cell and execute its forward pass. For `LSTMCell` and `GRUCell`, this means adding a `features` attribute that should be provided by the user upon construction. This change aligns with the structure of most other `Module`s, making them more familiar to users. ```python x = jnp.ones((2, 100, 10)) # (batch, time, features) cell = nn.LSTMCell(features=32) carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input (carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x) ``` ### num_feature_dims To simplify the handling of `RNNCellBase` instances in abstractions like `RNN`, each cell should implement the `num_feature_dims` property. For most cells, such as `LSTMCell` and `GRUCell`, this is always 1. For cells like `ConvLSTM`, this depends on their `kernel_size`. ## Discussion ### Alternative Approaches * To eliminate the need for `num_feature_dims`, `RNN` could support only a single batch dimension, i.e., inputs of the form `(batch, time, *features)`. Currently, it supports both multiple batch dimensions and multiple feature dimensions. * Another approach could be a complete redesign of how Flax deals with recurrent states. For example, a `memory` collection could be handled as part of the variables. However, this introduces challenges such as handling stateless cells during training, passing state from one layer to another, and performing initialization inside `scan`. ### Refactor Cost Initial TGP results showed 761 broken and 110 failed tests. However, after fixing one test, TGP results in 231 broken and 13 failed tests so there seems to be a lot of overlap between the broken tests. To minimize refactor costs, the current implementation will be kept for Google internal users under a deprecated name. This will allow users to migrate to the new API at their own pace. For Open Source users we should bump Flax version to `0.7.0` so existing users can continue to depend on `0.6.x` versions. ================================================ FILE: docs_nnx/flip/4105-jax-style-nnx-transforms.md ================================================ # JAX-style NNX Transforms - Authors: Cristian Garcia, Anselm Levskaya - Date: Jun/2024 - FLIP PR: #4107 - Status: Implementing ## Motivation NNX allows users to utilize Modules at the top level due to their eager initialization and self-contained state. This naturally leads users to want to use them with transforms and soon start playing with NNX transforms. Since NNX Modules resemble PyTrees in that they contain Arrays, new users often attempt to apply JAX conventions, for example: ```py @nnx.vmap(in_axes=(1, 0)) def f(m1: Module, m2: Module): ... ``` However, this can be misleading. Currently, NNX transforms follow Linen's convention of treating input Modules as a single unit (all Modules are split together to preserve shared references) and provide APIs for transforming that State separately. The previous example effectively translates to: ```py # this is what is really happening @nnx.vmap(in_axes=(IGNORE, IGNORE), state_axes={BatchStat: None, ...: 0}) def f(m1: Module, m2: Module): ... ``` Note that `IGNORE` is not a real symbol, but represents the fact that any value placed here won't affect the outcome, as Modules are replaced by empty PyTree placeholders (similar to `None`). The `state_axes` parameter controls how the State is vectorized through a mapping of high-level `Filter`s to their desired axes. In this example, `...` (ellipsis) is a filter that accepts everything, so by default all States are vectorized on the 0th axis. To express their original intention, users must resort to more complex custom filters that guess the index of each Module in the monolith. While this is straightforward in simple cases, users generally need to calculate the index (Modules appear in the order specified by `jax.tree.leaves` over the `args`): ```py select_m1 = lambda path, value: path[0] == 0 select_m2 = lambda path, value: path[0] == 1 # To select modules individually, you must create a filter (which can be tricky) @nnx.vmap(state_axes={select_m1: 1, select_m2: 0}) def f(m1: Module, m2: Module): ... ``` ## What if JAX conventions Just Worked™? This proposal aims to align NNX transforms with user's expectations based on their JAX experience, making the syntax work as intuitively as possible. The original example would function **as if** `m1` and `m2` were PyTrees vectorized in axes `1` and `0` respectively: ```py @nnx.vmap(in_axes=(1, 0)) def f(m1: Module, m2: Module): ... ``` The primary advantage of this approach is that for `vmap` and `scan`, we could eliminate the `state_axes` and `split_rngs` arguments, relying solely on the `in_axes` API. This syntax alone would likely suffice for 80-90% of use cases, as users tend to manage state in predictable ways. ### The Lift symbols To enable more fine-grained state control within each Module, we introduce the `Lift` API. By using special types containing State Filters in place of a tree prefix, state lifting can now be done **structurally**. This allows different Filters to be applied to different Modules in the arguments without the need for complex path-based filters. Ideally, each transform would support its own Lift type, adding the desired behavior through existing JAX APIs. For example, in `vmap`, we could allow `StateAxes` instances (vmap's Lift type) to be accepted by `in/out_axes` to control how substates are handled by mapping state `Filter`s to an axis specifier: ```py state_axes = StateAxes({Param: 1, BatchStat: None}) @nnx.vmap(in_axes=(state_axes, 0)) def f(m1: Module, m2: Module): ... ``` In this case, `m1`'s `Param`s are vectorized in axis `1` while its `BatchStat`s are broadcasted, and `m2`'s entire state is vectorized in axis `0`. For `nnx.grad`, we could allow `DiffState` to be used in the `argnums` parameter to specify both the position of the argument to be differentiated and a Filter specifying the differentiable State of the Module: ```py grads = nnx.grad(loss_fn, argnums=(DiffState(0, LoRAParam),))(model, x, y) ``` ## Rng Handling To simplify RNG state handling, we propose removing the separate `split_rngs` parameter in `vmap` and `scan`. Instead, we suggest introducing a new `nnx.split_rngs` API that would manage RNG handling before and after the transformation. This approach provides more explicit control to the user and aligns better with JAX transform behavior. ## Consistent Aliasing To ensure the correctness of transformations with objects that obey reference semantics, we must enforce consistent lifting/lowering specifications for all aliases of a reference. Transforms must adhere to two rules: 1. All aliases of a reference must receive the **exact same** lifting/lowering specification. 2. Captured references are not allowed on the output of transformed functions. For example: ```py @nnx.vmap(in_axes=(m1_axes, m2_axes, m1_axes), out_axes=m2_axes) def f(m1, m2, m1_alias): return m2 m2 = f(m1, m2, m1) ``` Here, `m1` has two input aliases as it is passed as the first and third input to `f`, but this is acceptable because `m1_axes` is assigned to both in `in_axes`. `m2` is passed as the second input and has an output alias, which is also acceptable because `m2_axes` is assigned in both `in_axes` and `out_axes`. Let's examine some examples of programs that should be **rejected** based on these criteria: ### Inconsistent input aliases Consider a function with two arguments `m1` and `m2` being vectorized in axis `0` and `1` respectively. Passing the same Module as both arguments would be inconsistent: ```py @nnx.vmap(in_axes=(0, 1)) def f(m1: Module, m2: Module): ... f(m, m) # This should be rejected ``` ### Inconsistent input / output aliases Now consider an identity function `g` under `vmap` with `in_axes=0` and `out_axes=1`. In JAX, this would result in transposing the arrays in the inputs: ```py @nnx.vmap(in_axes=0, out_axes=1) def g(m: Module): return m ``` While this appears correct, in NNX this behavior is not well-defined because shared mutable references behave as auxiliary outputs. Under the hood, `g` is converted into a function that has the inputs as an extra first output, and `out_axes` is set to the same values as `in_axes` for that output: ```py @nnx.vmap(in_axes=0, out_axes=(0, 1)) def g_real(m: Module): return m, m ``` This return structure reveals an inconsistency: we're attempting to lower `m` with both `out_axes=0` and `out_axes=1`. ### Inconsistent aliases in nested structures Similar issues can arise in less obvious cases, such as when `m` is contained within another structure: ```py @nnx.vmap(in_axes=0, out_axes=1) def f(m: Module): return SomeModule(m) ``` This means we must traverse the entire graph of both inputs and outputs to check for consistent assignments. The same problem occurs when passing shared reference inputs/outputs with different specifications: ```py shared = Shared() m1, m2 = Foo(shared), Foo(shared) @nnx.vmap(in_axes=(0, 1)) def f(m1, m2): # shared is passed through both ... ``` ### Captured Modules cannot be outputs Finally, let's consider the second consistent aliasing rule, which states that captured Modules cannot be outputs. The main issue here is that NNX needs to split all input references together to track changes, but captured Modules bypass this process. Treating them as new references would result in **implicit cloning**: ```py m = SomeModule() @nnx.vmap(out_axes=0, axis_size=5) def f(): return m assert m is not f() # implicit cloning ``` To preserve reference identity, we must disallow captured Modules as outputs. In practice, we can detect captured Modules using the trace level context machinery used to restrict stateful updates on Modules from a different level. ## Recap In this document, we have: * Discussed issues with the current implementation that make it unintuitive for JAX users. * Proposed refactoring NNX transforms to allow users to use regular JAX semantics when interacting with objects, removing extra arguments introduced by NNX transforms. * Introduced the use of Lift types in JAX APIs to compensate for the lack of a "prefix" notion in NNX objects, enabling independent lifting of Module substates. * Proposed a new `nnx.split_rngs` API to replace the `split_rngs` arguments in `vmap` and `scan`, making RNG handling an explicit operation and giving users more control. * Analyzed edge cases resulting from aliasing shared mutable references and proposed enforcing **consistent aliasing** on all transforms with semantics over the inputs. ================================================ FILE: docs_nnx/flip/4844-var-eager-sharding.md ================================================ - Start Date: 2025-09-12 - FLIP PR: [#4844](https://github.com/google/flax/pull/4844) # FLIP 4844: Variable eager sharding ## Summary [summary]: #summary Simplify the creation of sharded NNX models. When a sharding annotation is provided, all `nnx.Variable` creation will **require a mesh context** and automatically be sharded as annotated. See [GSPMD Guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) for a comprehensive guide on how to make sharded NNX models. # Motivation To create a sharded model, user should only need to do this: ```python mesh = jax.make_mesh(((2, 4)), ("data", "model")) with jax.set_mesh(mesh): model = YourModelWithShardingAnnotations() ``` Instead of the current boilerplate combo of `nnx.jit`, `nnx.get_partition_spec`, `with_sharding_constraint` and `nnx.update`: ```python @nnx.jit def create_sharded_model(): model = YourModelWithShardingAnnotations() # Unsharded at this moment. state = nnx.state(model) # The model's state, a pure pytree. pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state. sharded_state = jax.lax.with_sharding_constraint(state, pspecs) nnx.update(model, sharded_state) # The model is sharded now! return model mesh = jax.make_mesh(((2, 4)), ("data", "model")) with jax.set_mesh(mesh): sharded_model = create_sharded_model() ``` # Backward compatibility User can turn off this feature in two ways: * **Global config flag**: Run `flax.config.update('flax_always_shard_variable', False)` before running any NNX model initialization. * **Variable-specific flag**: Create a specific variable with metadata `eager_sharding=False`, such as: `nnx.Param(..., eager_sharding=False)`. # Flexibility options For debugging in a CPU environment, make a dummy mesh to run the model: ```python mesh = jax.make_mesh(((1, 1, 1)), ('your', 'axes', 'names')) with jax.set_mesh(mesh): ... ``` For JAX explicit mode, remove the `out_sharding=` annotation on the `nnx.Variable`. # Implementation [implementation]: #implementation When an `nnx.Variable` is created, check for the metadata `out_sharding`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value. Note that this only works in auto sharding mode. User should use JAX-level APIs to annotate shardings for explicit mode. ================================================ FILE: docs_nnx/flip/5310-tree-mode-nnx.md ================================================ # Tree Mode NNX Mar 4, 2026 Cristian Garcia, Samuel Anklesaria, Flax Team ## Motivation Current NNX APIs allow general graph structures and graph transformations, this includes: 1. Tracking Variable state updates 2. Handling shared references (graphs) 3. Supporting prefix filters (StateAxes, DiffState, StateSharding) 4. Propagating graph updates (static state and structure changes) While powerful, some of these capabilities (**3** and **4**) are beyond what JAX transform APIs offer and supporting them results in both internal complexity, harder to reason about code, and a larger set of APIs a user must learn. We wish to tackle all these issues by simplifying NNX. ## Proposal To do this we propose two things. First, the introduction of **Tree Mode NNX**: a reimplementation of the NNX APIs that only handles trees, assumes referential transparency, and has a more limited support for state updates. Concretely, this means: * Automatic state updates only for Variables in NNX transforms. * Tree structure assumed and enforced on all APIs (no sharing) * Modules treated as stateless pytrees (no graph updates). * Full JAX transform compatibility (remove [prefix filters](#prefix-filters): StateAxes, DiffState, StateSharding). Second, simplifying graph support. Graphs stand out as an important feature for some NNX users. However, we will be limiting support to **1** and **2**, meaning that prefix filters and graph updates will be dropped. This will make it such that tree and graph transforms can share the same underlying implementation and semantics while still allowing for a great deal of expressivity. ## Implementation Tree mode will be implemented on top of the current APIs by introducing a `graph` argument, when `True` graph support is enabled, when `False` only trees are supported and internals rely on `jax.tree.*` APIs. Additionally, a `graph_updates` argument will be added to NNX transforms, when `False` transforms will no longer propagate graph structure update (**4**) or support prefix filters (**3**). ```py def split(..., graph: bool | None = None) ... def jit(..., graph: bool | None = None, graph_updates: bool | None = None) ... ``` If `graph` or `graph_updates` are not provided, their default values will be taken from the `nnx_graph_mode` and `nnx_graph_updates` config flags respectively. These can be easily fetched and updated via `set_graph_mode` and `set_graph_updates`. ```py # status print(nnx.set_graph_mode.current_value()) print(nnx.set_graph_updates.current_value()) # set value nnx.set_graph_mode(True/False) nnx.set_graph_updates(True/False) # via env vars # NNX_GRAPH_MODE=true/false # NNX_GRAPH_UPDATES=true/false # context managers with nnx.set_graph_mode(True/False): ... with nnx.set_graph_updates(True/False): ... ``` The goal will be to have the default value for `nnx_graph_mode` and `nnx_graph_updates` to be set to `False`, thus enabling tree mode for new projects. Users that don’t want to migrate can use these flags to make sure their code continues to work with current features. ### Simple transforms These new transforms are highly simplified compared to current transforms, they are easier to implement and optimize, while supporting both trees and graphs. Given a user function f, most simplified transforms follow this pattern: ```py def transform_wrapper(*args): if graph: args = to_tree(args) check_no_aliases(args=args) @jax_transform def transformed_f(*args): updates, snapshot = updates_and_snapshot(args) if graph: args = from_tree(args) out = f(*args) if graph: out = to_tree(out) check_no_aliases(args=updates, out=out) updates = mask_variable_updates(updates, snapshot) return out, updates out, updates = transformed_f(*args) apply_variable_updates(args, updates) if graph: out = from_tree(out) return out ``` The transformed function tracks input Variable `updates`, applies f, and masks Variable updates (no updates for Variables that didn’t change). It also checks that there are no Variable aliases between the inputs and outputs (no shared references), and returns the user output plus Variable updates. The wrapper function calls the transformed function, applies the Variable updates to the input Variables, and returns the user output. To support graphs, we simply convert objects to a tree representation before passing them to jax, and back to graphs before passing them to the user code. ## Backward Compatibility When tree mode is on by default, code that relies on graphs, graph updates, and prefix filters will stop working. There are two ways to port existing code, the first is reverting the defaults config via `set_graph_mode` and `set_graph_updates` somewhere in the after the imports: ```py from flax import nnx ... nnx.set_graph_mode(True) nnx.set_graph_updates(True) ``` The previous implementation of the transform APIs will also be accessible via the `nnx.compat` module. They are implemented as partials that set `graph=True` and `graph_updates=True`: ```py nnx.compat.split = partial(nnx.split, graph=True) ... nnx.compat.jit = partial(nnx.jit, graph=True, graph_updates=True) ... ``` The above shortcuts will make it such that porting existing code (if needed) is as simple as performing some rewrites: `nnx.split` → `nnx.compat.split` `nnx.jit` → `nnx.compat.jit` … ## Breaking changes ### Prefix filters {#prefix-filters} Code that relies on prefix filters such as StateAxes, StateSharding, and DiffState will require some restructuring as JAX has no equivalent mechanisms (these were added to make Linen migration easier). The solution is to use `split` and `merge` to create state groups, and pass each group through their corresponding tree prefix on the jax transform. For example: ```py # previous code state_axes = nnx.StateAxes({some_filter: 0, ...: None}) @nnx.vmap(in_axis=state_axes, graph=True, graph_updates=True) def f(model): ... ``` This can be rewritten to `split` the model into two state groups using the previous filter, passing the groups as separate arguments, one vectorized and the other broadcasted, and using `merge` to reconstruct the model inside the transform. ```py # new code graphdef, vectorized, broadcasted = nnx.split(model, some_filter, ...) @nnx.vmap(in_axis=(0, None)) def f(vectorized, broadcasted): model = nnx.merge(graphdef, vectorized, broadcasted) ... ``` This is roughly how prefix filters were implemented under the hood. ### nnx.grad Code that uses `nnx.grad` will change in two ways: 1. The first argument will no longer be differentiated w.r.t. to `Param`s only, this is because `grad` used this prefix filter by default: `DiffState(0, Param)`. 2. The gradients of NNX Pytree/Module types will no longer be `State` types. Now they just follow JAX and return the same input type. Concretely it means that code like this: ```py # previous code def loss_fn(model: Foo): ... # uses argnums=nnx.DiffState(0, nnx.Param) grads = nnx.grad(loss_fn)(model) ``` Now has to explicitly use `split` and `merge` if to avoid calculating gradients for the non-differentiable state: ```py # new code def graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params, nondiff): model = nnx.merge(graphdef, params, nondiff) ... # uses argnums=0 grads = nnx.grad(loss_fn)(params, nondiff) ``` If there is no non-differentiable the `model` can be passed in directly but the gradients will now be of the same type: ```py # new code def loss_fn(model: Foo): ... # uses argnums=0 grads: Foo = nnx.grad(loss_fn)(model) ``` ### nnx.custom_vjp Previously `nnx.custom_vjp` did two particular things: 1. The backward function returned the gradients of the Variable updates (`m_updates_g`) along with the output gradient. 2. The tangent for nnx.Pytree/Module objects were of type `nnx.State`. For a `Foo` Module with `x: Param` and `y: Param` attributes, a simple example could look like this: ```py # previous code @nnx.custom_vjp def f(m: Foo): return jnp.sin(m.x) * m.y def f_fwd(m: Foo): return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) def f_bwd(res, g): (m_updates_g,), out_g = g cos_x, sin_x, m = res m_g: nnx.State = nnx.clone(m_updates_g) # create copy m_g['x'][...] = cos_x * out_g * m.y m_g['y'][...] = sin_x * out_g return (m_g,) # State gradient ``` In the new implementation gradients for Variable updates are not returned, and the tangent type is the same as the input type (`Foo`), this matches the behavior of `jax.custom_vjp`: ```py # new code @nnx.custom_vjp def f(m: Foo): return jnp.sin(m.x) * m.y def f_fwd(m: Foo): return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) def f_bwd(res, g): # no gradients for updates cos_x, sin_x, m = res m_g: Foo = nnx.clone(m) # create copy m_g.x[...] = cos_x * g * m.y m_g.y[...] = sin_x * g return (m_g,) # Foo gradient ``` Note that to avoid losing information, now differentiable Variables are not allowed to be updated inside `nnx.custom_vjp`. ### transform\_metadata Previously NNX transforms like `vmap` and `scan` had a `transform_metadata` metadata argument that allowed them to update the sharding metadata. ```py # old code @nnx.split_rngs(8) @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: 'din'}) class create_stack(rngs): # 'din' added to out_sharding metadata return nnx.Variable(rngs.uniform((16,)), out_sharding=('dout',)) v_stack = create_stack(nnx.Rngs(0)) assert v_stack.shape == (8, 16) assert v_stack.out_shardings == ('din', 'dout') ``` The new simplified NNX transform implementations don’t support this argument. However, to keep supporting the behavior, a new `nnx.transform_metadata` transform is introduced that can be inserted to get back the same results. TODO: mention it works on `jax.vmap`. ```py # new code @nnx.split_rngs(8) @nnx.vmap(in_axes=0, out_axes=0) @nnx.transform_metadata(in_axes=0, out_axes=0, partition='din') class create_stack(rngs): # 'din' added to out_sharding metadata return nnx.Variable(rngs.uniform((16,)), out_sharding=('dout',)) v_stack = create_stack(nnx.Rngs(0)) assert v_stack.shape == (8, 16) assert v_stack.out_shardings == ('din', 'dout') ``` `transform_metada` accepts `in_axes` and `out_axes`, these should match the values passed to the corresponding transform. ### Module.sow Previously, `Module.sow` used graph updates to capture intermediate values during computations and propagate them outside, it was used in conjunction with `nnx.pop` to log and extract intermediates: ```py # old code class Foo(nnx.Module): def __call__(self, x): self.sow(nnx.Intermediate, "y_mean", jnp.mean(x)) return x model = Foo() result = model(x) intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values ``` To achieve the same without graph updates we’ve added a new `nnx.capture` API which allows for a similar workflow. ```py # New Code class Foo(nnx.Module): def __call__(self, x): self.sow(nnx.Intermediate, "y_mean", jnp.mean(x)) return x model = Foo() result, intermediates = nnx.capture(model, nnx.Intermediate)(x) ``` In general, `nnx.capture` takes a function or Module to be transformed, a `nnx.Variable` subclass to collect, and an optional `init` argument to initialize the collected state, which will be stored within `nnx.Variable` objects. `nnx.capture` creates a `__captures__: tuple[Variable, ...]` attribute on each `Module` instance, each Variable in `__captures__` contains a dictionary which `sow` and `perturb` populate. ### Module.perturb Similarly, `Module.perturb` was previously used to extract the gradients of intermediate values. This was done in two steps: initializing a perturbation state by running a module once, and then passing the perturbation state as a differentiable target to `grad`. ```py class Model(nnx.Module): def __call__(self, x): x = self.perturb('grad_of_x', x) ... return y # old code @nnx.jit def train_step(model, optimizer, x, y): model(x) # Initialize perturbation state def loss_fn(model): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) diff_state = nnx.DiffState(0, (nnx.Param, nnx.Perturbation)) grads = nnx.grad(loss_fn, argnums=diff_state)(model) grads, interm_grads = nnx.state(grads, nnx.Param, nnx.Perturbation) optimizer.update(model, grads) nnx.pop(model, nnx.Perturbation) # clean up perturbations return interm_grads ``` Similar pattern can be used with `nnx.capture` during both perturbation initialization and when running the forward pass to insert the differentiable perturbations state. In this version explicitly pass the `perturbs` state as a separate argument and use `argnums` to specify that both arguments are differentiable: ```py # new code @nnx.jit def train_step(model, optimizer, x, y): _, perturbs = nnx.capture(model, nnx.Perturbation)(x) # init perturbations def loss_fn(model, perturbs): y_pred = nnx.capture(model, init=perturbs)(x) return jnp.mean((y_pred - y) ** 2) grads, interm_grads = nnx.grad(loss_fn, argnums=(0, 1))(model, perturbs) optimizer.update(model, grads) return interm_grads ``` ================================================ FILE: docs_nnx/flip/README.md ================================================ # FLIP: Flax Improvement Process Most changes can be discussed with simple issues/discussions and pull requests. Some changes though are a bit larger in scope or require more discussion, and these should be implemented as FLIPs. This allows for writing longer documents that can be discussed in a pull request themselves. The structure of FLIPs is kept as lightweight as possible to start and might be extended later on. ## When you should use a FLIP - When your change requires a design doc. We prefer collecting the designs as FLIPs for better discoverability and further reference. - When your change requires extensive discussion. It's fine to have relatively short discussions on issues or pull requests, but when the discussion gets longer this becomes unpractical for later digestion. FLIPs allow to update the main document with a summary of the discussion and these updates can be discussed themselves in the pull request adding the FLIP. ## How to start a FLIP First, create an issue with the [FLIP label]. All pull requests that relate to the FLIP (i.e. adding the FLIP itself as well as any implementing pull requests) should be linked to this issue. Then create a pull request that consists of a copy of the `0000-template.md` renamed to `%04d-{short-title}.md` - with the number being the issue number. [FLIP label]: https://github.com/google/flax/issues?q=label%3AFLIP ================================================ FILE: docs_nnx/guides/blog.md ================================================ ### Do we need another JAX NN library? Hello, today I want to talk to you about a new JAX library that I have been working on, but before I do that, I wanted to discuss the topic: Do we need another JAX NN library? ### JAX Libraries JAX NN libraries come in a wide variety ranging from functional like Flax and Haiku, to Pytree-based like Equinox. ================================================ FILE: docs_nnx/guides/bridge_guide.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Use Flax NNX and Linen together\n", "\n", "This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.\n", "\n", "This will be helpful if you:\n", "\n", "* Want to migrate your codebase to NNX gradually, one module at a time;\n", "* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX.\n", "\n", "We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.\n", "\n", "**Note**:\n", "\n", "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide.\n", "\n", "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'\n", "\n", "from flax import nnx\n", "from flax import linen as nn\n", "from flax.nnx import bridge\n", "import jax\n", "from jax import numpy as jnp\n", "from jax.experimental import mesh_utils\n", "from typing import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submodule is all you need\n", "\n", "A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`).\n", "\n", "An `nnx.bridge` wrapper glues the two types together, in both ways:\n", "\n", "* `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops.\n", "* `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen.\n", "\n", "This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Basics\n", "\n", "There are two fundamental difference between Linen and NNX modules:\n", "\n", "* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes.\n", "\n", "* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.\n", "\n", "With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences.\n", "\n", "### Linen -> NNX\n", "\n", "Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input.\n", "\n", "For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code.\n", "\n", "(Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class LinenDot(nn.Module):\n", " out_dim: int\n", " w_init: Callable[..., Any] = nn.initializers.lecun_normal()\n", " @nn.compact\n", " def __call__(self, x):\n", " # Linen might need the input shape to create the weight!\n", " w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))\n", " return x @ w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.ToNNX(LinenDot(64),\n", " rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n", "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n", "y = model(x) # => `y = model.apply(var, x)` in Linen\n", "\n", "nnx.display(model)\n", "\n", "# In-place swap your weight array and the model still works!\n", "model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n", "assert not jnp.allclose(y, model(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class NNXOuter(nnx.Module):\n", " def __init__(self, out_dim: int, rngs: nnx.Rngs):\n", " self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)\n", " self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))\n", "\n", " def __call__(self, x):\n", " return self.dot(x) + self.b\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module.\n", "\n", "We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "assert isinstance(model.dot.w, nnx.Param)\n", "assert isinstance(model.dot.w.value, jax.Array)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "partial_model = NNXOuter(64, rngs=nnx.Rngs(0))\n", "nnx.display(partial_model)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "full_model = bridge.lazy_init(partial_model, x)\n", "nnx.display(full_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NNX -> Linen\n", "\n", "To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process.\n", "\n", "This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['params']\n", "(32, 64)\n", "(4, 64)\n" ] } ], "source": [ "class NNXDot(nnx.Module):\n", " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", " self.w = nnx.Param(nnx.initializers.lecun_normal()(\n", " rngs.params(), (in_dim, out_dim)))\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "# Pass in the arguments, not an actual module\n", "model = bridge.to_linen(NNXDot, 32, out_dim=64)\n", "variables = model.init(jax.random.key(0), x)\n", "y = model.apply(variables, x)\n", "\n", "print(list(variables.keys()))\n", "print(variables['params']['w'].shape) # => (32, 64)\n", "print(y.shape) # => (4, 64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class NNXAddConstant(nnx.Module):\n", " def __init__(self):\n", " self.constant = nnx.Variable(jnp.array(1))\n", " def __call__(self, x):\n", " return x + self.constant\n", "\n", "# You have to use `skip_rng=True` because this module's `__init__` don't\n", "# take `rng` as argument\n", "model = bridge.ToLinen(NNXAddConstant, skip_rng=True)\n", "y, var = model.init_with_output(jax.random.key(0), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(32, 64) (1, 64) (4, 64)\n" ] } ], "source": [ "class LinenOuter(nn.Module):\n", " out_dim: int\n", " @nn.compact\n", " def __call__(self, x):\n", " dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim)\n", " b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim))\n", " return dot(x) + b\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenOuter(out_dim=64)\n", "y, variables = model.init_with_output(jax.random.key(0), x)\n", "w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']\n", "print(w.shape, b.shape, y.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Handling RNG keys\n", "\n", "All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.\n", "\n", "Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves.\n", "\n", "### Linen to NNX\n", "\n", "If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))\n", "# We don't really need to call lazy_init because no extra params were created here,\n", "# but it's a good practice to always add this line.\n", "bridge.lazy_init(model, x)\n", "y1, y2 = model(x), model(x)\n", "assert not jnp.allclose(y1, y2) # Two runs yield different outputs!\n", "\n", "# Reset the dropout RNG seed, so that next model run will be the same as the first.\n", "nnx.reseed(model, dropout=0)\n", "y1 = model(x)\n", "nnx.reseed(model, dropout=0)\n", "y2 = model(x)\n", "assert jnp.allclose(y1, y2) # Two runs yield the same output!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NNX to Linen\n", "\n", "`to_linen` will automatically take the `rngs` dict argument and create a `Rngs` object that is passed to the underlying NNX module via the `rngs` keyword argument. If the module holds internal `RngState`, `to_linen` will always call reseed using the `rngs` dict to reset the RNG state." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.to_linen(nnx.Dropout, rate=0.5)\n", "variables = model.init({'dropout': jax.random.key(0)}, x)\n", "\n", "# Just pass different RNG keys for every `apply()` call.\n", "y1 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})\n", "y2 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})\n", "assert not jnp.allclose(y1, y2) # Every call yields different output!\n", "y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})\n", "assert jnp.allclose(y1, y3) # When you use same top-level RNG, outputs are same" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NNX variable types vs. Linen collections\n", "\n", "When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.\n", "\n", "Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically.\n", "\n", "Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`.\n", "\n", "### Linen to NNX\n", "\n", "For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly.\n", "\n", "(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 12 (48 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[ 0.53824717, 0.7668343 , -0.38585317],\n", " [-0.35335615, -0.5244857 , -0.43152452],\n", " [-1.0662307 , 0.14089198, -0.16519307],\n", " [ 0.3971692 , 0.43213558, -0.461545 ]], dtype=float32)\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([0., 0., 0.], dtype=float32)\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;79;201;177mcounter\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32)\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "\n", "(Intermediate( # 1 (4 B)\n", " value=Array(0.5475821, dtype=float32)\n", "),)\n" ] } ], "source": [ "class LinenMultiCollections(nn.Module):\n", " out_dim: int\n", " def setup(self):\n", " self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim))\n", " self.b = self.param('b', nn.zeros_init(), (self.out_dim,))\n", " self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))\n", "\n", " def __call__(self, x):\n", " if not self.is_initializing():\n", " self.count.value += 1\n", " y = x @ self.w + self.b\n", " self.sow('intermediates', 'dot_sum', jnp.sum(y))\n", " return y\n", "\n", "x = jax.random.normal(jax.random.key(42), (2, 4))\n", "model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)\n", "print(model.w) # Of type `nnx.Param` - note this is still under attribute `params`\n", "print(model.b) # Of type `nnx.Param`\n", "print(model.count) # Of type `counter` - auto-created type from the collection name\n", "print(type(model.count))\n", "\n", "y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n", "print(model.dot_sum) # Of type `nnx.Intermediates`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can quickly separate different types of NNX variables apart using `nnx.split`.\n", "\n", "This can be handy when you only want to set some variables as trainable." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All Params: ['b', 'w']\n", "All Counters: ['count']\n", "All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']\n" ] } ], "source": [ "# Separate variables of different types with nnx.split\n", "CountType = type(model.count)\n", "static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)\n", "print('All Params:', list(params.keys()))\n", "print('All Counters:', list(counter.keys()))\n", "print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))\n", "\n", "model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time\n", "y = model(x, mutable=True) # still works!" ] }, { "cell_type": "markdown", "id": "cc9d78ed", "metadata": {}, "source": [ " All Params: ['b', 'w']\n", " All Counters: ['count']\n", " All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NNX to Linen\n", "\n", "If you define custom NNX variable types, you should register their names with `nnx.register_variable_name` so that they go to the desired collections." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']\n", "{'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],\n", " [ 0.17487915, -0.34043145, 0.24764155],\n", " [ 0.6420431 , 0.6220095 , -0.44769976],\n", " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}\n" ] } ], "source": [ "@nnx.register_variable_name('counts', overwrite=True)\n", "class Count(nnx.Variable): pass\n", "\n", "\n", "class NNXMultiCollections(nnx.Module):\n", " def __init__(self, din, dout, rngs):\n", " self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))\n", " self.lora = nnx.LoRA(din, 3, dout, rngs=rngs)\n", " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self, x):\n", " self.count.value += 1\n", " return (x @ self.w.value) + self.lora(x)\n", "\n", "xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3)\n", "x = jax.random.normal(xkey, (2, 4))\n", "model = bridge.to_linen(NNXMultiCollections, 4, 3)\n", "var = model.init({'params': pkey, 'dropout': dkey}, x)\n", "print('All Linen collections:', list(var.keys()))\n", "print(var['params'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " All Linen collections: ['LoRAParam', 'params', 'counts']\n", " {'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],\n", " [ 0.17487915, -0.34043145, 0.24764155],\n", " [ 0.6420431 , 0.6220095 , -0.44769976],\n", " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partition metadata\n", "\n", "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n", "\n", "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n", "\n", "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX).\n", "\n", "### Linen to NNX\n", "\n", "Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within.\n", "\n", "If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`.\n", "\n", "You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "We have 8 fake JAX devices now to partition this model...\n", "\n", "('in', 'out')\n", "GSPMDSharding({devices=[2,4]<=[8]})\n" ] } ], "source": [ "class LinenDotWithPartitioning(nn.Module):\n", " out_dim: int\n", " @nn.compact\n", " def __call__(self, x):\n", " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),\n", " ('in', 'out')),\n", " (x.shape[-1], self.out_dim))\n", " return x @ w\n", "\n", "@nnx.jit\n", "def create_sharded_nnx_module(x):\n", " model = bridge.lazy_init(\n", " bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)\n", " state = nnx.state(model)\n", " sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))\n", " nnx.update(model, sharded_state)\n", " return model\n", "\n", "\n", "print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n", "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n", " axis_names=('in', 'out'))\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "with jax.set_mesh(mesh):\n", " model = create_sharded_nnx_module(x)\n", "\n", "print(type(model.w)) # `nnx.Param`\n", "print(model.w.sharding) # The partition annotation attached with `w`\n", "print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" ] }, { "cell_type": "markdown", "id": "08555a06", "metadata": {}, "source": [ " We have 8 fake JAX devices now to partition this model...\n", " \n", " ('in', 'out')\n", " GSPMDSharding({devices=[2,4]<=[8]})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NNX to Linen\n", "\n", "If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it.\n", "\n", "But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable.\n", "\n", "Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GSPMDSharding({devices=[2,4]<=[8]})\n" ] } ], "source": [ "class NNXDotWithParititioning(nnx.Module):\n", " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", " init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))\n", " self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim)))\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "\n", "@jax.jit\n", "def create_sharded_variables(key, x):\n", " model = bridge.to_linen(NNXDotWithParititioning, 32, 64)\n", " variables = model.init(key, x)\n", " # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", " assert type(variables['params']['w']) == bridge.NNXMeta\n", " # The annotation coming from the `nnx.Param` => (in, out)\n", " assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n", "\n", " unboxed_variables = nn.unbox(variables)\n", " variable_pspecs = nn.get_partition_spec(variables)\n", " assert isinstance(unboxed_variables['params']['w'], jax.Array)\n", " assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')\n", "\n", " sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint,\n", " nn.unbox(variables),\n", " nn.get_partition_spec(variables))\n", " return sharded_vars\n", "\n", "with jax.set_mesh(mesh):\n", " variables = create_sharded_variables(jax.random.key(0), x)\n", "\n", "# The underlying JAX array is sharded across the 2x4 mesh\n", "print(variables['params']['w'].sharding)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " GSPMDSharding({devices=[2,4]<=[8]})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lifted transforms\n", "\n", "In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax.\n", "\n", "For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases)\n", "\n", "### Linen to NNX\n", "\n", "NNX style lifted transforms are similar to JAX transforms, and they work on functions." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 32, 64)\n", "(4, 64)\n" ] } ], "source": [ "class NNXVmapped(nnx.Module):\n", " def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):\n", " self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)\n", " self.vmap_axis_size = vmap_axis_size\n", "\n", " def __call__(self, x):\n", "\n", " @nnx.split_rngs(splits=self.vmap_axis_size)\n", " @nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size)\n", " def vmap_fn(submodule, x):\n", " return submodule(x)\n", "\n", " return vmap_fn(self.linen_dot, x)\n", "\n", "x = jax.random.normal(jax.random.key(0), (4, 32))\n", "model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)\n", "\n", "print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped\n", "y = model(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " (4, 32, 64)\n", " (4, 64)" ] }, { "cell_type": "markdown", "id": "61a1ac21", "metadata": {}, "source": [ "### NNX to Linen\n", "\n", "Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases).\n", "\n", "`ToLien` can naturally be used with Linen transforms like `nn.vmap` or `nn.scan`." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 32, 64)\n", "(4, 64)\n" ] } ], "source": [ "class LinenVmapped(nn.Module):\n", " dout: int\n", " @nn.compact\n", " def __call__(self, x):\n", " inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0}, split_rngs={'params': True}\n", " )(nnx.Linear, args=(x.shape[-1], self.dout))\n", " return inner(x)\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenVmapped(64)\n", "var = model.init(jax.random.key(0), x)\n", "print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped\n", "y = model.apply(var, x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "178d2b2f", "metadata": {}, "source": [ " (4, 32, 64)\n", " (4, 64)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "formats": "ipynb,md:myst", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs_nnx/guides/bridge_guide.md ================================================ --- jupytext: cell_metadata_filter: -all formats: ipynb,md:myst main_language: python text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Use Flax NNX and Linen together This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. This will be helpful if you: * Want to migrate your codebase to NNX gradually, one module at a time; * Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX. We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different. **Note**: This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html). ```{code-cell} ipython3 import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' from flax import nnx from flax import linen as nn from flax.nnx import bridge import jax from jax import numpy as jnp from jax.experimental import mesh_utils from typing import * ``` ## Submodule is all you need A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). An `nnx.bridge` wrapper glues the two types together, in both ways: * `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops. * `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen. This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up. +++ ## The Basics There are two fundamental difference between Linen and NNX modules: * **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes. * **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input. With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences. ### Linen -> NNX Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input. For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code. (Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.) ```{code-cell} ipython3 class LinenDot(nn.Module): out_dim: int w_init: Callable[..., Any] = nn.initializers.lecun_normal() @nn.compact def __call__(self, x): # Linen might need the input shape to create the weight! w = self.param('w', self.w_init, (x.shape[-1], self.out_dim)) return x @ w x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen y = model(x) # => `y = model.apply(var, x)` in Linen nnx.display(model) # In-place swap your weight array and the model still works! model.w.value = jax.random.normal(jax.random.key(1), (32, 64)) assert not jnp.allclose(y, model(x)) ``` `nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish: ```{code-cell} ipython3 class NNXOuter(nnx.Module): def __init__(self, out_dim: int, rngs: nnx.Rngs): self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs) self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,))) def __call__(self, x): return self.dot(x) + self.b x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line nnx.display(model) ``` The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module. We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones. ```{code-cell} ipython3 assert isinstance(model.dot.w, nnx.Param) assert isinstance(model.dot.w.value, jax.Array) ``` If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not. ```{code-cell} ipython3 partial_model = NNXOuter(64, rngs=nnx.Rngs(0)) nnx.display(partial_model) ``` ```{code-cell} ipython3 full_model = bridge.lazy_init(partial_model, x) nnx.display(full_model) ``` ### NNX -> Linen To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process. This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice. ```{code-cell} ipython3 class NNXDot(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()( rngs.params(), (in_dim, out_dim))) def __call__(self, x: jax.Array): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) # Pass in the arguments, not an actual module model = bridge.to_linen(NNXDot, 32, out_dim=64) variables = model.init(jax.random.key(0), x) y = model.apply(variables, x) print(list(variables.keys())) print(variables['params']['w'].shape) # => (32, 64) print(y.shape) # => (4, 64) ``` `bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling: ```{code-cell} ipython3 class NNXAddConstant(nnx.Module): def __init__(self): self.constant = nnx.Variable(jnp.array(1)) def __call__(self, x): return x + self.constant # You have to use `skip_rng=True` because this module's `__init__` don't # take `rng` as argument model = bridge.ToLinen(NNXAddConstant, skip_rng=True) y, var = model.init_with_output(jax.random.key(0), x) ``` Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. ```{code-cell} ipython3 class LinenOuter(nn.Module): out_dim: int @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXDot, x.shape[-1], self.out_dim) b = self.param('b', nn.initializers.lecun_normal(), (1, self.out_dim)) return dot(x) + b x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenOuter(out_dim=64) y, variables = model.init_with_output(jax.random.key(0), x) w, b = variables['params']['ToLinen_0']['w'], variables['params']['b'] print(w.shape, b.shape, y.shape) ``` ## Handling RNG keys All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys. Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves. ### Linen to NNX If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within. ```{code-cell} ipython3 x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0)) # We don't really need to call lazy_init because no extra params were created here, # but it's a good practice to always add this line. bridge.lazy_init(model, x) y1, y2 = model(x), model(x) assert not jnp.allclose(y1, y2) # Two runs yield different outputs! # Reset the dropout RNG seed, so that next model run will be the same as the first. nnx.reseed(model, dropout=0) y1 = model(x) nnx.reseed(model, dropout=0) y2 = model(x) assert jnp.allclose(y1, y2) # Two runs yield the same output! ``` ### NNX to Linen `to_linen` will automatically take the `rngs` dict argument and create a `Rngs` object that is passed to the underlying NNX module via the `rngs` keyword argument. If the module holds internal `RngState`, `to_linen` will always call reseed using the `rngs` dict to reset the RNG state. ```{code-cell} ipython3 x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.to_linen(nnx.Dropout, rate=0.5) variables = model.init({'dropout': jax.random.key(0)}, x) # Just pass different RNG keys for every `apply()` call. y1 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)}) y2 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)}) assert not jnp.allclose(y1, y2) # Every call yields different output! y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)}) assert jnp.allclose(y1, y3) # When you use same top-level RNG, outputs are same ``` ## NNX variable types vs. Linen collections When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types. Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`. ### Linen to NNX For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. (However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.) ```{code-cell} ipython3 class LinenMultiCollections(nn.Module): out_dim: int def setup(self): self.w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.out_dim)) self.b = self.param('b', nn.zeros_init(), (self.out_dim,)) self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32)) def __call__(self, x): if not self.is_initializing(): self.count.value += 1 y = x @ self.w + self.b self.sow('intermediates', 'dot_sum', jnp.sum(y)) return y x = jax.random.normal(jax.random.key(42), (2, 4)) model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x) print(model.w) # Of type `nnx.Param` - note this is still under attribute `params` print(model.b) # Of type `nnx.Param` print(model.count) # Of type `counter` - auto-created type from the collection name print(type(model.count)) y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger print(model.dot_sum) # Of type `nnx.Intermediates` ``` You can quickly separate different types of NNX variables apart using `nnx.split`. This can be handy when you only want to set some variables as trainable. ```{code-cell} ipython3 # Separate variables of different types with nnx.split CountType = type(model.count) static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...) print('All Params:', list(params.keys())) print('All Counters:', list(counter.keys())) print('All the rest (intermediates and RNG keys):', list(the_rest.keys())) model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time y = model(x, mutable=True) # still works! ``` All Params: ['b', 'w'] All Counters: ['count'] All the rest (intermediates and RNG keys): ['dot_sum', 'rngs'] +++ ### NNX to Linen If you define custom NNX variable types, you should register their names with `nnx.register_variable_name` so that they go to the desired collections. ```{code-cell} ipython3 @nnx.register_variable_name('counts', overwrite=True) class Count(nnx.Variable): pass class NNXMultiCollections(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) self.lora = nnx.LoRA(din, 3, dout, rngs=rngs) self.count = Count(jnp.array(0)) def __call__(self, x): self.count.value += 1 return (x @ self.w.value) + self.lora(x) xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3) x = jax.random.normal(xkey, (2, 4)) model = bridge.to_linen(NNXMultiCollections, 4, 3) var = model.init({'params': pkey, 'dropout': dkey}, x) print('All Linen collections:', list(var.keys())) print(var['params']) ``` All Linen collections: ['LoRAParam', 'params', 'counts'] {'w': Array([[ 0.2916921 , 0.22780475, 0.06553137], [ 0.17487915, -0.34043145, 0.24764155], [ 0.6420431 , 0.6220095 , -0.44769976], [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)} +++ ## Partition metadata Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded. In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX). ### Linen to NNX Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding. ```{code-cell} ipython3 class LinenDotWithPartitioning(nn.Module): out_dim: int @nn.compact def __call__(self, x): w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), (x.shape[-1], self.out_dim)) return x @ w @nnx.jit def create_sharded_nnx_module(x): model = bridge.lazy_init( bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) state = nnx.state(model) sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) nnx.update(model, sharded_state) return model print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...') mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out')) x = jax.random.normal(jax.random.key(42), (4, 32)) with jax.set_mesh(mesh): model = create_sharded_nnx_module(x) print(type(model.w)) # `nnx.Param` print(model.w.sharding) # The partition annotation attached with `w` print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh ``` We have 8 fake JAX devices now to partition this model... ('in', 'out') GSPMDSharding({devices=[2,4]<=[8]}) +++ ### NNX to Linen If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it. But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable. Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree. ```{code-cell} ipython3 class NNXDotWithParititioning(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out')) self.w = nnx.Param(init_fn(rngs.params(), (in_dim, out_dim))) def __call__(self, x: jax.Array): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) @jax.jit def create_sharded_variables(key, x): model = bridge.to_linen(NNXDotWithParititioning, 32, 64) variables = model.init(key, x) # A `NNXMeta` wrapper of the underlying `nnx.Param` assert type(variables['params']['w']) == bridge.NNXMeta # The annotation coming from the `nnx.Param` => (in, out) assert variables['params']['w'].metadata['sharding'] == ('in', 'out') unboxed_variables = nn.unbox(variables) variable_pspecs = nn.get_partition_spec(variables) assert isinstance(unboxed_variables['params']['w'], jax.Array) assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out') sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, nn.unbox(variables), nn.get_partition_spec(variables)) return sharded_vars with jax.set_mesh(mesh): variables = create_sharded_variables(jax.random.key(0), x) # The underlying JAX array is sharded across the 2x4 mesh print(variables['params']['w'].sharding) ``` GSPMDSharding({devices=[2,4]<=[8]}) +++ ## Lifted transforms In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases) ### Linen to NNX NNX style lifted transforms are similar to JAX transforms, and they work on functions. ```{code-cell} ipython3 class NNXVmapped(nnx.Module): def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs): self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs) self.vmap_axis_size = vmap_axis_size def __call__(self, x): @nnx.split_rngs(splits=self.vmap_axis_size) @nnx.vmap(in_axes=(0, 0), axis_size=self.vmap_axis_size) def vmap_fn(submodule, x): return submodule(x) return vmap_fn(self.linen_dot, x) x = jax.random.normal(jax.random.key(0), (4, 32)) model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x) print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped y = model(x) print(y.shape) ``` (4, 32, 64) (4, 64) +++ ### NNX to Linen Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). `ToLien` can naturally be used with Linen transforms like `nn.vmap` or `nn.scan`. ```{code-cell} ipython3 class LinenVmapped(nn.Module): dout: int @nn.compact def __call__(self, x): inner = nn.vmap(bridge.ToLinen, variable_axes={'params': 0}, split_rngs={'params': True} )(nnx.Linear, args=(x.shape[-1], self.dout)) return inner(x) x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenVmapped(64) var = model.init(jax.random.key(0), x) print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped y = model.apply(var, x) print(y.shape) ``` (4, 32, 64) (4, 64) ================================================ FILE: docs_nnx/guides/checkpointing.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Save and load checkpoints\n", "\n", "This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/).\n", "\n", "> **Note:** The Flax team does not actively maintain a library for saving and loading model checkpoints to disk. Therefore, it is recommended you use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it.\n", "\n", "In this guide you will learn how to:\n", "\n", "* Save checkpoints.\n", "* Restore checkpoints.\n", "* Restore checkpoints if checkpoint structures differ. \n", "* Perform multi-process checkpointing. \n", "\n", "The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/).\n", "\n", "> **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities.\n", "\n", "> **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup\n", "\n", "Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "import orbax.checkpoint as ocp\n", "import jax\n", "from jax import numpy as jnp\n", "import numpy as np\n", "\n", "ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class TwoLayerMLP(nnx.Module):\n", " def __init__(self, dim, rngs: nnx.Rngs):\n", " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)\n", "\n", " def __call__(self, x):\n", " x = self.linear1(x)\n", " return self.linear2(x)\n", "\n", "# Instantiate the model and show we can run it.\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "assert model(x).shape == (3, 4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save checkpoints\n", "\n", "JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, \"tensors\" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states.\n", "\n", "In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, state = nnx.split(model)\n", "nnx.display(state)\n", "\n", "checkpointer = ocp.StandardCheckpointer()\n", "checkpointer.save(ckpt_dir / 'state', state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "\n", "\n", "
\n", "\n", "\n", "## Restore checkpoints\n", "\n", "Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n", "\n", "At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n", "- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n", "- Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The abstract NNX state (all leaves are abstract arrays):\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "NNX State restored: \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1251: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.\n", "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graphdef, abstract_state = nnx.split(abstract_model)\n", "print('The abstract NNX state (all leaves are abstract arrays):')\n", "nnx.display(abstract_state)\n", "\n", "state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)\n", "jax.tree.map(np.testing.assert_array_equal, state, state_restored)\n", "print('NNX State restored: ')\n", "nnx.display(state_restored)\n", "\n", "# The model is now good to use!\n", "model = nnx.merge(graphdef, state_restored)\n", "assert model(x).shape == (3, 4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " The abstract NNX state (all leaves are abstract arrays):\n", "\n", "\n", "\n", "
\n", "\n", "\n", " NNX State restored: \n", "\n", "\n", " /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", " warnings.warn(\n", "\n", "\n", "\n", "
\n", "\n", "\n", "\n", "
\n", "\n", "\n", "## Save and restore as pure dictionaries\n", "\n", "When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" ] } ], "source": [ "# Save as pure dict\n", "pure_dict_state = nnx.to_pure_dict(state)\n", "nnx.display(pure_dict_state)\n", "checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n", "\n", "# Restore as a pure dictionary.\n", "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", "abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graphdef, abstract_state = nnx.split(abstract_model)\n", "nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n", "model = nnx.merge(graphdef, abstract_state)\n", "assert model(x).shape == (3, 4) # The model still works!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "\n", "\n", "
\n", "\n", "\n", " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", "\n", "\n", "## Restore when checkpoint structures differ\n", "\n", "The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below.\n", "\n", "This pattern also works if you save the checkpoint as an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class ModifiedTwoLayerMLP(nnx.Module):\n", " \"\"\"A modified version of TwoLayerMLP, which requires bias arrays.\"\"\"\n", " def __init__(self, dim, rngs: nnx.Rngs):\n", " self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", " self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now!\n", "\n", " def __call__(self, x):\n", " x = self.linear1(x)\n", " return self.linear2(x)\n", "\n", "# Accommodate your old checkpoint to the new code.\n", "restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n", "restored_pure_dict['linear1']['bias'] = jnp.zeros((4,))\n", "restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))\n", "\n", "# Same restore code as above.\n", "abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graphdef, abstract_state = nnx.split(abstract_model)\n", "nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n", "model = nnx.merge(graphdef, abstract_state)\n", "assert model(x).shape == (3, 4) # The new model works!\n", "\n", "nnx.display(model.linear1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n", "\n", "\n", "\n", "
\n", "\n", "\n", "\n", "
\n", "\n", "\n", "## Multi-process checkpointing\n", "\n", "In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out the [Load sharded model from a checkpoint](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) section in the Flax [Scale up on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint.\n", "\n", "> **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other checkpointing features\n", "\n", "This guide only uses the simplest [`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit.\n", "\n", "In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as:\n", "\n", "* [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps.\n", "\n", "* [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html).\n", "\n", "* [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time, instead of after loading time, which is demonstrated in this guide." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs_nnx/guides/checkpointing.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Save and load checkpoints This guide demonstrates how to save and load Flax NNX model checkpoints with [Orbax](https://orbax.readthedocs.io/). > **Note:** The Flax team does not actively maintain a library for saving and loading model checkpoints to disk. Therefore, it is recommended you use external libraries like [Orbax](https://orbax.readthedocs.io/en/latest/index.html) to do it. In this guide you will learn how to: * Save checkpoints. * Restore checkpoints. * Restore checkpoints if checkpoint structures differ. * Perform multi-process checkpointing. The Orbax API examples used throughout the guide are for demonstration purposes, and for the most up-to-date recommended APIs refer to the [Orbax website](https://orbax.readthedocs.io/). > **Note:** The Flax team recommends using [Orbax](https://orbax.readthedocs.io/en/latest/index.html) for saving and loading checkpoints to disk, as we do not actively maintain a library for these functionalities. > **Note:** If you are looking for Flax Linen's legacy `flax.training.checkpoints` package, it was deprecated in 2023 in favor of Orbax. The documentation resides on the [Flax Linen site](https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html). +++ ### Setup Import the necessary dependencies, set up a checkpoint directory and an example Flax NNX model - `TwoLayerMLP` - by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). ```{code-cell} ipython3 from flax import nnx import orbax.checkpoint as ocp import jax from jax import numpy as jnp import numpy as np ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') ``` ```{code-cell} ipython3 class TwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False) self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=False) def __call__(self, x): x = self.linear1(x) return self.linear2(x) # Instantiate the model and show we can run it. model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) assert model(x).shape == (3, 4) ``` ## Save checkpoints JAX checkpointing libraries, such as Orbax, can save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), which is a pure, possibly nested container of [`jax.Array`s)](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) (or, "tensors" as some other frameworks would put it). In the context of machine learning, the checkpoint is usually a pytree of model parameters and other data, such as optimizer states. In Flax NNX, you can obtain such a pytree from an [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) by calling [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), and picking up the returned [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State). ```{code-cell} ipython3 _, state = nnx.split(model) nnx.display(state) checkpointer = ocp.StandardCheckpointer() checkpointer.save(ckpt_dir / 'state', state) ```
## Restore checkpoints Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes. At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows: - First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library. - Once you have the state, use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to obtain your Flax NNX model, and use it as usual. ```{code-cell} ipython3 # Restore the checkpoint back to its `nnx.State` structure - need an abstract reference. abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) graphdef, abstract_state = nnx.split(abstract_model) print('The abstract NNX state (all leaves are abstract arrays):') nnx.display(abstract_state) state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state) jax.tree.map(np.testing.assert_array_equal, state, state_restored) print('NNX State restored: ') nnx.display(state_restored) # The model is now good to use! model = nnx.merge(graphdef, state_restored) assert model(x).shape == (3, 4) ``` The abstract NNX state (all leaves are abstract arrays):
NNX State restored: /Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with. warnings.warn(
## Save and restore as pure dictionaries When interacting with checkpoint libraries (like Orbax), you may prefer to work with Python built-in container types. In this case, you can use the [`nnx.State.to_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L170) and [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179) API to convert an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) to and from pure nested dictionaries. ```{code-cell} ipython3 # Save as pure dict pure_dict_state = nnx.to_pure_dict(state) nnx.display(pure_dict_state) checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) # Restore as a pure dictionary. restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) graphdef, abstract_state = nnx.split(abstract_model) nnx.replace_by_pure_dict(abstract_state, restored_pure_dict) model = nnx.merge(graphdef, abstract_state) assert model(x).shape == (3, 4) # The model still works! ```
WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under. ## Restore when checkpoint structures differ The ability to load a checkpoint as a pure nested dictionary can come in handy when you want to load some outdated checkpoints that no longer match with your current model code. Check out this simple example below. This pattern also works if you save the checkpoint as an [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) instead of a pure dictionary. Check out the [Checkpoint surgery section](https://flax.readthedocs.io/en/latest/guides/surgery.html#checkpoint-surgery) of the [Model Surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) guide for an example with code. The only difference is you need to reprocess your raw dictionary a bit before calling [`nnx.State.replace_by_pure_dict`](https://github.com/google/flax/blob/764e1732dcd3b8bf178b9ba73ddecf125709b5d7/flax/nnx/statelib.py#L179). ```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): """A modified version of TwoLayerMLP, which requires bias arrays.""" def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now! self.linear2 = nnx.Linear(dim, dim, rngs=rngs, use_bias=True) # We need bias now! def __call__(self, x): x = self.linear1(x) return self.linear2(x) # Accommodate your old checkpoint to the new code. restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') restored_pure_dict['linear1']['bias'] = jnp.zeros((4,)) restored_pure_dict['linear2']['bias'] = jnp.zeros((4,)) # Same restore code as above. abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))) graphdef, abstract_state = nnx.split(abstract_model) nnx.replace_by_pure_dict(abstract_state, restored_pure_dict) model = nnx.merge(graphdef, abstract_state) assert model(x).shape == (3, 4) # The new model works! nnx.display(model.linear1) ``` WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.
## Multi-process checkpointing In a multi-host/multi-process environment, you would want to restore your checkpoint as sharded across multiple devices. Check out the [Load sharded model from a checkpoint](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-sharded-model-from-a-checkpoint) section in the Flax [Scale up on multiple devices](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide to learn how to derive a sharding pytree and use it to load your checkpoint. > **Note:** JAX provides several ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). Check out JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html), [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html), [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), and [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html). +++ ## Other checkpointing features This guide only uses the simplest [`orbax.checkpoint.StandardCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#standardcheckpointer) API to show how to save and load on the Flax modeling side. Feel free to use other tools or libraries as you see fit. In addition, check out [the Orbax website](https://orbax.readthedocs.io/en/latest/index.html) for other commonly used features, such as: * [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) to track checkpoints from different steps. * [Asynchronous checkpointing](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html). * [Orbax transformations](https://orbax.readthedocs.io/en/latest/guides/checkpoint/transformations.html): A way to modify pytree structure during loading time, instead of after loading time, which is demonstrated in this guide. ================================================ FILE: docs_nnx/guides/demo.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "a1b37dff", "metadata": {}, "source": [ "# NNX Demo" ] }, { "cell_type": "code", "execution_count": 1, "id": "e8099a6f", "metadata": {}, "outputs": [], "source": [ "import jax\n", "from jax import numpy as jnp\n", "from flax import nnx" ] }, { "cell_type": "markdown", "id": "bcc5cffe", "metadata": {}, "source": [ "### [1] NNX is Pythonic" ] }, { "cell_type": "code", "execution_count": 7, "id": "d99b73af", "metadata": { "outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model = MLP(\n", " blocks=[Block(\n", " linear=Linear(\n", " in_features=4,\n", " out_features=4,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x28ae86dc0>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " bn=BatchNorm(\n", " num_features=4,\n", " \n", "...\n" ] } ], "source": [ "\n", "class Block(nnx.Module):\n", " def __init__(self, din, dout, *, rngs):\n", " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", " self.bn = nnx.BatchNorm(dout, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " return nnx.relu(self.bn(self.linear(x)))\n", "\n", "\n", "class MLP(nnx.Module):\n", " def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading\n", " self.blocks = [\n", " Block(dim, dim, rngs=rngs) for _ in range(nlayers)\n", " ]\n", " self.count = Count(0) # stateful variables are defined as attributes\n", "\n", " def __call__(self, x):\n", " self.count.value += 1 # in-place stateful updates\n", " for block in self.blocks:\n", " x = block(x)\n", " return x\n", "\n", "class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", " pass\n", "\n", "model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", "model.set_attributes(use_running_average=False) # set flags\n", "y = model(jnp.ones((2, 4))) # call methods directly\n", "\n", "print(f'{model = }'[:500] + '\\n...')" ] }, { "cell_type": "markdown", "id": "523aa27c", "metadata": {}, "source": [ "Because NNX Modules contain their own state, they are very easily to inspect:" ] }, { "cell_type": "code", "execution_count": 9, "id": "6f278ec4", "metadata": { "outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.count = Count(\n", " raw_value=1\n", ")\n", "model.blocks[0].linear.kernel = Param(\n", " raw_value=Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n", " [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n", " [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n", " [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)\n", ")\n" ] } ], "source": [ "print(f'{model.count = }')\n", "print(f'{model.blocks[0].linear.kernel = }')\n", "# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection" ] }, { "cell_type": "markdown", "id": "95f389f2", "metadata": {}, "source": [ "### [2] Model Surgery is Intuitive" ] }, { "cell_type": "code", "execution_count": 10, "id": "96f61108", "metadata": { "outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (2, 4)\n" ] } ], "source": [ "# Module sharing\n", "model.blocks[1] = model.blocks[3]\n", "# Weight tying\n", "model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel\n", "# Monkey patching\n", "def my_optimized_layer(x): return x\n", "model.blocks[2] = my_optimized_layer\n", "\n", "y = model(jnp.ones((2, 4))) # still works\n", "print(f'{y.shape = }')" ] }, { "cell_type": "markdown", "id": "aca5a6cd", "metadata": {}, "source": [ "### [3] Interacting with JAX is easy" ] }, { "cell_type": "code", "execution_count": 11, "id": "c166dcc7", "metadata": { "outputId": "9a3f378b-739e-4f45-9968-574651200ede" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "state = State({\n", " 'blocks': {\n", " '0': {\n", " 'linear': {\n", " 'kernel': Param(\n", " raw_value=Array([[-0.33095378, 0.67149884, 0.33700302, 0.30972847],\n", " [ 0.8662822 , -0.11225506, -1.0820619 , -0.9906892 ],\n", " [ 0.88298297, -0.2143851 , 0.48143268, 0.6474548 ],\n", " [-0.7710582 , 0.3372276 , 0.15487202, 0.6219269 ]], dtype=float32)\n", " ),\n", " 'bias': Param(\n", " raw_value=Array([0., 0., 0., 0.], dtype=float32)\n", " \n", "...\n", "\n", "graphdef = GraphDef(\n", " type=MLP,\n", " index=0,\n", " attributes=('blocks', 'count'),\n", " subgraphs={\n", " 'blocks': GraphDef(\n", " type=list,\n", " index=1,\n", " attributes=('0', '1', '2', '3', '4'),\n", " subgraphs={\n", " '0': GraphDef(\n", " type=Block,\n", " index=2,\n", " attributes=('line\n", "...\n" ] } ], "source": [ "graphdef, state = model.split()\n", "\n", "# state is a dictionary-like JAX pytree\n", "print(f'{state = }'[:500] + '\\n...')\n", "\n", "# graphdef is also a JAX pytree, but just metadata\n", "print(f'\\n{graphdefefefefefef = }'[:300] + '\\n...')" ] }, { "cell_type": "code", "execution_count": 12, "id": "9f03e3af", "metadata": { "outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (2, 4)\n", "model.count.value = Array(3, dtype=int32, weak_type=True)\n" ] } ], "source": [ "graphdef, state = model.split()\n", "\n", "@jax.jit\n", "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", " model = graphdef.merge(state)\n", " y = model(x)\n", " state, _ = model.split()\n", " return y, state\n", "\n", "x = jnp.ones((2, 4))\n", "y, state = forward(graphdef,state, x)\n", "\n", "model.update(state)\n", "\n", "print(f'{y.shape = }')\n", "print(f'{model.count.value = }')" ] }, { "cell_type": "code", "execution_count": 7, "id": "9e23dbb4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (2, 4)\n", "model.count = Array(4, dtype=int32, weak_type=True)\n" ] } ], "source": [ "params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count)\n", "\n", "@jax.jit\n", "def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", " model = graphdef.merge(params, batch_stats, counts)\n", " y = model(x, train=True)\n", " params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", " return y, params, batch_stats, counts\n", "\n", "x = jnp.ones((2, 4))\n", "y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)\n", "\n", "model.update(params, batch_stats, counts)\n", "\n", "print(f'{y.shape = }')\n", "print(f'{model.count = }')" ] }, { "cell_type": "code", "execution_count": 14, "id": "2461bfe8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (2, 4)\n", "parent.model.count.value = Array(4, dtype=int32, weak_type=True)\n" ] } ], "source": [ "class Parent(nnx.Module):\n", " def __init__(self, model: MLP):\n", " self.model = model\n", "\n", " def __call__(self, x):\n", " params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count)\n", "\n", " @jax.jit\n", " def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", " model = graphdef.merge(params, batch_stats, counts)\n", " y = model(x)\n", " params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", " return y, params, batch_stats, counts\n", "\n", " y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)\n", "\n", " self.model.update(params, batch_stats, counts)\n", " return y\n", "\n", "parent = Parent(model)\n", "\n", "y = parent(jnp.ones((2, 4)))\n", "\n", "print(f'{y.shape = }')\n", "print(f'{parent.model.count.value = }')" ] }, { "cell_type": "code", "execution_count": null, "id": "2e340bcb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/demo.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # NNX Demo ```{code-cell} ipython3 import jax from jax import numpy as jnp from flax import nnx ``` ### [1] NNX is Pythonic ```{code-cell} ipython3 :outputId: d8ef66d5-6866-4d5c-94c2-d22512bfe718 class Block(nnx.Module): def __init__(self, din, dout, *, rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x): return nnx.relu(self.bn(self.linear(x))) class MLP(nnx.Module): def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading self.blocks = [ Block(dim, dim, rngs=rngs) for _ in range(nlayers) ] self.count = Count(0) # stateful variables are defined as attributes def __call__(self, x): self.count.value += 1 # in-place stateful updates for block in self.blocks: x = block(x) return x class Count(nnx.Variable): # custom Variable types define the "collections" pass model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method model.set_attributes(use_running_average=False) # set flags y = model(jnp.ones((2, 4))) # call methods directly print(f'{model = }'[:500] + '\n...') ``` Because NNX Modules contain their own state, they are very easily to inspect: ```{code-cell} ipython3 :outputId: 10a46b0f-2993-4677-c26d-36a4ddf33449 print(f'{model.count = }') print(f'{model.blocks[0].linear.kernel = }') # print(f'{model.blocks.sdf.kernel = }') # typesafe inspection ``` ### [2] Model Surgery is Intuitive ```{code-cell} ipython3 :outputId: e6f86be8-3537-4c48-f471-316ee0fb6c45 # Module sharing model.blocks[1] = model.blocks[3] # Weight tying model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel # Monkey patching def my_optimized_layer(x): return x model.blocks[2] = my_optimized_layer y = model(jnp.ones((2, 4))) # still works print(f'{y.shape = }') ``` ### [3] Interacting with JAX is easy ```{code-cell} ipython3 :outputId: 9a3f378b-739e-4f45-9968-574651200ede graphdef, state = model.split() # state is a dictionary-like JAX pytree print(f'{state = }'[:500] + '\n...') # graphdef is also a JAX pytree, but just metadata print(f'\n{graphdefefefefefef = }'[:300] + '\n...') ``` ```{code-cell} ipython3 :outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d graphdef, state = model.split() @jax.jit def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array): model = graphdef.merge(state) y = model(x) state, _ = model.split() return y, state x = jnp.ones((2, 4)) y, state = forward(graphdef,state, x) model.update(state) print(f'{y.shape = }') print(f'{model.count.value = }') ``` ```{code-cell} ipython3 params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count) @jax.jit def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): model = graphdef.merge(params, batch_stats, counts) y = model(x, train=True) params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count) return y, params, batch_stats, counts x = jnp.ones((2, 4)) y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x) model.update(params, batch_stats, counts) print(f'{y.shape = }') print(f'{model.count = }') ``` ```{code-cell} ipython3 class Parent(nnx.Module): def __init__(self, model: MLP): self.model = model def __call__(self, x): params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count) @jax.jit def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): model = graphdef.merge(params, batch_stats, counts) y = model(x) params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count) return y, params, batch_stats, counts y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x) self.model.update(params, batch_stats, counts) return y parent = Parent(model) y = parent(jnp.ones((2, 4))) print(f'{y.shape = }') print(f'{parent.model.count.value = }') ``` ```{code-cell} ipython3 ``` ================================================ FILE: docs_nnx/guides/extracting_intermediates.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "0fa991e8", "metadata": {}, "source": [ "# Extracting intermediate values\n", "\n", "This guide will show you how to extract intermediate values from a module.\n", "Consider a toy neural network with two pieces: a \"feature\" component that embeds\n", "inputs in some feature space, and a \"loss\" component that operates on those features.\n", "We'll want to log these feature components during training to identify any issues with\n", "the feature extraction. To do this, we can use the `Module.sow` method." ] }, { "cell_type": "code", "execution_count": 2, "id": "e4c7c65b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "W0317 18:04:12.704562 2028538 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.\n" ] } ], "source": [ "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp\n", "from functools import partial\n", "\n", "class Foo(nnx.Module):\n", " def __init__(self, *, rngs: nnx.Rngs):\n", " self.dense1 = nnx.Linear(8, 32, rngs=rngs)\n", " self.dense2 = nnx.Linear(32, 1, rngs=rngs)\n", "\n", " def features(self, x, rngs= None):\n", " feature = nnx.relu(self.dense1(x))\n", " self.sow(nnx.Intermediate, 'features', feature)\n", " return feature\n", "\n", " def loss(self, x_features, y_features):\n", " return jnp.sum((x_features - y_features)**2)\n", "\n", " def __call__(self, x, y):\n", " return self.loss(self.features(x), self.features(y))\n", "\n", "# Instantiate the model.\n", "rngs = nnx.Rngs(0)\n", "model = Foo(rngs=rngs)\n", "\n", "# Dummy input for testing\n", "x, y = rngs.normal((2,8))" ] }, { "cell_type": "markdown", "id": "c56dd826", "metadata": {}, "source": [ "Here, `self.sow` will store intermediate values under the key `'features'` in a collection associated with the\n", "`nnx.Intermediate` type. If you want to log values to multiple different collections, you can use different subclasses of `nnx.Intermediate`\n", "for each collection.\n", "\n", "Now, we can wrap the module with the `nnx.capture` decorator, which wraps any `Callable` accepting a module as its argument (which includes `nnx.Module`s, their methods, or ordinary functions) to return both the resulting loss as well as any intermediate values stored to the `nnx.Intermediate` collection:" ] }, { "cell_type": "code", "execution_count": 2, "id": "c508f8f3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "State({\n", " 'features': Intermediate(\n", " value=((32,), (32,))\n", " )\n", "})" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "capturing_model = nnx.capture(model, nnx.Intermediate)\n", "result, intms = capturing_model(x, y)\n", "jax.tree.map(lambda a: a.shape, intms)" ] }, { "cell_type": "markdown", "id": "d2f2f609", "metadata": {}, "source": [ "Note that, by default, sow appends values every time it is called. We can see\n", "this in the *features* intermediate values logged above. It contains a tuple with one element for the call on `x` and one for the call on `y`. To override the default append behavior, specify `init_fn` and `reduce_fn` - see `Module.sow()`.\n", "\n", "## How `nnx.capture` Works\n", "\n", "`nnx.capture` works by temporarily installing a set of mutable capture buffers on every module in the graph before calling the wrapped function, then harvesting those buffers afterward. Before calling the wrapped function, `capture` walks the entire module graph with `iter_modules`. For each module it sets a `__captures__` attribute: a tuple of Variable instances, one per requested `var_type`. Each Variable holds a plain `dict` that maps sow-key → accumulated value.\n", "\n", "We can see this `__captures__` tuple by printing out the module contents during a `nnx.capture` call:" ] }, { "cell_type": "code", "execution_count": 15, "id": "1e5c2ae7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Captures: (Intermediate(\n", " value={}\n", "),)\n" ] } ], "source": [ "@nnx.capture(nnx.Intermediate)\n", "def print_captures(model):\n", " print(\"Captures:\", model.__captures__)\n", "_, intms = print_captures(nnx.Module())" ] }, { "cell_type": "markdown", "id": "2daab2c9", "metadata": {}, "source": [ "`Module.sow` looks for the Variable in the `__captures__` tuple whose type matches `variable_type`, then writes its value into that dict using `reduce_fn`.\n", "\n", "If no matching type is found, `sow` silently returns `False` without logging the value. This can be used to capture only a subset of the sown values. For example:" ] }, { "cell_type": "code", "execution_count": 16, "id": "159a909b-0c3a-411e-9ccb-98ddecd5720e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "State({\n", " 'gets_sown': Metric1(\n", " value=((2,),)\n", " )\n", "})" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Metric1(nnx.Intermediate):\n", " pass\n", "\n", "class Metric2(nnx.Intermediate):\n", " pass\n", "\n", "@nnx.capture(Metric1)\n", "def get_captures(model):\n", " model.sow(Metric1, 'gets_sown', jnp.ones(2))\n", " model.sow(Metric2, 'gets_ignored', jnp.ones(2))\n", "_, intms = get_captures(nnx.Module())\n", "jax.tree.map(lambda a: a.shape, intms)" ] }, { "cell_type": "markdown", "id": "1328ce66", "metadata": {}, "source": [ "## Capturing all intermediate values\n", "\n", "To observe the output of each method without manually adding calls to `sow`, we can call `nnx.capture` with the `method_outputs` argument. This will automatically `sow` the output of each method using the given variable type, including methods of sub-modules." ] }, { "cell_type": "code", "execution_count": 7, "id": "47781215", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "State({\n", " '__call__': Intermediate(\n", " value=((),)\n", " ),\n", " 'dense1': {\n", " '__call__': Intermediate(\n", " value=((32,), (32,))\n", " )\n", " },\n", " 'features': Intermediate(\n", " value=((32,), (32,))\n", " ),\n", " 'loss': Intermediate(\n", " value=((),)\n", " )\n", "})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Foo(nnx.Module):\n", " def __init__(self, *, rngs: nnx.Rngs):\n", " self.dense1 = nnx.Linear(8, 32, rngs=rngs)\n", " self.dense2 = nnx.Linear(32, 1, rngs=rngs)\n", "\n", " def features(self, x, rngs= None):\n", " feature = nnx.relu(self.dense1(x))\n", " return feature\n", "\n", " def loss(self, x_features, y_features):\n", " return jnp.sum((x_features - y_features)**2)\n", "\n", " def __call__(self, x, y):\n", " return self.loss(self.features(x), self.features(y))\n", "\n", "model = Foo(rngs=nnx.Rngs(0))\n", "capturing_model = nnx.capture(model, nnx.Intermediate, method_outputs=nnx.Intermediate)\n", "result, intms = capturing_model(x, y)\n", "jax.tree.map(lambda a: a.shape, intms)" ] }, { "cell_type": "markdown", "id": "eee2809a", "metadata": {}, "source": [ "This pattern should be considered the \"sledge hammer\" approach to capturing intermediates. As a debugging and inspection tool it is very useful, but using the other patterns described in this guide will give you more fine-grained control over what intermediates you want to extract. We can also combine the `method_output_type` argument with manual calls to sow to capture both layer outputs and computations mid-layer.\n", "\n", "## Extracting gradients of intermediate values\n", "\n", "For debugging purposes, it can be useful to extract the gradients of intermediate values. This is a little tricky: jax doesn't have a stable mechanism for sowing information from the backward pass into to objects from the forward pass. Instead, we record the gradients of intermediate values using the `Module.perturb` method as follows:" ] }, { "cell_type": "code", "execution_count": 30, "id": "84911b66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "State({\n", " 'grad_of_x': Perturbation( # 1 (4 B)\n", " value=Array(3., dtype=float32, weak_type=True)\n", " ),\n", " 'activations': Intermediate( # 1 (4 B)\n", " value=(Array(1., dtype=float32, weak_type=True),)\n", " )\n", "})" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Model(nnx.Module):\n", " def __call__(self, x):\n", " x2 = self.perturb('grad_of_x', x)\n", " self.sow(nnx.Intermediate, 'activations', x2)\n", " return 3 * x2\n", "\n", "model = Model()\n", "\n", "def train_step(model, x):\n", " _, perturbations = nnx.capture(model, nnx.Perturbation)(x)\n", " def loss(model, perturbations, x):\n", " return nnx.capture(model, nnx.Intermediate, init=perturbations)(x)\n", "\n", " (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x)\n", " return nnx.merge_state(perturb_grads, sowed)\n", "\n", "train_step(model, 1.0)" ] }, { "cell_type": "markdown", "id": "7dc978af", "metadata": {}, "source": [ "There are four steps:\n", "\n", "**Step One: Initialize *perturbations* of the model**.\n", "\n", "We do this with a call to `nnx.capture(model, nnx.Perturbation)`. Before the call, `capture` installs `__captures__` on the module — a tuple containing one empty `Perturbation` buffer (as described in \"How `nnx.capture` Works\" above). When `self.perturb` runs, it checks `__captures__` for a matching `Perturbation` Variable, initialises the slot to `zeros_like(value)`, and returns `zeros + x`. After the call, `__captures__` is removed and the filled buffer is returned as `perturbations`." ] }, { "cell_type": "code", "execution_count": 12, "id": "5f4fda80", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "before perturb: (Perturbation(\n", " value={}\n", "),)\n", "after perturb: (Perturbation( # 1 (4 B)\n", " value={'grad_of_x': Array(0., dtype=float32, weak_type=True)}\n", "),)\n", "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'grad_of_x'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0., dtype=float32, weak_type=True)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "class Model(nnx.Module):\n", " def __call__(self, x):\n", " print(\"before perturb:\", self.__captures__)\n", " x2 = self.perturb('grad_of_x', x)\n", " print(\"after perturb:\", self.__captures__)\n", " self.sow(nnx.Intermediate, 'activations', x2)\n", " # sow is a no-op: Intermediate is not in __captures__, so it returns False silently\n", " return 3 * x2\n", "\n", "model = Model()\n", "_, perturbations = nnx.capture(model, nnx.Perturbation)(1.0)\n", "print(perturbations)" ] }, { "cell_type": "markdown", "id": "2b4e1d98", "metadata": {}, "source": [ "There are only two differences between `sow` and `perturb`:\n", "\n", "- The `nnx.Variable` tag used for values written with `self.perturb` is `nnx.Perturbation` rather than `nnx.Intermediate`.\n", " \n", "- `perturb` returns the logged value. You must use this returned value rather than the original value for the gradient capturing machinery to work.\n", "\n", "The `var_types` argument to `capture` restricts which of the logged values we want to return. Because we only want the intermediates logged with `self.perturb` statements, we only capture `nnx.Perturbation` types.\n", "\n", "**Step Two: Run the model again, but add in these perturbations**.\n", "\n", "Call `capture` again with `init=perturbations`. `capture` first builds a mapping from module path to the Variables in `init`, then uses it to pre-populate `__captures__`. Now `__captures__` has *two* buffers: an empty `Intermediate` buffer (from `var_types`) and a `Perturbation` buffer pre-populated from `init`. `self.perturb` finds the pre-populated buffer and returns `x + perturbation`; `self.sow` writes into the `Intermediate` buffer as normal." ] }, { "cell_type": "code", "execution_count": 13, "id": "a4087d73", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "before perturb: (Intermediate(\n", " value={}\n", "), Perturbation( # 1 (4 B)\n", " value={'grad_of_x': Array(0., dtype=float32, weak_type=True)}\n", "))\n", "after sow: (Intermediate( # 1 (4 B)\n", " value={'activations': (Array(1., dtype=float32, weak_type=True),)}\n", "), Perturbation( # 1 (4 B)\n", " value={'grad_of_x': Array(0., dtype=float32, weak_type=True)}\n", "))\n" ] } ], "source": [ "class Model(nnx.Module):\n", " def __call__(self, x):\n", " print(\"before perturb:\", self.__captures__)\n", " x2 = self.perturb('grad_of_x', x)\n", " self.sow(nnx.Intermediate, 'activations', x2)\n", " print(\"after sow: \", self.__captures__)\n", " return 3 * x2\n", "\n", "model = Model()\n", "_, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(1.0)" ] }, { "cell_type": "markdown", "id": "63e2ba59", "metadata": {}, "source": [ "This changes the behavior of `x2 = self.perturb('name', x)` to essentially be `x2 = x + perturbations['name']`. The gradient of our output with respect to `x` will be the same as the gradient with respect to the perturbation, because JAX can differentiate through the addition with respect to the perturbation value stored in the capture dict.\n", "\n", "**Step Three: Take gradients**.\n", "\n", "Specifically, take the gradient of this second `capture` call with respect to the perturbation arguments. JAX traces through exactly the same `__captures__` setup as Step Two, but with abstract (traced) array values instead of concrete ones. This will give us the same values as the gradients with respect to the intermediate variables. If we want to track intermediate variables in the forward pass at the same time, we'll need to return the intermediate values output of the `capture` call as well, so we'll need to pass `has_aux=True` to `nnx.grad`.\n", "\n", "**Step Four: Combine intermediate states**\n", "\n", "Merge the `State` object we get from the perturbation gradients with the `State` object for forward intermediates with `nnx.merge_state(perturb_grads, sowed)`. At this point `__captures__` no longer exists on any module — it was cleaned up at the end of the `capture` call in Step Three." ] }, { "cell_type": "markdown", "id": "23ccf952", "metadata": {}, "source": [ "## NNX Transforms and Capturing\n", "\n", "`nnx.capture` composes with NNX transforms such as `nnx.vmap`. The main thing to keep in mind is that perturbations must be initialized with a run that has the same batch structure as the training step that will consume them.\n", "\n", "Consider a model that calls both `sow` and `perturb`:" ] }, { "cell_type": "code", "execution_count": 34, "id": "7c8f8d83", "metadata": {}, "outputs": [], "source": [ "class Foo(nnx.Module):\n", " def __init__(self, dim):\n", " self.w = nnx.Param(jax.random.normal(jax.random.key(0), dim))\n", "\n", " def __call__(self, x):\n", " x = self.perturb('grad_of_x', x)\n", " y = jnp.dot(x, self.w)\n", " self.sow(nnx.Intermediate, 'y', y)\n", " return y" ] }, { "cell_type": "markdown", "id": "786278eb", "metadata": {}, "source": [ "The training step vmaps `loss_grad` over a batch of inputs and perturbations, while the model weights are shared across the batch (`in_axes=None`):" ] }, { "cell_type": "code", "execution_count": 35, "id": "8ac86ed1", "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(model, x):\n", " _, perturbations = init_perturbations(model, x)\n", " def loss_grad(model, perturbations, x):\n", " def loss(model, perturbations, x):\n", " loss, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(x)\n", " return loss, interms\n", " (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x)\n", " return grads, nnx.merge_state(perturb_grads, sowed)\n", " return nnx.vmap(loss_grad, in_axes=(None, 0, 0))(model, perturbations, x)" ] }, { "cell_type": "markdown", "id": "3de42b97-9923-436f-9db0-d9aeedc259ad", "metadata": {}, "source": [ "After every training step, we can sum the gradients and pass them to an `Optimizer` to adjust the model, as usual. But we can also look at the full batch of sown values and perturbations.\n", "\n", "Because `train_step` expects `perturbations` to have a leading batch axis (axis 0), the perturbation initialization run must also produce a batched `perturbations` state. We do this inside an `init_perturbations` method that splits the model and vmaps the run with `in_axes=(0, None, 0)` for `(intermediates, params, x)`." ] }, { "cell_type": "code", "execution_count": 36, "id": "76c291c8", "metadata": {}, "outputs": [], "source": [ "@nnx.capture(nnx.Perturbation)\n", "def init_perturbations(model, x):\n", " graphdef, intms, params = nnx.split(model, nnx.Intermediate, nnx.Param)\n", " def forward(intms, params, x):\n", " return nnx.merge(graphdef, intms, params)(x)\n", " return nnx.vmap(forward, in_axes=(0, None, 0))(intms, params, x)" ] }, { "cell_type": "markdown", "id": "981642f6", "metadata": {}, "source": [ "Putting it together:" ] }, { "cell_type": "code", "execution_count": 37, "id": "7a741ca4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "State({\n", " 'grad_of_x': Perturbation(\n", " value=(3, 4)\n", " ),\n", " 'y': Intermediate(\n", " value=((3,),)\n", " )\n", "})" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model, x = Foo(4), jnp.ones((3, 4))\n", "_, intermediates = train_step(model, x)\n", "jax.tree.map(lambda a: a.shape, intermediates)" ] }, { "cell_type": "markdown", "id": "027216c7", "metadata": {}, "source": [ "The pattern generalises: whenever a transform introduces a new batch axis over which `capture` runs, initialize perturbations with a matching vmapped pre-run so that the `init=perturbations` argument inside the transform has the correct shape." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/extracting_intermediates.md ================================================ --- jupyter: jupytext: formats: ipynb,md main_language: python text_representation: extension: .md format_name: markdown format_version: '1.3' jupytext_version: 1.13.8 --- # Extracting intermediate values This guide will show you how to extract intermediate values from a module. Consider a toy neural network with two pieces: a "feature" component that embeds inputs in some feature space, and a "loss" component that operates on those features. We'll want to log these feature components during training to identify any issues with the feature extraction. To do this, we can use the `Module.sow` method. ```python from flax import nnx import jax import jax.numpy as jnp from functools import partial class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.dense1 = nnx.Linear(8, 32, rngs=rngs) self.dense2 = nnx.Linear(32, 1, rngs=rngs) def features(self, x, rngs= None): feature = nnx.relu(self.dense1(x)) self.sow(nnx.Intermediate, 'features', feature) return feature def loss(self, x_features, y_features): return jnp.sum((x_features - y_features)**2) def __call__(self, x, y): return self.loss(self.features(x), self.features(y)) # Instantiate the model. rngs = nnx.Rngs(0) model = Foo(rngs=rngs) # Dummy input for testing x, y = rngs.normal((2,8)) ``` Here, `self.sow` will store intermediate values under the key `'features'` in a collection associated with the `nnx.Intermediate` type. If you want to log values to multiple different collections, you can use different subclasses of `nnx.Intermediate` for each collection. Now, we can wrap the module with the `nnx.capture` decorator, which wraps any `Callable` accepting a module as its argument (which includes `nnx.Module`s, their methods, or ordinary functions) to return both the resulting loss as well as any intermediate values stored to the `nnx.Intermediate` collection: ```python capturing_model = nnx.capture(model, nnx.Intermediate) result, intms = capturing_model(x, y) jax.tree.map(lambda a: a.shape, intms) ``` Note that, by default, sow appends values every time it is called. We can see this in the *features* intermediate values logged above. It contains a tuple with one element for the call on `x` and one for the call on `y`. To override the default append behavior, specify `init_fn` and `reduce_fn` - see `Module.sow()`. ## How `nnx.capture` Works `nnx.capture` works by temporarily installing a set of mutable capture buffers on every module in the graph before calling the wrapped function, then harvesting those buffers afterward. Before calling the wrapped function, `capture` walks the entire module graph with `iter_modules`. For each module it sets a `__captures__` attribute: a tuple of Variable instances, one per requested `var_type`. Each Variable holds a plain `dict` that maps sow-key → accumulated value. We can see this `__captures__` tuple by printing out the module contents during a `nnx.capture` call: ```python @nnx.capture(nnx.Intermediate) def print_captures(model): print("Captures:", model.__captures__) _, intms = print_captures(nnx.Module()) ``` `Module.sow` looks for the Variable in the `__captures__` tuple whose type matches `variable_type`, then writes its value into that dict using `reduce_fn`. If no matching type is found, `sow` silently returns `False` without logging the value. This can be used to capture only a subset of the sown values. For example: ```python class Metric1(nnx.Intermediate): pass class Metric2(nnx.Intermediate): pass @nnx.capture(Metric1) def get_captures(model): model.sow(Metric1, 'gets_sown', jnp.ones(2)) model.sow(Metric2, 'gets_ignored', jnp.ones(2)) _, intms = get_captures(nnx.Module()) jax.tree.map(lambda a: a.shape, intms) ``` ## Capturing all intermediate values To observe the output of each method without manually adding calls to `sow`, we can call `nnx.capture` with the `method_outputs` argument. This will automatically `sow` the output of each method using the given variable type, including methods of sub-modules. ```python class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.dense1 = nnx.Linear(8, 32, rngs=rngs) self.dense2 = nnx.Linear(32, 1, rngs=rngs) def features(self, x, rngs= None): feature = nnx.relu(self.dense1(x)) return feature def loss(self, x_features, y_features): return jnp.sum((x_features - y_features)**2) def __call__(self, x, y): return self.loss(self.features(x), self.features(y)) model = Foo(rngs=nnx.Rngs(0)) capturing_model = nnx.capture(model, nnx.Intermediate, method_outputs=nnx.Intermediate) result, intms = capturing_model(x, y) jax.tree.map(lambda a: a.shape, intms) ``` This pattern should be considered the "sledge hammer" approach to capturing intermediates. As a debugging and inspection tool it is very useful, but using the other patterns described in this guide will give you more fine-grained control over what intermediates you want to extract. We can also combine the `method_output_type` argument with manual calls to sow to capture both layer outputs and computations mid-layer. ## Extracting gradients of intermediate values For debugging purposes, it can be useful to extract the gradients of intermediate values. This is a little tricky: jax doesn't have a stable mechanism for sowing information from the backward pass into to objects from the forward pass. Instead, we record the gradients of intermediate values using the `Module.perturb` method as follows: ```python class Model(nnx.Module): def __call__(self, x): x2 = self.perturb('grad_of_x', x) self.sow(nnx.Intermediate, 'activations', x2) return 3 * x2 model = Model() def train_step(model, x): _, perturbations = nnx.capture(model, nnx.Perturbation)(x) def loss(model, perturbations, x): return nnx.capture(model, nnx.Intermediate, init=perturbations)(x) (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) return nnx.merge_state(perturb_grads, sowed) train_step(model, 1.0) ``` There are four steps: **Step One: Initialize *perturbations* of the model**. We do this with a call to `nnx.capture(model, nnx.Perturbation)`. Before the call, `capture` installs `__captures__` on the module — a tuple containing one empty `Perturbation` buffer (as described in "How `nnx.capture` Works" above). When `self.perturb` runs, it checks `__captures__` for a matching `Perturbation` Variable, initialises the slot to `zeros_like(value)`, and returns `zeros + x`. After the call, `__captures__` is removed and the filled buffer is returned as `perturbations`. ```python class Model(nnx.Module): def __call__(self, x): print("before perturb:", self.__captures__) x2 = self.perturb('grad_of_x', x) print("after perturb:", self.__captures__) self.sow(nnx.Intermediate, 'activations', x2) # sow is a no-op: Intermediate is not in __captures__, so it returns False silently return 3 * x2 model = Model() _, perturbations = nnx.capture(model, nnx.Perturbation)(1.0) print(perturbations) ``` There are only two differences between `sow` and `perturb`: - The `nnx.Variable` tag used for values written with `self.perturb` is `nnx.Perturbation` rather than `nnx.Intermediate`. - `perturb` returns the logged value. You must use this returned value rather than the original value for the gradient capturing machinery to work. The `var_types` argument to `capture` restricts which of the logged values we want to return. Because we only want the intermediates logged with `self.perturb` statements, we only capture `nnx.Perturbation` types. **Step Two: Run the model again, but add in these perturbations**. Call `capture` again with `init=perturbations`. `capture` first builds a mapping from module path to the Variables in `init`, then uses it to pre-populate `__captures__`. Now `__captures__` has *two* buffers: an empty `Intermediate` buffer (from `var_types`) and a `Perturbation` buffer pre-populated from `init`. `self.perturb` finds the pre-populated buffer and returns `x + perturbation`; `self.sow` writes into the `Intermediate` buffer as normal. ```python class Model(nnx.Module): def __call__(self, x): print("before perturb:", self.__captures__) x2 = self.perturb('grad_of_x', x) self.sow(nnx.Intermediate, 'activations', x2) print("after sow: ", self.__captures__) return 3 * x2 model = Model() _, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(1.0) ``` This changes the behavior of `x2 = self.perturb('name', x)` to essentially be `x2 = x + perturbations['name']`. The gradient of our output with respect to `x` will be the same as the gradient with respect to the perturbation, because JAX can differentiate through the addition with respect to the perturbation value stored in the capture dict. **Step Three: Take gradients**. Specifically, take the gradient of this second `capture` call with respect to the perturbation arguments. JAX traces through exactly the same `__captures__` setup as Step Two, but with abstract (traced) array values instead of concrete ones. This will give us the same values as the gradients with respect to the intermediate variables. If we want to track intermediate variables in the forward pass at the same time, we'll need to return the intermediate values output of the `capture` call as well, so we'll need to pass `has_aux=True` to `nnx.grad`. **Step Four: Combine intermediate states** Merge the `State` object we get from the perturbation gradients with the `State` object for forward intermediates with `nnx.merge_state(perturb_grads, sowed)`. At this point `__captures__` no longer exists on any module — it was cleaned up at the end of the `capture` call in Step Three. ## NNX Transforms and Capturing `nnx.capture` composes with NNX transforms such as `nnx.vmap`. The main thing to keep in mind is that perturbations must be initialized with a run that has the same batch structure as the training step that will consume them. Consider a model that calls both `sow` and `perturb`: ```python class Foo(nnx.Module): def __init__(self, dim): self.w = nnx.Param(jax.random.normal(jax.random.key(0), dim)) def __call__(self, x): x = self.perturb('grad_of_x', x) y = jnp.dot(x, self.w) self.sow(nnx.Intermediate, 'y', y) return y ``` The training step vmaps `loss_grad` over a batch of inputs and perturbations, while the model weights are shared across the batch (`in_axes=None`): ```python @nnx.jit def train_step(model, x): _, perturbations = init_perturbations(model, x) def loss_grad(model, perturbations, x): def loss(model, perturbations, x): loss, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(x) return loss, interms (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) return grads, nnx.merge_state(perturb_grads, sowed) return nnx.vmap(loss_grad, in_axes=(None, 0, 0))(model, perturbations, x) ``` After every training step, we can sum the gradients and pass them to an `Optimizer` to adjust the model, as usual. But we can also look at the full batch of sown values and perturbations. Because `train_step` expects `perturbations` to have a leading batch axis (axis 0), the perturbation initialization run must also produce a batched `perturbations` state. We do this inside an `init_perturbations` method that splits the model and vmaps the run with `in_axes=(0, None, 0)` for `(intermediates, params, x)`. ```python @nnx.capture(nnx.Perturbation) def init_perturbations(model, x): graphdef, intms, params = nnx.split(model, nnx.Intermediate, nnx.Param) def forward(intms, params, x): return nnx.merge(graphdef, intms, params)(x) return nnx.vmap(forward, in_axes=(0, None, 0))(intms, params, x) ``` Putting it together: ```python model, x = Foo(4), jnp.ones((3, 4)) _, intermediates = train_step(model, x) jax.tree.map(lambda a: a.shape, intermediates) ``` The pattern generalises: whenever a transform introduces a new batch axis over which `capture` runs, initialize perturbations with a matching vmapped pre-run so that the `init=perturbations` argument inside the transform has the correct shape. ================================================ FILE: docs_nnx/guides/filters_guide.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "95b08e64", "metadata": {}, "source": [ "# Filters\n", "\n", "Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html).\n", "\n", "In this guide you will learn how to:\n", "\n", "* Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups;\n", "* Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html);\n", "* Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language.\n", "\n", "In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics:" ] }, { "cell_type": "code", "execution_count": 1, "id": "45485345", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params = State({\n", " 'a': Param(\n", " value=0\n", " )\n", "})\n", "batch_stats = State({\n", " 'b': BatchStat(\n", " value=True\n", " )\n", "})\n" ] } ], "source": [ "from flax import nnx\n", "\n", "class Foo(nnx.Module):\n", " def __init__(self):\n", " self.a = nnx.Param(0)\n", " self.b = nnx.BatchStat(True)\n", "\n", "foo = Foo()\n", "\n", "graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)\n", "\n", "print(f'{params = }')\n", "print(f'{batch_stats = }')" ] }, { "cell_type": "markdown", "id": "8f77e99a", "metadata": {}, "source": [ "Let's dive deeper into `Filter`s." ] }, { "cell_type": "markdown", "id": "a0413d64", "metadata": {}, "source": [ "## The `Filter` Protocol\n", "\n", "In general, Flax `Filter`s are predicate functions of the form:\n", "\n", "```python\n", "\n", "(path: tuple[Key, ...], value: Any) -> bool\n", "\n", "```\n", "\n", "where:\n", "\n", "- `Key` is a hashable and comparable type;\n", "- `path` is a tuple of `Key`s representing the path to the value in a nested structure; and\n", "- `value` is the value at the path.\n", "\n", "The function returns `True` if the value should be included in the group, and `False` otherwise.\n", "\n", "Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this:" ] }, { "cell_type": "code", "execution_count": 2, "id": "30f4c868", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is_param((), nnx.Param(0)) = True\n" ] } ], "source": [ "def is_param(path, value) -> bool:\n", " return isinstance(value, nnx.Param)\n", "\n", "print(f'{is_param((), nnx.Param(0)) = }')" ] }, { "cell_type": "markdown", "id": "a8a2641e", "metadata": {}, "source": [ "Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type:" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3095221", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is_param((), nnx.Param(0)) = True\n" ] } ], "source": [ "is_param = nnx.OfType(nnx.Param)\n", "\n", "print(f'{is_param((), nnx.Param(0)) = }')" ] }, { "cell_type": "markdown", "id": "87c06e39", "metadata": {}, "source": [ "## The `Filter` DSL\n", "\n", "Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section.\n", "\n", "Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available):\n", "\n", "\n", "| Literal | Callable | Description |\n", "|--------|----------------------|-------------|\n", "| `...` or `True` | `Everything()` | Matches all values |\n", "| `None` or `False` | `Nothing()` | Matches no values |\n", "| `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` |\n", "| | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` |\n", "| `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. |\n", "| `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` |\n", "| | `All(*filters)` | Matches values that match all of the inner `filters` |\n", "| | `Not(filter)` | Matches values that do not match the inner `filter` |\n", "\n", "\n", "Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following:\n", "\n", "1) You want to vectorize all parameters;\n", "2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and\n", "3) Broadcast the rest.\n", "\n", "To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized:" ] }, { "cell_type": "code", "execution_count": 4, "id": "d38b7694", "metadata": {}, "outputs": [], "source": [ "state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})\n", "\n", "@nnx.vmap(in_axes=(state_axes, 0))\n", "def forward(model, x):\n", " ..." ] }, { "cell_type": "markdown", "id": "bd60f0e1", "metadata": {}, "source": [ "Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`.\n", "\n", "If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate):" ] }, { "cell_type": "code", "execution_count": 5, "id": "7e065fa9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "is_param = OfType()\n", "everything = Everything()\n", "nothing = Nothing()\n", "params_or_dropout = Any(OfType(), WithTag('dropout'))\n" ] } ], "source": [ "is_param = nnx.filterlib.to_predicate(nnx.Param)\n", "everything = nnx.filterlib.to_predicate(...)\n", "nothing = nnx.filterlib.to_predicate(False)\n", "params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))\n", "\n", "print(f'{is_param = }')\n", "print(f'{everything = }')\n", "print(f'{nothing = }')\n", "print(f'{params_or_dropout = }')" ] }, { "cell_type": "markdown", "id": "db9b4cf3", "metadata": {}, "source": [ "## Grouping `State`s\n", "\n", "With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas:\n", "\n", "* Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node.\n", "* Convert all the `Filter`s to predicates.\n", "* Use `State.flat_state` to get the flat representation of the state.\n", "* Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates.\n", "* Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s." ] }, { "cell_type": "code", "execution_count": 6, "id": "068208fc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params = State({\n", " 'a': Param(\n", " value=0\n", " )\n", "})\n", "batch_stats = State({\n", " 'b': BatchStat(\n", " value=True\n", " )\n", "})\n" ] } ], "source": [ "from typing import Any\n", "KeyPath = tuple[nnx.graph.Key, ...]\n", "\n", "def split(node, *filters):\n", " graphdef, state = nnx.graph.flatten(node)\n", " predicates = [nnx.filterlib.to_predicate(f) for f in filters]\n", " flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]\n", "\n", " for path, value in state:\n", " for i, predicate in enumerate(predicates):\n", " if predicate(path, value):\n", " flat_states[i][path] = value\n", " break\n", " else:\n", " raise ValueError(f'No filter matched {path = } {value = }')\n", "\n", " states: tuple[nnx.GraphState, ...] = tuple(\n", " nnx.State.from_flat_path(flat_state) for flat_state in flat_states\n", " )\n", " return graphdef, *states\n", "\n", "# Let's test it.\n", "foo = Foo()\n", "\n", "graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)\n", "\n", "print(f'{params = }')\n", "print(f'{batch_stats = }')" ] }, { "cell_type": "markdown", "id": "7b3aeac8", "metadata": {}, "source": [ "**Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s.\n", "\n", "For example, as demonstrated below, if you:\n", "\n", "1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and\n", "2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s\n", "\n", "then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s:" ] }, { "cell_type": "code", "execution_count": 7, "id": "014da4d4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params = State({\n", " 'a': Param(\n", " value=0\n", " ),\n", " 'b': SpecialParam(\n", " value=0\n", " )\n", "})\n", "special_params = State({})\n" ] } ], "source": [ "class SpecialParam(nnx.Param):\n", " pass\n", "\n", "class Bar(nnx.Module):\n", " def __init__(self):\n", " self.a = nnx.Param(0)\n", " self.b = SpecialParam(0)\n", "\n", "bar = Bar()\n", "\n", "graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!\n", "print(f'{params = }')\n", "print(f'{special_params = }')" ] }, { "cell_type": "markdown", "id": "a9f0b7b8", "metadata": {}, "source": [ "And reversing the order will ensure that the `SpecialParam` are captured first:" ] }, { "cell_type": "code", "execution_count": 8, "id": "a2ebf5b2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params = State({\n", " 'a': Param(\n", " value=0\n", " )\n", "})\n", "special_params = State({\n", " 'b': SpecialParam(\n", " value=0\n", " )\n", "})\n" ] } ], "source": [ "graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!\n", "print(f'{params = }')\n", "print(f'{special_params = }')" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/filters_guide.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Filters Flax NNX uses [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) extensively as a way to create [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) groups in APIs, such as [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`nnx.state()`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.state), and many of the [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html). In this guide you will learn how to: * Use [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) to group Flax NNX variables and states into subgroups; * Understand relationships between types, such as [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) or [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat), and [`Filter`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html); * Express your `Filter`s flexibly with [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) language. In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat) are used as `Filter`s to split the model into two groups: one with the parameters and the other with the batch statistics: ```{code-cell} ipython3 from flax import nnx class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(0) self.b = nnx.BatchStat(True) foo = Foo() graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat) print(f'{params = }') print(f'{batch_stats = }') ``` Let's dive deeper into `Filter`s. +++ ## The `Filter` Protocol In general, Flax `Filter`s are predicate functions of the form: ```python (path: tuple[Key, ...], value: Any) -> bool ``` where: - `Key` is a hashable and comparable type; - `path` is a tuple of `Key`s representing the path to the value in a nested structure; and - `value` is the value at the path. The function returns `True` if the value should be included in the group, and `False` otherwise. Types are not functions of this form. They are treated as `Filter`s because, as you will learn in the next section, types and some other literals are converted to _predicates_. For example, [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) is roughly converted to a predicate like this: ```{code-cell} ipython3 def is_param(path, value) -> bool: return isinstance(value, nnx.Param) print(f'{is_param((), nnx.Param(0)) = }') ``` Such function matches any value that is an instance of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param). Internally Flax NNX uses `OfType` which defines a callable of this form for a given type: ```{code-cell} ipython3 is_param = nnx.OfType(nnx.Param) print(f'{is_param((), nnx.Param(0)) = }') ``` ## The `Filter` DSL Flax NNX exposes a small domain specific language ([DSL](https://en.wikipedia.org/wiki/Domain-specific_language)), formalized as the [`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html) type. This means users don't have to create functions like in the previous section. Here is a list of all the callable `Filter`s included in Flax NNX, and their corresponding DSL literals (when available): | Literal | Callable | Description | |--------|----------------------|-------------| | `...` or `True` | `Everything()` | Matches all values | | `None` or `False` | `Nothing()` | Matches no values | | `type` | `OfType(type)` | Matches values that are instances of `type` or have a `type` attribute that is an instance of `type` | | | `PathContains(key)` | Matches values that have an associated `path` that contains the given `key` | | `'{filter}'` str | `WithTag('{filter}')` | Matches values that have string `tag` attribute equal to `'{filter}'`. Used by `RngKey` and `RngCount`. | | `(*filters)` tuple or `[*filters]` list | `Any(*filters)` | Matches values that match any of the inner `filters` | | | `All(*filters)` | Matches values that match all of the inner `filters` | | | `Not(filter)` | Matches values that do not match the inner `filter` | Let's check out the DSL in action by using [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) as an example. Consider the following: 1) You want to vectorize all parameters; 2) Apply `'dropout'` `Rng(Keys|Counts)` on the `0`th axis; and 3) Broadcast the rest. To do this, you can use the following `Filter`s to define a `nnx.StateAxes` object that you can pass to `nnx.vmap`'s `in_axes` to specify how the `model`'s various sub-states should be vectorized: ```{code-cell} ipython3 state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None}) @nnx.vmap(in_axes=(state_axes, 0)) def forward(model, x): ... ``` Here `(nnx.Param, 'dropout')` expands to `Any(OfType(nnx.Param), WithTag('dropout'))` and `...` expands to `Everything()`. If you wish to manually convert literal into a predicate, you can use [`nnx.filterlib.to_predicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/filterlib.html#flax.nnx.filterlib.to_predicate): ```{code-cell} ipython3 is_param = nnx.filterlib.to_predicate(nnx.Param) everything = nnx.filterlib.to_predicate(...) nothing = nnx.filterlib.to_predicate(False) params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout')) print(f'{is_param = }') print(f'{everything = }') print(f'{nothing = }') print(f'{params_or_dropout = }') ``` ## Grouping `State`s With the knowledge of `Filter`s from previous sections at hand, let's learn how to roughly implement [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). Here are the key ideas: * Use `nnx.graph.flatten` to get the [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) and [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) representation of the node. * Convert all the `Filter`s to predicates. * Use `State.flat_state` to get the flat representation of the state. * Traverse all the `(path, value)` pairs in the flat state and group them according to the predicates. * Use `State.from_flat_state` to convert the flat states to nested [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. ```{code-cell} ipython3 from typing import Any KeyPath = tuple[nnx.graph.Key, ...] def split(node, *filters): graphdef, state = nnx.graph.flatten(node) predicates = [nnx.filterlib.to_predicate(f) for f in filters] flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates] for path, value in state: for i, predicate in enumerate(predicates): if predicate(path, value): flat_states[i][path] = value break else: raise ValueError(f'No filter matched {path = } {value = }') states: tuple[nnx.GraphState, ...] = tuple( nnx.State.from_flat_path(flat_state) for flat_state in flat_states ) return graphdef, *states # Let's test it. foo = Foo() graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat) print(f'{params = }') print(f'{batch_stats = }') ``` **Note:*** It's very important to know that **filtering is order-dependent**. The first `Filter` that matches a value will keep it, and therefore you should place more specific `Filter`s before more general `Filter`s. For example, as demonstrated below, if you: 1) Create a `SpecialParam` type that is a subclass of [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param), and a `Bar` object (subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)) that contains both types of parameters; and 2) Try to split the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s before the `SpecialParam`s then all the values will be placed in the [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) group, and the `SpecialParam` group will be empty because all `SpecialParam`s are also [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s: ```{code-cell} ipython3 class SpecialParam(nnx.Param): pass class Bar(nnx.Module): def __init__(self): self.a = nnx.Param(0) self.b = SpecialParam(0) bar = Bar() graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong! print(f'{params = }') print(f'{special_params = }') ``` And reversing the order will ensure that the `SpecialParam` are captured first: ```{code-cell} ipython3 graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct! print(f'{params = }') print(f'{special_params = }') ``` ================================================ FILE: docs_nnx/guides/flax_gspmd.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scale up on multiple devices\n", "\n", "This guide demonstrates how to scale up a Flax NNX model on multiple accelerators (GPUs or Google TPUs) using JAX's parallel programming APIs.\n", "\n", "[Introduction to Parallel Programming](https://docs.jax.dev/en/latest/sharded-computation.html) is a fantastic guide to learn about the distributed programming essentials of JAX. It describes three parallelism APIs - automatic, explicit and manual - for different levels of control.\n", "\n", "This guide will primarily cover the automatic scenario, which use the [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html) to compile your single-device code as multi-device. You will use [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) APIs to annotate your model variables with how it should be sharded.\n", "\n", "If you want to follow explicit sharding style, follow [JAX Explicit Sharding](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) guide and use JAX's relevant APIs. No API on Flax side is needed." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You have 8 “fake” JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n" ] } ], "source": [ "from functools import partial\n", "\n", "import jax\n", "from jax import numpy as jnp\n", "from jax.sharding import PartitionSpec as P, NamedSharding, AxisType\n", "import optax\n", "import flax\n", "from flax import nnx\n", "\n", "# Ignore this if you are already running on a TPU or GPU\n", "if not jax._src.xla_bridge.backends_are_initialized():\n", " jax.config.update('jax_num_cpu_devices', 8)\n", "print(f'You have 8 “fake” JAX devices now: {jax.devices()}')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs.\n", "\n", "In this guide we use a standard FSDP layout and shard our devices on two axes - `data` and `model`, for doing batch data parallelism and tensor parallelism." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n", "auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "nnx.use_eager_sharding(True)\n", "assert nnx.using_eager_sharding()" ] }, { "cell_type": "markdown", "id": "c24144d8", "metadata": {}, "source": [ "The `nnx.use_eager_sharding` function can also be used as a context manager to toggle the eager sharding feature within a specific scope." ] }, { "cell_type": "code", "execution_count": null, "id": "2d849e2e", "metadata": {}, "outputs": [], "source": [ "with nnx.use_eager_sharding(False):\n", " assert not nnx.using_eager_sharding()" ] }, { "cell_type": "markdown", "id": "c9f808ec", "metadata": {}, "source": [ "You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way." ] }, { "cell_type": "code", "execution_count": null, "id": "67bbd440", "metadata": {}, "outputs": [], "source": [ "nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Shard a single-array model\n", "\n", "Let's begin by sharding the simplest component possible - a Flax variable.\n", "\n", "When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n", "\n", "**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PartitionSpec(None, 'model')\n" ] }, { "data": { "text/html": [ "
                                                \n",
       "                                                \n",
       "                                                \n",
       "                                                \n",
       "                                                \n",
       "  CPU 0,4     CPU 1,5     CPU 2,6     CPU 3,7   \n",
       "                                                \n",
       "                                                \n",
       "                                                \n",
       "                                                \n",
       "                                                \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0,4\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 1,5\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mCPU 2,6\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mCPU 3,7\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rngs = nnx.Rngs(0)\n", "\n", "with jax.set_mesh(auto_mesh):\n", " w = nnx.Param(\n", " rngs.lecun_normal()((4, 8)),\n", " out_sharding=(None, 'model')\n", " )\n", " print(w.sharding.spec)\n", " jax.debug.visualize_array_sharding(w) # already sharded!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialize with style\n", "\n", "When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n", "\n", "Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def init_sharded_linear(key):\n", " init_fn = nnx.nn.linear.default_kernel_init\n", " # Shard your parameter along `model` dimension, as in model/tensor parallelism\n", " return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n", " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n", "\n", "with jax.set_mesh(auto_mesh):\n", " key= rngs()\n", " linear = init_sharded_linear(key)\n", " assert linear.kernel.sharding.spec == P(None, 'model') # already sharded!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the model\n", "\n", "If you also shard your input correctly, JAX would be able to carry out the most natural and optimized computation and produce your output as sharded.\n", "\n", "You should still make sure to `jax.jit` for maximum performance, and also to explicitly control how each array is sharded when you want to. We will give an example of that control in the next section.\n", "\n", "> Note: You need to `jax.jit` a pure function that takes the model as an argument, instead of jitting the callable model directly." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PartitionSpec('data', 'model')\n" ] }, { "data": { "text/html": [ "
                                    \n",
       "                                    \n",
       "  CPU 0    CPU 1    CPU 2    CPU 3  \n",
       "                                    \n",
       "                                    \n",
       "                                    \n",
       "                                    \n",
       "                                    \n",
       "  CPU 4    CPU 5    CPU 6    CPU 7  \n",
       "                                    \n",
       "                                    \n",
       "                                    \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# For simple computations, you can get correctly-sharded output without jitting\n", "# In this case, ('data', None) @ (None, 'model') = ('data', 'model')\n", "with jax.set_mesh(auto_mesh):\n", " # Create your input data, sharded along `data` dimension, as in data parallelism\n", " x = jax.device_put(jnp.ones((16, 4)), P('data', None))\n", "\n", " # Run the model forward function, jitted\n", " y = jax.jit(lambda m, x: m(x))(linear, x)\n", " print(y.sharding.spec) # sharded: ('data', 'model')\n", " jax.debug.visualize_array_sharding(y)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Shard a wholesome model\n", "\n", "Now we construct a more wholesome model to show a few advanced tricks. Check out this simple `DotReluDot` module that does two matmuls, and the `MultiDotReluDot` module that creates an arbitrary stack of `DotReluDot` sublayers.\n", "\n", "Make note of the following:\n", "\n", "* **Additional axis annotation**: Transforms like `vmap` and `scan` will add additional dimensions to the JAX arrays. Unfortunately, in auto sharding mode you will need to use `nnx.vmap` and `nnx.scan` instead of raw JAX transforms, so that both JAX and Flax knows how to shard this dimension. You won't need this in [explicit sharding mode](#explicit-sharding).\n", "\n", "* [`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class DotReluDot(nnx.Module):\n", " def __init__(self, depth: int, rngs: nnx.Rngs):\n", " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n", " use_bias=False, # or use `bias_init` to give it annotation too\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", " sharding=('model', None),\n", " )\n", "\n", " def __call__(self, x: jax.Array):\n", " y = self.dot1(x)\n", " y = jax.nn.relu(y)\n", " y = jax.lax.with_sharding_constraint(y, P('data', 'model'))\n", " z = jnp.dot(y, self.w2[...])\n", " return z\n", "\n", "class MultiDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n", " # Annotate the additional axis with sharding=None, meaning it will be\n", " # replicated across all devices.\n", " @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})\n", " def create_sublayers(r):\n", " return DotReluDot(depth, r)\n", " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", "\n", " def __call__(self, x):\n", " def scan_over_layers(x, layer):\n", " return layer(x), None\n", " x, _ = jax.lax.scan(scan_over_layers, x, self.layers)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now a sample training loop, using `jax.jit`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.251457\n", "0.8495563\n", "0.6590716\n", "0.5399748\n", "0.39150265\n" ] } ], "source": [ "@jax.jit\n", "def train_step(model, optimizer, x, y):\n", " def loss_fn(model: DotReluDot):\n", " y_pred = model(x)\n", " return jnp.mean((y_pred - y) ** 2)\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", " return model, loss\n", "\n", "\n", "with jax.set_mesh(auto_mesh):\n", " # Training data\n", " input = jax.device_put(rngs.normal((8, 1024)), P('data', None))\n", " label = jax.device_put(rngs.normal((8, 1024)), P('data', None))\n", " # Model and optimizer\n", " model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n", " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", "\n", " # The loop\n", " for i in range(5):\n", " model, loss = train_step(model, optimizer, input, label)\n", " print(loss) # Model (over-)fitting to the labels quickly." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Profiling\n", "\n", "If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "13 ms ± 588 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%%timeit\n", "\n", "def block_all(xs):\n", " jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)\n", " return xs\n", "\n", "with jax.set_mesh(auto_mesh):\n", " new_state = block_all(train_step(model, optimizer, input, label))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load a sharded model from a checkpoint\n", "\n", "Now you learned how to initialize a sharded model without OOM, but what about saving and loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), support loading a model distributedly if a sharding pytree is provided. Below is an example that uses Orbax's `StandardCheckpointer` API.\n", "\n", "Make sure you save a model's state, especially if your model shares some variables across modules. Given a You can generate an identical abstract pytree with shardings using Flax’s `nnx.get_abstract_model`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PartitionSpec(None, None, 'model')\n", "(2, 1024, 1024)\n" ] } ], "source": [ "import orbax.checkpoint as ocp\n", "\n", "# Save the sharded state.\n", "sharded_state = nnx.state(model)\n", "path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')\n", "checkpointer = ocp.StandardCheckpointer()\n", "checkpointer.save(path / 'checkpoint_name', sharded_state)\n", "\n", "# Load a sharded state from the checkpoint.\n", "graphdef, abs_state = nnx.get_abstract_model(\n", " lambda: MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)), auto_mesh)\n", "restored_state = checkpointer.restore(path / 'checkpoint_name',\n", " target=abs_state)\n", "restored_model = nnx.merge(graphdef, abs_state)\n", "print(restored_model.layers.dot1.kernel.sharding.spec)\n", "print(restored_model.layers.dot1.kernel.shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Logical axis annotation\n", "\n", "JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes.\n", "\n", "You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# The mapping from alias annotation to the device mesh.\n", "sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None))\n", "\n", "class LogicalDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, rngs: nnx.Rngs):\n", " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", " kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n", " use_bias=False, # or use `bias_init` to give it annotation too\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", " sharding=('hidden', 'embed'),\n", " )\n", "\n", " def __call__(self, x: jax.Array):\n", " y = self.dot1(x)\n", " y = jax.nn.relu(y)\n", " # Unfortunately the logical aliasing doesn't work on lower-level JAX calls.\n", " y = jax.lax.with_sharding_constraint(y, P('data', None))\n", " z = jnp.dot(y, self.w2[...])\n", " return z\n", "\n", "class LogicalMultiDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n", " @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})\n", " def create_sublayers(r):\n", " return LogicalDotReluDot(depth, r)\n", " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", "\n", " def __call__(self, x):\n", " def scan_over_layers(x, layer):\n", " return layer(x), None\n", " x, _ = jax.lax.scan(scan_over_layers, x, self.layers)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you didn't provide all `sharding_rule` annotations in the model definition, you can apply them at top level by put them into the context via `nnx.logical_axis_rules`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "with jax.set_mesh(auto_mesh), nnx.logical_axis_rules(sharding_rules):\n", " # Model and optimizer\n", " logical_model = LogicalMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n", " logical_output = logical_model(input)\n", "\n", "# Check out their equivalency with some easier-to-read sharding descriptions.\n", "assert logical_model.layers.dot1.kernel.sharding.is_equivalent_to(\n", " NamedSharding(auto_mesh, P(None, None, 'model')), ndim=3\n", ")\n", "assert logical_model.layers.w2.sharding.is_equivalent_to(\n", " NamedSharding(auto_mesh, P(None, 'model', None)), ndim=3\n", ")\n", "assert logical_output.sharding.is_equivalent_to(\n", " NamedSharding(auto_mesh, P('data', None)), ndim=2\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### When to use device axis / logical axis\n", "\n", "Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model:\n", "\n", "* **Device mesh axis**:\n", "\n", " * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", "\n", "* **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explicit sharding\n", "\n", "[Explicit sharding](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html), also called \"sharding-in-types\", is a new JAX sharding feature that allows every sharding of every array to be deterministic and explicit. Instead of letting XLA compiler figure out the shardings, you as user would explicitly state the shardings via JAX APIs.\n", "\n", "For education purposes, we provide a simple Flax model example using explicit sharding. Note how you specify shardings for this model:\n", "\n", "* Parameters: `out_sharding` argument passed into JAX initializers.\n", "\n", "* Ambigious computations like `jnp.dot`: provide `out_sharding` argument to specify the output sharding.\n", "\n", "* Additional dimension from transforms: use `jax.vmap`'s argument `spmd_axis_name`, instead of Flax lifted transforms." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PartitionSpec(None, None, 'model')\n", "PartitionSpec(None, 'model', None)\n" ] } ], "source": [ "# Explicit axis mesh\n", "explicit_mesh = jax.make_mesh((2, 4), ('data', 'model'),\n", " axis_types=(AxisType.Explicit, AxisType.Explicit))\n", "\n", "class ExplicitDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, rngs: nnx.Rngs):\n", " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", " kernel_init=partial(init_fn, out_sharding=P(None, 'model')),\n", " use_bias=False,\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", " init_fn(rngs.params(), (depth, depth), out_sharding=P('model', None)),\n", " )\n", " self.b2 = nnx.Param(jnp.zeros((depth, ), out_sharding=P(None,)))\n", "\n", " def __call__(self, x: jax.Array):\n", " y = self.dot1(x)\n", " y = jax.nn.relu(y)\n", " z = jnp.dot(y, self.w2[...], out_sharding=P('data', None))\n", " z = z + self.b2\n", " return z\n", "\n", "\n", "class ExplicitMultiDotReluDot(nnx.Module):\n", " def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n", " # Annotate the additional axis with sharding=None, meaning it will be\n", " # replicated across all devices.\n", " @partial(jax.vmap, spmd_axis_name=None)\n", " def create_sublayers(r):\n", " return ExplicitDotReluDot(depth, r)\n", " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", "\n", " def __call__(self, x):\n", " def scan_over_layers(x, layer):\n", " return layer(x), None\n", " x, _ = jax.lax.scan(scan_over_layers, x, self.layers)\n", " return x\n", "\n", "\n", "with jax.set_mesh(explicit_mesh):\n", " model = ExplicitMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n", " x = jax.device_put(rngs.normal((8, 1024)),\n", " NamedSharding(explicit_mesh, P('data', None)))\n", " y = model(x)\n", "\n", "print(model.layers.dot1.kernel.sharding.spec)\n", "print(model.layers.w2.sharding.spec)\n", "assert x.sharding.is_equivalent_to(y.sharding, ndim=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One thing easier in explicit mode is that you can obtain the abstract array tree with shardings via `jax.eval_shape`, instead of calling `nnx.get_abstract_sharding`. This is not possible in auto mode." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PartitionSpec(None, None, 'model')\n", "PartitionSpec(None, 'model', None)\n" ] } ], "source": [ "# Get the sharding tree to load checkpoint with\n", "with jax.set_mesh(explicit_mesh):\n", " abs_model = jax.eval_shape(\n", " lambda: ExplicitMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)))\n", " print(abs_model.layers.dot1.kernel.sharding.spec)\n", " print(abs_model.layers.w2.sharding.spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Further readings\n", "\n", "JAX has abundant documentation on scaled computing.\n", "\n", "- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n", "- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).\n", "- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html)." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: docs_nnx/guides/flax_gspmd.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Scale up on multiple devices This guide demonstrates how to scale up a Flax NNX model on multiple accelerators (GPUs or Google TPUs) using JAX's parallel programming APIs. [Introduction to Parallel Programming](https://docs.jax.dev/en/latest/sharded-computation.html) is a fantastic guide to learn about the distributed programming essentials of JAX. It describes three parallelism APIs - automatic, explicit and manual - for different levels of control. This guide will primarily cover the automatic scenario, which use the [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html) to compile your single-device code as multi-device. You will use [`flax.nnx.spmd`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) APIs to annotate your model variables with how it should be sharded. If you want to follow explicit sharding style, follow [JAX Explicit Sharding](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html) guide and use JAX's relevant APIs. No API on Flax side is needed. +++ ### Setup ```{code-cell} ipython3 from functools import partial import jax from jax import numpy as jnp from jax.sharding import PartitionSpec as P, NamedSharding, AxisType import optax import flax from flax import nnx # Ignore this if you are already running on a TPU or GPU if not jax._src.xla_bridge.backends_are_initialized(): jax.config.update('jax_num_cpu_devices', 8) print(f'You have 8 “fake” JAX devices now: {jax.devices()}') ``` Set up a `2x4` device mesh as the [JAX data sharding tutorial](https://docs.jax.dev/en/latest/sharded-computation.html#key-concept-data-sharding) instructs. In this guide we use a standard FSDP layout and shard our devices on two axes - `data` and `model`, for doing batch data parallelism and tensor parallelism. ```{code-cell} ipython3 # Create an auto-mode mesh of two dimensions and annotate each axis with a name. auto_mesh = jax.make_mesh((2, 4), ('data', 'model')) ``` > Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function. ```{code-cell} ipython3 nnx.use_eager_sharding(True) assert nnx.using_eager_sharding() ``` The `nnx.use_eager_sharding` function can also be used as a context manager to toggle the eager sharding feature within a specific scope. ```{code-cell} ipython3 with nnx.use_eager_sharding(False): assert not nnx.using_eager_sharding() ``` You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way. ```{code-cell} ipython3 nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh) ``` ## Shard a single-array model Let's begin by sharding the simplest component possible - a Flax variable. When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded. **You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook. ```{code-cell} ipython3 rngs = nnx.Rngs(0) with jax.set_mesh(auto_mesh): w = nnx.Param( rngs.lecun_normal()((4, 8)), out_sharding=(None, 'model') ) print(w.sharding.spec) jax.debug.visualize_array_sharding(w) # already sharded! ``` ### Initialize with style When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight. Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out. ```{code-cell} ipython3 @jax.jit def init_sharded_linear(key): init_fn = nnx.nn.linear.default_kernel_init # Shard your parameter along `model` dimension, as in model/tensor parallelism return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key), kernel_init=nnx.with_partitioning(init_fn, (None, 'model'))) with jax.set_mesh(auto_mesh): key= rngs() linear = init_sharded_linear(key) assert linear.kernel.sharding.spec == P(None, 'model') # already sharded! ``` ### Run the model If you also shard your input correctly, JAX would be able to carry out the most natural and optimized computation and produce your output as sharded. You should still make sure to `jax.jit` for maximum performance, and also to explicitly control how each array is sharded when you want to. We will give an example of that control in the next section. > Note: You need to `jax.jit` a pure function that takes the model as an argument, instead of jitting the callable model directly. ```{code-cell} ipython3 # For simple computations, you can get correctly-sharded output without jitting # In this case, ('data', None) @ (None, 'model') = ('data', 'model') with jax.set_mesh(auto_mesh): # Create your input data, sharded along `data` dimension, as in data parallelism x = jax.device_put(jnp.ones((16, 4)), P('data', None)) # Run the model forward function, jitted y = jax.jit(lambda m, x: m(x))(linear, x) print(y.sharding.spec) # sharded: ('data', 'model') jax.debug.visualize_array_sharding(y) ``` ## Shard a wholesome model Now we construct a more wholesome model to show a few advanced tricks. Check out this simple `DotReluDot` module that does two matmuls, and the `MultiDotReluDot` module that creates an arbitrary stack of `DotReluDot` sublayers. Make note of the following: * **Additional axis annotation**: Transforms like `vmap` and `scan` will add additional dimensions to the JAX arrays. Unfortunately, in auto sharding mode you will need to use `nnx.vmap` and `nnx.scan` instead of raw JAX transforms, so that both JAX and Flax knows how to shard this dimension. You won't need this in [explicit sharding mode](#explicit-sharding). * [`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#constraining-shardings-of-intermediates-in-jitted-code): They can help you to enforce specific shardings on intermediate activations. Only works under an auto mode mesh context. ```{code-cell} ipython3 class DotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, kernel_init=nnx.with_partitioning(init_fn, (None, 'model')), use_bias=False, # or use `bias_init` to give it annotation too rngs=rngs) self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation sharding=('model', None), ) def __call__(self, x: jax.Array): y = self.dot1(x) y = jax.nn.relu(y) y = jax.lax.with_sharding_constraint(y, P('data', 'model')) z = jnp.dot(y, self.w2[...]) return z class MultiDotReluDot(nnx.Module): def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs): # Annotate the additional axis with sharding=None, meaning it will be # replicated across all devices. @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None}) def create_sublayers(r): return DotReluDot(depth, r) self.layers = create_sublayers(rngs.fork(split=num_layers)) def __call__(self, x): def scan_over_layers(x, layer): return layer(x), None x, _ = jax.lax.scan(scan_over_layers, x, self.layers) return x ``` Now a sample training loop, using `jax.jit`. ```{code-cell} ipython3 @jax.jit def train_step(model, optimizer, x, y): def loss_fn(model: DotReluDot): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) loss, grads = jax.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return model, loss with jax.set_mesh(auto_mesh): # Training data input = jax.device_put(rngs.normal((8, 1024)), P('data', None)) label = jax.device_put(rngs.normal((8, 1024)), P('data', None)) # Model and optimizer model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) # The loop for i in range(5): model, loss = train_step(model, optimizer, input, label) print(loss) # Model (over-)fitting to the labels quickly. ``` ## Profiling If you are using a Google TPU pod or a pod slice, you can create a custom `block_all()` utility function, as defined below, to measure the performance: ```{code-cell} ipython3 %%timeit def block_all(xs): jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs) return xs with jax.set_mesh(auto_mesh): new_state = block_all(train_step(model, optimizer, input, label)) ``` ## Load a sharded model from a checkpoint Now you learned how to initialize a sharded model without OOM, but what about saving and loading it from a checkpoint on disk? JAX checkpointing libraries, such as [Orbax](https://orbax.readthedocs.io/en/latest/), support loading a model distributedly if a sharding pytree is provided. Below is an example that uses Orbax's `StandardCheckpointer` API. Make sure you save a model's state, especially if your model shares some variables across modules. Given a You can generate an identical abstract pytree with shardings using Flax’s `nnx.get_abstract_model`. ```{code-cell} ipython3 import orbax.checkpoint as ocp # Save the sharded state. sharded_state = nnx.state(model) path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') checkpointer = ocp.StandardCheckpointer() checkpointer.save(path / 'checkpoint_name', sharded_state) # Load a sharded state from the checkpoint. graphdef, abs_state = nnx.get_abstract_model( lambda: MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)), auto_mesh) restored_state = checkpointer.restore(path / 'checkpoint_name', target=abs_state) restored_model = nnx.merge(graphdef, abs_state) print(restored_model.layers.dot1.kernel.sharding.spec) print(restored_model.layers.dot1.kernel.shape) ``` ## Logical axis annotation JAX's [automatic](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) [SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you have the option to annotate with more descriptive axis names (not just device mesh axis names like `'data'` and `'model'`), as long as you provide a mapping from your alias to the device mesh axes. You can provide the mapping along with the annotation as another metadata of the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), or overwrite it at top-level. Check out the `LogicalDotReluDot` example below. ```{code-cell} ipython3 # The mapping from alias annotation to the device mesh. sharding_rules = (('batch', 'data'), ('hidden', 'model'), ('embed', None)) class LogicalDotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')), use_bias=False, # or use `bias_init` to give it annotation too rngs=rngs) self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation sharding=('hidden', 'embed'), ) def __call__(self, x: jax.Array): y = self.dot1(x) y = jax.nn.relu(y) # Unfortunately the logical aliasing doesn't work on lower-level JAX calls. y = jax.lax.with_sharding_constraint(y, P('data', None)) z = jnp.dot(y, self.w2[...]) return z class LogicalMultiDotReluDot(nnx.Module): def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs): @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None}) def create_sublayers(r): return LogicalDotReluDot(depth, r) self.layers = create_sublayers(rngs.fork(split=num_layers)) def __call__(self, x): def scan_over_layers(x, layer): return layer(x), None x, _ = jax.lax.scan(scan_over_layers, x, self.layers) return x ``` If you didn't provide all `sharding_rule` annotations in the model definition, you can apply them at top level by put them into the context via `nnx.logical_axis_rules`. ```{code-cell} ipython3 with jax.set_mesh(auto_mesh), nnx.logical_axis_rules(sharding_rules): # Model and optimizer logical_model = LogicalMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)) logical_output = logical_model(input) # Check out their equivalency with some easier-to-read sharding descriptions. assert logical_model.layers.dot1.kernel.sharding.is_equivalent_to( NamedSharding(auto_mesh, P(None, None, 'model')), ndim=3 ) assert logical_model.layers.w2.sharding.is_equivalent_to( NamedSharding(auto_mesh, P(None, 'model', None)), ndim=3 ) assert logical_output.sharding.is_equivalent_to( NamedSharding(auto_mesh, P('data', None)), ndim=2 ) ``` ### When to use device axis / logical axis Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model: * **Device mesh axis**: * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. Therefore, if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. * **Logical naming**: This is helpful if you want to experiment around and find the most optimal partition layout for your *model weights*. +++ ## Explicit sharding [Explicit sharding](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html), also called "sharding-in-types", is a new JAX sharding feature that allows every sharding of every array to be deterministic and explicit. Instead of letting XLA compiler figure out the shardings, you as user would explicitly state the shardings via JAX APIs. For education purposes, we provide a simple Flax model example using explicit sharding. Note how you specify shardings for this model: * Parameters: `out_sharding` argument passed into JAX initializers. * Ambigious computations like `jnp.dot`: provide `out_sharding` argument to specify the output sharding. * Additional dimension from transforms: use `jax.vmap`'s argument `spmd_axis_name`, instead of Flax lifted transforms. ```{code-cell} ipython3 # Explicit axis mesh explicit_mesh = jax.make_mesh((2, 4), ('data', 'model'), axis_types=(AxisType.Explicit, AxisType.Explicit)) class ExplicitDotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, kernel_init=partial(init_fn, out_sharding=P(None, 'model')), use_bias=False, rngs=rngs) self.w2 = nnx.Param( init_fn(rngs.params(), (depth, depth), out_sharding=P('model', None)), ) self.b2 = nnx.Param(jnp.zeros((depth, ), out_sharding=P(None,))) def __call__(self, x: jax.Array): y = self.dot1(x) y = jax.nn.relu(y) z = jnp.dot(y, self.w2[...], out_sharding=P('data', None)) z = z + self.b2 return z class ExplicitMultiDotReluDot(nnx.Module): def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs): # Annotate the additional axis with sharding=None, meaning it will be # replicated across all devices. @partial(jax.vmap, spmd_axis_name=None) def create_sublayers(r): return ExplicitDotReluDot(depth, r) self.layers = create_sublayers(rngs.fork(split=num_layers)) def __call__(self, x): def scan_over_layers(x, layer): return layer(x), None x, _ = jax.lax.scan(scan_over_layers, x, self.layers) return x with jax.set_mesh(explicit_mesh): model = ExplicitMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)) x = jax.device_put(rngs.normal((8, 1024)), NamedSharding(explicit_mesh, P('data', None))) y = model(x) print(model.layers.dot1.kernel.sharding.spec) print(model.layers.w2.sharding.spec) assert x.sharding.is_equivalent_to(y.sharding, ndim=2) ``` One thing easier in explicit mode is that you can obtain the abstract array tree with shardings via `jax.eval_shape`, instead of calling `nnx.get_abstract_sharding`. This is not possible in auto mode. ```{code-cell} ipython3 # Get the sharding tree to load checkpoint with with jax.set_mesh(explicit_mesh): abs_model = jax.eval_shape( lambda: ExplicitMultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))) print(abs_model.layers.dot1.kernel.sharding.spec) print(abs_model.layers.w2.sharding.spec) ``` ## Further readings JAX has abundant documentation on scaled computing. - [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map). - [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html). - [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html). ================================================ FILE: docs_nnx/guides/index.rst ================================================ Guides ====== .. toctree:: :maxdepth: 2 ../guides_basic ../guides_advanced ================================================ FILE: docs_nnx/guides/jax_and_nnx_transforms.rst ================================================ Flax NNX vs JAX transformations =============================== This guide describes the differences between `Flax NNX transformations `__ and `JAX transformations `__, and how to seamlessly switch between them or use them side-by-side. The examples here will focus on ``nnx.jit``, ``jax.jit``, ``nnx.grad`` and ``jax.grad`` function transformations (transforms). First, let's set up imports and generate some dummy data: .. testcode:: Flax NNX, JAX from flax import nnx import jax x = jax.random.normal(jax.random.key(0), (1, 2)) y = jax.random.normal(jax.random.key(1), (1, 3)) Differences *********** Flax NNX transformations can transform functions that are not pure and make mutations and side-effects: - Flax NNX transforms enable you to transform functions that take in Flax NNX graph objects as arguments - such as ``nnx.Module``, ``nnx.Rngs``, ``nnx.Optimizer``, and so on - even those whose state will be mutated. - In comparison, these kinds of objects aren't recognized in JAX transformations. The Flax NNX `Functional API `_ provides a way to convert graph structures to `pytrees `__ and back. By doing this at every function boundary you can effectively use graph structures with any JAX transforms and propagate state updates in a way consistent with functional purity. Flax NNX custom transforms, such as ``nnx.jit`` and ``nnx.grad``, simply remove the boilerplate, and as a result the code looks stateful. Below is an example of using the ``nnx.jit`` and ``nnx.grad`` transforms compared to the the code that uses ``jax.jit`` and ``jax.grad`` transforms. Notice that: - The function signature of Flax NNX-transformed functions can accept the ``nnx.Linear`` ``nnx.Module`` instances directly and make stateful updates to the ``Module``. - The function signature of JAX-transformed functions can only accept the pytree-registered ``nnx.State`` and ``nnx.GraphDef`` objects, and must return an updated copy of them to maintain the purity of the transformed function. .. codediff:: :title: Flax NNX transforms, JAX transforms :groups: Flax NNX, JAX :sync: @nnx.jit def train_step(model, x, y): def loss_fn(model): return ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn)(model) params = nnx.state(model, nnx.Param) params = jax.tree_util.tree_map( lambda p, g: p - 0.1 * g, params, grads ) nnx.update(model, params) model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) train_step(model, x, y) --- @jax.jit #! def train_step(graphdef, state, x, y): #! def loss_fn(graphdef, state): #! model = nnx.merge(graphdef, state) #! return ((model(x) - y) ** 2).mean() grads = jax.grad(loss_fn, argnums=1)(graphdef, state) #! model = nnx.merge(graphdef, state) #! params = nnx.state(model, nnx.Param) params = jax.tree_util.tree_map( lambda p, g: p - 0.1 * g, params, grads ) nnx.update(model, params) return nnx.split(model) #! graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) #! graphdef, state = train_step(graphdef, state, x, y) #! Mixing Flax NNX and JAX transforms ********************************** Both Flax NNX transforms and JAX transforms can be mixed together, so long as the JAX-transformed function in your code is pure and has valid argument types that are recognized by JAX. .. codediff:: :title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad`` :groups: Flax NNX, JAX :sync: @nnx.jit def train_step(model, x, y): def loss_fn(graphdef, state): #! model = nnx.merge(graphdef, state) return ((model(x) - y) ** 2).mean() grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #! params = nnx.state(model, nnx.Param) params = jax.tree_util.tree_map( lambda p, g: p - 0.1 * g, params, grads ) nnx.update(model, params) model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) train_step(model, x, y) --- @jax.jit #! def train_step(graphdef, state, x, y): #! model = nnx.merge(graphdef, state) def loss_fn(model): return ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn)(model) params = nnx.state(model, nnx.Param) params = jax.tree_util.tree_map( lambda p, g: p - 0.1 * g, params, grads ) nnx.update(model, params) return nnx.split(model) graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) graphdef, state = train_step(graphdef, state, x, y) ================================================ FILE: docs_nnx/guides/performance.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Performance considerations\n", "\n", "Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python which can add overhead. This overhead mostly affects small to medium models and can be mitigated in the following ways:\n", "* By leveraging JAX's [Asynchronous dispatch](#asynchronous-dispatch).\n", "* By using [nnx.cached_partial](#caching-graph-node-traversals) to cache the graph node traversals.\n", "* By using a [Functional training loop](#functional-training-loop) which stages out the graph traversals.\n", "\n", "A full resolution _might_ involve developing a C extension (e.g. `flaxlib`) to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). Before we continue lets an example of a model and a simple training loop:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "class Model(nnx.Module):\n", " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(din, dmid, rngs=rngs)\n", " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", " self.dropout = nnx.Dropout(0.2, rngs=rngs)\n", " self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = nnx.relu(self.dropout(self.bn(self.linear(x))))\n", " return self.linear_out(x)\n", " \n", "model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization\n", "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", "metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", ")\n", "\n", "@nnx.jit # <== currently slow\n", "def train_step(model, optimizer, metrics, x, y):\n", " def loss_fn(model):\n", " y_pred = model(x) # call methods directly\n", " return ((y_pred - y) ** 2).mean()\n", "\n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads) # in-place updates\n", " metrics.update(loss=loss)\n", "\n", " return loss\n", " \n", "for _ in range(10):\n", " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n", " loss = train_step(model, optimizer, metrics, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Important thing here is that we created a `train_step()` function that uses `nnx.jit` and takes in a `model`, `optimizer`, and `metrics` arguments, all of which are Flax NNX objects. We'll later see how to improve this." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Asynchronous dispatch\n", "\n", "Asynchronous dispatch is a feature of JAX where it runs operations in the background whenever possible so Python can continue executing other code. This can be use to absorb the cost of data loading and in this case the overhead of `nnx.jit` and similar transforms. In general, as the amount of computation JAX has to perform per iteration increases the more it is able to absorb the python overhead since eventually the JAX computation will be the main blocker and programs with different overhead will have the same performance. This could be achieved in a couple of ways:\n", "\n", "* Increasing the batch size.\n", "* Increasing the model size.\n", "* Performing more JAX steps per python step if data loading is fast enough.\n", "\n", "To demonstrate this, the graph below which shows total time of running [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) for both `jax.jit` and `nnx.jit` with different model sizes:\n", "\n", "![performance-graph](images/performance-graph.png)\n", "\n", "As we can observe, after a certain model size both `jax.jit` and `nnx.jit` converge to the same runtime cost. This means we don't have to modify our training loop above.\n", "\n", "## Caching graph node traversals\n", "\n", "The simplest way to get rid of the traversal overhead entirely is by using `nnx.cached_partial` to convert a transformed function and the input graph objects into a partial function which caches the graph object and just expects the remaining arguments. In this example we use `nnx.cached_partial` over `train_step` and partially apply `model`, `optimizer`, and `metrics`, to create `cached_train_step`. Then we simply update our training loop to use `cached_train_step` which only expects the `x` and `y` inputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cached_train_step = nnx.cached_partial(train_step, model, optimizer, metrics)\n", "\n", "for _ in range(10):\n", " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n", " loss = cached_train_step(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `cached_partial` will enforce that the structure of the graph nodes doesn't change during `train_step` (no mutations except for `Variable` state update) so the cache is guaranteed to be up-to-date and we can avoid costly checks which require traversals. This is actually what is expected for most step functions as making any change here would imply costly recompilation, so enforcing this might be a secondary feature that could be useful for this purpose.\n", "\n", "Similarly, to prevent the user from mutating the cached objects outside, `cached_partial` creates a copy of all the graph nodes but, to allow state to be propagated to the original objects, they share references to the same `Variable`s." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Functional training loop\n", "\n", "To remove the Python overhead we can create a functional training loop that uses regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. Concretely we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) before the training loop to create a single `graphdef` and `state` pytrees for all the graph nodes. Then we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to recreate the objects inside, and either `nnx.split` or `nnx.state` at the end to get the output `state`. At the end of the training loop or whenever needed we can use [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update the objects to the current `state`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# split before training loop\n", "graphdef, state = nnx.split((model, optimizer, metrics))\n", "\n", "@jax.jit # regular JAX\n", "def jax_train_step(graphdef, state, x, y):\n", " # merge at the beginning of the function\n", " model, optimizer, metrics = nnx.merge(graphdef, state)\n", "\n", " def loss_fn(model):\n", " y_pred = model(x) # call methods directly\n", " return ((y_pred - y) ** 2).mean()\n", "\n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", " metrics.update(loss=loss)\n", "\n", " state = nnx.state((model, optimizer, metrics))\n", " return loss, state\n", "\n", "for _ in range(10):\n", " x, y = jnp.ones((32, 2)), jnp.zeros((32, 3))\n", " loss, state = jax_train_step(graphdef, state, x, y)\n", "\n", "# update objects after training\n", "nnx.update((model, optimizer, metrics), state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that we only need to do this for `jit`, the use of other Flax transforms like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) inside `train_step` doesn't have any performance cost since `jit` will make sure this only traced once." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "name": "python", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs_nnx/guides/performance.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Performance considerations Currently, Flax [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) traverses the object graph in pure Python which can add overhead. This overhead mostly affects small to medium models and can be mitigated in the following ways: * By leveraging JAX's [Asynchronous dispatch](#asynchronous-dispatch). * By using [nnx.cached_partial](#caching-graph-node-traversals) to cache the graph node traversals. * By using a [Functional training loop](#functional-training-loop) which stages out the graph traversals. A full resolution _might_ involve developing a C extension (e.g. `flaxlib`) to speed up some of the traversal logic in [`graph.py`](https://github.com/google/flax/blob/main/flax/nnx/graph.py). Before we continue lets an example of a model and a simple training loop: ```{code-cell} from flax import nnx import jax import jax.numpy as jnp import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) metrics = nnx.MultiMetric( loss=nnx.metrics.Average('loss'), ) @nnx.jit # <== currently slow def train_step(model, optimizer, metrics, x, y): def loss_fn(model): y_pred = model(x) # call methods directly return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates metrics.update(loss=loss) return loss for _ in range(10): x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) loss = train_step(model, optimizer, metrics, x, y) ``` Important thing here is that we created a `train_step()` function that uses `nnx.jit` and takes in a `model`, `optimizer`, and `metrics` arguments, all of which are Flax NNX objects. We'll later see how to improve this. +++ ## Asynchronous dispatch Asynchronous dispatch is a feature of JAX where it runs operations in the background whenever possible so Python can continue executing other code. This can be use to absorb the cost of data loading and in this case the overhead of `nnx.jit` and similar transforms. In general, as the amount of computation JAX has to perform per iteration increases the more it is able to absorb the python overhead since eventually the JAX computation will be the main blocker and programs with different overhead will have the same performance. This could be achieved in a couple of ways: * Increasing the batch size. * Increasing the model size. * Performing more JAX steps per python step if data loading is fast enough. To demonstrate this, the graph below which shows total time of running [benchmarks/nnx_simple_training.py](https://github.com/google/flax/blob/main/benchmarks/nnx_simple_training.py) for both `jax.jit` and `nnx.jit` with different model sizes: ![performance-graph](images/performance-graph.png) As we can observe, after a certain model size both `jax.jit` and `nnx.jit` converge to the same runtime cost. This means we don't have to modify our training loop above. ## Caching graph node traversals The simplest way to get rid of the traversal overhead entirely is by using `nnx.cached_partial` to convert a transformed function and the input graph objects into a partial function which caches the graph object and just expects the remaining arguments. In this example we use `nnx.cached_partial` over `train_step` and partially apply `model`, `optimizer`, and `metrics`, to create `cached_train_step`. Then we simply update our training loop to use `cached_train_step` which only expects the `x` and `y` inputs: ```{code-cell} cached_train_step = nnx.cached_partial(train_step, model, optimizer, metrics) for _ in range(10): x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) loss = cached_train_step(x, y) ``` Note that `cached_partial` will enforce that the structure of the graph nodes doesn't change during `train_step` (no mutations except for `Variable` state update) so the cache is guaranteed to be up-to-date and we can avoid costly checks which require traversals. This is actually what is expected for most step functions as making any change here would imply costly recompilation, so enforcing this might be a secondary feature that could be useful for this purpose. Similarly, to prevent the user from mutating the cached objects outside, `cached_partial` creates a copy of all the graph nodes but, to allow state to be propagated to the original objects, they share references to the same `Variable`s. +++ ## Functional training loop To remove the Python overhead we can create a functional training loop that uses regular `jax.jit` in combination with `nnx.split` and `nnx.merge` to stage out the traversal logic. Concretely we can use [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) before the training loop to create a single `graphdef` and `state` pytrees for all the graph nodes. Then we change `train_step()` to accept `graphdef` and `state`, and use [`nnx.merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge) to recreate the objects inside, and either `nnx.split` or `nnx.state` at the end to get the output `state`. At the end of the training loop or whenever needed we can use [`nnx.update`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.update) to update the objects to the current `state`. ```{code-cell} # split before training loop graphdef, state = nnx.split((model, optimizer, metrics)) @jax.jit # regular JAX def jax_train_step(graphdef, state, x, y): # merge at the beginning of the function model, optimizer, metrics = nnx.merge(graphdef, state) def loss_fn(model): y_pred = model(x) # call methods directly return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) metrics.update(loss=loss) state = nnx.state((model, optimizer, metrics)) return loss, state for _ in range(10): x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) loss, state = jax_train_step(graphdef, state, x, y) # update objects after training nnx.update((model, optimizer, metrics), state) ``` Notice that we only need to do this for `jit`, the use of other Flax transforms like [`nnx.value_and_grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.value_and_grad) inside `train_step` doesn't have any performance cost since `jit` will make sure this only traced once. ================================================ FILE: docs_nnx/guides/pytree.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "90ad74ee", "metadata": {}, "source": [ "# Module & Pytree" ] }, { "cell_type": "markdown", "id": "5c0691d0", "metadata": {}, "source": [ "Flax NNX's Modules are by default registered as JAX Pytrees, this allows using them throughout most of JAX APIs but in particular JAX transforms and the `jax.tree.*` functions. Thanks to the pytree protocol a simple NNX program might look like this:" ] }, { "cell_type": "code", "execution_count": 1, "id": "9b2b929d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (5, 3)\n" ] } ], "source": [ "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "class Linear(nnx.Module):\n", " def __init__(self, din, dout, rngs: nnx.Rngs):\n", " self.din, self.dout = din, dout\n", " self.kernel = nnx.Param(rngs.normal((din, dout)))\n", "\n", "rngs = nnx.Rngs(0)\n", "weights = Linear(2, 3, rngs=rngs)\n", "\n", "@jax.jit\n", "def forward(weights, x):\n", " return x @ weights.kernel\n", "\n", "y = forward(weights, x=rngs.uniform((5, 2)))\n", "print(f\"{y.shape = }\")" ] }, { "cell_type": "markdown", "id": "bcfbbdb2", "metadata": {}, "source": [ "Here `weights`, of type `Linear`, was able to be passed directly to the `jit`-ed function `forward`. Throughout the rest of this guide we will try to answer the questions:\n", "1. What are pytrees? \n", "2. How does NNX implement pytrees?" ] }, { "cell_type": "markdown", "id": "4f610bb3", "metadata": {}, "source": [ "## Pytrees 101\n", "Most modern ML models have too many Arrays for users to pass around individually, to deal with this JAX developed a way to track Array data in nested structures that still allowed caching for compilation: Pytrees. JAX pytrees are tree structures made of python objects that can be recursively traversed in order to collect an ordered list of leaves and a definition of the tree structure, this is done via the `jax.tree.flatten` function. Most common pytrees are native python containers like `list`, `dict`, and `tuple`, but interestingly it also include `None`. The example bellow shows how to collect all the integer leaves from a nested structure using `flatten`:" ] }, { "cell_type": "code", "execution_count": 2, "id": "c3529274", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "leaves = [1, 2, 3, 4]\n", "treedef = PyTreeDef([{'a': *}, {'b': *, 'c': (*, *), 'd': None}])\n" ] } ], "source": [ "pytree = [\n", " {'a': 1},\n", " {\n", " 'b': 2,\n", " 'c': (3, 4),\n", " 'd': None,\n", " }\n", "]\n", "\n", "leaves, treedef = jax.tree.flatten(pytree)\n", "print(f\"leaves = {leaves}\")\n", "print(f\"treedef = {treedef}\")" ] }, { "cell_type": "markdown", "id": "9c037b5e", "metadata": {}, "source": [ "Note that `None` is not a leaf because its defined as a pytree with no children. The main purpose of being able to flatten, apart from collecting the leaves, is being able reconstruct the pytree structure from the tree definition from any sequence of leaves of the same length via the `jax.tree.unflatten` function:" ] }, { "cell_type": "code", "execution_count": 3, "id": "d8237524", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "old pytree = [{'a': 1}, {'b': 2, 'c': (3, 4), 'd': None}]\n", "new pytree = [{'a': 10}, {'b': 20, 'c': (30, 40), 'd': None}]\n" ] } ], "source": [ "new_leaves = [x * 10 for x in leaves]\n", "new_pytree = jax.tree.unflatten(treedef, new_leaves)\n", "\n", "print(f\"old pytree = {pytree}\")\n", "print(f\"new pytree = {new_pytree}\")" ] }, { "cell_type": "markdown", "id": "117c8b2d", "metadata": {}, "source": [ "### Custom Pytrees\n", "JAX allows us to register custom pytree node type by using the `jax.tree_util.register_pytree_node` utility. For any type we are able to define a flatten that decomposes the object into a a sequence of nodes / children and a static (hashable) structure, and a unflatten function which takes the sequence of nodes and the static structure and creates a new instance. In the following example we create a simple type `Foo` with the attributes `a`, `b`, and `c`, and define `a` and `b` as nodes and `c` as static." ] }, { "cell_type": "code", "execution_count": 4, "id": "0c46905c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "leaves = [1, 2]\n", "treedef = PyTreeDef(CustomNode(Foo[('hi',)], [*, *]))\n" ] } ], "source": [ "class Foo:\n", " def __init__(self):\n", " self.a = 1\n", " self.b = 2\n", " self.c = \"hi\"\n", "\n", "def flatten_foo(foo: Foo):\n", " nodes = [foo.a, foo.b] # sequence of nodes\n", " static = (foo.c,) # hashable & equatable structure\n", " return nodes, static\n", "\n", "def unflatten_foo(static, nodes):\n", " foo = object.__new__(Foo) # create uninitialized instance\n", " foo.a = nodes[0]\n", " foo.b = nodes[1]\n", " foo.c = static[0]\n", " return foo\n", "\n", "jax.tree_util.register_pytree_node(Foo, flatten_foo, unflatten_foo)\n", "\n", "foo = Foo()\n", "leaves, treedef = jax.tree.flatten(foo)\n", "print(f\"leaves = {leaves}\")\n", "print(f\"treedef = {treedef}\")" ] }, { "cell_type": "markdown", "id": "15647b13", "metadata": {}, "source": [ "Notice that `'hi'` does not appear in the leaves because `c` is defined as static, but you can see it as part of the `PyTreeDef` structure." ] }, { "cell_type": "markdown", "id": "d603fa09", "metadata": {}, "source": [ "## nnx.Pytree\n", "In general it would be cumbersome for users to manually register the pytree definition for every type they create. To automate this process NNX provides the `nnx.Pytree` base type that offers a simple API: users annotate attributes using either `nnx.static` or `nnx.data`, and Pytree will register some flatten and unflatten functions that will take the annotations into account. The `nnx.data` and `nnx.static` annotations must only be assigned to `Pytree` attributes directly." ] }, { "cell_type": "code", "execution_count": 5, "id": "95016a94", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pytree structure:\n", " - pytree.layers[0].b = Array([0.], dtype=float32)\n", " - pytree.layers[0].w = Array([[1.]], dtype=float32)\n", " - pytree.layers[1].b = Array([0.], dtype=float32)\n", " - pytree.layers[1].w = Array([[1.]], dtype=float32)\n" ] } ], "source": [ "class Linear(nnx.Pytree):\n", " def __init__(self, din: int, dout: int):\n", " self.din = nnx.static(din)\n", " self.dout = nnx.static(dout)\n", " self.w = nnx.data(jnp.ones((din, dout)))\n", " self.b = nnx.data(jnp.zeros((dout,)))\n", "\n", "class MLP(nnx.Pytree):\n", " def __init__(self, num_layers, dim):\n", " self.num_layers = nnx.static(num_layers)\n", " self.layers = nnx.data([\n", " Linear(dim, dim) for _ in range(num_layers)\n", " ])\n", "\n", "pytree = MLP(num_layers=2, dim=1)\n", "\n", "def pytree_structure(pytree, title='pytree structure'):\n", " print(f\"{title}:\")\n", " path_leaves, treedef = jax.tree.flatten_with_path(pytree)\n", " for path, value in path_leaves:\n", " print(f\" - pytree{jax.tree_util.keystr(path)} = {value!r}\")\n", "\n", "pytree_structure(pytree)" ] }, { "cell_type": "markdown", "id": "3a2db214", "metadata": {}, "source": [ "As you can see above, only the `data` paths appear in the leaves. However, its very verbose to have to define `static` and `data` for each attribute, so `Pytree` has sensible defaults. You can remove most of them and it will just work:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a8665146", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pytree structure:\n", " - pytree.layers[0].b = Array([0.], dtype=float32)\n", " - pytree.layers[0].w = Array([[1.]], dtype=float32)\n", " - pytree.layers[1].b = Array([0.], dtype=float32)\n", " - pytree.layers[1].w = Array([[1.]], dtype=float32)\n" ] } ], "source": [ "class Linear(nnx.Pytree):\n", " def __init__(self, din: int, dout: int):\n", " self.din = din # static\n", " self.dout = dout # static\n", " self.w = jnp.ones((din, dout)) # data\n", " self.b = jnp.zeros((dout,)) # data\n", "\n", "class MLP(nnx.Pytree):\n", " def __init__(self, num_layers, dim):\n", " self.num_layers = num_layers # static\n", " self.layers = nnx.List([ # data\n", " Linear(dim, dim) for _ in range(num_layers)\n", " ])\n", "\n", "pytree = MLP(num_layers=2, dim=1)\n", "pytree_structure(pytree)" ] }, { "cell_type": "markdown", "id": "a249c5f0", "metadata": {}, "source": [ "The only change we had to do here is use `nnx.List` to signal that `layers` contains `data`, the status of the rest of the attributes can be correctly inferred. The rules that determine if a value is data or not are the following:\n", "\n", "* `Array`s, `Variable`s, `ArrayRef`s, and `nnx.Pytree`s are data.\n", "* Types registered using `nnx.register_data_type` are data.\n", "* All other types are static.\n", "\n", "To check if a value is data use the `nnx.is_data` function which will return its status:" ] }, { "cell_type": "code", "execution_count": 7, "id": "27682936", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "# ------ DATA ------------\n", "nnx.is_data( jnp.array(0) ) = True # Arrays are data\n", "nnx.is_data( nnx.Param(1) ) = True # Variables are data\n", "nnx.is_data( nnx.Rngs(2) ) = True # nnx.Pytrees are data\n", "\n", "# ------ STATIC ------------\n", "nnx.is_data( 'hello' ) = False # strings, arbitrary objects\n", "nnx.is_data( 42 ) = False # int, float, bool, complex, etc.\n", "nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) = False # list, dict, tuple, regular pytrees\n", "\n" ] } ], "source": [ "print(f\"\"\"\n", "# ------ DATA ------------\n", "{nnx.is_data( jnp.array(0) ) = } # Arrays are data\n", "{nnx.is_data( nnx.Param(1) ) = } # Variables are data\n", "{nnx.is_data( nnx.Rngs(2) ) = } # nnx.Pytrees are data\n", "\n", "# ------ STATIC ------------\n", "{nnx.is_data( 'hello' ) = } # strings, arbitrary objects\n", "{nnx.is_data( 42 ) = } # int, float, bool, complex, etc.\n", "{nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) = } # list, dict, tuple, regular pytrees\n", "\"\"\")" ] }, { "cell_type": "markdown", "id": "5ab06e79", "metadata": {}, "source": [ "### When to use explicit annotations?" ] }, { "cell_type": "markdown", "id": "9dead73d", "metadata": {}, "source": [ "There are cases were you do want to explicitely annotate the attributes to avoid ambiguity or protect yourself against possible edge cases. These include constraining input arguments which might have unexpected types, forcing attributes as data when their type is not treated as data by default, or using `nnx.static` as a way to assert the attribute should not contain data." ] }, { "cell_type": "code", "execution_count": 8, "id": "9e064461", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pytree structure:\n", " - pytree.bias.value = Array(0., dtype=float32, weak_type=True)\n", " - pytree.ls[0] = Array(0, dtype=int32, weak_type=True)\n", " - pytree.ls[1] = Array(1, dtype=int32, weak_type=True)\n", " - pytree.ls[2] = Array(2, dtype=int32, weak_type=True)\n", " - pytree.x = 1.0\n", " - pytree.y = 42\n" ] } ], "source": [ "class Bar(nnx.Pytree):\n", " def __init__(self, x, use_bias: bool):\n", " self.x = nnx.data(x) # constrain inputs (e.g. user could pass Array or float)\n", " self.y = nnx.data(42) # force types that are not data by default\n", " self.ls = nnx.List([jnp.array(i) for i in range(3)]) # use nnx.List for lists of data\n", " self.bias = nnx.data(None) # optional values that can be data\n", " if use_bias:\n", " self.bias = nnx.Param(jnp.array(0.0))\n", "\n", "pytree = Bar(1.0, True)\n", "pytree_structure(pytree)" ] }, { "cell_type": "markdown", "id": "8055b72c", "metadata": {}, "source": [ "### Dataclasses\n", "`nnx.Pytree` dataclasses can be created by using the `nnx.dataclass` decorator. To control the status of each field, `nnx.static` and `nnx.data` can be used as `field` specifiers." ] }, { "cell_type": "code", "execution_count": 9, "id": "9ca77d65", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pytree structure:\n", " - pytree.ls[0].i = 0\n", " - pytree.ls[0].x = Array(0, dtype=int32, weak_type=True)\n", " - pytree.ls[1].i = 1\n", " - pytree.ls[1].x = Array(42, dtype=int32, weak_type=True)\n" ] } ], "source": [ "import dataclasses\n", "\n", "@nnx.dataclass\n", "class Foo(nnx.Pytree):\n", " i: int = nnx.data()\n", " x: jax.Array\n", " a: int\n", " s: str = nnx.static(default='hi', kw_only=True)\n", "\n", "@nnx.dataclass\n", "class Bar(nnx.Pytree):\n", " ls: list[Foo] = nnx.data()\n", " shapes: list[int]\n", "\n", "pytree = Bar(\n", " ls=[Foo(i, jnp.array(42 * i), hash(i)) for i in range(2)],\n", " shapes=[8, 16, 32]\n", ")\n", "pytree_structure(pytree)" ] }, { "cell_type": "markdown", "id": "fca51f65", "metadata": {}, "source": [ "`dataclasses.dataclass` can also be used directly, however type checkers will not handle `nnx.static` and `nnx.data` correctly. To solve this `dataclasses.field` can be used by setting `metadata` with the appropriate entry for `static`." ] }, { "cell_type": "code", "execution_count": 10, "id": "ff54e732", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataclass pytree structure:\n", " - pytree.a = 10\n" ] } ], "source": [ "@dataclasses.dataclass\n", "class Bar(nnx.Pytree):\n", " a: int = dataclasses.field(metadata={'static': False}) # data\n", " b: str = dataclasses.field(metadata={'static': True}) # static\n", "\n", "pytree = Bar(a=10, b=\"hello\")\n", "pytree_structure(pytree, title='dataclass pytree structure')" ] }, { "cell_type": "markdown", "id": "d6036a0e", "metadata": {}, "source": [ "### Attribute Updates" ] }, { "cell_type": "markdown", "id": "4e0a1f63", "metadata": {}, "source": [ "The status of an attribute is defined during its first assignment and will not change upon reassignment. However, it is possible to override the status by explicitly using `nnx.data` or `nnx.static` on reassignment." ] }, { "cell_type": "code", "execution_count": 11, "id": "509a517e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "original:\n", " - pytree.a = Array(1., dtype=float32, weak_type=True)\n", " - pytree.c = 3.14\n", "updated:\n", " - pytree.a = '🤔'\n", " - pytree.b = 42\n" ] } ], "source": [ "class Foo(nnx.Pytree):\n", " def __init__(self):\n", " self.a = jnp.array(1.0) # data\n", " self.b = \"Hello, world!\" # static\n", " self.c = nnx.data(3.14) # data\n", "\n", "pytree = Foo()\n", "pytree_structure(pytree, \"original\")\n", "\n", "pytree.a = \"🤔\" # data status doesn't change\n", "pytree.b = nnx.data(42) # explicit annotation overrides status to data\n", "pytree.c = nnx.static(0.5) # explicit annotation overrides status to static\n", "pytree_structure(pytree, \"updated\")" ] }, { "cell_type": "markdown", "id": "17fd41f5", "metadata": {}, "source": [ "### Attribute checks\n", "`Pytree` has a variety of checks to prevent a common class of errors in JAX. This includes checking for Arrays being assigned to new `static` attributes:" ] }, { "cell_type": "code", "execution_count": 12, "id": "98e04ff9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError: Found Arrays in value of type '' annotated with nnx.static(...) when setting attribute 'name' of Pytree type ''.\n" ] } ], "source": [ "class Foo(nnx.Pytree):\n", " def __init__(self, name):\n", " self.name = nnx.static(name)\n", "\n", "try:\n", " foo = Foo(name=jnp.array(123))\n", "except ValueError as e:\n", " print(\"ValueError:\", e)" ] }, { "cell_type": "markdown", "id": "070a07f0", "metadata": {}, "source": [ "Checking for Arrays being assigned to known `static` attributes:" ] }, { "cell_type": "code", "execution_count": 13, "id": "c864d5b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError: Cannot assign data value of type '' to static attribute 'name' of Pytree type ''. To override the status explicitly wrap the value with nnx.data on assignment:\n", "\n", " _.name = nnx.data(...)\n", "\n", "\n" ] } ], "source": [ "try:\n", " foo = Foo(name=\"mattjj\")\n", " foo.name = jnp.array(123)\n", "except ValueError as e:\n", " print(\"ValueError:\", e)" ] }, { "cell_type": "markdown", "id": "c3b17f07", "metadata": {}, "source": [ "Checking for Arrays after `__init__` on `static` attributes that could've been inserted via mutation. This check can be manually trigger via `nnx.check_pytree` at any time." ] }, { "cell_type": "code", "execution_count": 14, "id": "628698dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError: Found unexpected Arrays on value of type in static attribute 'ls' of Pytree type ''. This is an error starting from Flax version 0.12.0.\n", "Consider one of the following options:\n", "\n", "1. If the attribute is meant to be static, either remove the Array value or wrap it in a static container.\n", "2. Wrap the value with nnx.data on assignment:\n", "\n", " _.ls = nnx.data(...)\n", "\n", "3. Annotate the class attribute with nnx.Data:\n", "\n", " class Foo(Pytree):\n", " ls: nnx.Data[list]\n", "\n", "4. If the container is a list or dict, try using nnx.List(...) or nnx.Dict(...) instead.\n", "5. Disable pytree for this class:\n", "\n", " class Foo(Pytree, pytree=False):\n", "\n", "\n" ] } ], "source": [ "class Foo(nnx.Pytree):\n", " def __init__(self):\n", " self.ls = [] # treated as static\n", " for i in range(5):\n", " self.ls.append(jnp.array(i)) # append arrays into static attribute\n", "\n", "try:\n", " foo = Foo() # error: Array found in static attribute after `__init__`\n", "except ValueError as e:\n", " print(\"ValueError:\", e)" ] }, { "cell_type": "markdown", "id": "37ee2429", "metadata": {}, "source": [ "Checking for `nnx.data` or `nnx.static` annotations stored inside nested structures that are not `nnx.Pytree` instances:" ] }, { "cell_type": "code", "execution_count": 15, "id": "f9d69634", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError: Found unexpected tags {'static', 'data'} on attribute 'Foo.a'. Values from nnx.data(...) and\n", "nnx.static(...) should be assigned to nnx.Pytree attributes directly, they should not be inside other structures. Got value of type '' on Pytree of type ''.\n" ] } ], "source": [ "class Foo(nnx.Pytree):\n", " def __init__(self):\n", " self.a = [nnx.data(1), nnx.static(2)] # annotations in sub-pytree\n", "\n", "try:\n", " foo = Foo()\n", "except ValueError as e:\n", " print(\"ValueError:\", e)" ] }, { "cell_type": "markdown", "id": "452915ac", "metadata": {}, "source": [ "### Trace-level awareness\n", "To prevent tracer leakage NNX will raise an error when trying to update the attribute of a `Pytree` or the value of a `Variable` on instances that are passed as captures to functions called by JAX transforms:" ] }, { "cell_type": "code", "execution_count": 16, "id": "668db479", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Error: Cannot mutate 'Foo' from different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)\n" ] } ], "source": [ "class Foo(nnx.Pytree):\n", " def __init__(self):\n", " self.count = nnx.data(0)\n", "\n", "foo = Foo()\n", "\n", "@jax.vmap # or jit, grad, shard_map, pmap, scan, etc.\n", "def increment(n):\n", " # foo passed as capture\n", " foo.count += 1 # error!\n", "\n", "try:\n", " increment(jnp.arange(5))\n", "except Exception as e:\n", " print(f\"Error: {e}\")" ] }, { "cell_type": "markdown", "id": "536d77b0", "metadata": {}, "source": [ "### Reference Sharing" ] }, { "cell_type": "markdown", "id": "2449a3c5", "metadata": {}, "source": [ "As the name implies Pytrees should be trees. To check if a structure is a well-defined tree you can use the `nnx.find_duplicates` functions which will return a list of duplicates, where each duplicate is a list of path tuples. In the example below we see that `left` and `right` are shared references therefore `find_duplicates` returns a non-empty list with the paths:" ] }, { "cell_type": "code", "execution_count": 17, "id": "32c46ce8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nnx.find_duplicates(m) = [[('left',), ('right',)]] # not a tree\n" ] } ], "source": [ "class Shared(nnx.Pytree):\n", " def __init__(self):\n", " self.x = jnp.array(1.0)\n", "\n", "class Parent(nnx.Pytree):\n", " def __init__(self):\n", " self.left = Shared()\n", " self.right = self.left # reference sharing\n", "\n", "m = Parent()\n", "\n", "print(f\"{nnx.find_duplicates(m) = } # not a tree\")" ] }, { "cell_type": "markdown", "id": "49b2e5f0", "metadata": {}, "source": [ "The main issue is that sharing is not preserved across pytree operations including JAX transforms, and this results in unintended state duplication:" ] }, { "cell_type": "code", "execution_count": 18, "id": "c33e4862", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: m.left is m.right = True\n", "Inside: m.left is m.right = False\n", "After: m.left is m.right = False\n" ] } ], "source": [ "m = Parent()\n", "print(f\"Before: {m.left is m.right = }\")\n", "\n", "@jax.jit\n", "def f(m):\n", " print(f\"Inside: {m.left is m.right = }\")\n", " return m\n", "\n", "m = f(m)\n", "print(f\"After: {m.left is m.right = }\")" ] }, { "cell_type": "markdown", "id": "d60953f4", "metadata": {}, "source": [ "Reference sharing is rare in most Machine Learning applications, however if it is required you can either use the `nnx.{split, merge, state, update}` APIs to move the deduplicated state and graph definiton across the JAX transforms:" ] }, { "cell_type": "code", "execution_count": 19, "id": "dda51b67", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: m.left is m.right = True\n", "Inside: m.left is m.right = True\n", "After: m.left is m.right = True\n", "state = State({\n", " 'left': {\n", " 'x': Array(1., dtype=float32, weak_type=True)\n", " }\n", "})\n" ] } ], "source": [ "m = Parent()\n", "print(f\"Before: {m.left is m.right = }\")\n", "graphdef, state = nnx.split(m)\n", "\n", "@jax.jit\n", "def f(graphdef, state):\n", " m = nnx.merge(graphdef, state)\n", " print(f\"Inside: {m.left is m.right = }\")\n", " return nnx.state(m)\n", "\n", "state = f(graphdef, state)\n", "nnx.update(m, state)\n", "\n", "print(f\"After: {m.left is m.right = }\")\n", "print(f\"{state = }\") # deduplicated state" ] }, { "cell_type": "markdown", "id": "0afee781", "metadata": {}, "source": [ "Or alternatively you can use the NNX transforms which preserve shared references:" ] }, { "cell_type": "code", "execution_count": 20, "id": "caa01e3b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: m.left is m.right = True\n", "Inside: m.left is m.right = True\n", "After: m.left is m.right = True\n" ] } ], "source": [ "m = Parent()\n", "print(f\"Before: {m.left is m.right = }\")\n", "\n", "@nnx.jit\n", "def f(m):\n", " print(f\"Inside: {m.left is m.right = }\")\n", " return m\n", "\n", "m = f(m)\n", "\n", "print(f\"After: {m.left is m.right = }\")" ] }, { "cell_type": "markdown", "id": "734aa5b3", "metadata": {}, "source": [ "### Turning off pytree registration\n", "`nnx.Pytree` allows you to turn off the pytree registration along with the attribute checks for subtypes by setting `pytree` type attribute option to `False`. This can be useful when upgrading to previous NNX code to newer Flax verions as you will still be able to use the NNX APIs or when creating types that should not be treated as pytree because e.g. they shared references." ] }, { "cell_type": "code", "execution_count": 21, "id": "d2e03753", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " nnx.state(foo) = State({\n", " 'a': {\n", " 0: Array(2, dtype=int32, weak_type=True),\n", " 1: Array(4, dtype=int32, weak_type=True)\n", " },\n", " 'b': Array(6, dtype=int32, weak_type=True)\n", "})\n", " jax.tree_util.all_leaves([foo]) = True\n" ] } ], "source": [ "class Foo(nnx.Pytree, pytree=False):\n", " def __init__(self):\n", " self.a = [jnp.array(1), jnp.array(2)] # no checks\n", " self.b = \"hello\" \n", " self.b = jnp.array(3) # no checks\n", "\n", "foo = Foo()\n", "\n", "@nnx.jit # can use in NNX transformations\n", "def double(foo: Foo):\n", " foo.a = [x * 2 for x in foo.a]\n", " foo.b *= 2\n", "\n", "double(foo)\n", "print(f\"{ nnx.state(foo) = }\") # can be used with NNX APIs\n", "print(f\"{ jax.tree_util.all_leaves([foo]) = }\") # not a pytree" ] }, { "cell_type": "markdown", "id": "bfaf17de", "metadata": {}, "source": [ "## Module" ] }, { "cell_type": "markdown", "id": "67000b88", "metadata": {}, "source": [ "NNX Modules are `Pytree`s that have two additional methods for traking intermediate values: `sow` and `perturb`." ] }, { "cell_type": "markdown", "id": "cc5afc70", "metadata": {}, "source": [ "### sow\n", "`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module." ] }, { "cell_type": "code", "execution_count": 22, "id": "ca9f58a2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'blocks'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m0\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'y_mean'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mIntermediate\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0mArray(4.659754e-06, dtype=float32),\u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m1\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'y_mean'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mIntermediate\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0mArray(0.00025933, dtype=float32),\u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254m2\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'y_mean'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mIntermediate\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0mArray(0.05561922, dtype=float32),\u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", " \u001b[38;2;255;213;3m}\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "class Block(nnx.Module):\n", " def __init__(self, din: int, dout: int, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", " self.bn = nnx.BatchNorm(dout, rngs=rngs)\n", " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " y = nnx.relu(self.dropout(self.bn(self.linear(x))))\n", " self.sow(nnx.Intermediate, \"y_mean\", jnp.mean(y))\n", " return y\n", "\n", "class MLP(nnx.Module):\n", " def __init__(self, num_layers, dim, rngs: nnx.Rngs):\n", " self.blocks = nnx.List([Block(dim, dim, rngs) for _ in range(num_layers)])\n", "\n", " def __call__(self, x):\n", " for block in self.blocks:\n", " x = block(x)\n", " return x\n", "\n", "\n", "model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0))\n", "x = jnp.ones((10, 20))\n", "y = model(x)\n", "intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values\n", "print(intermediates)" ] }, { "cell_type": "markdown", "id": "6e266e5f", "metadata": {}, "source": [ "### perturb\n", "`perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future:\n", "1. Initialize the pertubation state by running the model once.\n", "2. Pass the perturbation state as a differentiable target to `grad`.\n", "\n", "As an example lets create a simple model and use `perturb` to get the intermediate gradient `xgrad` for the variable `x`, and initialize the perturbations:" ] }, { "cell_type": "code", "execution_count": 23, "id": "41398e14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nnx.state(model, nnx.Perturbation) = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[0., 0., 0.]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "import optax\n", "\n", "class Model(nnx.Module):\n", " def __init__(self, rngs):\n", " self.linear1 = nnx.Linear(2, 3, rngs=rngs)\n", " self.linear2 = nnx.Linear(3, 4, rngs=rngs)\n", " def __call__(self, x):\n", " x = nnx.gelu(self.linear1(x))\n", " x = self.perturb('xgrad', x)\n", " x = self.linear2(x)\n", " return x\n", "\n", "rngs = nnx.Rngs(0)\n", "model = Model(rngs)\n", "optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param)\n", "x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4))\n", "_ = model(x) # initialize perturbations\n", "print(f\"{nnx.state(model, nnx.Perturbation) = !s}\")" ] }, { "cell_type": "markdown", "id": "c9221005", "metadata": {}, "source": [ "Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients." ] }, { "cell_type": "code", "execution_count": 24, "id": "d10effba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step = 0, loss = Array(0.7326511, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.430146 , -0.14356601, 0.2935633 ]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n", "step = 1, loss = Array(0.65039134, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.38535568, -0.11745065, 0.24441527]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m})\u001b[0m\n" ] } ], "source": [ "@nnx.jit\n", "def train_step(model, optimizer, x, y):\n", " graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)\n", "\n", " def loss_fn(params, perturbations):\n", " model = nnx.merge(graphdef, params, perturbations)\n", " return jnp.mean((model(x) - y) ** 2)\n", "\n", " loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations)\n", " optimizer.update(model, grads)\n", "\n", " return loss, iterm_grads\n", "\n", "for step in range(2):\n", " loss, iterm_grads = train_step(model, optimizer, x, y)\n", " print(f\"{step = }, {loss = }, {iterm_grads = !s}\")" ] }, { "cell_type": "markdown", "id": "d8511c3d", "metadata": {}, "source": [ "## Object" ] }, { "cell_type": "markdown", "id": "1c2fd61c", "metadata": {}, "source": [ "`Object` are NNX types that are **not** registered as JAX pytrees. Formally, any `Object` subclass is a `nnx.Pytree` with `pytree=False`." ] }, { "cell_type": "code", "execution_count": 25, "id": "a9cab639", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " nnx.state(foo) = State({\n", " 'a': {\n", " 0: Array(2, dtype=int32, weak_type=True),\n", " 1: Array(4, dtype=int32, weak_type=True)\n", " },\n", " 'b': Array(6, dtype=int32, weak_type=True)\n", "})\n", " jax.tree_util.all_leaves([foo]) = True\n" ] } ], "source": [ "class Foo(nnx.Object): # instead of Foo(nnx.Pytree, pytree=False)\n", " def __init__(self):\n", " self.a = [jnp.array(1), jnp.array(2)] # no checks\n", " self.b = \"hello\" \n", " self.b = jnp.array(3) # no checks\n", "\n", "foo = Foo()\n", "\n", "@nnx.jit # can use in NNX transformations\n", "def double(foo: Foo):\n", " foo.a = [x * 2 for x in foo.a]\n", " foo.b *= 2\n", "\n", "double(foo)\n", "print(f\"{ nnx.state(foo) = }\") # can be used with NNX APIs\n", "print(f\"{ jax.tree_util.all_leaves([foo]) = }\") # not a pytree" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/pytree.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Module & Pytree +++ Flax NNX's Modules are by default registered as JAX Pytrees, this allows using them throughout most of JAX APIs but in particular JAX transforms and the `jax.tree.*` functions. Thanks to the pytree protocol a simple NNX program might look like this: ```{code-cell} ipython3 from flax import nnx import jax import jax.numpy as jnp class Linear(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): self.din, self.dout = din, dout self.kernel = nnx.Param(rngs.normal((din, dout))) rngs = nnx.Rngs(0) weights = Linear(2, 3, rngs=rngs) @jax.jit def forward(weights, x): return x @ weights.kernel y = forward(weights, x=rngs.uniform((5, 2))) print(f"{y.shape = }") ``` Here `weights`, of type `Linear`, was able to be passed directly to the `jit`-ed function `forward`. Throughout the rest of this guide we will try to answer the questions: 1. What are pytrees? 2. How does NNX implement pytrees? +++ ## Pytrees 101 Most modern ML models have too many Arrays for users to pass around individually, to deal with this JAX developed a way to track Array data in nested structures that still allowed caching for compilation: Pytrees. JAX pytrees are tree structures made of python objects that can be recursively traversed in order to collect an ordered list of leaves and a definition of the tree structure, this is done via the `jax.tree.flatten` function. Most common pytrees are native python containers like `list`, `dict`, and `tuple`, but interestingly it also include `None`. The example bellow shows how to collect all the integer leaves from a nested structure using `flatten`: ```{code-cell} ipython3 pytree = [ {'a': 1}, { 'b': 2, 'c': (3, 4), 'd': None, } ] leaves, treedef = jax.tree.flatten(pytree) print(f"leaves = {leaves}") print(f"treedef = {treedef}") ``` Note that `None` is not a leaf because its defined as a pytree with no children. The main purpose of being able to flatten, apart from collecting the leaves, is being able reconstruct the pytree structure from the tree definition from any sequence of leaves of the same length via the `jax.tree.unflatten` function: ```{code-cell} ipython3 new_leaves = [x * 10 for x in leaves] new_pytree = jax.tree.unflatten(treedef, new_leaves) print(f"old pytree = {pytree}") print(f"new pytree = {new_pytree}") ``` ### Custom Pytrees JAX allows us to register custom pytree node type by using the `jax.tree_util.register_pytree_node` utility. For any type we are able to define a flatten that decomposes the object into a a sequence of nodes / children and a static (hashable) structure, and a unflatten function which takes the sequence of nodes and the static structure and creates a new instance. In the following example we create a simple type `Foo` with the attributes `a`, `b`, and `c`, and define `a` and `b` as nodes and `c` as static. ```{code-cell} ipython3 class Foo: def __init__(self): self.a = 1 self.b = 2 self.c = "hi" def flatten_foo(foo: Foo): nodes = [foo.a, foo.b] # sequence of nodes static = (foo.c,) # hashable & equatable structure return nodes, static def unflatten_foo(static, nodes): foo = object.__new__(Foo) # create uninitialized instance foo.a = nodes[0] foo.b = nodes[1] foo.c = static[0] return foo jax.tree_util.register_pytree_node(Foo, flatten_foo, unflatten_foo) foo = Foo() leaves, treedef = jax.tree.flatten(foo) print(f"leaves = {leaves}") print(f"treedef = {treedef}") ``` Notice that `'hi'` does not appear in the leaves because `c` is defined as static, but you can see it as part of the `PyTreeDef` structure. +++ ## nnx.Pytree In general it would be cumbersome for users to manually register the pytree definition for every type they create. To automate this process NNX provides the `nnx.Pytree` base type that offers a simple API: users annotate attributes using either `nnx.static` or `nnx.data`, and Pytree will register some flatten and unflatten functions that will take the annotations into account. The `nnx.data` and `nnx.static` annotations must only be assigned to `Pytree` attributes directly. ```{code-cell} ipython3 class Linear(nnx.Pytree): def __init__(self, din: int, dout: int): self.din = nnx.static(din) self.dout = nnx.static(dout) self.w = nnx.data(jnp.ones((din, dout))) self.b = nnx.data(jnp.zeros((dout,))) class MLP(nnx.Pytree): def __init__(self, num_layers, dim): self.num_layers = nnx.static(num_layers) self.layers = nnx.data([ Linear(dim, dim) for _ in range(num_layers) ]) pytree = MLP(num_layers=2, dim=1) def pytree_structure(pytree, title='pytree structure'): print(f"{title}:") path_leaves, treedef = jax.tree.flatten_with_path(pytree) for path, value in path_leaves: print(f" - pytree{jax.tree_util.keystr(path)} = {value!r}") pytree_structure(pytree) ``` As you can see above, only the `data` paths appear in the leaves. However, its very verbose to have to define `static` and `data` for each attribute, so `Pytree` has sensible defaults. You can remove most of them and it will just work: ```{code-cell} ipython3 class Linear(nnx.Pytree): def __init__(self, din: int, dout: int): self.din = din # static self.dout = dout # static self.w = jnp.ones((din, dout)) # data self.b = jnp.zeros((dout,)) # data class MLP(nnx.Pytree): def __init__(self, num_layers, dim): self.num_layers = num_layers # static self.layers = nnx.List([ # data Linear(dim, dim) for _ in range(num_layers) ]) pytree = MLP(num_layers=2, dim=1) pytree_structure(pytree) ``` The only change we had to do here is use `nnx.List` to signal that `layers` contains `data`, the status of the rest of the attributes can be correctly inferred. The rules that determine if a value is data or not are the following: * `Array`s, `Variable`s, `ArrayRef`s, and `nnx.Pytree`s are data. * Types registered using `nnx.register_data_type` are data. * All other types are static. To check if a value is data use the `nnx.is_data` function which will return its status: ```{code-cell} ipython3 print(f""" # ------ DATA ------------ {nnx.is_data( jnp.array(0) ) = } # Arrays are data {nnx.is_data( nnx.Param(1) ) = } # Variables are data {nnx.is_data( nnx.Rngs(2) ) = } # nnx.Pytrees are data # ------ STATIC ------------ {nnx.is_data( 'hello' ) = } # strings, arbitrary objects {nnx.is_data( 42 ) = } # int, float, bool, complex, etc. {nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) = } # list, dict, tuple, regular pytrees """) ``` ### When to use explicit annotations? +++ There are cases were you do want to explicitely annotate the attributes to avoid ambiguity or protect yourself against possible edge cases. These include constraining input arguments which might have unexpected types, forcing attributes as data when their type is not treated as data by default, or using `nnx.static` as a way to assert the attribute should not contain data. ```{code-cell} ipython3 class Bar(nnx.Pytree): def __init__(self, x, use_bias: bool): self.x = nnx.data(x) # constrain inputs (e.g. user could pass Array or float) self.y = nnx.data(42) # force types that are not data by default self.ls = nnx.List([jnp.array(i) for i in range(3)]) # use nnx.List for lists of data self.bias = nnx.data(None) # optional values that can be data if use_bias: self.bias = nnx.Param(jnp.array(0.0)) pytree = Bar(1.0, True) pytree_structure(pytree) ``` ### Dataclasses `nnx.Pytree` dataclasses can be created by using the `nnx.dataclass` decorator. To control the status of each field, `nnx.static` and `nnx.data` can be used as `field` specifiers. ```{code-cell} ipython3 import dataclasses @nnx.dataclass class Foo(nnx.Pytree): i: int = nnx.data() x: jax.Array a: int s: str = nnx.static(default='hi', kw_only=True) @nnx.dataclass class Bar(nnx.Pytree): ls: list[Foo] = nnx.data() shapes: list[int] pytree = Bar( ls=[Foo(i, jnp.array(42 * i), hash(i)) for i in range(2)], shapes=[8, 16, 32] ) pytree_structure(pytree) ``` `dataclasses.dataclass` can also be used directly, however type checkers will not handle `nnx.static` and `nnx.data` correctly. To solve this `dataclasses.field` can be used by setting `metadata` with the appropriate entry for `static`. ```{code-cell} ipython3 @dataclasses.dataclass class Bar(nnx.Pytree): a: int = dataclasses.field(metadata={'static': False}) # data b: str = dataclasses.field(metadata={'static': True}) # static pytree = Bar(a=10, b="hello") pytree_structure(pytree, title='dataclass pytree structure') ``` ### Attribute Updates +++ The status of an attribute is defined during its first assignment and will not change upon reassignment. However, it is possible to override the status by explicitly using `nnx.data` or `nnx.static` on reassignment. ```{code-cell} ipython3 class Foo(nnx.Pytree): def __init__(self): self.a = jnp.array(1.0) # data self.b = "Hello, world!" # static self.c = nnx.data(3.14) # data pytree = Foo() pytree_structure(pytree, "original") pytree.a = "🤔" # data status doesn't change pytree.b = nnx.data(42) # explicit annotation overrides status to data pytree.c = nnx.static(0.5) # explicit annotation overrides status to static pytree_structure(pytree, "updated") ``` ### Attribute checks `Pytree` has a variety of checks to prevent a common class of errors in JAX. This includes checking for Arrays being assigned to new `static` attributes: ```{code-cell} ipython3 class Foo(nnx.Pytree): def __init__(self, name): self.name = nnx.static(name) try: foo = Foo(name=jnp.array(123)) except ValueError as e: print("ValueError:", e) ``` Checking for Arrays being assigned to known `static` attributes: ```{code-cell} ipython3 try: foo = Foo(name="mattjj") foo.name = jnp.array(123) except ValueError as e: print("ValueError:", e) ``` Checking for Arrays after `__init__` on `static` attributes that could've been inserted via mutation. This check can be manually trigger via `nnx.check_pytree` at any time. ```{code-cell} ipython3 class Foo(nnx.Pytree): def __init__(self): self.ls = [] # treated as static for i in range(5): self.ls.append(jnp.array(i)) # append arrays into static attribute try: foo = Foo() # error: Array found in static attribute after `__init__` except ValueError as e: print("ValueError:", e) ``` Checking for `nnx.data` or `nnx.static` annotations stored inside nested structures that are not `nnx.Pytree` instances: ```{code-cell} ipython3 class Foo(nnx.Pytree): def __init__(self): self.a = [nnx.data(1), nnx.static(2)] # annotations in sub-pytree try: foo = Foo() except ValueError as e: print("ValueError:", e) ``` ### Trace-level awareness To prevent tracer leakage NNX will raise an error when trying to update the attribute of a `Pytree` or the value of a `Variable` on instances that are passed as captures to functions called by JAX transforms: ```{code-cell} ipython3 class Foo(nnx.Pytree): def __init__(self): self.count = nnx.data(0) foo = Foo() @jax.vmap # or jit, grad, shard_map, pmap, scan, etc. def increment(n): # foo passed as capture foo.count += 1 # error! try: increment(jnp.arange(5)) except Exception as e: print(f"Error: {e}") ``` ### Reference Sharing +++ As the name implies Pytrees should be trees. To check if a structure is a well-defined tree you can use the `nnx.find_duplicates` functions which will return a list of duplicates, where each duplicate is a list of path tuples. In the example below we see that `left` and `right` are shared references therefore `find_duplicates` returns a non-empty list with the paths: ```{code-cell} ipython3 class Shared(nnx.Pytree): def __init__(self): self.x = jnp.array(1.0) class Parent(nnx.Pytree): def __init__(self): self.left = Shared() self.right = self.left # reference sharing m = Parent() print(f"{nnx.find_duplicates(m) = } # not a tree") ``` The main issue is that sharing is not preserved across pytree operations including JAX transforms, and this results in unintended state duplication: ```{code-cell} ipython3 m = Parent() print(f"Before: {m.left is m.right = }") @jax.jit def f(m): print(f"Inside: {m.left is m.right = }") return m m = f(m) print(f"After: {m.left is m.right = }") ``` Reference sharing is rare in most Machine Learning applications, however if it is required you can either use the `nnx.{split, merge, state, update}` APIs to move the deduplicated state and graph definiton across the JAX transforms: ```{code-cell} ipython3 m = Parent() print(f"Before: {m.left is m.right = }") graphdef, state = nnx.split(m) @jax.jit def f(graphdef, state): m = nnx.merge(graphdef, state) print(f"Inside: {m.left is m.right = }") return nnx.state(m) state = f(graphdef, state) nnx.update(m, state) print(f"After: {m.left is m.right = }") print(f"{state = }") # deduplicated state ``` Or alternatively you can use the NNX transforms which preserve shared references: ```{code-cell} ipython3 m = Parent() print(f"Before: {m.left is m.right = }") @nnx.jit def f(m): print(f"Inside: {m.left is m.right = }") return m m = f(m) print(f"After: {m.left is m.right = }") ``` ### Turning off pytree registration `nnx.Pytree` allows you to turn off the pytree registration along with the attribute checks for subtypes by setting `pytree` type attribute option to `False`. This can be useful when upgrading to previous NNX code to newer Flax verions as you will still be able to use the NNX APIs or when creating types that should not be treated as pytree because e.g. they shared references. ```{code-cell} ipython3 class Foo(nnx.Pytree, pytree=False): def __init__(self): self.a = [jnp.array(1), jnp.array(2)] # no checks self.b = "hello" self.b = jnp.array(3) # no checks foo = Foo() @nnx.jit # can use in NNX transformations def double(foo: Foo): foo.a = [x * 2 for x in foo.a] foo.b *= 2 double(foo) print(f"{ nnx.state(foo) = }") # can be used with NNX APIs print(f"{ jax.tree_util.all_leaves([foo]) = }") # not a pytree ``` ## Module +++ NNX Modules are `Pytree`s that have two additional methods for traking intermediate values: `sow` and `perturb`. +++ ### sow `sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module. ```{code-cell} ipython3 class Block(nnx.Module): def __init__(self, din: int, dout: int, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x): y = nnx.relu(self.dropout(self.bn(self.linear(x)))) self.sow(nnx.Intermediate, "y_mean", jnp.mean(y)) return y class MLP(nnx.Module): def __init__(self, num_layers, dim, rngs: nnx.Rngs): self.blocks = nnx.List([Block(dim, dim, rngs) for _ in range(num_layers)]) def __call__(self, x): for block in self.blocks: x = block(x) return x model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0)) x = jnp.ones((10, 20)) y = model(x) intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values print(intermediates) ``` ### perturb `perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future: 1. Initialize the pertubation state by running the model once. 2. Pass the perturbation state as a differentiable target to `grad`. As an example lets create a simple model and use `perturb` to get the intermediate gradient `xgrad` for the variable `x`, and initialize the perturbations: ```{code-cell} ipython3 import optax class Model(nnx.Module): def __init__(self, rngs): self.linear1 = nnx.Linear(2, 3, rngs=rngs) self.linear2 = nnx.Linear(3, 4, rngs=rngs) def __call__(self, x): x = nnx.gelu(self.linear1(x)) x = self.perturb('xgrad', x) x = self.linear2(x) return x rngs = nnx.Rngs(0) model = Model(rngs) optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param) x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4)) _ = model(x) # initialize perturbations print(f"{nnx.state(model, nnx.Perturbation) = !s}") ``` Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients. ```{code-cell} ipython3 @nnx.jit def train_step(model, optimizer, x, y): graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation) def loss_fn(params, perturbations): model = nnx.merge(graphdef, params, perturbations) return jnp.mean((model(x) - y) ** 2) loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations) optimizer.update(model, grads) return loss, iterm_grads for step in range(2): loss, iterm_grads = train_step(model, optimizer, x, y) print(f"{step = }, {loss = }, {iterm_grads = !s}") ``` ## Object +++ `Object` are NNX types that are **not** registered as JAX pytrees. Formally, any `Object` subclass is a `nnx.Pytree` with `pytree=False`. ```{code-cell} ipython3 class Foo(nnx.Object): # instead of Foo(nnx.Pytree, pytree=False) def __init__(self): self.a = [jnp.array(1), jnp.array(2)] # no checks self.b = "hello" self.b = jnp.array(3) # no checks foo = Foo() @nnx.jit # can use in NNX transformations def double(foo: Foo): foo.a = [x * 2 for x in foo.a] foo.b *= 2 double(foo) print(f"{ nnx.state(foo) = }") # can be used with NNX APIs print(f"{ jax.tree_util.all_leaves([foo]) = }") # not a pytree ``` ================================================ FILE: docs_nnx/guides/randomness.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Randomness\n", "\n", "Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "\n", "class Model(nnx.Module):\n", " def __init__(self, *, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", " self.drop = nnx.Dropout(0.1)\n", "\n", " def __call__(self, x, *, rngs):\n", " return nnx.relu(self.drop(self.linear(x), rngs=rngs))\n", "\n", "rngs = nnx.Rngs(0)\n", "model = Model(rngs=rngs) # pass rngs to initialize parameters\n", "x = rngs.normal((32, 20)) # convenient jax.random methods\n", "y = model(x, rngs=rngs) # pass rngs for dropout masks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We always pass `nnx.Rngs` objects to models at initialization (to initialize parameters). For models with nondeterministic outputs like the one above, we also pass `nnx.Rngs` objects to the model's `__call__` method.\n", "\n", "The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics:\n", "\n", "- It is **explicit**.\n", "- It is **order-based**.\n", "- It uses **dynamic counters**.\n", "\n", "> **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html).\n", "\n", "## `Rngs`, `RngStream`, and `RngState`\n", "\n", "In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", "\n", "Here are the main PRNG-related types in Flax NNX:\n", "\n", "* **`nnx.Rngs`**: The main user interface. It defines a set of named `nnx.RngStream` objects.\n", "* **`nnx.RngStream`**: An object that can generate a stream of PRNG keys. It holds a root `key` and a `count` inside an `nnx.RngKey` and `nnx.RngCount` `nnx.Variable`s, respectively. When a new key is generated, the count is incremented.\n", "* **`nnx.RngState`**: The base type for all RNG-related states.\n", " * **`nnx.RngKey`**: NNX Variable type for holding PRNG keys. It includes a `tag` attribute containing the name of the PRNG key stream.\n", " * **`nnx.RngCount`**: NNX Variable type for holding PRNG counts. It includes a `tag` attribute containing the PRNG key stream name.\n", "\n", "To create an `nnx.Rngs` object you can simply pass an integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor.\n", "\n", "Here's an example:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rngs = nnx.Rngs(params=0, dropout=random.key(1))\n", "nnx.display(rngs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that the `key` and `count` `nnx.Variable`s contain the PRNG key stream name in a `tag` attribute. This is primarily used for filtering as we'll see later.\n", "\n", "To generate new keys, you can access one of the streams and use its `__call__` method with no arguments. This will return a new key by using `random.fold_in` with the current `key` and `count`. The `count` is then incremented so that subsequent calls will return new keys." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "params_key = rngs.params()\n", "dropout_key = rngs.dropout()\n", "\n", "nnx.display(rngs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the `key` attribute does not change when new PRNG keys are generated." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using random state with flax Modules.\n", "\n", "Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is provided using the `rngs` keyword argument at initialization." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "linear = nnx.Linear(20, 10, rngs=rngs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params` stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "dropout = nnx.Dropout(0.5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([2., 0., 2., 2.], dtype=float32)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax.numpy as jnp\n", "dropout(jnp.ones(4), rngs=rngs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `nnx.Dropout` layer will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below:\n", "\n", "| PRNG key stream name | Description |\n", "|----------------------|-----------------------------------------------|\n", "| `params` | Used for parameter initialization |\n", "| `dropout` | Used by `nnx.Dropout` to create dropout masks |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Default PRNG key stream\n", "\n", "One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be\n", "be used as a fallback when a stream is not found. To use the default PRNG key stream, you can simply pass an integer seed or `jax.random.key` as the first positional argument." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rngs = nnx.Rngs(0, params=1)\n", "\n", "key1 = rngs.params() # Call params.\n", "key2 = rngs.dropout() # Fallback to the default stream.\n", "key3 = rngs() # Call the default stream directly.\n", "\n", "nnx.display(rngs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As shown above, a PRNG key from the `default` stream can also be generated by calling the `nnx.Rngs` object itself.\n", "\n", "> **Note**\n", ">
For large projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### jax.random shorthand methods\n", "Since a very common pattern is to sample a key and immediately pass it to a function from `jax.random`, both `Rngs` and `RngStream` expose the same functions as methods with the same signature except they don't require a key:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import jax\n", "rngs = nnx.Rngs(0, params=1)\n", "\n", "# using jax.random\n", "z1 = jax.random.normal(rngs(), (2, 3))\n", "z2 = jax.random.bernoulli(rngs.params(), 0.5, (10,))\n", "\n", "# shorthand methods\n", "z1 = rngs.normal((2, 3)) # generates key from rngs.default\n", "z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Forking random state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class Model(nnx.Module):\n", " def __init__(self, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", " self.drop = nnx.Dropout(0.1)\n", "\n", " def __call__(self, x, rngs):\n", " return nnx.relu(self.drop(self.linear(x), rngs=rngs))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model = Model(rngs=nnx.Rngs(0))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "@nnx.vmap(in_axes=(None, 0, 0), out_axes=0)\n", "def model_forward(model, x, rngs):\n", " return model(x, rngs=rngs)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Rngs( # RngState: 2 (12 B)\n", " default=RngStream( # RngState: 2 (12 B)\n", " tag='default',\n", " key=RngKey( # 1 (8 B)\n", " value=Array((), dtype=key) overlaying:\n", " [0 1],\n", " tag='default'\n", " ),\n", " count=RngCount( # 1 (4 B)\n", " value=Array(1, dtype=uint32),\n", " tag='default'\n", " )\n", " )\n", " ),\n", " Rngs( # RngState: 10 (60 B)\n", " default=RngStream( # RngState: 10 (60 B)\n", " tag='default',\n", " key=RngKey( # 5 (40 B)\n", " value=Array(shape=(5,), dtype=key),\n", " tag='default'\n", " ),\n", " count=RngCount( # 5 (20 B)\n", " value=Array(shape=(5,), dtype=dtype('uint32')),\n", " tag='default'\n", " )\n", " )\n", " ))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dropout_rngs = nnx.Rngs(1)\n", "forked_rngs = dropout_rngs.fork(split=5)\n", "(dropout_rngs, forked_rngs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5, 10)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_forward(model, jnp.ones((5, 20)), forked_rngs).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Implicit Random State" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This makes the random state is just another type of Module state. Using implicit random state requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (1, 10)\n" ] } ], "source": [ "class Model(nnx.Module):\n", " def __init__(self, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(20, 10, rngs=rngs)\n", " self.drop = nnx.Dropout(0.1, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " return nnx.relu(self.drop(self.linear(x)))\n", "\n", "model = Model(nnx.Rngs(params=0, dropout=1))\n", "\n", "y = model(x=jnp.ones((1, 20)))\n", "print(f'{y.shape = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This implicit state handling style is less verbose than passing RNGs explicitly, and more closely resembles code in other deep learning frameworks like PyTorch. However, as we'll see in the following sections, using implicit state makes it less obvious how to apply jax transformations to your Modules. With explicit state, you can usually use tranforms like `jax.vmap` directly. With implicit state, you'll need to some extra tricks with `nnx.vmap` to make everything work. Because of this additional complexity, we recommend that new flax projects stick to the explicit style." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Filtering random state\n", "\n", "Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = Model(nnx.Rngs(params=0, dropout=1))\n", "\n", "rng_state = nnx.state(model, nnx.RngState) # All random states.\n", "key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys.\n", "count_state = nnx.state(model, nnx.RngCount) # Only counts.\n", "rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`.\n", "\n", "nnx.display(rng_dropout_state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reseeding\n", "\n", "In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` each time before you call the model. This makes it easy to control the randomness of the model when needed (for example, for reproducibility).\n", "\n", "In Flax NNX, there are two ways to approach this:\n", "\n", "1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously.\n", "2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state.\n", "\n", "`nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero.\n", "\n", "Here's an example of how to use `nnx.reseed` to reset the random state of the `nnx.Dropout` layer and verify that the computation is identical to the first time the model was called:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "model = Model(nnx.Rngs(params=0, dropout=1))\n", "x = jnp.ones((1, 20))\n", "\n", "y1 = model(x)\n", "y2 = model(x)\n", "\n", "nnx.reseed(model, dropout=1) # reset dropout RngState\n", "y3 = model(x)\n", "\n", "assert not jnp.allclose(y1, y2) # different\n", "assert jnp.allclose(y1, y3) # same" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Forking implicit random state\n", "\n", "We saw above how to use `rng.fork` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. The decorator `nnx.fork_rngs` allows this for implicit random state. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(5, 1, 10)\n" ] } ], "source": [ "rng_axes = nnx.StateAxes({'dropout': 0, ...: None})\n", "\n", "@nnx.fork_rngs(split={'dropout': 5})\n", "@nnx.vmap(in_axes=(rng_axes, None), out_axes=0)\n", "def sample_from_model(model, x):\n", " return model(x)\n", "\n", "print(sample_from_model(model, x).shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here `sample_from_model` is modified by two decorators:\n", "- The function we get from the `nnx.vmap` decorator expects that the random state of the `model` argument has already been split into 5 pieces. It runs the model once for each random key.\n", "- The function we get from the `nnx.fork_rngs` decorator splits the random state of its `model` argument into five pieces before passing it on to the inner function." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transforming implicit state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the previous section, we showed how to use `nnx.vmap` with a module that contained implicit random state. But we can use other `nnx` transformations too! Remember: implicit random state isn't different from any other type of Model state, and this applies to Flax NNX transforms too. This means you can use the Flax NNX state handling APIs of each transform to get the results you want. For a more involved example, let’s explore how to implement recurrent dropout on an `RNNCell` using `nnx.scan`.\n", "\n", "We'll start by constructing the `RNNCell` class:\n", "\n", "- First, create an `nnx.Dropout` layer that will sample PRNG keys from a custom `recurrent_dropout` stream.\n", "- Apply dropout (`drop`) to the hidden state `h` of the `RNNCell`.\n", "- Then, define an `initial_state` function to create the initial state of the `RNNCell`.\n", "- Finally, instantiate `RNNCell`." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "class Count(nnx.Variable): pass\n", "\n", "class RNNCell(nnx.Module):\n", " def __init__(self, din, dout, rngs):\n", " self.linear = nnx.Linear(dout + din, dout, rngs=rngs)\n", " self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout')\n", " self.dout = dout\n", " self.count = Count(jnp.array(0, jnp.uint32))\n", "\n", " def __call__(self, h, x) -> tuple[jax.Array, jax.Array]:\n", " h = self.drop(h) # Recurrent dropout.\n", " y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1)))\n", " self.count.value += 1\n", " return y, y\n", "\n", " def initial_state(self, batch_size: int):\n", " return jnp.zeros((batch_size, self.dout))\n", "\n", "cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, use `nnx.scan` over an `unroll` function to implement the `rnn_forward` operation:\n", "- The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps. Therefore, to achieve this you will pass `nnx.StateAxes` to `nnx.scan`'s `in_axes`, specifying that the `cell`'s `recurrent_dropout` PRNG stream will be broadcast, and the rest of the `RNNCell`'s state will be carried over.\n", "- Also, the hidden state `h` will be the `nnx.scan`'s `Carry` variable, and the sequence `x` will be `scan`ned over its axis `1`." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (4, 20, 16)\n", "cell.count.value = Array(20, dtype=uint32)\n" ] } ], "source": [ "@nnx.jit\n", "def rnn_forward(cell: RNNCell, x: jax.Array):\n", " h = cell.initial_state(batch_size=x.shape[0])\n", "\n", " # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step.\n", " state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry})\n", " @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1))\n", " def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]:\n", " h, y = cell(h, x)\n", " return h, y\n", "\n", " h, y = unroll(cell, h, x)\n", " return y\n", "\n", "x = jnp.ones((4, 20, 8))\n", "y = rnn_forward(cell, x)\n", "\n", "print(f'{y.shape = }')\n", "print(f'{cell.count.value = }')" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs_nnx/guides/randomness.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Randomness Flax NNX uses the stateful `nnx.Rngs` class to simplify Jax's handling of random states. For example, the code below uses a `nnx.Rngs` object to define a simple linear model with dropout: ```{code-cell} ipython3 from flax import nnx class Model(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) self.drop = nnx.Dropout(0.1) def __call__(self, x, *, rngs): return nnx.relu(self.drop(self.linear(x), rngs=rngs)) rngs = nnx.Rngs(0) model = Model(rngs=rngs) # pass rngs to initialize parameters x = rngs.normal((32, 20)) # convenient jax.random methods y = model(x, rngs=rngs) # pass rngs for dropout masks ``` We always pass `nnx.Rngs` objects to models at initialization (to initialize parameters). For models with nondeterministic outputs like the one above, we also pass `nnx.Rngs` objects to the model's `__call__` method. The Flax NNX [pseudorandom number generator (PRNG)](https://flax.readthedocs.io/en/latest/glossary.html#term-RNG-sequences) system has the following main characteristics: - It is **explicit**. - It is **order-based**. - It uses **dynamic counters**. > **Note:** To learn more about random number generation in JAX, the `jax.random` API, and PRNG-generated sequences, check out this [JAX PRNG tutorial](https://jax.readthedocs.io/en/latest/random-numbers.html). ## `Rngs`, `RngStream`, and `RngState` In Flax NNX, the `nnx.Rngs` type is the primary convenience API for managing the random state(s). Following Flax Linen's footsteps, `nnx.Rngs` have the ability to create multiple named PRNG key [streams](https://jax.readthedocs.io/en/latest/jep/263-prng.html), each with its own state, for the purpose of having tight control over randomness in the context of [JAX transformations (transforms)](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations). Here are the main PRNG-related types in Flax NNX: * **`nnx.Rngs`**: The main user interface. It defines a set of named `nnx.RngStream` objects. * **`nnx.RngStream`**: An object that can generate a stream of PRNG keys. It holds a root `key` and a `count` inside an `nnx.RngKey` and `nnx.RngCount` `nnx.Variable`s, respectively. When a new key is generated, the count is incremented. * **`nnx.RngState`**: The base type for all RNG-related states. * **`nnx.RngKey`**: NNX Variable type for holding PRNG keys. It includes a `tag` attribute containing the name of the PRNG key stream. * **`nnx.RngCount`**: NNX Variable type for holding PRNG counts. It includes a `tag` attribute containing the PRNG key stream name. To create an `nnx.Rngs` object you can simply pass an integer seed or `jax.random.key` instance to any keyword argument of your choice in the constructor. Here's an example: ```{code-cell} ipython3 rngs = nnx.Rngs(params=0, dropout=random.key(1)) nnx.display(rngs) ``` Notice that the `key` and `count` `nnx.Variable`s contain the PRNG key stream name in a `tag` attribute. This is primarily used for filtering as we'll see later. To generate new keys, you can access one of the streams and use its `__call__` method with no arguments. This will return a new key by using `random.fold_in` with the current `key` and `count`. The `count` is then incremented so that subsequent calls will return new keys. ```{code-cell} ipython3 params_key = rngs.params() dropout_key = rngs.dropout() nnx.display(rngs) ``` Note that the `key` attribute does not change when new PRNG keys are generated. +++ ### Using random state with flax Modules. Almost all flax Modules require a random state for initialization. In a `Linear` layer, for example, we need to sample the weights and biases from the appropriate Normal distribution. Random state is provided using the `rngs` keyword argument at initialization. ```{code-cell} ipython3 linear = nnx.Linear(20, 10, rngs=rngs) ``` Specifically, this will use the RngSteam `rngs.params` for weight initialization. The `params` stream is also used for initialization of `nnx.Conv`, `nnx.ConvTranspose`, and `nnx.Embed`. +++ The `nnx.Dropout` module also requires a random state, but it requires this state at *call* time rather than initialization. Once again, we can pass it random state using the `rngs` keyword argument. ```{code-cell} ipython3 dropout = nnx.Dropout(0.5) ``` ```{code-cell} ipython3 import jax.numpy as jnp dropout(jnp.ones(4), rngs=rngs) ``` The `nnx.Dropout` layer will use the rng's `dropout` stream. This also applies to Modules that use `Dropout` as a sub-Module, like `nnx.MultiHeadAttention`. +++ To summarize, there are only two standard PRNG key stream names used by Flax NNX's built-in layers, shown in the table below: | PRNG key stream name | Description | |----------------------|-----------------------------------------------| | `params` | Used for parameter initialization | | `dropout` | Used by `nnx.Dropout` to create dropout masks | +++ ### Default PRNG key stream One of the downsides of having named streams is that the user needs to know all the possible names that a model will use when creating the `nnx.Rngs` object. While this could be solved with some documentation, Flax NNX provides a `default` stream that can be be used as a fallback when a stream is not found. To use the default PRNG key stream, you can simply pass an integer seed or `jax.random.key` as the first positional argument. ```{code-cell} ipython3 rngs = nnx.Rngs(0, params=1) key1 = rngs.params() # Call params. key2 = rngs.dropout() # Fallback to the default stream. key3 = rngs() # Call the default stream directly. nnx.display(rngs) ``` As shown above, a PRNG key from the `default` stream can also be generated by calling the `nnx.Rngs` object itself. > **Note** >
For large projects it is recommended to use named streams to avoid potential conflicts. For small projects or quick prototyping just using the `default` stream is a good choice. +++ ### jax.random shorthand methods Since a very common pattern is to sample a key and immediately pass it to a function from `jax.random`, both `Rngs` and `RngStream` expose the same functions as methods with the same signature except they don't require a key: ```{code-cell} ipython3 import jax rngs = nnx.Rngs(0, params=1) # using jax.random z1 = jax.random.normal(rngs(), (2, 3)) z2 = jax.random.bernoulli(rngs.params(), 0.5, (10,)) # shorthand methods z1 = rngs.normal((2, 3)) # generates key from rngs.default z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params ``` ## Forking random state +++ Say you want to train a model that uses dropout on a batch of data. You don't want to use the same random state for every dropout mask in your batch. Instead, you want to fork the random state into separate pieces for each layer. This can be accomplished with the `fork` method, as shown below. ```{code-cell} ipython3 class Model(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) self.drop = nnx.Dropout(0.1) def __call__(self, x, rngs): return nnx.relu(self.drop(self.linear(x), rngs=rngs)) ``` ```{code-cell} ipython3 model = Model(rngs=nnx.Rngs(0)) ``` ```{code-cell} ipython3 @nnx.vmap(in_axes=(None, 0, 0), out_axes=0) def model_forward(model, x, rngs): return model(x, rngs=rngs) ``` ```{code-cell} ipython3 dropout_rngs = nnx.Rngs(1) forked_rngs = dropout_rngs.fork(split=5) (dropout_rngs, forked_rngs) ``` ```{code-cell} ipython3 model_forward(model, jnp.ones((5, 20)), forked_rngs).shape ``` The output of `rng.fork` is another `Rng` with keys and counts that have an expanded shape. In the example above, the `RngKey` and `RngCount` of `dropout_rngs` have shape `()`, but in `forked_rngs` they have shape `(5,)`. +++ # Implicit Random State +++ So far, we have looked at passing random state directly to each Module when it gets called. But there's another way to handle call-time randomness in flax: we can bundle the random state into the Module itself. This makes the random state is just another type of Module state. Using implicit random state requires passing the `rngs` keyward argument when initializing the module rather than when calling it. For example, here is how we might construct the simple `Module` we defined earlier using an implicit style. ```{code-cell} ipython3 class Model(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 10, rngs=rngs) self.drop = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x): return nnx.relu(self.drop(self.linear(x))) model = Model(nnx.Rngs(params=0, dropout=1)) y = model(x=jnp.ones((1, 20))) print(f'{y.shape = }') ``` This implicit state handling style is less verbose than passing RNGs explicitly, and more closely resembles code in other deep learning frameworks like PyTorch. However, as we'll see in the following sections, using implicit state makes it less obvious how to apply jax transformations to your Modules. With explicit state, you can usually use tranforms like `jax.vmap` directly. With implicit state, you'll need to some extra tricks with `nnx.vmap` to make everything work. Because of this additional complexity, we recommend that new flax projects stick to the explicit style. +++ ## Filtering random state Implicit random state can be manipulated using [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) just like any other type of state. It can be filtered using types (`nnx.RngState`, `nnx.RngKey`, `nnx.RngCount`) or using strings corresponding to the stream names (refer to [the Flax NNX `Filter` DSL](https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl)). Here's an example using `nnx.state` with various filters to select different substates of the `Rngs` inside a `Model`: ```{code-cell} ipython3 model = Model(nnx.Rngs(params=0, dropout=1)) rng_state = nnx.state(model, nnx.RngState) # All random states. key_state = nnx.state(model, nnx.RngKey) # Only PRNG keys. count_state = nnx.state(model, nnx.RngCount) # Only counts. rng_dropout_state = nnx.state(model, 'dropout') # Only `dropout`. nnx.display(rng_dropout_state) ``` ## Reseeding In Haiku and Flax Linen, random states are explicitly passed to `Module.apply` each time before you call the model. This makes it easy to control the randomness of the model when needed (for example, for reproducibility). In Flax NNX, there are two ways to approach this: 1. By passing an `nnx.Rngs` object through the `__call__` stack manually, as shown previously. 2. By using `nnx.reseed` to set the random state of the model to a specific configuration. This option is less intrusive and can be used even if the model is not designed to enable manual control over the random state. `nnx.reseed` is a function that accepts an arbitrary graph node (this includes [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#working-with-pytrees) of `nnx.Module`s) and some keyword arguments containing the new seed or key value for the `nnx.RngStream`s specified by the argument names. `nnx.reseed` will then traverse the graph and update the random state of the matching `nnx.RngStream`s, this includes both setting the `key` to a possibly new value and resetting the `count` to zero. Here's an example of how to use `nnx.reseed` to reset the random state of the `nnx.Dropout` layer and verify that the computation is identical to the first time the model was called: ```{code-cell} ipython3 model = Model(nnx.Rngs(params=0, dropout=1)) x = jnp.ones((1, 20)) y1 = model(x) y2 = model(x) nnx.reseed(model, dropout=1) # reset dropout RngState y3 = model(x) assert not jnp.allclose(y1, y2) # different assert jnp.allclose(y1, y3) # same ``` ## Forking implicit random state We saw above how to use `rng.fork` when passing explicit random state through [Flax NNX transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) like `nnx.vmap` or `nnx.pmap`. The decorator `nnx.fork_rngs` allows this for implicit random state. Consider the example below, which generates a batch of samples from the nondeterministic model we defined above. ```{code-cell} ipython3 rng_axes = nnx.StateAxes({'dropout': 0, ...: None}) @nnx.fork_rngs(split={'dropout': 5}) @nnx.vmap(in_axes=(rng_axes, None), out_axes=0) def sample_from_model(model, x): return model(x) print(sample_from_model(model, x).shape) ``` Here `sample_from_model` is modified by two decorators: - The function we get from the `nnx.vmap` decorator expects that the random state of the `model` argument has already been split into 5 pieces. It runs the model once for each random key. - The function we get from the `nnx.fork_rngs` decorator splits the random state of its `model` argument into five pieces before passing it on to the inner function. +++ ## Transforming implicit state +++ In the previous section, we showed how to use `nnx.vmap` with a module that contained implicit random state. But we can use other `nnx` transformations too! Remember: implicit random state isn't different from any other type of Model state, and this applies to Flax NNX transforms too. This means you can use the Flax NNX state handling APIs of each transform to get the results you want. For a more involved example, let’s explore how to implement recurrent dropout on an `RNNCell` using `nnx.scan`. We'll start by constructing the `RNNCell` class: - First, create an `nnx.Dropout` layer that will sample PRNG keys from a custom `recurrent_dropout` stream. - Apply dropout (`drop`) to the hidden state `h` of the `RNNCell`. - Then, define an `initial_state` function to create the initial state of the `RNNCell`. - Finally, instantiate `RNNCell`. ```{code-cell} ipython3 class Count(nnx.Variable): pass class RNNCell(nnx.Module): def __init__(self, din, dout, rngs): self.linear = nnx.Linear(dout + din, dout, rngs=rngs) self.drop = nnx.Dropout(0.1, rngs=rngs, rng_collection='recurrent_dropout') self.dout = dout self.count = Count(jnp.array(0, jnp.uint32)) def __call__(self, h, x) -> tuple[jax.Array, jax.Array]: h = self.drop(h) # Recurrent dropout. y = nnx.relu(self.linear(jnp.concatenate([h, x], axis=-1))) self.count.value += 1 return y, y def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.dout)) cell = RNNCell(8, 16, nnx.Rngs(params=0, recurrent_dropout=1)) ``` Next, use `nnx.scan` over an `unroll` function to implement the `rnn_forward` operation: - The key ingredient of recurrent dropout is to apply the same dropout mask across all time steps. Therefore, to achieve this you will pass `nnx.StateAxes` to `nnx.scan`'s `in_axes`, specifying that the `cell`'s `recurrent_dropout` PRNG stream will be broadcast, and the rest of the `RNNCell`'s state will be carried over. - Also, the hidden state `h` will be the `nnx.scan`'s `Carry` variable, and the sequence `x` will be `scan`ned over its axis `1`. ```{code-cell} ipython3 @nnx.jit def rnn_forward(cell: RNNCell, x: jax.Array): h = cell.initial_state(batch_size=x.shape[0]) # Broadcast the 'recurrent_dropout' PRNG state to have the same mask on every step. state_axes = nnx.StateAxes({'recurrent_dropout': None, ...: nnx.Carry}) @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) def unroll(cell: RNNCell, h, x) -> tuple[jax.Array, jax.Array]: h, y = cell(h, x) return h, y h, y = unroll(cell, h, x) return y x = jnp.ones((4, 20, 8)) y = rnn_forward(cell, x) print(f'{y.shape = }') print(f'{cell.count.value = }') ``` ================================================ FILE: docs_nnx/guides/surgery.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model surgery\n", "\n", "Model surgery is an act of making modifications on an existing neural network's building blocks and parameters, such as layer replacement, parameter or state manipulation, or even \"monkey patching\". In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios:\n", "\n", "* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model.\n", "\n", "* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation.\n", "\n", "* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n", "\n", "* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from typing import *\n", "from pprint import pprint\n", "import functools\n", "\n", "import jax\n", "from jax import lax, numpy as jnp, tree_util as jtu\n", "\n", "from jax.sharding import PartitionSpec, Mesh, NamedSharding\n", "from jax.experimental import mesh_utils\n", "import flax\n", "from flax import nnx\n", "import flax.traverse_util\n", "import numpy as np\n", "import orbax.checkpoint as orbax\n", "\n", "key = jax.random.key(0)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class TwoLayerMLP(nnx.Module):\n", " def __init__(self, dim, rngs: nnx.Rngs):\n", " self.linear1 = nnx.Linear(dim, dim, rngs=rngs)\n", " self.linear2 = nnx.Linear(dim, dim, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = self.linear1(x)\n", " return self.linear2(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pythonic `nnx.Module` manipulation\n", "\n", "It is easier to perform model surgery when:\n", "\n", "1) You already have a fully fleshed-out model loaded with correct parameters; and\n", "2) You don't intend to change your model definition code.\n", "\n", "You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "x = jax.random.normal(jax.random.key(42), (3, 4))\n", "np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n", "\n", "# Sub-`Module` swapping.\n", "original1, original2 = model.linear1, model.linear2\n", "model.linear1, model.linear2 = model.linear2, model.linear1\n", "np.testing.assert_allclose(model(x), original1(original2(x)))\n", "\n", "# `Module` sharing (tying all weights together).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear2 = model.linear1\n", "assert not hasattr(nnx.state(model), 'linear2')\n", "np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n", "\n", "# Variable sharing (weight-tying).\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n", "assert 'linear2' in nnx.state(model)\n", "assert 'bias' in nnx.state(model)['linear2']\n", "assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n", "\n", "# Monkey-patching.\n", "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "def awesome_layer(x): return x\n", "model.linear2 = awesome_layer\n", "np.testing.assert_allclose(model(x), model.linear1(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating an abstract model or state without memory allocation\n", "\n", "To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n", "\n", "To create an abstract model:\n", "\n", "* Create a function that returns a valid Flax NNX model; and\n", "* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n", "\n", "Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "State({\n", " 'linear1': {\n", " 'bias': Param( # 4 (16 B)\n", " value=ShapeDtypeStruct(shape=(4,), dtype=float32)\n", " ),\n", " 'kernel': Param( # 16 (64 B)\n", " value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)\n", " )\n", " },\n", " 'linear2': {\n", " 'bias': Param( # 4 (16 B)\n", " value=ShapeDtypeStruct(shape=(4,), dtype=float32)\n", " ),\n", " 'kernel': Param( # 16 (64 B)\n", " value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)\n", " )\n", " }\n", "})\n" ] } ], "source": [ "abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "gdef, abs_state = nnx.split(abs_model)\n", "pprint(abs_state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "abs_state['linear1']['kernel'].value = model.linear1.kernel.value\n", "abs_state['linear1']['bias'].value = model.linear1.bias.value\n", "abs_state['linear2']['kernel'].value = model.linear2.kernel.value\n", "abs_state['linear2']['bias'].value = model.linear2.bias.value\n", "nnx.update(abs_model, abs_state)\n", "np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Checkpoint surgery\n", "\n", "With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them.\n", "\n", "This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible.\n", "\n", "Let's run a simple example here:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Save a version of model into a checkpoint\n", "checkpointer = orbax.PyTreeCheckpointer()\n", "old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n", "checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This will throw error: : User-provided restore item and on-disk value metadata tree structures do not match: {'layer1': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'layer2': Diff(lhs={'bias': {'value': ShapeDtypeStruct(shape=(4,), dtype=float32)}, 'kernel': {'value': ShapeDtypeStruct(shape=(4, 4), dtype=float32)}}, rhs=None), 'linear1': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}}), 'linear2': Diff(lhs=None, rhs={'bias': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4,))}, 'kernel': {'value': ValueMetadataEntry(value_type='jax.Array', skip_deserialize=False, write_shape=(4, 4))}})}\n" ] } ], "source": [ "class ModifiedTwoLayerMLP(nnx.Module):\n", " def __init__(self, dim, rngs: nnx.Rngs):\n", " self.layer1 = nnx.Linear(dim, dim, rngs=rngs) # no longer linear1!\n", " self.layer2 = nnx.Linear(dim, dim, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = self.layer1(x)\n", " return self.layer2(x)\n", "\n", "abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "try:\n", " with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))\n", " print(with_item)\n", "except Exception as e:\n", " print(f'This will throw error: {type(e)}: {e}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'linear1': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", " 'kernel': {'value': Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626],\n", " [-0.46665004, 0.31773907, 0.38944173, -0.54608804],\n", " [ 0.84378934, -0.93099 , -0.67658 , 0.0724705 ],\n", " [-0.6101737 , 0.12972134, 0.877074 , 0.27292168]], dtype=float32)}},\n", " 'linear2': {'bias': {'value': Array([0., 0., 0., 0.], dtype=float32)},\n", " 'kernel': {'value': Array([[ 0.67979455, 0.7079946 , -0.22166717, -0.4147039 ],\n", " [ 0.20622818, 0.01024843, 0.31011865, -0.40491563],\n", " [ 0.12478007, -0.7697264 , -0.48899388, 0.8853114 ],\n", " [-0.5123713 , -0.23335123, 0.4374407 , 0.63321066]], dtype=float32)}}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1251: UserWarning: Sharding info not provided when restoring. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n", " warnings.warn(\n" ] } ], "source": [ "def process_raw_dict(raw_state_dict):\n", " flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n", " # Cut the '.value' postfix on every leaf path.\n", " flattened = {(path[:-1] if path[-1] == 'value' else path): value\n", " for path, value in flattened.items()}\n", " return nnx.traversals.unflatten_mapping(flattened)\n", "\n", "# Make your local change on the checkpoint dictionary.\n", "raw_dict = checkpointer.restore('/tmp/nnx-surgery-state')\n", "pprint(raw_dict)\n", "raw_dict['layer1'] = raw_dict.pop('linear1')\n", "raw_dict['layer2'] = raw_dict.pop('linear2')\n", "\n", "# Fit it into the model state.\n", "abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "graph_def, state = nnx.split(abs_model)\n", "nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))\n", "restored_model = nnx.merge(graph_def, state)\n", "\n", "np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partial initialization\n", "\n", "In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through:\n", "\n", "- Naive partial initialization; or\n", "- Memory-efficient partial initialization." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Naive partial initialization\n", "\n", "To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this.\n", "\n", "> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of jax arrays in memory at start: 38\n", "Number of jax arrays in memory midway: 42 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)\n", "Number of jax arrays in memory at end: 40 (2 discarded - only lora_a & lora_b are used in model)\n" ] } ], "source": [ "# Some pretrained model state\n", "old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "\n", "simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n", "print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n", "# In this line, extra kernel and bias is created inside the new LoRALinear!\n", "# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.\n", "simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n", "print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n", " ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n", "nnx.update(simple_model, old_state)\n", "print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n", " ' (2 discarded - only lora_a & lora_b are used in model)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Memory-efficient partial initialization\n", "\n", "To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of JAX Arrays in memory at start: 44\n", "Number of JAX Arrays in memory at end: 50 (2 new created - lora_a and lora_b)\n" ] } ], "source": [ "# Some pretrained model state\n", "old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n", "\n", "# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!\n", "@nnx.jit(donate_argnums=0)\n", "def partial_init(old_state, rngs):\n", " model = TwoLayerMLP(4, rngs=rngs)\n", " # Create a new state.\n", " model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)\n", " # Add the existing state.\n", " nnx.update(model, old_state)\n", " return model\n", "\n", "print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')\n", "# Note that `old_state` will be deleted after this `partial_init` call.\n", "good_model = partial_init(old_state, nnx.Rngs(42))\n", "print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'\n", " ' (2 new created - lora_a and lora_b)')" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs_nnx/guides/surgery.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Model surgery Model surgery is an act of making modifications on an existing neural network's building blocks and parameters, such as layer replacement, parameter or state manipulation, or even "monkey patching". In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios: * __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model. * __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnx.Module`s and states without memory allocation. * __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code. * __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method. ```{code-cell} ipython3 from typing import * from pprint import pprint import functools import jax from jax import lax, numpy as jnp, tree_util as jtu from jax.sharding import PartitionSpec, Mesh, NamedSharding from jax.experimental import mesh_utils import flax from flax import nnx import flax.traverse_util import numpy as np import orbax.checkpoint as orbax key = jax.random.key(0) ``` ```{code-cell} ipython3 class TwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) def __call__(self, x): x = self.linear1(x) return self.linear2(x) ``` ## Pythonic `nnx.Module` manipulation It is easier to perform model surgery when: 1) You already have a fully fleshed-out model loaded with correct parameters; and 2) You don't intend to change your model definition code. You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching: ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) np.testing.assert_allclose(model(x), model.linear2(model.linear1(x))) # Sub-`Module` swapping. original1, original2 = model.linear1, model.linear2 model.linear1, model.linear2 = model.linear2, model.linear1 np.testing.assert_allclose(model(x), original1(original2(x))) # `Module` sharing (tying all weights together). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear2 = model.linear1 assert not hasattr(nnx.state(model), 'linear2') np.testing.assert_allclose(model(x), model.linear1(model.linear1(x))) # Variable sharing (weight-tying). model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate assert 'linear2' in nnx.state(model) assert 'bias' in nnx.state(model)['linear2'] assert not hasattr(nnx.state(model)['linear2'], 'kernel') # Monkey-patching. model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) def awesome_layer(x): return x model.linear2 = awesome_layer np.testing.assert_allclose(model(x), model.linear1(x)) ``` ## Creating an abstract model or state without memory allocation To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints. To create an abstract model: * Create a function that returns a valid Flax NNX model; and * Run `nnx.eval_shape` (not `jax.eval_shape`) upon it. Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information. ```{code-cell} ipython3 abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0))) gdef, abs_state = nnx.split(abs_model) pprint(abs_state) ``` When you fill every `nnx.Variable` pytree leaf's `value` attributes with real `jax.Array`s, the abstract model becomes equivalent to a real model. ```{code-cell} ipython3 model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) abs_state['linear1']['kernel'].value = model.linear1.kernel.value abs_state['linear1']['bias'].value = model.linear1.bias.value abs_state['linear2']['kernel'].value = model.linear2.kernel.value abs_state['linear2']['bias'].value = model.linear2.bias.value nnx.update(abs_model, abs_state) np.testing.assert_allclose(abs_model(x), model(x)) # They are equivalent now! ``` ## Checkpoint surgery With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them. This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible. Let's run a simple example here: ```{code-cell} ipython3 # Save a version of model into a checkpoint checkpointer = orbax.PyTreeCheckpointer() old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0)) checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True) ``` In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to directly load the old checkpoint with the new model state structure: ```{code-cell} ipython3 class ModifiedTwoLayerMLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.layer1 = nnx.Linear(dim, dim, rngs=rngs) # no longer linear1! self.layer2 = nnx.Linear(dim, dim, rngs=rngs) def __call__(self, x): x = self.layer1(x) return self.layer2(x) abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))) try: with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model)) print(with_item) except Exception as e: print(f'This will throw error: {type(e)}: {e}') ``` However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition. ```{code-cell} ipython3 def process_raw_dict(raw_state_dict): flattened = nnx.traversals.flatten_mapping(raw_state_dict) # Cut the '.value' postfix on every leaf path. flattened = {(path[:-1] if path[-1] == 'value' else path): value for path, value in flattened.items()} return nnx.traversals.unflatten_mapping(flattened) # Make your local change on the checkpoint dictionary. raw_dict = checkpointer.restore('/tmp/nnx-surgery-state') pprint(raw_dict) raw_dict['layer1'] = raw_dict.pop('linear1') raw_dict['layer2'] = raw_dict.pop('linear2') # Fit it into the model state. abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))) graph_def, state = nnx.split(abs_model) nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict)) restored_model = nnx.merge(graph_def, state) np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4)))) ``` ## Partial initialization In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through: - Naive partial initialization; or - Memory-efficient partial initialization. +++ ### Naive partial initialization To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this. > **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be “messed up” when you run a single Jupyter notebook cell multiple times (due to garbage-collection of old Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output. ```{code-cell} ipython3 # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') # In this line, extra kernel and bias is created inside the new LoRALinear! # They are wasted, because you are going to use the kernel and bias in `old_state` anyway. simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42)) print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}' ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)') nnx.update(simple_model, old_state) print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ' (2 discarded - only lora_a & lora_b are used in model)') ``` ### Memory-efficient partial initialization To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized: ```{code-cell} ipython3 # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) # Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient! @nnx.jit(donate_argnums=0) def partial_init(old_state, rngs): model = TwoLayerMLP(4, rngs=rngs) # Create a new state. model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs) # Add the existing state. nnx.update(model, old_state) return model print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}') # Note that `old_state` will be deleted after this `partial_init` call. good_model = partial_init(old_state, nnx.Rngs(42)) print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}' ' (2 new created - lora_a and lora_b)') ``` ================================================ FILE: docs_nnx/guides/tiny_nnx.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Tiny NNX\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cgarciae/nnx/blob/main/docs/tiny_nnx.ipynb)\n", "\n", "A pedagogical implementation of NNX's core APIs.\n", "\n", "## Core APIs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import dataclasses\n", "import hashlib\n", "import typing as tp\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import random\n", "\n", "A = tp.TypeVar(\"A\")\n", "M = tp.TypeVar(\"M\", bound=\"Module\")\n", "Sharding = tp.Tuple[tp.Optional[str], ...]\n", "Array = jax.Array\n", "\n", "\n", "class Variable(tp.Generic[A]):\n", "\n", " def __init__(\n", " self,\n", " value: A,\n", " *,\n", " sharding: tp.Optional[Sharding] = None,\n", " ):\n", " self.value = value\n", " self.sharding = sharding\n", "\n", " def __repr__(self) -> str:\n", " return (\n", " f\"{type(self).__name__}(value={self.value}, sharding={self.sharding})\"\n", " )\n", "\n", " def __init_subclass__(cls):\n", " super().__init_subclass__()\n", " jax.tree_util.register_pytree_node(\n", " cls,\n", " lambda x: ((x.value,), (x.sharding,)),\n", " lambda metadata, value: cls(value[0], sharding=metadata[0]),\n", " )\n", "\n", "\n", "class State(dict[str, Variable[tp.Any]]):\n", "\n", " def extract(self, variable_type: tp.Type[Variable]) -> \"State\":\n", " return State(\n", " {\n", " path: variable\n", " for path, variable in self.items()\n", " if isinstance(variable, variable_type)\n", " }\n", " )\n", "\n", " def __repr__(self) -> str:\n", " elems = \",\\n \".join(\n", " f\"'{path}': {variable}\".replace(\"\\n\", \"\\n \")\n", " for path, variable in self.items()\n", " )\n", " return f\"State({{\\n {elems}\\n}})\"\n", "\n", "\n", "jax.tree_util.register_pytree_node(\n", " State,\n", " # in reality, values and paths should be sorted by path\n", " lambda x: (tuple(x.values()), tuple(x.keys())),\n", " lambda paths, values: State(dict(zip(paths, values))),\n", ")\n", "\n", "\n", "@dataclasses.dataclass\n", "class GraphDef(tp.Generic[M]):\n", " type: tp.Type[M]\n", " index: int\n", " submodules: dict[str, tp.Union[\"GraphDef[Module]\", int]]\n", " static_fields: dict[str, tp.Any]\n", "\n", " def merge(self, state: State) -> M:\n", " module = GraphDef._build_module_recursive(self, {})\n", " module.update(state)\n", " return module\n", "\n", " @staticmethod\n", " def _build_module_recursive(\n", " graphdef: tp.Union[\"GraphDef[M]\", int],\n", " index_to_module: dict[int, \"Module\"],\n", " ) -> M:\n", " if isinstance(graphdef, int):\n", " return index_to_module[graphdef] # type: ignore\n", "\n", " assert graphdef.index not in index_to_module\n", "\n", " # add a dummy module to the index to avoid infinite recursion\n", " module = object.__new__(graphdef.type)\n", " index_to_module[graphdef.index] = module\n", "\n", " submodules = {\n", " name: GraphDef._build_module_recursive(submodule, index_to_module)\n", " for name, submodule in graphdef.submodules.items()\n", " }\n", " vars(module).update(graphdef.static_fields)\n", " vars(module).update(submodules)\n", " return module\n", "\n", " def apply(\n", " self, state: State\n", " ) -> tp.Callable[..., tuple[tp.Any, tuple[State, \"GraphDef[M]\"]]]:\n", " def _apply(*args, **kwargs):\n", " module = self.merge(state)\n", " out = module(*args, **kwargs) # type: ignore\n", " return out, module.split()\n", "\n", " return _apply\n", "\n", "\n", "class Module:\n", "\n", " def split(self: M) -> tp.Tuple[State, GraphDef[M]]:\n", " state = State()\n", " graphdef = Module._partition_recursive(\n", " module=self, module_id_to_index={}, path_parts=(), state=state\n", " )\n", " assert isinstance(graphdef, GraphDef)\n", " return state, graphdef\n", "\n", " @staticmethod\n", " def _partition_recursive(\n", " module: M,\n", " module_id_to_index: dict[int, int],\n", " path_parts: tp.Tuple[str, ...],\n", " state: State,\n", " ) -> tp.Union[GraphDef[M], int]:\n", " if id(module) in module_id_to_index:\n", " return module_id_to_index[id(module)]\n", "\n", " index = len(module_id_to_index)\n", " module_id_to_index[id(module)] = index\n", "\n", " submodules = {}\n", " static_fields = {}\n", "\n", " # iterate fields sorted by name to ensure deterministic order\n", " for name, value in sorted(vars(module).items(), key=lambda x: x[0]):\n", " value_path = (*path_parts, name)\n", " # if value is a Module, recurse\n", " if isinstance(value, Module):\n", " submoduledef = Module._partition_recursive(\n", " value, module_id_to_index, value_path, state\n", " )\n", " submodules[name] = submoduledef\n", " # if value is a Variable, add to state\n", " elif isinstance(value, Variable):\n", " state[\"/\".join(value_path)] = value\n", " else: # otherwise, add to graphdef fields\n", " static_fields[name] = value\n", "\n", " return GraphDef(\n", " type=type(module),\n", " index=index,\n", " submodules=submodules,\n", " static_fields=static_fields,\n", " )\n", "\n", " def update(self, state: State) -> None:\n", " for path, value in state.items():\n", " path_parts = path.split(\"/\")\n", " Module._set_value_at_path(self, path_parts, value)\n", "\n", " @staticmethod\n", " def _set_value_at_path(\n", " module: \"Module\", path_parts: tp.Sequence[str], value: Variable[tp.Any]\n", " ) -> None:\n", " if len(path_parts) == 1:\n", " setattr(module, path_parts[0], value)\n", " else:\n", " Module._set_value_at_path(\n", " getattr(module, path_parts[0]), path_parts[1:], value\n", " )\n", "\n", "\n", "@dataclasses.dataclass\n", "class Rngs:\n", " key: jax.Array\n", " count: int = 0\n", " count_path: tuple[int, ...] = ()\n", "\n", " def fork(self) -> \"Rngs\":\n", " \"\"\"Forks the context, guaranteeing that all the random numbers generated\n", " will be different from the ones generated in the original context. Fork is\n", " used to create a new Rngs that can be passed to a JAX transform\"\"\"\n", " count_path = self.count_path + (self.count,)\n", " self.count += 1\n", " return Rngs(self.key, count_path=count_path)\n", "\n", " def make_rng(self) -> jax.Array:\n", " fold_data = self._stable_hash(self.count_path + (self.count,))\n", " self.count += 1\n", " return random.fold_in(self.key, fold_data) # type: ignore\n", "\n", " @staticmethod\n", " def _stable_hash(data: tuple[int, ...]) -> int:\n", " hash_str = \" \".join(str(x) for x in data)\n", " _hash = hashlib.blake2s(hash_str.encode())\n", " hash_bytes = _hash.digest()\n", " # uint32 is represented as 4 bytes in big endian\n", " return int.from_bytes(hash_bytes[:4], byteorder=\"big\")\n", "\n", "\n", "# in the real NNX Rngs is not a pytree, instead\n", "# it has a split/merge API similar to Module\n", "# but for simplicity we use a pytree here\n", "jax.tree_util.register_pytree_node(\n", " Rngs,\n", " lambda x: ((x.key,), (x.count, x.count_path)),\n", " lambda metadata, value: Rngs(value[0], *metadata),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Layers" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Param(Variable[A]):\n", " pass\n", "\n", "\n", "class BatchStat(Variable[A]):\n", " pass\n", "\n", "\n", "class Linear(Module):\n", "\n", " def __init__(self, din: int, dout: int, *, rngs: Rngs):\n", " self.din = din\n", " self.dout = dout\n", " key = rngs.make_rng()\n", " self.w = Param(random.uniform(key, (din, dout)))\n", " self.b = Param(jnp.zeros((dout,)))\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return x @ self.w.value + self.b.value\n", "\n", "\n", "class BatchNorm(Module):\n", "\n", " def __init__(self, din: int, mu: float = 0.95):\n", " self.mu = mu\n", " self.scale = Param(jax.numpy.ones((din,)))\n", " self.bias = Param(jax.numpy.zeros((din,)))\n", " self.mean = BatchStat(jax.numpy.zeros((din,)))\n", " self.var = BatchStat(jax.numpy.ones((din,)))\n", "\n", " def __call__(self, x, train: bool) -> jax.Array:\n", " if train:\n", " axis = tuple(range(x.ndim - 1))\n", " mean = jax.numpy.mean(x, axis=axis)\n", " var = jax.numpy.var(x, axis=axis)\n", " # ema update\n", " self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean\n", " self.var.value = self.mu * self.var.value + (1 - self.mu) * var\n", " else:\n", " mean, var = self.mean.value, self.var.value\n", "\n", " scale, bias = self.scale.value, self.bias.value\n", " x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias\n", " return x\n", "\n", "\n", "class Dropout(Module):\n", "\n", " def __init__(self, rate: float):\n", " self.rate = rate\n", "\n", " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", " if train:\n", " mask = random.bernoulli(rngs.make_rng(), (1 - self.rate), x.shape)\n", " x = x * mask / (1 - self.rate)\n", " return x" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Scan Over Layers Example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Block(Module):\n", "\n", " def __init__(self, din: int, dout: int, *, rngs: Rngs):\n", " self.linear = Linear(din, dout, rngs=rngs)\n", " self.bn = BatchNorm(dout)\n", " self.dropout = Dropout(0.1)\n", "\n", " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", " x = self.linear(x)\n", " x = self.bn(x, train=train)\n", " x = jax.nn.gelu(x)\n", " x = self.dropout(x, train=train, rngs=rngs)\n", " return x\n", "\n", "\n", "class ScanMLP(Module):\n", "\n", " def __init__(self, hidden_size: int, n_layers: int, *, rngs: Rngs):\n", " self.n_layers = n_layers\n", "\n", " # lift init\n", " key = random.split(rngs.make_rng(), n_layers - 1)\n", " graphdef: GraphDef[Block] = None # type: ignore\n", "\n", " def init_fn(key):\n", " nonlocal graphdef\n", " state, graphdef = Block(\n", " hidden_size, hidden_size, rngs=Rngs(key)\n", " ).split()\n", " return state\n", "\n", " state = jax.vmap(init_fn)(key)\n", " self.layers = graphdef.merge(state)\n", " self.linear = Linear(hidden_size, hidden_size, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:\n", " # lift call\n", " key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1) # type: ignore\n", " state, graphdef = self.layers.split()\n", "\n", " def scan_fn(x, inputs: tuple[jax.Array, State]):\n", " key, state = inputs\n", " x, (state, _) = graphdef.apply(state)(x, train=train, rngs=Rngs(key))\n", " return x, state\n", "\n", " x, state = jax.lax.scan(scan_fn, x, (key, state))\n", " self.layers.update(state)\n", " x = self.linear(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "state = State({\n", " 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n", " 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n", " 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n", " 'layers/bn/var': BatchStat(value=(4, 10), sharding=None),\n", " 'layers/linear/b': Param(value=(4, 10), sharding=None),\n", " 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n", " 'linear/b': Param(value=(10,), sharding=None),\n", " 'linear/w': Param(value=(10, 10), sharding=None)\n", "})\n", "graphdef = GraphDef(type=, index=0, submodules={'layers': GraphDef(type=, index=1, submodules={'bn': GraphDef(type=, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': GraphDef(type=, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': GraphDef(type=, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': GraphDef(type=, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})\n" ] } ], "source": [ "module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))\n", "x = jax.random.normal(random.key(0), (2, 10))\n", "y = module(x, train=True, rngs=Rngs(random.key(1)))\n", "\n", "state, graphdef = module.split()\n", "print(\"state =\", jax.tree.map(jnp.shape, state))\n", "print(\"graphdef =\", graphdef)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Filtering State" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "params = State({\n", " 'layers/bn/bias': Param(value=(4, 10), sharding=None),\n", " 'layers/bn/scale': Param(value=(4, 10), sharding=None),\n", " 'layers/linear/b': Param(value=(4, 10), sharding=None),\n", " 'layers/linear/w': Param(value=(4, 10, 10), sharding=None),\n", " 'linear/b': Param(value=(10,), sharding=None),\n", " 'linear/w': Param(value=(10, 10), sharding=None)\n", "})\n", "batch_stats = State({\n", " 'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),\n", " 'layers/bn/var': BatchStat(value=(4, 10), sharding=None)\n", "})\n" ] } ], "source": [ "# split\n", "params = state.extract(Param)\n", "batch_stats = state.extract(BatchStat)\n", "# merge\n", "state = State({**params, **batch_stats})\n", "\n", "print(\"params =\", jax.tree.map(jnp.shape, params))\n", "print(\"batch_stats =\", jax.tree.map(jnp.shape, batch_stats))" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: docs_nnx/guides/transforms.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "962be290", "metadata": {}, "source": [ "# Transformations\n", "\n", "In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s\n", "and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects\n", "that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX\n", "transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving\n", "reference semantics.\n", "\n", "Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the\n", "same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with\n", "Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes:\n", "\n", "* Preserving shared references across multiple objects in the inputs and outputs of the transformation.\n", "* Propagating any state changes made to the objects inside the transformation to the objects outside the transformation.\n", "* Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs." ] }, { "cell_type": "code", "execution_count": null, "id": "8d645146", "metadata": {}, "outputs": [], "source": [ "import jax\n", "from jax import numpy as jnp, random\n", "from flax import nnx" ] }, { "cell_type": "markdown", "id": "b44fb248", "metadata": {}, "source": [ "Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles\n", "outlined in this document extends to all transforms.\n", "\n", "## Basic example\n", "\n", "To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on\n", "batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed\n", "as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use\n", "`nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`:" ] }, { "cell_type": "code", "execution_count": 2, "id": "4eab27a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (3, 10)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/ivyzheng/envs/f1/lib/python3.12/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook :\n", "Traceback (most recent call last):\n", " File \"/Users/ivyzheng/envs/f1/lib/python3.12/site-packages/treescope/renderers.py\", line 225, in _render_subtree\n", " postprocessed_result = hook(\n", " ^^^^^\n", " File \"/Users/ivyzheng/envs/f1/lib/python3.12/site-packages/treescope/_internal/handlers/autovisualizer_hook.py\", line 47, in use_autovisualizer_if_present\n", " result = autoviz(node, path)\n", " ^^^^^^^^^^^^^^^^^^^\n", " File \"/Users/ivyzheng/envs/f1/lib/python3.12/site-packages/treescope/_internal/api/array_autovisualizer.py\", line 306, in __call__\n", " jax.sharding.PositionalSharding\n", " File \"/Users/ivyzheng/envs/f1/lib/python3.12/site-packages/jax/_src/deprecations.py\", line 54, in getattr\n", " raise AttributeError(message)\n", "AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0\n", "\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class Weights(nnx.Module):\n", " def __init__(self, kernel: jax.Array, bias: jax.Array):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", "\n", "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", ")\n", "x = jax.random.normal(random.key(1), (10, 2))\n", "\n", "def vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " return x @ weights.kernel + weights.bias\n", "\n", "y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", "print(f'{y.shape = }')\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "d2b222eb", "metadata": {}, "source": [ "Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case.\n", "\n", "Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example,\n", "you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of\n", "`Weights` with the same shapes as before:" ] }, { "cell_type": "code", "execution_count": 3, "id": "0b076a0f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def create_weights(seed: jax.Array):\n", " return Weights(\n", " kernel=random.uniform(random.key(seed), (2, 3)),\n", " bias=jnp.zeros((3,)),\n", " )\n", "\n", "seeds = jnp.arange(10)\n", "weights = nnx.vmap(create_weights)(seeds)\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "fac3dca9", "metadata": {}, "source": [ "## Transforming methods\n", "\n", "Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`:" ] }, { "cell_type": "code", "execution_count": 4, "id": "5d9a55fd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (3, 10)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class WeightStack(nnx.Module):\n", " @nnx.vmap\n", " def __init__(self, seed: jax.Array):\n", " self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))\n", " self.bias = nnx.Param(jnp.zeros((3,)))\n", "\n", " @nnx.vmap(in_axes=0, out_axes=1)\n", " def __call__(self, x: jax.Array):\n", " assert self.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " return x @ self.kernel + self.bias\n", "\n", "weights = WeightStack(jnp.arange(10))\n", "\n", "x = jax.random.normal(random.key(1), (10, 2))\n", "y = weights(x)\n", "\n", "print(f'{y.shape = }')\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "13b52d61", "metadata": {}, "source": [ "The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style." ] }, { "cell_type": "markdown", "id": "0251e7db", "metadata": {}, "source": [ "## State propagation\n", "\n", "So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding\n", "a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function:" ] }, { "cell_type": "code", "execution_count": 5, "id": "a4fbadb3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Count( # 10 (40 B)\n", " value=Array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int32)\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Count(nnx.Variable): pass\n", "\n", "class Weights(nnx.Module):\n", " def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.arange(10),\n", ")\n", "x = jax.random.normal(random.key(1), (10, 2))\n", "\n", "def stateful_vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " weights.count += 1\n", " return x @ weights.kernel + weights.bias\n", "\n", "\n", "y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", "weights.count" ] }, { "cell_type": "markdown", "id": "322312ee", "metadata": {}, "source": [ "After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice!" ] }, { "cell_type": "markdown", "id": "7294661f", "metadata": {}, "source": [ "### Graph updates propagation\n", "\n", "JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported).\n", "\n", "This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit!\n", "\n", "The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation:" ] }, { "cell_type": "code", "execution_count": 6, "id": "76c58a29", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class Count(nnx.Variable): pass\n", "\n", "class Weights(nnx.Module):\n", " def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.arange(10),\n", ")\n", "x = jax.random.normal(random.key(1), (10, 2))\n", "\n", "def crazy_vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " weights.count += 1\n", " y = x @ weights.kernel + weights.bias\n", " weights.some_property = ['a', 2, False] # add attribute\n", " del weights.bias # delete attribute\n", " weights.new_param = weights.kernel # share reference\n", " return y\n", "\n", "y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x)\n", "\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "743bcc34", "metadata": {}, "source": [ "> With great power comes great responsibility.\n", ">
\\- Uncle Ben\n", "\n", "While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error." ] }, { "cell_type": "markdown", "id": "0d11d191", "metadata": {}, "source": [ "## Transforming sub-states (lift types)\n", "\n", "Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation:\n", "\n", "| Lift type | JAX transforms |\n", "|------------------|-----------------------------------------|\n", "| `StateAxes` | `vmap`, `pmap`, `scan` |\n", "| `StateSharding` | `jit`, `shard_map` |\n", "| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |\n", "\n", "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n", "\n", "Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n", "To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object." ] }, { "cell_type": "code", "execution_count": 7, "id": "d10aee8a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Count( # 1 (4 B)\n", " value=Array(1, dtype=int32, weak_type=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class Weights(nnx.Module):\n", " def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", "\n", "weights = Weights(\n", " kernel=random.uniform(random.key(0), (10, 2, 3)),\n", " bias=jnp.zeros((10, 3)),\n", " count=jnp.array(0),\n", ")\n", "x = jax.random.normal(random.key(1), (10, 2))\n", "\n", "\n", "def stateful_vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " weights.count += 1\n", " return x @ weights.kernel + weights.bias\n", "\n", "state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count\n", "y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x)\n", "\n", "weights.count" ] }, { "cell_type": "markdown", "id": "1cfd87e1", "metadata": {}, "source": [ "Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects." ] }, { "cell_type": "markdown", "id": "1c8bb104", "metadata": {}, "source": [ "### Random state\n", "\n", "In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need to be aware of how the state is laid out so we can transform it correctly.\n", "\n", "Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element.\n", "\n", "To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`." ] }, { "cell_type": "code", "execution_count": 8, "id": "33c284b6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class Weights(nnx.Module):\n", " def __init__(self, kernel, bias, count, seed):\n", " self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)\n", " self.count = Count(count)\n", " self.rngs = nnx.Rngs(noise=seed)\n", "\n", "weights = Weights(\n", " kernel=random.uniform(random.key(0), (2, 3)),\n", " bias=jnp.zeros((3,)),\n", " count=jnp.array(0),\n", " seed=random.split(random.key(0), num=10),\n", ")\n", "x = random.normal(random.key(1), (10, 2))\n", "\n", "def noisy_vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " weights.count += 1\n", " y = x @ weights.kernel + weights.bias\n", " return y + random.normal(weights.rngs.noise(), y.shape)\n", "\n", "state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})\n", "y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n", "y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x)\n", "\n", "print(jnp.allclose(y1, y2))\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "6f26b99f", "metadata": {}, "source": [ "Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called.\n", "\n", "In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then \"lower\" it back so that it becomes usable." ] }, { "cell_type": "code", "execution_count": 9, "id": "8c9e5858", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "weights = Weights(\n", " kernel=random.uniform(random.key(0), (2, 3)),\n", " bias=jnp.zeros((3,)),\n", " count=jnp.array(0),\n", " seed=0,\n", ")\n", "x = random.normal(random.key(1), (10, 2))\n", "\n", "state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None})\n", "\n", "@nnx.split_rngs(splits=10)\n", "@nnx.vmap(in_axes=(state_axes, 0))\n", "def noisy_vector_dot(weights: Weights, x: jax.Array):\n", " assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'\n", " assert x.ndim == 1, 'Batch dimensions not allowed'\n", " weights.count += 1\n", " y = x @ weights.kernel + weights.bias\n", " return y + random.normal(weights.rngs.noise(), y.shape)\n", "\n", "y1 = noisy_vector_dot(weights, x)\n", "y2 = noisy_vector_dot(weights, x)\n", "\n", "print(jnp.allclose(y1, y2))\n", "nnx.display(weights)" ] }, { "cell_type": "markdown", "id": "60eee7f9", "metadata": {}, "source": [ "## Rules and limitations\n", "In this section we will cover some rules and limitations apply when using Modules inside transformations.\n", "\n", "### Mutable Module cannot be passed by closure\n", "\n", "While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function.\n", "\n", "For example, if we have a stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this:" ] }, { "cell_type": "code", "execution_count": 10, "id": "f8b95c03", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cannot mutate Param from a different trace level (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.TraceContextError)\n" ] } ], "source": [ "class Counter(nnx.Module):\n", " def __init__(self):\n", " self.count = nnx.Param(jnp.array(0))\n", "\n", " def increment(self):\n", " self.count += jnp.array(1)\n", "\n", "counter = Counter()\n", "\n", "@nnx.jit\n", "def f(x):\n", " counter.increment()\n", " return 2 * x\n", "\n", "try:\n", " y = f(3)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "6f37e23b", "metadata": {}, "source": [ "To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument." ] }, { "cell_type": "markdown", "id": "75edf7a8", "metadata": {}, "source": [ "### Consistent aliasing\n", "\n", "The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error." ] }, { "cell_type": "code", "execution_count": 11, "id": "46b1cc25", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", "Node: Param\n", " param: 0\n", " param: 0\n", " param: 1\n" ] } ], "source": [ "class Weights(nnx.Module):\n", " def __init__(self, array: jax.Array):\n", " self.param = nnx.Param(array)\n", "\n", "m = Weights(jnp.arange(10))\n", "arg1 = {'a': {'b': m}, 'c': m}\n", "arg2 = [(m, m), m]\n", "\n", "@nnx.vmap(in_axes=(0, 1))\n", "def f(arg1, arg2):\n", " ...\n", "\n", "try:\n", " f(arg1, arg2)\n", "except ValueError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "46aa978c", "metadata": {}, "source": [ "Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error." ] }, { "cell_type": "code", "execution_count": 12, "id": "cca9cf31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Inconsistent aliasing detected. The following nodes have different prefixes:\n", "Node: Param\n", " param: 0\n", " param: 0\n", " param: 1\n" ] } ], "source": [ "@nnx.vmap(in_axes=0, out_axes=1)\n", "def f(arg1):\n", " return arg1\n", "\n", "try:\n", " f(arg1)\n", "except ValueError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "13f9aeea", "metadata": {}, "source": [ "## Axis metadata\n", "\n", "Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`).\n", "\n", "However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed.\n", "\n", "To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`.\n", "\n", "Let's see an example of this in action:" ] }, { "cell_type": "code", "execution_count": null, "id": "d85c772c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Inner m.param.shape = (3, 5)\n", "Inner m.param.out_sharding = ('a', None)\n", "Outter m.param.shape = (3, 4, 5)\n", "Outter m.param.out_sharding = ('a', 'b', None)\n" ] } ], "source": [ "mesh = jax.make_mesh((1, 1), ('a', 'b'))\n", "\n", "class Weights(nnx.Module):\n", " def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):\n", " self.param = nnx.Param(array, out_sharding=out_sharding)\n", "\n", "@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", "def f(m: Weights):\n", " print(f'Inner {m.param.shape = }')\n", " print(f'Inner {m.param.out_sharding = }')\n", "\n", "with jax.set_mesh(mesh):\n", " m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))\n", " f(m)\n", "\n", "print(f'Outter {m.param.shape = }')\n", "print(f'Outter {m.param.out_sharding = }')" ] }, { "cell_type": "markdown", "id": "a23bda09", "metadata": {}, "source": [ "Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.\n", "\n", "You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s." ] }, { "cell_type": "code", "execution_count": null, "id": "358e51f7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Outter m.param.shape = (3, 4, 5)\n", "Outter m.param.out_sharding = ('a', 'b', None)\n" ] } ], "source": [ "@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", "def init_vmap():\n", " return Weights(jnp.ones((3, 5)), out_sharding=('a', None))\n", "\n", "with jax.set_mesh(mesh):\n", " m = init_vmap()\n", "print(f'Outter {m.param.shape = }')\n", "print(f'Outter {m.param.out_sharding = }')" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "formats": "ipynb,md:myst", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/transforms.md ================================================ --- jupytext: cell_metadata_filter: -all formats: ipynb,md:myst main_language: python text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Transformations In general, JAX transformations (transforms) operate on [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) of `jax.Array`s and abide by value semantics. This presents a challenge for Flax NNX, which represents `nnx.Module`s as regular Python objects that follow reference semantics. To address this, Flax NNX introduced its own set of transforms that extend JAX transforms to allow `nnx.Module`s and other Flax NNX objects to be passed in and out of transforms while preserving reference semantics. Flax NNX transforms should feel quite familiar if you have used JAX transforms before. They use the same APIs and behave like the JAX transforms when only working with pytrees of `jax.Array`s. However, when working with Flax NNX objects, they allow Python's reference semantics to be preserved for these objects, this includes: * Preserving shared references across multiple objects in the inputs and outputs of the transformation. * Propagating any state changes made to the objects inside the transformation to the objects outside the transformation. * Enforcing consistency of how objects are transformed when aliases are present across multiple inputs and outputs. ```{code-cell} ipython3 import jax from jax import numpy as jnp, random from flax import nnx ``` Throughout this guide, `nnx.vmap` is used as a case study to demonstrate how Flax NNX transforms work. However, the principles outlined in this document extends to all transforms. ## Basic example To begin, let's look at a simple example of using `nnx.vmap` to extend an element wise `vector_dot` function to work on batched inputs. We will define a `Weights` Module with no methods to hold some parameters, these weights will be passed as an input to the `vector_dot` function along with some data. Both the weights and data will be batched on axis `0` and we will use `nnx.vmap` to apply `vector_dot` to each batch element, and the result will be a batched on axis `1`: ```{code-cell} ipython3 class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), ) x = jax.random.normal(random.key(1), (10, 2)) def vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' return x @ weights.kernel + weights.bias y = nnx.vmap(vector_dot, in_axes=0, out_axes=1)(weights, x) print(f'{y.shape = }') nnx.display(weights) ``` Notice that `in_axes` interacts naturally with the `Weights` Module, treating it as if it were a pytree of `jax.Array`s. Prefix patterns are also allowed, so `in_axes=(0, 0)` would have also worked in this case. Objects are also allowed as outputs of Flax NNX transforms, which can be useful to transform initializers. For example, you can define a `create_weights` function to create an single `Weights` `nnx.Module`, and use `nnx.vmap` to create a stack of `Weights` with the same shapes as before: ```{code-cell} ipython3 def create_weights(seed: jax.Array): return Weights( kernel=random.uniform(random.key(seed), (2, 3)), bias=jnp.zeros((3,)), ) seeds = jnp.arange(10) weights = nnx.vmap(create_weights)(seeds) nnx.display(weights) ``` ## Transforming methods Methods in Python are just functions that take the instance as the first argument, this means that you can decorate methods from `Module` and other Flax NNX subtypes. For example, we can refactor `Weights` from the previous example and decorate `__init__` with `vmap` to do the work of `create_weights`, and add a `__call__` method and decorate it with `@nnx.vmap` to do the work of `vector_dot`: ```{code-cell} ipython3 class WeightStack(nnx.Module): @nnx.vmap def __init__(self, seed: jax.Array): self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3))) self.bias = nnx.Param(jnp.zeros((3,))) @nnx.vmap(in_axes=0, out_axes=1) def __call__(self, x: jax.Array): assert self.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' return x @ self.kernel + self.bias weights = WeightStack(jnp.arange(10)) x = jax.random.normal(random.key(1), (10, 2)) y = weights(x) print(f'{y.shape = }') nnx.display(weights) ``` The rest of the guide will focus on transforming individual functions. But do note that all examples can be written in this method style. +++ ## State propagation So far our functions have been stateless. However, the real power of Flax NNX transforms comes when you have stateful functions, because one of their main features is to propagate state changes to preserve reference semantics. Let's update the previous example by adding a `count` attribute to `Weights` and incrementing it in the new `stateful_vector_dot` function: ```{code-cell} ipython3 class Count(nnx.Variable): pass class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.arange(10), ) x = jax.random.normal(random.key(1), (10, 2)) def stateful_vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' weights.count += 1 return x @ weights.kernel + weights.bias y = nnx.vmap(stateful_vector_dot, in_axes=0, out_axes=1)(weights, x) weights.count ``` After running `stateful_vector_dot` once, you verified that the `count` attribute was correctly updated. Because `Weights` was vectorized, `count` was initialized as an `arange(10)`, and all of its elements were incremented by `1` inside the transformation. The most important part is that updates were propagated to the original `Weights` object outside the transformation. Nice! +++ ### Graph updates propagation JAX transforms see inputs as pytrees of `jax.Array`s, and Flax NNX sees inputs as pytrees of `jax.Array`s and Python references, where references form a graph. Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the inputs (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing `nnx.Variable`s between objects, etc. Sky is the limit! The following example demonstrates performing some arbitrary updates to the `Weights` object inside `nnx.vmap`, and verifying that the updates are correctly propagated to the original `Weights` object outside the transformation: ```{code-cell} ipython3 class Count(nnx.Variable): pass class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.arange(10), ) x = jax.random.normal(random.key(1), (10, 2)) def crazy_vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' weights.count += 1 y = x @ weights.kernel + weights.bias weights.some_property = ['a', 2, False] # add attribute del weights.bias # delete attribute weights.new_param = weights.kernel # share reference return y y = nnx.vmap(crazy_vector_dot, in_axes=0, out_axes=1)(weights, x) nnx.display(weights) ``` > With great power comes great responsibility. >
\- Uncle Ben While this feature is very powerful, it must be used with care because it can clash with JAX's underlying assumptions for certain transforms. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside an `nnx.jit`-ed function causes continuous recompilations and performance degradation. On the other hand, `scan` only allows a fixed `carry` structure, so adding/removing sub-states declared as carry will cause an error. +++ ## Transforming sub-states (lift types) Certain JAX transforms allow the use of pytree prefixes to specify how different parts of the inputs/outputs should be transformed. Flax NNX supports pytree prefixes for pytree structures but currently it doesn't have the notion of a prefix for graph objects. Instead, Flax NNX introduces the concept of “lift types” which allow specifying how different sub-states of an object should be transformed. Different transforms support different lift types, here is the list of currently supported FLax NNX lift types for each JAX transformation: | Lift type | JAX transforms | |------------------|-----------------------------------------| | `StateAxes` | `vmap`, `pmap`, `scan` | | `StateSharding` | `jit`, `shard_map` | | `DiffState` | `grad`, `value_and_grad`, `custom_vjp` | To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix. Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements. To do this we will define a `nnx.StateAxes` with a filter that matches the `nnx.Param` variables and maps them to axis `0`, and all the `Count` variables to `None`, and pass this `nnx.StateAxes` to `in_axes` for the `Weights` object. ```{code-cell} ipython3 class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array, count: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) weights = Weights( kernel=random.uniform(random.key(0), (10, 2, 3)), bias=jnp.zeros((10, 3)), count=jnp.array(0), ) x = jax.random.normal(random.key(1), (10, 2)) def stateful_vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' weights.count += 1 return x @ weights.kernel + weights.bias state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count y = nnx.vmap(stateful_vector_dot, in_axes=(state_axes, 0), out_axes=1)(weights, x) weights.count ``` Here, `count` is now a scalar since it's not being vectorized. Also, note that `nnx.StateAxes` can only be used directly on Flax NNX objects, and it cannot be used as a prefix for a pytree of objects. +++ ### Random state In Flax NNX, a random state is just a regular state. This means that it is stored inside `nnx.Module`s that need it, and it is treated as any other type of state. This is a simplification over Flax Linen, where a random state was handled by a separate mechanism. In practice `nnx.Module`s simply need to keep a reference to a `Rngs` object that is passed to them during initialization, and use it to generate a unique key for each random operation. For the purposes of this guide, this means that random state can be transformed like any other type of state but we also need to be aware of how the state is laid out so we can transform it correctly. Suppose you want to change things up a bit and apply the same weights to all elements in the batch. But you also want to add different random noise to each element. To do this, you will add an `Rngs` attribute to `Weights`, created from a `seed` key argument passed during construction. This seed key must be `split` beforehand, so that you can vectorize it successfully. For pedagogical reasons, you will assign the seed key to a `noise` “stream” and sample from it. To vectorize the PRNG state, you must configure `nnx.StateAxes` to map all `RngState`s (a base class for all variables in `Rngs`) to axis `0`, and `nnx.Param` and `Count` to `None`. ```{code-cell} ipython3 class Weights(nnx.Module): def __init__(self, kernel, bias, count, seed): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) self.count = Count(count) self.rngs = nnx.Rngs(noise=seed) weights = Weights( kernel=random.uniform(random.key(0), (2, 3)), bias=jnp.zeros((3,)), count=jnp.array(0), seed=random.split(random.key(0), num=10), ) x = random.normal(random.key(1), (10, 2)) def noisy_vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' weights.count += 1 y = x @ weights.kernel + weights.bias return y + random.normal(weights.rngs.noise(), y.shape) state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None}) y1 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x) y2 = nnx.vmap(noisy_vector_dot, in_axes=(state_axes, 0))(weights, x) print(jnp.allclose(y1, y2)) nnx.display(weights) ``` Because `Rngs`'s state is updated in place and automatically propagated by `nnx.vmap`, we will get a different result every time that `noisy_vector_dot` is called. In the example above, you manually split the random state during construction. This is fine, as it makes the intention clear, but it also doesn't let you use `Rngs` outside of `nnx.vmap` because its state is always split. To solve this, you can pass an unsplit seed and use the `nnx.split_rngs` decorator before `nnx.vmap` to split the `RngState` right before each call to the function, and then "lower" it back so that it becomes usable. ```{code-cell} ipython3 weights = Weights( kernel=random.uniform(random.key(0), (2, 3)), bias=jnp.zeros((3,)), count=jnp.array(0), seed=0, ) x = random.normal(random.key(1), (10, 2)) state_axes = nnx.StateAxes({nnx.RngState: 0, (nnx.Param, Count): None}) @nnx.split_rngs(splits=10) @nnx.vmap(in_axes=(state_axes, 0)) def noisy_vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' weights.count += 1 y = x @ weights.kernel + weights.bias return y + random.normal(weights.rngs.noise(), y.shape) y1 = noisy_vector_dot(weights, x) y2 = noisy_vector_dot(weights, x) print(jnp.allclose(y1, y2)) nnx.display(weights) ``` ## Rules and limitations In this section we will cover some rules and limitations apply when using Modules inside transformations. ### Mutable Module cannot be passed by closure While Python allows for passing objects as closures to functions, this is generally not supported by Flax NNX transforms. The reason is that because Modules are mutable it is very easy to capture tracer into a Module created outside of the transform, this is silent error in JAX. To avoid this, Flax NNX checks that the Modules and Variables being mutated are passed as arguments to the transformed function. For example, if we have a stateful Module such as `Counter` that increments a counter every time it is called, and we try to pass it as a closure to a function decorated with `nnx.jit`, we would be leaking the tracer. However Flax NNX will raise an error instead to prevent this: ```{code-cell} ipython3 class Counter(nnx.Module): def __init__(self): self.count = nnx.Param(jnp.array(0)) def increment(self): self.count += jnp.array(1) counter = Counter() @nnx.jit def f(x): counter.increment() return 2 * x try: y = f(3) except Exception as e: print(e) ``` To solve this issue pass all Module as arguments to the functions being transformed. In this case `f` should accept `counter` as an argument. +++ ### Consistent aliasing The main issue with allowing for reference semantics in transforms is that references can be shared across inputs and outputs. This can be problematic if it is not taken care of because it would lead to ill-defined or inconsistent behavior. In the example below you have a single `Weights` Module `m` whose reference appears in multiple places in `arg1` and `arg2`. The problem here is that you also specify that you want to vectorize `arg1` in axis `0` and `arg2` in axis `1`. This would be fine in JAX because of referential transparency of pytrees. But this would be problematic in Flax NNX because you are trying to vectorize `m` in two different ways. Flax NNX will enforce consistency by raising an error. ```{code-cell} ipython3 class Weights(nnx.Module): def __init__(self, array: jax.Array): self.param = nnx.Param(array) m = Weights(jnp.arange(10)) arg1 = {'a': {'b': m}, 'c': m} arg2 = [(m, m), m] @nnx.vmap(in_axes=(0, 1)) def f(arg1, arg2): ... try: f(arg1, arg2) except ValueError as e: print(e) ``` Inconsistent aliasing can also happen between inputs and outputs. In the next example you have a trivial function that accepts and immediately returns `arg1`. However, `arg1` is vectorized on axis `0` on the input, and axis `1` on the output. As expected, this is problematic and Flax NNX will raise an error. ```{code-cell} ipython3 @nnx.vmap(in_axes=0, out_axes=1) def f(arg1): return arg1 try: f(arg1) except ValueError as e: print(e) ``` ## Axis metadata Flax NNX `Variable`s can hold arbitrary metadata, which can be added by simply passing it as keyword arguments to its constructor. This is often used to store `sharding` information, as used by the `nnx.spmd` APIs (like `nnx.get_partition_spec` and `nnx.get_named_sharding`). However, it is often important to keep this axes-related information in sync to what the actual state of the axes is when transforms are involved. For example, if you vectorize a variable on axis `1`, you should remove the `sharding` information at position `1` when inside a `vmap` or `scan` to reflect the fact that the axes are temporarily removed. To achieve this, Flax NNX transforms provide a non-standard `transform_metadata` dictionary argument. And when the `nnx.PARTITION_NAME` key is present, the `sharding` metadata will be updated as specified by `in_axes` and `out_axes`. Let's see an example of this in action: ```{code-cell} ipython3 mesh = jax.make_mesh((1, 1), ('a', 'b')) class Weights(nnx.Module): def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]): self.param = nnx.Param(array, out_sharding=out_sharding) @nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'}) def f(m: Weights): print(f'Inner {m.param.shape = }') print(f'Inner {m.param.out_sharding = }') with jax.set_mesh(mesh): m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None)) f(m) print(f'Outter {m.param.shape = }') print(f'Outter {m.param.out_sharding = }') ``` Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`. You can verify that this also works when `nnx.Module`s are created inside the transformation - the new `sharding` axes will be added to the `nnx.Module` `nnx.Variable`s outside the transformation, matching the axes of the transformed `nnx.Variable`s. ```{code-cell} ipython3 @nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'}) def init_vmap(): return Weights(jnp.ones((3, 5)), out_sharding=('a', None)) with jax.set_mesh(mesh): m = init_vmap() print(f'Outter {m.param.shape = }') print(f'Outter {m.param.out_sharding = }') ``` ================================================ FILE: docs_nnx/guides/view.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "75afc9f3", "metadata": {}, "source": [ "# Model Views\n", "This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:" ] }, { "cell_type": "code", "execution_count": null, "id": "8e333aab", "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "\n", "# example model with different train/eval behavior\n", "rngs = nnx.Rngs(0)\n", "model = nnx.Sequential(\n", " nnx.Linear(2, 4, rngs=rngs), nnx.BatchNorm(4, rngs=rngs), nnx.Dropout(0.1)\n", ")\n", "\n", "# set train and eval modes\n", "train_model = nnx.view(model, deterministic=False, use_running_average=False)\n", "eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n", "\n", "# Can see deterministic is different between train_model and eval_model\n", "assert train_model.layers[2].deterministic == False\n", "assert eval_model.layers[2].deterministic == True\n", "\n", "# Weights are shared between the models\n", "assert train_model.layers[0].kernel is eval_model.layers[0].kernel\n", "\n", "# Print information about kwargs for nnx.view with nnx.view_info\n", "print(nnx.view_info(model))" ] }, { "cell_type": "markdown", "id": "f70f9353", "metadata": {}, "source": [ "## Motivation\n", "\n", "Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.\n", "\n", "`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below." ] }, { "cell_type": "code", "execution_count": null, "id": "886c7479", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import optax\n", "import matplotlib.pyplot as plt\n", "\n", "in_dim, hidden_dim, out_dim = 16, 32, 2\n", "\n", "\n", "class MyModel(nnx.Module):\n", " def __init__(\n", " self,\n", " in_dim: int,\n", " hidden_dim: int,\n", " out_dim: int,\n", " dropout_rate: float,\n", " *,\n", " rngs: nnx.Rngs,\n", " ):\n", " self.lin1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)\n", " self.do = nnx.Dropout(dropout_rate)\n", " self.bn = nnx.BatchNorm(hidden_dim, rngs=rngs)\n", " self.lin2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)\n", "\n", " def __call__(self, x, *, rngs=None):\n", " x = nnx.relu(self.do(self.bn(self.lin1(x)), rngs=rngs))\n", " return self.lin2(x)" ] }, { "cell_type": "markdown", "id": "0568e6ce", "metadata": {}, "source": [ "Lets take a look at the model to see what is going on." ] }, { "cell_type": "code", "execution_count": null, "id": "da2c8f80", "metadata": {}, "outputs": [], "source": [ "# can display to inspect state\n", "model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=nnx.Rngs(0))\n", "nnx.display(model)\n", "\n", "# can assert to inspect state\n", "assert model.do.deterministic == False" ] }, { "cell_type": "markdown", "id": "1a075c12", "metadata": {}, "source": [ "From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. \n", "\n", "This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below." ] }, { "cell_type": "code", "execution_count": null, "id": "11e59178", "metadata": {}, "outputs": [], "source": [ "train_model = nnx.view(model, deterministic=False)\n", "eval_model = nnx.view(model, deterministic=True)\n", "\n", "# weights are references to the same data\n", "assert train_model.lin1.kernel is eval_model.lin1.kernel\n", "\n", "# Dropout.deterministic is different in each model\n", "assert train_model.do.deterministic is False\n", "assert eval_model.do.deterministic is True" ] }, { "cell_type": "markdown", "id": "5c1ee1db", "metadata": {}, "source": [ "## Example with `nnx.view`" ] }, { "cell_type": "markdown", "id": "e35d1bfd", "metadata": {}, "source": [ "We first set up data generators and define train/eval step functions. The `train_step` receives an `nnx.Rngs` object for dropout randomness, while `eval_step` doesn't since dropout is disabled in `eval_model`." ] }, { "cell_type": "code", "execution_count": null, "id": "a0f72d8d", "metadata": {}, "outputs": [], "source": [ "ndata, batch_size, total_epochs, lr = 2048, 32, 100, 1e-3\n", "rngs = nnx.Rngs(0)\n", "x = rngs.normal((ndata, in_dim))\n", "y = rngs.normal((ndata, out_dim))\n", "\n", "\n", "@nnx.jit\n", "def train_step(model, optimizer, x, y, rngs):\n", " def loss_fn(model, rngs):\n", " return ((model(x, rngs=rngs) - y) ** 2).mean()\n", "\n", " grads = nnx.grad(loss_fn)(model, rngs)\n", " optimizer.update(model, grads)\n", "\n", "\n", "@nnx.jit\n", "def eval_step(model, x, y):\n", " return ((model(x) - y) ** 2).mean()" ] }, { "cell_type": "markdown", "id": "70c05c4d", "metadata": {}, "source": [ "Now we create `train_model` and `eval_model` views up front. During the training loop we simply use the appropriate view — there is no need to call `.train()` or `.eval()`, and it is always clear from the code which mode the model is in." ] }, { "cell_type": "code", "execution_count": null, "id": "175db8e7", "metadata": {}, "outputs": [], "source": [ "model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)\n", "optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)\n", "train_model = nnx.view(model, deterministic=False) # training view\n", "eval_model = nnx.view(model, deterministic=True) # eval view\n", "\n", "eval_results = []\n", "for epoch in range(total_epochs):\n", " for i in range(ndata // batch_size):\n", " idx = slice(i * batch_size, (i + 1) * batch_size)\n", " train_step(train_model, optimizer, x[idx], y[idx], rngs) # use train_model\n", "\n", " eval_results.append(eval_step(eval_model, x, y)) # use eval_model\n", "plt.plot(eval_results)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "3d666cdb", "metadata": {}, "source": [ "## Getting information with `nnx.view_info`\n", "To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions." ] }, { "cell_type": "code", "execution_count": null, "id": "54a65a31", "metadata": {}, "outputs": [], "source": [ "print(nnx.view_info(model))" ] }, { "cell_type": "markdown", "id": "47479be6", "metadata": {}, "source": [ "## Writing modules compatible with `nnx.view`\n", "\n", "You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n", "\n", "Your `set_view` method should follow these conventions:\n", "\n", "1. **Accept keyword arguments with `None` defaults.** Each kwarg represents a configurable mode for this module. A `None` default means \"leave unchanged\", so views only override the modes you explicitly set.\n", "2. **Only update the attribute when the kwarg is not `None`.** This ensures that unrelated views don't accidentally reset each other's settings.\n", "3. **Include a Google-style docstring.** The `nnx.view_info` function parses these docstrings to display human-readable information about available view options.\n", "\n", "The general pattern looks like this:\n", "\n", "```python\n", "class MyLayer(nnx.Module):\n", " ...\n", "\n", " def set_view(self, kwarg1: type1 = None, ..., kwargN: typeN = None):\n", " \"\"\"Description of the module's configurable modes.\n", "\n", " Args:\n", " kwarg1: description of kwarg1.\n", " ...\n", " kwargN: description of kwargN.\n", " \"\"\"\n", " if kwarg1 is not None:\n", " self.kwarg1 = kwarg1\n", " ...\n", "```\n", "\n", "Here is a concrete example — a `PrintLayer` that can be toggled to print a message during its forward pass:" ] }, { "cell_type": "code", "execution_count": null, "id": "2dfdfd64", "metadata": {}, "outputs": [], "source": [ "class PrintLayer(nnx.Module):\n", " def __init__(self, msg: str | None = None):\n", " self.msg = msg\n", "\n", " def __call__(self, *args, **kwargs):\n", " if self.msg:\n", " print(self.msg)\n", "\n", " def set_view(self, msg: bool | None = None):\n", " \"\"\"Example set_view docstring. This follows Google style docstrings.\n", "\n", " Args:\n", " msg: bool indicating if a message should be printed.\n", " If True, the `__call__` method prints the message.\n", " \"\"\"\n", " if msg is not None:\n", " self.msg = msg\n", "\n", "\n", "model = PrintLayer()\n", "model_print = nnx.view(model, msg='Hello, World!')\n", "\n", "model() # nothing printed\n", "model_print() # prints \"Hello, World!\"" ] }, { "cell_type": "markdown", "id": "c7b261b8", "metadata": {}, "source": [ "We can use `nnx.view_info` to inspect what view options `PrintLayer` exposes. This is especially handy when working with unfamiliar models — it lists every submodule that defines `set_view`, along with the accepted kwargs, their types, defaults, and docstring descriptions." ] }, { "cell_type": "code", "execution_count": null, "id": "a5e3bc03", "metadata": {}, "outputs": [], "source": [ "# Display the information for nnx.view\n", "print(nnx.view_info(model))" ] }, { "cell_type": "markdown", "id": "1acbcc09", "metadata": {}, "source": [ "The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n", "\n", "## Using `with_attributes`\n", "\n", "If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged." ] }, { "cell_type": "code", "execution_count": null, "id": "62b5c185", "metadata": {}, "outputs": [], "source": [ "class NoisyLinear(nnx.Module):\n", " def __init__(self, din, dout, *, training=None, rngs: nnx.Rngs):\n", " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", " self.training = training\n", "\n", " def __call__(self, x, rngs=None):\n", " assert self.training is not None\n", " x = self.linear(x)\n", " if self.training:\n", " x = x + rngs.normal(x.shape) * 0.1\n", " return x\n", "\n", "rngs = nnx.Rngs(0)\n", "model = nnx.Sequential(\n", " NoisyLinear(4, 8, rngs=rngs),\n", " NoisyLinear(8, 2, rngs=rngs),\n", ")\n", "\n", "train_model = nnx.with_attributes(model, training=True)\n", "eval_model = nnx.with_attributes(model, training=False)\n", "\n", "print(f'{train_model.layer1.training=}')\n", "y1 = train_model(jnp.ones((1, 4)), rngs=rngs)\n", "\n", "print(f'{eval_model.layer1.training=}')\n", "y2 = eval_model(jnp.ones((1, 4)))" ] }, { "cell_type": "markdown", "id": "0cc37a57", "metadata": {}, "source": [ "## Using `recursive_map`\n", "\n", "For more advanced transformations — such as replacing submodules — you can use {func}`nnx.recursive_map `. This function traverses the entire module tree bottom-up, calling a user-defined function `f(path, node)` on every node and leaf. Whatever `f` returns is used as the replacement for that node in the new tree. The resulting model view shares the Variables with the original (unless instructed otherwise).\n", "\n", "In the example below, we use `recursive_map` to replace every `nnx.Linear` layer with a `NoisyLinear` version (reusing the class defined earlier) that adds random noise during training:" ] }, { "cell_type": "code", "execution_count": null, "id": "2e77b49d", "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "\n", "def add_noise(path, node):\n", " if isinstance(node, nnx.Linear):\n", " noisy = nnx.eval_shape(\n", " lambda: NoisyLinear(node.in_features, node.out_features, rngs=nnx.Rngs(0))\n", " )\n", " noisy.linear = node\n", " return noisy\n", " return node\n", "\n", "rngs = nnx.Rngs(0)\n", "model = nnx.Sequential(\n", " nnx.Linear(4, 8, rngs=rngs),\n", " nnx.Linear(8, 2, rngs=rngs),\n", ")\n", "\n", "noisy_model = nnx.recursive_map(add_noise, model)\n", "\n", "y = noisy_model(jnp.ones((1, 4)), rngs=rngs)\n", "print(noisy_model)s" ] }, { "cell_type": "markdown", "id": "bf521e45", "metadata": {}, "source": [ "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/guides/view.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Model Views This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example: ```{code-cell} from flax import nnx # example model with different train/eval behavior rngs = nnx.Rngs(0) model = nnx.Sequential( nnx.Linear(2, 4, rngs=rngs), nnx.BatchNorm(4, rngs=rngs), nnx.Dropout(0.1) ) # set train and eval modes train_model = nnx.view(model, deterministic=False, use_running_average=False) eval_model = nnx.view(model, deterministic=True, use_running_average=True) # Can see deterministic is different between train_model and eval_model assert train_model.layers[2].deterministic == False assert eval_model.layers[2].deterministic == True # Weights are shared between the models assert train_model.layers[0].kernel is eval_model.layers[0].kernel # Print information about kwargs for nnx.view with nnx.view_info print(nnx.view_info(model)) ``` ## Motivation Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten. `nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below. ```{code-cell} import jax import optax import matplotlib.pyplot as plt in_dim, hidden_dim, out_dim = 16, 32, 2 class MyModel(nnx.Module): def __init__( self, in_dim: int, hidden_dim: int, out_dim: int, dropout_rate: float, *, rngs: nnx.Rngs, ): self.lin1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs) self.do = nnx.Dropout(dropout_rate) self.bn = nnx.BatchNorm(hidden_dim, rngs=rngs) self.lin2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs) def __call__(self, x, *, rngs=None): x = nnx.relu(self.do(self.bn(self.lin1(x)), rngs=rngs)) return self.lin2(x) ``` Lets take a look at the model to see what is going on. ```{code-cell} # can display to inspect state model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=nnx.Rngs(0)) nnx.display(model) # can assert to inspect state assert model.do.deterministic == False ``` From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below. ```{code-cell} train_model = nnx.view(model, deterministic=False) eval_model = nnx.view(model, deterministic=True) # weights are references to the same data assert train_model.lin1.kernel is eval_model.lin1.kernel # Dropout.deterministic is different in each model assert train_model.do.deterministic is False assert eval_model.do.deterministic is True ``` ## Example with `nnx.view` +++ We first set up data generators and define train/eval step functions. The `train_step` receives an `nnx.Rngs` object for dropout randomness, while `eval_step` doesn't since dropout is disabled in `eval_model`. ```{code-cell} ndata, batch_size, total_epochs, lr = 2048, 32, 100, 1e-3 rngs = nnx.Rngs(0) x = rngs.normal((ndata, in_dim)) y = rngs.normal((ndata, out_dim)) @nnx.jit def train_step(model, optimizer, x, y, rngs): def loss_fn(model, rngs): return ((model(x, rngs=rngs) - y) ** 2).mean() grads = nnx.grad(loss_fn)(model, rngs) optimizer.update(model, grads) @nnx.jit def eval_step(model, x, y): return ((model(x) - y) ** 2).mean() ``` Now we create `train_model` and `eval_model` views up front. During the training loop we simply use the appropriate view — there is no need to call `.train()` or `.eval()`, and it is always clear from the code which mode the model is in. ```{code-cell} model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs) optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param) train_model = nnx.view(model, deterministic=False) # training view eval_model = nnx.view(model, deterministic=True) # eval view eval_results = [] for epoch in range(total_epochs): for i in range(ndata // batch_size): idx = slice(i * batch_size, (i + 1) * batch_size) train_step(train_model, optimizer, x[idx], y[idx], rngs) # use train_model eval_results.append(eval_step(eval_model, x, y)) # use eval_model plt.plot(eval_results) plt.show() ``` ## Getting information with `nnx.view_info` To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions. ```{code-cell} print(nnx.view_info(model)) ``` ## Writing modules compatible with `nnx.view` You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about. Your `set_view` method should follow these conventions: 1. **Accept keyword arguments with `None` defaults.** Each kwarg represents a configurable mode for this module. A `None` default means "leave unchanged", so views only override the modes you explicitly set. 2. **Only update the attribute when the kwarg is not `None`.** This ensures that unrelated views don't accidentally reset each other's settings. 3. **Include a Google-style docstring.** The `nnx.view_info` function parses these docstrings to display human-readable information about available view options. The general pattern looks like this: ```python class MyLayer(nnx.Module): ... def set_view(self, kwarg1: type1 = None, ..., kwargN: typeN = None): """Description of the module's configurable modes. Args: kwarg1: description of kwarg1. ... kwargN: description of kwargN. """ if kwarg1 is not None: self.kwarg1 = kwarg1 ... ``` Here is a concrete example — a `PrintLayer` that can be toggled to print a message during its forward pass: ```{code-cell} class PrintLayer(nnx.Module): def __init__(self, msg: str | None = None): self.msg = msg def __call__(self, *args, **kwargs): if self.msg: print(self.msg) def set_view(self, msg: bool | None = None): """Example set_view docstring. This follows Google style docstrings. Args: msg: bool indicating if a message should be printed. If True, the `__call__` method prints the message. """ if msg is not None: self.msg = msg model = PrintLayer() model_print = nnx.view(model, msg='Hello, World!') model() # nothing printed model_print() # prints "Hello, World!" ``` We can use `nnx.view_info` to inspect what view options `PrintLayer` exposes. This is especially handy when working with unfamiliar models — it lists every submodule that defines `set_view`, along with the accepted kwargs, their types, defaults, and docstring descriptions. ```{code-cell} # Display the information for nnx.view print(nnx.view_info(model)) ``` The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree. ## Using `with_attributes` If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged. ```{code-cell} class NoisyLinear(nnx.Module): def __init__(self, din, dout, *, training=None, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.training = training def __call__(self, x, rngs=None): assert self.training is not None x = self.linear(x) if self.training: x = x + rngs.normal(x.shape) * 0.1 return x rngs = nnx.Rngs(0) model = nnx.Sequential( NoisyLinear(4, 8, rngs=rngs), NoisyLinear(8, 2, rngs=rngs), ) train_model = nnx.with_attributes(model, training=True) eval_model = nnx.with_attributes(model, training=False) print(f'{train_model.layer1.training=}') y1 = train_model(jnp.ones((1, 4)), rngs=rngs) print(f'{eval_model.layer1.training=}') y2 = eval_model(jnp.ones((1, 4))) ``` ## Using `recursive_map` For more advanced transformations — such as replacing submodules — you can use {func}`nnx.recursive_map `. This function traverses the entire module tree bottom-up, calling a user-defined function `f(path, node)` on every node and leaf. Whatever `f` returns is used as the replacement for that node in the new tree. The resulting model view shares the Variables with the original (unless instructed otherwise). In the example below, we use `recursive_map` to replace every `nnx.Linear` layer with a `NoisyLinear` version (reusing the class defined earlier) that adds random noise during training: ```{code-cell} import jax.numpy as jnp def add_noise(path, node): if isinstance(node, nnx.Linear): noisy = nnx.eval_shape( lambda: NoisyLinear(node.in_features, node.out_features, rngs=nnx.Rngs(0)) ) noisy.linear = node return noisy return node rngs = nnx.Rngs(0) model = nnx.Sequential( nnx.Linear(4, 8, rngs=rngs), nnx.Linear(8, 2, rngs=rngs), ) noisy_model = nnx.recursive_map(add_noise, model) y = noisy_model(jnp.ones((1, 4)), rngs=rngs) print(noisy_model)s ``` Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`. ================================================ FILE: docs_nnx/guides_advanced.rst ================================================ Advanced Guides ====== .. toctree:: :maxdepth: 1 :caption: Advanced guides/flax_gspmd guides/performance guides/bridge_guide guides/surgery guides/extracting_intermediates ================================================ FILE: docs_nnx/guides_basic.rst ================================================ Basic Guides ====== .. toctree:: :maxdepth: 1 :caption: Basic guides/pytree guides/transforms guides/view guides/filters_guide guides/randomness guides/checkpointing guides/jax_and_nnx_transforms ================================================ FILE: docs_nnx/hijax/hijax.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "15c2d208", "metadata": {}, "source": [ "# Hijax" ] }, { "cell_type": "code", "execution_count": 1, "id": "99809892", "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "current_mode = nnx.var_defaults().hijax # ignore: only needed for testing" ] }, { "cell_type": "code", "execution_count": 2, "id": "d1aaa0ec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.85250294\n", "0.8165137\n", "0.7814907\n" ] } ], "source": [ "nnx.var_defaults(hijax=True)\n", "\n", "rngs = nnx.Rngs(0)\n", "model = nnx.Linear(2, 3, rngs=rngs)\n", "optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)\n", "\n", "@jax.jit\n", "def train_step(x, y):\n", " loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n", " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n", " optimizer.update(model, grads)\n", " return loss\n", "\n", "x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))\n", "for _ in range(3):\n", " print(train_step(x, y))" ] }, { "cell_type": "markdown", "id": "04458d66", "metadata": {}, "source": [ "## Hijax Variable" ] }, { "cell_type": "markdown", "id": "f4220c6f", "metadata": {}, "source": [ "State propagation:" ] }, { "cell_type": "code", "execution_count": 3, "id": "396a07a3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1\n" ] } ], "source": [ "v = nnx.Variable(jnp.array(0), hijax=True)\n", "\n", "@jax.jit\n", "def inc(v):\n", " v[...] += 1\n", "\n", "print(v[...]); inc(v); print(v[...])" ] }, { "cell_type": "code", "execution_count": 4, "id": "2ab7d801", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:Variable()\u001b[39m. \u001b[34;1mlet\n", " \u001b[39;22mjit[\n", " name=inc\n", " jaxpr={ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:Variable()\u001b[39m. \u001b[34;1mlet\n", " \u001b[39;22mb\u001b[35m:i32[]\u001b[39m = get_variable[\n", " avals=(ShapedArray(int32[], weak_type=True),)\n", " has_qdd=True\n", " treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))\n", " var_type=\n", " ] a\n", " c\u001b[35m:i32[]\u001b[39m = add b 1:i32[]\n", " _\u001b[35m:i32[]\u001b[39m = get_variable[\n", " avals=(ShapedArray(int32[], weak_type=True),)\n", " has_qdd=True\n", " treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))\n", " var_type=\n", " ] a\n", " set_variable[\n", " treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))\n", " var_type=\n", " ] a c\n", " \u001b[34;1min \u001b[39;22m() }\n", " ] a\n", " \u001b[34;1min \u001b[39;22m() }\n" ] } ], "source": [ "v = nnx.Variable(jnp.array(0), hijax=True)\n", "print(jax.make_jaxpr(inc)(v))" ] }, { "cell_type": "markdown", "id": "39070460", "metadata": {}, "source": [ "Pytree values:" ] }, { "cell_type": "code", "execution_count": 5, "id": "fcd0de3f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'a'\u001b[0m: Array(0, dtype=int32, weak_type=True), \u001b[38;2;207;144;120m'b'\u001b[0m: Array(2, dtype=int32, weak_type=True)\u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2 (8 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'a'\u001b[0m: Array(1, dtype=int32, weak_type=True), \u001b[38;2;207;144;120m'b'\u001b[0m: Array(4, dtype=int32, weak_type=True)\u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True)\n", "\n", "@jax.jit\n", "def inc_and_double(v):\n", " v['a'] += 1\n", " v['b'] *= 2\n", "\n", "print(v); inc_and_double(v); print(v)" ] }, { "cell_type": "markdown", "id": "f0cfe954", "metadata": {}, "source": [ "Dynamic state structure:" ] }, { "cell_type": "code", "execution_count": 6, "id": "0d83a130", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "After: \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m{\u001b[0m\u001b[38;2;207;144;120m'y_mean'\u001b[0m: Array(-1.1782329, dtype=float32)\u001b[38;2;255;213;3m}\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "rngs = nnx.Rngs(0)\n", "x = rngs.uniform((4, 5))\n", "w = rngs.normal((5, 3))\n", "metrics = nnx.Variable({}, hijax=True)\n", "\n", "@jax.jit\n", "def linear(x, w, metrics: nnx.Variable):\n", " y = x @ w\n", " metrics['y_mean'] = jnp.mean(y)\n", " return y\n", "\n", "print(\"Before:\", metrics)\n", "y = linear(x, w, metrics)\n", "print(\"After:\", metrics)" ] }, { "cell_type": "code", "execution_count": 7, "id": "0a55df94", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([1, 2, 3], dtype=int32),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "# set default Variable mode for the rest of the guide\n", "nnx.var_defaults(hijax=True)\n", "\n", "variable = nnx.Variable(jnp.array([1, 2, 3]))\n", "\n", "print(variable)" ] }, { "cell_type": "markdown", "id": "1b2632f1", "metadata": {}, "source": [ "### Mutability" ] }, { "cell_type": "code", "execution_count": 8, "id": "b7b1f421", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nnx.vars_as(model, mutable=False) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", " \u001b[38;2;156;220;254mmutable\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mFalse\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "nnx.vars_as(model, mutable=True) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "class Linear(nnx.Module):\n", " def __init__(self, in_features, out_features, rngs: nnx.Rngs):\n", " self.kernel = nnx.Param(rngs.normal((in_features, out_features)))\n", "\n", " def __call__(self, x):\n", " return x @ self.kernel\n", "\n", "model = Linear(1, 3, rngs=nnx.Rngs(0))\n", "\n", "print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n", "print(f\"{nnx.vars_as(model, mutable=True) = !s}\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "594cb65e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ImmutableVariableError: Cannot mutate Variable as it is marked as immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ImmutableVariableError)\n" ] } ], "source": [ "v = nnx.Variable(jnp.array(0))\n", "v_immut = nnx.vars_as(v, mutable=False)\n", "assert not v_immut.mutable\n", "\n", "try:\n", " v_immut[...] += 1 # raises an error\n", "except Exception as e:\n", " print(f\"{type(e).__name__}: {e}\")" ] }, { "cell_type": "markdown", "id": "58692a37", "metadata": {}, "source": [ "### Ref support" ] }, { "cell_type": "code", "execution_count": 10, "id": "fcd4fb4f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", " \u001b[38;2;156;220;254mref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "Ref(0, dtype=int32, weak_type=True)\n" ] } ], "source": [ "v = nnx.Variable(jnp.array(0))\n", "v_ref = nnx.vars_as(v, ref=True)\n", "assert v_ref.ref\n", "print(v_ref)\n", "print(v_ref.get_raw_value())" ] }, { "cell_type": "code", "execution_count": 11, "id": "18256668", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "immutable = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhad_ref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", " \u001b[38;2;156;220;254mmutable\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mFalse\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n", "mutable = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m,\n", " \u001b[38;2;156;220;254mref\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "v_immut = nnx.vars_as(v_ref, mutable=False)\n", "assert not v_immut.ref\n", "print(\"immutable =\", v_immut)\n", "\n", "v_ref = nnx.vars_as(v_immut, mutable=True)\n", "assert v_ref.ref\n", "print(\"mutable =\", v_ref)" ] }, { "cell_type": "markdown", "id": "f4e35e75", "metadata": {}, "source": [ "### Examples" ] }, { "cell_type": "code", "execution_count": 12, "id": "5400fe58", "metadata": {}, "outputs": [], "source": [ "class Block(nnx.Module):\n", " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", " self.linear = Linear(din, dmid, rngs=rngs)\n", " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", " self.linear_out = Linear(dmid, dout, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = nnx.gelu(self.dropout(self.bn(self.linear(x))))\n", " return self.linear_out(x)" ] }, { "cell_type": "markdown", "id": "ba980b6b", "metadata": {}, "source": [ "#### Training Loop" ] }, { "cell_type": "code", "execution_count": 13, "id": "566c4249", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss = 1.000178\n", "loss = 0.9700456\n", "loss = 0.93967044\n" ] } ], "source": [ "# hijax Variables by default\n", "model = Block(2, 64, 3, rngs=nnx.Rngs(0))\n", "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", "\n", "@jax.jit\n", "def train_step(model, optimizer, x, y):\n", " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", " def loss_fn(params):\n", " model = nnx.merge(graphdef, params, nondiff)\n", " return ((model(x) - y) ** 2).mean()\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n", " optimizer.update(model, grads)\n", "\n", " return loss\n", "\n", "for _ in range(3):\n", " loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))\n", " print(f\"{loss = !s}\")" ] }, { "cell_type": "markdown", "id": "1dea99c1", "metadata": {}, "source": [ "#### Scan Over Layers" ] }, { "cell_type": "code", "execution_count": 14, "id": "d8136be4", "metadata": {}, "outputs": [], "source": [ "# TODO: does not work with hijax yet\n", "# @jax.vmap\n", "# def create_stack(rngs):\n", "# return nnx.as_immutable_vars(Block(2, 64, 2, rngs=rngs))\n", "\n", "# block_stack = nnx.as_mutable_vars(create_stack(nnx.Rngs(0).fork(split=8)))\n", "\n", "# def scan_fn(x, block):\n", "# x = block(x)\n", "# return x, None\n", "\n", "# x = jax.random.uniform(jax.random.key(0), (3, 2))\n", "# y, _ = jax.lax.scan(scan_fn, x, block_stack)\n", "\n", "# print(\"y = \", y)" ] }, { "cell_type": "markdown", "id": "7ca18a0d", "metadata": {}, "source": [ "### Limitations" ] }, { "cell_type": "markdown", "id": "1dd39c79", "metadata": {}, "source": [ "#### Mutable Outputs" ] }, { "cell_type": "code", "execution_count": 15, "id": "c6062d19", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Error: mutable hitypes should use lo_ty_qdd instead\n" ] } ], "source": [ "@jax.jit\n", "def create_model(rngs):\n", " return Block(2, 64, 3, rngs=rngs)\n", "\n", "try:\n", " model = create_model(nnx.Rngs(0))\n", "except Exception as e:\n", " print(f\"Error:\", e)" ] }, { "cell_type": "code", "execution_count": 16, "id": "8bb1e9e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 128 (512 B)\u001b[0m\n", " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 128 (512 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "@jax.jit\n", "def create_model(rngs):\n", " return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", "\n", "model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n", "\n", "print(\"model.linear =\", model.linear)" ] }, { "cell_type": "markdown", "id": "609bed7c", "metadata": {}, "source": [ "#### Reference Sharing (aliasing)" ] }, { "cell_type": "code", "execution_count": 17, "id": "045d03c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "# NOTE: doesn't currently fail on the jax side\n", "def get_error(f, *args):\n", " try:\n", " return f(*args)\n", " except Exception as e:\n", " return f\"{type(e).__name__}: {e}\"\n", "\n", "x = nnx.Variable(jnp.array(0))\n", "\n", "@jax.jit\n", "def f(a, b):\n", " ...\n", "\n", "print(get_error(f, x, x))" ] }, { "cell_type": "code", "execution_count": 18, "id": "bc2e87e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n", "\u001b[38;2;79;201;177mHasShared\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Variable: 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "# NOTE: doesn't currently fail on the jax side\n", "class HasShared(nnx.Pytree):\n", " def __init__(self):\n", " self.a = nnx.Variable(jnp.array(0))\n", " self.b = self.a\n", "\n", "@jax.jit\n", "def g(has_shared):\n", " has_shared.a[...] = 5\n", "\n", "has_shared = HasShared()\n", "\n", "print(get_error(g, has_shared))\n", "print(has_shared) # updates don't propagate" ] }, { "cell_type": "code", "execution_count": 19, "id": "6298f3d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Duplicates found:\n", "- [('a',), ('b',)]\n" ] } ], "source": [ "print(\"Duplicates found:\")\n", "if (all_duplicates := nnx.find_duplicates(has_shared)):\n", " for duplicates in all_duplicates:\n", " print(\"-\", duplicates)" ] }, { "cell_type": "code", "execution_count": 20, "id": "00854d38", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mHasShared\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Variable: 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(5, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(5, dtype=int32, weak_type=True),\n", " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "@jax.jit\n", "def h(graphdef, state):\n", " has_shared = nnx.merge(graphdef, state)\n", " has_shared.a[...] = 5\n", "\n", "graphdef, state = nnx.split(has_shared)\n", "h(graphdef, state)\n", "print(has_shared)" ] }, { "cell_type": "code", "execution_count": 21, "id": "195296c8", "metadata": {}, "outputs": [], "source": [ "# clean up for CI tests\n", "_ = nnx.var_defaults(hijax=current_mode)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/hijax/hijax.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Hijax ```{code-cell} ipython3 from flax import nnx import jax import jax.numpy as jnp import optax current_mode = nnx.var_defaults().hijax # ignore: only needed for testing ``` ```{code-cell} ipython3 nnx.var_defaults(hijax=True) rngs = nnx.Rngs(0) model = nnx.Linear(2, 3, rngs=rngs) optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param) @jax.jit def train_step(x, y): loss_fn = lambda m: jnp.mean((m(x) - y) ** 2) loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad optimizer.update(model, grads) return loss x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3)) for _ in range(3): print(train_step(x, y)) ``` ## Hijax Variable +++ State propagation: ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0), hijax=True) @jax.jit def inc(v): v[...] += 1 print(v[...]); inc(v); print(v[...]) ``` ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0), hijax=True) print(jax.make_jaxpr(inc)(v)) ``` Pytree values: ```{code-cell} ipython3 v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True) @jax.jit def inc_and_double(v): v['a'] += 1 v['b'] *= 2 print(v); inc_and_double(v); print(v) ``` Dynamic state structure: ```{code-cell} ipython3 rngs = nnx.Rngs(0) x = rngs.uniform((4, 5)) w = rngs.normal((5, 3)) metrics = nnx.Variable({}, hijax=True) @jax.jit def linear(x, w, metrics: nnx.Variable): y = x @ w metrics['y_mean'] = jnp.mean(y) return y print("Before:", metrics) y = linear(x, w, metrics) print("After:", metrics) ``` ```{code-cell} ipython3 # set default Variable mode for the rest of the guide nnx.var_defaults(hijax=True) variable = nnx.Variable(jnp.array([1, 2, 3])) print(variable) ``` ### Mutability ```{code-cell} ipython3 class Linear(nnx.Module): def __init__(self, in_features, out_features, rngs: nnx.Rngs): self.kernel = nnx.Param(rngs.normal((in_features, out_features))) def __call__(self, x): return x @ self.kernel model = Linear(1, 3, rngs=nnx.Rngs(0)) print(f"{nnx.vars_as(model, mutable=False) = !s}") print(f"{nnx.vars_as(model, mutable=True) = !s}") ``` ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) v_immut = nnx.vars_as(v, mutable=False) assert not v_immut.mutable try: v_immut[...] += 1 # raises an error except Exception as e: print(f"{type(e).__name__}: {e}") ``` ### Ref support ```{code-cell} ipython3 v = nnx.Variable(jnp.array(0)) v_ref = nnx.vars_as(v, ref=True) assert v_ref.ref print(v_ref) print(v_ref.get_raw_value()) ``` ```{code-cell} ipython3 v_immut = nnx.vars_as(v_ref, mutable=False) assert not v_immut.ref print("immutable =", v_immut) v_ref = nnx.vars_as(v_immut, mutable=True) assert v_ref.ref print("mutable =", v_ref) ``` ### Examples ```{code-cell} ipython3 class Block(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) self.linear_out = Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.gelu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) ``` #### Training Loop ```{code-cell} ipython3 # hijax Variables by default model = Block(2, 64, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @jax.jit def train_step(model, optimizer, x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return ((model(x) - y) ** 2).mean() loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad optimizer.update(model, grads) return loss for _ in range(3): loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3))) print(f"{loss = !s}") ``` #### Scan Over Layers ```{code-cell} ipython3 # TODO: does not work with hijax yet # @jax.vmap # def create_stack(rngs): # return nnx.as_immutable_vars(Block(2, 64, 2, rngs=rngs)) # block_stack = nnx.as_mutable_vars(create_stack(nnx.Rngs(0).fork(split=8))) # def scan_fn(x, block): # x = block(x) # return x, None # x = jax.random.uniform(jax.random.key(0), (3, 2)) # y, _ = jax.lax.scan(scan_fn, x, block_stack) # print("y = ", y) ``` ### Limitations +++ #### Mutable Outputs ```{code-cell} ipython3 @jax.jit def create_model(rngs): return Block(2, 64, 3, rngs=rngs) try: model = create_model(nnx.Rngs(0)) except Exception as e: print(f"Error:", e) ``` ```{code-cell} ipython3 @jax.jit def create_model(rngs): return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False) model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True) print("model.linear =", model.linear) ``` #### Reference Sharing (aliasing) ```{code-cell} ipython3 # NOTE: doesn't currently fail on the jax side def get_error(f, *args): try: return f(*args) except Exception as e: return f"{type(e).__name__}: {e}" x = nnx.Variable(jnp.array(0)) @jax.jit def f(a, b): ... print(get_error(f, x, x)) ``` ```{code-cell} ipython3 # NOTE: doesn't currently fail on the jax side class HasShared(nnx.Pytree): def __init__(self): self.a = nnx.Variable(jnp.array(0)) self.b = self.a @jax.jit def g(has_shared): has_shared.a[...] = 5 has_shared = HasShared() print(get_error(g, has_shared)) print(has_shared) # updates don't propagate ``` ```{code-cell} ipython3 print("Duplicates found:") if (all_duplicates := nnx.find_duplicates(has_shared)): for duplicates in all_duplicates: print("-", duplicates) ``` ```{code-cell} ipython3 @jax.jit def h(graphdef, state): has_shared = nnx.merge(graphdef, state) has_shared.a[...] = 5 graphdef, state = nnx.split(has_shared) h(graphdef, state) print(has_shared) ``` ```{code-cell} ipython3 # clean up for CI tests _ = nnx.var_defaults(hijax=current_mode) ``` ================================================ FILE: docs_nnx/hijax/index.rst ================================================ Hijax (experimental) ==================== Basic usage ^^^^^^^^^^^^ .. testsetup:: import jax import jax.numpy as jnp current_mode = nnx.var_defaults().hijax .. testcode:: from flax import nnx import optax nnx.var_defaults(hijax=True) class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x, rngs): x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs)) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @jax.jit def train_step(model, optimizer, rngs, x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return ((model(x, rngs) - y) ** 2).mean() loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) optimizer.update(model, grads) # in-place updates return loss nnx.var_defaults(hijax=current_mode) # clean up for CI tests ---- .. toctree:: :hidden: :maxdepth: 2 hijax ================================================ FILE: docs_nnx/index.rst ================================================ Flax ==== .. div:: sd-text-left sd-font-italic **N**\ eural **N**\ etworks for JA\ **X** ---- Flax provides a **flexible end-to-end user experience for researchers and developers who use JAX for neural networks**. Flax enables you to use the full power of `JAX `__. At the core of Flax is **NNX - a simplified API that makes it easier to create, inspect, debug, and analyze neural networks in JAX.** Flax NNX has first class support for Python reference semantics, enabling users to express their models using regular Python objects. Flax NNX is an evolution of the previous `Flax Linen `__ API, and it took years of experience to bring a simpler and more user-friendly API. .. note:: Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out `Why Flax NNX `_ for a comparison between Flax NNX and Linen, and our reasoning to make the new API. To move your Flax Linen codebase to Flax NNX, get familiarized with the API in `NNX Basics `_ and then start your move following the `evolution guide `_. Features ^^^^^^^^^ .. grid:: .. grid-item:: :columns: 12 12 12 6 .. card:: Pythonic :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax NNX supports the use of regular Python objects, providing an intuitive and predictable development experience. .. grid-item:: :columns: 12 12 12 6 .. card:: Simple :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax NNX relies on Python's object model, which results in simplicity for the user and increases development speed. .. grid-item:: :columns: 12 12 12 6 .. card:: Expressive :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax NNX allows fine-grained control of the model's state via its `Filter `__ system. .. grid-item:: :columns: 12 12 12 6 .. card:: Familiar :class-card: sd-border-0 :shadow: none :class-title: sd-fs-5 .. div:: sd-font-normal Flax NNX makes it very easy to integrate objects with regular JAX code via the `Functional API `__. Basic usage ^^^^^^^^^^^^ .. testsetup:: import jax import jax.numpy as jnp .. testcode:: from flax import nnx import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x, rngs): x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs)) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @nnx.jit # automatic state propagation def train_step(model, optimizer, x, y): loss_fn = lambda model: ((model(x) - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates return loss Installation ^^^^^^^^^^^^ Install via pip: .. code-block:: bash pip install flax Or install the latest version from the repository: .. code-block:: bash pip install git+https://github.com/google/flax.git ---- Learn more ^^^^^^^^^^ .. grid:: .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`rocket_launch;2em` Flax NNX Basics :class-card: sd-text-black sd-bg-light :link: nnx_basics.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`library_books;2em` MNIST Tutorial :class-card: sd-text-black sd-bg-light :link: mnist_tutorial.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`library_books;2em` Guides :class-card: sd-text-black sd-bg-light :link: guides/index.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`transform;2em` Flax Linen to Flax NNX :class-card: sd-text-black sd-bg-light :link: guides/linen_to_nnx.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`menu_book;2em` API reference :class-card: sd-text-black sd-bg-light :link: api_reference/index.html .. grid-item:: :columns: 6 6 6 4 .. card:: :material-regular:`import_contacts;2em` Glossary :class-card: sd-text-black sd-bg-light :link: nnx_glossary.html ---- .. toctree:: :hidden: :maxdepth: 3 nnx_basics mnist_tutorial why key_concepts guides_basic guides_advanced Models (Bonsai) Post-training (Tunix) hijax/index migrating/index examples/index nnx_glossary philosophy contributing api_reference/index ================================================ FILE: docs_nnx/key_concepts.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "929920d0", "metadata": {}, "source": [ "# JAX/Flax Key Concepts\n", "\n", "Flax is a **neural network library** built on top of JAX, a language for **accelerated numerical computations**. In effect, Flax is a pretty thin layer, and you likely will use some JAX APIs directly to do anything more than using the built-in Flax modules.\n", "\n", "This means a **basic understanding on JAX helps you to use Flax well**. You would have better a mental model to understand what's happening underneath and how to debug a confusing error. This doc aims to clarify a few key concepts and help you build that uniquely-JAX mental model as a practical model developer (pun intended).\n", "\n", "[JAX documentations](https://docs.jax.dev/en/latest/index.html) are great sources to learn more. We recommend all Flax users to at least read the [JAX Key Concepts](https://docs.jax.dev/en/latest/key-concepts.html) doc." ] }, { "cell_type": "code", "execution_count": 1, "id": "3515d62b", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import flax\n", "from flax import nnx\n", "from functools import partial\n", "\n", "# For simulating multi-device environment\n", "jax.config.update('jax_num_cpu_devices', 8)" ] }, { "cell_type": "markdown", "id": "be2cad4a", "metadata": {}, "source": [ "## What is JAX?\n", "\n", "JAX is the lower level library that does **all the large-scale data computations**. It provides the singular data container, aka the `jax.Array`, and all the ways we possibly deal with them:\n", "\n", "* **Make arithmetic operations upon the arrays**, including: the `jax.numpy` ops, automatic differentiation (`jax.grad`), batching (`jax.vmap`), and more.\n", "\n", "* **Run computation on accelerators**, including: interface with various accelerator platforms and layouts; allocating buffers for arrays; compile and execute computation programs across accelerators.\n", "\n", "* **Bundle multiple arrays together** using a simple concept called [pytrees](#pytrees).\n", "\n", "This implies that any error related with accelerators and numericals are probably a JAX issue, or an issue with Flax built-in layers.\n", "\n", "It also means you *can* build a neural network model with JAX alone, especially if you are comfortable with functional programming. JAX docsite have some [simple examples](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). The article [GPT in 60 Lines of NumPy](https://jaykmody.com/blog/gpt-from-scratch/) also shows how to implement all the key elements of a GPT using JAX." ] }, { "cell_type": "code", "execution_count": 2, "id": "a3769631", "metadata": {}, "outputs": [], "source": [ "def jax_linear(x, kernel, bias):\n", " return jnp.dot(x, kernel) + bias\n", "\n", "params = {'kernel': jax.random.normal(jax.random.key(42), (4, 2)), \n", " 'bias': jnp.zeros((2,))}\n", "x = jax.random.normal(jax.random.key(0), (2, 4))\n", "y = jax_linear(x, params['kernel'], params['bias'])" ] }, { "cell_type": "markdown", "id": "ee6f86e7", "metadata": {}, "source": [ "## What is Flax?\n", "\n", "Flax is a **neural network toolkit**, offering higher level abstractions that are handy for model developers. Such as:\n", "\n", "* **Object-oriented `Module` class** to represent layers/models and bookkeep parameters.\n", "\n", "* **Modeling utilities** like random number handling, model traversal and surgery, optimizers, advanced parameter bookkeeping, sharding annotations, and more.\n", "\n", "* **Some built-in commonly-used** layers, initializers, and model examples.\n", "\n", "Take the example below: A Flax layer `Linear`, during initialization, takes one RNG key and automatically initialize all internal parameters as `jax.Array`s. In forward pass, it carries out the exact same computation via JAX APIs." ] }, { "cell_type": "code", "execution_count": 3, "id": "14caace1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] } ], "source": [ "# Eligible parameters were created inside `linear`, using one RNG key 42\n", "linear = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(42))\n", "\n", "# Flax created a `Param` wrapper over the actual `jax.Array` parameter to track metadata\n", "print(type(linear.kernel)) # flax.nnx.Param\n", "print(type(linear.kernel.value)) # jax.Array\n", "\n", "# The computation of the two are the same\n", "x = jax.random.normal(jax.random.key(0), (2, 4))\n", "flax_y = linear(x)\n", "jax_y = jax_linear(x, linear.kernel.value, linear.bias.value)\n", "assert jnp.array_equal(flax_y, jax_y)" ] }, { "cell_type": "markdown", "id": "09989bf7", "metadata": {}, "source": [ "## Pytrees\n", "\n", "Your code likely needs more than one `jax.Array`. A **pytree** is a container structure of multiple pytrees, possibly nested. It is a key and handly concept in the JAX world.\n", "\n", "Many things are pytrees: Python dicts, lists, tuples, dataclasses, and more. The key is that a pytree can be \"flattened\" into multiple children, which are either pytrees or individual leaves - a `jax.Array` counts as a leaf. Other metadata of a pytree are stored in the `PyTreeDef` object, allowing \"unflattening\" to restore the old pytree.\n", "\n", "Pytree is the primary data holder in JAX. When JAX transforms see a pytree argument, they automatically trace its internal `jax.Array`s when compiling. Therefore, it's crucial to organize your data as pytrees. You can use [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass) to quickly construct a pytree node dataclass, or register your own classes via JAX API. [JAX pytree documentation](https://docs.jax.dev/en/latest/working-with-pytrees.html) has a thorough overview on pytrees and JAX APIs to manipulate them. \n", "\n", "In Flax, a `Module` is a pytree, and variables are its flattenable data. This means you can directly run JAX transforms upon a Flax model." ] }, { "cell_type": "code", "execution_count": 4, "id": "a2059c47", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "linear.bias.value: [0. 0.]\n", "linear.kernel.value: [[ 0.04119061 -0.2629074 ]\n", " [ 0.6772455 0.2807398 ]\n", " [ 0.16276604 0.16813846]\n", " [ 0.310975 -0.43336964]]\n", "treedef = PyTreeDef(CustomNode(Linear[(('_pytree__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'kernel', '_pytree__state', 'bias'})), ('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x120f45260>), ('out_features', 2), ('param_dtype', ), ('precision', None), ('promote_dtype', ), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])]))\n" ] } ], "source": [ "# Flatten allows you to see all the content inside a pytree\n", "arrays, treedef = jax.tree.flatten_with_path(linear)\n", "assert len(arrays) > 1\n", "for kp, value in arrays:\n", " print(f'linear{jax.tree_util.keystr(kp)}: {value}')\n", "print(f'{treedef = }')\n", "\n", "# Unflatten brings the pytree back intact\n", "linear = jax.tree.unflatten(treedef, [value for _, value in arrays])" ] }, { "cell_type": "code", "execution_count": 5, "id": "4ea2f351", "metadata": {}, "outputs": [], "source": [ "y = jax.jit(linear)(x) # JAX transforms works on Flax modules" ] }, { "cell_type": "markdown", "id": "723b3f42", "metadata": {}, "source": [ "## Traced vs. static data\n", "\n", "A pytree *contains* JAX arrays, but a pytree is *more than* its JAX arrays. For example, a dictionary keeps information like the key of every array, and it might contain entries that are not JAX arrays. From JAX's standpoint, all data are one of the two types:\n", "\n", "* **Traced** (\"dynamic\") data: JAX will trace them during compilation and optimize the operations upon them. If they stay inside a pytree argument, `jax.tree.flatten` must return them as leaves. They must be data values (`jax.Array`, Numpy array, scalar, etc), and implement basic functionalities like `__eq__` and `__hash__`.\n", "\n", "* **\"Static\"** data: They stay as simple Python objects that don't get traced by JAX.\n", "\n", "In practice, you would want to control what data goes into dynamic, and what to static. Dynamic data and their computation will be optimized by JAX, but you cannot base your code control flow upon its values. Non-data values like strings must stay static.\n", "\n", "Take a Flax model: you would want JAX to only track and optimize its parameters, and the RNG keys. For trivial things like the model hyperparameters (e.g., the param shape, the initializer function), they can stay static to save compilation bandwidth and allow code path customization.\n", "\n", "Current Flax module automatically classifies this for you. Only the `jax.Array` attributes are treated as dynamic data, unless you explicitly wrap a data value using `nnx.Variable` classes." ] }, { "cell_type": "code", "execution_count": 6, "id": "7b6cbc3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['rng']['default']['count'].value: 1\n", "['rng']['default']['key'].value: Array((), dtype=key) overlaying:\n", "[0 0]\n", "['traced_dim'].value: 4\n", "['w'].value: [[ 1.0040143 -0.9063372 -0.7481722 -1.1713669 ]\n", " [-0.8712328 0.5888381 0.72392994 -1.0255982 ]\n", " [ 1.661628 -1.8910251 -1.2889339 0.13360691]\n", " [-1.1530392 0.23929629 1.7448074 0.5050189 ]]\n" ] } ], "source": [ "class Foo(nnx.Module):\n", " def __init__(self, dim, rngs):\n", " self.w = nnx.Param(jax.random.normal(rngs.param(), (dim, dim)))\n", " self.dim = dim\n", " self.traced_dim = nnx.Param(dim) # This became traced!\n", " self.rng = rngs\n", "\n", "foo = Foo(4, nnx.Rngs(0))\n", "for kp, x in jax.tree.flatten_with_path(nnx.state(foo))[0]:\n", " print(f'{jax.tree_util.keystr(kp)}: {x}')" ] }, { "cell_type": "markdown", "id": "0b10383c", "metadata": {}, "source": [ "When compiling a function using this pytree, you'll notice the difference between traced and static values. You can only use static ones in control flows." ] }, { "cell_type": "code", "execution_count": 7, "id": "395c9d79", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.dim = 4\n", "model.traced_dim.value = JitTracer<~int32[]>\n", "Code path based on static data value works fine.\n", "Code path based on JAX data value throws error: Attempted boolean conversion of traced array with shape bool[].\n", "The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.value.\n", "See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError\n" ] } ], "source": [ "@jax.jit\n", "def jitted(model):\n", " print(f'{model.dim = }')\n", " print(f'{model.traced_dim.value = }') # This is being traced\n", " if model.dim == 4:\n", " print('Code path based on static data value works fine.')\n", " try:\n", " if model.traced_dim.value == 4:\n", " print('This will never run :(')\n", " except jax.errors.TracerBoolConversionError as e:\n", " print(f'Code path based on JAX data value throws error: {e}')\n", "\n", "jitted(foo)" ] }, { "cell_type": "markdown", "id": "202bf52b", "metadata": {}, "source": [ "## Abstract arrays\n", "\n", "Abstract array is a JAX class to represent an array not by its value, but simply by its metadata information like shape, dtype and sharding. It is fast and handy because it doesn't allocate any memory for the array data.\n", "\n", "You can construct an abstract array by calling [`jax.ShapeDtypeStruct`](https://docs.jax.dev/en/latest/_autosummary/jax.ShapeDtypeStruct.html) on your own, or use [`jax.eval_shape`](https://docs.jax.dev/en/latest/_autosummary/jax.eval_shape.html), which takes a function and arguments and returns the abstract version of its output." ] }, { "cell_type": "code", "execution_count": 8, "id": "21ebeebf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 1.0040143 -0.9063372 -0.7481722 -1.1713669 ]\n", " [-0.8712328 0.5888381 0.72392994 -1.0255982 ]\n", " [ 1.661628 -1.8910251 -1.2889339 0.13360691]\n", " [-1.1530392 0.23929629 1.7448074 0.5050189 ]]\n", "ShapeDtypeStruct(shape=(4, 4), dtype=float32)\n" ] } ], "source": [ "print(x)\n", "abs_x = jax.eval_shape(lambda x: x, x)\n", "print(abs_x)" ] }, { "cell_type": "markdown", "id": "e8345d12", "metadata": {}, "source": [ "It is a good way to dry-run your code and debug a model without any actual compute and memory cost. For example, you can have an overview of the parameters inside this very large model." ] }, { "cell_type": "code", "execution_count": null, "id": "f9b1b308", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 67,084,290 (268.3 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 8,190 (32.8 KB)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(8190,), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 67,076,100 (268.3 MB)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mShapeDtypeStruct(shape=(8190, 8190), dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m,\n", " \u001b[38;2;156;220;254mbias_init\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m,\n", " \u001b[38;2;156;220;254mdot_general\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m,\n", " \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", " \u001b[38;2;156;220;254min_features\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m8190\u001b[0m,\n", " \u001b[38;2;156;220;254mkernel_init\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m.init at 0x120f45260>,\n", " \u001b[38;2;156;220;254mout_features\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m8190\u001b[0m,\n", " \u001b[38;2;156;220;254mparam_dtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mfloat32\u001b[0m,\n", " \u001b[38;2;156;220;254mprecision\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", " \u001b[38;2;156;220;254mpromote_dtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m,\n", " \u001b[38;2;156;220;254muse_bias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] } ], "source": [ "class MLP(nnx.Module):\n", " def __init__(self, dim, nlayers, rngs):\n", " self.blocks = nnx.List([nnx.Linear(dim, dim, rngs=rngs) for _ in range(nlayers)])\n", " self.activation = jax.nn.relu\n", " self.nlayers = nlayers\n", " def __call__(self, x):\n", " for block in self.blocks:\n", " x = self.activation(block(x))\n", " return x\n", "\n", "dim, nlayers = 8190, 64 # Some very big numbers\n", "@partial(jax.jit, static_argnums=(0, 1))\n", "def init_state(dim, nlayers):\n", " return MLP(dim, nlayers, nnx.Rngs(0))\n", "abstract_model = jax.eval_shape(partial(init_state, dim, nlayers))\n", "print(abstract_model.blocks[0])" ] }, { "cell_type": "markdown", "id": "8894cbc6", "metadata": {}, "source": [ "Once you have an abstract pytree for your function input or output, it's easier to describe how you want your data to be sharded. You should use such a pytree with sharding information to instruct your checkpoint loading library to load your arrays distributedly. Our checkpointing guide contains [an example of how to do this](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-a-sharded-model-from-a-checkpoint)." ] }, { "cell_type": "markdown", "id": "b98b5184", "metadata": {}, "source": [ "## Distributed computing\n", "\n", "Another big use case for abstract pytrees is to tell JAX machinery how you want each array to be sharded during any point of your computation.\n", "\n", "Remember what we mentioned earlier? JAX handles the actual computation and data allocation on accelerators. This means you **must** use some `jax.jit`-compiled function to run any distributed computation task.\n", "\n", "There are a few ways to tell `jax.jit` of your model sharding. The simplest way is to call `jax.lax.with_sharding_constraint` to constraint the so-to-be model with your predetermined model sharding." ] }, { "cell_type": "code", "execution_count": 10, "id": "9b289c02", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mNamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)\n", "\u001b[38;2;255;213;3m)\u001b[0m\n" ] }, { "data": { "text/html": [ "
                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  \n",
       "                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "                                                                        \n",
       "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mCPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mCPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mCPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mCPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mCPU 4\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mCPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mCPU 6\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mCPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Some smaller numbers so that we actually can run it\n", "dim, nlayers = 1024, 2\n", "abstract_model = jax.eval_shape(partial(init_state, dim, nlayers))\n", "mesh = jax.make_mesh((jax.device_count(), ), 'model')\n", "\n", "# Generate sharding for each of your params manually, sharded along the last axis.\n", "def make_sharding(abs_x):\n", " if len(abs_x.shape) > 1:\n", " pspec = jax.sharding.PartitionSpec(None, 'model') # kernel\n", " else:\n", " pspec = jax.sharding.PartitionSpec('model',) # bias\n", " return jax.sharding.NamedSharding(mesh, pspec)\n", "model_shardings = jax.tree.map(make_sharding, abstract_model)\n", "print(model_shardings.blocks[0].kernel)\n", "\n", "@partial(jax.jit, static_argnums=(0, 1))\n", "def sharded_init(dim, nlayers):\n", " model = MLP(dim, nlayers, nnx.Rngs(0))\n", " return jax.lax.with_sharding_constraint(model, model_shardings)\n", "model = sharded_init(dim, nlayers)\n", "jax.debug.visualize_array_sharding(model.blocks[0].kernel.value)" ] }, { "cell_type": "markdown", "id": "3b7ac7ea", "metadata": {}, "source": [ "The example below are just to showcase how to do sharding in pure JAX API. Flax offers a small API to annotate the sharding when you define a parameter, so that you don't have to write an arbitrary `make_sharding()` function at top level. Check out our [GSPMD guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) to learn more." ] }, { "cell_type": "markdown", "id": "7dbf421a", "metadata": {}, "source": [ "## Transformations\n", "\n", "For Flax transforms and their relation with JAX transforms, refer to [Flax Transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) doc. This should be a rarer use case now that Flax NNX modules are JAX pytrees." ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/key_concepts.md ================================================ # JAX/Flax Key Concepts Flax is a **neural network library** built on top of JAX, a language for **accelerated numerical computations**. In effect, Flax is a pretty thin layer, and you likely will use some JAX APIs directly to do anything more than using the built-in Flax modules. This means a **basic understanding on JAX helps you to use Flax well**. You would have better a mental model to understand what's happening underneath and how to debug a confusing error. This doc aims to clarify a few key concepts and help you build that uniquely-JAX mental model as a practical model developer (pun intended). [JAX documentations](https://docs.jax.dev/en/latest/index.html) are great sources to learn more. We recommend all Flax users to at least read the [JAX Key Concepts](https://docs.jax.dev/en/latest/key-concepts.html) doc. ```python import jax import jax.numpy as jnp import flax from flax import nnx from functools import partial # For simulating multi-device environment jax.config.update('jax_num_cpu_devices', 8) ``` ## What is JAX? JAX is the lower level library that does **all the large-scale data computations**. It provides the singular data container, aka the `jax.Array`, and all the ways we possibly deal with them: * **Make arithmetic operations upon the arrays**, including: the `jax.numpy` ops, automatic differentiation (`jax.grad`), batching (`jax.vmap`), and more. * **Run computation on accelerators**, including: interface with various accelerator platforms and layouts; allocating buffers for arrays; compile and execute computation programs across accelerators. * **Bundle multiple arrays together** using a simple concept called [pytrees](#pytrees). This implies that any error related with accelerators and numericals are probably a JAX issue, or an issue with Flax built-in layers. It also means you *can* build a neural network model with JAX alone, especially if you are comfortable with functional programming. JAX docsite have some [simple examples](https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html). The article [GPT in 60 Lines of NumPy](https://jaykmody.com/blog/gpt-from-scratch/) also shows how to implement all the key elements of a GPT using JAX. ```python def jax_linear(x, kernel, bias): return jnp.dot(x, kernel) + bias params = {'kernel': jax.random.normal(jax.random.key(42), (4, 2)), 'bias': jnp.zeros((2,))} x = jax.random.normal(jax.random.key(0), (2, 4)) y = jax_linear(x, params['kernel'], params['bias']) ``` ## What is Flax? Flax is a **neural network toolkit**, offering higher level abstractions that are handy for model developers. Such as: * **Object-oriented `Module` class** to represent layers/models and bookkeep parameters. * **Modeling utilities** like random number handling, model traversal and surgery, optimizers, advanced parameter bookkeeping, sharding annotations, and more. * **Some built-in commonly-used** layers, initializers, and model examples. Take the example below: A Flax layer `Linear`, during initialization, takes one RNG key and automatically initialize all internal parameters as `jax.Array`s. In forward pass, it carries out the exact same computation via JAX APIs. ```python # Eligible parameters were created inside `linear`, using one RNG key 42 linear = nnx.Linear(in_features=4, out_features=2, rngs=nnx.Rngs(42)) # Flax created a `Param` wrapper over the actual `jax.Array` parameter to track metadata print(type(linear.kernel)) # flax.nnx.Param print(type(linear.kernel.value)) # jax.Array # The computation of the two are the same x = jax.random.normal(jax.random.key(0), (2, 4)) flax_y = linear(x) jax_y = jax_linear(x, linear.kernel.value, linear.bias.value) assert jnp.array_equal(flax_y, jax_y) ``` ## Pytrees Your code likely needs more than one `jax.Array`. A **pytree** is a container structure of multiple pytrees, possibly nested. It is a key and handly concept in the JAX world. Many things are pytrees: Python dicts, lists, tuples, dataclasses, and more. The key is that a pytree can be "flattened" into multiple children, which are either pytrees or individual leaves - a `jax.Array` counts as a leaf. Other metadata of a pytree are stored in the `PyTreeDef` object, allowing "unflattening" to restore the old pytree. Pytree is the primary data holder in JAX. When JAX transforms see a pytree argument, they automatically trace its internal `jax.Array`s when compiling. Therefore, it's crucial to organize your data as pytrees. You can use [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass) to quickly construct a pytree node dataclass, or register your own classes via JAX API. [JAX pytree documentation](https://docs.jax.dev/en/latest/working-with-pytrees.html) has a thorough overview on pytrees and JAX APIs to manipulate them. In Flax, a `Module` is a pytree, and variables are its flattenable data. This means you can directly run JAX transforms upon a Flax model. ```python # Flatten allows you to see all the content inside a pytree arrays, treedef = jax.tree.flatten_with_path(linear) assert len(arrays) > 1 for kp, value in arrays: print(f'linear{jax.tree_util.keystr(kp)}: {value}') print(f'{treedef = }') # Unflatten brings the pytree back intact linear = jax.tree.unflatten(treedef, [value for _, value in arrays]) ``` linear.bias.value: [0. 0.] linear.kernel.value: [[ 0.04119061 -0.2629074 ] [ 0.6772455 0.2807398 ] [ 0.16276604 0.16813846] [ 0.310975 -0.43336964]] treedef = PyTreeDef(CustomNode(Linear[(('_pytree__state', 'bias', 'kernel'), (('_object__nodes', frozenset({'kernel', '_pytree__state', 'bias'})), ('bias_init', ), ('dot_general', ), ('dtype', None), ('in_features', 4), ('kernel_init', .init at 0x120f45260>), ('out_features', 2), ('param_dtype', ), ('precision', None), ('promote_dtype', ), ('use_bias', True)))], [CustomNode(ObjectState[(False, False)], []), CustomNode(Param[()], [*]), CustomNode(Param[()], [*])])) ```python y = jax.jit(linear)(x) # JAX transforms works on Flax modules ``` ## Traced vs. static data A pytree *contains* JAX arrays, but a pytree is *more than* its JAX arrays. For example, a dictionary keeps information like the key of every array, and it might contain entries that are not JAX arrays. From JAX's standpoint, all data are one of the two types: * **Traced** ("dynamic") data: JAX will trace them during compilation and optimize the operations upon them. If they stay inside a pytree argument, `jax.tree.flatten` must return them as leaves. They must be data values (`jax.Array`, Numpy array, scalar, etc), and implement basic functionalities like `__eq__` and `__hash__`. * **"Static"** data: They stay as simple Python objects that don't get traced by JAX. In practice, you would want to control what data goes into dynamic, and what to static. Dynamic data and their computation will be optimized by JAX, but you cannot base your code control flow upon its values. Non-data values like strings must stay static. Take a Flax model: you would want JAX to only track and optimize its parameters, and the RNG keys. For trivial things like the model hyperparameters (e.g., the param shape, the initializer function), they can stay static to save compilation bandwidth and allow code path customization. Current Flax module automatically classifies this for you. Only the `jax.Array` attributes are treated as dynamic data, unless you explicitly wrap a data value using `nnx.Variable` classes. ```python class Foo(nnx.Module): def __init__(self, dim, rngs): self.w = nnx.Param(jax.random.normal(rngs.param(), (dim, dim))) self.dim = dim self.traced_dim = nnx.Param(dim) # This became traced! self.rng = rngs foo = Foo(4, nnx.Rngs(0)) for kp, x in jax.tree.flatten_with_path(nnx.state(foo))[0]: print(f'{jax.tree_util.keystr(kp)}: {x}') ``` ['rng']['default']['count'].value: 1 ['rng']['default']['key'].value: Array((), dtype=key) overlaying: [0 0] ['traced_dim'].value: 4 ['w'].value: [[ 1.0040143 -0.9063372 -0.7481722 -1.1713669 ] [-0.8712328 0.5888381 0.72392994 -1.0255982 ] [ 1.661628 -1.8910251 -1.2889339 0.13360691] [-1.1530392 0.23929629 1.7448074 0.5050189 ]] When compiling a function using this pytree, you'll notice the difference between traced and static values. You can only use static ones in control flows. ```python @jax.jit def jitted(model): print(f'{model.dim = }') print(f'{model.traced_dim.value = }') # This is being traced if model.dim == 4: print('Code path based on static data value works fine.') try: if model.traced_dim.value == 4: print('This will never run :(') except jax.errors.TracerBoolConversionError as e: print(f'Code path based on JAX data value throws error: {e}') jitted(foo) ``` model.dim = 4 model.traced_dim.value = JitTracer<~int32[]> Code path based on static data value works fine. Code path based on JAX data value throws error: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function jitted at /var/folders/4c/ylxxyg_n67957jf6616c7z5000gbn1/T/ipykernel_69242/584946237.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument model.traced_dim.value. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError ## Abstract arrays Abstract array is a JAX class to represent an array not by its value, but simply by its metadata information like shape, dtype and sharding. It is fast and handy because it doesn't allocate any memory for the array data. You can construct an abstract array by calling [`jax.ShapeDtypeStruct`](https://docs.jax.dev/en/latest/_autosummary/jax.ShapeDtypeStruct.html) on your own, or use [`jax.eval_shape`](https://docs.jax.dev/en/latest/_autosummary/jax.eval_shape.html), which takes a function and arguments and returns the abstract version of its output. ```python print(x) abs_x = jax.eval_shape(lambda x: x, x) print(abs_x) ``` [[ 1.0040143 -0.9063372 -0.7481722 -1.1713669 ] [-0.8712328 0.5888381 0.72392994 -1.0255982 ] [ 1.661628 -1.8910251 -1.2889339 0.13360691] [-1.1530392 0.23929629 1.7448074 0.5050189 ]] ShapeDtypeStruct(shape=(4, 4), dtype=float32) It is a good way to dry-run your code and debug a model without any actual compute and memory cost. For example, you can have an overview of the parameters inside this very large model. ```python class MLP(nnx.Module): def __init__(self, dim, nlayers, rngs): self.blocks = [nnx.Linear(dim, dim, rngs=rngs) for _ in range(nlayers)] self.activation = jax.nn.relu self.nlayers = nlayers def __call__(self, x): for block in self.blocks: x = self.activation(block(x)) return x dim, nlayers = 8190, 64 # Some very big numbers @partial(jax.jit, static_argnums=(0, 1)) def init_state(dim, nlayers): return MLP(dim, nlayers, nnx.Rngs(0)) abstract_model = jax.eval_shape(partial(init_state, dim, nlayers)) print(abstract_model.blocks[0]) ``` Linear( # Param: 67,084,290 (268.3 MB) bias=Param( # 8,190 (32.8 KB) value=ShapeDtypeStruct(shape=(8190,), dtype=float32) ), kernel=Param( # 67,076,100 (268.3 MB) value=ShapeDtypeStruct(shape=(8190, 8190), dtype=float32) ), bias_init=, dot_general=, dtype=None, in_features=8190, kernel_init=.init at 0x120f45260>, out_features=8190, param_dtype=float32, precision=None, promote_dtype=, use_bias=True ) Once you have an abstract pytree for your function input or output, it's easier to describe how you want your data to be sharded. You should use such a pytree with sharding information to instruct your checkpoint loading library to load your arrays distributedly. Our checkpointing guide contains [an example of how to do this](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html#load-a-sharded-model-from-a-checkpoint). ## Distributed computing Another big use case for abstract pytrees is to tell JAX machinery how you want each array to be sharded during any point of your computation. Remember what we mentioned earlier? JAX handles the actual computation and data allocation on accelerators. This means you **must** use some `jax.jit`-compiled function to run any distributed computation task. There are a few ways to tell `jax.jit` of your model sharding. The simplest way is to call `jax.lax.with_sharding_constraint` to constraint the so-to-be model with your predetermined model sharding. ```python # Some smaller numbers so that we actually can run it dim, nlayers = 1024, 2 abstract_model = jax.eval_shape(partial(init_state, dim, nlayers)) mesh = jax.make_mesh((jax.device_count(), ), 'model') # Generate sharding for each of your params manually, sharded along the last axis. def make_sharding(abs_x): if len(abs_x.shape) > 1: pspec = jax.sharding.PartitionSpec(None, 'model') # kernel else: pspec = jax.sharding.PartitionSpec('model',) # bias return jax.sharding.NamedSharding(mesh, pspec) model_shardings = jax.tree.map(make_sharding, abstract_model) print(model_shardings.blocks[0].kernel) @partial(jax.jit, static_argnums=(0, 1)) def sharded_init(dim, nlayers): model = MLP(dim, nlayers, nnx.Rngs(0)) return jax.lax.with_sharding_constraint(model, model_shardings) model = sharded_init(dim, nlayers) jax.debug.visualize_array_sharding(model.blocks[0].kernel.value) ``` Param( value=NamedSharding(mesh=Mesh('model': 8, axis_types=(Auto,)), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host) )
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
  CPU 0    CPU 1    CPU 2    CPU 3    CPU 4    CPU 5    CPU 6    CPU 7  
                                                                        
                                                                        
                                                                        
                                                                        
                                                                        
The example below are just to showcase how to do sharding in pure JAX API. Flax offers a small API to annotate the sharding when you define a parameter, so that you don't have to write an arbitrary `make_sharding()` function at top level. Check out our [GSPMD guide](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) to learn more. ## Transformations For Flax transforms and their relation with JAX transforms, refer to [Flax Transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) doc. This should be a rarer use case now that Flax NNX modules are JAX pytrees. ================================================ FILE: docs_nnx/migrating/convert_pytorch_to_flax.rst ================================================ Convert PyTorch models to Flax ============================== .. testsetup:: import numpy as np import jax from jax import random, numpy as jnp from flax import nnx import torch We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling. FC Layers -------------------------------- Let's start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC] and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick. .. testcode:: t_fc = torch.nn.Linear(in_features=3, out_features=4) kernel = t_fc.weight.detach().cpu().numpy() bias = t_fc.bias.detach().cpu().numpy() # [outC, inC] -> [inC, outC] kernel = jnp.transpose(kernel, (1, 0)) key = random.key(0) x = random.normal(key, (1, 3)) j_fc = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) j_fc.kernel.value = kernel j_fc.bias.value = jnp.array(bias) j_out = j_fc(x) t_x = torch.from_numpy(np.array(x)) t_out = t_fc(t_x) t_out = t_out.detach().cpu().numpy() np.testing.assert_almost_equal(j_out, t_out, decimal=6) Convolutions -------------------------------- Let's now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC. Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW] and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick. .. testcode:: t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') kernel = t_conv.weight.detach().cpu().numpy() bias = t_conv.bias.detach().cpu().numpy() # [outC, inC, kH, kW] -> [kH, kW, inC, outC] kernel = jnp.transpose(kernel, (2, 3, 1, 0)) key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=nnx.Rngs(0)) j_conv.kernel.value = kernel j_conv.bias.value = bias j_out = j_conv(x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_conv(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Convolutions and FC Layers -------------------------------- We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W]. Consider this PyTorch model: .. testcode:: class TModel(torch.nn.Module): def __init__(self): super(TModel, self).__init__() self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid') self.fc = torch.nn.Linear(in_features=100, out_features=2) def forward(self, x): x = self.conv(x) x = x.reshape(x.shape[0], -1) x = self.fc(x) return x t_model = TModel() Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this: .. testcode:: class JModel(nnx.Module): def __init__(self, rngs): self.conv = nnx.Conv(3, 4, kernel_size=(2, 2), padding='valid', rngs=rngs) self.linear = nnx.Linear(100, 2, rngs=rngs) def __call__(self, x): x = self.conv(x) # [N, H, W, C] -> [N, C, H, W] x = jnp.transpose(x, (0, 3, 1, 2)) x = jnp.reshape(x, (x.shape[0], -1)) x = self.linear(x) return x j_model = JModel(nnx.Rngs(0)) The model looks very similar to the PyTorch model, except that we included a transpose operation before reshaping our activations for the fc layer. We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1. Other than the transpose operation before reshaping, we can convert the weights the same way as we did before: .. testcode:: conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy() conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy() fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy() fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy() # [outC, inC, kH, kW] -> [kH, kW, inC, outC] conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0)) # [outC, inC] -> [inC, outC] fc_kernel = jnp.transpose(fc_kernel, (1, 0)) j_model.conv.kernel.value = conv_kernel j_model.conv.bias.value = conv_bias j_model.linear.kernel.value = fc_kernel j_model.linear.bias.value = fc_bias key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = j_model(x) t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_model(t_x) t_out = t_out.detach().cpu().numpy() np.testing.assert_almost_equal(j_out, t_out, decimal=6) Batch Norm -------------------------------- ``torch.nn.BatchNorm2d`` uses ``0.1`` as the default value for the ``momentum`` parameter while |nnx.BatchNorm|_ uses ``0.9``. However, this corresponds to the same computation, because PyTorch multiplies the estimated statistic with ``(1 − momentum)`` and the new observed value with ``momentum``, while Flax multiplies the estimated statistic with ``momentum`` and the new observed value with ``(1 − momentum)``. .. |nnx.BatchNorm| replace:: ``nnx.BatchNorm`` .. _nnx.BatchNorm: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm .. testcode:: t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1) t_bn.eval() key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_bn = nnx.BatchNorm(num_features=3, momentum=0.9, use_running_average=True, rngs=nnx.Rngs(0)) j_out = j_bn(x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_bn(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Average Pooling -------------------------------- ``torch.nn.AvgPool2d`` and |nnx.avg_pool()|_ are compatible when using default parameters. However, ``torch.nn.AvgPool2d`` has a parameter ``count_include_pad``. When ``count_include_pad=False``, the zero-padding will not be considered for the average calculation. There does not exist a similar parameter for |nnx.avg_pool()|_. However, we can easily implement a wrapper around the pooling operation. ``nnx.pool()`` is the core function behind |nnx.avg_pool()|_ and |nnx.max_pool()|_. .. |nnx.avg_pool()| replace:: ``nnx.avg_pool()`` .. _nnx.avg_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool .. |nnx.max_pool()| replace:: ``nnx.max_pool()`` .. _nnx.max_pool(): https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool .. testcode:: def avg_pool(inputs, window_shape, strides=None, padding='VALID'): """ Pools the input by taking the average over a window. In comparison to nnx.avg_pool(), this pooling operation does not consider the padded zero's for the average computation. """ assert len(window_shape) == 2 y = nnx.pool(inputs, 0., jax.lax.add, window_shape, strides, padding) counts = nnx.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding) y = y / counts return y key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1))) t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_pool(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) Transposed Convolutions -------------------------------- ``torch.nn.ConvTranspose2d`` and |nnx.ConvTranspose|_ are not compatible. |nnx.ConvTranspose|_ is a wrapper around |jax.lax.conv_transpose|_ which computes a fractionally strided convolution, while ``torch.nn.ConvTranspose2d`` computes a gradient based transposed convolution. Currently, there is no implementation of a gradient based transposed convolution is ``Jax``. However, there is a pending `pull request`_ that contains an implementation. To load ``torch.nn.ConvTranspose2d`` parameters into Flax, we need to use the ``transpose_kernel`` arg in Flax's ``nnx.ConvTranspose`` layer. .. testcode:: t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=0) kernel = t_conv.weight.detach().cpu().numpy() bias = t_conv.bias.detach().cpu().numpy() # [inC, outC, kH, kW] -> [kH, kW, outC, inC] kernel = jnp.transpose(kernel, (2, 3, 1, 0)) key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) # ConvTranspose expects the kernel to be [kH, kW, inC, outC], # but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead j_conv = nnx.ConvTranspose(3, 4, kernel_size=(2, 2), padding='VALID', transpose_kernel=True, rngs=nnx.Rngs(0)) j_conv.kernel.value = kernel j_conv.bias.value = bias j_out = j_conv(x) # [N, H, W, C] -> [N, C, H, W] t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2))) t_out = t_conv(t_x) # [N, C, H, W] -> [N, H, W, C] t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1)) np.testing.assert_almost_equal(j_out, t_out, decimal=6) .. _`pull request`: https://github.com/jax-ml/jax/pull/5772 .. |nnx.ConvTranspose| replace:: ``nnx.ConvTranspose`` .. _nnx.ConvTranspose: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.ConvTranspose .. |jax.lax.conv_transpose| replace:: ``jax.lax.conv_transpose`` .. _jax.lax.conv_transpose: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html ================================================ FILE: docs_nnx/migrating/haiku_to_flax.rst ================================================ Haiku to Flax NNX ############################ This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku. If you are new to Flax NNX, make sure you become familiarized with `Flax NNX basics `__, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. Let’s start with some imports. .. testsetup:: Haiku, Flax NNX import jax import jax.numpy as jnp import optax from typing import Any Basic Module definition ======================= Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you: * First, create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function. * Then, use ``Block`` as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. There are two fundamental differences between Haiku and Flax ``Module`` objects: * **Stateless vs. stateful**: * A ``haiku.Module`` instance is stateless. This means, the variables are returned from a purely functional ``Module.init()`` call and managed separately. * A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. * **Lazy vs. eager**: * A ``haiku.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). * A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). .. codediff:: :title: Haiku, Flax NNX :sync: import haiku as hk class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x) x = jax.nn.relu(x) return x class Model(hk.Module): def __init__(self, dmid: int, dout: int, name=None): super().__init__(name=name) self.dmid = dmid self.dout = dout def __call__(self, x, training: bool): x = Block(self.dmid)(x, training) x = hk.Linear(self.dout)(x) return x --- from flax import nnx class Block(nnx.Module): def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x class Model(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs): self.block = Block(din, dmid, rngs=rngs) self.linear = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = self.block(x) x = self.linear(x) return x Variable creation ================= This section is about instantiating a model and initializing its parameters. * To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``haiku.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. * In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API `__ (:func:`nnx.split` / :func:`nnx.merge`). .. codediff:: :title: Haiku, Flax NNX :sync: def forward(x, training: bool): return Model(256, 10)(x, training) model = hk.transform(forward) sample_x = jnp.ones((1, 784)) params = model.init(jax.random.key(0), sample_x, training=False) assert params['model/linear']['b'].shape == (10,) assert params['model/block/linear']['w'].shape == (784, 256) --- ... model = Model(784, 256, 10, rngs=nnx.Rngs(0)) # Parameters were already initialized during model instantiation. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) Training step and compilation ============================= This section covers writing a training step and compiling it using the `JAX just-in-time compilation `__. When compiling the training step: * Haiku uses ``@jax.jit`` - a `JAX transformation `__ - to compile a purely functional training step. * Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transformation `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects `__). While ``jax.jit`` only accepts functions with pure stateless arguments, ``flax.nnx.jit`` allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step. When taking gradients: * Similarly, Haiku uses ``jax.grad`` (a JAX transformation for `automatic differentiation `__) to return a raw dictionary of gradients. * Meanwhile, Flax NNX uses :meth:`flax.nnx.grad` (a Flax NNX transformation) to return the gradients of Flax NNX Modules as :class:`flax.nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX, you need to use the `split/merge API `__. For optimizers: * If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`flax.nnx.Optimizer` example in the `Flax basics `__ guide for a much more concise way of training and updating your model. Model updates during each training step: * The Haiku training step needs to return a `JAX pytree `__ of parameters as the input of the next step. * The Flax NNX training step does not need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. * In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``flax.nnx.BatchNorm`` stats. That is why you don't need to explicitly pass a PRNG key in at every step. Also note that you can use :meth:`flax.nnx.reseed` to reset its underlying PRNG state. The dropout behavior: * In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``haiku.dropout`` and make sure that random dropout only happens if ``training=True``. * In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`flax.nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``flax.nnx.Module.train`` does in its `API reference `__. .. codediff:: :title: Haiku, Flax NNX :sync: ... @jax.jit def train_step(key, params, inputs, labels): def loss_fn(params): logits = model.apply( params, key, inputs, training=True # <== inputs ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(params) params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads) return params --- model.train() # set deterministic=False @nnx.jit def train_step(model, inputs, labels): def loss_fn(model): logits = model( inputs, # <== inputs ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = nnx.grad(loss_fn)(model) _, params, rest = nnx.split(model, nnx.Param, ...) params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads) nnx.update(model, nnx.merge_state(params, rest)) .. testcode:: Haiku :hide: train_step(jax.random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) .. testcode:: Flax NNX :hide: sample_x = jnp.ones((1, 784)) train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) Handling non-parameter states ============================= Haiku makes a distinction between trainable parameters and all other data ("states") that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with ``hk.transform_with_state`` so that their ``.init()`` returns both params and states. In Flax, there isn't such a strong distinction - they are all subclasses of ``nnx.Variable`` and seen by a module as its attributes. Parameters are instances of a subclass called ``nnx.Param``, and batch stats can be of another subclass called ``nnx.BatchStat``. You can use :func:`nnx.split` to quickly extract all data of a certain variable type. Let's see an example of this by taking the ``Block`` definition above but replace dropout with ``BatchNorm``. .. codediff:: :title: Haiku, Flax NNX :sync: class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.BatchNorm( create_scale=True, create_offset=True, decay_rate=0.99 )(x, is_training=training) x = jax.nn.relu(x) return x def forward(x, training: bool): return Model(256, 10)(x, training) model = hk.transform_with_state(forward) sample_x = jnp.ones((1, 784)) params, batch_stats = model.init(jax.random.key(0), sample_x, training=True) --- class Block(nnx.Module): def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) self.batchnorm = nnx.BatchNorm( num_features=out_features, momentum=0.99, rngs=rngs ) def __call__(self, x): x = self.linear(x) x = self.batchnorm(x) x = jax.nn.relu(x) return x model = Block(4, 4, rngs=nnx.Rngs(0)) model.linear.kernel # Param(value=...) model.batchnorm.mean # BatchStat(value=...) Flax takes the difference of trainable params and other data into account. ``nnx.grad`` will only take gradients on the ``nnx.Param`` variables, thus skipping the ``batchnorm`` arrays automatically. Therefore, the training step will look the same for Flax NNX with this model. Using multiple methods ====================== In this section you will learn how to use multiple methods in Haiku and Flax. As an example, you will implement an auto-encoder model with three methods: ``encode``, ``decode``, and ``__call__``. In Haiku, you need to use ``hk.multi_transform`` to explicitly define how the model shall be initialized and what methods (``encode`` and ``decode`` here) it can call. Note that you still need to define a ``__call__`` that activates both layers for the lazy initialization of all model parameters. In Flax, it's simpler as you initialized parameters in ``__init__`` and the :class:`nnx.Module` methods ``encode`` and ``decode`` can be used directly. .. codediff:: :title: Haiku, Flax NNX :sync: class AutoEncoder(hk.Module): def __init__(self, embed_dim: int, output_dim: int, name=None): super().__init__(name=name) self.encoder = hk.Linear(embed_dim, name="encoder") self.decoder = hk.Linear(output_dim, name="decoder") def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) def __call__(self, x): x = self.encode(x) x = self.decode(x) return x def forward(): module = AutoEncoder(256, 784) init = lambda x: module(x) return init, (module.encode, module.decode) model = hk.multi_transform(forward) params = model.init(jax.random.key(0), x=jnp.ones((1, 784))) --- class AutoEncoder(nnx.Module): def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs): self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs) self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs) def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0)) ... The parameter structure is as follows: .. tab-set:: .. tab-item:: Haiku :sync: Haiku .. code-block:: python ... { 'auto_encoder/~/decoder': { 'b': (784,), 'w': (256, 784) }, 'auto_encoder/~/encoder': { 'b': (256,), 'w': (784, 256) } } .. tab-item:: Flax NNX :sync: Flax NNX .. code-block:: python _, params, _ = nnx.split(model, nnx.Param, ...) params { 'decoder': { 'bias': Param(value=(784,)), 'kernel': Param(value=(256, 784)) }, 'encoder': { 'bias': Param(value=(256,)), 'kernel': Param(value=(784, 256)) } } To call those custom methods: * In Haiku, you need to decouple the `.apply` function to extract your method before calling it. * In Flax, you can simply call the method directly. .. codediff:: :title: Haiku, Flax NNX :sync: encode, decode = model.apply z = encode(params, None, x=jnp.ones((1, 784))) --- ... z = model.encode(jnp.ones((1, 784))) Transformations ======================= Both Haiku and `Flax transformations `__ provide their own set of transforms that wrap `JAX transforms `__ in a way that they can be used with ``Module`` objects. For more information on Flax transforms, check out the `Transforms guide `__. Let's start with an example: * First, define an ``RNNCell`` ``Module`` that will contain the logic for a single step of the RNN. * Define a ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan`` (`API doc `__), the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: :title: Haiku, Flax NNX :sync: class RNNCell(hk.Module): def __init__(self, hidden_size: int, name=None): super().__init__(name=name) self.hidden_size = hidden_size def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = hk.Linear(self.hidden_size)(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) --- class RNNCell(nnx.Module): def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs) self.hidden_size = hidden_size def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = self.linear(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. In both cases, we use the library's ``scan`` call to run the ``RNNCell`` over the input sequence. The only difference is that Flax ``nnx.scan`` allows you to specify which axis to repeat over in arguments ``in_axes`` and ``out_axes``, which will be forwarded to the underlying `jax.lax.scan `__, whereas in Haiku you need to transpose the input and output explicitly. .. codediff:: :title: Haiku, Flax NNX :sync: class RNN(hk.Module): def __init__(self, hidden_size: int, name=None): super().__init__(name=name) self.hidden_size = hidden_size def __call__(self, x): cell = RNNCell(self.hidden_size) carry = cell.initial_state(x.shape[0]) carry, y = hk.scan( cell, carry, jnp.swapaxes(x, 1, 0) ) y = jnp.swapaxes(y, 0, 1) return y --- class RNN(nnx.Module): def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs): self.hidden_size = hidden_size self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs) def __call__(self, x): scan_fn = lambda carry, cell, x: cell(carry, x) carry = self.cell.initial_state(x.shape[0]) carry, y = nnx.scan( scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1) )(carry, self.cell, x) return y Scan over layers ======================= Most Haiku transforms should look similar with Flax, since they all wraps their JAX counterparts, but the scan-over-layers use case is an exception. Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the ``Block`` ``Module`` 5 times in the top-level ``MLP`` ``Module``. In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack of ``Block`` Modules. The same code will create 5 layers of parameters in initialization time, and run the input through them in call time. In Flax, model initialization and calling code are completely decoupled, so we use the :func:`nnx.vmap` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan` transform to run the model input through them. .. codediff:: :title: Haiku, Flax NNX :sync: class Block(hk.Module): def __init__(self, features: int, name=None): super().__init__(name=name) self.features = features def __call__(self, x, training: bool): x = hk.Linear(self.features)(x) x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x) x = jax.nn.relu(x) return x class MLP(hk.Module): def __init__(self, features: int, num_layers: int, name=None): super().__init__(name=name) self.features = features self.num_layers = num_layers def __call__(self, x, training: bool): @hk.experimental.layer_stack(self.num_layers) def stack_block(x): return Block(self.features)(x, training) stack = hk.experimental.layer_stack(self.num_layers) return stack_block(x) def forward(x, training: bool): return MLP(64, num_layers=5)(x, training) model = hk.transform(forward) sample_x = jnp.ones((1, 64)) params = model.init(jax.random.key(0), sample_x, training=False) --- class Block(nnx.Module): def __init__(self, input_dim, features, rngs): self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x: jax.Array): # No need to require a second input! x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x # No need to return a second output! class MLP(nnx.Module): def __init__(self, features, num_layers, rngs): @nnx.split_rngs(splits=num_layers) @nnx.vmap(in_axes=(0,), out_axes=0) def create_block(rngs: nnx.Rngs): return Block(features, features, rngs=rngs) self.blocks = create_block(rngs) self.num_layers = num_layers def __call__(self, x): @nnx.split_rngs(splits=self.num_layers) @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, model): x = model(x) return x return forward(x, self.blocks) model = MLP(64, num_layers=5, rngs=nnx.Rngs(0)) There are a few other details to explain in the Flax example above: * **The `@nnx.split_rngs` decorator:** Flax transforms, like their JAX counterparts, are completely agnostic of the PRNG state and rely on input for PRNG keys. The ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside. * Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform. * Note that actually ``create_block()`` knows it needs to create 5 layers *precisely because* it sees 5 PRNG keys, because ``in_axes=(0,)`` indicates that ``vmap`` will look into the first argument's first dimension to know the size it will map over. * Same goes for ``forward()``, which looks at the variables inside the first argument (aka. ``model``) to find out how many times it needs to scan. ``nnx.split_rngs`` here actually splits the PRNG state inside the ``model``. (If the ``Block`` ``Module`` doesn't have dropout, you don't need the :meth:`nnx.split_rngs` line as it would not consume any PRNG key anyway.) * **Why the Block Module in Flax doesn't need to take and return that extra dummy value:** ``jax.lax.scan`` `(API doc `__ requires its function to return two inputs - the carry and the stacked output. In this case, we didn't use the latter. Flax simplifies this, so that you can now choose to ignore the second output if you set ``out_axes=nnx.Carry`` instead of the default ``(nnx.Carry, 0)``. * This is one of the rare cases where Flax NNX transforms diverge from the `JAX transforms `__ APIs. There are more lines of code in the Flax example above, but they express what happens at each time more precisely. Since Flax transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying `JAX transforms `__ before using their `Flax NNX equivalents `__ Now inspect the variable pytree on both sides: .. tab-set:: .. tab-item:: Haiku :sync: Haiku .. code-block:: python ... { 'mlp/__layer_stack_no_per_layer/block/linear': { 'b': (5, 64), 'w': (5, 64, 64) } } ... .. tab-item:: Flax NNX :sync: Flax NNX .. code-block:: python _, params, _ = nnx.split(model, nnx.Param, ...) params { 'blocks': { 'linear': { 'bias': Param(value=(5, 64)), 'kernel': Param(value=(5, 64, 64)) } } } Top-level Haiku functions vs top-level Flax modules ======================= In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and states. It is very common to write the top-level "Module" as a function instead. The Flax team recommends a more Module-centric approach that uses ``__call__`` to define the forward function. In Flax modules, the parameters and variables can be set and accessed as normal using regular Python class semantics. .. codediff:: :title: Haiku, Flax NNX :sync: ... def forward(x): counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones) multiplier = hk.get_parameter( 'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones ) output = x + multiplier * counter hk.set_state("counter", counter + 1) return output model = hk.transform_with_state(forward) params, state = model.init(jax.random.key(0), jnp.ones((1, 64))) --- class Counter(nnx.Variable): pass class FooModule(nnx.Module): def __init__(self, rngs): self.counter = Counter(jnp.ones((), jnp.int32)) self.multiplier = nnx.Param( nnx.initializers.ones(rngs.params(), [1,], jnp.float32) ) def __call__(self, x): output = x + self.multiplier * self.counter.value self.counter.value += 1 return output model = FooModule(rngs=nnx.Rngs(0)) _, params, counter = nnx.split(model, nnx.Param, Counter) ================================================ FILE: docs_nnx/migrating/index.rst ================================================ Migrating ------------------------ .. toctree:: :maxdepth: 2 convert_pytorch_to_flax nnx_010_to_nnx_011 linen_to_nnx haiku_to_flax ================================================ FILE: docs_nnx/migrating/linen_to_nnx.rst ================================================ Flax Linen to Flax NNX ################################ This guide demonstrates the differences between Flax Linen and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Flax Linen. This document mainly teaches how to convert arbitrary Flax Linen code to Flax NNX. If you want to play it “safe” and convert your codebase iteratively, check out the `Use Flax NNX and Linen together via nnx.bridge `__ guide. To get the most out of this guide, it is highly recommended to get go through `Flax NNX basics `__ document, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. .. testsetup:: Linen, NNX import jax import jax.numpy as jnp import optax import flax.linen as nn from typing import Any Basic ``Module`` definition =========================== Both Flax Linen and Flax NNX use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. There are two fundamental differences between Flax Linen and Flax NNX ``Module`` objects: * **Stateless vs. stateful**: A ``flax.linen.Module`` (``nn.Module``) instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. * **Lazy vs. eager**: A ``flax.linen.Module`` only allocates space to create variables when they actually see their input (lazy). A :class:`flax.nnx.Module` instance creates variables the moment they are instantiated before seeing a sample input (eager). * Flax Linen can use the ``@nn.compact`` decorator to define the model in a single method, and use shape inference from the input sample. A Flax NNX ``Module`` generally requests additional shape information to create all parameters during ``__init__`` , and separately defines the computation in the ``__call__`` method. .. codediff:: :title: Linen, NNX :sync: import flax.linen as nn class Block(nn.Module): features: int @nn.compact def __call__(self, x, training: bool): x = nn.Dense(self.features)(x) x = nn.Dropout(0.5, deterministic=not training)(x) x = jax.nn.relu(x) return x class Model(nn.Module): dmid: int dout: int @nn.compact def __call__(self, x, training: bool): x = Block(self.dmid)(x, training) x = nn.Dense(self.dout)(x) return x --- from flax import nnx class Block(nnx.Module): def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x class Model(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs): self.block = Block(din, dmid, rngs=rngs) self.linear = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = self.block(x) x = self.linear(x) return x Variable creation ================= Next, let’s discuss instantiating the model and initializing its parameters: * To generate model parameters for a Flax Linen model, you call the ``flax.linen.Module.init`` (``nn.Module.init``) method with a ``jax.random.key`` (`doc `__) plus some sample inputs that the model shall take. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. * In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-``Module``) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. If you want to access Flax NNX model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API `__ (:func:`nnx.split` / :func:`nnx.merge`). .. codediff:: :title: Linen, NNX :sync: model = Model(256, 10) sample_x = jnp.ones((1, 784)) variables = model.init(jax.random.key(0), sample_x, training=False) params = variables["params"] assert params['Dense_0']['bias'].shape == (10,) assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256) --- model = Model(784, 256, 10, rngs=nnx.Rngs(0)) # Parameters were already initialized during model instantiation. assert model.linear.bias.value.shape == (10,) assert model.block.linear.kernel.value.shape == (784, 256) Training step and compilation ============================= Now, let’s proceed to writing a training step and compiling it using `JAX just-in-time compilation `__. Below are certain differences between Flax Linen and Flax NNX approaches. Compiling the training step: * Flax Linen uses ``@jax.jit`` - a `JAX transform `__ - to compile the training step. * Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transform `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax NNX objects `__). So, while ``jax.jit`` only accepts functions pure stateless arguments, ``nnx.jit`` allows the arguments to be stateful NNX Modules. This greatly reduced the number of lines needed for a train step. Taking gradients: * Similarly, Flax Linen uses ``jax.grad`` (a JAX transform for `automatic differentiation `__) to return a raw dictionary of gradients. * Flax NNX uses :meth:`nnx.grad` (a Flax NNX transform) to return the gradients of NNX Modules as :class:`nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX you need to use the `Flax NNX split/merge API `__. Optimizers: * If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Flax Linen, check out the :class:`nnx.Optimizer` example in the `Flax NNX basics `__ guide for a much more concise way of training and updating your model. Model updates during each training step: * The Flax Linen training step needs to return a `pytree `__ of parameters as the input of the next step. * The Flax NNX training step doesn't need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. * In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``BatchNorm`` stats. That is why you don't need to explicitly pass an PRNG key in on every step. Also note that you can use :meth:`nnx.reseed` to reset its underlying PRNG state. Dropout behavior: * In Flax Linen, you need to explicitly define and pass in the ``training`` argument to control the behavior of ``flax.linen.Dropout`` (``nn.Dropout``), namely, its ``deterministic`` flag, which means random dropout only happens if ``training=True``. * In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``nnx.Module.train`` does in its `API reference `__. .. codediff:: :title: Linen, NNX :sync: ... @jax.jit def train_step(key, params, inputs, labels): def loss_fn(params): logits = model.apply( {'params': params}, inputs, training=True, # <== inputs rngs={'dropout': key} ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(params) params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads) return params --- model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout @nnx.jit def train_step(model, inputs, labels): def loss_fn(model): logits = model(inputs) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = nnx.grad(loss_fn)(model) _, params, rest = nnx.split(model, nnx.Param, ...) params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads) nnx.update(model, nnx.merge_state(params, rest)) .. testcode:: Linen :hide: train_step(jax.random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) .. testcode:: NNX :hide: sample_x = jnp.ones((1, 784)) train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) Collections and variable types ============================== One key difference between Flax Linen and NNX APIs is how they group variables into categories. Flax Linen uses different collections, while Flax NNX, since all variables shall be top-level Python attributes, you use different variable types. In Flax NNX, you can freely create your own variable types as subclasses of ``nnx.Variable``. For all the built-in Flax Linen layers and collections, Flax NNX already creates the corresponding layers and variable types. For example: * ``flax.linen.Dense`` (``nn.Dense``) creates ``params`` -> :class:`nnx.Linear` creates :class:nnx.Param`. * ``flax.linen.BatchNorm`` (``nn.BatchNorm``) creates ``batch_stats`` -> :class:`nnx.BatchNorm` creates :class:`nnx.BatchStats`. * ``flax.linen.Module.sow()`` creates ``intermediates`` -> :class:`nnx.Module.sow()` creates :class:`nnx.Intermediaries`. * In Flax NNX, you can also simply obtain the intermediates by assigning it to an ``nnx.Module`` attribute - for example, ``self.sowed = nnx.Intermediates(x)``. This will be similar to Flax Linen's ``self.variable('intermediates' 'sowed', lambda: x)``. .. codediff:: :title: Linen, NNX :sync: class Block(nn.Module): features: int def setup(self): self.dense = nn.Dense(self.features) self.batchnorm = nn.BatchNorm(momentum=0.99) self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32)) @nn.compact def __call__(self, x, training: bool): x = self.dense(x) x = self.batchnorm(x, use_running_average=not training) self.count.value += 1 x = jax.nn.relu(x) return x x = jax.random.normal(jax.random.key(0), (2, 4)) model = Block(4) variables = model.init(jax.random.key(0), x, training=True) variables['params']['dense']['kernel'].shape # (4, 4) variables['batch_stats']['batchnorm']['mean'].shape # (4, ) variables['counter']['count'] # 1 --- class Counter(nnx.Variable): pass class Block(nnx.Module): def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs): self.linear = nnx.Linear(in_features, out_features, rngs=rngs) self.batchnorm = nnx.BatchNorm( num_features=out_features, momentum=0.99, rngs=rngs ) self.count = Counter(jnp.array(0)) def __call__(self, x): x = self.linear(x) x = self.batchnorm(x) self.count.value += 1 x = jax.nn.relu(x) return x model = Block(4, 4, rngs=nnx.Rngs(0)) model.linear.kernel # Param(value=...) model.batchnorm.mean # BatchStat(value=...) model.count # Counter(value=...) If you want to extract certain arrays from the pytree of variables: * In Flax Linen, you can access the specific dictionary path. * In Flax NNX, you can use :func:`nnx.split` to distinguish the types apart in Flax NNX. The code below is a simple example that splits up the variables by their types - check out the `Flax NNX Filters `__ guide for more sophisticated filtering expressions. .. codediff:: :title: Linen, NNX :sync: params, batch_stats, counter = ( variables['params'], variables['batch_stats'], variables['counter']) params.keys() # ['dense', 'batchnorm'] batch_stats.keys() # ['batchnorm'] counter.keys() # ['count'] # ... make arbitrary modifications ... # Merge back with raw dict to carry on: variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter} --- graphdef, params, batch_stats, count = nnx.split( model, nnx.Param, nnx.BatchStat, Counter) params.keys() # ['batchnorm', 'linear'] batch_stats.keys() # ['batchnorm'] count.keys() # ['count'] # ... make arbitrary modifications ... # Merge back with ``nnx.merge`` to carry on: model = nnx.merge(graphdef, params, batch_stats, count) Using multiple methods ====================== In this section you will learn how to use multiple methods in both Flax Linen and Flax NNX. As an example, you will implement an auto-encoder model with three methods: ``encode``, ``decode``, and ``__call__``. Defining the encoder and decoder layers: * In Flax Linen, as before, define the layers without having to pass in the input shape, since the ``flax.linen.Module`` parameters will be initialized lazily using shape inference. * In Flax NNX, you must pass in the input shape since the :class:`nnx.Module` parameters will be initialized eagerly without shape inference. .. codediff:: :title: Linen, NNX :sync: class AutoEncoder(nn.Module): embed_dim: int output_dim: int def setup(self): self.encoder = nn.Dense(self.embed_dim) self.decoder = nn.Dense(self.output_dim) def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) def __call__(self, x): x = self.encode(x) x = self.decode(x) return x model = AutoEncoder(256, 784) variables = model.init(jax.random.key(0), x=jnp.ones((1, 784))) --- class AutoEncoder(nnx.Module): def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs): self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs) self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs) def encode(self, x): return self.encoder(x) def decode(self, x): return self.decoder(x) def __call__(self, x): x = self.encode(x) x = self.decode(x) return x model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0)) The variable structure is as follows: .. tab-set:: .. tab-item:: Linen :sync: Linen .. code-block:: python # variables['params'] { decoder: { bias: (784,), kernel: (256, 784), }, encoder: { bias: (256,), kernel: (784, 256), }, } .. tab-item:: NNX :sync: NNX .. code-block:: python # _, params, _ = nnx.split(model, nnx.Param, ...) # params { 'decoder': { 'bias': Param(value=(784,)), 'kernel': Param(value=(256, 784)) }, 'encoder': { 'bias': Param(value=(256,)), 'kernel': Param(value=(784, 256)) } } To call methods other than ``__call__``: * In Flax Linen, you still need to use the ``apply`` API. * In Flax NNX, you can simply call the method directly. .. codediff:: :title: Linen, NNX :sync: z = model.apply(variables, x=jnp.ones((1, 784)), method="encode") --- z = model.encode(jnp.ones((1, 784))) Transformations =============== Both Flax Linen and `Flax NNX transformations `__ provide their own set of transforms that wrap `JAX transforms `__ in a way that they can be used with ``Module`` objects. Most of the transforms in Flax Linen, such as ``grad`` or ``jit``, don't change much in Flax NNX. But, for example, if you try to do ``scan`` over layers, as described in the next section, the code differs by a lot. Let’s start with an example: * First, define an ``RNNCell`` ``Module`` that will contain the logic for a single step of the RNN. * Define a ``initial_state`` method that will be used to initialize the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan`` (`API doc `__), the ``RNNCell.__call__`` method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: :title: Linen, NNX :sync: class RNNCell(nn.Module): hidden_size: int @nn.compact def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = nn.Dense(self.hidden_size)(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) --- class RNNCell(nnx.Module): def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs) self.hidden_size = hidden_size def __call__(self, carry, x): x = jnp.concatenate([carry, x], axis=-1) x = self.linear(x) x = jax.nn.relu(x) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) Next, define an ``RNN`` ``Module`` that will contain the logic for the entire RNN. In Flax Linen: * You will use ``flax.linen.scan`` (``nn.scan``) to define a new temporary type that wraps ``RNNCell``. During this process you will also: 1) instruct ``nn.scan`` to broadcast the ``params`` collection (all steps share the same parameters) and to not split the ``params`` PRNG stream (so that all steps initialize with the same parameters); and, finally, 2) specify that you want scan to run over the second axis of the input and stack outputs along the second axis as well. * You will then use this temporary type immediately to create an instance of the “lifted” ``RNNCell`` and use it to create the ``carry``, and the run the ``__call__`` method, which will ``scan`` over the sequence. In Flax NNX: * You will create a ``scan`` function (``scan_fn``) that will use the ``RNNCell`` defined in ``__init__`` to scan over the sequence, and explicitly set ``in_axes=(nnx.Carry, None, 1)``. ``nnx.Carry`` means that the ``carry`` argument will be the carry, ``None`` means that ``cell`` will be broadcasted to all steps, and ``1`` means ``x`` will be scanned across axis `1`. .. codediff:: :title: Linen, NNX :sync: class RNN(nn.Module): hidden_size: int @nn.compact def __call__(self, x): rnn = nn.scan( RNNCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1 )(self.hidden_size) carry = rnn.initial_state(x.shape[0]) carry, y = rnn(carry, x) return y x = jnp.ones((3, 12, 32)) model = RNN(64) variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32))) y = model.apply(variables, x=jnp.ones((3, 12, 32))) --- class RNN(nnx.Module): def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs): self.hidden_size = hidden_size self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs) def __call__(self, x): scan_fn = lambda carry, cell, x: cell(carry, x) carry = self.cell.initial_state(x.shape[0]) carry, y = nnx.scan( scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1) )(carry, self.cell, x) return y x = jnp.ones((3, 12, 32)) model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0)) y = model(x) Scan over layers ================ In general, transforms of Flax Linen and Flax NNX should look the same. However, `Flax NNX transforms `__ are designed to be closer to their lower-level `JAX counterparts `__, and thus we throw away some assumptions in certain Linen lifted transforms. This scan-over-layers use case will be a good example to showcase it. Scan-over-layers is a technique where you run an input through a sequence of N repeated layers, passing the output of each layer as the input to the next layer. This pattern can significantly reduce compilation time for large models. In the example below, you will repeat the ``Block`` ``Module`` 5 times in the top-level ``MLP`` ``Module``. * In Flax Linen, you apply the ``flax.linen.scan`` (``nn.scan``) transforms upon the ``Block`` ``nn.Module`` to create a larger ``ScanBlock`` ``nn.Module`` that contains 5 ``Block`` ``nn.Module`` objects. It will automatically create a large parameter of shape ``(5, 64, 64)`` at initialization time, and iterate over at call time every ``(64, 64)`` slice for a total of 5 times, like a ``jax.lax.scan`` (`API doc `__) would. * Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array. * In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan` transform to run the model input through them. For more information on Flax NNX transforms, check out the `Transforms guide `__. .. codediff:: :title: Linen, NNX :sync: class Block(nn.Module): features: int training: bool @nn.compact def __call__(self, x, _): x = nn.Dense(self.features)(x) x = nn.Dropout(0.5)(x, deterministic=not self.training) x = jax.nn.relu(x) return x, None class MLP(nn.Module): features: int num_layers: int @nn.compact def __call__(self, x, training: bool): ScanBlock = nn.scan( Block, variable_axes={'params': 0}, split_rngs={'params': True}, length=self.num_layers) y, _ = ScanBlock(self.features, training)(x, None) return y model = MLP(64, num_layers=5) --- class Block(nnx.Module): def __init__(self, input_dim, features, rngs): self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x: jax.Array): # No need to require a second input! x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x # No need to return a second output! class MLP(nnx.Module): def __init__(self, features, num_layers, rngs): @nnx.split_rngs(splits=num_layers) @nnx.vmap(in_axes=(0,), out_axes=0) def create_block(rngs: nnx.Rngs): return Block(features, features, rngs=rngs) self.blocks = create_block(rngs) self.num_layers = num_layers def __call__(self, x): @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, model): x = model(x) return x return forward(x, self.blocks) model = MLP(64, num_layers=5, rngs=nnx.Rngs(0)) There are a few other details to explain in the Flax NNX example above: * **The `@nnx.split_rngs` decorator:** Flax NNX transforms are completely agnostic of PRNG state, which makes them behave more like JAX transforms but diverge from the Flax Linen transforms that handle PRNG state. To regain this functionality, the ``nnx.split_rngs`` decorator allows you to split the ``nnx.Rngs`` before passing them to the decorated function and 'lower' them afterwards, so they can be used outside. * Here, you split the PRNG keys because ``jax.vmap`` and ``jax.lax.scan`` require a list of PRNG keys if each of its internal operations needs its own key. So for the 5 layers inside the ``MLP``, you split and provide 5 different PRNG keys from its arguments before going down to the JAX transform. * Note that actually ``create_block()`` knows it needs to create 5 layers *precisely because* it sees 5 PRNG keys, because ``in_axes=(0,)`` indicates that ``vmap`` will look into the first argument's first dimension to know the size it will map over. * Same goes for ``forward()``, which looks at the variables inside the first argument (aka. ``model``) to find out how many times it needs to scan. ``nnx.split_rngs`` here actually splits the PRNG state inside the ``model``. (If the ``Block`` ``Module`` doesn't have dropout, you don't need the :meth:`nnx.split_rngs` line as it would not consume any PRNG key anyway.) * **Why the Block Module in Flax NNX doesn't need to take and return that extra dummy value:** This is a requirement from ``jax.lax.scan`` `(API doc `__. Flax NNX simplifies this, so that you can now choose to ignore the second output if you set ``out_axes=nnx.Carry`` instead of the default ``(nnx.Carry, 0)``. * This is one of the rare cases where Flax NNX transforms diverge from the `JAX transforms `__ APIs. There are more lines of code in the Flax NNX example above, but they express what happens at each time more precisely. Since Flax NNX transforms become way closer to the JAX transform APIs, it is recommended to have a good understanding of the underlying `JAX transforms `__ before using their `Flax NNX equivalents `__ Now inspect the variable pytree on both sides: .. tab-set:: .. tab-item:: Linen :sync: Linen .. code-block:: python # variables = model.init(key, x=jnp.ones((1, 64)), training=True) # variables['params'] { ScanBlock_0: { Dense_0: { bias: (5, 64), kernel: (5, 64, 64), }, }, } .. tab-item:: NNX :sync: NNX .. code-block:: python # _, params, _ = nnx.split(model, nnx.Param, ...) # params { 'blocks': { 'linear': { 'bias': Param(value=(5, 64)), 'kernel': Param(value=(5, 64, 64)) } } } Using ``TrainState`` in Flax NNX ================================ Flax Linen has a convenient ``TrainState`` data class to bundle the model, parameters and optimizer. In Flax NNX, this is not really necessary. In this section, you will learn how to construct your Flax NNX code around ``TrainState`` for any backward compatibility needs. In Flax NNX: * You must first call :meth:`nnx.split` on the model to get the separate :class:`nnx.GraphDef` and :class:`nnx.State` objects. * You can pass in :class:`nnx.Param` to filter all trainable parameters into a single :class:`nnx.State`, and pass in ``...`` for the remaining variables. * You also need to subclass ``TrainState`` to add a field for the other variables. * Then, you can pass in :meth:`nnx.GraphDef.apply` as the ``apply`` function, :class:`nnx.State` as the parameters and other variables, and an optimizer as arguments to the ``TrainState`` constructor. Note that :class:`nnx.GraphDef.apply` will take in :class:`nnx.State` objects as arguments and return a callable function. This function can be called on the inputs to output the model's logits, as well as the updated :class:`nnx.GraphDef` and :class:`nnx.State` objects. Notice below the use of ``@jax.jit`` since you aren't passing in Flax NNX Modules into the ``train_step``. .. codediff:: :title: Linen, NNX :sync: from flax.training import train_state sample_x = jnp.ones((1, 784)) model = nn.Dense(features=10) params = model.init(jax.random.key(0), sample_x)['params'] state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=optax.adam(1e-3) ) @jax.jit def train_step(key, state, inputs, labels): def loss_fn(params): logits = state.apply_fn( {'params': params}, inputs, # <== inputs rngs={'dropout': key} ) return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state --- from flax.training import train_state model = nnx.Linear(784, 10, rngs=nnx.Rngs(0)) model.train() # set deterministic=False graphdef, params, other_variables = nnx.split(model, nnx.Param, ...) class TrainState(train_state.TrainState): other_variables: nnx.State state = TrainState.create( apply_fn=graphdef.apply, params=params, other_variables=other_variables, tx=optax.adam(1e-3) ) @jax.jit def train_step(state, inputs, labels): def loss_fn(params, other_variables): logits, (graphdef, new_state) = state.apply_fn( params, other_variables )(inputs) # <== inputs return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() grads = jax.grad(loss_fn)(state.params, state.other_variables) state = state.apply_gradients(grads=grads) return state .. testcode:: Linen :hide: train_step(jax.random.key(0), state, sample_x, jnp.ones((1,), dtype=jnp.int32)) .. testcode:: NNX :hide: sample_x = jnp.ones((1, 784)) train_step(state, sample_x, jnp.ones((1,), dtype=jnp.int32)) ================================================ FILE: docs_nnx/migrating/nnx_010_to_nnx_011.rst ================================================ NNX 0.10 to NNX 0.11 ######################################### In this guide we present the code changes required when we update Flax NNX code from Flax version ``0.10.x`` to ``0.11.x``. Using Rngs in NNX Transforms ==================================== NNX layers that use RNGs like Dropout or MultiHeadAttention now hold a ``fork``-ed copy of the ``Rngs`` object given at construction time instead of a shared reference to the original ``Rngs`` object. This has two consequences: * It changes the checkpoint structure, as each layer will have unique RNG state. * It changes how ``nnx.split_rngs`` interacts with transforms like ``nnx.vmap`` and ``nnx.scan``, as the resulting RNG state will now not be stored in scalar form. Here is how a "scan over layers" looks like in the new version: .. tab-set:: .. tab-item:: v0.11 :sync: v0.11 .. code-block:: python import flax.nnx as nnx class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(0, 0)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Param(jnp.ones((2,))) @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) def __call__(self, x: jax.Array): return nnx.gelu(self.dropout(self.bn(self.linear(x)))) .. tab-item:: v0.10 :sync: v0.10 .. code-block:: python :emphasize-lines: 12 import flax.nnx as nnx class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(0, 0)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Param(jnp.ones((2,))) @nnx.split_rngs(splits=5) @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) def __call__(self, x: jax.Array): return nnx.gelu(self.dropout(self.bn(self.linear(x)))) The main thing to note is that the ``nnx.split_rngs`` over ``scan`` is not needed anymore, as the RNGs produced by ``__init__`` are no longer in scalar form (they keep the additional dimension) and thus can be used directly in ``scan`` without the need to split them again. Alternatively, can even remove the ``nnx.split_rngs`` decorator from the ``__init__`` method and use ``Rngs.fork`` directly before passing the RNGs to the module. .. code-block:: python class MLP(nnx.Module): @nnx.vmap(in_axes=(0, 0)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Param(jnp.ones((2,))) @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) def __call__(self, x: jax.Array): return nnx.gelu(self.dropout(self.bn(self.linear(x)))) rngs = nnx.Rngs(0) mlp = MLP(rngs=rngs.fork(splits=5)) Loading Checkpoints with RNGs ================================================== When loading checkpoints in the new version, you need to drop the old RNGs structure and partially reinitialize the model with new RNGs. To do this, you can use ``nnx.jit`` to 1. Remove the RNGs from the checkpoint. 2. Perform partial initialization of the model with new RNGs. .. code-block:: python # load checkpoint checkpointer = ocp.StandardCheckpointer() checkpoint = checkpointer.restore(path / "state") @jax.jit def fix_checkpoint(checkpoint, rngs: nnx.Rngs): # drop rngs keys flat_paths = nnx.traversals.flatten_mapping(checkpoint) flat_paths = { path[:-1] if path[-1] == "value" else path: value # remove "value" suffix for path, value in flat_paths.items() if "rngs" not in path # remove rngs paths } checkpoint = nnx.traversals.unflatten_mapping(flat_paths) # initialize new model with given rngs model = MyModel(rngs=rngs) # overwrite model parameters with checkpoint nnx.update(model, checkpoint) # get full checkpoint with new rngs new_checkpoint = nnx.state(model) return new_checkpoint checkpoint = fix_checkpoint(checkpoint, rngs=nnx.Rngs(params=0, dropout=1)) checkpointer.save(path.with_name(path.name + "_new"), checkpoint) The previous code is efficient because ``jit`` performs dead code elimination (DCE) so it will not actually initialize the existing model parameters in memory. Optimizer Updates ==================================== Optimizer has been updated to not hold a reference to the model anymore. Instead, it now takes the model and gradients as arguments in the ``update`` method. Concretely, these are the the new changes: 1. The ``wrt`` constructor argument is now required. 2. The ``model`` attribute has been removed. 3. The ``update`` method now takes ``(model, grads)`` instead of only ``(grads)``. .. tab-set:: .. tab-item:: v0.11 :sync: v0.11 .. code-block:: python :emphasize-lines: 17, 26 from flax import nnx import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @nnx.jit def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return loss .. tab-item:: v0.10 :sync: v0.10 .. code-block:: python :emphasize-lines: 17, 26 from flax import nnx import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3)) @nnx.jit def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(grads) return loss Pytrees containing NNX Objects ==================================== In the new version, NNX modules are now Pytrees. This means that you can use them with JAX transforms like ``jax.vmap`` and ``jax.jit`` directly (more documentation on this will be available soon). However, this also means that code using ``jax.tree.*`` functions on structures that contain NNX modules will need to take this into account to maintain the current behavior. In these cases, the solution is to use the ``is_leaf`` argument to specify that NNX modules and other NNX objects should be treated as leaves. .. code-block:: python modules = [nnx.Linear(3, 3, rngs=nnx.Rngs(0)), nnx.BatchNorm(3, rngs=nnx.Rngs(1))] type_names = jax.tree.map( lambda x: type(x).__name__, modules, is_leaf=lambda x: isinstance(x, nnx.Pytree) # <-- specify that NNX objects are leaves ) ================================================ FILE: docs_nnx/mnist_tutorial.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n", "\n", "# MNIST tutorial\n", "\n", "Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n", "\n", "Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", "\n", "Let’s get started!\n", "\n", "## 1. Install Flax\n", "\n", "If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "# !pip install -U \"jax[cuda12]\"\n", "# !pip install -U flax" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## 2. Load the MNIST dataset\n", "\n", "First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance." ] }, { "cell_type": "code", "execution_count": 1, "id": "4", "metadata": {}, "outputs": [], "source": [ "import tensorflow_datasets as tfds # TFDS to download MNIST.\n", "import tensorflow as tf # TensorFlow / `tf.data` operations.\n", "\n", "tf.random.set_seed(0) # Set the random seed for reproducibility.\n", "\n", "train_steps = 1200\n", "eval_every = 200\n", "batch_size = 32\n", "\n", "train_ds: tf.data.Dataset = tfds.load('mnist', split='train')\n", "test_ds: tf.data.Dataset = tfds.load('mnist', split='test')\n", "\n", "train_ds = train_ds.map(\n", " lambda sample: {\n", " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", " 'label': sample['label'],\n", " }\n", ") # normalize train set\n", "test_ds = test_ds.map(\n", " lambda sample: {\n", " 'image': tf.cast(sample['image'], tf.float32) / 255,\n", " 'label': sample['label'],\n", " }\n", ") # Normalize the test set.\n", "\n", "# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.\n", "train_ds = train_ds.repeat().shuffle(1024)\n", "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)\n", "# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n", "test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "## 3. Define the model with Flax NNX\n", "\n", "Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from flax import nnx # The Flax NNX API.\n", "from functools import partial\n", "from typing import Optional\n", "\n", "class CNN(nnx.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", "\n", " def __init__(self, *, rngs: nnx.Rngs):\n", " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", " self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)\n", " self.dropout1 = nnx.Dropout(rate=0.025)\n", " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", " self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)\n", " self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))\n", " self.linear1 = nnx.Linear(3136, 256, rngs=rngs)\n", " self.dropout2 = nnx.Dropout(rate=0.025)\n", " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", "\n", " def __call__(self, x, rngs: nnx.Rngs | None = None):\n", " x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))\n", " x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))\n", " x = x.reshape(x.shape[0], -1) # flatten\n", " x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))\n", " x = self.linear2(x)\n", " return x\n", "\n", "# Instantiate the model.\n", "model = CNN(rngs=nnx.Rngs(0))\n", "# Visualize it.\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "### Run the model\n", "\n", "Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results." ] }, { "cell_type": "code", "execution_count": 3, "id": "8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([[ 0.11409501, 0.4546129 , -0.6421267 , -0.12122799, -0.22859162,\n", " 0.13616608, 1.0126765 , -0.03625144, 0.6132787 , -0.06018351]], dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax.numpy as jnp # JAX NumPy\n", "\n", "y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0))\n", "y" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## 4. Create the optimizer and define some metrics\n", "\n", "In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." ] }, { "cell_type": "code", "execution_count": 4, "id": "12", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import optax\n", "\n", "learning_rate = 0.005\n", "momentum = 0.9\n", "\n", "optimizer = nnx.Optimizer(\n", " model, optax.adamw(learning_rate, momentum), wrt=nnx.Param\n", ")\n", "metrics = nnx.MultiMetric(\n", " accuracy=nnx.metrics.Accuracy(),\n", " loss=nnx.metrics.Average('loss'),\n", ")\n", "\n", "nnx.display(optimizer)" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "## 5. Define training step functions\n", "\n", "In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n", "\n", "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n", "\n", "During training — the `train_step` — you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. The `train_step` also receives an `nnx.Rngs` object to provide randomness for dropout. The `eval_step` omits `rngs` because the eval view already has `deterministic=True`, so dropout is disabled and no random key is needed. During both steps, the `loss` and `logits` are used to update the metrics." ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "def loss_fn(model: CNN, batch, rngs: nnx.Rngs | None = None):\n", " logits = model(batch['image'], rngs)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=batch['label']\n", " ).mean()\n", " return loss, logits\n", "\n", "@nnx.jit\n", "def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, rngs: nnx.Rngs, batch):\n", " \"\"\"Train for a single step.\"\"\"\n", " grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)\n", " (loss, logits), grads = grad_fn(model, batch, rngs)\n", " metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.\n", " optimizer.update(model, grads) # In-place updates.\n", "\n", "@nnx.jit\n", "def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):\n", " loss, logits = loss_fn(model, batch)\n", " metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates." ] }, { "cell_type": "markdown", "id": "17", "metadata": {}, "source": [ "In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a stateful version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad` is a stateful version of `jax.value_and_grad`. Check out [the transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more.\n", "\n", "> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, the RNG streams, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html).\n", "\n", "\n", "## 6. Train and evaluate the model\n", "\n", "Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation." ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABL4AAAHDCAYAAAAqZtO0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAApvxJREFUeJzs3Qd4VFX6BvB3Mukd0hMCoXcSepMmCIrS7QXL2sWy/NUVReyyuitrwYq9NzoqiDRBmhBC7530AOmkzcz/+c7NpEACCSS5U97f81wyc+dm5oRkkjvvfOc7BovFYgEREREREREREZGDcdF7AERERERERERERPWBwRcRERERERERETkkBl9EREREREREROSQGHwREREREREREZFDYvBFREREREREREQOicEXERERERERERE5JAZfRERERERERETkkBh8ERERERERERGRQ2LwRUREREREREREDonBFxEREREREREROSQGX0Skq88//xwGgwGbNm3SeyhEREREVOq9995T52i9e/fWeyhERJeEwRcRERERERFV8s033yAmJgYbN27EgQMH9B4OEdFFY/BFREREREREZQ4fPoy1a9dixowZCAkJUSGYLcrLy9N7CERkBxh8EZHN27JlC6666ir4+/vD19cXQ4cOxfr16ysdU1xcjBdeeAGtW7eGp6cngoKCcNlll2Hp0qVlx6SkpODOO+9EkyZN4OHhgYiICIwZMwZHjhzR4asiIiIisk0SdDVq1AhXX301rr322iqDr8zMTPzzn/9UVWFyXiXnVxMnTkRGRkbZMQUFBXj++efRpk0bdX4m517jx4/HwYMH1e0rV65U0ynlY0Vybib7pSWG1R133KHOA+VzR44cCT8/P9xyyy3qttWrV+O6665D06ZN1Viio6PV2M6cOXPOuPfs2YPrr79eBXpeXl5o27YtnnnmGXXbihUr1OPOnTv3nM/79ttv1W3r1q27pP9bImp4rjo8JhFRje3cuRMDBgxQodeTTz4JNzc3fPjhhxg8eDBWrVpV1ndCTqqmT5+Ou+++G7169UJ2drbqGxYfH48rrrhCHTNhwgR1fw8//LA6SUtLS1PB2LFjx9R1IiIiItKCLwmo3N3dcdNNN+H999/H33//jZ49e6rbc3Nz1fnZ7t27cdddd6Fbt24q8FqwYAFOnDiB4OBgmEwmXHPNNVi2bBluvPFGPProo8jJyVHnXjt27EDLli1rPa6SkhKMGDFCvbn53//+F97e3mr/Tz/9hPz8fDzwwAPqzU+ZnvnOO++oschtVtu2bVPjlvPJe++9V53/SZC2cOFCvPLKK+r8UkIz+frHjRt3zv+JjLlv376X/P9LRA3MQkSko88++8wiv4r+/vvvKm8fO3asxd3d3XLw4MGyfUlJSRY/Pz/LwIEDy/bFxsZarr766mof5/Tp0+px/vOf/9TxV0BERETkODZt2qTOmZYuXaqum81mS5MmTSyPPvpo2THTpk1Tx8yZM+ecz5fjxaeffqqOmTFjRrXHrFixQh0jHys6fPiw2i/niVa333672vfUU0+dc3/5+fnn7Js+fbrFYDBYjh49WrZPzh3lHLLivorjEVOmTLF4eHhYMjMzy/alpaVZXF1dLc8991wV/2NEZOs41ZGIbJa8U/j7779j7NixaNGiRdl+KZO/+eabsWbNGlXZJQIDA1U11/79+6u8Lylll3ctpZT+9OnTDfY1EBEREdkTqWwKCwvDkCFD1HWZ3nfDDTfg+++/V+dmYvbs2YiNjT2nKsp6vPUYqfySSvvqjrkYUtVV1Xlexb5fUn3Wr18/KfJQLTNEeno6/vzzT1WhJlMiqxuPTNcsLCzEzz//XLbvhx9+UNVmt95660WPm4j0w+CLiGyWnKBI2br0Xjhb+/btYTabcfz4cXX9xRdfVL0mpIdE586d8cQTT6hydivp9/Daa6/ht99+UydzAwcOxOuvv676fhERERGR9qajBFwSekmDe1nNUTZpLZGamqqmLQqZHtipU6fz3pccI+dwrq51111H7kt6iZ1N2lZID7DGjRurPmDSv2vQoEHqtqysLPXx0KFD6uOFxt2uXTs1pbNiXzO53KdPH7Rq1arOvhYiajgMvojIIUiQJSdYn376qTqh+fjjj1W/Cflo9dhjj2Hfvn2qF5g0WH322WdVgGZ9J5CIiIjImS1fvhzJyckq/JIFg6ybNIMXdb26Y3WVX9bKsrPJG5kuLi7nHCv9XH/55Rf861//wrx581QfMWtjfHmjtLak6kt6yUqPMDm/lEWVWO1FZL/Y3J6IbJa8WydNS/fu3Vvlijxy4iMNSK3kXT5ZtVE2aboqYZg0vZeG91bSlPT//u//1CbTIuPi4vDGG2/g66+/brCvi4iIiMgWSbAVGhqKd99995zb5syZo1Y7/OCDD9T5lDSoPx85ZsOGDWrlbWkmXxVZOVJI1X5FR48erfGYt2/frt7Y/OKLL1RgZVVxZW9hbZtxoXELacY/efJkfPfdd2plSBm/TPckIvvEii8isllGoxHDhw/H/Pnz1bLWVlJqL0tKy4o+stqjOHnyZKXPlTJ3KUeXHg1CpkzKktpnn5DJUtjWY4iIiIiclQQ8Em7JSozXXnvtOdukSZPUqoyycqOslL1161YVhJ1N+moJOUZ6bc2cObPaY5o1a6bO96T3VkXvvfdejcctn1/xPq2X33rrrXPeUJU3RWV2gEyNrGo8VtKb7KqrrlJvjEoYeOWVV6p9RGSfWPFFRDZBTkIWL158zn6p2JJ37CTkevDBB1Vvhw8//FCFVdKjy6pDhw5qCeru3buryq9NmzappqRykibkncChQ4eqUn05Vu5HTtYkRJN39YiIiIicmQRaEmyNHj26ytulx5WERxIEyRuQcp513XXXqWbxcv516tQpdR9SESaN76X66ssvv1SVUxs3bsSAAQNU4/k//vhDndONGTMGAQEB6j7eeecdNe1R3pRctGgR0tLSajxu6ckln/f4448jMTFRvSkqjfWrWszo7bffVueU0g7j3nvvRfPmzdWbqzJNMiEhodKxMn4J/MRLL71U6/9PIrIhei8rSUTOTZapll9F1W3Hjx+3xMfHW0aMGGHx9fW1eHt7W4YMGWJZu3Ztpft5+eWXLb169bIEBgZavLy8LO3atbO88sorlqKiInV7RkaG5aGHHlL7fXx8LAEBAZbevXtbfvzxR52+ciIiIiLbMWrUKIunp6clLy+v2mPuuOMOi5ubmzqvOnnypGXSpEmWqKgoi7u7u6VJkyaW22+/Xd1mlZ+fb3nmmWcszZs3V58XHh5uufbaay0HDx4sOyY9Pd0yYcIEdY7XqFEjy3333WfZsWOHOg+U80QruW85h6vKrl27LMOGDVPnisHBwZZ77rnHsnXr1nPuQ8h9jxs3Tp0zytfbtm1by7PPPnvOfRYWFqrxyDnjmTNnav3/SUS2wyD/6B2+EREREREREdmKkpISREZGYtSoUfjkk0/0Hg4RXQL2+CIiIiIiIiKqQFaHTE9Pr9Qwn4jsEyu+iIiIiIiIiAC1EuW2bdtUXy9paB8fH6/3kIjoErHii4iIiIiIiAjA+++/jwceeAChoaGqOT8R2T9WfBERERERERERkUNixRcRERERERERETkkBl9EREREREREROSQXGEHzGYzkpKS4OfnB4PBoPdwiIiIyA5IN4ecnBy1HL2LC9/rs1U8zyMiIqL6PM+zi+BLToaio6P1HgYRERHZoePHj6NJkyZ6D4OqwfM8IiIiqs/zPLsIvuQdQOsX5O/vr/dwiIiIyA5kZ2erQMV6HkG2ied5REREVJ/neXYRfFnL3uVkiCdEREREVBucPmfbeJ5HRERE9Xmex4YXRERERERERETkkBh8ERERERERERGRQ2LwRUREREREREREDskuenwRERHVB5PJhOLiYr2HQRfJzc0NRqNR72FQA+HzleoSf38QETkPBl9EROR0LBYLUlJSkJmZqfdQ6BIFBgYiPDycDewdGJ+vVF/4+4OIyDkw+CIiIqdjfREdGhoKb29vvuixQxKG5OfnIy0tTV2PiIjQe0hUT/h8pbrG3x9ERM6FwRcRETnddCnri+igoCC9h0OXwMvLS32UF6/y/eS0JcfD5yvVF/7+ICJyHmxuT0RETsXaI0gqR8j+Wb+P7P3kmPh8pfrE3x9ERM6BwRcRETklTpdyDPw+Ogd+n6k+8OeKiMg5MPgiIiIiIiIiIiKHdFHB17vvvouYmBh4enqid+/e2LhxY7XHfv755+rdlIqbfB4RERHpR/6Ov/nmm3VyXytXrlR/37nqHpHtP1+JiIicTa2b2//www+YPHkyPvjgAxV6yR/hESNGYO/evaoxZFX8/f3V7VYsKyYiIqq9wYMHIy4urk5eAP/999/w8fGpk3ER0bn4fCUiIrLTiq8ZM2bgnnvuwZ133okOHTqoAEwaQ3766afVfo4EXeHh4WVbWFjYpY6biIiIzmKxWFBSUlKjY0NCQtgwnEhHfL6WKyoq0nsIRETkwFxq+0dp8+bNGDZsWPkduLio6+vWrav283Jzc9GsWTNER0djzJgx2Llz53kfp7CwENnZ2ZW2+mI2W7D+0ElsPHyq3h6DiIjoUt1xxx1YtWoV3nrrrbLWAdZ2Ar/99hu6d+8ODw8PrFmzBgcPHlR/b+WNJl9fX/Ts2RN//PHHeadOyf18/PHHGDdunHqB3bp1ayxYsOCixzt79mx07NhRjUke64033qh0+3vvvaceQ9ofyDivvfbastt+/vlndO7cGV5eXggKClLnGXl5eRc9FqKGZsvPV5PJhH/84x9o3ry5eo61bdtWjfNs8qa29TkcERGBSZMmld0m05rvu+8+NWZ5Dnfq1AmLFi1Stz3//POq0q0iGbt8DRX/f8aOHYtXXnkFkZGRagziq6++Qo8ePeDn56feLL/55puRlpZW6b7kdcQ111yjZpTIcQMGDFD/h3/++Sfc3NyQkpJS6fjHHntMHUNERA0sLwPY9BnsLvjKyMhQfyzPrtiS62f/kbGSP2Tyh3P+/Pn4+uuvYTab0a9fP5w4caLax5k+fToCAgLKNgnM6ssX647gxo/WY8bS8qmYRETkfJUX+UUlDb7J49aUvDDt27evqrpOTk5Wm/Xv41NPPYV///vf2L17N7p06aLecBo5ciSWLVuGLVu24Morr8SoUaNw7Nix8z7GCy+8gOuvvx7btm1Tn3/LLbfg1KnavzEkb5LJ/dx4443Yvn27eiH87LPPqhf+YtOmTXjkkUfw4osvqlYIixcvxsCBA9Vt8nXddNNNuOuuu9TXI/3Dxo8fX6v/K3Jcej1XHen5KufiTZo0wU8//YRdu3Zh2rRpePrpp/Hjjz+WHfP+++/joYcewr333quewxKqtWrVquzzr7rqKvz111/q3F7uQ74eo9GI2pCvV57/S5cuLQvNiouL8dJLL2Hr1q2YN28ejhw5okIyq8TERPW7QsK45cuXq9818rtCKudkf4sWLVR4ZiX3980336hjiIiogZQUAn+9BbzdFVj0GHBkDeyux1dtyR992awk9Grfvj0+/PBD9YetKlOmTFF9xKyk4qu+wq/hHcPxwsJd2HD4FFKyChAewMb7RETO5kyxCR2mLWnwx9314gh4u9fsT7G8EeTu7q6qO6QSQuzZs0d9lADpiiuuKDu2cePGiI2NLbsuf2/nzp2rXrxWrNo4m7zAlNBJvPrqq3j77bfVAjbyQry2bRGGDh2qwi7Rpk0b9eL4P//5j3oMeUEv/YqkakMqNqQqvGvXrupYCQjkRayEXbJfSPUX1Q1ZoEi+D/KGpfyMvPPOO+jVq1eVx0poIG9GfvHFFypwkDczX3vttUo/D/KGqASbEoDIfUr1jnyPp06dWi89XfV6rjrS81WqoiQ0s5LKL5m5IcGXBGni5Zdfxv/93//h0UcfLTtOKtGEVKPJ40hwJ89tIYFTbcnvAKlak/8nq4oBldynfE3yuBIOSjWc/PzK/+3333+vvg5hHYOQSrbPPvsMTzzxhLq+cOFCFBQUlH1dRERUj+QNot0LgKXTgNNHtH3hXQBXL9hVxVdwcLB6Nyc1NbXSfrlu/aN+IfJHSk5uDxw4UO0x8i6OlC9X3OpLVKAXesY0Ut+jRduS6u1xiIiI6otMDapIXiQ+/vjj6o2mwMBA9YJRXqReqIJEqk8qviiVv79nTzOqCXms/v37V9on1/fv36+CEnnRL6GWvLC97bbbVEVGfn6+Ok4CAAnNJOy67rrrMGvWLJw+fbrWY6DqFyh67rnnEB8fr/6vZYGi6r7HEl7JG5USjklwef/996updVKVZCVBmFQHzZw5U33f5frrr7+uPods9/kqAZJMt5TeYfJ4H330UdnjyX0kJSWp52FVEhISVMVYxcDpYshzvGLoJaSCS6rdmjZtqkLxQYMGqf3Wscljy7RFa+hVVRgorzHWr1+vrkuVqYReXBiAiKieJSUAn18N/DhRC718w4Ex7wH3rgKadIddVXzJHyf5IymlyTIv31ruLNfP945URXLCKyXTUpJtK0bHRuLvI6exYGsS7h5Q+3esiIjIvnm5GVU1hx6PWxfOflEnL6Jl+tB///tfNT1J+vhID60LNZA++8WkVOzI3/m6Ji9oJXiRaYy///67mmolVUOycp288Jexr127Vt0mAcozzzyDDRs2qMoUungVFygSskDRL7/8olpSyPS7s8mUMfm/t56zPfDAA6raR/q1SYWXkO+T9Ke6+uqr1XXp4/Tdd9+piiBHeq5aH9sRnq9SLSWPKd9HmZUhz0epApTnmJDHP58L3S79f8+eFirVgxf6f5A+fhLEyiZhuIRyEnjJdev/xYUeW1aYl+BMqr7k94X0U5PfM0REVE+yk4HlLwEJ30rJF+DqCfR7BOj/KODhC1tR66mO8k7h7bffrt6tktJ4aVYpf6isJ1ETJ05EVFSUKo23lnP36dNH/SGXRpjyh/Xo0aO4++67YStGdo7A8wt3YduJLBzOyEPzYL4rRETkTOQFY02nMOlJ3oCSN5AuRHrvSOWDVOdYK0qkV05DkcoVGcPZY5IKEWsfIFdXV9W0XjapQJLAS3r2yBRH+X5IhZhsEopJdZhM/arYBoFqx7pAkbSTqOkCRbLYkDQur0iCB2nIXrGFhVQL7du3T31/pTeT3C4hmzM/V235+SqPJ9+3Bx98sGyfNIe3kiBMAkx5Y3vIkCFVVppJr17r9/xsEljJtFcJv6zTXaVS60JkKujJkydVvzBrixPpB3j2Y8vUWwnSqqv6ktcYMgVUqtJatmx5TvUpERHVgaJ8YN1MYM2bQHHpAkSdrwOGPgcE1l+P9otV6zOHG264Aenp6epEVP6oyaot0pTW2vBe3pmREykrmZ4g7y7KsY0aNVIVY/LuYIcOHWArgnw9cFmrYKzal44FCUl4dFhrvYdERER0DnkxKlUZ8qJYpidVV90hK7zNmTNHVT7IC0/ptVUflVvVkd5A0pdHehXJeYMEKzIVTlZyFNLI+tChQ6oZtZwb/Prrr2p80kNKvj55wT18+HBVvSHX5bxDwjS6eOdboMjae+psUmkjAZZ8nyRAkO+L/FxVDHOkUkx6sbZr106FmnKbrNQnjdarI4GabFb1uXq3nmz1+SqP9+WXX2LJkiWqKkoq+6TasmJFpVRgytRWeQ5KI/ucnBwVmD388MNq+qH8TEyYMEH9fMib2/IzJGOX/mKDBw9Wz1mZ8iqVa/I6QSqvLtS6RKY3SlgoVZ7y2Dt27DinH7DMMJHbZeEMCXGl35dMa5Q3460rQ8rPrTyW9CmTN+CJiKgOSUXv9p+BP54HsksXLGzSE7jy30CTylP5bUmtenxV/KMjVVty0iJ/0Hv37l12m5QTW1dtEv/73//KjpXwS0rqrQ1sbcmYuEj1cf7WRK4cRURENkmmJ0m4IG8eWacBVUVejEqgJFUd8mJaXgh269atwcYpjyWNsmVKVadOndSbZfIC1Lo6m1R3yQv9yy+/XAVaMuVOpsd17NhRvWD9888/1fQ6qSaRPlMyJUtefFPDkpUJJSSRUEsCCTn/kwr/im9wyvdZpqV9++23avqqVOPIlD35aAurd+vJVp+v9913n6qslFBazuGlyqpi9ZeQ2R0yq0PCanleykIU0qPPavbs2Srclsoq+fqefPLJskBUntPyedJHTPrIybRX+b+4EPk/ktcQstqk3KdUfsnPUkVBQUGqMlSq4iSAkzfUpQ9gxeov+fmU3zUyHpmJQkREdeT4RuDjYcCcu7XQKyAamPAJ8I+lNh16CYPFDlIeeSdQToyysrLqrdF9bmEJur+0FIUlZix6+DJ0igqol8chIiJ9yQpfhw8fVtUNZ0/jIsf6fjbE+YM9TXWUFQZ//vnnsj6t1oBDWlHMnz//vP/HEo7Iio1S4SUVezt37lS3SWgl+x566KGy46XSRnqAVVdJVlXFl9xPVd8nPl/pYsjqjlJ1Jitjng9/voiIaiDzmFbhtWO2dt3NBxgwGej7EOCm34qNtTnPu6iKL0fk6+GKYe218n9pck9ERETkKCouUGRlXaBIGpyfjwQC0r+1pKREVfpIM3srWY2zYgWYkCqn803Va8jVu8m5yIsf6TEnFYgyLZOIiC5BYQ6w7EVgZs/S0MsAdL0VeCQeGPi4rqFXbTH4qmBUrDbdceHWJJjNNl8IR0RE1CCk3470KKpqk9vIPsjiADItTKYh7t69W63SePYCRRWb30s7C5mSKv3YVq9erfo3SaAl09qsZGqe9PSSVhbSy0oWIZCpe9ZG7dTwnPn5KqGs9AeUr/OKK67QezhERPbJbALivwLe6Q6sfgMoKQBiBgD3rQLGvAv4hcPe2MeyOA1kcNsQ+Hm6IjmrAH8fOYXeLYL0HhIREZHupD9XdT16WK1jP2q7QJFMA5MeaxJ8SWgifdekEbr0aLOSRuPSjF16RKWlpanpkNJDSh6D9OHMz1fpNUxERJfg8GpgyRQgZbt2vVFzYPjLQLurZWll2Cv2+DrLkz9vxY+bTuDm3k3x6rjO9fpYRETU8NjTxbGwx5f9O9/3ic9Xqk/8+SIiKnXyILB0GrBnkXbdIwAY9ATQ617A1QO2iD2+LsHo2Cj18dftySgqabil34mIiIiIiIiIGsyZTGDJM8C7vbXQy2AEet6t9fHq97DNhl61xamOZ+nbMgjBvh7IyC3EmgPpuLydVv5PRERERERERGT3TCXA5s+AldOB/JPavlbDgOGvAKHt4GhY8XUWo4sB13SJUJcXJHB1RyIiIiIiIiJyEPv/AD7oD/z6uBZ6BbcFbpkN3DrbIUMvwYqvKoyJi8Tna4/g912pOFNkgpe7Ue8hERERERERERFdnPS92rTGA0u1616NgSFPA93vBIyOHQ059ld3keKiA9G0sTeOncrHH7tTMSo2Uu8hERERERERkT05dUj72LiF3iMhZ5Z3UpvSuOlTwGICXNyA3vcBAx8HvBrBGTD4qoLBYMDo2EjMXHEA8xOSGHwREREBOHLkiFr9bMuWLYiLi9N7OERERLaluAA4ugbYvxTY/3t58NWkJ9BtItBxPODhq/coyVmUFAEbPwJWvQ4UZmn72l0DXPEiENQSzoTBVzVGx2nB16p9acjKL0aAt5veQyIiIic3ePBgFTi9+eabdXJ/d9xxBzIzMzFv3rw6uT8iKsfnK5GTOH1UmzomYdehVUDJmfLbpLLGYgZO/K1ti6cAncYDXScCTXpIxYWeIydHZbEAe34Blj5bHr6GdQaufBVoPrBeH7rEZMaelBzEHzuNLccy1ccf7+uLMH9P6InBVzXahPmhXbif+qb9tiMZN/ZqqveQiIiIiIjsVlFREdzd3fUeBtGlV9EcX69VdEnYlb6n8u3+UdrqeK2HAy0GAUX5wNbvgPgvgVMHtY+yhbTXqsC63AD4BOn11ZCjSd4GLHkaOLJau+4TCgx9Foi7BXCp+97l6TmFlUKu7SeycKbYVOmYLcdO48pO2gKCeuGqjheo+hILtnJ1RyIi0pdUe6xatQpvvfWWmpIvm0w93LFjB6666ir4+voiLCwMt912GzIyMso+7+eff0bnzp3h5eWFoKAgDBs2DHl5eXj++efxxRdfYP78+WX3t3LlylqPS8bUq1cveHh4ICIiAk899RRKSkou+PhCHk8+18fHB4GBgejfvz+OHj1aR/9jRPqxlefrv/71L7Rp0wbe3t5o0aIFnn32WRQXF1c6ZuHChejZsyc8PT0RHByMcePGld1WWFio7iM6Olo9x1u1aoVPPvlE3fb555+r521FUo0mY7OScUvV28cff6ymSctjiMWLF+Oyyy5Tny9f5zXXXIODBw9Wuq8TJ07gpptuQuPGjdXviB49emDDhg3q/9HFxQWbNm2qdLxU1jVr1gxms7mG3yWiWshO1sKqH24FXm8BfDEKWPuOFnoZjECz/sCw54EH1gL/3AmMfhtofw3g4Qf4hQGXPQY8vBm441egy42AqyeQvhtYMgWY0Q746U7g4AqAP790sXJSgfmTgA8HaqGX0QMY8H/AI/FawFoHoVdRiRlbj2fi878O45HvtuCy15aj5yt/4L6vNuODVQex8fApFXr5ebpiYJsQPDq0Nb64qxf6twqG3ljxdR6jukTi9cV7se7QSaRmF+henkdERPVYEl6c3/CP6+Zd42kO8gJ637596NSpE1588UXt093cVHB0991343//+x/OnDmjXqRef/31WL58OZKTk9ULx9dff129mM3JycHq1athsVjw+OOPY/fu3cjOzsZnn32m7k9eYNZGYmIiRo4cqV7kf/nll9izZw/uuece9eJWXvCe7/ElHBs7dqw6/rvvvlOVIBs3bqz0opnIZp6rdvp89fPzUwFVZGQktm/frp5vsu/JJ59Ut//yyy/qsZ555hn1HJbn4a+//lr2+RMnTsS6devw9ttvIzY2FocPH64U1NXEgQMHMHv2bMyZMwdGo/bCS8K8yZMno0uXLsjNzcW0adPUOBISElSoJfsGDRqEqKgoLFiwAOHh4YiPj1ehVkxMjAoE5f9BwjAruS6/i+TziS6ZqQRI3FRa1fU7kLK98u1SRdP6Cm1rMQTwqhwCV0l+f8T017arXgN2/KyFaclbgZ1ztC2wqTYNMu5mICCq3r48ciDFZ4B17wJr/gcU5Wr7Ok3Qglj5eboEqdkFqlorXqq5jp7G9sQsFJaYz/mxbhPqh65NA9GtaSP1sWWIL1xcbOt8jsHXeUQ39kb3Zo2w+ehpLNqWjH9c1lzvIRERUX2QF9Kv6rCQydNJgLtPjQ4NCAhQU4SkckNeBIqXX34ZXbt2xauvvlp23KeffqqqM+RFt7x4lIBp/PjxqhJCSDWJlVSVSEWH9f5q67333lOPNXPmTBVYtWvXDklJSerFvLyQlRfy1T3+qVOnkJWVpSo9WrbUGqy2b9/+osZBTkSv56qdPl+nTp1adlkCIwnQvv/++7Lg65VXXsGNN96IF154oew4CbiEjOnHH3/E0qVLVdAkpGqstiRMk1AtJCSkbN+ECRMqHSP/D3L7rl27VFj47bffIj09HX///XdZwCfVZlYSHt5///2YMWOGqkSTUEyCPamII7pouenAgT+0oOvgcqAgs8KNBq0nl0xflLArPBa4lJBVgrKed2tbUgKw5Stg209A5jFgxcvAyle16ZJSqdPmSsDIftNUxRtBO2YDf7wAZB3T9kV1B0ZMB5r2rvXdFZaYsCspWwu5jp1GwrFMJGZW6FdXKsDLrSzkkq1LdAD8PW3/55PB1wWMiYtUwdeChEQGX0REZFO2bt2KFStWqGlTZ5NpQ8OHD8fQoUPVi+cRI0ao69deey0aNaqbpaulAqVv376VqrRkuqK8gJdpSvICurrHlxezUp0h+6+44gr1wloqX2S6JJEj0uP5+sMPP6hqLbl/a7Dm7+9fdrtUWEkVWFXkNqnQksqrSyEhXsXQS+zfv1+F4zJ1USrIrNMTjx07poIveWwJCaurapNq0Yceeghz585VwZ1UtQ0ZMkSFe0Q1Jj93SVvKq7rkMizlt3s1Ku/V1XJo/fXhiozTtiteAnYv0KrAjv5VPi6fEK0CTCrBgssDYHJiJzZr02SPbyjvKycVXp2urXEgm5x1BvFHtZBLqrp2JGWrqYwVSdFW23D/StVcLYJ97LI6n8HXBYzsHIEXFu7C1hNZOJKRh5jgmr3TR0REdkSmMEk1hx6PewnkheyoUaPw2muvnXObBEjyolWqNdauXYvff/8d77zzjprSJC82pd9OfbvQ48vUpEceeUT1+5EX6FKdIsf36dOn3sdGdkqv56r1se3o+SpTFG+55RZVzSVBmlShSbXXG2+8UamKrDrnu03IlEKZhlnR2f3DhPTnOpv8P0ggNmvWLDUNU4IvCbykOqwmjy3VdDINU36HSIWcVIjJ9FKiC8o/pVVzSVN6qe7KP2vqbkRsaVXXcK16ph6agVfL3RuIvVHbMg5oVWAJ3wJ5acBfb2mb9BKTKrD2o7XjyblkndAqvLb/WP536bJ/An0nnffnoaDYhJ1JWSro2nL8tPqYkl1wznGNfdzRNToQ3Zo1Uh+7RAfC18MxIiPH+CrqUbCvh2rG9ue+dNXk/pGhrfUeEhER1TV556qGU5j0JC/2TKbylXK6deumeudIlYOra9V/0uVdOanCkk0qLOTFplRJSH+ds++vtmRqojy+vPi1vvv3119/qR5CTZo0ueDjC6nqkG3KlCmqekxewDL4Int/rtrC81UCNPl8Cc+szl48QnpsLVu2DHfeeec5ny+VZxJISZN+61THiqSKS/qQSb8ua7gllVoXcvLkSezdu1eFXgMGDFD71qxZc864pCG+TImurupLpjtKWCZTrq1TRInOIeGs9OeyrsB4YiNgqVDV4uEPtByiBV1S3eV3cVP/65xUdl3xAnD5VGDfEq0K7MBSrRJMtl+fADpfp4VgUi1Gjq0wVws+ZUGFktLph7JK4+XPAv6VK+XlnEymKMqURWt/rl1JWSg2VX6jwuhiQLtwv7JKLvnYLMjbLqu5aoLBVw2Mjo1Uwdf8hEQ8fHkrh/1hICIi2yYvmK2rmsl0KZnqIy8epSG29OyRF4jSSFqqOuRFo6x6Ji9qZcpUaGio+lzpm2PtpSX3t2TJEvUiVFZWk4oQacBdUw8++KBaSe3hhx/GpEmT1P0899xz6kW6VIPI41X3+NIk+6OPPsLo0aNVxYd8rkx/kioOIkeg9/O1devWauqg3L+s2iiN7CVEq0ierzK9UvrsyZRBCZCkub306ZPHu/3223HXXXeVNbeX4CwtLU1NS+7du7fqYfb000+ryk0Zr0w5vBCZuinjl+e/VLrJGGU12Irk/0h6ocmUxunTp6vjtmzZon5XSEAu5P9FQnIZq4zxQlVi5EQKsoFDK8vDrtyUyreHdihtTD8ciO5t2/2zZGyyOqRsWYlaBdiWL7VeYJs+0bbwLloAJkFYTZrsk31Nx936HbDsxfKf46b9gCtfBSK7llVzbTuRVTZlUYKu9JzCc+4q2NcdXSuEXF2aBMDb3YniIIsdyMrKknhSfdRD9pkiS+tnfrU0+9ciy47ETF3GQEREdePMmTOWXbt2qY/2Zu/evZY+ffpYvLy81N/Fw4cPW/bt22cZN26cJTAwUO1v166d5bHHHrOYzWb1dY4YMcISEhJi8fDwsLRp08byzjvvlN1fWlqa5YorrrD4+vqq+1uxYsV5H18eT47bsmVL2b6VK1daevbsaXF3d7eEh4db/vWvf1mKi4vVbed7/JSUFMvYsWMtERER6nObNWtmmTZtmsVkMtXZ91Pv8weqmfN9n/h8vfjnq3jiiScsQUFB6nNuuOEGy//+9z9LQEBApWNmz55tiYuLU8/D4OBgy/jx48tuk//3f/7zn2XP01atWlk+/fTTstvnzp2r9snXcs0111g++ugjNTar5557zhIbG3vOuJYuXWpp3769+jq7dOmifo/I58n9WR05csQyYcIEi7+/v8Xb29vSo0cPy4YNGyrdzyeffKI+b+PGjZaLYc8/X1SB2WyxpO62WNa8ZbF8drXF8kJji+U5//Lt5XCL5dsbLZa/P7FYTh+z2D35O3lwhcXy050Wy4vB5V/nS6EWy+x7LZbDq7X/E7JvR/6yWD4YWP79/V9ni3nHXMvR9FzL3PgTlmnztluueXu1peWUX1ROUXGTfaPeWa2OmbflhOXYyTz1d8bR1OY8zyD/wMbJ0s3yrpas/lSxIWdDeuDrzfhtRwruG9QCU67iqlNERPaqoKBAVRtJzxxPT0+9h0P1+P20hfMHurDzfZ/4fKXzeemll/DTTz9h27ZtF/X5/PmyY0V5wOE/tYou2ayr2lkFtSpfgVH6Yrl6wGF7lm37QZsKmbarfH/jlkC324DYmwG/MD1HSLV16jCwdJq20AGAEjdfbIz+B740X4lNJ/KQkav1QqwoxM8D3awrLTZrhE6RAfByb8D+dDqpzXmeE9W2XfrqjhJ8LUxIwr9GtIOLLHFAREREREQNShYKkCmkM2fOxMsvv6z3cKihnDxYvtLhkb8AU4XpXK6eQMyA0rBrGNC4BZyCd2OgzwNA7/uBxM1aALZjNnDqIPDH88Cyl4C2V2lTIWVlSiNf/tsqy5lMZP3+b/ht/RhGczFMcMH3pssxo2ACTu4KAHBaHedmNKBjZECllRajAr3YjukC+JNfQ4PbhsLPwxVJWQXYdPQ0ejWvutEmERGRvZKeOrJVRZpQ//bbbw0+JiKqmjM/X6Wn4Hfffad6gEl/L3JQxQXA0TWlVV2/A6cOVb49sCnQeoQWdsVc5tyrHEro0aSHto14Fdg5VwvBpJn/nkXa5hehNUTveivQuP5Xdqbzyy0swbbjmdhyNB1+O7/DNac+Q2Nkq9v+NHXGKyW3YK+lKSICPDGyLORqhI6R/vB0c/xqrrrGqY618PhPW/Hz5hO4tU9TvDy2s27jICKii8epLdWTFdRkq4o0jo6KioKt4VRH+8epjs7zfLU1/PmyQdK03dqUXqYyFueX3+biBjTrV1rVNRwIbq0FPlS9tN1A/Fdag/QzFX5fNB+kVYG1uwZw489+fZPI5VBGHuKPnsaW45nq477UHPQzbMdU16/RzuW4Ou6QJRLfBd4LS6vh6BbTWFVzRQRw4Y7qcKpjPa7uKMHXL9uS8dyojnAzuug9JCIiojojq8zJRkS2j89XcgglRcDx9eVhV/qeyrf7RZavwNhiEODhB3sLPKTMRLc2OaHttRUAhz0H7P1VqwI7uAI4vErbvBoBXW7U+oGFddRnjA4ou6AYW6Wa61hm6WqLmcg6U1x2e0tDIma5fouhxi3qeoFbANK7/xNRQx7AMx4MIusDg69a6NcySC0DKg3l1hzIwJC2oXoPiYiIiIiIyH5kJwMHSqcvHlwJFOWU32YwAtG9y8MuCWNsqKpLgiyZonYqr0i9JpSPJ3MLcVJ9LMLJvMLSj9p+ud1kscDf0w0BXm4I9NY++stlL+2ydZPb/Msuu6uPPu7GuundJM39O47TttNHgYRvgC1fA9mJwIb3tS2qu1YF1mmC3QWMejKbLTiYnlsWcMnH/Wm5KvCsyMPVBf0iDXjI8DO6pc2Fi6UEcHEFet0Lz4FPIFr6tVG9YfBVC65GF1zTJRKfrz2CBQlJDL6IiOyYHcz0pxrg99E58PtM9YE/Vw3EVAIkbipvTJ+yvfLtPiFAKwm6rgBaDtGqkBrQmSITMkpDqrODq4qB1qncImTkFaGoxFzrx5BqH9mOVT07uVquLoayYKw8FKscmAWcFZZZj6m2D1SjZsCQp4FB/wIOLgfivwD2/qY1x5dt8dNAp3FAt9uBJj1tKni0BVn5xdhyvDzkSjieiZyCknOOi27spfXlig5Et2hfdDjxE1z/fA0oyNQOaHMVMPwlbcou1TsGX7U0KlYLvn7fmaJ+STrDMqFERI7Ezc1NfczPz1d9cMi+yfex4veVHAufr1Sf+PujHuWmAweXaUHXgWXlL/YVg1ZdpHp1XQFExMlcwDp76MISU2klVuXKK6nQKrtcYX9+kanWj+HtbkRjH3cE+Xog2Me9/LJv+eUg9dEdRhcDskuDL9ky8ytftt6WWeEYCVeKTGaUmC3a15BXVOsxuru6lAdhFUOyCsFZoHd7BMS+gaDYZxF1dAEC93wP19MHtGow2ULaaVVgMh3SJwjOxmS2YH9ajhZyHT2tgq6D6XnnHOflZkSXJrLSYiN0axqIuKaBCPXzlHQd2LcYmD8VOHlAOzi0ozb9tMXghv+CnBiDr1qSH+Qmjbxw4vQZLNuTqirAiIjIfhiNRgQGBiItLU1d9/b25hLQdlqpIS9a5fso30/5vpLj4fOV6gN/f9QDsxlI3lK+AmNivPxPl98uVVwth2phV6uhgE9wje+6xGTGqXwtyNICLK0qy1qhVWnKYW4RcgrPrb6pSUikAixfdwT5eKjAKuisAKt8v0etix9UCFLLn9GCYnOFgKyoPBSrJkSruElgI5Vp6TmFaquZDgBeQA/DXtzitgpXuayDp/RcW/I0SpZMw07/AdgZNganwvshwNvjnGmZqirN01XNkrJXp/OKVAWXBFyybT2epaa2ni0myLss5JKP7cL9zv26U3ao/zvVS81a2Xj5VKDrbYALf+c0NAZftSQnW9Lk/r2VB9V0RwZfRET2Jzw8XH20vpgm+yUvWq3fT3JMfL5SfeHvj0t05rQ2VU6FXUuB/IzKt0fElq/AKBVepS/2pSdSZmlQlXFOgFU61dDaLyuvSIU7tSVTBM+uvDo70JLbrRVavh6uNhWqy1gkXJMtPKD2oZmENZXCsIpVZtXtz9dCw02WdthU1A7TcBtGG9fieuNKxLocQmz2CrWd2BeMH0sG4z3TICTj3CowPw9XrYfZ2VMyK1aaeVWelinHy+c15CIAEqjuS9V6c6kpi8cy1cqLZ5M+a7HREnAFqqmLcdGB6uenWrlpwPKXgS1fARYzYHQH+jwIDPg/wJMrTOvFYLGDye22thz53pQcjHjzT7gbXfD31GHqCUtERPbHZDKhuLj2J9RkG2R60vkqNWzt/IEu7fvE5ys15O8PqoK8bJT+XNYVGE9s1F7YlzK7+yEz4jIkhlyGfX59kFgScG6frDwt6DLX8hWo5CGNvMsDLKnMCj4rwCq77CPVSLYVZNkLqRTLLShB5pnKFWaGlO2IPjIbrVN/hZdJW4zADAM2u3bDXMNQ/FoUh8yaFpad53tsrSKruodZxeulwZm3No1Tpp5e6PstP4sVV1nceiKzymmuLUJ80DW6Ebo1C1Qf24b7qemqF1RcoC0S8Ocb5Qs2dBgLXPEC0Cjmov9fCHVynseKr4sgP/xtw/ywNzUHS3ak4Pqe0XoPiYiILoK86OELHyL7wOcrUcOQuoi8IpMKCk6fPgnDoZXwPbYCYWmr4VuUXunYA2iK5aZYLCuJw+aC1ijJdgX2yi2ppVv1JMiwBlUSaJ3TJ6tChZZMp6tR+ECXRP6PVWWW91mFHWqW0wig+AyweyEQ/yVcjqxGz5LN6InNeNUnGKZ+NyK73U047R1TVlWWfcG+ZlrAJtM6JQyV2y62wq/qyjI3ZBeUqLDr6Emtp19FUuknFVzWKYtyuZGPe+0D4V3zgKXTgMxj2r7IrsCI6UCzvrX+Wqh+MPi6SKPjIvGfJXsxf2sigy8iIiIiIrJZsiiXdcXCsj5ZeVX0zMopQGD+YfS3xGOISwJ6uuyFm6G8Kibf4oG/zJ2wwhyHlaZYJKG8V5dMVTt/gFUeZEm44GbHvaCclpsX0OV6bTt5UJvOl/AtkJsK47qZaCRb075aQ/wOYwB3nxrdbUGxqdoFAM7ta2atRpPpnEUoNllqvAhAq1DfspBLpi3K9UsKVKWXnfTxOrZOu+4XCQx7Duh8fZ0u2ECXjlMdL9LxU/kY8PoKVZK5fspQhPrXbu41EREROd/5A52L3yeiuicv8VbsTcP7Kw9iZ1L2eVcu9EIB+rrsUkHXEGMCmhgq9+pKNEZhp09vHGnUH1mhvRDo51sWcAVLD63Syx6urMh0SqZibepr/JfA/iXl0189/IFOE7QQTCqg6mHqqfycnyk2nRuWVbgsAaussijVXHXWoig7CVj2IrD1O+26qxfQ/1Gg/yM1Dvvo0nGqYwOIbuyt0uL4Y5lYtC0Zd13WXO8hERERERGRE5MgYNnuNLy9fD+2nciqdJv0J1YVV77uaO+egX7mzehyZiOa5W6Bq7m8UsZi9IC52WUwth0BtBqGqKCWiNLhayE7YXQD2o3UNgmEpAJMKsFOHwE2f6ZtYZ21AKzLddoKn3VE+np5u7uqLSLAC/WuKA9Y+w7w11tAcenUyS43AkOnAQF8ltgyBl+XQFZ3lOBrwdYkBl9ERERERKRb4LV0V6oKvHYkZqt9Xm5GTOzbDNf1iEaYtwW+KRthUCsw/g4kH6x8B4FNy1ZgNMQMgNHdW58vhOybfyQw8HHgssnAkdVaALZrAZC6HfjtCeD3qdoUyG63Ac0us5/pgGYzsP1H4I8XgJwkbV90H+DKV7UVS8nmMfi6BFd3icSLi3Yh4Xgmjp7MQ7MgljUSEREREVHDMJst+F0Crz/24WhKGoIM2ejjnotr23nhyuau8C3ZCfzxN3B4VXmFinBxA5r1A1pfoQVewW3qZSoaOSkJtFoM0rarTgHbfwI2fwGk7dQCJNkaNdcCsLhbAL9w2Kxj64HFU4Ck+PKQeNgLQMdxfM7YEfb4ukS3fbIBq/dn4PHhbTDp8tZ6D4eIiIjs4PyByvH7RFQFeYlWmAPkpQP5J4G8DCA/Q7uedxKWvHRkpCUhKyMJ3iWZCEIOPAwXWA1PGm9bgy4JJDz8GuqrIdJ+piU8kl5g22cDRTnafoMRaDNCmwrZ6grAaCO1OaePAn88B+ycq1139wMGTAb6PAi4sb+3LWCPrwae7ijB1/yEJDw0pJWaZ0xERERERFQ5yMrWAqyyECuj2mBLXTZVv0KdvOIIKd3UFSs3b8A7GPAp3eRycGst7ArryAoV0o/87Mm0QNlGvArsnKeFYMfXA3t/1TbfcKDrLUDXW4HGLfQZZ0E2sGYGsO49wFQIGFyArrcBl08FfEP1GRNdMgZfl2hEp3A8M28H9qflYk9KDtpH8J1KIiIiIiKHD7IKskpDq/Sqg6uzQy7zBSqyquLmA/gEweIdjFSTH7aecsXhM144afFHvmsgenRogyt6dIBvUIQWcrE3F9kDWflQBVy3AOl7tQBMVkjMTQFWv6FtMQOAbrcD7Uc1TIWV2aT1JFv+svY8Fs0HaSFdeKf6f3yqVwy+LpG/pxsubxuKxTtTVNUXgy8iIiIiInsMsjK10EpVYZ1dmWW9bL395MUFWe6+gHdQaUVWSGl1VlDpx5DSKq3S272DYXL1wqJtSXhn+QEcSMtVd+Hv6aoW1prUvzkCvNzq/v+CqCGFtAVGvAIMfU6r+pIQ7OByrTm+bJ6BQJcbtKmQ9RVAHVoFLHkaSN2hXW/cEhj+MtD2KlZJOggGX3VgdFykCr4Wbk3CkyPawsWFTw4iIiIiIl1XYZMg65yKrLODrYpBVkntH0f6/pQFVxWmF5Z9DKl8u5tXje62xGTGwtLA61B6ntonIdc/LmuOO/rHqDffiRyKqzvQcay2ZR4HEr4BtnwNZB0HNn6obZHdtACs0wTAsw4KTjIOAEuf1QI34RkADHoK6Hm3Nh5yGAy+6sDl7ULh6+GKxMwziD92Gj1iGus9JCIiIiIixwuyKk0prBBcVVWRZTFdZJBVMbgKqlCZVUWwVcdTsCTwklkkM1ccwOEMLfAK9HbD3Zc1x+39YuDHwIucQWA0MPgpYOATwKEVWhXYnl+15viySXWWrKooIVh079pXZZ05Dax6Hdj4kRZ4S4N9CbvkMb35Wt4RMfiqA55uRgzvGIY58YnqDxWDLyIiIiKiCwRZ8uKzUpP30uDqnGAr4+KDLA//s8Kq0iCrqmBLbtNptTYJvOZuScS7Kw7gyMl8ta+RBF4DWqjAS95kJ3I6Lkag1TBty00Htn0PxH8FZOzVKsJkC26jBWBdbgR81XIP1TMVA5s+BVZO137/iNYjtGmNIW0a5EsifRgsFpnQbtvsYZnrVfvScfunGxHk4471Tw+Fm9FF7yERERE5NXs4fyB+nxyarI627QetefXZqxeqIMtc+/v0CDirCquq/lgVLrt6wJYVS+AVn6gqvI6d0gKvxj7uuGdAC0zs2ww+DLyIKpP44vhGrQps5xygWHvewMUNaDcS6DoRaDlEC80qfs7+pcDvzwAZ+7R9Ie213mKthurzdVCDnj/wN2kd6d8ySIVeJ/OK8NeBDAxuy6VOiYiIiMgJZSUCGz4ANn8OFGaf/1jpqVNVY/eqgi0VZDlG352iEjNmx59QFV4nTp9R++S1xL0DW+DWPgy8iKol0xqb9ta2K6cDO2ZrIZhMgdw1X9v8mwBdb9VWjSzM1QIvaZgv5PfIkGe0FSONfJ45C36n64ir0QVXd4nAl+uOYsHWJAZfRERERORckrcB62ZqL0StjeJlGlK7awDf0PLwyjrN0IGCrNoEXj9tPo73VhxU/YFFsK877hvYErf0aQpvd748I6oxaXDf405tS9kBbPkK2Po9kH0CWPVvYNVrWlAm1aVGd6D3/cDAx7XAnZwKf7PWoTFxkSr4WrIjBQXjTKr3FxERERGRw5IpRAeWAWvfBg6vKt8fMwDoOwloPRxwYQuQwhITftx0Au+vOICkrAK1L8TPA/cPaombezWFlztfNxBdkvBOwFWvAcNeAPYsAuK/AA7/qf2Oaj8auOIFoHELvUdJOmHwVYe6NW2EqEAv9e7N8j1pGNk5Qu8hERERERHVvZJCYPtPwNqZQPpubZ+sjCYrrfWbBER21XuENqGgWAKv43h/5UEklwZeoX4eeGBwS9zUqynfKCeqa7JARedrtS3zmNbQPqil3qMinTH4qkMGgwGj4yLVH7b5CYkMvoiIiIjIseSf0lZF2/gRkJuq7XP3BbrfAfS+DwhsqvcIbSbw+n7jMby/6iBSswvVvnB/TxV43dAzmoEXUUPg7yMqxeCrjo2O1YKvFXvTkXWmGAFebnoPiYiIiIjo0pw6BKx/H9jydfkqan6RQJ/7tSbRXoF6j9BmAq9vNxzDB6sOIi1HC7wiAjzx4OCWuK4HAy8iIj0w+Kpj7cL90CbMF/tSc7FkZwqu7xGt95CIiIiIiC7O8b+1/l3SM0caRIvwzkDfh7VpjU7WnL46Z4pM+GbDUXyw6hAycrXAS1qgSIXXdT2awMOVgRcRkV4YfNXHdMfYSPz3931YuDWJwRcRERER2RezCdj7K7D2HeD4hvL9rYYB/R4Gmg/SVkoj5BeV4Ov1R/HRnxJ4FZUFXg8NaYVruzeBuysb+xMR6Y3BVz0YHRulgq+/DmQgLacAoX6eeg+JiIiIiOj8ivKBhG+A9e9pUxuF0R3ofD3Q9yEgrIPeI7QZeYUl+Gr9Ucz68xBO5mmBV3RjL0wa0grjujLwIiKyJQy+6kHTIG/ERQci4Xgmft2WjDv6N9d7SEREREREVctN05rV//0JcOaUts8zEOj5D6DXvYBfuN4jtBm5hSX4ct0RfLz6ME6VBl7NgrxVhde4rlFwMzLwIiKyNQy+6smYuEgVfM3fmsTgi4iIiIhsT9oeYN1MYNuPgEnrS4XAZkDfSUDXWwB3H71HaDNyCorx5bqjmLX6EDLzi9W+mCBvTLq8NcbGRcKVgRcRkc1i8FVPru4SgZcW7cKWY5k4djJfVYEREREREenKYgGOrAbWzgT2Lynf36Sn1r+r3TWACxuxW2UXFOPzv47gkzWH1YrtokWwDyZd3kr19WXgRURk+xh81RPp69WvZTDWHMjAwm1JqvyZiIiIiEgXpmJg5zxg3TtA8tbSnQag3dVAv0eApr11HqBtkZDrs78O49M1h5FdUKL2tQjxwSOXt8ao2EgYXdjcn4jIXjD4qkfyLpAEXwsSGHwRERERkQ4KsoH4L4H17wPZJ7R9rl7aVMY+DwJBLfUeoU3Jyi/GJ38dVqFXTmng1SrUFw9f3grXdGHgRURkjxh81aMRncIxdd4O7E3NwZ6UbLQL99d7SERERETkDLJOABs+ADZ/ARRma/t8QoBe9wE97gJ8gvQeoU3JzC9S0xllWmNOoRZ4tQmTwKs1RnaOYOBFRGTHGHzVowAvNwxuG4Lfd6Wqqq92VzL4IiIiIqJ6JNMYpX/XzjmAWQtwENwW6DcJ6Hw94Oap9whtyum8Iny85hC+WHtUrdgo2ob54ZGhrXFVp3C4MPAiIrJ7DL7q2Zi4KBV8zU9IwhMj2sJg4B9PIiIiIqrjhvUH/gDWvg0c/rN8f8wArWF9qysAFzZhr+hUXpFaofHLtUeQV2RS+9pH+OPRoa0wvAMDLyIiR8Lgq54NbR8KH3cjEjPPIP7YaXRv1ljvIRERERGRIygpBLb9CKybCaTv0fYZjECn8UDfSUBknN4jtDkncwvx0epD+GrdUeSXBl4dJPAa1hpXtA9j4EVE5ID41k8983QzYkTHcHVZpjsSERER6eXdd99FTEwMPD090bt3b2zcuLHaY4uLi/Hiiy+iZcuW6vjY2FgsXrz4nOMSExNx6623IigoCF5eXujcuTM2bdpUz1+Jk8s/Bfz5H+B/nYAFk7TQy91PC7se3QpM+Jih11nScwrx6q+7cdlrK/DhqkMq9OoU5Y9ZE3vgl0cuU+frDL2IiBwTK74awKi4SMzZkohftifj2Ws6wNXIvJGIiIga1g8//IDJkyfjgw8+UKHXm2++iREjRmDv3r0IDQ095/ipU6fi66+/xqxZs9CuXTssWbIE48aNw9q1a9G1a1d1zOnTp9G/f38MGTIEv/32G0JCQrB//340atRIh6/QCZw6BKx7D0j4BijO1/b5RwG97we63w54Bug9QpuTllOggq5vNhxFQbFZ7evSJACPDm2Ny9uFsg0JEZETMFgs0hTAtmVnZyMgIABZWVnw97e/BvHFJjN6v7pM9RL48q5eGNgmRO8hEREROTx7P3+oaxJ29ezZEzNnzlTXzWYzoqOj8fDDD+Opp5465/jIyEg888wzeOihh8r2TZgwQVV1SSAm5PP++usvrF69+qLHxe9TDRzfqPXv2r1IGnpp+8I7A/0eATqOA4xueo/Q5qRlF+D9VQfx7YZjKCzRAq/Y6EA8NrS1WnyKgRcRkX2rzfkDK74agJvRBSM7h+Pr9cdUk3sGX0RERNSQioqKsHnzZkyZMqVsn4uLC4YNG4Z169ZV+TmFhYVqimNFEnqtWbOm7PqCBQtU1dh1112HVatWISoqCg8++CDuueeeasci9ytbxRNXqoLZBOz5BVj7DnCiwpRUaVQvDeubDwQY3pwjJasAH0jgtfEYikoDr65NA1WF16A2DLyIiJyRS333h6jo+++/V39sxo4dC2dc3VEs2ZmCgmKtkSYRERFRQ8jIyIDJZEJYWFil/XI9JSWlys+RQGvGjBlq6qJUhy1duhRz5sxBcnJy2TGHDh3C+++/j9atW6upkA888AAeeeQRfPHFF9WOZfr06eodWusmVWdUQVEesHEW8E534MfbtNDL6A50vRV4cD1w689Ai0EMvc6SnHUG0+bvwMD/rMDna4+o0KtHs0b46h+9MOeBfhjcltMaiYiclWt994ewOnLkCB5//HEMGDAAzqh700aIDPBEUlYBVuxJw1WdI/QeEhEREVG13nrrLVW5Jf29JDCQJvd33nknPv3007JjJBDr0aMHXn31VXVden/t2LFDnSfefvvtVd6vVJ3JuWTFii+GXwByUoGNHwGbPgHOnNb2eQYCPe8Get0L+FUOLUkjK6e/v/IAfvz7BIpMWoVXr5jGapXGfi2DGHYREVHtK77knT85CZITnw4dOqgTG29v70onQWeTdxhvueUWvPDCC2jRogWckawSI03uxYKtXN2RiIiIGk5wcDCMRiNSU1Mr7Zfr4eHa6tNnk0b18+bNQ15eHo4ePYo9e/bA19e30rlcRESEOh+sqH379jh27Fi1Y/Hw8FC9OCpuTi1tNzD/IeDNTsDq/2qhV6MYYOR/gcm7gKHPMvSqwonT+Xh67nYM/s8K1U5EQq/ezRvj23t644f7+qB/q2CGXkREVPuKr4vpDyFkKWypBvvHP/5Ro+anjtr7YUxslFpVZtmeNGQXFMPfk41IiYiIqP65u7uje/fuWLZsWVnLCanWkuuTJk067+dKawvp3VVcXIzZs2fj+uuvL7tNVnSUqv+K9u3bh2bNmtXTV+IgZG2pw39q/bsOLC3f36SX1r+r3dWAi1HPEdqs46fy8d7KA/h58wkUm7RG/31bBKkKrz4tgvQeHhER2Xvwdb7+EPIuYFWkAeonn3yChISEGj+O9H6Q6jBH0z7CD61CfXEgLRe/70zFtd2b6D0kIiIichIyvVCmH8rUxF69eql2FVLNJVX8YuLEiSrgkvMwsWHDBiQmJiIuLk59fP7551VY9uSTT5bd5z//+U/069dPTXWUQEz6vn700UdqoyqYioGdc7UVGlO2l+40AO2vAfo+DDTtrfMAbdexk/mYuWI/5sQnosSsBV79WwXh0aFt0Kt5Y72HR0RENqxeV3XMycnBbbfdhlmzZqkS+5py1N4PUm49JjYSbyzdh/kJiQy+iIiIqMHccMMNSE9Px7Rp01RDewm0Fi9eXPaGpkxPlEp+q4KCAkydOlU1sJcpjiNHjsRXX32FwMDAsmN69uyJuXPnqnM3qfBv3ry5CtSkxQVVUJAFbP4C2PABkJ2o7XP10hrW93kACGqp9wht1pGMPMxccQBztyTCVBp4DWgdrFZp7BHDwIuIiC7MYLFIrXXNpzpKP6+ff/650sqM8u5hZmYm5s+fX+l4qfKSJqfSU8JK3ikUcmIlpfHSKPVCJPiSVX+ysrLsvg+E/PEe/N+VMLoYsH7KUIT4eeg9JCIiIofkSOcPjsyhv0+Zx7WwS0Kvohxtn08o0PteoMc/AG8GN9U5nJGHd5bvx/yEpLLAa2CbEBV4dW/WSO/hERGRHZ0/uNZnfwhZBWj7dmsZt0beOZRKMFkpyBGquGorJtgHsdGB2Ho8E79uT8bt/WL0HhIRERER1aWkBK1/l0xrtJi0fcFtgX6TgM7XA26eeo/QZh1Mz8XM5QfU7IjSvAuD22qBV9emDLyIiKgBpjrWpj+ENEPt1KlTpc+3lsefvd+ZjI6NVMGXrO7I4IuIiIjIAcisBmlUL4HXkQqLOTUfqPXvajVMpjzoOUKbdiAtB+8sP4CFW5PKAq/L24XikaGtERddPr2WiIio3oOv2vaHoHON6hKBl3/Zhc1HT6uVaaIbe+s9JCIiIiK6GMUFwLYfgHXvAhmlK1wajECn8UDfSUBknN4jtGn7U3Pw9vIDWLQtSS12KYa1D1MVXp2bBOg9PCIicrYeX3pxxN4PN89aj7UHT+LJK9viwcGt9B4OERGRw3HE8wdHZLffp/xTwN+fABs/BPLStX3ufkD324He9wOBztfSozb2puTg7WX78euO5LLAa3iHMFXh1SmKgRcREenU44vqzpi4SBV8LUhIYvBFREREZC9OHgTWvwds+QYoOaPt828C9Lkf6DYR8GRocz67k7NV4PXbjpSyfVd2DMfDQ1uhYyT/74iIqO4x+NLJlR0jMHXeDuxJyVHveLUN99N7SERERERUnWMbgLVvA3t+AVBaohTeBej3CNBxLGB003uENm1nUpYKvJbsTC3bN7JzOB6+vDXaR9hRpR8REdkdBl86CfB2w+C2oVi6KxULtibiifB2eg+JiIiIiCoym4A9i7SG9Sf+Lt/fejjQ72EgZgBgMOg5Qruo8JqxdJ865xXy3zWycwQeubw13/glIqIGweBL59UdteArCY8PbwsDT5yIiIiI9FeUp01lXP8ucPqIts/oDnS5QWtYH8o3LGsiOesMJry/FvlFJhV4XdMlEo9c3gqtwxh4ERFRw2HwpSNZscbb3Yjjp85gy/FMdGvaSO8hERERETmvnBRg40da0/qCTG2fVyOg591Az3sAP20Vc6qZOfGJKvRqF+6HmTd3RatQBl5ERNTwGHzpyMvdqFavmZeQpJrcM/giIiIi0kHabmDtTGD7j4CpSNvXqDnQ9yEg7mbA3UfvEdodWTh+TvwJdfmu/s0ZehERkW4YfOlsTFyUCr4WbUvG1Kvbw9XooveQiIiIiByfxQIcXqX17zrwR/n+6N5a/662IwEXo54jtGvbTmThYHoePN1ccFXncL2HQ0RETozBl84uax2MRt5uyMgtxLpDJzGgdYjeQyIiIiJyXKZiYMccYN07QMr20p0GoP0oLfCK7qXzAB3D7NJqrxEdw+HnyRUviYhIPwy+dOZmdFEr23yz4Zia7sjgi4iIiKgeFGQBmz8HNnwIZCdq+9y8ga63An0eABq30HuEDqOoxKwWbxLjuzXRezhEROTkGHzZyOqOEnwt3pGCl8Z2gqcby+qJiIiI6kTmMWD9B0D8l0BRjrbPJxTofS/Q4x+Ad2O9R+hwVuxNQ2Z+MUL9PHBZq2C9h0NERE6OwZcN6BnTGBEBnkjOKsDKvem4shP7IBARERFdsl3zgZ/uBCwm7XpIO6DvJKDL9YCrh96jc1jWpvbjukbB6GLQezhEROTk2EndBri4GDAqNlJdXrC1tPSeiIiIiC5Ns8sAozvQfBBwy8/Ag+uBbrcx9KpHp/OKsHxPmrrMaY5ERGQLWPFlQ9MdP/rzEJbtTkNOQTGbgBIRERFdKp8g4JEtgH+E3iNxGgu3JaHYZEHHSH+0DffTezhERESs+LIVcnLQIsQHhSVm/L4zVe/hEBERETkGhl4Nana8NnuB1V5ERGQrGHzZCIPBgDGxUeqydRUcIiIiIiJ7cSAtF1uPZ6q+XmPitDYeREREemPwZUNGl54grDmQgZO5hXoPh4iIiIio1k3tB7cJQbAv+6gREZFtYPBlQ5oH+6BLkwCYzBb8uj1Z7+EQEREREdWI2WzB3C2c5khERLaHwZcNNrkX8xM43ZGIiIiI7MP6QyeRnFUAP09XDG0fqvdwiIiIyjD4sjGjYiNhMACbjp7GidP5eg+HiIiIiKjGTe2v6RIJTzej3sMhIiIqw+DLxoT5e6JP8yB1eeFWTnckIiIiItuWX1SC33Zo563XdtcWayIiIrIVDL5suMk9V3ckIiIiIlu3eEcK8otMaBbkjW5NG+k9HCIiokoYfNmgqzqFw81owO7kbOxPzdF7OERERERE1ZpTOs1xfNcmMEjPDiIiIhvC4MsGBXq7Y1CbEHWZVV9EREREZKuSs87gr4MZ6vL4bpzmSEREtofBl40aHRdVtrqjxWLRezhEREREROeYt0XOVYFezRsjurG33sMhIiI6B4MvGzWsfSi83Iw4diofW09k6T0cIiIiIqJK5M3Z2fEn1OUJrPYiIiIbxeDLRnm7u2J4xzB1eX6C1jeBiIiIiMhWbE/MwoG0XHi4uuCqzhF6D4eIiKhKDL5s2OhYbXXHRduSYTJzuiMRERER2V5T++Edw+Hv6ab3cIiIiKrE4MuGDWgdgkBvN6TnFGL9oZN6D4eIiIiISCkqMZctwsRpjkREZMsYfNkwdykb76SVjXO6IxERERHZipV703Aqrwghfh64rFWw3sMhIiKqFoMvGzcmTpvu+NuOFBSWmPQeDhERERFR2TTHsXGRcDXyJQUREdku/pWycb1iGiPc3xM5BSVYuTdd7+EQERERkZPLzC/Csj2p6vL4bk30Hg4REdF5MfiycS4uBoyK1aY7WvsoEBERERHpZeG2ZBSbLOgQ4Y/2Ef56D4eIiOi8GHzZgdGxWsPQP3alIrewRO/hEBEREZETmxN/Qn0cz6b2RERkBxh82YFOUf5oEeyDwhIzlu5K0Xs4REREROSkDqbnYsuxTBhdDBhd2ouWiIjIljH4sgMGg0x31E4s5idwuiMRERER6WNuaVP7ga2DEernqfdwiIiILojBl52wvqO2en8GTuYW6j0cIiIiInIyZrMFc7dowReb2hMRkb1g8GUnWob4onNUAExmC37dwemORERERNSwNhw+hcTMM/DzdMUVHcL0Hg4REVGNMPiyI6NLpzsu5HRHIiIiImpgs0ub2l/TJQKebka9h0NERFQjDL7syDWxETAYgI1HtHfbiIiIiIgaQn5RCX7bnqwuc5ojERHZEwZfdiQiwAu9Yhqry4u2suqLiIiIiBrG7ztTkVdkQtPG3ujRrJHewyEiIqoxBl92ZkxclPrI1R2JiIiIqKGnOY7vFqVWHCciIrIXDL7szFWdwuHqYsCu5GwcSMvRezhERERE5OBSsgrw14EMdXl8V05zJCIi+8Lgy8408nHHoDYh6vICVn0RERERUT2bl5AIswXoGdMITYO89R4OERFRrTD4skOj47TVHRdsTYLFYtF7OERERETkoORcc/Zm6zRHVnsREZH9YfBlh4a1D4OXmxFHTuZj24ksvYdDRERERA5qZ1I29qflwt3VBVd3idB7OERERLXG4MsO+Xi4YliHsLKqLyIiIiKi+mxqP7xDGPw93fQeDhERUa0x+LJTY2K16Y4LtybBJE0XiIiIiIjqULHJXNZTdgKnORIRkZ1i8GWnBrYJQYCXG9JyCrHh0Em9h0NEREREDmbV3nSczCtCsK8HBrQO1ns4REREF4XBl52SPgsjO4ery5zuSERERER1bc4WbZrj2LhIuBr5soGIiOwT/4LZsVGl0x1/3Z6MwhKT3sMhIiIiIgeRlV+MP3alqctczZGIiOwZgy871rt5EML8PZBdUII/92XoPRwiIiIichALtyWhyGRGu3A/dIj013s4REREF43Blx0zuhhwTRet6mt+QqLewyEiIiIiBzGndDVHNrUnIiJ7x+DLzo2J04KvP3anIq+wRO/hEBEREZGdO5yRh/hjmXAxlJ9rEhER2SsGX3auc1QAYoK8UVBsxtJdqXoPh4iIiIjs3NzSai9ZRTzU31Pv4RAREV0SBl92zmAwYHRclLrM1R2JiIjofN59913ExMTA09MTvXv3xsaNG6s9tri4GC+++CJatmypjo+NjcXixYurPf7f//63Oi957LHH6mn01BDMZgtmx2stNNjUnoiIHAGDLwcwunR1xz/3peN0XpHewyEiIiIb9MMPP2Dy5Ml47rnnEB8fr4KsESNGIC1NW7nvbFOnTsWHH36Id955B7t27cL999+PcePGYcuWLecc+/fff6tju3Tp0gBfCdWnjUdOITHzDPw8XDG8Q5jewyEiIrpkDL4cQKtQX3SM9EeJ2YJfdyTrPRwiIiKyQTNmzMA999yDO++8Ex06dMAHH3wAb29vfPrpp1Ue/9VXX+Hpp5/GyJEj0aJFCzzwwAPq8htvvFHpuNzcXNxyyy2YNWsWGjVq1EBfDdV3U/uRnSPg6WbUezhERESXjMGXg1V9zU/gdEciIiKqrKioCJs3b8awYcPK9rm4uKjr69atq/JzCgsL1RTHiry8vLBmzZpK+x566CFcffXVle77fOR+s7OzK21kG84UmfDr9hR1eUJ3TnMkIiLHwODLQYwqDb7+PnIKSZln9B4OERER2ZCMjAyYTCaEhVWeuibXU1K0oONsMg1SqsT2798Ps9mMpUuXYs6cOUhOLq8u//7779W0yenTp9d4LHJsQEBA2RYdHX0JXxnVpd93pSC3sATRjb3Qoxmr94iIyDEw+HIQkYFe6NW8MSwWYNE2Vn0RERHRpXnrrbfQunVrtGvXDu7u7pg0aZKaJimVYuL48eN49NFH8c0335xTGXY+U6ZMQVZWVtkm90O2wdrUflzXJnBxMeg9HCIiojrB4MsBpztydUciIiKqKDg4GEajEampqZX2y/Xw8PAqPyckJATz5s1DXl4ejh49ij179sDX11f1+xIydVIa43fr1g2urq5qW7VqFd5++211WSrMquLh4QF/f/9KG+kvNbsAa/anq8vju2orhhMRETkCBl8ORJqQuroYsCMxGwfTc/UeDhEREdkIqdjq3r07li1bVrZPpi/K9b59+573c6WaKyoqCiUlJZg9ezbGjBmj9g8dOhTbt29HQkJC2dajRw/V6F4uS9BG9mN+QiLMFqgpjjHBPnoPh4iIqM641t1dkd4a+7hjQOtgrNibjgUJSfjnFW30HhIRERHZiMmTJ+P2229X4VSvXr3w5ptvqmoumb4oJk6cqAIua7+uDRs2IDExEXFxcerj888/r8KyJ598Ut3u5+eHTp06VXoMHx8fBAUFnbOfbJvFYsHszdo0x/Hd2NSeiIgcC4MvBzMmLkoLvrYm4bFhrWEwsD8DERERATfccAPS09Mxbdo01dBeAq3FixeXNbw/duxYWf8uUVBQgKlTp+LQoUNqiuPIkSPx1VdfITAwUMevgurDzqRs7E3NgburC67uHKH3cIiIiOoUgy8Hc0WHMHi6ueBwRp6a8ti5SYDeQyIiIiIbIQ3qZavKypUrK10fNGgQdu3aVav7P/s+yD7MKW1qf0X7MAR4u+k9HCIiojrFHl8OxsfDFcPah5X1aiAiIiIiqk6xyYwFW7Vzxgnd2dSeiIgcD4MvB17dceG2JJikSykRERERURVW709HRm4Rgn2lV2yI3sMhIiKqcwy+HNCgtiHw93RFanYhNh4+pfdwiIiIiMhGWZvaj46NgpuRLw2IiMjxXNRft3fffRcxMTFqeevevXtj48aN1R47Z84ctXqQNEKVlX6kkao0RqX64+FqxFWdtMak1tJ1IiIiIqKKsvKLsXR3qro8vhunORIRkWOqdfD1ww8/qOWwn3vuOcTHxyM2NhYjRoxAWlpalcc3btwYzzzzDNatW4dt27apJbNlW7JkSV2Mn6oxJk6b7vjr9hQUlZj1Hg4RERER2Zhftier88S2YX7oGOmv93CIiIhsI/iaMWMG7rnnHhVedejQAR988AG8vb3x6aefVnn84MGDMW7cOLRv3x4tW7bEo48+ii5dumDNmjV1MX6qRu8WQQj180DWmWL8uS9d7+EQERERkY2ZE3+irKm9wWDQezhERET6B19FRUXYvHkzhg0bVn4HLi7qulR0XYjFYsGyZcuwd+9eDBw48OJGTDVidDHgmi5a1deCrUl6D4eIiIiIbMiRjDxsOnoaLgaZKcBpjkRE5Lhca3NwRkYGTCYTwsLCKu2X63v27Kn287KyshAVFYXCwkIYjUa89957uOKKK6o9Xo6TzSo7O7s2w6RSo+Mi8elfh7F0Vyryi0rg7V6rbzcREREROag5W7Q+sJe1DkGYv6fewyEiIqo3DbJ0i5+fHxISEvD333/jlVdeUT3CVq5cWe3x06dPR0BAQNkWHR3dEMN0OLFNAtAsyBtnik0q/CIiIiIiMpst5dMc2dSeiIgcXK2Cr+DgYFWxlZpaOUSR6+Hh4dU/iIsLWrVqpVZ0/L//+z9ce+21KtyqzpQpU1SVmHU7fvx4bYZJpaRXw+jY0umOCZzuSERERERQUxxPnD4DXw9XDO9Q/Tk8ERGR0wVf7u7u6N69u+rTZWU2m9X1vn371vh+5HMqTmU8m4eHB/z9/SttdGmrO67al47TeUV6D4eIiIiIdGat9hrZORxe7ka9h0NERGRbUx1lmuKsWbPwxRdfYPfu3XjggQeQl5enVnkUEydOVBVbVlLZtXTpUhw6dEgd/8Ybb+Crr77CrbfeWrdfCVWpVagf2kf4o8RswW87UvQeDhERERHpqKDYhF+2JavL47s10Xs4RERE9a7W3c5vuOEGpKenY9q0aUhJSVHTFxcvXlzW8P7YsWNqaqOVhGIPPvggTpw4AS8vL7Rr1w5ff/21uh9quKqv3cnZWLA1ETf3bqr3cIiIiIhIJ7/vSkVOYQmiAr3QK6ax3sMhIiKqdwaLxWKBjZNVHaXJvfT74rTH2kvMPIP+/14OgwFY99RQhAdw5R4iInJ8PH+wD/w+Naw7PtuIlXvT8cjlrTB5eFu9h0NERFTv5w8Nsqoj6Uve0esZ0wgScS7axib3RERERM4oLacAf+5LV5fHcZojERE5CQZfTsK6uuN8ru5IRERE5JTmb0mC2QJ0axqI5sE+eg+HiIioQTD4chIjO0fA6GLA9sQsHErP1Xs4RERERNTAZpeu5sim9kRE5EwYfDmJIF8PDGgdrC4v2MqqLyIiIiJnsispG3tScuBudMGoLtpMACIiImfA4MsJpztK8GUHaxoQERERUR2ZU1rtNaxDKAK83fQeDhERUYNh8OVEhncMh4erCw6l52FnUrbewyEiIiKiBlBiMmNeaZ/X8V05zZGIiJwLgy8n4uvhimHtw9RlTnckIiIicg6r92cgI7cQjX3cMahtiN7DISIialAMvpzM6LjS6Y4JSTDLsj5ERERE5BRN7aXthZuRp/9ERORc+JfPyQxuGwI/T1ekZBdg45FTeg+HiIiIiOpR1pli/L4rVV2+tjunORIRkfNh8OVkPFyNuKpTuLrM6Y5EREREju237ckoKjGjTZgvOkb66z0cIiKiBsfgywmNjo1SH38tPREiIiIiIsee5ji+WxMYDAa9h0NERNTgGHw5ob4tgxDs64HM/GKsOZCu93CIiIiIqB4cPZmHv4+chosBGBunvfFJRETkbBh8OSGjiwHXdIlQl+eXLm1NRERERI5l7pZE9bF/q2CEB3jqPRwiIiJdMPhyUmNKV3dcuisV+UUleg+HiIiIiOqQxWLBnHgt+JrQjU3tiYjIeTH4clJx0YFo2tgb+UUm/LE7Te/hEBEREVEd2nT0NI6dyoePuxHDO4bpPRwiIiLdMPhyUtLcdHSsVvW1gNMdiYiIiBzKnNKm9ld1joC3u6vewyEiItINgy8nNrp0uuOqfWnIzC/SezhEREREVAcKik1YtC1ZXeY0RyIicnYMvpxYmzA/tAv3Q7HJgsU7UvQeDhERERHVgT92pyKnoARRgV7o3byx3sMhIiLSFYMvJ2et+uLqjkRERESOYfZmbZrjuK5RcHEx6D0cIiIiXTH4cnKjumjB1/rDJ5GaXaD3cIiIiIjoEqTlFODP/Rnq8rhuUXoPh4iISHcMvpxcdGNvdG/WCBYLsHArq76IiIiI7JksWmQyW9C1aSBahvjqPRwiIiLdMfgijCmd7sjgi4iIiMi+zYlPVB/Hs6k9ERGRwuCLMLJzBIwuBmw9kYXDGXl6D4eIiIiILsLu5GzsSs6Gm9GAUV0i9B4OERGRTWDwRQj29UD/VsHqMqu+iIiIiOzTnHitqf3QdmEI9HbXezhEREQ2gcEXKWNiras7JsIiDb+IiIiIyG6UmMyYV7pK93g2tSciIirD4IuU4R3D4OHqgoPpeapEnoiIiIjsx5oDGUjPKURjH3cMbhuq93CIiIhsBoMvUvw83TC0fWjZakBEREREZH9N7UfHRsLdlaf4REREVvyrSGXkRMna58ts5nRHIiIiInuQXVCMJTtT1GVOcyQiIqqMwReVkbJ4Pw9XJGUVYNPR03oPh4iIiIhq4LftySgsMaNVqC86RwXoPRwiIiKbwuCLyni6GTGiU3hZk3siIiIisn2zS6c5TujWBAaDQe/hEBER2RQGX1TJmDhtuuOv25NRbDLrPRwiIiIiOo/jp/Kx8fApSN41tqt2HkdERETlGHxRJX1bBCHY1x2n84uxZn+G3sMhIiIioho0te/fMhgRAV56D4eIiMjmMPiiSlyNLrimi/Zu4YKtXN2RiIiIyFZZLBbM2XJCXWZTeyIioqox+KJzjCpd3VFWBzpTZNJ7OERERERUhfhjp3H0ZD683Y24srRPKxEREVXG4IvO0a1pIJo08kJ+kQnL9qTqPRwiIiIiOk9T+6s6RcDb3VXv4RAREdkkBl90DlkNaHRp1df8BE53JCIiIrI1BcUmLCptSzGB0xyJiIiqxeCLqjQmTjuBWrU3HVn5xXoPh4iIiIgqWLY7DdkFJYgM8ESfFkF6D4eIiMhmMfiiKrUN90PbMD8UmcxYvDNZ7+EQERERUQVz4rWm9uO6RcHFxaD3cIiIiGwWgy+q1ug4ru5IREREZGsycguxcl+6ujyuaxO9h0NERGTTGHxRtax9vtYePIm07AK9h0NERERE8qZkQhJMZgtiowPRKtRX7+EQERHZNAZfVK3oxt5qhUeLBVi0jdMdiYiIiGzB7NJpjmxqT0REdGEMvui8ylZ35HRHIiIiIt3tScnGzqRsuBkNGNVFO08jIiKi6jH4ovO6ukskpF/q1uOZOHoyT+/hEBERETm1ufGJ6uPl7ULRyMdd7+EQERHZPAZfdF4hfh7o3yq4rJ8EEREREelD+nrN3aIFX+O7sak9ERFRTTD4olpNd7RIwy8iIiKyS++++y5iYmLg6emJ3r17Y+PGjdUeW1xcjBdffBEtW7ZUx8fGxmLx4sWVjpk+fTp69uwJPz8/hIaGYuzYsdi7d28DfCXOac2BDKTlFCLQ2w1D2obqPRwiIiK7wOCLLmhEp3C4u7rgQFoudifn6D0cIiIiugg//PADJk+ejOeeew7x8fEqyBoxYgTS0tKqPH7q1Kn48MMP8c4772DXrl24//77MW7cOGzZsqXsmFWrVuGhhx7C+vXrsXTpUhWWDR8+HHl5bI9QH+aUNrWXNyXl3IyIiIguzGCxgxKe7OxsBAQEICsrC/7+/noPxynd/9VmLN6ZgvsHtcRTV7XTezhEREQXxPOHyqTCS6qzZs6cqa6bzWZER0fj4YcfxlNPPXXO8ZGRkXjmmWdUsGU1YcIEeHl54euvv67yMdLT01XllwRiAwcOrNG4+H2qmZyCYvR85Q8UFJsx/6H+iI0O1HtIREREuqnN+QPfKqIaGROnTXdcuDUJZrPNZ6VERERUQVFRETZv3oxhw4aV7XNxcVHX161bV+XnFBYWqimOFUnotWbNmmofR04+RePGjets7KT5bUeKCr1ahvigS5MAvYdDRERkNxh8UY0MaRcKXw9XJGaeQfyx03oPh4iIiGohIyMDJpMJYWFhlfbL9ZSUlCo/R6ZBzpgxA/v371fVYTKVcc6cOUhOTq7yeDnmscceQ//+/dGpU6dqxyKBmrxLW3GjC5u9+URZU3uDwaD3cIiIiOwGgy+qEU83I0Z0DFeX53N1RyIiIof31ltvoXXr1mjXrh3c3d0xadIk3HnnnapSrCoyJXLHjh34/vvvz3u/0hBfpiZYN5luSed3/FQ+Nhw+Bcm7xnWN0ns4REREdoXBF9XY6NLpjr9sT0axyaz3cIiIiKiGgoODYTQakZqaWmm/XA8P197YOltISAjmzZunGtUfPXoUe/bsga+vL1q0aHHOsRKKLVq0CCtWrECTJk3OO5YpU6aoKZHW7fjx45f41Tm+eVsS1cd+LYMQGeil93CIiIjsCoMvqrH+LYMQ5OOOU3lF+OtAht7DISIiohqSiq3u3btj2bJllaYmyvW+ffue93Olz1dUVBRKSkowe/ZsjBkzpuw2WSNJQq+5c+di+fLlaN68+QXH4uHhoZrQVtyoevJ/PKc0+Brf9fyhIhEREZ2LwRfVmKvRBVd3iVCXF3C6IxERkV2ZPHkyZs2ahS+++AK7d+/GAw88oKq5ZPqimDhxoqrGstqwYYPq6XXo0CGsXr0aV155pQrLnnzyyUrTG2WFx2+//RZ+fn6qX5hsZ86c0eVrdERbjmficEYevNyMuLJT1dV5REREVD3X89xGVOXqjl+uO4olO2VlIZPq/UVERES274YbbkB6ejqmTZumwqm4uDgsXry4rOH9sWPHKvXvKigowNSpU1XwJVMcR44cia+++gqBgYFlx7z//vvq4+DBgys91meffYY77rijwb42Z2hqf1WncPh48NSdiIiotvjXk2qlW9NGiAr0Uqs7LtudVlYBRkRERLZPpiXKVpWVK1dWuj5o0CDs2rXrgtPwqP4UlpiwcKtWZT+hO6c5EhERXQxOdaRakeWzrU3uF2zV+k0QERERUd1bvjsN2QUliAjwRJ8WQXoPh4iIyC4x+KJaGx2rBV8r9qQj60yx3sMhIiIickiz47U3Gcd2jYLRxaD3cIiIiOwSgy+qtXbhfmgT5osik1n1+iIiIiKiunUytxAr96apy+O7Ruk9HCIiIrvF4IsubrpjadUXV3ckIiIiqnsLtiahxGxBlyYBaB3mp/dwiIiI7BaDL7ooo2O1dx7XHsxAWk6B3sMhIiIicihzSqc5TujGpvZERESXgsEXXZSmQd6Iiw6E2QL8si1Z7+EQEREROYx9qTnYnpgFVxcDRpVW2RMREdHFYfAl9i4Gis/oPQq7M6ZsdUdOdyQiIiKqK7PjT6iPQ9qForGPu97DISIismsMvvYtAb67EfjsKiCblUu1cXWXCMgCQ1uOZeLYyXy9h0NERERk90xmC+ZtsU5zZFN7IiKiS8Xgy80b8AoEkrYAs4YAiZv1HpHdCPXzRL+Wwerywm2s+iIiIiK6VNI/NTW7EIHebqrii4iIiC4Ng6/mA4B7VgAh7YCcZOCzkcD2n/Ueld2wru44P0F7Z5KIiIiILr2p/agukfBwNeo9HCIiIrvH4Es0bg78YynQegRQUgDM/gew/GXAbNZ7ZDZvRKdwuBtdsC81F3tSsvUeDhEREZHdyi0sweIdKeryeE5zJCIiqhMMvqw8/YGbvgP6PaJd//M/wI+3AYW5eo/MpgV4uWFw2xB1eX4CpzsSERERXazftifjTLEJLYJ91OrZREREdOkYfFXkYgSGvwSMfR8wugN7FgGfXglkHtN7ZDZtTJz2juSChCRYLBa9h0NERERk19McJ3RvAoPBoPdwiIiInDf4evfddxETEwNPT0/07t0bGzdurPbYWbNmYcCAAWjUqJHahg0bdt7jbULczcDtiwCfECB1OzDrcuDYBr1HZbOGtg+Fj7sRiZlnEH/stN7DISIiIrI7J07nY92hk+ry2K6c5khERKRb8PXDDz9g8uTJeO655xAfH4/Y2FiMGDECaWlpVR6/cuVK3HTTTVixYgXWrVuH6OhoDB8+HImJNt4MvWlvrel9eGcgLx344hpgyzd6j8omeboZMaJjeFnVFxERERHVjrVlRN8WQYgK9NJ7OERERM4bfM2YMQP33HMP7rzzTnTo0AEffPABvL298emnn1Z5/DfffIMHH3wQcXFxaNeuHT7++GOYzWYsW7YMNi8wGrhrCdB+FGAqAuY/CPw+FTCb9B6ZzRkdp63uuGhbMkpMXBSAiIiIqKakVcTszSfUZTa1JyIi0jH4KioqwubNm9V0xbI7cHFR16Waqyby8/NRXFyMxo0bwy64+wDXfQkMfFK7vvYd4LsbgQKuYFhR/1bBaOzjjpN5RfjroFamT0REREQXlnA8E4cy8uDlZsRVnSP0Hg4REZHzBl8ZGRkwmUwICwurtF+up6RoSy9fyL/+9S9ERkZWCs/OVlhYiOzs7EqbrlxcgMufAa79FHD1BPb/DnxyBXDqkL7jsiFuRhdcXXqixumORERERLVvan9lp3D4erjqPRwiIiKH0qCrOv773//G999/j7lz56rG+NWZPn06AgICyjbpC2YTOk0A7vwN8IsA0vdoTe8P/6n3qGxuuuOSnSkoKOZ0UCIiIqILKSwxYeE27U1DTnMkIiLSOfgKDg6G0WhEampqpf1yPTxca25enf/+978q+Pr999/RpUuX8x47ZcoUZGVllW3Hjx+HzYjqpjW9j+wGnDkNfDUO2FR1fzNn071pI9WMNbewBCv2VL3YARERERGVk3OmzPxihPl7oF/LYL2HQ0RE5NzBl7u7O7p3716pMb21UX3fvn2r/bzXX38dL730EhYvXowePXpc8HE8PDzg7+9fabMp/hHAnb8Cna8DzCXAon8Cvz4BmErgzFxcDLgmNqLSykREREREVL3ZpdMcx3aNgtHFoPdwiIiIHE6tpzpOnjwZs2bNwhdffIHdu3fjgQceQF5enlrlUUycOFFVbFm99tprePbZZ9WqjzExMaoXmGy5ubmwa25ewPhZwNBp2vWNHwHfTNCqwJzYmFitRH/53jRkFxTrPRwiIiIim3Uqr6isSn5CtyZ6D4eIiMgh1Tr4uuGGG9S0xWnTpiEuLg4JCQmqksva8P7YsWNITk4uO/79999Xq0Fee+21iIiIKNvkPuyewQAM+D/ghm8ANx/g0Epg1lAgfR+cVfsIP7QK9UVRiRlLdtRswQMiIiIiZ7RwaxJKzBZ0jgpAmzA/vYdDRETkkAwWi8UCGyerOkqTe+n3ZXPTHq1SdgDf3QRkHQM8ArQVIFtXv3KlI3tn2X68sXQfBrQOxlf/6K33cIiIyEnZxfkDOfX3aczMNdh6IgvPjeqAO/s313s4REREDnn+0KCrOjq08E7APcuBpn2Bwizg2+uAde8Btp8r1rlRsdrqjn8dyEB6TqHewyEiIiKyOQfSclTo5epiwOjScyciIiKqewy+6pJvCDBxPhB3K2AxA0umAAsfAUqK4Exign0QGx0IswX4dXv5tFciIiIiqtzUfnDbUAT5eug9HCIiIofF4KuuuXpI3Tow4lXA4ALEfwl8OQbIy4Azsb5zOT9BO6kjIiIiIo3JbMG8Ldo50oRu2sJAREREVD8YfNVX0/u+DwE3/wh4+APH1gKzhgCpO+EsRnWJUP8N8ccycfxUvt7DISIiIrIZ6w+dRHJWAfw9XXF5+1C9h0NEROTQGHzVp9ZXAHf/ATRqDmQeAz4ZDuz5Fc4g1N8TfVsEqcsLtibpPRwiIiIimzF784myvqgerka9h0NEROTQGHzVt5C2WtP75gOBolzg+5uB1TOcoun9mLjIsqW6iYiIiAjIKyzBbztS1OUJ3ZvoPRwiIiKHx+CrIXg3Bm6dA/S8G4AFWPYCMOdeoLgAjuzKjhFwMxqwJyUHe1Ny9B4OERERke4W70jBmWITmgf7oGt0oN7DISIicngMvhqK0Q24+g1g5H8BgxHY/iPw+dVAjvaOnyMK8HZTKxWJBVvZ5J6IiIhozhZtmuP4rlEwSENUIiIiqlcMvhpar3uA2+YCnoFA4iZg1uVAUgIcfXVH6fNlcYLpnURERETVSco8g7UHT6rLY7tyNUciIqKGwOBLDy0GaX2/gtsA2YnAp1cCO+fCEQ1rHwZvdyOOnzqDLccz9R4OERERkW7mbklUbV77tGiM6Mbeeg+HiIjIKTD40ktQS23Fx1ZXACVngJ/uAFZMB8xmOBIvdyOGdwhTlxcksMk9EREROSepfJ8TXzrNsRub2hMRETUUBl968gwAbv4B6DtJu77q38DPdwBFeXAkY+K0Uv5F25JRYnKsYI+IiIioJradyMLB9Dx4urngqk7heg+HiIjIaTD40puLERjxCjDmXcDFDdg1X5v6mKW9I+gILmsdjEbebsjILcS6Q1pfCyIiIiJnYq32GtExHH6ebnoPh4iIyGkw+LIVXW8Fbl8IeAcDKduAj4YAx/+GI3AzumBk5wh1eT6nOxIREZGTKSoxq4V+BKc5EhERNSwGX7akWV/g3hVAWCcgLw34/Gpg6/dwpOmOS3akoKDYpPdwiIiIiBrMir1pOJ1fjFA/D1zWKljv4RARETkVBl+2JrApcNcSoN01gKkQmHsfsPQ5wGzfYVGPZo0QEeCJnMISrNybpvdwiIiIiBp8muO4rlEwuhj0Hg4REZFTYfBlizx8geu/AgY8rl3/603g+1uAwhzYKxcXA0bHRqrL1lJ/IiIiIkd3Oq8Iy/dob/pxmiMREVHDY/Blq1xcgKHPAuM/BowewL7fgE+GA6ePwF6NKg2+/tidhpyCYr2HQ0RERFTvFm5LQrHJgo6R/mgb7qf3cIiIiJwOgy9b1+U64M7fAN9wIG2X1vT+yF+wR3LC1zLERzV4/X1nqt7DISIiIqp3s+MT1ccJrPYiIiLSBYMve9Cku9b0PrIrcOYU8OVoYPMXsDcGg0x31Jrcz+d0RyIiInJwB9JysfV4purrNTpOq3wnIiKihsXgy174RwJ3/Ap0HA+YS4CFjwC/PQWYSmBPrCd9fx3IQEZuod7DISIiIqo3c7doTe0HtwlBsK+H3sMhIiJySgy+7Im7N3Dtp8CQqdr1De8D31wLnDkNe9E82AddmgTAZLbg1+3Jeg+HiIiIqF6YzRbMLZ3myKb2RERE+mHwZW8MBmDQE9qqj27ewKEVwMfDgIwDsBdlqzsmcLojEREROab1h04iKasA/p6uGNo+VO/hEBEROS0GX/aqw2jgriWAfxPg5AHg48uBg8thL6s7Sn636ehpnDidr/dwiIiIiOqtqf01sZHwdDPqPRwiIiKnxeDLnkV00ZreN+kFFGQBX18LbPgIsFhgy8L8PdGneZC6vHArpzsSERGRY8kvKsFvO7RznAndtIV9iIiISB8MvuydbyhwxyIg9mbAYgJ+ewJY9BhQUgR7aHI/P0F7N5SIiIjIUSzekYL8IhNigrzRrWkjvYdDRETk1Bh8OQJXD2Dse8AVL0kTMGDz58BX44C8k7BVV3UKh5vRgD0pOdiXmqP3cIiIiIjqzJwKTe0N0t+BiIiIdMPgy1HISVX/R4CbfwDc/YCja4BZQ4C03bBFgd7uGNQmRF1mk3siIiJyFMlZZ/DXwQx1eVxXTnMkIiLSG4MvR9NmBHD3UqBRDJB5FPj4CmDvYtii0XHayeCCrUmw2HhfMiIiIqKamLdFzmuAXs0bI7qxt97DISIicnoMvhxRaHvg7uVAs8uAohzguxuBv96yuab3w9qHwsvNiGOn8pFwPFPv4RARERFdEnkjb078CXWZTe2JiIhsA4MvR+UTBNw2F+h+p5yGAUunAfMeBEoKYSu83V0xvGOYuvz1+mMoNpn1HhIRERHRRduemIX9abnwcHXBVZ0j9B4OERERMfhycK7uwDX/A0b+FzAYga3fAp9fA+SkwlaMLZ3uODv+BPr9ezlmLN2HlKwCvYdFREREdNFN7Ud0DIe/p5vewyEiIiIGX07S9L7XPcCtswHPAODERmDW5UDyVtiCwW1D8PTIdgj29UB6TiHeXrYf/V9bjge+3oy1BzLY+4uIiIjsQlGJWfUtFeM5zZGIiMhmMPhyFi2HaH2/gloD2SeAT68Eds3Xe1Rqie97B7bE2qcuxzs3dVWNYE1mC37bkYKbP96AYTNW4bO/DiO7oFjvoRIRERFVa9W+dJzKK0KInwcuaxWs93CIiIioFIMvZxLcCrj7D6Dl5UBxPvDjRGDV6zbR9N7d1QWjYiPx4319seSxgbi1T1P4uBtxMD0PLyzchd6vLMOUOduxKylb76ESERERnWP2Zq2p/di4SLgaeYpNRERkK/hX2dl4BQI3/wT0eVC7vuIV4Oe7gKJ82Iq24X54eWxnrH96KF4c0xGtQ31xptiE7zYew8i3V+Pa99difkIiCktMeg+ViIjIrrz77ruIiYmBp6cnevfujY0bN1Z7bHFxMV588UW0bNlSHR8bG4vFixdf0n06qsz8Iizbo/VQndC9id7DISIiogoYfDkjoytw5XRg1NuAixuwcw7w2VVAttaXwlb4ebphYt8Y/P7Pgfj+3j64uksEXF0M2HT0NB79PgH9/70c/1myB4mZZ/QeKhERkc374YcfMHnyZDz33HOIj49XQdaIESOQlpZW5fFTp07Fhx9+iHfeeQe7du3C/fffj3HjxmHLli0XfZ+OauG2ZBSbLOgQ4Y924f56D4eIiIgqMFjsoHt4dnY2AgICkJWVBX9/nkzUqSN/AT/eBuSfBHzDgRu/BZp0h61KzS7A9xuP49uNR5GaXaj2uRiAoe3DcFufZqqnhovsICIip8fzh8qkGqtnz56YOXOmum42mxEdHY2HH34YTz311DnHR0ZG4plnnsFDDz1Utm/ChAnw8vLC119/fVH36ajfp3Hv/YUtxzIx9er2uHtAC72HQ0RE5PCya3H+wIovZxfTH7hnORDaAchN0Sq/tv0EWxXm74lHh7XGmn9djvdv6YZ+LYNgtgBLd6Vi4qcbMXTGKny8+hCy8tkMn4iIyKqoqAibN2/GsGHDyva5uLio6+vWravycwoLC9X0xYok9FqzZs1F36f1fuVkteJmzw6m56rQy+hiwJg4ruZIRERkaxh8EdAoBvjH70CbqwBTITDnbmDZi/K2LWyVm9EFV3WOwLf39MEfkwfijn4x8PNwxeGMPLz8y270nv4Hnvx5K7afyNJ7qERERLrLyMiAyWRCWFhYpf1yPSUlpcrPkSmLM2bMwP79+1Ul19KlSzFnzhwkJydf9H2K6dOnq3dorZtUiNmzufGJ6uOgNiFqRUciIiKyLQy+SOPhB9z4DXDZP7Xrq98AfrgVKMyFrWsV6ofnR3dUzfBfHdcZ7cL9UFBsxo+bTmDUzDUY++5faqWlgmI2wyciIqqpt956C61bt0a7du3g7u6OSZMm4c4771RVXZdiypQpalqCdTt+/DjsldlswdwtWvA1vhurvYiIiGwRgy8q52IEhj0PjPsIMHoAe38BPh0BnD4Ke+Dj4YqbezfFb48OwM/398Xo2Ei4GQ1IOJ6J//tpK/pOX4bpv+3G8VO2s4IlERFRQwgODobRaERqqrbyoJVcDw8Pr/JzQkJCMG/ePOTl5eHo0aPYs2cPfH190aJFi4u+T+Hh4aF6cVTc7NWGw6fUIjt+nq4Y1r5y5RsRERHZBgZfdK7YG4A7fgF8QoHUHcCsIcDR6nt12BqDwYAeMY3x9k1dsfapoXhiRFtEBnjidH4xPlx1CAP/swJ3frYRK/akwSQNwoiIiBycVGx1794dy5YtK9sn0xflet++fc/7udLnKyoqCiUlJZg9ezbGjBlzyffpKObEn1Afr+kSAU83o97DISIioiow+KKqRfcE7l0BhHfRVnz8YhQQ/xXsjfTaeGhIK/z55BB8dFt3DGgdDFnHdMXedNz5+d8Y/N8V+HDVQZzKK9J7qERERPVq8uTJmDVrFr744gvs3r0bDzzwgKrmkumLYuLEiWoaotWGDRtUT69Dhw5h9erVuPLKK1Ww9eSTT9b4Ph1ZflEJft2u9Tub0K2J3sMhIiKiarhWdwMRApoAdy0G5j0A7JoPLJgEpO8BrnhRmxZpR1yNLhjeMVxt0gD/6/VH8dOm4zh+6gym/7YHbyzdp96tva1PM8RFB6qqMSIiIkdyww03ID09HdOmTVPN5+Pi4rB48eKy5vTHjh2r1L+roKAAU6dOVcGXTHEcOXIkvvrqKwQGBtb4Ph3Z7ztTkVdkQtPG3ujerJHewyEiIqJqGCwWqX+xbbLMtaz6Iw1Q7bkPhN2S1R3/fB1YOV273moYcO2ngGcA7NmZIhMWbk3Cl+uPYEdi+VLqnaMCVAA2KjYSXu72FfAREVE5nj/YB3v9Pt32yQas3p+Bx4a1xmPD2ug9HCIiIqeSXYvzB051pAuTd38HPwVc9wXg6gUc+AP4eBhw8iDsmYRa1/eMxsJJl2Hug/3Uakzuri7YnpiFJ2dvQ5/py/Dyol2qQoyIiIjIKiWrAH8dyFCXx3flNEciIiJbxuCLaq7jWG3qo38UkLEPmHU5cGgV7J1Ma+zatBFmXB+H9VOG4qmr2qFJIy9knSnGx2sOY8h/V2LipxuxdFcqm+ETERER5iUkQk4JesY0QtMgb72HQ0REROfB4ItqJzIOuGc5ENUDKMgEvhoHbJwFR9HYxx33D2qJVU8Mwad39MCQtiGQdl9/7kvHPV9uwsDXV+DdFQeQkVuo91CJiIhIB9IlZPZmbTVHNrUnIiKyfQy+qPb8woE7fgG63ABYTMCvjwOLJgOmYjgKo4sBl7cLw2d39sKqx4fgvoEtEOjthsTMM/jPkr3oO30ZHv1+CzYfPaVOgImIiMg57EzKxv60XNUeYWSXCL2HQ0RERBfA4IsujpsnMO5DYNjzMlkQ2PQJ8PV4IP8UHI1MYZgysr2aBvnGdbGIjQ5EscmC+QlJmPD+Oox8ew2+3XAMeYUleg+ViIiI6tnseK3aa3iHMPh7uuk9HCIiIroABl908WQO4GX/BG76DnD3BQ7/qfX9St8LR+TpZsSE7k0w/6H+WDCpP67v0QQeri7YnZyNp+duR59Xl+H5BTtxIC1X76ESERFRPSg2mbEgIUld5jRHIiIi+8Dgiy5d26uAfywFApsCpw9rKz7uXwpH1qVJIF6/NhYbnh6KqVe3R0yQN3IKS/D52iMYNmMVbp61Hot3JKPEZNZ7qERERFRHVu1Nx8m8IgT7emBA62C9h0NEREQ1wOCL6kZYB+CeFUCz/kBhNvDt9cDamdIBFo4s0Nsddw9ogeX/Nxhf3NULw9qHwcUArD14Evd/HY/LXluBt5ftR1p2gd5DJSIioks0Z4s2zXFsXCRcjTyNJiIisgf8i011xycYuG0e0G0iYDEDvz8DzJ8ElDj+CoguLgYMahOCj2/vgT+fHIIHB7dEkI87UrILMGPpPvT793I89G081h86yWb4REREdigrvxh/7EpTl8dzmiMREZHdYPBFdcvVHRj1NnDla4DBBUj4GvhiNJCbDmfRpJE3nryyHdZOuRxv3RiH7s0aocRswS/bknHjR+sx4s0/8dW6I8hlM3wiIiK7sWh7EopMZrQL90OHSH+9h0NEREQ1xOCL6qfpfZ/7gVt+BjwCgOPrgVlDgJQdcCYerkaMiYvC7Af64ZdHLsNNvZrCy82Ifam5eHb+TvR+5Q88O28H9qbk6D1UIiIiuoDZm7Vpjtd2Z7UXERGRPWHwRfWn1VDgnmVA45ZA1nHgk+HA7kVwRh0jAzB9fGesf3oonhvVAS1CfJBXZMJX64+qCrDrP1yHRduSUFTCZvhERES25nBGHuKPZao+nqPjIvUeDhEREdUCgy+qX8GttfCrxWCgOA/44Rbgz/86fNP76gR4ueHO/s2xbPIgfHN3b1zZMRxGFwM2Hj6FSd9uQf/XlmPG73uRnHVG76ESERFRqbnxWrXXwDYhCPXz1Hs4REREVAsMvqj+eTUCbpkN9LpPu778JWD23UCx84Y7BoMB/VsF44PbumPNv4bgkctbqaXR03MK8fbyA2o1yPu/2oy1BzLYDJ+IiEhHZrMFc7Ykqstsak9ERGR/GHxRwzC6AiNfB675H+DiCuz4GfhsJJCdDGcXEeCFycPbYu1Tl+Odm7qiV/PGMJktWLwzBTd/vAHDZqzCZ38dRtaZYr2HSkRE5HQ2HjmFE6fPwM/DFcM7hOk9HCIiIqolBl/UsHrcBdw2T6sCS4rXmt4nxus9Kpvg7uqCUbGR+PG+vljy2EDc1qcZfNyNOJiehxcW7kKfV5dhypzt2JWUrfdQiYiInMac0mmOV3eJgKebUe/hEBERUS0ZLHYwjyo7OxsBAQHIysqCvz+Xj3YIpw4D390IpO8BXD2B7ncCfuGATzDgEwJ4y8cg7aO7j7ZSpBPKKSjGvC2J+HLdUexPyy3b36NZI9zWtxmu7BSuVo8kIqJz8fzBPtjy9+lMkQk9X/kDuYUl6o0pqcomIiIi+zp/cG2wURFV1Lg58I+lWq+v/UuADe9Xf6wEYyoMC9KCMRWKBVe+bA3K5Dh3X4cJyvw83XBb3xjc2qcZNhw+pVaBXLIjBZuOnlZbsK87bugZjZt7N0NUoJfewyUiInIov+9KUaFXdGMv9aYTERER2R8GX6QfT3/gpu+Ard8BabuBvAwgLx3Il48ntY8lBdqWdVzbasLoURqGBVVdQVa2r/R2D3+bD8qkGX6fFkFqS80uwPcbj+PbjUeRml2Id1ccxPsrD+LydmGqCmxAq2C4yHrrREREdElmx2tN7cd1bcK/rURERHaKwRfpy8UIdL216ttkFm5RXmkYdlILxlQollHFPrmcDpScAUyFQHaittWE0f3carKKFWRn7/MM1DUoC/P3xKPDWuPBIS2xbHeqmga59uBJ/LE7VW0xQd6qQuy67tEI8HbTbZxERET2TN5oWrM/XV2e0C1K7+EQERHRRWLwRbZLwiUPX22TqZE1oYKys8Iwa1hmDcoqVpUV5wGmIiAnWdtqwsWtQlBWg6oyCcpc6n4dCTejC67sFKG2A2k5+Hr9MczefAJHTubj5V9247+/78Xo2Ejc1icGnZsE1PnjExERObL5CYkwW7S+ms2CfPQeDhEREV0kBl/kWKQRvmyNmtXs+OIz1VSQWcOyih9PAkU5gLkYyE3RtpowGCsEZDWoKpMVL2sZlLUK9cPzozviiRFtMT8hCV+uO4I9KTn4cdMJtcVGB2Jin2ZckYqIiKgGZO2n2Zu1yvHx3ZroPRwiIiK6BAy+yLm5eQGB0dpWE8UF54Zh1VaVnQQKswGLCchL07aaMLhoQVlZGFaxquzsBv8hpUGZFmb5eLji5t5NcVOvaGw+elo1w/91ezK2Hs/E/x3PxMu/7ML1PaNxa+9miG7sfQn/cUR2rjAXOLYeOLpGq+LscgMQ3ErvURGRjdiVnI29qTlwd3VRbxoRERGRkwVf7777Lv7zn/8gJSUFsbGxeOedd9CrV68qj925cyemTZuGzZs34+jRo/jf//6Hxx577FLHTaQPN08goIm21URJ4blh2PmqygqyAItZO1Y2rbXIBRgA78aVVrs0eAejh08wejQPxkstA7D8uBk/7z6DPTme+HhVAT768xAGtwnBxL4xGNgmBEY27CVHJ9WdxzcCR1YDh/8EEjcD5pLy2/98HWjaD+g2EegwBnBnMEzkzKzVXld0CEOAF/tlEhEROVXw9cMPP2Dy5Mn44IMP0Lt3b7z55psYMWIE9u7di9DQ0HOOz8/PR4sWLXDdddfhn//8Z12Nm8g+uHoA/pHaVhMlRVo4Vm1fsrP2FWTKhIzSzzkJZOw95y79AYwt3eAJmGFAlsUHJw/7q221WyAQ3Bqng7qjMLw7/AJD0NjHHcG+7upjoLc7gzGyP/JcknDLGnRJ6CULX1QU2BSIGag9t/b/Dhxbq22/PQl0vhboehsQ2dXmV30lorpVbDJjwVYt+GJTeyIiIvtnsEgTg1qQsKtnz56YOXOmum42mxEdHY2HH34YTz311Hk/NyYmRlV71bbiKzs7GwEBAcjKyoK/v7yMJyLFVAzknzrPapdnVZWdOa0FZdUwWwzYa4nGRnNbbDK3xd/mtkgzBKGRtxaCBfm6I8jHo8Jl2e9RdjnI1wOBXm5c8p0anqkESN4KHPlTC7pkGmNxfuVj/CKAmAFA84FA8wFAo5jy27KTgIRvgC1fA6ePlO8P66xVgXW5TptWTHaF5w/2wda+T8v3pOKuzzepN4DWTRmqFpMhIiIi+z1/qFXFV1FRkZqyOGXKlLJ9Li4uGDZsGNatW4e6UlhYqLaKXxARVcHoBviFaVtNwwEJv0oryIqy0rD74EEYU7cjPHMLgotOoL3hGNq7HMPtWKo+5YQlGH8XtcWmgrbYmN4OGyyRsKD6FwGSeUlQJmGYCshKgzEtLPMoDcusFWUMyugimc1A6o7yiq6ja7WeehVJT7yyoGsgENSq+uotqcoc+ARw2f9p97nlK2DXAiB1O/DbE8DvU4EOo7UqMLnPeliplYhsw+x4rdprdGwUQy8iIiIHUKvgKyMjAyaTCWFhlV9ky/U9e/bU2aCmT5+OF154oc7uj4hKGV0B3xBtk0UwAcTGVbg9JxU4vl6rljm2DpbkbWiCDDQxZmCc8S91SKGrP074dcF+z07YZmiPLaYYpOYBJ/OKkHWmWC39LpdlqwnJvCQIs27WcExVllkryUorzBiUOTEpTk7fWxp0rQKOrCmtYKzAMwBodll5RVdI+9oHVHJ8i0HadtUpYPtPQPyXWsgml2WTSjEJwOJurvk0ZiKyC1n5xVi6K1VdntCd0xyJiIgcgU2u6igVZdJHrGLFl0ynJKJ6JpVj0thbNpkLLSvfnfi7LAiTyx7F2Wh5eg1aYg2ulIOMHkBUdyCuD0qa9Mbpxl2RYfLCqbwiZOQWqo/aZflYiJPqY1GloExuk60mpN9YI2+3smoyCceCS6dcll8un5YpTYkZlNlp0HXqUHlF1+HV566M6u4LNO1bHnSFdylb4bROyKIRve8Det0LJG3RqsC2/6xNhVz+ErDiFaD1cC0EazNCq8AkIrv2y/ZkFJWY0S7cDx0i9J92SURERA0cfAUHB8NoNCI1VXsnzEquh4eHo654eHiojYh05uELtByibdaeYinbS4MwaQS+Xps2WdoUXH6hhMCAkNAOQLO+WijRqQ8Q0LzaBsKnS0MwCcRO5mlBmXZZPlYMzgqRXVACk9lSISjLrXFQdr7eZBUrzRiU6SjzeOWgK/tE5dtdPYHo3uVTF6XxfEOETTI9Mqqbtg1/Bdg1X6sCk5/7fYu1zSdUqwCTECy4Vf2PiYjqxZx47ffO+G5RMHBhCyIiIucLvtzd3dG9e3csW7YMY8eOLWtuL9cnTZpUX2MkIlshIYM1AOj7YHlVjlSDyXZ0HXDqIJC2U9v+/lj7vIBooGmf0q1v2RQ06Z0S6u+ptpqwBmVa9ZgWlJVXkFWuJjtZZVCGGgZlFadXljfuP/uy9Cnz92RQdtFkaq116qIEXacPV77dxQ1o0rO8oksuy0qpenL3BuJu0raM/VoVWMK3WjXaX29qW7P+WgAmlZNyPBHZhaMn87Dp6Gk1BX9sHKc5EhEROe1UR5mCePvtt6NHjx7o1asX3nzzTeTl5eHOO+9Ut0+cOBFRUVGqT5e1If6uXbvKLicmJiIhIQG+vr5o1YrvihPZNXk3PKiltnW9VduXm1ZaEVY6PVJW2ss6DmyX7afyXkzRFYIwqdxxu3D4VdugTKarnM6vqpqscmXZqXOCskK1oXJx63mDMq1Zv7Vxf/XVZU4dlMkKpCroKq3qythb+XaDUftZsAZd8jNiy8FRcGvgiheBy58F9i3RqsAOLAWO/qVtvz0JdL4O6HYbEBFXfWN9IrKppvYDWofU+O8MERER2T6DxSIlG7Uzc+ZM/Oc//0FKSgri4uLw9ttvo3fv3uq2wYMHIyYmBp9//rm6fuTIETRvfu40p0GDBmHlypV2ucw1EdWC9AlL3FQehB3/GyjOq3yM0V3rE2YNwqJ7AV6NGnyo1qCsqt5k5Ze1kEwCs5yCklo/hgRl1sox6xRLadjv4+EKP09X+Hq4qsu+1s2z8mVvN6P9BGcFWdpqi9agS1ZIrMQAhHcun7oo33tPO/8dn50EJHwDxH8FZB4t3y9fZ9eJQJfrdPnZdlY8f7APtvB9MpstGPifFThx+gzeujEOY1jxRURE5DDnDxcVfDnjCRER1RFTiRaASBB21Non7Kym5UL6hKkgrJ/2MdD2FrgoLDHhdF5xFRVkhedUk11sUHY2KRrycbcGZEb4errBz3rZww2+ap/rOZfldr/Sj7JPLnu6udRtD5uiPC3ctAZdyQmAxVz5GJnmKtVcEnTJlEBpIO+IzGatuk2qwHYvBEyF2n5ZDKLDaKDbRG0FytquOkm1wvMH+2AL36eNh0/h+g/Xqd+tfz8zDF7udbhQBhEREel6/mCTqzoSkQMzumrT2WTr80CFPmGlFWHy8eR+IG2Xtm36VPs8/yblfcKa9SvrE6YnD1cjwgNk86x1UGbtR2Zt2p9bUIK8whLkVtwKKl+XaZjy32W9fqmkcOzsyrKzK8+0UO2syjPrdWMJAk8mwCdpLYxH18CQuBkwF1d+kMYty4OumAGAbyicgvxsthikbTLFU6b5SgiWukO7LFujGK0XWNwtgH+E3iMmcmrWpvYjO4cz9CIiInIwrPgiItuTmw4cP6tPmPmsoMcjAGjau0KfsG416hNmr+RXdWGJWVWNWQOynAuFZXJ70VnHye1FJSpAqy03lKCL4SD6uuxCP5ed6O6yHx6GykFXiiEE29y6YK9XVxzy7YZin4gqp21aw7SyYK00cJPL0svNIcl/etIWLQDb/jNQlKPtN7gArYdrVWDysSFWqnQSPH+wD3p/nwqKTej58h/IKSzB9/f2QZ8WQQ0+BiIiIqodTnUkIsciU+ikmkhWjZQg7MTfQFHuuX3CJPyq2CfMUafRXSL5tZ9fZKockEmQdnZYVlAI/9O7EJX5N5rnxKNVwXZ4Wgoq3VeqJRDrzB2w1txRfTxukYquS5s+6eHqcm4VWun1SlVoVfRAqxSsebiqnmo2+zO9a74WgsnPtJVPKBB3sxaCyaIRdEl4/mAf9P4+LdiahEe+24Imjbzw5xND7KePIhERkRPLZvBFRI7fJ2xHhemR64DcKpZglOmQ1iCsWV8gIJor612oL1XazvIeXdKDrTCr8jHeQUDMZaVTFwfC1LgV8opNWkVZgRaeWS9XqkIrKt+nQrbSgK3icQXFZ/UDqwPe7sYLhmUStBkNBvVi10U+GrRFCKT/mdGACvsNkGI0bb+h9BjtWOvt1s9Vl9VHqGPV55Ret96fth/wyjqIRnt/gP++n+F6JqNs7IVR/9/enUBHWd57HP/NBANhSVjCHpDFAhKRVSOLC0rhAEW91lvFyOZ+3cvVHtQWxI32aC29VqjneIq2RwrUKvQi4lUULVe4LAFEFBBFCIRVLWExJEzmnud5mYQhwQ0yz8z7fj/nvGcW3sk8eWcy8+f//p//c4EOn5Ov0k4/USi9rjfG439uNc97Wnu2+QDxQ2pw/TqNnbFcizfu1d2XnqXxgzsn/PkBAMD3R+ILQLCYj7GvtsT3Cdu3qep+ma0rE2Hm0jTQD6cF+7iZ42SSXGb7fIn09ZdVp5S261/Zo8ses5qZilgWKY+buhlLkMUly46YSrWyY/tEdLDkuOtHyuy/HygpU1kk6b/aqqilo7osvFo/S1usS8JrlBbyfofiaIbmRfprVmSg1kerrpJ8IpMAiyXVKpNjXsKtMglnerzFEmmqklSrTPRVn6w7eXLv2GMr9tdJk3UVicNQSLde3FFNG9Q+7ceU+CE1uHyd9hwo0QVPLFJ5VHrnvkvUPrteQp8fAAD8MDS3BxAs5n/PjTt4m5kmZhzaF58IMysMFu+QPvy7t8WSOmZKZCwZ1tr0CcuQ7xOEsYous+rgiZVyZ9TzquNiia6W3ROWHDS9vRrWTbfbqTILCdgkma1CKzsuYebdd/x1s295NGoL3iLm0l43l95tc34ocuy2d39UJq9Web/3WO/+E/Y7tiBBlZ9r7z9x/zO0tLyv3o/2VXZ0ny6PvqurQ++oTWiPRtV6y27ry8+0CbB5kX4qVv1qf3fzsyNKrcTfyLy2NZL4Ar7NP9YU2b/BXm0bkvQCAMCnSHwB8Kd62dLZP/E2o/Sw1ycsNjWycLk3jW/zm95mhM/wVps0iR/bJywv9fuE7d/uJbo+P5bs2l8Y/++16njJv2NTF23yzwfN1c2Km2ZrXO/Uk2ju5HvTTz9/Tyr4i6If/0O52qpHwy/okYxZinT+icq6X6/SnH6KKlSZnIuekHSLJeeOJdmqS9Z5STxVSfp9e7Ku+uRe1bHoGxOKDTNS/z2H1PTyKm81x5/2znE9FAAAUEOY6ggguH3CTD8rUw1melmZy4O7qu7XtMtx0yP7Sg3bJnefsIN7Kqu5zOWXn8X/e7iWlHOeV81lkl3muo9Xw/SVw19K6/4mrXrRe+/GNGov9bxe6pEvZbZ0OcKkQ/yQGly9Th8VFWvYf/1T6WlhrXhokLLqkoAFACBV0OMLAH7QNMDPT+gTtrHqfg1axfcJa57rtk+YSYaY3lyxRNfeDfH/Hgp7VWyxRJcZczrTeVL+vVpUYKvAtO5lqfRA5Wv9oyFSr1HSjwb7onLvVBE/pAZXr9Nj8z/S80u2aFi3FpqW3zthzwsAAE4diS8AOB1Mn7DC/6tMhBWtlsqPxu9TO/OEPmG9a7ZPWEmxV6FmE13vSrs+NJmQ+H1adPOmLZpEl5m2WSer5sYDt0oPSR/Nkwr+7L1PY+o3l7qPlHqNlpp0VFARP6QGF6/T0Ui5LpjytvYdPKLnR/fRoK7NE/K8AADg9CDxBQA1oaJP2LLKPmGxapsY2yesR+XUSNMnrF6TU3jOQ97zxSq6itZI0UjV6Zixiq52A1K/Lxl+mL2bpNV/kdb+VTq0t/L+M/t7CbCzL5fS6ypIiB9Sg4vX6Z0NezTuhRVqUi9dyx68zC6uAQAAUgeJLwBIhPKItPtYn7Bt70tbl1bfJyy7c/z0yEbtTt4nrKxE2r6iMtG1faVUXha/j1m9siLRdaHUgEoFHCdSJm1a6FWBbX5LipZXrmLa7WovCWaSswFA/JAaXLxOd84s0PwPdmpc/3aaNCI3Ic8JAABOHxJfAOCC+Tj919bKijCTCKu2T1jL+ESYSXaZlftMostUkR0tid8/M8dLctntQimL1cfwHe3fIa2ZKa3+s/SvbfHTYXuN8RJhGY3kV8QPqSHRr9P+r8t03uNvqfRouebfNUDntGY6OAAAqYbEFwAki0NfVNMn7IQKrhOZ/kyxii6T6DKr9iXzSpJIfuXlXnLVVIF9/N9SpNS7v1YdbwqkqQIz02R99j4jfkgNiX6dZi3fpgmvrFOn5vX1xr0XKeSz9z0AAEFQ/D3ih1oJGxUABJHp79VlmLcZZV8f6xN2LBFmKrzMqpAVia6LpOxOvktAwLFwWOpwibeZlUA/mOMlwfasl9bN8TaTYDUrQna/Tsps6XrEQI35e8F2e/nTXjkkvQAACAASXwCQSGbFR1NZYzbDFN3yHy8kkln84ILbpLxbpaICLwG27u/SV1ukRY9Ibz8u/WiwVwVmLtMIFeAfW784pBWff6VwSLqyZ2vXwwEAAAlANAsALpH0gsv3Xuve3jbkCWn9XG9VSFONuOl1bzPTbntcJ/UcJTXp6HrEwCl7dfUOe9n/rGw1z6zjejgAACABWLsZAICgS68n9cyXblgo3bFC6neXVDdbOrhbWvI76Zle0ozh0tpZUulh16MFfhDT1vaVgh0V0xwBAEAwkPgCAACVmnaSBj8mjf9Y+tlfvOmOobC0dYn06q3Sb7tIr/2nVLTG9UiB72XV1q+07cvDqpeepsG5zV0PBwAAJAhTHQEAQFW10qWul3vb/u3SmpneVMh/bZNWPO9tLc71eoF1u1rKaOR6xMB3amo/rFtL1U0nBAYAICio+AIAAN8sK0e6+BfS3WulUXOlc34qpaVLuz6QFtznVYG9cou05Z/egg1Akikpi2j+Bzvt9auY5ggAQKBwugsAAHw34bDUcaC3Hf5S+mCOtyrknvXSB7O9rXEHqef1Uo98qUEL1yMGrLc+3q0DJUfVumGG8to3dj0cAACQQFR8AQCA769uY+mC26T/+F/p5rel3mOl9AbSl59Jix6Rnu4q/XWktGGBFDnqerQIuFhT+3/r2VrhMKvpAgAQJFR8AQCAHy4Uklr39rYhT0jr53pVYIXLpI0LvK1+C6nHSKnnKKlJR9cjRsDsPXBE727aa69f1au16+EAAIAEo+ILAACcHun1pJ750o1vSHeskPrdJdXNlg7ukpb8TnqmlzRjuLR2tlT2tevRBtKzzz6rdu3aqU6dOsrLy9Py5cu/cf+pU6eqc+fOysjIUJs2bfTzn/9cJSUlFf8eiUT0q1/9Su3bt7f7dOzYUY8++qiiSdTrbd6aHYqUR9WzbUN1aFrf9XAAAECCUfEFAABOv6adpMGPSZdOlDYt9KrAPl0kbV3ibQvul879d29VyJbdXY82EGbPnq3x48frj3/8o016maTWkCFDtHHjRjVr1qzK/jNnztSECRP0pz/9Sf369dOmTZs0duxYhUIhPf3003af3/zmN5o+fbpefPFF5ebmauXKlRo3bpyysrJ09913K5mmOdLUHgCAYKLiCwAA1Jxa6VLXy6XrX5buXScNfEjKaisd2S+teF567iJp0/+4HmUgmGTVzTffbBNTXbt2tQmwunXr2sRWdd5//331799f1113na0SGzx4sEaOHBlXJWb2ueKKKzR8+HC7z9VXX233+7ZKskT5eGexPtpZrPS0sEac29L1cAAAgAMkvgAAQGJk5UgX/0K6Z600aq6Ue5XUoKXU4WLXI/O90tJSrVq1SoMGDaq4LxwO29tLly6t9jGmyss8JpbE+uyzz7RgwQINGzYsbp9FixbZajBj7dq1WrJkiYYOHXrSsRw5ckTFxcVxW035uiyi89o10mVnN1PDuuk19jwAACB5MdURAAAkVjgsdRzobUdLvaow1Kh9+/bZflzNmzePu9/c3rBhQ7WPMZVe5nEDBgywPbuOHj2q2267TQ8++GDFPmYqpElcdenSRWlpafY5Hn/8ceXn5590LFOmTNHkyZOVCL3aNtLfbuun0qPlCXk+AACQfKj4AgAA7pD0SlqLFy/WE088oWnTpqmgoECvvPKKXnvtNdu8PmbOnDl66aWXbD8ws4/p9fXUU0/Zy5N54IEHtH///oqtsLCwxn+X9FqEvAAABBUVXwAAAD6XnZ1tK7J2794dd7+53aJFi2ofY1ZrHDVqlG666SZ7u1u3bjp06JBuueUWPfTQQ3aq5P3332+rvq699tqKfbZu3WqrusaMGVPtz61du7bdAAAAEoHTXwAAAD6Xnp6u3r17235cMeXl5fZ23759q33M4cOHbXLreCZ5Zpipj9+0j/nZAAAAyYCKLwAAgAAYP368rcLq06ePzj//fE2dOtVWcJlVHo3Ro0erdevWtlrLGDFihF0JsmfPnsrLy9PmzZttFZi5P5YAM9dNT6+2bdsqNzdXq1evto+54YYbnP6uAAAAMSS+AAAAAuCaa67R3r17NXHiRO3atUs9evTQwoULKxreb9u2La5665e//KVCoZC93LFjh5o2bVqR6Ip55plnbDLs9ttv1549e9SqVSvdeuut9jkAAACSQSgaq1VPYma1oKysLNsANTMz0/VwAABACiB+SA28TgAAoCbjB3p8AQAAAAAAwJdIfAEAAAAAAMCXSHwBAAAAAADAl0h8AQAAAAAAwJdIfAEAAAAAAMCXSHwBAAAAAADAl0h8AQAAAAAAwJdIfAEAAAAAAMCXaikFRKNRe1lcXOx6KAAAIEXE4oZYHIHkRJwHAABqMs5LicTXgQMH7GWbNm1cDwUAAKQYE0dkZWW5HgZOgjgPAADUZJwXiqbAadDy8nIVFRWpQYMGCoVCNZIpNMFWYWGhMjMzT/vPxzfj+LvF8XeL4+8Wx9/fx9+EOCYYatWqlcJhujskK+I8f+P4u8Xxd4vj7xbH363iJIrzUqLiy/wSOTk5Nf485sXgD8Idjr9bHH+3OP5ucfz9e/yp9Ep+xHnBwPF3i+PvFsffLY6/W8kQ53H6EwAAAAAAAL5E4gsAAAAAAAC+ROJLUu3atTVp0iR7icTj+LvF8XeL4+8Wx98tjj8SgfeZWxx/tzj+bnH83eL4u1U7iY5/SjS3BwAAAAAAAL4vKr4AAAAAAADgSyS+AAAAAAAA4EskvgAAAAAAAOBLJL4AAAAAAADgS4FPfD377LNq166d6tSpo7y8PC1fvtz1kALjvffe04gRI9SqVSuFQiHNnTvX9ZACZcqUKTrvvPPUoEEDNWvWTFdeeaU2btzoeliBMX36dJ177rnKzMy0W9++ffX666+7HlYg/frXv7afQffee6/roQTGww8/bI/58VuXLl1cDws+RJznDnGeW8R5bhHnJQ/ivMRLxjgv0Imv2bNna/z48XaJzYKCAnXv3l1DhgzRnj17XA8tEA4dOmSPuQlKkXjvvvuu7rjjDi1btkxvvvmmysrKNHjwYPu6oObl5OTYL+JVq1Zp5cqVuvTSS3XFFVdo/fr1rocWKCtWrNBzzz1ng1MkVm5urnbu3FmxLVmyxPWQ4DPEeW4R57lFnOcWcV5yIM5zJzfJ4rxQNBqNKqDMmT9zJuQPf/iDvV1eXq42bdrorrvu0oQJE1wPL1BMFvjVV1+1Z6Pgxt69e+0ZQRMoXXTRRa6HE0iNGzfWk08+qRtvvNH1UALh4MGD6tWrl6ZNm6bHHntMPXr00NSpU10PKzBnAk31x5o1a1wPBT5GnJc8iPPcI85zjzgvsYjz3Hk4CeO8wFZ8lZaW2gz8oEGDKu4Lh8P29tKlS52ODXBh//79FV/KSKxIJKJZs2bZs7CmFB6JYc6EDx8+PO57AInzySef2ClQHTp0UH5+vrZt2+Z6SPAR4jwgHnGeO8R5bhDnufVJksV5tRRQ+/btsx9CzZs3j7vf3N6wYYOzcQEumLPgZt57//79dc4557geTmCsW7fOBkAlJSWqX7++PRvetWtX18MKBBOAmqlPpgQebipxXnjhBXXu3NmWv0+ePFkXXnihPvzwQ9uPBjhVxHlAJeI8N4jz3CHOcysvCeO8wCa+AMSfETEfRK7nXgeN+TIwJcDmLOzLL7+sMWPG2CkIBEU1q7CwUPfcc4/teWIaXiPxhg4dWnHd9N0wAdKZZ56pOXPmMAUEAE4z4jw3iPPcIM5zb2gSxnmBTXxlZ2crLS1Nu3fvjrvf3G7RooWzcQGJduedd2r+/Pl29SXTiBOJk56errPOOste7927tz0r9fvf/9424UTNMdOfTHNr0/chxlSGmL8B0wvoyJEj9vsBidOwYUN16tRJmzdvdj0U+ARxHuAhznOHOM8N4rzk0zAJ4rxwkD+IzAfQokWL4sqAzW3mXiMIzLoWJhgyZddvv/222rdv73pIgWc+g8yXMWrWZZddZqcfmLOwsa1Pnz62/4C5TjDkpgHtp59+qpYtW7oeCnyCOA9BR5yXfIjzEoM4L/kcTII4L7AVX4ZZ4tqUnJo/hPPPP9+u8mCaDo4bN8710ALzB3B81nfLli32w8g03Wzbtq3TsQWl7H3mzJmaN2+enWu9a9cue39WVpYyMjJcD8/3HnjgAVsGbN7rBw4csK/F4sWL9cYbb7gemu+Z9/uJPU7q1aunJk2a0PskQe677z6NGDHClr0XFRVp0qRJNhAdOXKk66HBR4jz3CLOc4s4zy3iPHeI89y7LwnjvEAnvq655hq7tO/EiRPtl4FZ4nThwoVVGqGiZqxcuVIDBw6MC1ANE6SaZnioWdOnT7eXl1xySdz9M2bM0NixYx2NKjhMCfbo0aNtw0cThJr57yYY+vGPf+x6aECN2759uw1+vvjiCzVt2lQDBgzQsmXL7HXgdCHOc4s4zy3iPLeI8xBk25MwzgtFTR0sAAAAAAAA4DOB7fEFAAAAAAAAfyPxBQAAAAAAAF8i8QUAAAAAAABfIvEFAAAAAAAAXyLxBQAAAAAAAF8i8QUAAAAAAABfIvEFAAAAAAAAXyLxBQAAAAAAAF8i8QUAAAAAAABfIvEFAAAAAAAAXyLxBQAAAAAAAF8i8QUAAAAAAAD50f8DPdpQHPUYsXYAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt\n", "\n", "metrics_history = {\n", " 'train_loss': [],\n", " 'train_accuracy': [],\n", " 'test_loss': [],\n", " 'test_accuracy': [],\n", "}\n", "\n", "rngs = nnx.Rngs(0)\n", "train_model = nnx.view(model, deterministic=False, use_running_average=False)\n", "eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n", "\n", "for step, batch in enumerate(train_ds.as_numpy_iterator()):\n", " # Run the optimization for one step and make a stateful update to the following:\n", " # - The train state's model parameters\n", " # - The optimizer state\n", " # - The training loss and accuracy batch metrics\n", " train_step(train_model, optimizer, metrics, rngs, batch)\n", "\n", " if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.\n", " # Log the training metrics.\n", " for metric, value in metrics.compute().items(): # Compute the metrics.\n", " metrics_history[f'train_{metric}'].append(value) # Record the metrics.\n", " metrics.reset() # Reset the metrics for the test set.\n", "\n", " # Compute the metrics on the test set after each training epoch.\n", " for test_batch in test_ds.as_numpy_iterator():\n", " eval_step(eval_model, metrics, test_batch)\n", "\n", " # Log the test metrics.\n", " for metric, value in metrics.compute().items():\n", " metrics_history[f'test_{metric}'].append(value)\n", " metrics.reset() # Reset the metrics for the next training epoch.\n", "\n", " clear_output(wait=True)\n", " # Plot loss and accuracy in subplots\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", " ax1.set_title('Loss')\n", " ax2.set_title('Accuracy')\n", " for dataset in ('train', 'test'):\n", " ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n", " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", " ax1.legend()\n", " ax2.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "25", "metadata": {}, "source": [ "## 7. Perform inference on the test set\n", "\n", "Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance." ] }, { "cell_type": "code", "execution_count": null, "id": "26", "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def pred_step(model: CNN, batch):\n", " logits = model(batch['image'], None)\n", " return logits.argmax(axis=1)" ] }, { "cell_type": "markdown", "id": "1d6cb81f", "metadata": {}, "source": [ "We reuse the `eval_model` view created earlier so that `Dropout` is disabled and `BatchNorm` uses stored running stats during inference." ] }, { "cell_type": "code", "execution_count": null, "id": "27", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA7QAAAPGCAYAAADTLdZkAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAjY5JREFUeJzt/QeYFFX6P27XINnAoqJgxIgBFUXBhCCirjlnFDNiXgOurgqYMfvVNbuIGDDnrIA5KwaUNbvGFSUYQIL0e1X/XvijddrtdmboOd33fV2sy4dD9ZnhnJ5+qqqfrsnlcrkEAAAAItOo3BMAAACAP0NBCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAESpqgra66+/PqmpqUk+/fTTkv5ejx49ko4dO9bpXNq3b5/su+++dXpM+CPWP9XOHqCaWf9UO3ugclVVQVuJXnnlleTwww9PVl111WTeeedNllpqqWTXXXdN3n///XJPDeaKqVOnJieccEKy2GKLJS1atEi6du2aPP744+WeFpTFmWeemX/BVtcvvqCh+uCDD5Ldd989WWKJJZKWLVsmK620UnLaaaclkydPLvfUoN6lRXFNTU3BX19++WVSDRqXewLUzuDBg5Pnnnsu2WWXXZLVV189+eabb5LLLrssWWuttZIXX3zRixqq4sn8jjvuSI4++uhkhRVWyJ+B3XLLLZORI0cmG264YbmnB3PNF198kZx11ln5k5tQDT7//POkS5cuSatWrfIn9xdccMHkhRdeSAYMGJC89tpryb333lvuKUK96tu3b9KrV6/fZLlcLjnkkEPyV4EXX3zxpBooaCN3zDHHJDfffHPStGnT2dluu+2WrLbaask555yT3HjjjWWdH9Snl19+ORk+fHhy3nnnJccdd1w+22efffIncvr37588//zz5Z4izDXpHlh33XWTX3/9Nfnuu+/KPR2od8OGDUsmTpyYPPvss/k71VIHH3xwMnPmzOSGG25IJkyYkLRu3brc04R6s9566+V/zSndD+kdCnvttVdSLar6luP0zN1WW22Vv1WxWbNmyXLLLZecfvrp+RcDIenZvvXXXz9/W+MyyyyTXHnllcHbH9Mzg8svv3z+mEsuuWT+hXWa14d0PnMWs6n0KlX6xP7ee+/Vy2NSGSph/adXZueZZ578C5hZmjdvnhxwwAH5s/Tp2Xuo5D0wy9NPP53fDxdffHG9Pg6VoxLW/w8//JD/76KLLvqbvF27dkmjRo0yr4+g0vZAyM0335y/3XjPPfdMqkVVX6FNb02cb7758lc50/+OGDEiOfXUU/NPkOkVnzmlZ/nS2xjT96fuscceyW233Zb069cv/2S5//7758ekZwS33Xbb/JmR9AX2yiuvnLz99tvJRRddlH9P6z333FNwLunfHT9+fFHzTm+tadKkScE/T281+O9//zv7bCVU6vp/4403khVXXDFZYIEFfjMmvQUtNXr06PwPE6jUPZBKX3wdccQRyYEHHpi/OweqZf2nzXrSt16lJzEHDRqULLTQQvk7c6644orkyCOPdPs9Fb8Hfm/69On5uaWFd3rLcdXIVZEhQ4bk0i/5k08+yf9+8uTJmTF9+/bNtWzZMvfLL7/Mzrp3757/exdccMHsbOrUqblOnTrlFllkkdy0adPy2bBhw3KNGjXKPfPMM7855pVXXpn/+88999zsbOmll8716dNn9u/TOaVjivk1cuTIP/w603mk46677ro/9X2iMlXi+l911VVzPXv2zHwdY8aMyY9NHxsqeQ+kLrvsslyrVq1y33777ez5pnsDqmH9n3766bkWLVr8Zsw//vGPOvmeUVkqdQ/M6f7778+Pufzyy3PVpKqv0Ka3DMzy448/5m8H6NatW3LVVVclY8eOTdZYY43Zf964ceP8G69nSc/IpL9Pz86ktyCk71u6/fbb82dj0g57c75/qWfPnvn/pk1q0jMmIW3bti26M+uc8/q9dN6HHXZY/n76Pn36FHU8qlMlrP8pU6bkb+n5vfS241l/DpW8B77//vv8FYVTTjkladOmTYnfAapZJaz/VHoVaqONNkp22mmn/BXaBx98MN8cLT1m2igKKn0P/P524/TqbXoluZpUdUE7ZsyY5OSTT87fYjDrfRizTJo06Te/T++v//2tK+mtjqn086zShZy2jk/ft1roRcW3335bcC7pC/DfdykrVdrhOH0vQHorwqz3FkIlr//0h1HofSm//PLL7D+HSt4D6fzTzq7pLcdQbes/bQqY3tqZ3s6ZfmxPascdd8zfvpl+nFt6a2ha5EKl7oE5/fTTT/n3BW+++eZVt+6rtqBNu+J17949/9679PPK0jeCp4vp9ddfzz8Jpk+GpUr/Tvr+pQsvvDD453/0Xr70PVDjxo0r6nHSFy+/b3SQbrwtttgi/3U988wz+Y0Hlb7+08Yfoc9Y+/rrr/P/tQ+o5D2Qvni6+uqr842gvvrqq9+c0EnfR5W+yEq/vnQ8VNr6T11++eXJmmuuObuYnSV9H2P6/si0z0JtiwQqU6XsgTndc889VdfdOKn2gnbUqFH5W7Xuuuuu/K0qs3zyySfB8emLhZ9//vk3Z2fSM4KpWW+6TjfDm2++mWyyySb57mKlSLuxph3TipHespA2Qpjzxcs222yTn88TTzyRrLLKKiU9NtWnUtZ/p06d8r9Pz6zO2RjqpZdemv3nUKl7ID2Zk76ASpvfpL9+Lz3eUUcdpfMxFbn+U2kDzNDH8qQndFIzZswoaR5Uj0rZA3O66aab8s2t0hM61aZqC9pZt+OmHYFnmTZtWv5sX0j6pJjeU592Qps1Nv19eltB586d81l6v/pDDz2UXHPNNb/5GJFZ7+VLX3gU6rj3Z++dT8/opJ87m35ESXqbwe8/iwoqef3vvPPOyfnnn5+/SjXrc2jTW5CHDBmSdO3aVYdjKnoPpJ+3fPfdd2f+PL2FLn0/2CWXXJJ/gQWVuP5n3fL52GOP5QuLWbd/pm655Zb8x/asvvrqRR2T6lMpe2CWcePG5S9qpbfZt2zZMqk2VVvQpm/KTs/qpY2T0jPb6ZmU9AO651zYc0pvXUxbw6e3cKVPmrfeemv+I0HSF9KzWmfvvffe+VbZhxxySP7syQYbbJAvONM3lqf5o48+mqy99tp1eu/8sccem9x33335K7Rpu+8bb7zxN3/eu3fvko9J5auU9Z8Wrbvsskty4okn5t+bkn7u29ChQ/PzvO6660o+HtWjEvbAwgsvnGy//faZfNYV2dCfQaWs/9Txxx+fPPzww/lGPmkDqPR9gw888EA+Sz/GyttOqPQ9MMutt96aL7qr8XbjvFwVt+tO22evu+66+Xbviy22WK5///65Rx99NNMSe9ZHILz66qu59dZbL9e8efN8u+30oxJ+L23dPXjw4Pz4Zs2a5Vq3bp3r3LlzbtCgQblJkyYVbNf9Z81qJV7oF1Ty+k9NmTIld9xxx+Xatm2bf8x11lkn98gjj9TJsakslboHfs/H9lBN6/+ll17KbbHFFvmfAU2aNMmtuOKKuTPPPDM3ffr0Ojk+laNS90Aq/TrSjxCaMWNGrhrVpP9T7qIaAAAAStWo5L8BAAAADYCCFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKCloAQAAiFLjYgfW1NTU70zgD5T745Ktf6p5/afsAap5D1j/VPP6T9kDNOQ94AotAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFqXO4JUHvt2rUL5gsuuGAmmzFjRnDsv//97zqfFw3TWmutFcwPOOCAYN6vX79gfu+992ayxx57rJazS5J33303mD/11FO1PjYAAJXFFVoAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKJUk8vlckUNrKmp/9nwh5ZffvlgPnLkyKK7H0+fPj049oorrgjmxxxzTNIQFLlM602s679Tp06Z7KGHHgqOXXTRRZOGYMKECcH86aefDuYXXnhhMP/iiy8y2aeffprEqNzrP+Y9QGUo9x6w/qnm9Z+yB2jIe8AVWgAAAKKkoAUAACBKCloAAACipKAFAAAgSppC1cJGG22UyW6//fbg2ELf5iFDhhR13FTHjh2D+XzzzVfSY4YUahb13HPPZbJevXol1dYQoaGv/1Dzp9Rdd92VyZZeeumkISv0vS51Dbz77ruZ7Oabbw6OPf/880vaF9W2/hvSHij0PDhixIhgftVVV2WyU045JakkvXv3Dua77LJLJtt///2DY7///vukISv3Hmgo67+aFWpcuOeee5b0czHk0ksvDeavvvpq0hCUe/2n7AHKSVMoAAAAKpKCFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIki7HRfjLX/4SzF977bVM1r59+3rrUPfVV18F82OOOaboYwwYMCCYr7zyysH8sccey2RbbrllUm0d/hr6+n/zzTdL6ghbDV2OS1Gow+XRRx+dNATlXv8NaQ9ceOGFwfxvf/tbMH/rrbcy2XbbbRcc++mnnyYxGjNmTDBfZZVVMtkdd9xRdEfkhqTce6ChrP9KM88882Sy/v37l/Rap9C/zYILLlj0PJ588slgvummmyYNQbnXf0PaAyuttFIwv/jii4P54osvXnT36kLHKPQai7lHl2MAAAAqkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhS43JPoCHp0qVLMD/jjDOC+dJLL13rxxwyZEgm+/jjj4sem/rmm2+KfrzTTz+9hNklyUcffVTSeKrH2LFji+4eO3Xq1GC+xx57ZLJu3bqV1G18/fXXT2rr0EMPLbqr47HHHhscO2PGjFrPg//9b77EEkvU+hjNmjVLYlSoi37Lli2LPsYmm2xShzOC4qyxxhrBfODAgUX/HBk6dGgwHzRoUDD//PPPM9kNN9wQHNuzZ8+kttq2bVvr12j8b4suumgw33zzzYs+RqFPgOjdu3cwf//994P5s88+m9TWQw89lMmmTJkSHLvjjjsG81tuuaXW8yjU5f+zzz5LYuAKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpZpcLpcramCgOUqlCTUnSJ1yyilFH+O5554ruvlN6ssvv0zmpv/+97/BfOGFFy66IdaAAQOSua3IZVpvGsr632qrrYL5TTfdFMznn3/+Wj/muHHjgvkGG2wwV5uILbjggsF84403DuZXX3110Y2lSrHccsuV1FChEtZ/ufbARhttlMmeeuqpko4Reg4r5Tm9ITnzzDOD+UknnVT0MSZMmFDS/mooyr0HGsrPgIZu3XXXDebXX3990c+nhxxySEmNMWfOnFn0/BZffPFg/vDDDwfz/fbbr+jXQG+++WYwr4vnm3Kv/4a0Bwo19Su0Pgq9/ua3fvzxxyTk5ZdfzmS9evVK5rb/tQdcoQUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEqNyz2BhmTMmDHB/Pbbbw/m77zzTlEdNcvlwAMPzGQLLLBASd3Dbr311jqfF3/eUkstVW/djAu55ZZbgnl9djQOGT9+fDC/8847g/kKK6xQdJfYUtx///3BfJtttpnr3Y8rXahTdTVYY401gvmhhx5a62N/9tlntT4GFHLssccG8w4dOgTz7bbbLpPdd999SX35+eefg/liiy0WzF955ZVMduqppwbHXnjhhbWcHcWYOnVqMN9///2D+WmnnZbJNt988+DYH374IZjvs88+wXzJJZdM5qZ27dqV1Ol7vvnmK/rYhV5HvvHGG0kMXKEFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKNblC7W1/P7Cmpv5nQ50aMWJEJttoo42CY5988slgvtVWW2WyGTNmJHNbkcu03jSU9T9lypRg3rRp03p7zPfffz+Yr7zyyklD1qxZs0y27bbbBscOHz681o8X6oaZWnfddaNf/+XaAxMnTsxkrVq1KukYoc7zp5xyStKQdenSJZi/9NJLtT52t27dgvmzzz6bNGTl3gMN5WdAQ9K+ffuiu99fc801wbxfv3719m8d+lSASy+9NDh26623LrrT+t/+9rfg2F9++SWp1PWfsgfKb8UVVyzp9dhdd92VyRo1Cl/L/PXXX4P5AQcckMmGDh2azG3/aw+4QgsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQal3sC1F7Xrl2D+SqrrFL0MQp1ICxHR2NK69xb3x0Ql1566WDeu3fvTHbjjTcmDcXUqVOL7ub9/PPPB/P111+/6Mdr3rx5CbNjToMGDQrm8803X9HHKNRd9corr/zT8wIKa9u2bdGdcJ966qmif3Y1bty46I7IqZ49ewbzv/71r5nsww8/DI7deeedg/ndd98dzKEcPvjgg2B+zjnnBPNQR+NCrxePP/74YF6OjsZ/hiu0AAAARElBCwAAQJQUtAAAAERJQQsAAECUNIWKSMeOHYP5gw8+GMz/8pe/ZLKnn346OPaxxx6r5eyotkZUiy++eBKb8ePHB/OJEyfO9bnwvxuPzTPPPEUfo2XLlsF8iSWWyGRffvllCbMDQjp16lT02O+++y6YH3LIIZnssMMOC45dddVVg/mECROC+eDBgzPZpZdeGhz7/fffB3NoSHr06BHMd9hhh6KPceGFFwbziy66KImZK7QAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAESp4rsch7ribb/99sGx2267bTBfe+21i368Ro3C5whmzpwZzF955ZWistQee+wRzBdaaKGiO7cOHDgwOPaHH34I5jQszz77bDDfcMMN5/pcampqkkpx+OGHB/NPPvmk6K979dVXD+b9+vUL5ldccUVJc6xk559/ftHPya1btw6ObdeuXTC/5ZZbMtmHH36YNGStWrWqt2Ofdtppwfyvf/1rMJ82bVq9zYW4FXrtEfLAAw8E88aNsy9D33jjjeDY/fbbL5gPHz48mE+dOrXo+UFDcuCBBwbza665ptaf7HDmmWcmlcgVWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAohRdl+Odd945mB966KHBvHv37pksl8uV9JiljC/UzbjQMUIdlEvpqvxHjxn6njz99NMlHZuGJdSxNbXBBhvU+tiFumt//fXXwfy6665LKsWyyy5b9L6tz+ePavXOO+8E8/XXXz+T3XPPPcGxHTp0CObLLLNMUVm12HjjjYP5lVdeGcz333//ep4RDd1mm20WzE844YSij1GoW/Z2222XyR555JESZgdxW2KJJTLZUUcdVSfH7tu3byabMGFCUolcoQUAACBKCloAAACipKAFAAAgSgpaAAAAotSgm0LtsMMOmeyGG24Ijm3atGkwHzduXNFNWoYMGRLMf/nll2A+fPjwot9sfdpppwXzgw46KKkvX331Vb0dm8qzyy67BPPPP/88qXTHHHNMrY9R6Pv0xBNP1PrY1Wrs2LGZbPfddw+O7dWrVzA/77zz6nxeMfvpp59KagpF9TjggAOC+dVXXx3MP/zww0z27bffBsd27tw5mDdp0qSkOUKlufPOOzNZx44dSzrGlQWevws1UaxErtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABClBtHleOeddw7moY7GhboZF+pQXJ9dhENOPfXUojs217e99tork73wwgvBsdOmTZsLM4LyWn755YP5csstV+tjT5w4sehOoPx5o0ePDuZvvfVWML/ssssy2QUXXBAc+/777wfzq666Kph369Ytkx133HFJbfXo0SOYF/r5V8gll1ySyU444YTg2KlTp5Z0bOKw6KKLZrJzzz03OHbLLbcsqfvxzTffnMmWWmqpkl6jhfbnK6+8Ehz7zTffBHOIwYYbbhjM11hjjaKP8fzzzwfzfv36JdXOFVoAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKJUk8vlckUNrKmpt0mMGDEimG+00UZFd8o7/PDD661z4+KLLx7M//GPf2Syvn37BscW+jaHuvmdddZZwbH77bdfMN9uu+2Kfsy//e1vwbGXXnpp0pAVuUzrTX2u/1LMN998wfzll18O5h06dCj62DfeeGMw79OnT1IpHY0feOCB4NgVVlih1o8X6viZ2nvvvaNf/w1pD1SDr7/+Opi3bds2mH/33XdF/2wo1CWzoSv3Hmjo679x4/CHVnz//fdFfy09e/YM5q+++motZ5cku+66azAfPnx40Z8Kce+99ybVqtzrP4Y90FCsvfbawfy5554runv9LbfcEhx76KGHlvQpC9W0B1yhBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSuG2ePVkww03DObdu3cP5v/+978z2UEHHVTrebRv3z6Y9+jRI5ifdNJJwXy55ZbLZNOmTQuOPf/884vu2leoo+D9999fdBfD1F/+8pdMtuOOOwbHDh06NJj/8MMPwZzy+Omnn4L59OnTa33szTbbLJjfcMMNwfyII47IZJMmTUrqS/PmzYP50ksvHczvvvvueulm/MUXXwTzSy65pNbHhj+j0L6LtaMxhTVp0iSYP/3000V/0kOh5/rRo0cn9WWhhRYqemyhrt3Q0DRq1Kjo10yhbsapl156KZNVczfjP8sVWgAAAKKkoAUAACBKCloAAACipKAFAAAgSnO1KdQ//vGPYJ7L5YL58OHDiz728ssvH8w32WSTTHbWWWcFx7Zq1SopxaOPPprJTj311ODYQo2e6sKWW24ZzO+5555M1q1bt+DYf/7zn8F87733ruXsmBtCzcVSHTt2LPoYiyyySDDfa6+9gvkSSyyRyV588cXg2Pvuuy+Yb7vttpmspqam6MdL7bnnnsnctNpqqwVzDdSAurTwwgtnstNPPz04tmvXrsF8/fXXn6vNn5o1a1bSa4lQI83333+/zucF9WHIkCGZbOWVVy7pNcJxxx2XyTR/Kp0rtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARGmudjnebLPNSupy3L1790z23HPPldTNdb755stkv/zyS3Dsf/7zn5K6qIY6F8+YMSOZ21566aVg/sILL2SybbbZpuhOiKktttgikz388MMlz5H6ddpppwXzH3/8MZOdc845dfKYof0ZylJHHXVUMG/evHkma9QofJ5t5syZSX256667gvkBBxxQ1PcU/qxQZ/xQd1uqz3fffZfJWrZsGRw7fvz4op9jGzcu7aVfp06dgvmSSy6ZyS688MKixxb62TVu3LiS5gf17bDDDgvm++yzT9HH+L//+79g/uyzz/7pefH/cYUWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAozdUux0OGDAnm++67bzAPdUx99913g2Ovv/76YP7MM89ksi+++CI49sUXX0wqyY477pjJhg4dGhy71157Fd3dUJfjhqdQd+2LLrqoqM7fqRNOOCGYN2nSpJazC3faLKRQ1/NShTplPv7448GxRx55ZDD/4Ycf6mQuUEi7du1q3YX2nnvuqcMZ0ZCFntP/6JMeRowYUW9zCXWef+qpp4Jjt95662A+ZsyYOp8X/FktWrQI5oW6d4c89thjwfy888770/Pif3OFFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIUk2uyA4sNTU1tX6wZs2aBfPllluu6GMUauikeUtx2rRpU1L+0UcfZbKpU6cmc1tdNQr6s+pi/Td0vXv3DuZLLrlkMD/jjDPqZR6NGoXPs73//vslNUl54403MtlLL72UxKjc679a9kA5XHHFFZnskEMOKekYhRoCVVLDnXLvgYa+/tu2bRvMN9lkk1of+7PPPgvmY8eOzWTfffddrR+Phrf+Y9gDdeHMM88M5ieddFIw//DDDzPZ6quvHhw7ZcqUWs6uuuX+xx5whRYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgCjN1S7HEGuHP+ufal7/KXugfuhyHMcesP6p5vVfaXtgoYUWCuaffvppMJ9vvvmC+eabb57JHnvssVrOjhBdjgEAAKhICloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEqNyz0BAKhWoS7Ha621VnDsmWeeGcz/85//1Pm8ACrVNttsU1I340KeeeaZOpoRteUKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUdLlGADK5K233spkXbt2LctcAKpBqd2MCznuuOMy2emnn14nx6Y0rtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFGqyeVyuaIG1tTU/2yggCKXab2x/qnm9Z+yB6jmPWD9U83rP2UP0JD3gCu0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABUdpdjAAAAaEhcoQUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSlVV0F5//fVJTU1N8umnn5b093r06JF07NixTufSvn37ZN99963TY8Ifsf6pdvYA1cz6p9rZA5WrqgraSvb6668n2267bbLgggsmLVu2zG+8//u//yv3tKDeTZ06NTnhhBOSxRZbLGnRokXStWvX5PHHHy/3tGCu8jOAajRmzJhkl112SZZddtn8ul944YWTjTbaKLn//vvLPTWYK3766adkwIAByV//+tf8839asKeFe7VpXO4JUHuPPfZYss022yRrrrlmcsoppyTzzTdf8tFHHyVffPFFuacG9S49w3nHHXckRx99dLLCCivkn8i33HLLZOTIkcmGG25Y7ulBvfMzgGr12WefJT/++GPSp0+f/EnNyZMnJ3feeWf+5M5VV12VHHzwweWeItSr7777LjnttNOSpZZaKlljjTWSUaNGJdVIQRu5H374Idlnn32SrbbaKv+ivlEjF92pHi+//HIyfPjw5LzzzkuOO+64fJbuh/TqVP/+/ZPnn3++3FOEeuVnANUsPXmZ/prT4YcfnnTu3Dm58MILFbRUvHbt2iVff/110rZt2+TVV19N1llnnaQaVfVPvnvvvTf/IiA9q9esWbNkueWWS04//fTk119/DY5/7bXXkvXXXz9/W+MyyyyTXHnllcHbH9NL/8svv3z+mEsuuWT+hXWa14ebb745+e9//5uceeaZ+RcyP//8czJz5sx6eSwqSyWs//QF/DzzzPObFy3NmzdPDjjggOSFF15IPv/883p5XCpDJewBPwOo5vUfkv5MSB934sSJc+0xiVMl7IH0Mdq2bZtUu6q+QpvempjemnXMMcfk/ztixIjk1FNPzZ/xTq/4zGnChAn5s4C77rprssceeyS33XZb0q9fv6Rp06bJ/vvvnx+TvohIb3N59tln8y+wV1555eTtt99OLrroouT9999P7rnnnoJzSf/u+PHji5p3q1atkiZNmuT//xNPPJEssMACyZdffplsv/32+ceZd955k7333jv/uOmLe6jU9f/GG28kK664Yn4PzKlLly75/44ePTr/wwQqdQ/4GUA1r/9Z0hM5U6ZMSSZNmpTcd999ycMPP5zstttuf+r7QvWopD1Q9XJVZMiQIbn0S/7kk0/yv588eXJmTN++fXMtW7bM/fLLL7Oz7t275//eBRdcMDubOnVqrlOnTrlFFlkkN23atHw2bNiwXKNGjXLPPPPMb4555ZVX5v/+c889Nztbeumlc3369Jn9+3RO6Zhifo0cOXL231t99dXz801/HXHEEbk777wz/9903O67715n3zviV4nrf9VVV8317Nkz83WMGTMmPzZ9bKjkPeBnANW8/uec96w/T+ew884758aPH1/r7xmVpZL3QOqVV17J/3n6dVabqr5Cm94yMEvaVCC9HaBbt275RgJjx47Nv7l6lsaNGyd9+/ad/fv0jEz6+/TsTHoLwrrrrpvcfvvt+bMxK620Uv5N2rP07Nkz/9+0SU16q0JIertAsZ1Z55xX2t0sbYJwyCGHzO5oueOOOybTpk3Lfx3pG8XTRjlQies/PSOf3m7ze7OuSqV/DpW8B/wMoJrX/yxpU8Cdd945+eqrr/JXztJbRtM9ANWyB6pd42pv937yySfnbzFIby+YU3rbypzS++vT27jmlN7qmEo/zypdyB988EHy3nvvJW3atAk+3rfffltwLukL8F69ev3pzZje/jCnPffcM78h0/cRejFDJa//0PtSfvnll9l/DpW+B1J+BlCN63+WtIBIf6XSJmmbbbZZvvP3Sy+9lP8YE6j0PVDtqragTZsFdO/ePf/eo/QMdvpG8HQxpZ/ll36m5Z9pqpH+ndVWWy3fWS/kj97Ll55NHDduXFGPk37OVHpmaNYGSzfkoosu+psxiyyyyOx7/qFS13/a3S997+DvpR3/Zu0PqOQ94GcA1bz+C0mv1qZXz9L3LXbo0KGo41JdKn0PVJuqLWjTz2n6/vvvk7vuuiv/IdyzfPLJJ8Hx6W0sadOBOc/OpE+Uqfbt2+f/m26GN998M9lkk01KPiOYdmNNO6YVI71loUePHvn/n7amT29RSF/Uz/mknc43VegsEdWtUtZ/p06d8r9Pz6zO2RgqPSs/68+hkveAnwFU8/ovZNbbTX5/lQ2qZQ9Um6otaNO27qlcLn3/9P+Tvt/i8ssvD46fMWNG/vattBParLHp79MXC+kLilTa+eyhhx5Krrnmmsxnn6VPrumZm9/frlDbe+fTxzznnHOS6667bvY9+qlrr702f7+/BU8lr//0LPz555+fXH311bM/hza9BXnIkCFJ165ddTim4veAnwFU8/pPb+GcdTfCLNOnT09uuOGG/O34q6yySlHHpPpUyh6gygva9E3ZrVu3Tvr06ZMceeSR+TMpw4YN+83CnlN6W9fgwYPz98mn98zfeuut+Y8ESV9Iz2qdnX5MQtqMIG3OkZ492WCDDfK3EKRvLE/zRx99NFl77bXr9N75NddcM98u/F//+ld+s6W3T6RnndI3pp944oluuaSi139atO6yyy75tZ6+sEk/923o0KH5eaYv8KHS94CfAVTz+k9vK07v0EmvsC2++OLJN998k9x00035x7zgggvyH8UClbwHUpdddln+FupZd+bcf//9yRdffJH//0cccUT+Y34qXq6K23Wn7bPXXXfdXIsWLXKLLbZYrn///rlHH3000xI7bdedfjzIq6++mltvvfVyzZs3z7fbvuyyyzKPkbbuHjx4cH58s2bNcq1bt8517tw5N2jQoNykSZMKtuuujfQxBw4cmD9mkyZNcssvv3zuoosuqpNjUzkqdf1PmTIld9xxx+Xatm2bf8x11lkn98gjj9TJsakslboH/AygWtf/LbfckuvVq1du0UUXzTVu3Dj/eOnv77333lofm8pTiXtg1rGSAh/xM+trrXQ16f+Uu6gGAACAUjUq+W8AAABAA6CgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAotS42IE1NTX1OxP4A+X+uGTrn2pe/yl7gGreA9Y/1bz+U/YADXkPuEILAABAlBS0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQal3sCFG/TTTcN5ocddlgw33bbbTPZueeeGxz797//vZazAwAAmLtcoQUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEo1uVwuV9TAmpr6n00VateuXSbbfPPNg2MvvPDCYN6qVauiH2/69OkldUq+7rrrkoagyGVab6x/qnn9p+yBuWe++eYL5tdcc00w33333YP5iy++WPTPlx9++CFpyMq9B6z/4jRt2jSYN2vWrOhj9OrVK5gPGDAgmK+22mpFH7vQMc4444ykISv3+k/ZA+X/9x40aFAwHzhwYFLp/tcecIUWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAoNS73BKqlO2Xv3r2D+f7775/JOnfunNSXeeaZJ5jPP//89faYNCyNG4e3/YEHHhjMV1hhhaKP/dNPPwXza6+9Nph/++23mWzq1KlFPx7EbqWVVspkDz30UHBs+/btS+r+2LVr10y29957B8f+85///B8zpSEo9DO8Q4cOwbxv377J3LT66qsH827duhXdObfUjr6ljA/tCSinUjoUd+/evV7nEjNXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKGkKVccKNfPYYIMNat0QoVCznIsuuiiYH3bYYZlswoQJwbEXX3xxMKfynHzyySXlpQit59Q//vGPYD5y5MhM9sQTTwTHFspfe+21kuYI5dCuXbtg/uijj2ayJZdcMjj26quvDuannXZaMP/www+LbgpHHBZZZJFg/tZbb831uTR0U6ZMyWR33XVXWeYCdaFHjx7lnkKD5QotAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABR0u6wCCuttFIwv/fee4vuTlmK8ePHB/ODDjoomN9zzz1Fd9W85ZZbajk7YrLHHntkslNOOSU4tlB37fq08cYbF5WlBg4cGMxff/31YH7rrbdmsqeeeio49s033/wfM4XitGjRoqRu9KGfGY888khw7LHHHhvMf/7552D+wAMPZLJ33nknOBYqTehn3ZAhQ8oyFyike/fu5Z5CRXCFFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKNXkimxtWlNTk1S6xo3DTZ8vueSSYH7IIYfU+jE///zzTPa3v/0tOPbuu+9OqlU5OvBWwvofM2ZM0V276+J7XOj71FCO/dNPPwXzQt2/+/XrlzQE5V7/Me+Bua3QmvnnP/8ZzD/55JNMtsYaa5S0fgtp3759Jvvyyy+DY6dPn540ZOXeAw1l/Tdr1iyYX3bZZcF8v/32q/VjvvHGG8E89LOkUJfvUr6vhf6tp0yZEswLde6/6aabMtm4ceOSGJV7/TekPRCrHj16BPORI0cWfYxBgwaV9CkQleR/7QFXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKIW7IFW4Qk1xjjjiiHpr/lRK0w4o1aWXXlr0Wm/UKHwea+bMmbWeR6FjfPXVV8F8+PDhmeyhhx4Kjn3qqaeC+WKLLRbMd9ttt6IbrvXt2zeYb7311plshx12CI4dPXp0MJ8xY0YwJ35rr712Jrv44ouDY8ePHx/Md91111o3fyrk008/rZPj0HBMnTo1mB955JHBfOjQobV+zELPba+99lomW2655eb61zhkyJBaPyaUqykUdcMVWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAolSVXY4LdeGri27Gjz/+eEldaKEU888/fzDfaKONgnkulyu6E/GPP/5YUpfMtdZaK5M99thjwbGnn356Ul8KdVC+6KKLMtnXX38dHHvTTTcF83bt2mWyF198MTj2sMMOC+ZXXXVVMCd+oa6rTZo0CY594YUXiu4UC6WaMmVKMH/22WdrfexC3YWXXHLJWh871AX+0EMPrbeOzRCzgQMHlnsKDZYrtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARKkmF2qDGhpYU1P/s5lLRowYEcy7d+9e0nEmTpyYyTbZZJPg2NGjR5d0bH6ryGVabxrK+u/Tp08wv+6662r9tRx99NFV26G7UJfj3XbbrehjPPjgg8F8u+22S2Jf/w1pD5RDly5dgvnzzz+fyT766KPg2LXXXruk7uI0rD1QDev/iCOOCOaDBw8O5k2bNq31Y+67776Z7MYbb6z1cStNudd/teyBhvJvOGrUqGC+8cYbJ9Uq9z++f67QAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpcZJFVp22WXr5Dj77LNPJtPNmLqy2GKLZbLLLrus1sf96quvgvm1116bVKtvvvmm1sdo165dncyF8inUtfX6668P5o0aZc8JDxs2rKRuxs2bNy96Hj/88EMwh1Icdthhwfzcc88N5k2aNKm3uehoTKUZOHBgrY9Rzd2M/yxXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKFV8U6gTTzwxky211FJ1cuxnnnmm1sfo2LFjJuvWrVtJx9h8882D+bbbblv0Me69995gvttuu2WyadOmlTA7/qyePXtmspYtW9b6uIWa00yZMiWpVvPPP38wr6mpKfoYTz/9dB3OiHLYcccdg/lKK61U9DFWXHHFYP7JJ58E88aNsz+G55lnnuDYX375JZgPHz48mA8YMCCTTZ8+PTiWyrTDDjtkssMPP3yuN38q5TVaqe6+++5gPnbs2FofG0oVet6l/rlCCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlCqmy3GoU2Shjsa5XK6kY1988cXB/Oeff85ka6yxRkldVG+99dZM1rZt26QulPJ1FuqI3Lx580ymy/Hcseaaa9Z67YZcc801SbXaeuutg/kBBxwQzEv5ftfFvw3ltfbaa9f6GL179w7mhZ43Q/uxUDfjPn36BPO///3vwfyRRx7JZLpxV6bll18+mN9xxx1JQ3bWWWdlspkzZ5Z0jDPOOCOY33bbbZnslFNOCY798MMPS3pMoGFxhRYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgChVTJfjeeedN5gffPDBtT72Dz/8EMx79uyZyW688cbg2IUXXjiY19TU1Lpb6tSpU4N5kyZNMlmjRs5hVLvhw4cn1WrLLbest2PrkhmPli1bBvOtttqq1sf+7LPPgvlJJ50UzG+55Zaij33nnXcG8+effz6YX3XVVZmsc+fOwbGTJ08ueh7Eo6F3Xw91NK6rOe+yyy6ZrEuXLsGxO+64YzAfM2ZMJpsxY0YdzI5KMHDgwFofY9CgQXUyl2qnugEAACBKCloAAACipKAFAAAgSgpaAAAAolSTK/Ld96HmRQ1Jq1atgvn48eOThqyUplD33XdfML/yyiuLbgiy5JJLljS/1q1bF90kq5IbW5Rj/Y8cOTKTdevWraRjvP7660U3xag0p556aib7xz/+ERzbuHHjotfd+++/Hxy73nrrBfNJkyYlsa//GH4GlGK33XardYOm1JdffpnJNt5443prGhZq9PdHjQFD2rZtG8y//fbbpCEr9x5o6Ot/gQUWCOb9+vXLZPvtt19JzdIKHbtZs2aZ7Oeffw6O/e6774r+vhZqolnodV59WnfddTPZq6++WnXrP4Y9UJ969OhR9Ou0UlXz97Uu94ArtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARCnc2pO55sEHH8xk//znP4Nj559//mC+1VZbBfPFFlus6HmMHTs2mM+YMaPoY1C3unfvXutOh08//XRS6Tp27BjM+/btW3Q340JdBqdNm5bJevfuXW/djJk72rVrVyfHefjhh+ulmzGUqtCnDwwePLio7I86YLdv3z6Y/+Uvf8lk33zzTXDs6NGjk2J16tQpmK+zzjrB/Oijjw7mHTp0SGrrpJNOKrpL+vTp02v9eMTV5bgUgwYNqpO5EOYKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUdLluI4V6nJ34YUXBvOzzz47k2222WbBscOHD6/l7JLk3//+dzDfdtttg/nkyZNr/Zj8OaGOxqV2OS51fIzdjEOdwlOLLrpo0d+PUDfjQt0zX3/99f8xU6rFHXfcMVcfr1AX2kLGjBmTyX788cc6nBGVpFCH4kJ5fSnUEblQXuhnwKhRozLZsssuW9JcQq+NFlxwweDY//73vyUdm7g/daJUofVI3XGFFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKOlyXMcmTpxY0vjbb789k2266aZJfTn22GOD+UcffVRvj8mf88EHH2Sy5ZdfPql0p556ajDv27dv0d2MS3XEEUcE82uvvbbWx6bh+f777+vkOCNGjEjqQ+PG4R/NQ4cOLek4w4YNy2RTpkz50/Oi/Jo1axbMd9xxx2B+yCGHZLL//Oc/wbGXXHJJMH/11VeThmD11VcP5scff3wwL7WjccgXX3xRdFd84tejR4+S8lLocly/XKEFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKJUMU2hampqkoagTZs2wfyEE04I5o0aZc8pzJw5s6THHDNmTDC/+eabM9njjz9e0rEpnwcffDCTHXXUUUmMtt5662B+8sknZ7I111yzpEY5uVyu6HkceuihwVzzp+ry2GOP1clxFlhggUw2fvz4ko7RpEmTohv8FGpM8uWXX5bU5Id4HXfcccF80KBBRR9jgw02KOl5+uOPPw7mb731ViZ76KGHklKceOKJRT+nL7nkksF8wQUXTOrLnnvumckmTJhQb49HedVF8yfKwxVaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACiVDFdjn/66adgvtFGGxXd+bFQd9X6VEqH1vfffz+Yb7PNNsH8s88++9PzomGu6VK7eYe6sJaqZcuWwXyhhRbKZKecckpw7AEHHFDreRT62qdNmxbMjzjiiEymmzF/1In4qaeeCubdu3cvuuPsSSedVHQ340IdjW+55ZaSfs5ttdVWwXzq1KnBnHgtssgi9Xbs+eefP5ivscYaRed77713rZ/XS3ldVKovvvgimF922WXB/JVXXqm3udDwFHqurwsDBw4sKac0rtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABClmlyR7eRK7a7akC299NLB/P777w/mq666ar3N5Zlnnslkw4cPD4594okngvmHH36YVLr67HrYUNd/qJvlO++8Exy74IILFn3cO++8s6R5LLHEEsG8a9euRX+f6uLfr9D6Hzx4cDAfOXJkUinKvf4r7WdAIaGu+KmHH344mE+ZMqXoPTrvvPMG886dOxfdzXjbbbcN5qNGjUoqXbn3QENZ/4U+peGwww5LYlSfXY7vu+++THbqqacGxxbatw1Fudd/Q9oDsX6fN95446p9/p4b/zau0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUarKplDEp9wNERrK+l9xxRWDeb9+/YL5gQcemMlatmxZb9/jUptCjRgxoujmT+eee25Srcq9/hvSHiiHxRZbLJjfcMMNmaxnz57BsRMnTgzmt99+eya79NJLo2xcU8l7oKGs/2bNmgXzxo0bF32MXXfdNZgvu+yyJc3lkEMOyWStW7cu6RhPP/10JnvuuedK2kNXXnllMJ86dWommzFjRhKjcq//hrQH6lN9vg6idjSFAgAAoCIpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKOlyTBTK3eEv1vXfrl27oruwdurUqdaP9/PPPwfza6+9Nph/++23mWzatGm1nkelKff6j3kPUBnKvQesf6p5/afsAcpJl2MAAAAqkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSLsdEodwd/qx/qnn9p+wBqnkPWP9U8/pP2QOUky7HAAAAVCQFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRqsnlcrlyTwIAAABK5QotAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFGqqoL2+uuvT2pqapJPP/20pL/Xo0ePpGPHjnU6l/bt2yf77rtvnR4T/oj1T7WzB6hm1j/Vzh6oXFVV0FaqDz74INl9992TJZZYImnZsmWy0korJaeddloyefLkck8N6t3UqVOTE044IVlsscWSFi1aJF27dk0ef/zxck8L5or0BVH6Aq3Qry+//LLcU4R69dprryV//etfkwUWWCCZf/75k8022ywZPXp0uacFc80H6oCkcbknQO18/vnnSZcuXZJWrVolhx9+eLLgggsmL7zwQjJgwID8k/y9995b7ilCvb+gv+OOO5Kjjz46WWGFFfJnYLfccstk5MiRyYYbblju6UG96tu3b9KrV6/fZLlcLjnkkEPyVwAWX3zxss0N6tvrr7+ef55fcskl8697Zs6cmVx++eVJ9+7dk5dffjnp0KFDuacI9Uod8P8oaCM3bNiwZOLEicmzzz6brLrqqvns4IMPzj+p33DDDcmECROS1q1bl3uaUC/SFyzDhw9PzjvvvOS4447LZ/vss0/+1qD+/fsnzz//fLmnCPVqvfXWy/+aU/rzID0zv9dee5VtXjA3nHLKKfk7c9IX8AsttFA+6927d7LiiismJ510UnLnnXeWe4pQr9QB/09V33KcnrXYaqut8rcqNmvWLFluueWS008/Pfn111+D49MzHeuvv37+yXOZZZZJrrzyyuDtj+lZkeWXXz5/zPSsYfrCOs3rww8//JD/76KLLvqbvF27dkmjRo2Spk2b1svjEr9KWP/pldl55pkn/+Q9S/PmzZMDDjgg/wInPXMJlbwHQm6++eb87cZ77rnnXHtM4lMJ6/+ZZ57J36Ewq5id9fonvUL7wAMPJD/99FO9PC6VoRL2gDrg/6nqK7TprYnzzTdfcswxx+T/O2LEiOTUU0/NL470is+c0jMc6W2Mu+66a7LHHnskt912W9KvX7/8Qtl///3zY9KzIdtuu23+LEn6AnvllVdO3n777eSiiy5K3n///eSee+4pOJf0744fP76oeae3FTRp0mT2G9UHDx6cfwE/aNCg/JN6elXqiiuuSI488shk3nnnrdX3iMpVCev/jTfeyJ+JT987Naf09ptU+j6q9IcJVOoe+L3p06fn55a+6EpvOYZKXv9pkZAWF7+Xvo9w2rRpyTvvvJOsu+66JX5nqBaVsAfUAf9/uSoyZMiQXPolf/LJJ/nfT548OTOmb9++uZYtW+Z++eWX2Vn37t3zf++CCy6YnU2dOjXXqVOn3CKLLJKbNm1aPhs2bFiuUaNGuWeeeeY3x7zyyivzf/+5556bnS299NK5Pn36zP59Oqd0TDG/Ro4c+Zvjn3766bkWLVr8Zsw//vGPOvmeUTkqcf2vuuqquZ49e2a+jjFjxuTHpo8NlbwHfu/+++/Pj7n88sv/9PeJylSJ63+11VbLrbjiirkZM2b8Zm5LLbVUfuwdd9xRB985KkUl7oHU6eqAXFVfoZ3zrN6PP/6YP9PXrVu35KqrrkrGjh2brLHGGrP/vHHjxvnmG7OkZ2TS36dnZ9JbENIzgLfffnv+bEzaXey7776bPbZnz575/6ZNatKz5iFt27YtujPrnPNKpWfhN9poo2SnnXbKn5l58MEHk7POOit/zPQN4lCp63/KlCn5W3p+L73teNafQyXvgdDtxumZ+/QqAlT6+j/00EPzc0ivTqW3daZXuc4444zk66+/zv+5nwFU+h5ItVcHVPctx2PGjElOPvnk/C0Gs+5Bn2XSpEm/+X16f/3vL9untzqm0s+zShdy2jb7vffeS9q0aRN8vG+//bbgXNIX4L/vVFmMtCFOeltDeitD2q47teOOO+af1NOPMklvi5jzvSVQSes//WEUel/KL7/8MvvPoZL3wJzS9wum7wnbfPPNPe9TFes/7ead9kpIbw8dOnRoPlt77bXzxe2ZZ56Zv40UKnkPqAOqvKBNO4KlTQPS996ln9WUvhE8XUxpC/h0AaQLoVTp31lttdWSCy+8MPjnf/RevvQN6OPGjSvqcdKW3LPe5J22p19zzTVnL+JZ0nv40/cGpO8xrO2LJCpPpaz/tOlB6HM2Z52dT38AQSXvgTml78/S3ZhqW/9p4Zp2uU+Lk/S9hekc0g7HcxYcUKl7QB1Q5QXtqFGjku+//z6566678pfpZ/nkk0+C47/66qvk559//s3ZmfRsSGpW4410M7z55pvJJptsku8wWYr0DGPaMa0Y6S0L6ZvAU//973+D7bjTxiCpGTNmlDQPqkOlrP9OnTrlf5+eWZ2zMdRLL700+8+hkvfAnG666ab8Fan0hQxU0/pPXwfN+bnjTzzxRP4FfnrrJ1TyHlAHVHlBm37Ux6wPoJ8l7YiXnukISRdEek992glt1tj09+ltBZ07d85n6XuWHnrooeSaa675zceIzHofR3rmplC3sT9773x69vGxxx7Lb6o5z0Tecsst+Xbdq6++elHHpLpUyvrfeeedk/PPPz+5+uqrZ38ObXoL8pAhQ5KuXbvqcEzF74FZ0jP76Yv49PaytMMrVNP6n9Ott96avPLKK/mfDenrIKjkPaAOqPKCNn1TdnpGo0+fPvm21umZlPTDiedc2HNKb11M22Kn98mnCyZ9wkw/EiR9IT2rdfbee++db+OdvqcjPXuywQYb5G8hSN9YnuaPPvpo/r0ddXnv/PHHH588/PDD+Texp2/8Tu+TTz97Lc0OPPBAt1xS0es/LVp32WWX5MQTT8y/NyX93Lf0fVTpPK+77rqSj0f1qJQ9MEs6n/QFl9uNqab1//TTT+dvF91ss83yr39efPHF/AnNv/71r8lRRx1V8vGoHpWyB9QB/3+5Km7XnbbPXnfddfOtrhdbbLFc//79c48++mimJXbarjv9eJBXX301t9566+WaN2+eb7d92WWXZR4jbd09ePDg/PhmzZrlWrdunevcuXNu0KBBuUmTJhVs110bL730Um6LLbbItW3bNtekSZN8C/szzzwzN3369Do5PpWhUtf/lClTcscdd1x+/aePuc466+QeeeSROjk2laVS90Aq/TrSj4+Y8+NLoNLX/4cffpjbbLPNcgsvvHD+8VZaaaXc2Wefnf9IFaiGPZB6SR2Qq0n/Z1ZxCwAAALHw5gIAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACi1LjYgTU1NfU7E/gD5f64ZOufal7/KXuAat4D1j/VvP5T9gANeQ+4QgsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARKlxuSdQ7f72t79lsgsuuCA4dr/99gvmQ4cOrfN5AQAANHSu0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECVdjueShx9+OJhvsskmmWzUqFHBsXfccUedzwvmhkIduk8++eRMtswyywTH1tTUBPNcLlf0fjnrrLOCY0ePHh3MAQBo2FyhBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSjW5Qi1Ci+wwWs0WWmihTHb//fcHx3bp0iWYT5gwIZNtuOGGwbH//ve/k2pV5DKtN9Z/1uDBgzPZUUcdFRzbuHHjBvF9De23VK9evRp09+Nyr/+UPUA17wHrn2pe/yl7YO6Zf/75g3m/fv1KOs5pp52WyZo1axYce8IJJwTzc889N4lhD7hCCwAAQJQUtAAAAERJQQsAAECUFLQAAABESVOoIhRq0nTppZdmsjXWWCM4dujQocH8yCOPzGQ//vhjyXOsdOVuiFDN6/+1114L5quvvnoma9QoznNkw4YNC+b77rtv0hCUe/1X+x6g/Mq9B6z/4nTu3DmYb7/99sG8TZs2mWyHHXYoemzqvffeC+Z33XVXJjv77LODYydPnpw0ZOVe/yl7oHa22267YN6/f/9M1qFDh+DY1q1bJ/Vl+vTpwfziiy/OZNdee21w7IcffpjUF02hAAAAqEgKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSlXZ5bhQJ9ZzzjknmB9++OHBvHHjxpns2GOPDY697LLLGmznuhiU+/tUSeu/VIU6SK644opJpfj555+L7nD+1ltvJdW2/su1B0KPucIKKwTH7rjjjsF8scUWK/rxdtppp2Derl27oudX6r/ho48+msk++OCD4NgzzjgjmH/77bdJpSv3HqjmnwEbbbRRMD/xxBMz2WabbVbSv1/o+1rK2FLH77PPPsGxN910U9KQlXv9V/seKOSss84qupvxMsssE8ybNWuW1JdHHnmk6G7hhTqUh7z77rvBfLXVVkvqiy7HAAAAVCQFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpWyb3grTtm3bTDZo0KDg2IMOOiiYf/7558F8wIABmez6668veY7Vqnnz5pnsl19+KctcKOyJJ56o+C7H8847bzDv0aNHg+hyXK2aNGlSdNft+lSou2JddB4NdYUt1Cl2vvnmC+b/+Mc/gvnXX39dy9lRiQo9391www3BfIcddih6/ZfaCbeU8XVx7EJf42OPPRbMx40bV9JjEreVVlopmF955ZVFfxJCqet08uTJmeztt98Ojr3vvvuC+bPPPhvMX3jhhUx25JFH1rrL8ZQpU5KGxhVaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKFdMUatFFFw3mjzzySCZbffXVg2O//PLLYL755psH87Fjx5Y0x2q18847F93IZM0115wLMyJk2WWXDeY77rhjUulCTRlSL7744lyfC//7ubq+FGry9PPPPwfzTz/9NJN16NCh1o2vCunTp08w/89//hPMBw4cWNJcqA5///vfg/l2221Xb03R7rrrrmB+zz33FN2EqpTmVIUUGlvo2FdffXXRxyYe/fv3D+YHH3xwMF9mmWWKPvZPP/0UzE866aRg/u6772aykSNHJnWhVatWmezoo48u6RjTp0/PZIMHD04aGldoAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIUsV0OT7rrLOK7pL59ttvB8eus846wXzatGlJtQp14OzcuXNw7GWXXRbMV1lllWDet2/fWs6Ouuw6fckllwTztm3bJg3ZO++8E8w7duxY9DFatmwZzNddd91M9vLLL5cwO2rjjTfeyGS33nprSf/eoQ6NV111VXDsxx9/HMyfeOKJpLZatGgRzO++++5Mtummm5Z07EKd+M8555xM9ssvv5R0bOIW6t578sknB8fOnDkzmNfU1BS9dgv9fClFoS6sheZRSKnjqUzt27evl27GqYceeiiTXXjhhcGxddW5uBS33XZbJltiiSVKOkaoo/Gdd96ZNDSu0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAEKWaXC6Xq4RucTvttFMwv/TSSzPZIossEhz71FNPBfMzzjijwXQsC2nVqlUwb926dSbba6+9gmN32223ojtzFnq8a6+9NpgX6kr65ptvJsUqcpnWm4a+/kvp7vfggw8Gx6600krJ3HbzzTcX1VHvjxTqzPnwww/Xurvfhx9+mMk6dOiQzG3lXv8x74GG7i9/+UvRnbSXW265ko49YMCAon+eNXTl3gOxrv9XXnklk6211lolfY/vueeeYL7PPvtkssmTJyf1Mec/M+/Qv1mhsYW6+X/33XdJQ1Du9R/zHgh9WsFzzz1X0jGefPLJYL7FFltksl9//TWZ2zbeeOOiuzA3bdo0OPbTTz8tuot+6LVRufeAK7QAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAESpcVIh7rzzzmD+0UcfZbLLL7+8pC5ha6+9djAfOnRoJjv33HODY7/44otgPu+882aynXfeueiOgqllllmm6A63n3/+eXDsiBEjgvl7772Xyf71r3816G6A/H/uvffeBtHNePz48cH8vPPOy2TvvPNOnTzm008/ncn23HPPko6x7LLLZrK99947OHbYsGElHRtSEydOLLrjfqldjrfaaqtMdvbZZwfHlqMzJ3F0qy00fuWVVy76GNtvv33Rn1BRqJN8Xcx7nXXWCY71+oU/Uug5eW4/b15wwQXB/Igjjgjm88wzT9EdikMdm1Mff/xxEgNXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKNXkcrlcUQNLfDN+Q9a4cbgX1gknnBDMDz744GC+5JJLFv2YN910UzDfdtttM9n8889fdPOQ1HXXXVd0o6wXX3wxiVGRy7TeNPT1v+uuuxbdqKjQ+q8Lr776ajAfOHBgMH/44Ydr/ZhLL710MH/55Zcz2cILL1zSsSdPnlx0U60vv/wyqdT1H8MeqCT77bdfML/22mtrfezmzZsH8+nTpycNWbn3QKzr/5VXXslka621Vknf40Jfe2h8KWMLja+LeaTuueeeoptrhp7rG5Jyr/+Y90DTpk2LbiS75ZZbBvMff/wxmG+22WZFvfb4I7179y66eV+rVq2KbjBbyKmnnhrMzzzzzCTmPeAKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUaq/dqcN2HzzzRfMb7/99mC+3HLLBfN999236Mfca6+9ih5baB4XXXRRMI+1czGl23333YP5oEGDgnl9djS++eabM9mhhx5aUofAutChQ4dgXmpH45Bff/11rnYzhrp07733FrWmqVxjx47NZJ07d6637raldsKti2N/9913wXznnXcuaS5UpmnTpmWyt99+u6Qux4U+fST06ShXXnllcOzRRx8dzDfccMOi65RCPv7446JrjzfffDOpRK7QAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpYrvcty1a9dMdskllwTHdunSJZjPnDkzmE+YMCGT3XrrrcGxbdq0CeY77LBDJuvRo0dw7GmnnRbMqR5LLbVUMF9++eXr7TEPPPDAYH7HHXfM1W7GoU6Aqeuvv77eHvOGG26ot2NDIc2bN6+T43zzzTdF/zyjMu29995Ff2LCyiuvHMzfe++9oh+v0DGGDh1a9DFyuVxSirPOOquk8XDGGWcE844dOwbzrbbaKphvv/32RWWlmj59ejA/++yzg/mNN94YzD/66KOkWrhCCwAAQJQUtAAAAERJQQsAAECUFLQAAABEqeKbQp1//vlFN3/673//G8zPOeecYF6ouVQpTj311Ew2cODA4NgRI0YE88022yyYv/nmm7WcHdVk4sSJwXzkyJHBvL4aQBVq/nTbbbcF80UXXbTWj1noa7n44otrfWwo1UEHHVQnx/n222/r5DhUltdff72kvBShRpepmpqakvKQxx57rN5ei1G5mjZtmslWWGGF4NiVVlopmdtCTc1effXV4Nh77713LswoTq7QAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpYrpcnzssccG8/XWWy+TzZw5s6TOkg888EBSX84888xMtskmmwTHduvWLZj37t07mOtyTCkefPDBYP7pp5/W+thLL710MO/QoUMmu/766+utm3Eh06ZNC+Yff/xxvT0mpBZZZJFM1rp165KOUahD/zXXXPOn5wV/xkknnRTMc7lc0ccoNHbvvff+0/Oi8nXs2DGYn3zyyZlsl112qZPHnDFjRiZr3Li00uqpp57KZE888USt5lWNXKEFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKFdPlePvttw/mjRpla/Zbb711rnczLuTXX3/NZNOnTy/pGHvttVcwP/fcczPZuHHjSjo21WPeeecN5k2aNAnmTZs2zWQbbLBBcOywYcOC+cILL5w0BBMmTCj3FKhS++yzTyZbaqmlSjrGa6+9Fsy//PLLPz0v+F9Cz+s1NTUlHSM0/uqrrw6O/e6770o6NtXlgAMOCOaldDSeOnVqMD/vvPOC+cSJEzPZ+eefn5SiZ8+emUyX49K5QgsAAECUFLQAAABESUELAABAlBS0AAAARElBCwAAQJQqpsvxRx99FMxDXVdj7WhaqHvgW2+9Fcx1NKYuOoXffPPNwbxVq1aZbJNNNkkasiFDhgTzUrsSQqlC+yV1+OGHF32MadOmldSBE+rCSiutVPTPjFwuFxxbKA91Lr7mmmtKniPV49JLLw3mhxxySNHHKPS6pl+/fsH8p59+CuZHHXVUUlubb755JjvppJNqfdxq4wotAAAAUVLQAgAAECUFLQAAAFFS0AIAABClimkK9dJLLwXzffbZJ5O1adMmmdu6du0azHv37p3JunfvHhw7adKkYH7GGWfUcnZQ2I477pg0ZJ988knRjZ5GjRoVHDt27Ng6nxfMqU+fPsF8ySWXLPoYTz/9dEk51IUtttgimLds2bLo5pWF3HTTTZns9ddfL+kYVJfddtstmDdqFL5GN3r06KIbSP3888/J3Pb222/P9cesRK7QAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQpYrpcjx06NBgvummm2ayHXbYITj2ySefDOYjRowI5k2bNs1ku+++e3DscsstV3RXtm+//bakzm7PPvtsMKfyFFqj33//fTBfcMEFa92FsqF3M/7rX/8azD/88MN6nhFkrbXWWvXWjf6+++6r9TGgVNtvv30wz+VyRR+j0NizzjrrT88LijF16tR662bcvn37Wh/jxhtvrJO5VDtXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKCloAQAAiFLFdDmePHlyMO/du3cm22effYJj+/fvH8xPP/30euti9sYbb2Sy66+/Pjh2woQJtZ4HcXvttdeC+SKLLBLMDz/88Ew2cODA4NjWrVvXcnZJMnPmzGBeqLPyjBkzil7/559/fjDXzZiGZPPNNw/m8847b9HH+Oqrr4L5dddd96fnBf9L3759g/lGG21U9PN96JMbCr0WS3333XclzREuuuiiYH7qqacG8w4dOmSyXXfdNTj2nXfeKel5/YgjjkiK9eijjwbz0aNHF30MCnOFFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIUk0ul8sVNbBAUxeYG4pcpvWmktZ/p06dgvlWW20VzI866qhgPm7cuEx2xhlnBMc2a9YsmI8aNSqTffrpp8Gx1azc67/S9kBd2WOPPTLZtddeGxzbvHnzoo/bq1evYD5y5MikWpV7D1TS+m/Tpk0wf+ihh4L5WmutVfS/SaHv0zrrrBPMX3/99T+YKQ1l/cewB0488cRgPmDAgEzWpEmTepvH1KlTg3mXLl1KakRFaXvAFVoAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkyzFRKHeHP+ufal7/1b4H5plnnmB+++23Z7LtttuupGM///zzmWyjjTZqsOugXMr9tVfS+l977bWD+UsvvRTMGzUKX/uYOXNm0V2Lt9hii2D+3Xff/cFMaSjrP+Y9EOpG//e//z04tmPHjiUd+9lnn81k5557bnDsgw8+WNKx+S1djgEAAKhICloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEq6HBOFcnf4s/6p5vVf7XugdevW9dah9bnnniu6y3E1K/ceqKT137Jly5K6HK+yyirB/K677spk/fr1C47VzTju9V9pe4D46HIMAABARVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABR0uWYKJS7w5/1TzWv/2rfA/PMM08wP/XUUzPZySefHBw7cuTIYL7ffvtlss8//7zkOVa6cu+Bal7/lF+513/KHqCcdDkGAACgIiloAQAAiJKCFgAAgCgpaAEAAIiSplBEodwNEax/qnn9p+wBqnkPWP9U8/pP2QOUk6ZQAAAAVCQFLQAAAFFS0AIAABAlBS0AAABRUtACAABQ2V2OAQAAoCFxhRYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKFVVQXv99dcnNTU1yaefflrS3+vRo0fSsWPHOp1L+/btk3333bdOjwl/xPqn2tkDVDPrn2pnD1Suqipoq8GZZ56Z36x1vfGgIRo1alR+vYd+vfjii+WeHtS7MWPGJLvsskuy7LLLJi1btkwWXnjhZKONNkruv//+ck8N5jqvgahGU6dOTU444YRkscUWS1q0aJF07do1efzxx5Nq0rjcE6DufPHFF8lZZ52VzDvvvOWeCsxVRx55ZLLOOuv8Jlt++eXLNh+YWz777LPkxx9/TPr06ZN/MTN58uTkzjvvTLbddtvkqquuSg4++OByTxHmCq+BqFb77rtvcscddyRHH310ssIKK+SvRG+55ZbJyJEjkw033DCpBgraCnLccccl6667bvLrr78m3333XbmnA3NNt27dkp133rnc04C5Ln3Rkv6a0+GHH5507tw5ufDCCxW0VA2vgahGL7/8cjJ8+PDkvPPOy++B1D777JO/S6F///7J888/n1SDqr7l+N5770222mqr/FntZs2aJcstt1xy+umn558MQ1577bVk/fXXz1/OX2aZZZIrr7wyeNl/wIAB+atD6TGXXHLJ/IJK8/r09NNP58/OXHzxxfX6OFSOSlr/qfQq1YwZM+r9cagclbYHZplnnnnyjztx4sS59pjEp5LWv9dAVOseSNf9PPPM85uTl82bN08OOOCA5IUXXkg+//zzpBpU9RXa9JL8fPPNlxxzzDH5/44YMSI59dRTkx9++CF/pmNOEyZMyJ8F33XXXZM99tgjue2225J+/folTZs2Tfbff//8mJkzZ+Zv83r22WfzC2vllVdO3n777eSiiy5K3n///eSee+4pOJf0744fP76oebdq1Spp0qTJ7N+nG++II45IDjzwwGS11Vb7098PqkulrP/Ufvvtl/z000/5J/X0am06/7XXXvtPfV+oHpW0B37++edkypQpyaRJk5L77rsvefjhh5PddtvtT31fqA6Vsv69BqKa98Abb7yRrLjiiskCCyzwmzFdunTJ/3f06NH5orri5arIkCFDcumX/Mknn+R/P3ny5MyYvn375lq2bJn75ZdfZmfdu3fP/70LLrhgdjZ16tRcp06dcossskhu2rRp+WzYsGG5Ro0a5Z555pnfHPPKK6/M//3nnntudrb00kvn+vTpM/v36ZzSMcX8Gjly5G+Of9lll+VatWqV+/bbb2fPd9VVV62D7xiVpBLXf3rMnXbaKXfdddfl7r333tzZZ5+dW2ihhXLNmzfPvf7663X2vaMyVOIemHPes/48ncPOO++cGz9+fK2/Z1SOSl3/XgNRzXsgXes9e/bMfB1jxozJj00fuxpU9RXa9JaBOW9XTG8HSK/upI00xo4dm6yxxhqz/7xx48ZJ3759Z/8+PSOT/j49O5PegpC+b+P222/Pn41ZaaWVfvP+jZ49e+b/m745O71VIaRt27ZFdySbc17ff/99/mzSKaeckrRp06bE7wDVrBLWf3q8OY+ZnhlN30u7+uqrJyeeeGLyyCOPFP39oPpUwh6YJW0Gkq79r776Kn/lIL1qNW3atCK/E1SjSlj/XgNR7XsgvTOnWbNmmTHpbcez/rwaNK72jzs4+eST87cYpLcXzCm9bWtO6f31v++cl17iT6WfZ5Uu5A8++CB57733Cj6pfvvttwXnki68Xr16lfw1pPNfcMEF87fbQLWt/5D0fSvbbbddctddd+Vf1Ke3IUOl74H0BVT6a1ZDkM022yzZZpttkpdeein/MSZQievfayCqfQ+kRfnUwPtzf/nll9l/Xg2qtqBNm2V07949f8/5aaedln8jeLqYXn/99fxnOaX3spcq/Tvp+zfSzpIhf3QPe/rCe9y4cUU9TvrknZ4ZSjfO1VdfnW+CkJ6Vn3MRT58+Pb/B0q8vHQ+Vtv7/SPpY6dWp9H2Fv39fCVTDHkiv1qZXD9L3bXXo0KGo41I9KmH9ew1Ete+BVLt27ZIvv/wyM+brr7+eXYhXg6otaEeNGpW/VSW9ipN+CP0sn3zySXB8+mSZvjie8+xM+kIh1b59+/x/083w5ptvJptssknJZ8TTLmRpx7RipLcs9OjRI7+A082TfgZn+uv30uMdddRRuv5Rkev/j3z88cf5H0xpkweoxj0w6zaz319lgEpZ/14DUe17INWpU6f879MrzHOewE/vzpn159WgagvaWbch5nLpe6b/n/SKzuWXXx4cn34cSHpPfdoJbdbY9PfpbQXp5/2l0s5nDz30UHLNNddkPvsvfXGRPvEW+sDvP3PvfPoZU3fffXfmz9PbJ9L3AlxyySX5zQWVuP5T6dnM39/ak/4wSbu8brHFFkmjRlX9yWRUwR5Ib2FbZJFFfvPn6dWpG264IX+r2SqrrFLUMakulbD+vQai2vfArLtxzj///PzdCrM+hza9BXnIkCFJ165dq6PDcTUXtOmbslu3bp306dMnf2YvPZMybNiw3yzsOaWX7AcPHpy/hSW9Z/7WW2/Nt8JOF9Cs1tl77713vhnHIYcckj9bssEGG+RvIUjfWJ7mjz76aMGPEvkz984vvPDCyfbbb5/JZ52NDP0ZVMr6T6UfS5K+aE+/nvRF/bvvvpufU8uWLZNzzjmn5ONRPSplD6S3Fadn5tMrDIsvvnjyzTffJDfddFP+MS+44AJ3KVCx699rIKp9D6TSonWXXXbJN8JMT3CmfUSGDh2an+d1112XVI1cFbfrTttnr7vuurkWLVrkFltssVz//v1zjz76aKYl9qwW8K+++mpuvfXWy38kSNpuO20V/3tp6+7Bgwfnxzdr1izXunXrXOfOnXODBg3KTZo0qWC77rqkZT3Vsv4vueSSXJcuXXILLrhgrnHjxrl27drlevfunfvggw9qfWwqTyXugVtuuSXXq1ev3KKLLprfA+njpb9PP8YKKn39h3gNRLXtgSlTpuSOO+64XNu2bfOPuc466+QeeeSRXDWpSf+n3EU1AAAAlMobzAAAAIiSghYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhS42IH1tTU1O9M4A+U++OSrX+qef2n7AGqeQ9Y/1Tz+k/ZAzTkPeAKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRalzuCQBAQ7bZZpsF87333juT7bXXXsGxo0ePDuaffvppJttxxx1LniMAVCtXaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKCloAQAAiFJNLpfLFTWwpqb+Z1PBBg4cGMxHjRpVVFbtilym9cb6p5rXf7XvgRdeeCGYd+nSpdbHnjx5cibbd999g2PvvPPOpFqVew9U8/qf2y699NJgftNNNwXzF198Mal05V7/KXug4TryyCOD+UorrZTJ+vbtW9KxGzXKXvtcfvnlg2M/+uijpFx7wBVaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKmkLVQo8ePTLZgAEDih4b6/e60NdSKC/U5KqU5lflbojQ0P9NqGzlXv/Vvge++eabYN6mTZt6+b4++eSTwbGbbrppUq3KvQcqaf2vuuqqwfzvf/97Sc3I7rnnnlrP5dhjj81k5557bknNnzbYYIOk0pV7/VfaHmhIQj9HrrnmmuDYlVdeOZgvX6BJU12sm9C/+0knnRQcO3jw4KS+aAoFAABARVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRalzuCcQs1NW3lG7GMQh9PSNHjizpGIU6P+uYR12Yf/75g/m8886byaZMmRIc26JFi1rPY8KECcF86tSptT425XXBBRcE83POOadeHq9p06bBvHHj8I/sGTNm1Ms8qEwdO3YM5r179w7m22+/fdEdV7/44ouS5hLqUNyoUfhay7rrrlsnn7AA5XDAAQcE84MPPjiTde7cOWnI7rjjjqShcYUWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAo6XJcRKe8Qnmh7r2lGDRoUDI3VcPXSPmccMIJwXzjjTeu9bELdcVeYYUVgvnSSy+dyb766qvg2MUXXzyY53K5ouf3+uuvB/N11lmn6GPQMF100UXB/K9//Wu9dLrfcMMNg/lKK60UzN95551aPyYUMt988wXziy++OJPtvPPO9TaPQt2P+/fvH8x1Oaa+tWnTJpP94x//CI494ogjav06g8JcoQUAACBKCloAAACipKAFAAAgSgpaAAAAolSVTaFKbYBUF00+CjVGGjhwYDI31efXWKgBw9z+Gimfs88+O2nICjV/KtRwqhRt27at9TFomDp27BjMV1111Xp5vFdffTWYf/zxx/XyeFSX1VdfvU6OU8rzZqGGTs2bN6/1PH755ZdaHwP+yD777BPMjz/++Ey28sor1/rx7rvvvqIbsaWefvrpoo99wAEHBPOrrroqiZkrtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAARKniuxyHOux27969oroZl9K1uT67GW+88ca1PjYNz/zzzx/Mn3rqqaKPkcvlgvn777+fyaZMmRIc+9xzzwXzTz/9NJi3b9++6LGlePfdd4P5888/X+tjU16FuhY/+uijwXzhhReul3msscYawfy0004L5qeeemownzx5cp3Oi8rQuXPnBtNZeYsttqj1sYcPH17rY1BdCnXoPvPMM4P53/72t2DepEmToh/zs88+C+Z77LFHJnv77bfr5Dm9e6Deueiii0o6xg033JDJ/vOf/yQNjSu0AAAARElBCwAAQJQUtAAAAERJQQsAAECUFLQAAABEqeK7HIc6/dan+uxmXOrXWBcdjUN0M65MiyyySDA/77zzSurEGnLggQcW3Z2yUJdjKFcn1vrqZlxIoc6ZhTptrr322sF8++23z2QTJ06s5ezg/3nxxRczWZcuXYJjhw4dOhdmBMU55JBDgnn//v3r7TGHDBkSzF966aV6e8y/BX5mtGzZMjh23LhxwXzw4MGZbPr06UlD4wotAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRiq7LcaHOvXO7m3F9dvst9DWOHDkyqS+jRo0K5oMGDaq3x6RhOeWUU4J57969a33sQl31mjdvnsl0OYbSdOvWLZhfddVVmWy33XabCzOiIXvnnXeC+eabb17Scc4999ykIZg2bVq5p0BkOnToUG/HvuOOO4L56aefXm+P2b1796J/NhTqZlxo/48dOzaJgSu0AAAARElBCwAAQJQUtAAAAERJQQsAAECUanK5XK6ogTU1SUMwcODAud4UqlDDpLpoChVqAFWfzZ8KKdT8KfT9LvRvUCivC0Uu03rTUNZ/XQg1ikkddNBBSUPwxBNPBPPTTjstmD/77LNJpSv3+q+0PVDIHnvsEcxvvPHGufp9rat/76lTp2ay9ddfPzh29OjRSUNW7j1QSeu/UGOw4cOHJw1ZoYaBhZoOVpJyr/9K2wOFvpYZM2aUdJwPP/xwrjacKnV9zJw5M5Ndd911wbEHH3xwEvMecIUWAACAKCloAQAAiJKCFgAAgCgpaAEAAIiSghYAAIAoNU4i07179wbR6bcuuhnXd3fmUhSaRygv1PW50NdYaDz1r1evXkV3M66LLoqFOge+8847wXyZZZbJZJtssklw7BJLLBHMu3btGsx//PHHP5gpZL3++uvB/Jhjjin6GHfddVcwX2ihhYp+3tx6662DY7t16xbMmzRpEsybN29e9H5p6F2OqTtvvPFGMH/vvfeC+corr5w0BB988EG5p0CFOPnkk0t6HTRu3LhgfuSRRyb1Yd555w3mF198cdHdjAu9/j766KOTSuQKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUarJFdnatFD30rmtUMfh+uwWXKj7cajjcqFOv5Vk4403nuvdjOuiA29tNJT1Xxf69esXzA888MBg/sADDwTzO++8s+jHfPfdd4vu2vrEE08Ex06bNi2YL7XUUiV1JYxRudd/pe2BWD344IPB/K9//WvRx1h//fWD+UsvvZQ0ZOXeA9Ww/gs9l2677bbBfPnll89kvXv3Do79+uuvg3nHjh2Lnt+UKVOCecuWLZNKV+71H/MeaNq0aSa7+uqrg2MLrd8bbrghmO+///5JfTj//PODeaEOxTUF/m222267ol/Txb4HXKEFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKIUXVOokSNHBvNqaMZUnwo1vgo1eqrP5k8NtSFCQ1n/lea4447LZIMHDw6OHTNmTDDv2rVrSQ1EYlTu9Z+yBxruz7+NNtqo6GNoCvXnWP+1U6iZzUUXXVT0MTSFKq9Y90CoednYsWNLOkaHDh2C+UcffZTU1iqrrJLJ7r///uDYpZdeOpg/88wzwXz77bfPZJMmTUpipCkUAAAAFUlBCwAAQJQUtAAAAERJQQsAAECUFLQAAABEqXESmY033jiYV3P340IdikMGDhxYr3Oh7nTs2DGT3XLLLcGxDz/8cDDv379/0pD17t27YjopEo9WrVoF8+bNmwfzcePGBfOZM2fWei4LLrhgJrv44ouDY7t161bSsWfMmJHJpk+fXtIxACpNOV5nrLHGGsH8sccey2QLL7xwcOzTTz9dUm1UTVyhBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgStF1OS6kUIevUFff7t27N5iOyKNGjcpkTz31VHCsDsXV5d57781k7du3D4497rjjav14LVq0COZt2rQJ5q1bt85kO+20U3DsgQceGMwXWmihTJbL5YJjt99++2A+ZcqUYA6FXHHFFcF8t912C+YnnnhiMD/33HOL2hepDh06BPOTTjopk2211VZJXXjttdcy2euvv14nx4ZS3H///cH8oosuKvoYhbqQb7rppsH88ccfL/rYVK5TTjml6NcZN9xwQzD/z3/+U+t59O3bt+jXQQ899FDRnwzB/+MKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUaqYLseFhDoDF+oWXI4ux4MGDSqq8zHVJ9ShtFCX46FDhwbzr7/+OpONGDEiOLZQZ9WFF144mIe6udbU1ATHFuooOHPmzEz2/PPPB8dOmDAhmEN9O/PMM4P5DjvsUHSX4xVWWCGYh/ZMof1SyFtvvRXMzzvvvJKOA/WlWbNmtT5GoZ8vhfYc1WXttdcO5ptvvnnRx5g0aVIwnz59ejBv2rRpJltqqaVK6nL8008/ZbLBgweXND9coQUAACBSCloAAACipKAFAAAgSgpaAAAAolTxTaFCjZ4GDBhQb49XqKFTqPnTH42H8ePHFz22TZs2RedrrLFGcGypjWhKUaih0wEHHJDJ7r333nqbB/wZjRqFz/126dJlrs5jxowZwfz4448P5k888UQ9zwjq/ucZ/BmFXgcVamxZF0INoMaOHVvSMY4++uhM9uyzz9ZqXtXIFVoAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKJU8V2OR44cOVcf76mnngrmuhlTqkMOOSSTnXjiicGxW2+9dTBfdtllM9n6669fb12Ob7rppmB+xx13BPMpU6bU+jEhZjU1NZnsp59+Co7de++9g7luxjR03333XdGvmbp3714nXcih0HNsqWPPPvvsYN6/f/+ij73bbruV9PqI0ngWAAAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgShXT5bhHjx5z/TFDnYsHDhw41+dB9ZgwYUIwHzZs2FyfC8TqrrvuKqkLZV345JNPgvnzzz+fyS688MLg2NGjR9f5vKCc5plnnlofo0+fPsF8+PDhtT428SvlExwKdZJv2bJlMA91pH/66aeDY3Uzrl+u0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUaqYplChBk31bdCgQXP9MQGonQceeCCYDx06tKSmM6U0nDr++OOD+aefflr0sSFWLVq0COZrrbXWXJ8LlWny5MnB/Oeff85k8847b3Bsq1atSnrMV199NZNts802JR2DuuEKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUarJ5XK5ogbW1CQx6tGjR1HZHxk4cGAdzog/o8hlWm9iXf9UhnKv/5Q9QDXvAeu/foS6gu+www4lHeORRx4J5ltssUVSKcq9/mPeA/vtt18mu+aaa0o6xhlnnBHMhwwZksk+++yzko5N3ewBV2gBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhSxXc5pjKUu8Of9U81r/+UPUA17wHrn2pe/yl7gHLS5RgAAICKpKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEo1uVwuV+5JAAAAQKlcoQUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSlVV0F5//fVJTU1N8umnn5b093r06JF07NixTufSvn37ZN99963TY8Ifsf6pdvYA1cz6p9rZA5WrqgraSvb6668n2267bbLgggsmLVu2zG+8//u//yv3tGCusP6pVmPGjEl22WWXZNlll82v/YUXXjjZaKONkvvvv7/cU4O5YurUqckJJ5yQLLbYYkmLFi2Srl27Jo8//ni5pwVzxSuvvJIcfvjhyaqrrprMO++8yVJLLZXsuuuuyfvvv59Uk8blngC199hjjyXbbLNNsuaaayannHJKMt988yUfffRR8sUXX5R7alDvrH+q2WeffZb8+OOPSZ8+ffIv6CdPnpzceeed+RM8V111VXLwwQeXe4pQr9KrXHfccUdy9NFHJyussEL+KtyWW26ZjBw5Mtlwww3LPT2oV4MHD06ee+65/InN1VdfPfnmm2+Syy67LFlrrbWSF198sc6vLDdUCtrI/fDDD8k+++yTbLXVVvkn9EaNXHSnelj/VLv0hXv6a07p2frOnTsnF154oYKWivbyyy8nw4cPT84777zkuOOOy2fpz4T0RXz//v2T559/vtxThHp1zDHHJDfffHPStGnT2dluu+2WrLbaask555yT3HjjjUk1qOpXf/fee2/+hXB6VrtZs2bJcsstl5x++unJr7/+Ghz/2muvJeuvv37+lpZlllkmufLKK4O3vgwYMCBZfvnl88dccskl80+qaV4f0kX83//+NznzzDPzL+Z//vnnZObMmfXyWFQW659qVwl7IGSeeebJP+7EiRPn2mMSn0pY/+mJzHS9z3nipnnz5skBBxyQvPDCC8nnn39eL49LZaiEPZDOp+kcxWwqvVMhvQX5vffeS6pFVV+hTW9LSW9PTM9upP8dMWJEcuqpp+av+qRn++Y0YcKE/Fnw9L70PfbYI7ntttuSfv365RfR/vvvnx+TvpBOb/N69tln80+uK6+8cvL2228nF110Uf5e9nvuuafgXNK/O378+KLm3apVq6RJkyb5///EE08kCyywQPLll18m22+/ff5x0nvo99577/zjpk/sEGL9U+0qYQ/Mkp7MmTJlSjJp0qTkvvvuSx5++OH8WXqo5PX/xhtvJCuuuGL+58CcunTpkv/v6NGj8wUFVOoeCMnlcvmT/WlRWzVyVWTIkCG59Ev+5JNP8r+fPHlyZkzfvn1zLVu2zP3yyy+zs+7du+f/3gUXXDA7mzp1aq5Tp065RRZZJDdt2rR8NmzYsFyjRo1yzzzzzG+OeeWVV+b//nPPPTc7W3rppXN9+vSZ/ft0TumYYn6NHDly9t9bffXV8/NNfx1xxBG5O++8M//fdNzuu+9eZ9874mf9U+0qcQ/MOe9Zf57OYeedd86NHz++1t8zKkclrv9VV10117Nnz8zXMWbMmPzY9LGhkvdAyLBhw/Ljrrvuuly1qOortOktA7OkTTXS2wG6deuWb6QxduzYZI011pj9540bN0769u07+/fpGZn09+nZmfQWhHXXXTe5/fbb82djVlpppeS7776bPbZnz575/6YNCtJbA0Latm1bdFe+Oef1008/5ZuAHHLIIbO7uu64447JtGnT8l/Haaedlr/1AH7P+qfaVcIemCVtiLPzzjsnX331Vf7KQXrLXLoPoJLXf3pXQnpb5+/Nujsn/XOo5D3we2PHjk0OO+ywZL311ss3C6wWjav94w5OPvnk/C0G6e0Fc0pv25pTen99eivjnNLbXFLp51mlC/mDDz7I36/epk2b4ON9++23BeeSPvn26tXrT2/G9PaHOe255575DZm+h8QLekKsf6pdJeyBWdIXUOmvWU1xNttss3z375deein/uYtQqT8DQu9N/OWXX2b/OVTyHpjTN998k39PcHpL8qz3l1eLqi1o02YZ3bt3z7/vIr2Kk74RPF1M6edZpp9n9mcay6R/J+0qlnaWDPmj93GkZ9PHjRtX1OOkn7U56w3g6QZLN+Siiy76mzGLLLLI7Hv+4fesf6pdpeyBQtKrtenVg/R9Wx06dCjquFSPSln/7dq1y/dQ+L2vv/569s8IqOQ9MGcBvsUWW+S/rmeeeabq1n7VFrSjRo1Kvv/+++Suu+7Kfwj9LJ988klwfHobV9p0Y86zM7M+tLh9+/b5/6ab4c0330w22WSTks+Ip5340o5pxUhvWejRo0f+/6cfzZDeopA+oc/5oiWdb6rQWSKqm/VPtauUPVDIrFstf3+VASpp/Xfq1Cn/+/Tq2pyNodI7E2b9OVTyHph1R8I222yTn0/aLHOVVVZJqk3VFrSzLsOnncBmSd9vdPnllwfHz5gxI38LY9oJbdbY9PfpC+b0RXUq7Xz20EMPJddcc03ms//SFxfpmZvf365Q23vn08dMP2fquuuum32Pfuraa6/N3+//v170UJ2sf6pdpeyB9Ba2WXckzDJ9+vTkhhtuyN9uWY0vbKie9Z/eiXD++ecnV1999ezPoU1vQR4yZEjStWtXHY6p+D2QXtndbbfd8m+xSj+GKH3vbDWq2oI2fVN269at82+YPvLII/NnUoYNG/abhT2n9NL94MGD8/fJp/fM33rrrfl28OmT6KzW2elHhaTNONIGNenZkw022CC/0NI3aKf5o48+mqy99tp1eu/8mmuumW8X/q9//Su/2dLbJ9KzTukb00888cSqu+WA4lj/VLtK2QPpbcXp1an0CsPiiy+efw/VTTfdlH/MCy64IP9RFFCp6z8tWnfZZZf88316cif97M+hQ4fm55me6IRK3wPHHnts/qPa0iu06cf+3Hjjjb/58969eydVIVfF7brT9tnrrrturkWLFrnFFlss179//9yjjz6aaYmdtutOW8O/+uqrufXWWy/XvHnzfLvtyy67LPMYaevuwYMH58c3a9Ys17p161znzp1zgwYNyk2aNKlgu+7aSB9z4MCB+WM2adIkt/zyy+cuuuiiOjk2lcP6p9pV4h645ZZbcr169cotuuiiucaNG+cfL/39vffeW+tjU1kqcf2npkyZkjvuuONybdu2zT/mOuusk3vkkUfq5NhUlkrcA7M+Uigp8Kta1KT/U+6iGgAAAErVqOS/AQAAAA2AghYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhS42IH1tTU1O9M4A+U++OSrX+qef2n7AGqeQ9Y/1Tz+k/ZAzTkPeAKLQAAAFFS0AIAABAlBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABRalzuCZC12WabBfPDDz88mG+66abBfIMNNshkr7/+ei1nB6Vr3759ML/lllsy2VlnnRUce//999f5vAAAiJsrtAAAAERJQQsAAECUFLQAAABESUELAABAlBS0AAAAREmX4zLr1atXJrv77ruDYz/66KNgvtpqqwXzDz/8sJazg9I0b948mA8bNiyYv/fee5nswQcfrPN5AQCVa/755w/mxx9/fCbbeuutg2PXXHPNYP7tt98G86uuuiqTffXVV8Gx1113XTCfPn16MKc0rtACAAAQJQUtAAAAUVLQAgAAECUFLQAAAFFS0AIAABClmlwulytqYE1N/c+mgi2xxBLB/J133slkzzzzTHDs/vvvH8zHjRuXVLoil2m9sf6Lc9BBBwXzY445Jpivs846meynn36q83nFrtzrP4Y9sPzyywfzZs2aZbLlllsuOHbbbbcN5vvtt1/R85g4cWIwP+OMM4L5TTfdVHRHzWpW7j3Q0Nd/Ka89br311uDY9dZbr6Rjv/LKK5nsgQceKHqdpz7//PNMputrw1v/DWkPhJ7T/+i1c+fOnZOG4Isvvgjmt9xySya79tprg2Or+dNLcv9jD7hCCwAAQJQUtAAAAERJQQsAAECUFLQAAABESUELAABAlHQ5rmOFuqlddtllwXz06NGZrF+/fnU+r9iVu8Of9V9cV9lXX301OPbss88O5oMHD67zeVWicq//cu2B008/PZNtsMEGwbFrr712MJ933nmL/n5OnTo1mN9///3BfPPNN89kCyywQHBsocd88803G2xXzoak3Hsg1p8Bb731ViZbccUVg2M//fTTYL7ooosG80JrvRQjRozIZAcccEBw7H/+85+kWpV7/TekPdC0adNgfueddwbzVVddNZNdcsklJT3mwgsvHMwPOeSQTDb//PMHxzZp0qToxyu0F0M/c6ql+3FOl2MAAAAqkYIWAACAKCloAQAAiJKCFgAAgChpClXHhg0bFsxXWWWVYK75RxwNEaz/rJtvvjmTLb300sGxG220UTD/9ddf63xelajc67++90CowVjq2WefLbo5RyGff/55JhsyZEhw7OTJk4P5Aw88EMxHjRqVydq0aVPSv+H7779f9M+LalbuPRDrz4Dtt9++6KY1Dz30UDBv1apVML/qqqsyWc+ePZPa+u9//xvM99xzz6L3YaUp9/qPeQ/MbVtttVUw//vf/150Q8NCja8K7dEdd9wxmE+fPj2pFJpCAQAAUJEUtAAAAERJQQsAAECUFLQAAABESUELAABAlHQ5LkLr1q2D+f/93/9lsu222y44dsCAAcH8oosuquXsqkO5O/xV8/rfcMMNg/mTTz5ZdHfWjz76qM7nVU3Kvf7LtQcOOuigTHbFFVcExx533HHB/P777y96PRbq5vrWW28F88UXX7zo79Pzzz8fzHv16pXJpk6dGhxbzcq9B6rhZ8BSSy0VzC+44IKiO6v++OOPwbFjxowJ5qGu5SuuuGJw7JQpU4L5kUceGcyvvfbapFKUe/1Xyx4oh1CX7m7dupV0jNVXX72kfRcjXY4BAACoSApaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKjcs9gYZknnnmCebXXXddMN9iiy0y2b777hsce+utt9ZydlAehx56aDC/5ZZbMpluxtSlIUOGZLLnnnsuOLbQ2iulY3CzZs2K7mZcqkcffTSY62jM3NakSZNgfvzxxxfdzTj11VdfZbLDDjssOPa+++4ren59+/YN5meddVYwP+GEE4L522+/ncleeumloucBc8P48ePLPYWK4AotAAAAUVLQAgAAECUFLQAAAFFS0AIAABAlTaHmcPXVVwfz7bbbLpifeOKJmUzzJ2K1ww47BPNdd901mPfq1aueZ0S1mzFjRiZ799136+3xvv3222B+0EEHBfPLL7+86MZSG2ywQTBfZJFFip4H1IV+/fqV1ACwkFBjzHfeeSeprauuuiqYL7fccsH82GOPDeaPPPJIJuvUqVNw7GeffVbSHKGufPPNN+WeQkVwhRYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgChVZZfjYcOGBfO99tormJ977rkl5RCj5s2bB/NCXWWfeeaZep4RNAxDhgwJ5quuumom+9vf/hYcu+mmmwbzJ598MpMdccQRwbGjRo36HzOF3zr44IMz2aBBg4JjZ86cGcxPOumkYF6fHcdDTj755GC+4IILBvP99tuvqM7Hf9S1/8svvyxpjlBI6OdFaueddy76GBMnTgzmU6ZMSaqdK7QAAABESUELAABAlBS0AAAARElBCwAAQJQUtAAAAESpJpfL5YoaWFOTxGjFFVfMZK+99lpw7D333BPMDzjggGA+bdq0Ws6OYhW5TOtNrOu/FFdccUVJXR7POOOMep4RDWX9V8seqAsHHnhgSV3xF1hggUw2ffr04NhC3Y+vvfbapNKVew/Euv7feuutorutvvTSS8F8/fXXT2IUek23zTbbBMeeeuqpwfzMM89MGoJyr/+Y98DcNv/88wfzK6+8MpjvvvvuRR/7sssuC+ZHHXVUUun+1x5whRYAAIAoKWgBAACIkoIWAACAKCloAQAAiJKCFgAAgChVfJfjjz76KJMtvvjiwbFrrbVWMH/33XeThtqxObXtttsWfYyRI0cG80KdnxuKcnf4i3X9l2L06NHB/I477qi3LsddunTJZKecckpwbI8ePYL5r7/+Gsw333zzort4NnTlXv/Vsgfq0xJLLBHM//Wvf2Wynj17lnTsgw46KJgPGTIkqRTl3gMNff23adMmmL/88suZbKmllgqO/fvf/x7MzzvvvCRGnTt3Lur7kfr222+D+VZbbZXJXn/99aTa1n8Me6ChuPTSS4P5oYceWvQxbr/99pKe63/88cek0ulyDAAAQEVS0AIAABAlBS0AAABRUtACAAAQpcZJhVtmmWUy2WGHHdagmz/tsssuwfyGG24I5k2bNi362GPGjAnmXbt2DeZTpkwp+tjE7bPPPqu3Y6+99tpFNz6Yd955g2MLNfNYf/31g3mfPn0qpikU8fviiy+C+ZZbbpnJzjrrrODYY445JphffvnlRc+jkhpF8f9ZffXVg3moAdTkyZODY5988smk0hsdnn766cGxhZoR7rDDDg2iKRTlNd988wXzCy64IJPttNNOJR17/PjxmWzgwIFV2/zpz3KFFgAAgCgpaAEAAIiSghYAAIAoKWgBAACIkoIWAACAKEXX5bhZs2bBfNiwYUV3Dxs+fHjSUOa933771apjZeqFF14ouhva5ptvHhy7wAILBHNdjqvHIossEsxXXHHFWq/z8847L5j/+9//zmR77rlncOx3330XzM8888xg3rp16z+YKTQMM2bMyGT9+/cPjl133XWDebdu3YL5tddem8nmmWeeoscSv5qamkz2888/B8dWWvfeX3/9NZNdddVVwbH77rtvMO/Ro0edz4uGq9CnLFx99dXBfLfddiv62KF6JLXXXntlsrFjxxZ9XP4fV2gBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhSdF2OC3Uu3WmnnYL54MGDM9mECROSuW3nnXcO5v/85z8z2X/+85+iv5bUNddcE8znn3/+TPb999//j5lSre65555g/ve//z2YN2nSJJN16tQpOLZQvs466xTdzbiQp59+Ophvt912JR0HGro99tgjmBf6mZHL5TLZ6aefHhyry3FlCq2BUFYtvv7665JeRw0YMCCTbbrppsGxjz/+eC1nx9zSsmXLkp4Hd91113rpZpx67LHHij42hblCCwAAQJQUtAAAAERJQQsAAECUFLQAAABEKbqmUKU677zz6u3YzZo1y2SXX355cOwuu+wSzEMNOi655JLg2ELNrDbbbLNgfvXVV2eyESNGBMdqFsWHH34YzFu1ahXMt95660zWvHnz4NiXX365pMcsxQ477BDMZ86cWetjQ0Py5Zdf1voYLVq0CObLLbdcMP/oo49q/Zg0LAsttFAw32KLLYL5ww8/nFS6jz/+OJjPM888mezEE08MjtUUqmEKPef961//Kum1eiGh1+V77rlncKz1Ub9coQUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIErRdTku1On39ddfD+adO3fOZE888USdzKV79+6ZbN999w2OnTRpUjC/6qqriv4aW7duXXQ349T06dMz2YABA4JjZ8yYEcypHm+//XYwHz9+fDA//vjjM9l1112X1JemTZsG83322afo+UHMFl988Vofo0mTJsG8TZs2wVyX4ziMGjUqmL/77ruZbJVVVgmObdu2bZ3Pq5q6RNMwbbTRRrXuZjxx4sRgvscee2Qy3YzLwxVaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACiFF2X46lTp5bUibU+uxyHuqh+/vnnwbFrrrlmMA91NO7Tp09w7HHHHRfMF1100WC+9957Z7LnnnsuOBbef//9YP7www8H87322iuTNW/ePDh23LhxtZxduJvgHx372muvrfVjQkPqZvzII4/U+tiTJ08uuhsu8fj111+DeS6XK/oYBx10UDAfMmRIUul0eI5foU8Cue2222p97L///e/BXEfjhsMVWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAohRdl+NCrrnmmmB+4403ZrJvv/02OHb48OHBfJVVVgnmm2yySSbbaqutgmPnn3/+YP7kk09msk6dOgXHPvvss8F8tdVWC+YffvhhMIdS/POf/wzm22yzTdFrd+bMmcF80KBBmaxZs2bBsf369Qvmp556akkd0aku7dq1y2SNG4d/9BXqUl+fQnumUFfO5ZZbLpg3atSo6H132mmnBcf+8MMP/2OmxOjss8/OZEOHDg2O7dixYzDfYYcdgvndd9+dVIpdd9216LF10TWXulfoeXC++ear9bG7dOkSzCdNmtSg10eHDh0yWc+ePYNj//a3vwXzm2++OZMNHDgwaWhcoQUAACBKCloAAACipKAFAAAgSgpaAAAAolSTy+VyRQ2sqUliFGo6U+iNz19++WUwb9GiRTBfaqmlMtmYMWOCY//yl78E8+eee67o5lQPPPBAMJ8xY0ZS6YpcpvUm1vVfn/bbb79MduGFFwbHtmrVqujva6F/60J7q1BTtEpS7vUfwx5Yfvnlg/nIkSMz2aeffhoc26tXr1o3GCvUXKZQQ6dQU7MmTZokdfFvE1o3oSZZqXHjxiUNWbn3QENf/6U455xzgvkxxxxT0muMCy64oOjXKa+++mow//XXX5P6EnrdVagpWt++fYP5zz//nMkWX3zx4NgpU6Yklbr+Y9gDhZpCHXjggZnsiiuuqJPHDK3fhtRgr2nTppls3nnnLekYF198cSY79thjk7ntf+0BV2gBAACIkoIWAACAKCloAQAAiJKCFgAAgCgpaAEAAIhSxXc5DunatWsw79evX0nHCXWt/Oijj4JjH3zwwWA+YsSITPb999+XNI9qUO4Of5W0/svRafaAAw4I5gcddFAmmzhxYkn7thr2S7nXfwx74Omnnw7mG2ywQdHfz/PPPz+Yt27dOphvt912maxNmzZz/d8w1Ik1dffdd2ey/fffPzh25syZSUNW7j3Q0Nd/Xdhjjz2C+dVXXx3MW7ZsWfSx//nPfxZ97OnTpwfH/vvf/w7mO+ywQzA/5ZRTMtkaa6wRHDtp0qRgPmDAgEx26aWXJtW2/mPeA6Hux6HOx3XZ/biSXKzLMQAAANQfBS0AAABRUtACAAAQJQUtAAAAUVLQAgAAEKWq7HJMfMrd4c/6p5rXfwx7INR1PjVq1KhM1q5du7n+faqLf8MhQ4YE83POOSeYF+q6H6Ny74GGvv7r09prrx3M+/fvn8l22mmnWj/e1KlTg/nzzz8fzLt06RLM55133qIfc9dddw3md955Z9IQlHv9V9oeKPS1LLroorX+FJTFF188mO+33361fq7/8ssvk9oaPXp0MH/ggQeK7oD/66+/JnObLscAAABUJAUtAAAAUVLQAgAAECUFLQAAAFHSFIoolLshgvVPNa//mPfAKqusksnOP//84NjNNtus1o83ceLEYH7GGWcE88cee6zoYxdq8lSoiU4lKfceiHX916dGjbLXRJZaaqmSmrZtvfXWRT/eNttsE8wnT54czJ988slM9sgjjxTdPK4h7a1yr/+UPUA5aQoFAABARVLQAgAAECUFLQAAAFFS0AIAABAlBS0AAABR0uWYKJS7w5/1TzWv/5Q9QDXvAeufal7/KXuActLlGAAAgIqkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSjW5XC5X7kkAAABAqVyhBQAAIEoKWgAAAKKkoAUAACBKCloAAACipKAFAAAgSgpaAAAAoqSgBQAAIEoKWgAAAKKkoAUAACCJ0f8PVuZNhiYH2eMAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_batch = test_ds.as_numpy_iterator().next()\n", "pred = pred_step(eval_model, test_batch)\n", "\n", "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", "for i, ax in enumerate(axs.flatten()):\n", " ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n", " ax.set_title(f'label={pred[i]}')\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "id": "65342ab4", "metadata": {}, "source": [ "# 8. Export the model\n", "\n", "Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method." ] }, { "cell_type": "code", "execution_count": null, "id": "49cace09", "metadata": {}, "outputs": [], "source": [ "from orbax.export import JaxModule, ExportManager, ServingConfig" ] }, { "cell_type": "code", "execution_count": null, "id": "421309d4", "metadata": {}, "outputs": [], "source": [ "def exported_predict(model, y):\n", " return model(y, None)\n", "\n", "jax_module = JaxModule(eval_model, exported_predict)" ] }, { "cell_type": "markdown", "id": "787136af", "metadata": {}, "source": [ "We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`." ] }, { "cell_type": "code", "execution_count": null, "id": "9f2ad72e", "metadata": {}, "outputs": [], "source": [ "sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]" ] }, { "cell_type": "markdown", "id": "31e9668a", "metadata": {}, "source": [ "Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class." ] }, { "cell_type": "code", "execution_count": null, "id": "18cdf9ad", "metadata": {}, "outputs": [], "source": [ "export_mgr = ExportManager(jax_module, [\n", " ServingConfig('mnist_server', input_signature=sig)\n", "])\n", "\n", "output_dir='/tmp/mnist_export'\n", "export_mgr.save(output_dir)" ] }, { "cell_type": "markdown", "id": "28", "metadata": {}, "source": [ "Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.\n", "\n", "Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html)." ] } ], "metadata": { "accelerator": "GPU", "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: docs_nnx/mnist_tutorial.md ================================================ --- jupytext: formats: ipynb,md:myst main_language: python text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb) [![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb) # MNIST tutorial Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API. Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning. Let’s get started! ## 1. Install Flax If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook): ```{code-cell} ipython3 # !pip install -U "jax[cuda12]" # !pip install -U flax ``` ## 2. Load the MNIST dataset First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance. ```{code-cell} ipython3 import tensorflow_datasets as tfds # TFDS to download MNIST. import tensorflow as tf # TensorFlow / `tf.data` operations. tf.random.set_seed(0) # Set the random seed for reproducibility. train_steps = 1200 eval_every = 200 batch_size = 32 train_ds: tf.data.Dataset = tfds.load('mnist', split='train') test_ds: tf.data.Dataset = tfds.load('mnist', split='test') train_ds = train_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # normalize train set test_ds = test_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # Normalize the test set. # Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from. train_ds = train_ds.repeat().shuffle(1024) # Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1) # Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency. test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) ``` ## 3. Define the model with Flax NNX Create a CNN for classification with Flax NNX by subclassing `nnx.Module`: ```{code-cell} ipython3 from flax import nnx # The Flax NNX API. from functools import partial from typing import Optional class CNN(nnx.Module): """A simple CNN model.""" def __init__(self, *, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) self.dropout1 = nnx.Dropout(rate=0.025) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.dropout2 = nnx.Dropout(rate=0.025) self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x, rngs: nnx.Rngs | None = None): x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs)))) x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) x = x.reshape(x.shape[0], -1) # flatten x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs)) x = self.linear2(x) return x # Instantiate the model. model = CNN(rngs=nnx.Rngs(0)) # Visualize it. nnx.display(model) ``` ### Run the model Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results. ```{code-cell} ipython3 import jax.numpy as jnp # JAX NumPy y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0)) y ``` ## 4. Create the optimizer and define some metrics In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. ```{code-cell} ipython3 import optax learning_rate = 0.005 momentum = 0.9 optimizer = nnx.Optimizer( model, optax.adamw(learning_rate, momentum), wrt=nnx.Param ) metrics = nnx.MultiMetric( accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'), ) nnx.display(optimizer) ``` ## 5. Define training step functions In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over. In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. During training — the `train_step` — you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. The `train_step` also receives an `nnx.Rngs` object to provide randomness for dropout. The `eval_step` omits `rngs` because the eval view already has `deterministic=True`, so dropout is disabled and no random key is needed. During both steps, the `loss` and `logits` are used to update the metrics. ```{code-cell} ipython3 def loss_fn(model: CNN, batch, rngs: nnx.Rngs | None = None): logits = model(batch['image'], rngs) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] ).mean() return loss, logits @nnx.jit def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, rngs: nnx.Rngs, batch): """Train for a single step.""" grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch, rngs) metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. optimizer.update(model, grads) # In-place updates. @nnx.jit def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): loss, logits = loss_fn(model, batch) metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. ``` In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a stateful version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad` is a stateful version of `jax.value_and_grad`. Check out [the transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more. > **Note:** The code shows how to perform several in-place updates to the model, the optimizer, the RNG streams, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). ## 6. Train and evaluate the model Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation. ```{code-cell} ipython3 from IPython.display import clear_output import matplotlib.pyplot as plt metrics_history = { 'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': [], } rngs = nnx.Rngs(0) train_model = nnx.view(model, deterministic=False, use_running_average=False) eval_model = nnx.view(model, deterministic=True, use_running_average=True) for step, batch in enumerate(train_ds.as_numpy_iterator()): # Run the optimization for one step and make a stateful update to the following: # - The train state's model parameters # - The optimizer state # - The training loss and accuracy batch metrics train_step(train_model, optimizer, metrics, rngs, batch) if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed. # Log the training metrics. for metric, value in metrics.compute().items(): # Compute the metrics. metrics_history[f'train_{metric}'].append(value) # Record the metrics. metrics.reset() # Reset the metrics for the test set. # Compute the metrics on the test set after each training epoch. for test_batch in test_ds.as_numpy_iterator(): eval_step(eval_model, metrics, test_batch) # Log the test metrics. for metric, value in metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) metrics.reset() # Reset the metrics for the next training epoch. clear_output(wait=True) # Plot loss and accuracy in subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) ax1.set_title('Loss') ax2.set_title('Accuracy') for dataset in ('train', 'test'): ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss') ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy') ax1.legend() ax2.legend() plt.show() ``` ## 7. Perform inference on the test set Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance. ```{code-cell} ipython3 @nnx.jit def pred_step(model: CNN, batch): logits = model(batch['image'], None) return logits.argmax(axis=1) ``` We reuse the `eval_model` view created earlier so that `Dropout` is disabled and `BatchNorm` uses stored running stats during inference. ```{code-cell} ipython3 test_batch = test_ds.as_numpy_iterator().next() pred = pred_step(eval_model, test_batch) fig, axs = plt.subplots(5, 5, figsize=(12, 12)) for i, ax in enumerate(axs.flatten()): ax.imshow(test_batch['image'][i, ..., 0], cmap='gray') ax.set_title(f'label={pred[i]}') ax.axis('off') ``` # 8. Export the model Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method. ```{code-cell} ipython3 from orbax.export import JaxModule, ExportManager, ServingConfig ``` ```{code-cell} ipython3 def exported_predict(model, y): return model(y, None) jax_module = JaxModule(eval_model, exported_predict) ``` We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`. ```{code-cell} ipython3 sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)] ``` Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class. ```{code-cell} ipython3 export_mgr = ExportManager(jax_module, [ ServingConfig('mnist_server', input_signature=sig) ]) output_dir='/tmp/mnist_export' export_mgr.save(output_dir) ``` Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset. Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html). ================================================ FILE: docs_nnx/nnx_basics.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Flax basics\n", "\n", "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n", "\n", "To begin, install Flax with `pip` and import necessary dependencies:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "# ! pip install -U flax" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from flax import nnx\n", "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Flax NNX Module system\n", "\n", "The main difference between the Flax `Module` and other Module systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that the NNX Module itself holds the state (such as parameters) directly, the [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user, and all shape information must be provided on initialization (no shape inference).\n", "\n", "Let's begin by creating a Linear `Module`. As shown next, dynamic state is usually stored in `Param`s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic states, although storing them inside Variables, such as Param, is preferred. Also the `Rngs` object can be used to get new unique keys based on a root PRNG key passed to the constructor." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Linear(nnx.Module):\n", " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", " self.w = nnx.Param(rngs.params.uniform((din, dout)))\n", " self.b = nnx.Param(jnp.zeros((dout,)))\n", " self.din, self.dout = din, dout\n", "\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w + self.b[None]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).\n", "\n", "To initialize a Flax `Module`, you just call the constructor, and all the parameters of a Module are usually created eagerly. Since Modules hold their own state methods, you can call them directly without the need for a separate apply method.\n", "This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1.5643291 0.94782424 0.37971854 1.0724319 0.22112393]]\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = Linear(2, 5, rngs=nnx.Rngs(params=0))\n", "y = model(x=jnp.ones((1, 2)))\n", "\n", "print(y)\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above visualization by `nnx.display` is generated using the awesome\n", "[Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stateful computation\n", "\n", "Implementing layers, such as `BatchNorm`, requires performing state updates during a forward pass. In Flax NNX, you just need to create a `Variable` and update its `.value` during the forward pass." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "counter.count[...] = Array(0, dtype=int32, weak_type=True)\n", "counter.count[...] = Array(1, dtype=int32, weak_type=True)\n" ] } ], "source": [ "class Count(nnx.Variable): pass\n", "\n", "class Counter(nnx.Module):\n", " def __init__(self):\n", " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", " self.count[...] += 1\n", "\n", "counter = Counter()\n", "print(f'{counter.count[...] = }')\n", "counter()\n", "print(f'{counter.count[...] = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms\n", "to handle them, as demonstrated in later sections of this guide." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Nested Modules\n", "\n", "Flax `Module`s can be used to compose other Modules in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.\n", "\n", "The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. Note that we need to pass the `__call__` method the RNG state that we want the `Dropout` layer to use." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class MLP(nnx.Module):\n", " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n", " self.linear1 = Linear(din, dmid, rngs=rngs)\n", " self.dropout = nnx.Dropout(rate=0.1)\n", " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", " self.linear2 = Linear(dmid, dout, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array, rngs: nnx.Rngs):\n", " x = nnx.gelu(self.dropout(self.bn(self.linear1(x)), rngs=rngs))\n", " return self.linear2(x)\n", "\n", "model = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", "\n", "y = model(x=jnp.ones((3, 2)), rngs=nnx.Rngs(1))\n", "\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model surgery\n", "\n", "Flax `Module`s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-Module attribute can be replaced with anything else, such as new Modules, existing shared Modules, Modules of different types, and so on. Moreover, `Variable`s can also be modified or replaced/shared.\n", "\n", "The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class LoraParam(nnx.Param): pass\n", "\n", "class LoraLinear(nnx.Module):\n", " def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):\n", " self.linear = linear\n", " self.A = LoraParam(rngs.normal((linear.din, rank)))\n", " self.B = LoraParam(rngs.normal((rank, linear.dout)))\n", "\n", " def __call__(self, x: jax.Array):\n", " return self.linear(x) + x @ self.A @ self.B\n", "\n", "rngs = nnx.Rngs(0)\n", "model = MLP(2, 32, 5, rngs=rngs)\n", "\n", "# Model surgery.\n", "model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)\n", "model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)\n", "\n", "y = model(x=jnp.ones((3, 2)), rngs=rngs)\n", "\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Flax transformations\n", "\n", "[Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html) extend [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations) to support `Module`s and other objects. They serve as supersets of their equivalent JAX counterparts with the addition of being aware of the object's state and providing additional APIs to transform it.\n", "\n", "One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outside as long as it is legal within the transform rules. In practice this means that Flax programs can be expressed using imperative code, highly simplifying the user experience.\n", "\n", "In the following example, you define a `train_step` function that takes a `MLP` model, an `Optimizer`, and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the `nnx.value_and_grad` transform over the `loss_fn`. The gradients are passed to the optimizer's `update` method to update the model's parameters." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss = Array(1.0000616, dtype=float32)\n", "optimizer.step.value = Array(1, dtype=uint32)\n" ] } ], "source": [ "import optax\n", "\n", "# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.\n", "model = MLP(2, 16, 10, rngs=nnx.Rngs(0))\n", "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", "\n", "@nnx.jit # Automatic state management\n", "def train_step(model, optimizer, x, y, rngs):\n", " def loss_fn(model: MLP, rngs: nnx.Rngs):\n", " y_pred = model(x, rngs)\n", " return jnp.mean((y_pred - y) ** 2)\n", "\n", " loss, grads = nnx.value_and_grad(loss_fn)(model, rngs)\n", " optimizer.update(model, grads) # In place updates.\n", "\n", " return loss\n", "\n", "x, y = jnp.ones((5, 2)), jnp.ones((5, 10))\n", "loss = train_step(model, optimizer, x, y, rngs)\n", "\n", "print(f'{loss = }')\n", "print(f'{optimizer.step.value = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are two things happening in this example that are worth mentioning:\n", "\n", "1. The updates to each of the `BatchNorm` and `Dropout` layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n", "2. The `optimizer` holds a mutable reference to the model - this relationship is preserved inside the train_step function making it possible to update the model's parameters using the optimizer alone.\n", "\n", "> **Note**
`nnx.jit` has performance overhead for small models, check the [Performance Considerations](https://flax.readthedocs.io/en/latest/guides/performance.html) guide for more information.\n", "\n", "### Scan over layers\n", "\n", "The next example uses Flax `nnx.vmap` to create a stack of multiple MLP layers and `nnx.scan` to iteratively apply each layer of the stack to the input.\n", "\n", "In the code below notice the following:\n", "\n", "1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.\n", "2. The `nnx.scan` is used to iteratively apply each `MLP` in the stack to the input `x`.\n", "3. The nnx.scan (consciously) deviates from `jax.lax.scan` and instead mimics nnx.vmap, which is more expressive. nnx.scan allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.\n", "4. `State` updates for `BatchNorm` layers are automatically propagated by nnx.scan.\n", "5. The `rngs` object is split into separate streams for each layer using the `fork` method." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "@nnx.vmap(in_axes=0, out_axes=0)\n", "def create_model(rngs):\n", " return MLP(10, 32, 10, rngs=rngs)\n", "\n", "@nnx.scan(in_axes=(0, 0, nnx.Carry), out_axes=nnx.Carry)\n", "def forward(model: MLP, rngs: nnx.Rngs, x):\n", " x = model(x, rngs)\n", " return x\n", " \n", "param_rngs = nnx.Rngs(0).fork(split=5)\n", "model = create_model(param_rngs)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y.shape = (3, 10)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x = jnp.ones((3, 10))\n", "dropout_rngs = nnx.Rngs(1).fork(split=5)\n", "y = forward(model, dropout_rngs, x)\n", "\n", "print(f'{y.shape = }')\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Flax Functional API\n", "\n", "The Flax NNX Functional API establishes a clear boundary between reference/object semantics and value/pytree semantics. It also allows the same amount of fine-grained control over the state that Flax Linen and Haiku users are used to. The Flax NNX Functional API consists of three basic methods: `nnx.split`, `nnx.merge`, and `nnx.update`.\n", "\n", "Below is an example of `StatefulLinear` `Module` that uses the Functional API. It contains:\n", "\n", "- Some `Param` Variables; and\n", "- A custom `Count` Variable type, which is used to track the integer scalar state that increases on every forward pass." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "class Count(nnx.Variable): pass\n", "\n", "class StatefulLinear(nnx.Module):\n", " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", " self.w = nnx.Param(rngs.uniform((din, dout)))\n", " self.b = nnx.Param(jnp.zeros((dout,)))\n", " self.count = Count(jnp.array(0, dtype=jnp.uint32))\n", "\n", " def __call__(self, x: jax.Array):\n", " self.count.value += 1\n", " return x @ self.w + self.b\n", "\n", "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", "y = model(jnp.ones((1, 3)))\n", "\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### State and GraphDef\n", "\n", "A Flax `Module` can be decomposed into `State` and `GraphDef` using the `nnx.split` function:\n", "\n", "- `State` is a `Mapping` from strings to `Variable`s or nested `State`s.\n", "- `GraphDef` contains all the static information needed to reconstruct a `Module` graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "graphdef, state = nnx.split(model)\n", "\n", "nnx.display(graphdef, state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split, merge, and update\n", "\n", "Flax's `nnx.merge` is the reverse of `nnx.split`. It takes the `GraphDef` + `State` and reconstructs the `Module`. The example below demonstrates this as follows:\n", "\n", "- By using `nnx.split` and `nnx.merge` in sequence any `Module` can be lifted to be used in any JAX transform.\n", "- `nnx.update` can update an object in place with the content of a given `State`.\n", "- This pattern is used to propagate the state from a transform back to the source object outside." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model.count.value = Array(1, dtype=uint32)\n", "model.count.value = Array(2, dtype=uint32)\n" ] } ], "source": [ "print(f'{model.count.value = }')\n", "\n", "# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.\n", "graphdef, state = nnx.split(model)\n", "\n", "@jax.jit\n", "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:\n", " # 2. Use `nnx.merge` to create a new model inside the JAX transformation.\n", " model = nnx.merge(graphdef, state)\n", " # 3. Call the `nnx.Module`\n", " y = model(x)\n", " # 4. Use `nnx.split` to propagate `nnx.State` updates.\n", " _, state = nnx.split(model)\n", " return y, state\n", "\n", "y, state = forward(graphdef, state, x=jnp.ones((1, 3)))\n", "# 5. Update the state of the original `nnx.Module`.\n", "nnx.update(model, state)\n", "\n", "print(f'{model.count.value = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries.\n", "\n", "**Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two `Module`s that have a shared Module through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fine-grained State control\n", "\n", "Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This is a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", "\n", "For example:\n", "\n", "- Not every model state can or should be differentiated when interacting with `jax.grad`.\n", "- Or, sometimes, there is a need to specify what part of the model's state is a carry and what part is not when using `jax.lax.scan`.\n", "\n", "To address this, the Flax NNX API has `nnx.split`, which allows you to pass one or more `Filter`s to partition the `Variable`s into mutually exclusive `State`s. Flax NNx uses `Filter` create `State` groups in APIs (such as `nnx.split`, `nnx.state`, and many of NNX transforms).\n", "\n", "The example below shows the most common `Filter`s:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.\n", "graphdef, params, counts = nnx.split(model, nnx.Param, Count)\n", "\n", "nnx.display(params, counts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note:** `Filter`s must be exhaustive, if a value is not matched an error will be raised.\n", "\n", "As expected, the `nnx.merge` and `nnx.update` methods naturally consume multiple `State`s:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Merge multiple `State`s\n", "model = nnx.merge(graphdef, params, counts)\n", "# Update with multiple `State`s\n", "nnx.update(model, params, counts)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: docs_nnx/nnx_basics.md ================================================ --- jupytext: formats: ipynb,md:myst text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.13.8 --- # Flax basics Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. To begin, install Flax with `pip` and import necessary dependencies: ```{code-cell} ipython3 :tags: [skip-execution] # ! pip install -U flax ``` ```{code-cell} ipython3 from flax import nnx import jax import jax.numpy as jnp ``` ## The Flax NNX Module system The main difference between the Flax `Module` and other Module systems in [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html) or [Haiku](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html#Built-in-Haiku-nets-and-nested-modules) is that in NNX everything is **explicit**. This means, among other things, that the NNX Module itself holds the state (such as parameters) directly, the [PRNG](https://jax.readthedocs.io/en/latest/random-numbers.html) state is threaded by the user, and all shape information must be provided on initialization (no shape inference). Let's begin by creating a Linear `Module`. As shown next, dynamic state is usually stored in `Param`s, and static state (all types not handled by NNX) such as integers or strings are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic states, although storing them inside Variables, such as Param, is preferred. Also the `Rngs` object can be used to get new unique keys based on a root PRNG key passed to the constructor. ```{code-cell} ipython3 class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(rngs.params.uniform((din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.din, self.dout = din, dout def __call__(self, x: jax.Array): return x @ self.w + self.b[None] ``` Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). To initialize a Flax `Module`, you just call the constructor, and all the parameters of a Module are usually created eagerly. Since Modules hold their own state methods, you can call them directly without the need for a separate apply method. This can be very convenient for debugging, allowing you to directly inspect the entire structure of the model. ```{code-cell} ipython3 model = Linear(2, 5, rngs=nnx.Rngs(params=0)) y = model(x=jnp.ones((1, 2))) print(y) nnx.display(model) ``` The above visualization by `nnx.display` is generated using the awesome [Treescope](https://treescope.readthedocs.io/en/stable/index.html#) library. +++ ### Stateful computation Implementing layers, such as `BatchNorm`, requires performing state updates during a forward pass. In Flax NNX, you just need to create a `Variable` and update its `.value` during the forward pass. ```{code-cell} ipython3 class Count(nnx.Variable): pass class Counter(nnx.Module): def __init__(self): self.count = Count(jnp.array(0)) def __call__(self): self.count[...] += 1 counter = Counter() print(f'{counter.count[...] = }') counter() print(f'{counter.count[...] = }') ``` Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms to handle them, as demonstrated in later sections of this guide. +++ ### Nested Modules Flax `Module`s can be used to compose other Modules in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on. The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. Note that we need to pass the `__call__` method the RNG state that we want the `Dropout` layer to use. ```{code-cell} ipython3 class MLP(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = Linear(din, dmid, rngs=rngs) self.dropout = nnx.Dropout(rate=0.1) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.linear2 = Linear(dmid, dout, rngs=rngs) def __call__(self, x: jax.Array, rngs: nnx.Rngs): x = nnx.gelu(self.dropout(self.bn(self.linear1(x)), rngs=rngs)) return self.linear2(x) model = MLP(2, 16, 5, rngs=nnx.Rngs(0)) y = model(x=jnp.ones((3, 2)), rngs=nnx.Rngs(1)) nnx.display(model) ``` ### Model surgery Flax `Module`s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-Module attribute can be replaced with anything else, such as new Modules, existing shared Modules, Modules of different types, and so on. Moreover, `Variable`s can also be modified or replaced/shared. The following example shows how to replace the `Linear` layers in the `MLP` model from the previous example with `LoraLinear` layers: ```{code-cell} ipython3 class LoraParam(nnx.Param): pass class LoraLinear(nnx.Module): def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs): self.linear = linear self.A = LoraParam(rngs.normal((linear.din, rank))) self.B = LoraParam(rngs.normal((rank, linear.dout))) def __call__(self, x: jax.Array): return self.linear(x) + x @ self.A @ self.B rngs = nnx.Rngs(0) model = MLP(2, 32, 5, rngs=rngs) # Model surgery. model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs) model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs) y = model(x=jnp.ones((3, 2)), rngs=rngs) nnx.display(model) ``` ## Flax transformations [Flax NNX transformations (transforms)](https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html) extend [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations) to support `Module`s and other objects. They serve as supersets of their equivalent JAX counterparts with the addition of being aware of the object's state and providing additional APIs to transform it. One of the main features of Flax Transforms is the preservation of reference semantics, meaning that any mutation of the object graph that occurs inside the transform is propagated outside as long as it is legal within the transform rules. In practice this means that Flax programs can be expressed using imperative code, highly simplifying the user experience. In the following example, you define a `train_step` function that takes a `MLP` model, an `Optimizer`, and a batch of data, and returns the loss for that step. The loss and the gradients are computed using the `nnx.value_and_grad` transform over the `loss_fn`. The gradients are passed to the optimizer's `update` method to update the model's parameters. ```{code-cell} ipython3 import optax # An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer. model = MLP(2, 16, 10, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @nnx.jit # Automatic state management def train_step(model, optimizer, x, y, rngs): def loss_fn(model: MLP, rngs: nnx.Rngs): y_pred = model(x, rngs) return jnp.mean((y_pred - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model, rngs) optimizer.update(model, grads) # In place updates. return loss x, y = jnp.ones((5, 2)), jnp.ones((5, 10)) loss = train_step(model, optimizer, x, y, rngs) print(f'{loss = }') print(f'{optimizer.step.value = }') ``` There are two things happening in this example that are worth mentioning: 1. The updates to each of the `BatchNorm` and `Dropout` layer's state is automatically propagated from within `loss_fn` to `train_step` all the way to the `model` reference outside. 2. The `optimizer` holds a mutable reference to the model - this relationship is preserved inside the train_step function making it possible to update the model's parameters using the optimizer alone. > **Note**
`nnx.jit` has performance overhead for small models, check the [Performance Considerations](https://flax.readthedocs.io/en/latest/guides/performance.html) guide for more information. ### Scan over layers The next example uses Flax `nnx.vmap` to create a stack of multiple MLP layers and `nnx.scan` to iteratively apply each layer of the stack to the input. In the code below notice the following: 1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created. 2. The `nnx.scan` is used to iteratively apply each `MLP` in the stack to the input `x`. 3. The nnx.scan (consciously) deviates from `jax.lax.scan` and instead mimics nnx.vmap, which is more expressive. nnx.scan allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry. 4. `State` updates for `BatchNorm` layers are automatically propagated by nnx.scan. 5. The `rngs` object is split into separate streams for each layer using the `fork` method. ```{code-cell} ipython3 @nnx.vmap(in_axes=0, out_axes=0) def create_model(rngs): return MLP(10, 32, 10, rngs=rngs) @nnx.scan(in_axes=(0, 0, nnx.Carry), out_axes=nnx.Carry) def forward(model: MLP, rngs: nnx.Rngs, x): x = model(x, rngs) return x param_rngs = nnx.Rngs(0).fork(split=5) model = create_model(param_rngs) ``` ```{code-cell} ipython3 x = jnp.ones((3, 10)) dropout_rngs = nnx.Rngs(1).fork(split=5) y = forward(model, dropout_rngs, x) print(f'{y.shape = }') nnx.display(model) ``` How do Flax NNX transforms achieve this? To understand how Flax NNX objects interact with JAX transforms, the next section explains the Flax NNX Functional API. +++ ## The Flax Functional API The Flax NNX Functional API establishes a clear boundary between reference/object semantics and value/pytree semantics. It also allows the same amount of fine-grained control over the state that Flax Linen and Haiku users are used to. The Flax NNX Functional API consists of three basic methods: `nnx.split`, `nnx.merge`, and `nnx.update`. Below is an example of `StatefulLinear` `Module` that uses the Functional API. It contains: - Some `Param` Variables; and - A custom `Count` Variable type, which is used to track the integer scalar state that increases on every forward pass. ```{code-cell} ipython3 class Count(nnx.Variable): pass class StatefulLinear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(rngs.uniform((din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = Count(jnp.array(0, dtype=jnp.uint32)) def __call__(self, x: jax.Array): self.count.value += 1 return x @ self.w + self.b model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) y = model(jnp.ones((1, 3))) nnx.display(model) ``` ### State and GraphDef A Flax `Module` can be decomposed into `State` and `GraphDef` using the `nnx.split` function: - `State` is a `Mapping` from strings to `Variable`s or nested `State`s. - `GraphDef` contains all the static information needed to reconstruct a `Module` graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling). ```{code-cell} ipython3 graphdef, state = nnx.split(model) nnx.display(graphdef, state) ``` ### Split, merge, and update Flax's `nnx.merge` is the reverse of `nnx.split`. It takes the `GraphDef` + `State` and reconstructs the `Module`. The example below demonstrates this as follows: - By using `nnx.split` and `nnx.merge` in sequence any `Module` can be lifted to be used in any JAX transform. - `nnx.update` can update an object in place with the content of a given `State`. - This pattern is used to propagate the state from a transform back to the source object outside. ```{code-cell} ipython3 print(f'{model.count.value = }') # 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`. graphdef, state = nnx.split(model) @jax.jit def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]: # 2. Use `nnx.merge` to create a new model inside the JAX transformation. model = nnx.merge(graphdef, state) # 3. Call the `nnx.Module` y = model(x) # 4. Use `nnx.split` to propagate `nnx.State` updates. _, state = nnx.split(model) return y, state y, state = forward(graphdef, state, x=jnp.ones((1, 3))) # 5. Update the state of the original `nnx.Module`. nnx.update(model, state) print(f'{model.count.value = }') ``` The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but it is necessary to use the Functional API when crossing boundaries. **Why aren't Modules just pytrees?** The main reason is that it is very easy to lose track of shared references by accident this way, for example if you pass two `Module`s that have a shared Module through a JAX boundary, you will silently lose that sharing. Flax's Functional API makes this behavior explicit, and thus it is much easier to reason about. +++ ### Fine-grained State control Experienced [Flax Linen](https://flax-linen.readthedocs.io/) or [Haiku](https://dm-haiku.readthedocs.io/) API users may recognize that having all the states in a single structure is not always the best choice as there are cases in which you may want to handle different subsets of the state differently. This is a common occurrence when interacting with [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations). For example: - Not every model state can or should be differentiated when interacting with `jax.grad`. - Or, sometimes, there is a need to specify what part of the model's state is a carry and what part is not when using `jax.lax.scan`. To address this, the Flax NNX API has `nnx.split`, which allows you to pass one or more `Filter`s to partition the `Variable`s into mutually exclusive `State`s. Flax NNx uses `Filter` create `State` groups in APIs (such as `nnx.split`, `nnx.state`, and many of NNX transforms). The example below shows the most common `Filter`s: ```{code-cell} ipython3 # Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s. graphdef, params, counts = nnx.split(model, nnx.Param, Count) nnx.display(params, counts) ``` **Note:** `Filter`s must be exhaustive, if a value is not matched an error will be raised. As expected, the `nnx.merge` and `nnx.update` methods naturally consume multiple `State`s: ```{code-cell} ipython3 # Merge multiple `State`s model = nnx.merge(graphdef, params, counts) # Update with multiple `State`s nnx.update(model, params, counts) ``` ================================================ FILE: docs_nnx/nnx_glossary.rst ================================================ ***************** Flax NNX glossary ***************** For additional terms, refer to the `JAX glossary `__. .. glossary:: Filter A way to extract only certain :term:`nnx.Variable` objects out of a Flax NNX :term:`Module` (``nnx.Module``). This is usually done by calling :meth:`nnx.split ` upon the :class:`nnx.Module`. Refer to the `Filter guide `__ to learn more. Folding in In Flax, `folding in `__ means generating a new `JAX pseudorandom number generator (PRNG) `__ key, given an input PRNG key and integer. This is typically used when you want to generate a new key but still be able to use the original PRNG key afterwards. You can also do this in JAX with `jax.random.split `__, but this method will effectively create two PRNG keys, which is slower. Learn how Flax generates new PRNG keys automatically in the `Randomness/PRNG guide `__. GraphDef :class:`nnx.GraphDef` is a class that represents all the static, stateless, and Pythonic parts of a Flax :term:`Module` (:class:`nnx.Module`). Merge Refer to :term:`Split and merge`. Module :class:`nnx.Module ` is a dataclass that enables defining and initializing parameters in a referentially-transparent form. It is responsible for storing and updating :term:`Variable objects and parameters within itself. Params / parameters :class:`nnx.Param ` is a particular subclass of :class:`nnx.Variable ` that generally contains the trainable weights. PRNG states A Flax :class:`nnx.Module ` can keep a reference of a `pseudorandom number generator (PRNG) `__ state object :class:`nnx.Rngs ` that can generate new `JAX PRNG `__ keys. These keys are used to generate random JAX arrays through `JAX's functional PRNGs `__. You can use a PRNG state with different seeds to add more fine-grained control to your model (for example, to have independent random numbers for parameters and dropout masks). Refer to the Flax `Randomness/PRNG guide `__ for more details. Split and merge :meth:`nnx.split ` is a way to represent an :class:`nnx.Module ` by two parts: 1) a static Flax NNX :term:`GraphDef ` that captures its Pythonic static information; and 2) one or more :term:`Variable state(s)` that capture its `JAX arrays `__ (``jax.Array``) in the form of `JAX pytrees `__. They can be merged back to the original ``nnx.Module`` using :meth:`nnx.merge `. Transformation A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation `__ that allows the function that is being transformed to take the Flax NNX :term:`Module` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit `__ is :meth:`nnx.jit `. Check out the `Flax NNX transforms guide `__ to learn more. Variable The weights / parameters / data / array :class:`nnx.Variable ` residing in a Flax :term:`Module`. Variables are defined inside modules as :class:`nnx.Variable ` or its subclasses. ================================================ FILE: docs_nnx/philosophy.md ================================================ # The Flax philosophy In no particular order: * Library code should be easy to read and understand. * Prefer duplicating code over a bad abstraction. * Generally, prefer duplicating code over adding options to functions. * Comment-driven design: If it's hard to document your code, consider changing the design. * Unit test-driven design: If it's hard to test your code, consider changing the design. * People start projects by copying an existing implementation — make base implementations excellent. * If we expose an abstraction to our developers, we own the mental overhead. * Developer-facing functional programming abstractions confuse some users, expose them where the benefit is high. * "Read the manual" is not an appropriate response to developer confusion. The framework should guide developers towards good solutions, such as through assertions and error messages. * An unhelpful error message is a bug. * "Debugging is twice as hard as writing the code in the first place. Therefore, if you write the code as cleverly as possible, you are, by definition, not smart enough to debug it." — Brian Kernighan ## Design principles Flax is a neural network library built on [JAX](https://jax.readthedocs.io) that has been adopted by a growing set of users, most notably in the JAX submissions for the MLPerf 0.7 benchmark. Our experience over the last year (and many conversations with users and JAX core devs) has guided a redesign of the API called [Linen](https://github.com/google/flax/blob/main/flax/linen/README.md) ([`flax.linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)) in response to the following basic design questions. ### How does a neural network library benefit from being built on JAX and leverage JAX’s unique strengths? The world already has TensorFlow and PyTorch, and there’s little need to build a clone of either. We believe that the composable function-transformation approach that JAX takes opens up new frontiers for making neural net code more maintainable, more scalable and more performant than existing libraries. While we strive to offer an API familiar to those experienced with Keras/Sonnet/PyTorch, Linen is fundamentally a functional system for defining neural nets in JAX. Just a few examples of what we believe a JAX-targeted library can enable: - Write models as “single-example” code and introduce batching automatically with [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html). - Automatically handle ragged batches in NLP and other masking issues. - Create efficient compile-time and runtime models by utilizing rematerialized `scan` for massive convolutional networks. - Remove memory headaches by enabling easy rematerialization, reversibility, and model-parallel data sharding. ### How does one interoperate with JAX transformations? Arguably, the entire point of a neural net library is to offer an implicit variable management API to save the user from having to manually thread thousands of variables through a complex tree of functions. However, JAX operates on pure functions. To handle both current and future JAX transforms (configured and composed in any way), Linen Modules are directly “functionalized”, that is, automatically cast in-place as explicit functions of the form: $$f \left( v_{in}, x \right) \rightarrow v_{out}, y$$ Where $v_{in}$ is the variable collections and [PRNG](https://jax.readthedocs.io/en/latest/jep/263-prng.html) state used by the model, $v_{out}$ the mutated output variable collections, $x$ the input data and $y$ the output data. Applying JAX transformations then simply reduces to specifying any argument-specific transform options to the various variable collections and PRNG state. This unleashes the flexibility and strength of [JAX transformations](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) – for example, one can achieve either device-parallel training or per-device ensembling by using [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) in different ways, without any explicit library support. Moreover, **within [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module)**, we expose lightweight wrappers around the complex JAX transforms such as [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) and [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) that annotate how each variable collection is to be transformed by JAX. Importantly, we handle the nontrivial cases of creating new variables and transformed variables under mapping and loop transforms correctly for initialization and application. ### How are parameters represented, and how do we handle general “differentiable algorithms” that update stateful variables? We follow the JAX functional conventions of storing data in “pytrees”: JAX arrays contained in nested tuples, lists, dictionaries. Because researchers inevitably manually interact with this data, we use nested dictionaries with meaningful default keys and offer several utilities (traversals, etc.) for handling them directly. Linen uses an accelerated version of a Python frozen dictionary that caches its JAX-flattened form to speed up [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html)ted function call overheads. Flax generalizes the operation of a neural net by allowing models to accept collections of several different “kinds”: parameters, batch-norm stats, autoregressive caches, debug information, fine-grained hyperparameters, etc. Each collection is stored in a nested dictionary of the same structure as the model. Importantly, we do *not* conflate these various kinds under the single vague rubric of “state”, but keep different logical types of variables separate that can be treated differently under JAX transformations and under mutations (e.g. training vs prediction). Similarly, we allow for multiple separate named PRNG chains inside [Modules](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) for separate treatment of randomness for different applications such as initialization, dropout, sampling, etc. At every stage the data associated with a neural net is not kept in a custom object hierarchy, but left in an explicit, Python and JAX native form that is easy to introspect and modify. Users have utilized this to map TF and PyTorch checkpoints to Flax, to implement submodel-specific loss terms, and to perform fast model surgery, etc. For saving this data, most Flax examples store these nested dictionaries via the efficient “msgpack” binary format – but as variables are simply Python dicts, you can use any (non-JAX-aware) serialization library directly. ### How does one interoperate with purely functional JAX code? To be broadly useful to the JAX ecosystem, users shouldn’t need to heavily refactor their code in order to add “trainability” for a given numerical task. _“The library should not get in the way.”_ Utilizing purely functional code from within Linen is trivial: [Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) implementations are just JAX code with named variables. Using Linen Modules inside otherwise purely functional code can be as simple as using a single top-level Module transformation to allow initialization and pure application of any JAX program that might contain various trainable sections. ================================================ FILE: docs_nnx/robots.txt ================================================ User-agent: * Disallow: /api_reference/flax.linen/_autosummary/ # for SEO, since Google still indexes this deprecated link Sitemap: https://flax.readthedocs.io/sitemap.xml ================================================ FILE: docs_nnx/why.rst ================================================ Why Flax NNX? ============= In 2020, the Flax team released the Flax Linen API to support modeling research on JAX, with a focus on scaling and performance. We have learned a lot from users since then. The team introduced certain ideas that have proven to be beneficial to users, such as: * Organizing variables into `collections `_. * Automatic and efficient `pseudorandom number generator (PRNG) management `_. * `Variable metadata `_ for `Single Program Multi Data (SPMD) `_ annotations, optimizer metadata, and other use cases. One of the choices the Flax team made was to use functional (``compact``) semantics for neural network programming via lazy initialization of parameters. This made for concise implementation code and aligned the Flax Linen API with Haiku. However, this also meant that the semantics of Modules and variables in Flax were non-Pythonic and often surprising. It also led to implementation complexity and obscured the core ideas of `transformations (transforms) `_ on neural networks. .. testsetup:: Linen, NNX import jax from jax import random, numpy as jnp from flax import nnx import flax.linen as nn Introducing Flax NNX -------------------- Fast forward to 2024, the Flax team developed Flax NNX - an attempt to retain the features that made Flax Linen useful for users, while introducing some new principles. The central idea behind Flax NNX is to introduce reference semantics into JAX. The following are its main features: - **NNX is Pythonic**: Regular Python semantics for Modules, including support for mutability and shared references. - **NNX is simple**: Many of the complex APIs in Flax Linen are either simplified using Python idioms or completely removed. - **Better JAX integration**: Custom NNX transforms adopt the same APIs as the JAX transforms. And with NNX it is easier to use `JAX transforms (higher-order functions) `_ directly. Here is an example of a simple Flax NNX program that illustrates many of the points from above: .. testcode:: NNX from flax import nnx import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @nnx.jit # Automatic state management for JAX transforms. def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) # call methods directly return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates return loss Flax NNX's improvements on Linen -------------------------------- The rest of this document uses various examples that demonstrate how Flax NNX improves on Flax Linen. Inspection ^^^^^^^^^^ The first improvement is that Flax NNX Modules are regular Python objects. This means that you can easily construct and inspect ``Module`` objects. On the other hand, Flax Linen Modules are not easy to inspect and debug because they are lazy, which means some attributes are not available upon construction and are only accessible at runtime. .. codediff:: :title: Linen, NNX :sync: class Block(nn.Module): def setup(self): self.linear = nn.Dense(10) block = Block() try: block.linear # AttributeError: "Block" object has no attribute "linear". except AttributeError as e: pass ... --- class Block(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(5, 10, rngs=rngs) block = Block(nnx.Rngs(0)) block.linear # Linear( # kernel=Param( # value=Array(shape=(5, 10), dtype=float32) # ), # bias=Param( # value=Array(shape=(10,), dtype=float32) # ), # ... Notice that in the Flax NNX example above, there is no shape inference - both the input and output shapes must be provided to the ``Linear`` ``nnx.Module``. This is a tradeoff that allows for more explicit and predictable behavior. Running computation ^^^^^^^^^^^^^^^^^^^ In Flax Linen, all top-level computation must be done through the ``flax.linen.Module.init`` or ``flax.linen.Module.apply`` methods, and the parameters or any other type of state are handled as a separate structure. This creates an asymmetry between: 1) code that runs inside ``apply`` that can run methods and other ``Module`` objects directly; and 2) code that runs outside of ``apply`` that must use the ``apply`` method. In Flax NNX, there's no special context because parameters are held as attributes and methods can be called directly. That means your NNX Module's ``__init__`` and ``__call__`` methods are not treated differently from other class methods, whereas Flax Linen Module's ``setup()`` and ``__call__`` methods are special. .. codediff:: :title: Linen, NNX :sync: Encoder = lambda: nn.Dense(10) Decoder = lambda: nn.Dense(2) class AutoEncoder(nn.Module): def setup(self): self.encoder = Encoder() self.decoder = Decoder() def __call__(self, x) -> jax.Array: return self.decoder(self.encoder(x)) def encode(self, x) -> jax.Array: return self.encoder(x) x = jnp.ones((1, 2)) model = AutoEncoder() params = model.init(random.key(0), x)['params'] y = model.apply({'params': params}, x) z = model.apply({'params': params}, x, method='encode') y = Decoder().apply({'params': params['decoder']}, z) --- Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs) Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs) class AutoEncoder(nnx.Module): def __init__(self, rngs): self.encoder = Encoder(rngs) self.decoder = Decoder(rngs) def __call__(self, x) -> jax.Array: return self.decoder(self.encoder(x)) def encode(self, x) -> jax.Array: return self.encoder(x) x = jnp.ones((1, 2)) model = AutoEncoder(nnx.Rngs(0)) y = model(x) z = model.encode(x) y = model.decoder(z) In Flax Linen, calling sub-Modules directly is not possible because they are not initialized. Therefore, what you must do is construct a new instance and then provide a proper parameter structure. But in Flax NNX you can call sub-Modules directly without any issues. State handling ^^^^^^^^^^^^^^ One of the areas where Flax Linen is notoriously complex is in state handling. When you use either a `Dropout` layer, a `BatchNorm` layer, or both, you suddenly have to handle the new state and use it to configure the ``flax.linen.Module.apply`` method. In Flax NNX, state is kept inside an ``nnx.Module`` and is mutable, which means it can just be called directly. .. codediff:: :title: Linen, NNX :sync: class Block(nn.Module): train: bool def setup(self): self.linear = nn.Dense(10) self.bn = nn.BatchNorm(use_running_average=not self.train) self.dropout = nn.Dropout(0.1, deterministic=not self.train) def __call__(self, x): return nn.relu(self.dropout(self.bn(self.linear(x)))) x = jnp.ones((1, 5)) model = Block(train=True) vs = model.init(random.key(0), x) params, batch_stats = vs['params'], vs['batch_stats'] y, updates = model.apply( {'params': params, 'batch_stats': batch_stats}, x, rngs={'dropout': random.key(1)}, mutable=['batch_stats'], ) batch_stats = updates['batch_stats'] --- class Block(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(5, 10, rngs=rngs) self.bn = nnx.BatchNorm(10, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) x = jnp.ones((1, 5)) model = Block(nnx.Rngs(0)) y = model(x) ... The main benefit of Flax NNX's state handling is that you don't have to change the training code when you add a new stateful layer. In addition, in Flax NNX, layers that handle state are also very easy to implement. Below is a simplified version of a ``BatchNorm`` layer that updates the mean and variance every time it is called. .. testcode:: NNX class BatchNorm(nnx.Module): def __init__(self, features: int, mu: float = 0.95): # Variables self.scale = nnx.Param(jax.numpy.ones((features,))) self.bias = nnx.Param(jax.numpy.zeros((features,))) self.mean = nnx.BatchStat(jax.numpy.zeros((features,))) self.var = nnx.BatchStat(jax.numpy.ones((features,))) self.mu = mu # Static def __call__(self, x): mean = jax.numpy.mean(x, axis=-1) var = jax.numpy.var(x, axis=-1) # ema updates self.mean.value = self.mu * self.mean + (1 - self.mu) * mean self.var.value = self.mu * self.var + (1 - self.mu) * var # normalize and scale x = (x - mean) / jax.numpy.sqrt(var + 1e-5) return x * self.scale + self.bias Model surgery ^^^^^^^^^^^^^ In Flax Linen, `model surgery `_ has historically been challenging because of two reasons: 1. Due to lazy initialization, it is not guaranteed that you can replace a sub-``Module`` with a new one. 2. The parameter structure is separated from the ``flax.linen.Module`` structure, which means you have to manually keep them in sync. In Flax NNX, you can replace sub-Modules directly as per the Python semantics. Since parameters are part of the ``nnx.Module`` structure, they are never out of sync. Below is an example of how you can implement a LoRA layer, and then use it to replace a ``Linear`` layer in an existing model. .. codediff:: :title: Linen, NNX :sync: class LoraLinear(nn.Module): linear: nn.Dense rank: int @nn.compact def __call__(self, x: jax.Array): A = self.param(random.normal, (x.shape[-1], self.rank)) B = self.param(random.normal, (self.rank, self.linear.features)) return self.linear(x) + x @ A @ B try: model = Block(train=True) model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR lora_params = model.linear.init(random.key(1), x) lora_params['linear'] = params['linear'] params['linear'] = lora_params except AttributeError as e: pass --- class LoraParam(nnx.Param): pass class LoraLinear(nnx.Module): def __init__(self, linear, rank, rngs): self.linear = linear self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank))) self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features))) def __call__(self, x: jax.Array): return self.linear(x) + x @ self.A @ self.B rngs = nnx.Rngs(0) model = Block(rngs) model.linear = LoraLinear(model.linear, rank=5, rngs=rngs) ... As shown above, in Flax Linen this doesn't really work in this case because the ``linear`` sub-``Module`` is not available. However, the rest of the code provides an idea of how the ``params`` structure must be manually updated. Performing arbitrary model surgery is not easy in Flax Linen, and currently the `intercept_methods `_ API is the only way to do generic patching of methods. But this API is not very ergonomic. In Flax NNX, to do generic model surgery you can just use ``nnx.iter_graph``, which is much simpler and easier than in Linen. Below is an example of replacing all ``nnx.Linear`` layers in a model with custom-made ``LoraLinear`` NNX layers. .. testcode:: NNX rngs = nnx.Rngs(0) model = Block(rngs) for path, module in nnx.iter_graph(model): if isinstance(module, nnx.Module): for name, value in vars(module).items(): if isinstance(value, nnx.Linear): setattr(module, name, LoraLinear(value, rank=5, rngs=rngs)) Transforms ^^^^^^^^^^ Flax Linen transforms are very powerful in that they enable fine-grained control over the model's state. However, Flax Linen transforms have drawbacks, such as: 1. They expose additional APIs that are not part of JAX, making their behavior confusing and sometimes divergent from their JAX counterparts. This also constrains your ways to interact with `JAX transforms `_ and keep up with JAX API changes. 2. They work on functions with very specific signatures, namely: - A ``flax.linen.Module`` must be the first argument. - They accept other ``Module`` objects as arguments but not as return values. 3. They can only be used inside ``flax.linen.Module.apply``. On the other hand, `Flax NNX transforms `_ are intented to be equivalent to their corresponding `JAX transforms `_ with an exception - they can be used on Flax NNX Modules. This means that Flax transforms: 1) Have the same API as JAX transforms. 2) Can accept Flax NNX Modules on any argument, and ``nnx.Module`` objects can be returned from it/them. 3) Can be used anywhere including the training loop. Below is an example of using ``vmap`` with Flax NNX to both create a stack of weights by transforming the ``create_weights`` function, which returns some ``Weights``, and to apply that stack of weights to a batch of inputs individually by transforming the ``vector_dot`` function, which takes ``Weights`` as the first argument and a batch of inputs as the second argument. .. testcode:: NNX class Weights(nnx.Module): def __init__(self, kernel: jax.Array, bias: jax.Array): self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias) def create_weights(seed: jax.Array): return Weights( kernel=random.uniform(random.key(seed), (2, 3)), bias=jnp.zeros((3,)), ) def vector_dot(weights: Weights, x: jax.Array): assert weights.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' return x @ weights.kernel + weights.bias seeds = jnp.arange(10) weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds) x = jax.random.normal(random.key(1), (10, 2)) y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x) Contrary to Flax Linen transforms, the ``in_axes`` argument and other APIs do affect how the ``nnx.Module`` state is transformed. In addition, Flax NNX transforms can be used as method decorators, because ``nnx.Module`` methods are simply functions that take a ``Module`` as the first argument. This means that the previous example can be rewritten as follows: .. testcode:: NNX class WeightStack(nnx.Module): @nnx.vmap(in_axes=(0, 0), out_axes=0) def __init__(self, seed: jax.Array): self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3))) self.bias = nnx.Param(jnp.zeros((3,))) @nnx.vmap(in_axes=(0, 0), out_axes=1) def __call__(self, x: jax.Array): assert self.kernel.ndim == 2, 'Batch dimensions not allowed' assert x.ndim == 1, 'Batch dimensions not allowed' return x @ self.kernel + self.bias weights = WeightStack(jnp.arange(10)) x = jax.random.normal(random.key(1), (10, 2)) y = weights(x) ================================================ FILE: examples/README.md ================================================ # Flax Examples Each example is designed to be **self-contained and easily forkable**, while reproducing relevant results in different areas of machine learning. As discussed in [#231](https://github.com/google/flax/issues/231), we decided to go for a standard pattern for all examples including the simplest ones (like MNIST). This makes every example a bit more verbose, but once you know one example, you know the structure of all of them. Having unit tests and integration tests is also very useful when you fork these examples. For more examples including contributions from the community and other projects currently using Flax see the **[Examples](https://flax.readthedocs.io/en/latest/examples.html)** section in the documentation. ================================================ FILE: examples/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/cloud/README.md ================================================ # Launching jobs on Google Cloud This directory provides a simple script that can be used to create a new VM on Google Cloud, train an example on that VM and then shutting it down. The training is implemented in a shell that is run on the VM after startup by setting the `startup_script_file` in the metadata. The script opens a TMUX session, installs Flax repository from Github with all dependencies, and then runs the training in parallel with `gcloud storage rsync` that copies the training artifacts in a storage bucket. The advantage of this approach is that every training is run in a single VM that contains all code and configuration, so it is easy to run multiple experiments in parallel without interference. The individual trainings can also be inspected by logging into the VM via SSH and attaching to the tmux session. The script `launch_gce.py` launches the VM and prints out the relevant commands to track the progress update and to log into the machine. Note that the VM also shuts down if an error is encountered, after waiting for five minutes. ## Preparation Prerequisites: 1. Create a Google Cloud account. 2. Set up billing: https://console.cloud.google.com/billing 3. Create a storage bucket (GCS). 4. Optional: Get quota for accelerators. This is usually granted with a short delay: https://console.cloud.google.com/iam-admin/quotas ## Setting up your environment The commands below use the same set of pre-defined environment variables. Mandatory environment variables: - `$PROJECT`: The name of your Google Cloud project. - `$GCS_BUCKET`: The name of the Google Cloud Storage bucket where the model output (artifacts, final checkpoint) are to be stored. - `$ZONE`: Compute zone (e.g. `central1-a`). Optional environment variables: - `$REPO`: Alternative repo to use instead of the default https://github.com/google/flax - this is useful for development. - `$BRANCH`: Alternative branch to use instead of the default `main`. ## Training the MNIST example Use the following command to launch the MNIST example on cloud (make sure to set `$PROJECT` and `$GCS_BUCKET` accordingly): ```shell python examples/cloud/launch_gce.py \ --project=$PROJECT \ --zone=us-west1-a \ --machine_type=n2-standard-2 \ --gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \ --repo=${REPO:-https://github.com/google/flax} \ --branch=${BRANCH:-main} \ --example=mnist \ --args='--config=configs/default.py' \ --name=default ``` ## Training the imagenet example Note that you need to first prepare the `imagenet2012` dataset. For this, download the data from http://image-net.org/ as described in the [tensorflow_datasets catalog](https://www.tensorflow.org/datasets/catalog/imagenet2012). Then point the environment variable `$IMAGENET_DOWNLOAD_PATH` to the directory where the downloads are stored and prepare the dataset by running ```shell python -c " import tensorflow_datasets as tfds tfds.builder('imagenet2012').download_and_prepare( download_config=tfds.download.DownloadConfig( manual_dir='$IMAGENET_DOWNLOAD_PATH')) " ``` Then copy the contents of the directory `~/tensorflow_datasets` into the directory `gs://$GCS_TFDS_BUCKET/datasets` (note that `$GCS_TFDS_BUCKET` and `$GCS_BUCKET` can be identical). After this preparation you can run the imagenet example with the following command (make sure to set `$PROJECT`, `$GCS_BUCKET` and `$GCS_TFDS_BUCKET` accordingly): ```shell python examples/cloud/launch_gce.py \ --project=$PROJECT \ --zone=us-west1-a \ --machine_type=n1-standard-96 \ --accelerator_type=nvidia-tesla-v100 --accelerator_count=8 \ --gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \ --tfds_data_dir=gs://$GCS_TFDS_BUCKET/datasets \ --repo=${REPO:-https://github.com/google/flax} \ --branch=${BRANCH:-main} \ --example=imagenet \ --args='--config=configs/v100_x8_mixed_precision.py' \ --name=v100_x8_mixed_precision ``` ## Tips You can add `--connect` to above commands to directly land in the training session once the VM is ready. This is very helpful for debugging when changing things. Note that the VM automatically shuts down after 5 minutes of inactivity, both in case of success as in case of failure. On OS X this could be combined with `VM_READY_CMD="osascript -e 'display notification \"VM ready\"'"` so get undistracted when the VM is up and running. When tweaking the startup script or individual arguments, it is often helpful to connect to the VM, stop the scripts and end the tmux session, and then copy and paste the contents of the generated `flax-...-startup_script.sh`, after modifying these contents accordingly. ================================================ FILE: examples/cloud/launch_gce.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Script for creating VM on Google cloud and run a Flax example inside. # See ./README.md for instructions. import datetime import os import re import subprocess import time from collections.abc import Sequence from absl import app from absl import flags # General options. flags.DEFINE_bool( 'dry_run', False, help=( 'If set, then the command to launch the GCE instance will only be ' 'printed to stdout.' ), ) flags.DEFINE_bool( 'connect', False, help='Same as --wait, but directly connect to VM once it is ready.', ) flags.DEFINE_bool( 'wait', False, help=( 'If set, then the script will wait until VM is ready. If VM_READY_CMD ' 'is set in environment, then that command will be executed once the VM ' 'is ready. Useful for sending a notification, e.g. "osascript" (mac).' ), ) # Machine configuration. flags.DEFINE_string('project', None, help='Name of the Google Cloud project.') flags.DEFINE_string('zone', None, help='Zone in which the VM will be created.') flags.DEFINE_string( 'machine_type', None, help='Machine type to use for VM. See "gcloud compute machine-types list".', ) flags.DEFINE_string( 'accelerator_type', '', help=( 'Type of accelerator to use, or empty. ' 'See "gcloud compute accelerator-types list".' ), ) flags.DEFINE_integer( 'shutdown_secs', 300, help=( 'How long to wait (after successful/failed training) before shutting ' 'down the VM. Set to 0 to disable.' ), ) flags.DEFINE_integer( 'accelerator_count', 8, help='Number of accelerators to use.' ) # GCS configuration. flags.DEFINE_string( 'gcs_workdir_base', None, help=( 'GCS base directory for model output. The --workdir argument will be ' 'constructed from {gcs_workdir_base}/{example}/{name}/{timestamp} .' ), ) flags.DEFINE_string( 'tfds_data_dir', '', help=( 'Optional tfds data directory. This can be useful to prepare datasets' ' on GCS and then point the jobs to this preloaded directory. Dataset' ' will be downloaded from the web if not specified.' ), ) # Repo configuration. flags.DEFINE_string( 'repo', 'https://github.com/google/flax', help='Git repository' ) flags.DEFINE_string('branch', 'main', help='Git repository') # Example configuration. flags.DEFINE_string( 'example', None, help='Name of Flax example (e.g. "imagenet").' ) flags.DEFINE_string( 'args', '', help=( 'Any additional command line arguments for {example}_main.py, like ' 'for example --config. Note that --workdir will be provided by the ' 'script.' ), ) # Run configuration. flags.DEFINE_string( 'name', None, help=( 'Name of the experiment. Note that the provided name will be ' 'extended to {example}/{name}/{timestamp}' ), ) FLAGS = flags.FLAGS flags.mark_flags_as_required( ['project', 'zone', 'machine_type', 'gcs_workdir_base', 'example', 'name'] ) timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') def generate_startup_file(vm_name: str) -> str: directory = os.path.dirname(os.path.abspath(__file__)) startup_script_src = os.path.join(directory, 'startup_script.sh') startup_script_dst = os.path.join(directory, f'{vm_name}-startup_script.sh') assert not os.path.exists(startup_script_dst) with open(startup_script_src, encoding='utf8') as f: startup_script_content = f.read() for from_str, to_str in ( ('__REPO__', FLAGS.repo), ('__BRANCH__', FLAGS.branch), ('__EXAMPLE__', FLAGS.example), ('__TIMESTAMP__', timestamp), ('__NAME__', FLAGS.name), ('__ARGS__', FLAGS.args), ('__GCS_WORKDIR_BASE__', FLAGS.gcs_workdir_base), ('__TFDS_DATA_DIR__', FLAGS.tfds_data_dir), ('__ACCELERATOR_TYPE__', FLAGS.accelerator_type), ('__SHUTDOWN_SECS__', str(FLAGS.shutdown_secs)), ): startup_script_content = startup_script_content.replace(from_str, to_str) with open(startup_script_dst, 'w', encoding='utf8') as f: f.write(startup_script_content) return startup_script_dst def launch_gce(*, vm_name: str, startup_script: str): # Note : Use `gcloud compute images list --project ml-images` to get a list # of available VM images. args = [ 'gcloud', 'compute', 'instances', 'create', vm_name, f'--project={FLAGS.project}', f'--zone={FLAGS.zone}', '--image=c1-deeplearning-tf-2-10-cu113-v20221107-debian-10', '--image-project=ml-images', f'--machine-type={FLAGS.machine_type}', '--scopes=cloud-platform,storage-full', '--boot-disk-size=256GB', '--boot-disk-type=pd-ssd', '--metadata=install-nvidia-driver=True', f'--metadata-from-file=startup-script={startup_script}', ] if FLAGS.accelerator_type and FLAGS.accelerator_count: args.extend([ '--maintenance-policy=TERMINATE', f'--accelerator=type={FLAGS.accelerator_type},count={FLAGS.accelerator_count}', ]) if FLAGS.dry_run: print() print('Would run the following command without --dry-run:') print() print(' \\\n '.join(args)) print() return print() print('Creating instance on GCE... This will take some minutes...') print() result = subprocess.run(args) if result.returncode: raise RuntimeError('Could not create VM!') def print_howto(login_args: Sequence[str]): print(f""" ############################################################################### ############################################################################### You can start/stop the instace via the web UI: https://console.cloud.google.com/compute/instances?project={FLAGS.project} Once the VM has started, you can login and connect to the training session: {' '.join(login_args)} Note that you can disconnect from the tmux session without stopping the training with the keystrokes 'CTRL-B A'. See "man tmux" for help about tmux. To observe the training via Tensorboard, simply run in your local computer: $ tensorboard --logdir={FLAGS.gcs_workdir_base} You can also browse the files at https://console.cloud.google.com/storage/browser/{FLAGS.gcs_workdir_base.replace('gs://', '')} ############################################################################### ############################################################################### """) def main(_): for name in ('repo', 'branch', 'example', 'name', 'gcs_workdir_base'): value = getattr(FLAGS, name) if re.match(r'[^\w:/_-]', value): raise ValueError(f'Invalid flag value: --{name}="{value}"') example_base_directory = os.path.join( os.path.dirname(os.path.abspath(__file__)), os.path.pardir, ) if not os.path.isdir(os.path.join(example_base_directory, FLAGS.example)): raise ValueError(f'Could not find --example={FLAGS.example}') if FLAGS.connect and FLAGS.dry_run: raise ValueError('Cannot --connect to VM with --dry_run') vm_name = '-'.join([ 'flax', FLAGS.example, timestamp, ]) vm_name = re.sub(r'[^a-z0-9-]', '-', vm_name) startup_script = generate_startup_file(vm_name) launch_gce(vm_name=vm_name, startup_script=startup_script) login_args = [ 'gcloud', 'compute', 'ssh', '--project', FLAGS.project, '--zone', FLAGS.zone, vm_name, '--', '/sudo_tmux_a.sh', ] print('Your instance is being started...') print_howto(login_args) if FLAGS.connect or FLAGS.wait: login_true_args = login_args[:-1] + ['true'] while True: try: result = subprocess.run( login_true_args, timeout=10, stderr=subprocess.PIPE ) if result.returncode == 0: break stderr = result.stderr.decode('utf8') if 'connection refused' in stderr.lower(): print('(Connection refused - waiting a little longer...)') time.sleep(20) else: raise ValueError(f'Unknown error: {stderr}') except ValueError as e: if 'HTTP 502' not in str(e): raise e print('(Bad Gateway - waiting a little longer...)') time.sleep(20) except subprocess.TimeoutExpired: print('(Timeout - waiting a little longer...)') time.sleep(20) if 'VM_READY_CMD' in os.environ: os.system(os.environ['VM_READY_CMD']) if FLAGS.connect: result = subprocess.run(login_args) # SSH session has cleared previous message, print it again. print_howto(login_args) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/cloud/startup_script.sh ================================================ #!/bin/bash # Note that all __XYZ__ strings are replaced by launch_gce.py WORKDIR="/train/workdir_base/__EXAMPLE__/__NAME__/__TIMESTAMP__" mkdir -p /train cd /train # Login directly with: # gcloud compute ssh $VM -- /sudo_tmux_a.sh echo -e '#!/bin/bash\nsudo /tmux_a.sh' > /sudo_tmux_a.sh chmod a+x /sudo_tmux_a.sh echo -e '#!/bin/bash\ntmux a' > /tmux_a.sh chmod a+x /tmux_a.sh # Main script running in bottom left tmux pane. cat >/install_train_stop.sh <&1 | tee -a $WORKDIR/setup_train_log_${TIMESTAMP}.txt if [ __SHUTDOWN_SECS__ -gt 0 ]; then echo echo WILL SHUT DOWN IN $((__SHUTDOWN_SECS__/60)) MIN ... sleep __SHUTDOWN_SECS__ && shutdown now fi EOF # Set up TMUX panes: tmux new-session -s flax -d # - top left: htop tmux send 'htop ' tmux split-window tmux selectp -U tmux split-window -h # - top right: htop tmux send 'watch nvidia-smi ' tmux selectp -D # - bottom left: main script tmux send '. /install_train_stop.sh ' tmux split-window -h # - bottom right: rsync files to GCS bucket. tmux send " while true; do gcloud storage rsync --recursive workdir_base __GCS_WORKDIR_BASE__ sleep 60 done 2>&1 | tee -a $WORKDIR/gcs_rsync_'__TIMESTAMP__'.txt " ================================================ FILE: examples/gemma/README.md ================================================ ## Language modeling Trains Gemma model on the One Billion Word Benchmark (lm1b; Chelba *et al.*, 2013). This example is based on lm1b and similarly uses linear learning rate warmup and inverse square root learning rate schedule. ### Requirements * TensorFlow datasets `lm1b` need to be downloaded and prepared (see below). A sentencepiece tokenizer vocabulary will be automatically generated and saved on each training run. * This example additionally depends on the `sentencepiece` and `tensorflow-text` packages. ### Downloading the LM1B Datasets We recommend downloading and preparing the TFDS datasets beforehand. You can download and prepare LM1B datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`. #### Using Cloud Storage FUSE for TPUs For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS data on a storage bucket, from where it can be mounted to the TPU VM using [Cloud Storage FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/quickstart-mount-bucket). ##### Copy the preprocessed dataset to the Cloud Storage We assume that the dataset was downloaded and prepared. We also assume we have configured `gcloud` CLI. The following commands helps to setup the storage and copy the dataset: ```bash # Install gcsfuse CLI export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` # For example, GCSFUSE_REPO=gcsfuse-noble for Ubuntu 24.04 echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - sudo apt-get update sudo apt-get install -y fuse gcsfuse --no-install-recommends gcsfuse -v # gcsfuse version 2.12.2 (Go version go1.24.0) ``` Let's get where LM1B dataset was locally stored: ```bash python -c "import tensorflow_datasets as tfds; b=tfds.builder('lm1b'); print(b.info.data_dir)" # For example: /home/user/tensorflow_datasets/lm1b/1.1.0 ``` Let's create a GCS bucket for the dataset and link the bucket to a local folder. We choose the bucket name "flax-lm1b-tfdataset" but this can be changed. ```bash gcloud storage buckets create gs://flax-lm1b-tfdataset mkdir -p $HOME/data gcsfuse flax-lm1b-tfdataset $HOME/data ``` Now let's copy the data to the bucket: ```bash # Let's assume that prepared dataset is at $HOME/tensorflow_datasets/lm1b/ cp -R $HOME/tensorflow_datasets/lm1b $HOME/data ``` ##### Setup the dataset on TPU VM We previously have choosen the bucket name "flax-lm1b-tfdataset" where stored the dataset, adapt this name to your situation. ```bash # On the TPU VM gcsfuse flax-lm1b-tfdataset $HOME/tensorflow_datasets ls $HOME/tensorflow_datasets/lm1b/1.1.0/ ``` ### How to run on GPU(s) Install Jax with CUDA support, Flax and the example dependencies with the following command: ```bash pip install jax[cuda12] # Check whether GPUs are available: # python3 -c "import jax; print(jax.devices())" git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/gemma pip install -r requirements.txt ``` Start the training: - train a small transformer model: ```bash python3 main.py --workdir=$HOME/logs/small_gemma_lm1b --config=configs/small.py ``` - train Gemma3-4B model: ```bash python3 main.py --workdir=$HOME/logs/gemma3-4b_lm1b --config=configs/gemma3_4b.py ``` To monitor the trainings with the TensorBoard: ```bash tensorboard --logdir=$HOME/logs ``` ### How to run on Cloud TPUs Setup the TPU VM and install the Flax dependencies on it as described [here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single v4-8 TPU. First create a single TPUv4-8 VM and connect to it (you can find more detailed instructions [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)): ```bash ZONE=us-central1-a TPU_TYPE=v4-8 TPU_NAME=$USER-flax-gemma-lm1b gcloud compute tpus tpu-vm create $TPU_NAME \ --zone $ZONE \ --accelerator-type $TPU_TYPE \ --version tpu-ubuntu2204-base gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE -- \ -L 6006:localhost:6006 ``` When connected install JAX: ```bash pip install "jax[tpu]>=0.2.16" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` Then install Flax + the example dependencies: ```bash git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/gemma pip install -r requirements.txt ``` In case of errors when installing example dependencies, try to upgrade existing `pip` package and downgrade `setuptools` and repeat the installation command ```bash # Optionally # pip install -U pip # pip install -U "setuptools<70" # pip install -r requirements.txt ``` And finally start the training: ```bash python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32 ``` Note that you might want to set `TFDS_DATA_DIR` as explained below. You probably also want to start the long-running command above in a `tmux` session and start some monitoring in a separate pane (note that we forwarded port 6006 locally above): ```bash tensorboard --logdir=$HOME/logs ``` ================================================ FILE: examples/gemma/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import dataclasses from train import MeshRules, TrainConfig @dataclasses.dataclass(unsafe_hash=True) class Config: # Path to load or store sentencepiece vocab file. vocab_path: str | None = None # Vocabulary size if `vocab_path` is not given. vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) # Maximum number of characters to use for training. max_corpus_chars: int = 10**7 # Name of TFDS translation dataset to use. dataset_name: str = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. eval_dataset_name: str = 'lm1b' # Optional name of TFDS split to use for evaluation. eval_split: str = 'test' # Per device batch size for training. per_device_batch_size: int = 32 # Per device batch size for training. eval_per_device_batch_size: int = 32 # Prompt for language model sampling prompts: tuple[str, ...] = ( 'Paris is a the capital', 'Flax is a', # From train set: 'The shutdown was aimed at creating efficiencies as', # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day 'A big theme of this hire is that there are parts of', # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... # From test set: 'Because of Bear Stearns , many analysts are', # -> raising the odds that a 2008 recession could be worse than expected 'Next month , the Brazilian bourse', # -> opens a London office', ) # Temperature for top_p sampling. sampling_temperature: float = 0.0 # Top-p sampling threshold. sampling_top_p: float = 0.95 # Number of steps to take during training. num_train_steps: int = 500_000 # Number of steps to take during evaluation. # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 num_eval_steps: int = 2_000 # Number of steps to generate predictions. # -1 will use the whole eval dataset. num_predict_steps: int = 50 # Base learning rate. learning_rate: float = 0.0016 # Linear learning rate warmup. warmup_steps: int = 1000 # Cross entropy loss label smoothing. label_smoothing: float = 0.0 # Decay factor for AdamW style weight decay. weight_decay: float = 0.1 # Maximum length cutoff for training examples. max_target_length: int = 128 # Maximum length cutoff for eval examples. max_eval_target_length: int = 512 # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) transformer_name: str | None = "gemma3_1b" # or alternatively define the model using the dict of parameters transformer_params: dict | None = None # Whether to save model checkpoints. save_checkpoints: bool = True # Whether to restore from existing model checkpoints. restore_checkpoints: bool = True # Save a checkpoint every these number of steps. checkpoint_every_steps: int = 10_000 # Frequency of eval during training, e.g. every 1_000 steps. eval_every_steps: int = 5_000 # Use bfloat16 mixed precision training instead of float32. use_bfloat16: bool = True # Integer for PRNG random seed. seed: int = 0 # Parallelism mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') axis_rules: MeshRules = MeshRules( embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor', ) data_sharding: tuple[str, ...] = ('data', 'fsdp') # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. dcn_data_parallelism: int = -1 dcn_fsdp_parallelism: int = 1 dcn_tensor_parallelism: int = 1 ici_data_parallelism: int = 1 ici_fsdp_parallelism: int = -1 ici_tensor_parallelism: int = 1 def get_config() -> TrainConfig: """Get the default hyperparameter configuration.""" config = Config() return TrainConfig(**dataclasses.asdict(config)) ================================================ FILE: examples/gemma/configs/gemma3_4b.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import dataclasses from train import MeshRules, TrainConfig @dataclasses.dataclass(unsafe_hash=True) class Config: # Path to load or store sentencepiece vocab file. vocab_path: str | None = None # Vocabulary size if `vocab_path` is not given. vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) # Maximum number of characters to use for training. max_corpus_chars: int = 10**7 # Name of TFDS translation dataset to use. dataset_name: str = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. eval_dataset_name: str = 'lm1b' # Optional name of TFDS split to use for evaluation. eval_split: str = 'test' # Per device batch size for training. per_device_batch_size: int = 32 # Per device batch size for training. eval_per_device_batch_size: int = 32 # Prompt for language model sampling prompts: tuple[str, ...] = ( 'Paris is a the capital', 'Flax is a', # From train set: 'The shutdown was aimed at creating efficiencies as', # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day 'A big theme of this hire is that there are parts of', # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... # From test set: 'Because of Bear Stearns , many analysts are', # -> raising the odds that a 2008 recession could be worse than expected 'Next month , the Brazilian bourse', # -> opens a London office', ) # Temperature for top_p sampling. sampling_temperature: float = 0.0 # Top-p sampling threshold. sampling_top_p: float = 0.95 # Number of steps to take during training. num_train_steps: int = 500_000 # Number of steps to take during evaluation. # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 num_eval_steps: int = 2_000 # Number of steps to generate predictions. # -1 will use the whole eval dataset. num_predict_steps: int = 50 # Base learning rate. learning_rate: float = 0.0016 # Linear learning rate warmup. warmup_steps: int = 1000 # Cross entropy loss label smoothing. label_smoothing: float = 0.0 # Decay factor for AdamW style weight decay. weight_decay: float = 0.1 # Maximum length cutoff for training examples. max_target_length: int = 128 # Maximum length cutoff for eval examples. max_eval_target_length: int = 512 # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) transformer_name: str | None = "gemma3_4b" # or alternatively define the model using the dict of parameters transformer_params: dict | None = None # Whether to save model checkpoints. save_checkpoints: bool = True # Whether to restore from existing model checkpoints. restore_checkpoints: bool = True # Save a checkpoint every these number of steps. checkpoint_every_steps: int = 10_000 # Frequency of eval during training, e.g. every 1_000 steps. eval_every_steps: int = 5_000 # Use bfloat16 mixed precision training instead of float32. use_bfloat16: bool = True # Integer for PRNG random seed. seed: int = 0 # Parallelism mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') axis_rules: MeshRules = MeshRules( embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor', ) data_sharding: tuple[str, ...] = ('data', 'fsdp') # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. dcn_data_parallelism: int = -1 dcn_fsdp_parallelism: int = 1 dcn_tensor_parallelism: int = 1 ici_data_parallelism: int = 1 ici_fsdp_parallelism: int = -1 ici_tensor_parallelism: int = 1 def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def get_config() -> TrainConfig: """Get the default hyperparameter configuration.""" config = Config() return TrainConfig(**dataclasses.asdict(config)) ================================================ FILE: examples/gemma/configs/small.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import dataclasses from train import MeshRules, TrainConfig @dataclasses.dataclass(unsafe_hash=True) class Config: # Path to load or store sentencepiece vocab file. vocab_path: str | None = None # Vocabulary size if `vocab_path` is not given. vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) # Maximum number of characters to use for training. max_corpus_chars: int = 10**7 # Name of TFDS translation dataset to use. dataset_name: str = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. eval_dataset_name: str = 'lm1b' # Optional name of TFDS split to use for evaluation. eval_split: str = 'test' # Per device batch size for training. per_device_batch_size: int = 32 # Per device batch size for training. eval_per_device_batch_size: int = 32 # Prompt for language model sampling prompts: tuple[str, ...] = ( 'Paris is a the capital', 'Flax is a', # From train set: 'The shutdown was aimed at creating efficiencies as', # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day 'A big theme of this hire is that there are parts of', # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... # From test set: 'Because of Bear Stearns , many analysts are', # -> raising the odds that a 2008 recession could be worse than expected 'Next month , the Brazilian bourse', # -> opens a London office', ) # Temperature for top_p sampling. sampling_temperature: float = 0.0 # Top-p sampling threshold. sampling_top_p: float = 0.95 # Number of steps to take during training. num_train_steps: int = 500_000 # Number of steps to take during evaluation. num_eval_steps: int = 500 # Number of steps to generate predictions. # -1 will use the whole eval dataset. num_predict_steps: int = 50 # Base learning rate. learning_rate: float = 0.0016 # Linear learning rate warmup. warmup_steps: int = 1000 # Cross entropy loss label smoothing. label_smoothing: float = 0.0 # Decay factor for AdamW style weight decay. weight_decay: float = 0.1 # Maximum length cutoff for training examples. max_target_length: int = 128 # Maximum length cutoff for eval examples. max_eval_target_length: int = 512 # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) transformer_name: str | None = None # or alternatively define the model using the dict of parameters transformer_params: dict | None = dataclasses.field( default_factory=lambda: { "num_layers": 6, "embed_dim": 512, "hidden_dim": 2048, "num_heads": 4, "head_dim": 256, "num_kv_heads": 1, "use_post_attn_norm": True, "use_post_ffw_norm": True, "use_qk_norm": True, "attention_types": (2, 2, 2, 2, 2, 1), # local_sliding, ..., local_sliding, global "query_pre_attn_norm": 1, # QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM "attn_logits_soft_cap": None, "final_logit_softcap": None, "sliding_window_size": 128, "transpose_gating_einsum": True, "local_base_frequency": 10_000, "global_base_frequency": 1_000_000, } ) # Whether to save model checkpoints. save_checkpoints: bool = True # Whether to restore from existing model checkpoints. restore_checkpoints: bool = True # Save a checkpoint every these number of steps. checkpoint_every_steps: int = 10_000 # Frequency of eval during training, e.g. every 1_000 steps. eval_every_steps: int = 5_000 # Use bfloat16 mixed precision training instead of float32. use_bfloat16: bool = True # Integer for PRNG random seed. seed: int = 0 # Parallelism mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') axis_rules: MeshRules = MeshRules( embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor', ) data_sharding: tuple[str, ...] = ('data', 'fsdp') # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. dcn_data_parallelism: int = -1 dcn_fsdp_parallelism: int = 1 dcn_tensor_parallelism: int = 1 ici_data_parallelism: int = 1 ici_fsdp_parallelism: int = -1 ici_tensor_parallelism: int = 1 def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def get_config() -> TrainConfig: """Get the default hyperparameter configuration.""" config = Config() return TrainConfig(**dataclasses.asdict(config)) ================================================ FILE: examples/gemma/configs/tiny.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import dataclasses from train import MeshRules, TrainConfig @dataclasses.dataclass(unsafe_hash=True) class Config: # Path to load or store sentencepiece vocab file. vocab_path: str | None = None # Vocabulary size if `vocab_path` is not given. vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144) # Maximum number of characters to use for training. max_corpus_chars: int = 10**7 # Name of TFDS translation dataset to use. dataset_name: str = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. eval_dataset_name: str = 'lm1b' # Optional name of TFDS split to use for evaluation. eval_split: str = 'test' # Per device batch size for training. per_device_batch_size: int = 32 # Per device batch size for training. eval_per_device_batch_size: int = 32 # Prompt for language model sampling prompts: tuple[str, ...] = ( 'Paris is a the capital', 'Flax is a', # From train set: 'The shutdown was aimed at creating efficiencies as', # -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day 'A big theme of this hire is that there are parts of', # -> our operations that to use a pretty trite phrase , need to be taken to the next level ... # From test set: 'Because of Bear Stearns , many analysts are', # -> raising the odds that a 2008 recession could be worse than expected 'Next month , the Brazilian bourse', # -> opens a London office', ) # Temperature for top_p sampling. sampling_temperature: float = 0.0 # Top-p sampling threshold. sampling_top_p: float = 0.95 # Number of steps to take during training. num_train_steps: int = 500_000 # Number of steps to take during evaluation. num_eval_steps: int = 500 # Number of steps to generate predictions. # -1 will use the whole eval dataset. num_predict_steps: int = 20 # Base learning rate. learning_rate: float = 0.0016 # Linear learning rate warmup. warmup_steps: int = 1000 # Cross entropy loss label smoothing. label_smoothing: float = 0.0 # Decay factor for AdamW style weight decay. weight_decay: float = 0.1 # Maximum length cutoff for training examples. max_target_length: int = 128 # Maximum length cutoff for eval examples. max_eval_target_length: int = 512 # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) transformer_name: str | None = None # or alternatively define the model using the dict of parameters transformer_params: dict | None = dataclasses.field( default_factory=lambda: { "num_layers": 4, "embed_dim": 256, "hidden_dim": 256 * 4 // 2, # embed_dim * num_heads // 2 "num_heads": 4, "head_dim": 128, "num_kv_heads": 1, "use_post_attn_norm": False, "use_post_ffw_norm": False, "attention_types": (1, 1, 1, 1), # global * num_layers "final_logit_softcap": None, } ) # Whether to save model checkpoints. save_checkpoints: bool = True # Whether to restore from existing model checkpoints. restore_checkpoints: bool = True # Save a checkpoint every these number of steps. checkpoint_every_steps: int = 10_000 # Frequency of eval during training, e.g. every 1_000 steps. eval_every_steps: int = 5_000 # Use bfloat16 mixed precision training instead of float32. use_bfloat16: bool = True # Integer for PRNG random seed. seed: int = 0 # Parallelism mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor') axis_rules: MeshRules = MeshRules( embed='fsdp', mlp='tensor', kv='tensor', vocab='tensor', ) data_sharding: tuple[str, ...] = ('data', 'fsdp') # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. dcn_data_parallelism: int = -1 dcn_fsdp_parallelism: int = 1 dcn_tensor_parallelism: int = 1 ici_data_parallelism: int = 1 ici_fsdp_parallelism: int = -1 ici_tensor_parallelism: int = 1 def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def get_config() -> TrainConfig: """Get the default hyperparameter configuration.""" config = Config() return TrainConfig(**dataclasses.asdict(config)) ================================================ FILE: examples/gemma/helpers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Helper functions.""" from __future__ import annotations from collections.abc import Callable from typing import Any, TypeVar import flax from flax import nnx from flax.typing import VariableDict # pylint: disable=g-importing-member,g-multiple-import M = TypeVar('M', bound='nnx.Module') def _flatten_path(path: tuple[str | int, ...]) -> str: def f(item) -> str: if isinstance(item, str): return f'{item}' elif isinstance(item, int): return f'[{item}]' else: raise ValueError(f'Unexpected type {type(item)}') return '.'.join([f(item) for item in path]).replace('.[', '[') def module_from_linen_variables( module_factory: Callable[[], M], variables: VariableDict, map_key_fn: None | ( Callable[[tuple[str, ...]], tuple[str | int, ...]] ) = None, assign_val_fn: None | ( Callable[ [dict[tuple[str, ...], Any], tuple[str | int, ...], VariableDict], dict[tuple[str, ...], Any], ] ) = None, ) -> M: """Returns an `nnx.Module` initialized with the `variables` of a linen module. Args: module_factory: A no-args callable that returns an `nnx.Module`. variables: A dictionary of variables. map_key_fn: An optional function for mapping keys in the `variables` dictionary to keys in the `nnx.Module`'s state. If not provided it is assumed that after removing the collection name the keys in the `variables` dictionary are the same as the keys in the `nnx.Module`'s state. """ if map_key_fn is None: def map_key_fn(path: tuple[str, ...]) -> tuple[str | int, ...]: return path[1:] if 'params' in variables else path if assign_val_fn is None: def assign_val_fn( state: dict[tuple[str, ...], Any], mapped_path: tuple[str | int, ...], val: Any, ) -> dict[tuple[str, ...], Any]: state[mapped_path].set_value(val) return state mdl: M = nnx.eval_shape(module_factory) graph_def, state = nnx.split(mdl) state = dict(nnx.to_flat_state(state)) for path, val in flax.traverse_util.flatten_dict(variables).items(): mapped_path = map_key_fn(path) if mapped_path not in state: raise ValueError( f"'{mdl.__class__.__name__}.{_flatten_path(mapped_path)}' doesn't " f' exist (original path={path}).' ) state = assign_val_fn(state, mapped_path, val) state = nnx.from_flat_state(state) return nnx.merge(graph_def, state) ================================================ FILE: examples/gemma/helpers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for helpers.""" from __future__ import annotations from absl.testing import absltest from absl.testing import parameterized from flax import nnx import flax.linen as nn import helpers import jax import jax.numpy as jnp import numpy as np class ModuleFromLinenVariablesTest(parameterized.TestCase): @parameterized.parameters( dict( inputs_shape=(1, 2, 3, 4), num_features=10, use_bias=True, ), dict( inputs_shape=(10, 5), num_features=4, use_bias=False, ), ) def test_same_structure(self, inputs_shape, num_features, use_bias): rng_key = jax.random.PRNGKey(0) rng_inputs, rng_params = jax.random.split(rng_key) inputs = jax.random.normal(rng_inputs, inputs_shape) linen_mdl = nn.Dense( features=num_features, use_bias=use_bias, ) linen_init_vars = linen_mdl.init(rng_params, jnp.zeros(inputs_shape)) linen_output = linen_mdl.apply( linen_init_vars, inputs, ) mdl = helpers.module_from_linen_variables( module_factory=lambda: nnx.Linear( in_features=inputs_shape[-1], out_features=num_features, use_bias=use_bias, rngs=nnx.Rngs(params=rng_params), ), variables=linen_init_vars, ) output = mdl(inputs) np.testing.assert_array_equal(output, linen_output) @parameterized.parameters( dict( inputs_shape=(1, 2, 3, 4), num_features=(10, 20, 7), use_bias=(False, True, False), ), ) def test_different_structure(self, inputs_shape, num_features, use_bias): rng_key = jax.random.PRNGKey(0) rng_inputs, rng_params = jax.random.split(rng_key) inputs = jax.random.normal(rng_inputs, inputs_shape) linen_mdl = nn.Sequential([ nn.Sequential([ nn.BatchNorm(use_running_average=False), nn.Dense(features=f, use_bias=b), ]) for f, b in zip(num_features, use_bias) ]) linen_init_vars = linen_mdl.init(rng_key, jnp.zeros(inputs_shape)) linen_output, linen_vars = linen_mdl.apply( linen_init_vars, inputs, mutable=['batch_stats'], ) module_factory = lambda: nnx.Sequential(*[ nnx.Sequential( nnx.BatchNorm( num_features=in_f, use_running_average=False, rngs=nnx.Rngs(params=rng_params), ), nnx.Linear( in_features=in_f, out_features=out_f, use_bias=b, rngs=nnx.Rngs(params=rng_params), ), ) for in_f, out_f, b in zip(in_features, out_features, use_bias) ]) def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]: new_key = [] for k in key[1:]: if k.startswith('layers_'): prefix, suffix = k.split('layers_') assert not prefix, prefix new_key.append('layers') new_key.append(int(suffix)) else: new_key.append(k) return tuple(new_key) in_features = (inputs_shape[-1], *num_features[:-1]) out_features = num_features mdl = helpers.module_from_linen_variables( module_factory=module_factory, variables=linen_init_vars, map_key_fn=_map_key_fn, ) output = mdl(inputs) np.testing.assert_array_equal(output, linen_output) for i in range(len(num_features)): np.testing.assert_array_equal( mdl.layers[i].layers[0].mean[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'], ) np.testing.assert_array_equal( mdl.layers[i].layers[0].var[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'], ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for a LM1B dataset.""" import os from typing import Any import tokenizer import tensorflow as tf import tensorflow_datasets as tfds AUTOTUNE = tf.data.experimental.AUTOTUNE Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: """Normalizes feature names to 'inputs' and 'targets'.""" def __call__(self, features: Features) -> Features: features['inputs'] = features.pop('text') # Unnecessary step used for uniformizing with examples/wmt. features['targets'] = features['inputs'] return features def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: dataset_name: TFDS dataset name. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ split = tfds.split_for_jax_process(split, drop_remainder=True) ds = tfds.load(dataset_name, split=split) ds = ds.map(NormalizeFeatureNamesOp(), num_parallel_calls=AUTOTUNE) return ds def pack_dataset( dataset: tf.data.Dataset, key2length: int | dict[str, int], keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate "packed" version of a dataset to train efficiently on TPU. Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: _segmentation: an int32 tensor identifying the parts representing the original example. _position: an int32 tensor identifying the position within the original example. Example: Two input examples get combined to form an output example. The input examples are: {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} The output example is: { "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] } 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length "length", and the sequences in the output examples all have fixed (padded) length "length". Args: dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) Returns: a tf.data.Dataset """ shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) if keys is None: keys = list(shapes.keys()) for k in keys: if k not in shapes: raise ValueError( 'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys()) ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: for suffix in ['_segmentation', '_position']: key2length[k + suffix] = key2length[k] # trim to length dataset = dataset.map( lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE, ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( batch_size, padded_shapes={k: [-1] for k in keys} ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. def my_fn(x): return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) def _pack_with_tf_ops( dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: a dict from feature-key to integer Returns: a dataset. """ empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), ) return new_partial, new_outputs def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) outputs[k + '_position'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] ), ) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( [partial[k + '_position'], tf.range(new_seq_len)], 0 ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. _, partial, outputs = tf.while_loop( cond=lambda *_: True, body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] ), maximum_iterations=dynamic_batch_size, ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = tf.cumsum( tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch() def shift_data_by_truncation(x): # https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53 x['inputs'] = x['inputs'][:-1] x['targets'] = x['targets'][1:] return x # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- def preprocess_data( dataset, shuffle: bool, num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, batch_size: int = 256, drop_remainder: bool = True, prefetch_size: int = AUTOTUNE, shift: bool = True, ): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) return tf.less(l, max_len + 1) return filter_fn if max_length > 0: dataset = dataset.filter(length_filter(max_length)) if shuffle: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat(num_epochs) # Shift inputs for teacher-forced training if shift: dataset = dataset.map( shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True, ) if pack_examples: dataset = pack_dataset(dataset, max_length) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, padded_shapes={'inputs': max_length, 'targets': max_length}, padding_values={'inputs': 0, 'targets': 0}, drop_remainder=drop_remainder, ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) return dataset def get_datasets( config: Any, *, n_devices: int, vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') train_data = get_raw_dataset(config.dataset_name, 'train') if config.eval_dataset_name: eval_dataset_name = config.eval_dataset_name else: eval_dataset_name = config.dataset_name eval_data = get_raw_dataset(eval_dataset_name, config.eval_split) # Tokenize data. sp_processor = tokenizer.load_or_train_tokenizer( train_data, vocab_path=vocab_path, vocab_size=config.vocab_size, max_corpus_chars=config.max_corpus_chars, ) train_data = train_data.map( tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) eval_data = eval_data.map( tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) batch_size = config.per_device_batch_size * n_devices if config.eval_per_device_batch_size > 0: eval_batch_size = config.eval_per_device_batch_size * n_devices else: eval_batch_size = batch_size train_ds = preprocess_data( train_data, shuffle=True, num_epochs=None, pack_examples=True, batch_size=batch_size, max_length=config.max_target_length, ) eval_ds = preprocess_data( eval_data, shuffle=False, pack_examples=False, batch_size=eval_batch_size, max_length=config.max_eval_target_length, ) return train_ds, eval_ds, sp_processor ================================================ FILE: examples/gemma/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pathlib import sys import tempfile from absl.testing import absltest import tensorflow_datasets as tfds from configs import default import input_pipeline # We just use different values here to verify that the input pipeline uses the # the correct value for the 3 different datasets. _TARGET_LENGTH = 32 _EVAL_TARGET_LENGTH = 48 class InputPipelineTest(absltest.TestCase): def setUp(self): super().setUp() if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') self.train_ds, self.eval_ds = self._get_datasets() def _get_datasets(self): config = default.get_config() config.per_device_batch_size = 1 config.eval_per_device_batch_size = 2 config.vocab_size = 32 config.max_corpus_chars = 1000 config.max_target_length = _TARGET_LENGTH config.max_eval_target_length = _EVAL_TARGET_LENGTH vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') # Go two directories up to the root of the flax directory. # "/path/to/flax/examples/gemma/input_pipeline_test.py" -> "/path/to/flax" flax_root_dir = pathlib.Path(__file__).absolute().parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train_ds, eval_ds, _ = input_pipeline.get_datasets( n_devices=2, config=config, vocab_path=vocab_path ) return train_ds, eval_ds def test_train_ds(self): expected_shape = [2, _TARGET_LENGTH] # 2 devices. # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in self.train_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'inputs_position': expected_shape, 'inputs_segmentation': expected_shape, 'targets': expected_shape, 'targets_position': expected_shape, 'targets_segmentation': expected_shape, }, ) def test_eval_ds(self): expected_shape = [4, _EVAL_TARGET_LENGTH] # 2 devices. for batch in self.eval_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'targets': expected_shape, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/layers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Base layers.""" from __future__ import annotations from collections.abc import Sequence from typing import Any, Union from flax import nnx import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import Shape = Sequence[Union[int, Any]] class Einsum(nnx.Module): """Einsum is a convenience module for parameterized tensor multiplication.""" def __init__( self, einsum_str: str, shape: Shape, *, kernel_init: nnx.Initializer = nnx.initializers.normal(), rngs: nnx.Rngs, dtype: Any = jnp.float32, ): self.einsum_str = einsum_str self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype)) def __call__(self, x: ArrayLike) -> Array: return jnp.einsum(self.einsum_str, x, self.w[...]) @property def shape(self) -> Shape: return self.w.shape class RMSNorm(nnx.Module): """RMSNorm layer.""" def __init__( self, dim: int, *, scale_init: nnx.Initializer = nnx.initializers.zeros_init(), rngs: nnx.Rngs, dtype: Any = jnp.float32, ): self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype)) def __call__(self, x: Array) -> Array: dtype = self.scale.dtype var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype) # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1)) normed_inputs = normed_inputs * (1 + scale) return normed_inputs ================================================ FILE: examples/gemma/layers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for transformer layers.""" from absl.testing import absltest from absl.testing import parameterized from flax import nnx import layers import jax.numpy as jnp import numpy as np class EinsumTest(parameterized.TestCase): @parameterized.parameters( dict( inputs_shape=(1, 4), params_shape=(3, 2, 4, 3), eqn='TD,SNDH->STNH', expected_shape=(3, 1, 2, 3), ), dict( inputs_shape=(1, 2, 4), params_shape=(2, 4, 8), eqn='ANH,NHD->AD', expected_shape=(1, 8), ), ) def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape): einsum = layers.Einsum(eqn, params_shape, rngs=nnx.Rngs(params=0)) output = einsum( jnp.ones(inputs_shape), ) self.assertEqual(output.shape, expected_shape) @parameterized.parameters( dict( shape=(1, 4), ), dict( shape=(2, 5, 4, 7), ), ) def test_shape(self, shape): einsum = layers.Einsum('ij->ji', shape, rngs=nnx.Rngs(params=0)) self.assertEqual(einsum.shape, shape) class RMSNormTest(parameterized.TestCase): @parameterized.parameters(dict(x=[0.1, 0.2], expected=[0.6324429, 1.2648858])) def test_rmsnorm(self, x, expected): x = jnp.array([x]) rmsnorm = layers.RMSNorm(x.shape[-1], rngs=nnx.Rngs(params=0)) output = rmsnorm(x) np.testing.assert_array_equal(output, jnp.array([expected])) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for training Gemma model on the One Billion Word Benchmark dataset. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import train import jax from ml_collections import config_flags import tensorflow as tf FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', lock_config=True, ) flags.mark_flags_as_required(['workdir']) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': jax.config.config_with_absl() app.run(main) ================================================ FILE: examples/gemma/modules.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Transformer sub-modules.""" from __future__ import annotations from collections.abc import Sequence import enum from typing import Any, Union from flax import nnx import layers import positional_embeddings import sow_lib import jax import jax.numpy as jnp from jaxtyping import Array, ArrayLike # pylint: disable=g-importing-member,g-multiple-import LayerCache = dict[str, Array] Shape = Sequence[Union[int, Any]] K_MASK = -2.3819763e38 # Set to a large negative number. DEFAULT_ROPE_BASE_FREQUENCY = 10_000 DEFAULT_ROPE_SCALE_FACTOR = 1.0 class AttentionType(enum.Enum): GLOBAL = 1 LOCAL_SLIDING = 2 class Embedder(nnx.Module): """Embedder module.""" def __init__( self, vocab_size: int, embed_dim: int, *, embedding_init: nnx.Initializer = nnx.initializers.normal(), dtype: Any = jnp.float32, rngs: nnx.Rngs, ): self.input_embedding = nnx.Param( embedding_init(rngs.params(), (vocab_size, embed_dim), dtype) ) def encode(self, x: ArrayLike) -> Array: x = self.input_embedding[(x,)] x *= jnp.sqrt(x.shape[-1]).astype(x.dtype) return x def decode(self, x: ArrayLike) -> Array: return jnp.dot(x, self.input_embedding.T) @property def embed_dim(self): return self.input_embedding.shape[1] @property def num_embed(self): return self.input_embedding.shape[0] class Attention(nnx.Module): """Attention module.""" def __init__( self, num_heads: int, num_kv_heads: int, features: int, head_dim: int, query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, rope_scale_factor: float = DEFAULT_ROPE_SCALE_FACTOR, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), dtype: Any = jnp.float16, kernel_init: nnx.Initializer = nnx.initializers.normal(), scale_init: nnx.Initializer = nnx.initializers.zeros_init(), attn_vec_einsum_kernel_init: nnx.Initializer | None = None, qkv_einsum_kernel_init: nnx.Initializer | None = None, q_einsum_kernel_init: nnx.Initializer | None = None, kv_einsum_kernel_init: nnx.Initializer | None = None, ): if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None: raise ValueError( '`sliding_window_size` must be set if `attn_type` is Local Sliding.' ) self.query_pre_attn_scalar = query_pre_attn_scalar self.attn_type = attn_type self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap attn_vec_einsum_kernel_init = attn_vec_einsum_kernel_init if attn_vec_einsum_kernel_init else kernel_init self.attn_vec_einsum = layers.Einsum( einsum_str='BTNH,NHD->BTD', shape=(num_heads, head_dim, features), kernel_init=attn_vec_einsum_kernel_init, dtype=dtype, rngs=rngs, ) self.rope_base_frequency = rope_base_frequency self.rope_scale_factor = rope_scale_factor self.use_qk_norm = use_qk_norm self.sow_config = sow_config if num_heads == num_kv_heads: qkv_einsum_kernel_init = qkv_einsum_kernel_init if qkv_einsum_kernel_init else kernel_init self.qkv_einsum = layers.Einsum( einsum_str='BTD,SNDH->SBTNH', shape=(3, num_heads, features, head_dim), kernel_init=qkv_einsum_kernel_init, dtype=dtype, rngs=rngs, ) else: if num_heads % num_kv_heads != 0: raise ValueError( f"Number of query heads ({num_heads}) must be divisible by " f"number of key/value heads ({num_kv_heads})." ) q_einsum_kernel_init = q_einsum_kernel_init if q_einsum_kernel_init else kernel_init self.q_einsum = layers.Einsum( einsum_str='BTD,NDH->BTNH', shape=(num_heads, features, head_dim), kernel_init=q_einsum_kernel_init, dtype=dtype, rngs=rngs, ) kv_einsum_kernel_init = kv_einsum_kernel_init if kv_einsum_kernel_init else kernel_init self.kv_einsum = layers.Einsum( einsum_str='BSD,CKDH->CBSKH', shape=(2, num_kv_heads, features, head_dim), kernel_init=kv_einsum_kernel_init, dtype=dtype, rngs=rngs, ) if self.use_qk_norm: self._query_norm = layers.RMSNorm( head_dim, scale_init=scale_init, dtype=dtype, rngs=rngs, ) self._key_norm = layers.RMSNorm( head_dim, scale_init=scale_init, dtype=dtype, rngs=rngs, ) def __call__( self, x: Array, segment_pos: Array, cache: LayerCache | None, attn_mask: Array, ) -> tuple[LayerCache | None, Array]: seq_len = x.shape[1] if self.use_qkv_einsum: query_proj, key_proj, value_proj = self.qkv_einsum(x) else: query_proj = self.q_einsum(x) key_proj, value_proj = self.kv_einsum(x) if self.use_qk_norm: query_proj = self._query_norm(query_proj) key_proj = self._key_norm(key_proj) query_proj = positional_embeddings.apply_rope( query_proj, segment_pos, head_dim=self.head_dim, max_wavelength=self.rope_base_frequency, scale_factor=self.rope_scale_factor, ) query_scaled = query_proj * self.query_pre_attn_scalar key_proj = positional_embeddings.apply_rope( key_proj, segment_pos, head_dim=self.head_dim, max_wavelength=self.rope_base_frequency, scale_factor=self.rope_scale_factor, ) # Cache is left aligned. if cache is not None: end_index = cache['end_index'][0] slice_indices = (0, end_index % cache['v'].shape[1], 0, 0) value_proj = jax.lax.dynamic_update_slice( cache['v'], value_proj, slice_indices, ) key_proj = jax.lax.dynamic_update_slice( cache['k'], key_proj, slice_indices ) use_gqa = self.num_heads > self.num_kv_heads and self.num_kv_heads > 1 if use_gqa: # Reshape matrices to enable einsums over groups. num_groups = self.num_heads // self.num_kv_heads batch_size, seq_size, _, head_dim = query_scaled.shape query_scaled = query_scaled.reshape( (batch_size, seq_size, self.num_kv_heads, num_groups, head_dim) ) logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) logits = logits.reshape( (batch_size, seq_size, self.num_heads, -1) ) else: logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) if self.attn_logits_soft_cap is not None: logits = jnp.tanh(logits / self.attn_logits_soft_cap) logits = logits * self.attn_logits_soft_cap if self.attn_type == AttentionType.LOCAL_SLIDING: if self.sliding_window_size is None: raise ValueError( 'sliding_window_size must be set if attn_type is Local Sliding.' ) all_ones = jnp.ones_like(attn_mask) sliding_mask = jnp.triu( all_ones, -1 * self.sliding_window_size + 1 ) * jnp.tril(all_ones, self.sliding_window_size - 1) attn_mask = sliding_mask * attn_mask padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) self.sow_config.maybe_sow_attn_logits_topk(padded_logits, self) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) if use_gqa: # Reshape matrices to enable einsums over groups. num_groups = self.num_heads // self.num_kv_heads batch_size, seq_size1, _, _ = probs.shape probs = probs.reshape( (batch_size, seq_size1, self.num_kv_heads, num_groups, -1) ) encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) encoded = encoded.reshape( (batch_size, seq_size, self.num_heads, head_dim) ) else: encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) attn_output = self.attn_vec_einsum(encoded) if cache is not None: new_cache = { 'v': value_proj, 'k': key_proj, 'end_index': cache['end_index'] + seq_len, } else: new_cache = None return new_cache, attn_output @property def head_dim(self): return self.attn_vec_einsum.shape[1] @property def num_heads(self): return ( self.qkv_einsum.shape[1] if self.use_qkv_einsum else self.q_einsum.shape[0] ) @property def num_kv_heads(self): return ( self.qkv_einsum.shape[1] if self.use_qkv_einsum else self.kv_einsum.shape[1] ) @property def use_qkv_einsum(self): return hasattr(self, 'qkv_einsum') and self.qkv_einsum is not None def init_cache( self, cache_size: int, batch_size: int, dtype: jnp.dtype = jnp.bfloat16, ) -> LayerCache: return { 'v': jnp.zeros( (batch_size, cache_size, self.num_kv_heads, self.head_dim), dtype=dtype, ), 'k': jnp.zeros( (batch_size, cache_size, self.num_kv_heads, self.head_dim), dtype=dtype, ), 'end_index': jnp.zeros((batch_size,), dtype=jnp.int32), } class FeedForward(nnx.Module): """Feed forward module.""" def __init__( self, features: int, hidden_dim: int, *, kernel_init: nnx.Initializer = nnx.initializers.normal(), rngs: nnx.Rngs, sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), dtype: Any = jnp.float32, ): self.gate_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, kernel_init=kernel_init, dtype=dtype, ) self.up_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, kernel_init=kernel_init, dtype=dtype, ) self.down_proj = nnx.Linear( in_features=hidden_dim, out_features=features, use_bias=False, rngs=rngs, kernel_init=kernel_init, dtype=dtype, ) self.sow_config = sow_config def __call__(self, x: ArrayLike) -> Array: ff_gate = self.gate_proj(x) gate_value = nnx.gelu(ff_gate) ff1 = self.up_proj(x) activations = gate_value * ff1 self.sow_config.maybe_sow_mlp_hidden_topk(activations, self) outputs = self.down_proj(activations) return outputs class Block(nnx.Module): """Transformer block.""" def __init__( self, config, # TransformerConfig attn_type: AttentionType, *, rngs: nnx.Rngs, sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ): num_heads = config.num_heads num_kv_heads = config.num_kv_heads embed_dim = config.embed_dim head_dim = config.head_dim hidden_dim = config.hidden_dim sliding_window_size = config.sliding_window_size use_post_attn_norm = config.use_post_attn_norm use_post_ffw_norm = config.use_post_ffw_norm query_pre_attn_scalar = config.query_pre_attn_scalar() if attn_type == AttentionType.LOCAL_SLIDING: rope_base_frequency = config.local_base_frequency rope_scale_factor = config.local_scale_factor else: rope_base_frequency = config.global_base_frequency rope_scale_factor = config.global_scale_factor attn_logits_soft_cap = config.attn_logits_soft_cap use_qk_norm = config.use_qk_norm dtype = config.dtype self.pre_attention_norm = layers.RMSNorm( embed_dim, scale_init=maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), rngs=rngs, dtype=dtype, ) self.attn = Attention( num_heads=num_heads, num_kv_heads=num_kv_heads, features=embed_dim, head_dim=head_dim, query_pre_attn_scalar=query_pre_attn_scalar, attn_type=attn_type, rope_base_frequency=rope_base_frequency, rope_scale_factor=rope_scale_factor, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, rngs=rngs, use_qk_norm=use_qk_norm, sow_config=sow_config, attn_vec_einsum_kernel_init=maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, (None, "embed", "kv"), # sharded array shape: (num_heads, head_dim, features) ), qkv_einsum_kernel_init=maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, (None, None, "embed", "kv"), # sharded array shape: (3, num_heads, features, head_dim) ), q_einsum_kernel_init=maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, (None, "embed", "kv"), # sharded array shape: (num_heads, features, head_dim) ), kv_einsum_kernel_init=maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, (None, None, "embed", "kv"), # sharded array shape: (2, num_kv_heads, features, head_dim) ), scale_init=maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), dtype=dtype, ) if use_post_attn_norm: self.post_attention_norm = layers.RMSNorm( embed_dim, scale_init=maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), rngs=rngs, dtype=dtype, ) else: self.post_attention_norm = None self.pre_ffw_norm = layers.RMSNorm( embed_dim, scale_init=maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), rngs=rngs, dtype=dtype, ) self.mlp = FeedForward( features=embed_dim, hidden_dim=hidden_dim, kernel_init=maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, ("embed", "mlp"), ), rngs=rngs, sow_config=sow_config, ) if use_post_ffw_norm: self.post_ffw_norm = layers.RMSNorm( embed_dim, scale_init=maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), rngs=rngs, dtype=dtype, ) else: self.post_ffw_norm = None self.sow_config = sow_config def __call__( self, x: jax.Array, segment_pos: jax.Array, cache: LayerCache | None, attn_mask: jax.Array, ) -> tuple[LayerCache | None, jax.Array]: # Attention. attn_inputs = self.pre_attention_norm(x) cache, attn_output = self.attn( attn_inputs, segment_pos, cache, attn_mask, ) if self.post_attention_norm is not None: attn_output = self.post_attention_norm(attn_output) x += attn_output self.sow_config.maybe_sow_rs_after_attention(x, self) # Feed forward. ffw_inputs = self.pre_ffw_norm(x) ffw_outputs = self.mlp(ffw_inputs) if self.post_ffw_norm is not None: ffw_outputs = self.post_ffw_norm(ffw_outputs) x += ffw_outputs self.sow_config.maybe_sow_rs_after_ffw(x, self) return cache, x def init_cache( self, cache_size: int, batch_size: int, dtype: jnp.dtype = jnp.bfloat16, ) -> LayerCache: return self.attn.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=dtype, ) def maybe_with_partitioning(fn, axis_rules, axis_rules_args=()): if axis_rules is None: return fn return nnx.with_partitioning(fn, axis_rules(*axis_rules_args)) ================================================ FILE: examples/gemma/modules_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for transformer modules.""" from absl.testing import absltest from absl.testing import parameterized from flax import nnx import modules import transformer as transformer_lib import jax import jax.numpy as jnp import numpy as np class EmbedderTest(parameterized.TestCase): @parameterized.parameters( dict( vocab_size=10, embed_dim=4, inputs=[2, 3], expected=[[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]], ), ) def test_encode(self, vocab_size, embed_dim, inputs, expected): embedder = modules.Embedder( vocab_size=vocab_size, embed_dim=embed_dim, rngs=nnx.Rngs(params=0), ) embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim)) output = embedder.encode(inputs) np.testing.assert_array_equal(output, jnp.array(expected)) @parameterized.parameters( dict( vocab_size=5, embed_dim=2, inputs=[[1, 2]], expected=[[3.0, 3.0, 3.0, 3.0, 3.0]], ), ) def test_decode(self, vocab_size, embed_dim, inputs, expected): embedder = modules.Embedder( vocab_size=vocab_size, embed_dim=embed_dim, rngs=nnx.Rngs(params=0), ) embedder.input_embedding[...] = jnp.ones((vocab_size, embed_dim)) output = embedder.decode(jnp.array(inputs)) np.testing.assert_array_equal(output, jnp.array(expected)) class AttentionTest(parameterized.TestCase): @parameterized.parameters( dict( head_dim=2, ), dict( head_dim=20, ), ) def test_head_dim(self, head_dim): attn = modules.Attention( num_heads=4, num_kv_heads=2, features=5, head_dim=head_dim, query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) self.assertEqual(attn.head_dim, head_dim) @parameterized.parameters( dict( num_heads=4, num_kv_heads=2, expected_use_qkv_einsum=False, ), dict( num_heads=3, num_kv_heads=3, expected_use_qkv_einsum=True, ), ) def test_use_qkv_einsum( self, num_heads, num_kv_heads, expected_use_qkv_einsum, ): attn = modules.Attention( num_heads=num_heads, num_kv_heads=num_kv_heads, features=5, head_dim=8, query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) self.assertEqual(attn.use_qkv_einsum, expected_use_qkv_einsum) self.assertEqual(hasattr(attn, 'q_einsum'), not expected_use_qkv_einsum) self.assertEqual(hasattr(attn, 'kv_einsum'), not expected_use_qkv_einsum) @parameterized.parameters( dict( num_heads=2, head_dim=4, features=8, segment_pos=0, cache_size=3, batch_size=2, expected_cache_shape=(2, 3, 2, 4), expected_output_shape=(2, 1, 8), ), ) def test_attention( self, num_heads, head_dim, features, segment_pos, cache_size, batch_size, expected_cache_shape, expected_output_shape, ): attn_mask = jnp.ones((batch_size, 1, cache_size)) attn = modules.Attention( num_heads, num_heads, features, head_dim, query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) x = jnp.ones((batch_size, 1, features)) cache, output = attn(x, jnp.array([[segment_pos]]), cache, attn_mask) self.assertEqual(cache['k'].shape, expected_cache_shape) self.assertEqual(output.shape, expected_output_shape) @parameterized.parameters( dict( sliding_window_size=2, ), ) def test_sliding_window(self, sliding_window_size): num_heads = 2 head_dim = 4 features = 8 segment_pos = 0 cache_size = 3 batch_size = 2 attn_mask = jnp.ones((batch_size, 1, cache_size)) x = jnp.ones((batch_size, 1, features)) attn = modules.Attention( num_heads, num_heads, features, head_dim, query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) _, output = attn(x, jnp.array([[segment_pos]]), cache, attn_mask) sliding_attn = modules.Attention( num_heads, num_heads, features, head_dim, query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.LOCAL_SLIDING, sliding_window_size=sliding_window_size, rngs=nnx.Rngs(params=0), ) _, sliding_output = sliding_attn( x, jnp.array([[segment_pos]]), cache, attn_mask ) self.assertFalse((output == sliding_output).all()) class FeedForwardTest(parameterized.TestCase): @parameterized.parameters( dict( features=2, hidden_dim=3, batch_size=2, expected_val=[11.72758674, 47.99916], expected_shape=(2, 1, 2), ), ) def test_ffw( self, features, hidden_dim, batch_size, expected_val, expected_shape ): inputs = jnp.arange(1, batch_size+1)[:, None, None] inputs = jnp.repeat(inputs, features, axis=-1) ffw = modules.FeedForward( features=features, hidden_dim=hidden_dim, rngs=nnx.Rngs(params=0), ) ffw.gate_proj.kernel[...] = jnp.ones((features, hidden_dim)) ffw.up_proj.kernel[...] = jnp.ones((features, hidden_dim)) ffw.down_proj.kernel[...] = jnp.ones((hidden_dim, features)) with jax.default_matmul_precision('float32'): outputs = ffw(inputs) np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val) self.assertEqual(outputs.shape, expected_shape) class BlockTest(parameterized.TestCase): @parameterized.parameters( dict( num_heads=2, embed_dim=4, head_dim=6, cache_size=3, batch_size=2, use_post_attn_norm=False, use_post_ffw_norm=False, expected_cache_shape=(2, 3, 2, 6), expected_output_shape=(2, 1, 4), ), ) def test_block( self, num_heads, embed_dim, head_dim, cache_size, batch_size, use_post_attn_norm, use_post_ffw_norm, expected_cache_shape, expected_output_shape, ): inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) config = transformer_lib.TransformerConfig( num_heads=num_heads, num_kv_heads=num_heads, embed_dim=embed_dim, head_dim=head_dim, hidden_dim=1, use_post_attn_norm=use_post_attn_norm, use_post_ffw_norm=use_post_ffw_norm, final_logit_softcap=None, num_layers=-1, num_embed=-1, attention_types=[], ) block = modules.Block( config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = block.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) new_cache, outputs = block(inputs, jnp.array([[0]]), cache, attn_mask) self.assertEqual(block.post_attention_norm is not None, use_post_attn_norm) self.assertEqual(new_cache['k'].shape, expected_cache_shape) self.assertEqual(outputs.shape, expected_output_shape) @parameterized.parameters( dict( num_heads=1, embed_dim=1, head_dim=2, cache_size=1, batch_size=1, ), ) def test_post_attention_norm( self, num_heads, embed_dim, head_dim, cache_size, batch_size, ): inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) normed_block_config = transformer_lib.TransformerConfig( num_heads=num_heads, num_kv_heads=num_heads, embed_dim=embed_dim, head_dim=head_dim, hidden_dim=1, use_post_attn_norm=True, use_post_ffw_norm=False, final_logit_softcap=None, num_layers=-1, num_embed=-1, attention_types=[], ) normed_block = modules.Block( normed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block_config = transformer_lib.TransformerConfig( num_heads=num_heads, num_kv_heads=num_heads, embed_dim=embed_dim, head_dim=head_dim, hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=False, final_logit_softcap=None, num_layers=-1, num_embed=-1, attention_types=[], ) unnormed_block = modules.Block( unnormed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) # Ok to use the same cache for both blocks. cache = normed_block.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) all_outputs = [] for block in (normed_block, unnormed_block): _, outputs = block(inputs, jnp.array([[0]]), cache, attn_mask) all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) @parameterized.parameters( dict( num_heads=1, embed_dim=1, head_dim=2, cache_size=1, batch_size=1, ), ) def test_post_ffw_norm( self, num_heads, embed_dim, head_dim, cache_size, batch_size, ): inputs = jnp.ones((batch_size, 1, embed_dim)) attn_mask = jnp.ones((batch_size, 1, cache_size)) normed_block_config = transformer_lib.TransformerConfig( num_heads=num_heads, num_kv_heads=num_heads, embed_dim=embed_dim, head_dim=head_dim, hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=True, final_logit_softcap=None, num_layers=-1, num_embed=-1, attention_types=[], ) normed_block = modules.Block( normed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block_config = transformer_lib.TransformerConfig( num_heads=num_heads, num_kv_heads=num_heads, embed_dim=embed_dim, head_dim=head_dim, hidden_dim=1, use_post_attn_norm=False, use_post_ffw_norm=False, final_logit_softcap=None, num_layers=-1, num_embed=-1, attention_types=[], ) unnormed_block = modules.Block( unnormed_block_config, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) # Ok to use the same cache for both blocks. cache = normed_block.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) all_outputs = [] for block in (normed_block, unnormed_block): _, outputs = block(inputs, jnp.array([[0]]), cache, attn_mask) all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking print(normed_output.shape, unnormed_output.shape) print(f"{normed_output=}") print(f"{unnormed_output=}") self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/params.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Utils for loading Gemma params.""" from collections.abc import Mapping import functools from typing import Any import jax import jax.numpy as jnp import orbax.checkpoint Params = Mapping[str, Any] def load_and_format_params(path: str) -> Params: """Loads parameters and formats them for compatibility.""" params = load_params(path) param_state = jax.tree.map(jnp.array, params) remapped_params = param_remapper(param_state) nested_params = nest_params(remapped_params) return nested_params def load_metadata(path: str) -> Any | None: """Loads metadata from a checkpoint path.""" checkpointer = orbax.checkpoint.PyTreeCheckpointer() metadata = checkpointer.metadata(path) return metadata @functools.cache def load_params(path: str) -> Params: """Loads parameters from a checkpoint path.""" checkpointer = orbax.checkpoint.PyTreeCheckpointer() params = checkpointer.restore(path) return params def param_remapper(orig_params: Params) -> Params: """Remaps params to new module layout. This is needed here because the model definition does not have a separate `mlp` module. Args: orig_params: original dict of parameters in Gemma format. Returns: dict of params with different names. """ new_params = {} for k, v in orig_params.items(): if 'mlp/' in k: layer_name, param = k.rsplit('/', maxsplit=1) if layer_name not in new_params: new_params[layer_name] = {} if 'w' in v: new_params[layer_name][param] = v['w'] else: new_params[k] = v return new_params def nest_params(params: Params) -> Params: """Nests params as a dict of dicts rather than a flat dict.""" nested_params = {} for path, param in params.items(): *path, leaf = path.split('/') subdict = nested_params for key in path: subdict = subdict.setdefault(key, {}) subdict[leaf] = param return nested_params ================================================ FILE: examples/gemma/positional_embeddings.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Utils for positional embeddings (including RoPE).""" import jax import jax.numpy as jnp _MAX_WAVELENGTH = 10_000 def add_positional_embedding( input_embedding: jax.Array, position: int, max_wavelength: int = _MAX_WAVELENGTH, ) -> jax.Array: """Adds positional embeddings to input embeddings.""" embed_dim = input_embedding.shape[-1] num_timescales = embed_dim // 2 log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum( jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 ) inv_timescales = jnp.exp( jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment ) scaled_time = position * inv_timescales signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)]) signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]]) position_embedding = signal.astype(jnp.float32) return input_embedding + position_embedding def apply_rope( inputs: jax.Array, # [B, L] positions: jax.Array, # [B, L] head_dim: int, max_wavelength: int = _MAX_WAVELENGTH, scale_factor: float = 1.0, ) -> jax.Array: """Applies RoPE.""" fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim timescale = max_wavelength**fraction sinusoid_inp = ( positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :] ) sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :] if scale_factor < 1.0: raise ValueError(f'scale_factor must be >= 1.0, got {scale_factor}') sinusoid_inp /= scale_factor sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_half, second_half = jnp.split(inputs, 2, axis=-1) first_part = first_half * cos - second_half * sin second_part = second_half * cos + first_half * sin out = jnp.concatenate([first_part, second_part], axis=-1) return out.astype(inputs.dtype) ================================================ FILE: examples/gemma/positional_embeddings_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for the positional embeddings utilities.""" from absl.testing import absltest from absl.testing import parameterized import positional_embeddings import jax import jax.numpy as jnp import numpy as np # positional_embeddings.py uses implicit rank broadcast and needs this config to # be 'allow', while the rest of Flax can use jax_numpy_rank_promotion=raise. jax.config.update('jax_numpy_rank_promotion', 'allow') class PositionalEmbeddingsTest(parameterized.TestCase): @parameterized.parameters( dict( input_embedding_shape=(2, 1, 1, 5), position=3, max_wavelength=100, expected=[[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]], [[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]]] ) ) def test_adds_positional_embeddings( self, input_embedding_shape, position, max_wavelength, expected ): outputs = positional_embeddings.add_positional_embedding( jnp.ones(input_embedding_shape), position, max_wavelength ) np.testing.assert_array_almost_equal(outputs, jnp.array(expected)) @parameterized.parameters( dict( input_embedding_shape=(2, 1, 2, 4), position=3, head_dim=4, max_wavelength=100, expected=[ [[ [-1.1311126, 0.6598157, -0.8488725, 1.2508571], [-1.1311126, 0.6598157, -0.8488725, 1.2508571], ]], [[ [-1.1311126, 0.6598157, -0.8488725, 1.2508571], [-1.1311126, 0.6598157, -0.8488725, 1.2508571], ]], ], ) ) def test_rope_positional_embeddings( self, input_embedding_shape, position, head_dim, max_wavelength, expected ): outputs = positional_embeddings.apply_rope( jnp.ones(input_embedding_shape), jnp.array([[position]]), head_dim, max_wavelength, ) np.testing.assert_array_almost_equal(outputs, jnp.array(expected)) if __name__ == "__main__": absltest.main() ================================================ FILE: examples/gemma/requirements.txt ================================================ absl-py~=2.2 clu==0.0.12 flax~=0.10 jax~=0.6 mlcroissant~=1.0 numpy~=2.1 optax~=0.2 sentencepiece~=0.2 jaxtyping~=0.3 tensorflow~=2.19 tensorflow-datasets~=4.9 tensorflow-text~=2.19 ================================================ FILE: examples/gemma/sampler.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Sampler for Gemma transformer. An example of a sampling class for a Gemma model. """ from __future__ import annotations from collections.abc import Sequence import dataclasses import flax from flax import nnx import modules import sow_lib import transformer as transformer_lib from flax.nnx import graph from flax.nnx import statelib import jax import jax.numpy as jnp import sentencepiece as spm def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.ndarray: """Sample a token using top-p sampling.""" probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1]) cumsum_probs = jnp.cumsum(probs_sorted, axis=-1) mask = cumsum_probs - probs_sorted > p probs_sorted = jnp.where(mask, 0.0, probs_sorted) probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True) next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted)) next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1) next_token = jnp.squeeze(next_token, axis=-1) return next_token def _compute_attention_masks( time_step: jax.Array, seq_len: int, input_mask: jax.Array ) -> jax.Array: """Computes causal attention mask.""" batch_size = input_mask.shape[0] batch_time_step = jnp.full((batch_size, 1), time_step, dtype=jnp.uint32) causal_padding = jnp.greater( jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step ) max_seq_len = min(input_mask.shape[-1], seq_len) input_mask = jax.lax.dynamic_slice( input_mask, (0, jnp.maximum(time_step - seq_len + 1, 0)), (batch_size, max_seq_len), ) input_mask = ( jnp.zeros((batch_size, seq_len), dtype=jnp.bool_) .at[:, :max_seq_len] .set(input_mask) ) causal_padding = jnp.logical_or(causal_padding, input_mask) attention_mask = causal_padding[:, jnp.newaxis, :].astype(jnp.bool_) return ~attention_mask @flax.struct.dataclass class _SamplingState: """Internal sampling state.""" # Decoding step. decoding_step: jnp.int32 # Number of tokens in the prompt. num_input_tokens: jnp.ndarray # [B] # Fixed-size buffer for accumulating the output tokens. token_buffer: jnp.ndarray # [B, L] # Position indices, based on ignoring pad tokens. positions: jnp.ndarray # [B, L] # Model state for conditioning the model on autoregressively. cache: dict[str, modules.LayerCache] # Is decoding done on the given sequence? done: jnp.ndarray # [B] # Total sampling steps (including the prompt). total_sampling_steps: int # Fixed-size buffer for accumulating the output logits. logits_buffer: jnp.ndarray | None # [B, L, V] # List of tokens that are forbidden to be generated. forbidden_token_ids: Sequence[int] | None # Intermediate activations from the model if requested. intermediates: sow_lib.TransformerIntermediates | None # Random seed for sampling. seed: jax.Array # Tempurature for top_p sampling. temperature: float = flax.struct.field(pytree_node=False) # Top-p sampling threshold. top_p: float = flax.struct.field(pytree_node=False) @dataclasses.dataclass class SamplerOutput: """Output of the sampler.""" # Decoded samples from the model. text: list[str] # Per-step logits used during sampling. logits: list[list[float]] # Tokens corresponding to the generated samples. tokens: list[list[int]] # Intermediate activations from the model if requested. intermediates: sow_lib.TransformerIntermediates | None = None class Sampler: """Sampler for gemma transformer.""" def __init__( self, transformer: transformer_lib.Transformer, vocab: spm.SentencePieceProcessor, cache_size: int = 1024, ): """Initializes a sampler for a Gemma model. Args: transformer: an instance of the Gemma transformer. vocab: vocabulary of the given model. cache_size: size of the cache for the transformer. """ self.vocab = vocab self.cache_size = cache_size graphdef, state = nnx.split(transformer) self._transformer_graphdef: graph.NodeDef = graphdef self._transformer_state: statelib.State = state # we separate out state and graph def so that the state can be passed as an # argument to _sample_fn, resulting in it not being treated as a static # arg. This greatly reduces the size of the HLO and reduces compile time self._compiled_sample_fn = jax.jit(self._sample_fn) @property def transformer(self) -> transformer_lib.Transformer: return nnx.merge(self._transformer_graphdef, self._transformer_state) @property def transformer_state(self) -> statelib.State: return self._transformer_state @transformer_state.setter def transformer_state(self, state: statelib.State) -> statelib.State: def check_tree_structure(tree1, tree2): if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure( tree2 ): raise ValueError( "New state must have the same structure as the old state." ) def check_shape_dtype(x, y): return jnp.shape(x) == jnp.shape(y) and jnp.dtype(x) == jnp.dtype(y) if not all( jax.tree_util.tree_leaves( jax.tree_util.tree_map(check_shape_dtype, tree1, tree2) ) ): raise ValueError( "New state must have the same shape and dtype as the old state." ) check_tree_structure(self._transformer_state, state) self._transformer_state = state @property def dtype(self) -> jnp.dtype: return jax.tree_util.tree_leaves( nnx.to_flat_state(self._transformer_state) )[0].dtype def _sample_step( self, params: statelib.State, sampler_state: _SamplingState ) -> _SamplingState: """Performs a single sampling step.""" batch_size = sampler_state.token_buffer.shape[0] decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32) last_token = sampler_state.token_buffer[:, decoding_step] input_mask = sampler_state.token_buffer == self.vocab.pad_id() attention_mask = _compute_attention_masks( decoding_step, self.cache_size, input_mask ) step_positions = jnp.expand_dims( sampler_state.positions[:, decoding_step], -1 ) last_token = last_token.reshape((batch_size, 1)) transformer = nnx.merge(self._transformer_graphdef, params) logits, cache = transformer( last_token, step_positions, sampler_state.cache, attention_mask, ) if sampler_state.forbidden_token_ids: logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf) def sample_top_p(logits, key): probs = jax.nn.softmax(logits[:, -1] / sampler_state.temperature, axis=-1) next_token = _sample_top_p(probs, sampler_state.top_p, key) return next_token def sample_best(logits): next_token = jnp.argmax(logits, axis=-1) next_token = next_token[:, 0] return next_token if sampler_state.temperature > 0: key = jax.random.fold_in(sampler_state.seed, decoding_step) next_token_candidate = sample_top_p(logits, key) else: next_token_candidate = sample_best(logits) next_token_candidate = jnp.where( decoding_step < sampler_state.num_input_tokens - 1, sampler_state.token_buffer[:, decoding_step + 1], next_token_candidate, ) token_buffer = sampler_state.token_buffer.at[:, decoding_step + 1].set( next_token_candidate ) if sampler_state.logits_buffer is not None: next_logits = jnp.squeeze(logits, 1) logits_buffer = sampler_state.logits_buffer.at[:, decoding_step + 1].set( next_logits ) else: logits_buffer = sampler_state.logits_buffer if sampler_state.intermediates is not None: sampler_state.intermediates.merge(decoding_step, transformer) done = sampler_state.done | jnp.equal( token_buffer[:, decoding_step + 1], self.vocab.eos_id() ) return _SamplingState( decoding_step=sampler_state.decoding_step + 1, num_input_tokens=sampler_state.num_input_tokens, token_buffer=token_buffer, positions=sampler_state.positions, logits_buffer=logits_buffer, cache=cache, done=done, total_sampling_steps=sampler_state.total_sampling_steps, forbidden_token_ids=sampler_state.forbidden_token_ids, intermediates=sampler_state.intermediates, temperature=sampler_state.temperature, top_p=sampler_state.top_p, seed=sampler_state.seed, ) def init_sample_state( self, all_input_ids: list[jax.Array], total_sampling_steps: int, include_logits: bool, forbidden_token_ids: Sequence[int] | None, temperature: float, top_p: float, seed: jax.Array, ) -> _SamplingState: """Initializes the sampling state given input prompts.""" batch_size = len(all_input_ids) num_input_tokens = [len(input_ids) for input_ids in all_input_ids] buffer_size = total_sampling_steps + 1 token_buffer = jnp.full( ( batch_size, buffer_size, ), self.vocab.pad_id(), dtype=jnp.int32, ) input_mask = jnp.ones_like(token_buffer, dtype=jnp.bool_) for i, (input_ids, num_tokens) in enumerate( zip(all_input_ids, num_input_tokens) ): token_buffer = token_buffer.at[i, :num_tokens].set(input_ids) input_mask = input_mask.at[i, :num_tokens].set( input_ids != self.vocab.pad_id() ) positions = transformer_lib.build_positions_from_mask(input_mask) done = jnp.zeros((batch_size,), dtype=jnp.bool_) if include_logits: logits_buffer = jnp.zeros( (batch_size, buffer_size, self.transformer.num_embed), dtype=jnp.float32, ) else: logits_buffer = None return _SamplingState( decoding_step=0, num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32), token_buffer=token_buffer, positions=positions, logits_buffer=logits_buffer, cache=self.transformer.init_cache( cache_size=self.cache_size, batch_size=batch_size, dtype=self.dtype, ), done=done, total_sampling_steps=total_sampling_steps, forbidden_token_ids=forbidden_token_ids, intermediates=self.transformer.init_intermediates( batch_size, buffer_size, self.transformer.sow_config ), temperature=temperature, top_p=top_p, seed=seed, ) def tokenize(self, input_string: str) -> jax.Array: """Tokenizes the input string.""" input_ids = self.vocab.EncodeAsIds(input_string) input_ids = jnp.array( [self.vocab.bos_id()] + jnp.array(input_ids).tolist(), dtype=jnp.int32 ) return input_ids def mask_tokens_after_eos_ids(self, token_buffer): """Mask token IDs after the EOS token with the padding ID.""" eos_id = self.vocab.eos_id() eos_exists = jnp.any(jnp.equal(token_buffer, eos_id), axis=-1) eos_indices = jnp.where( eos_exists, jnp.argmax(jnp.equal(token_buffer, eos_id), axis=-1), token_buffer.shape[-1], ) mask = jnp.less_equal( jnp.arange(token_buffer.shape[-1]), eos_indices[:, None] ) masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask) return masked_token_buffer def _sample_fn( self, params: statelib.State, initial_sampling_state: _SamplingState, ) -> _SamplingState: """Internal sampling function (to be jitted).""" def sample_with_params(sampler_state: _SamplingState): return self._sample_step(params, sampler_state) def cond_fn(sampler_state: _SamplingState): return ( sampler_state.decoding_step < sampler_state.total_sampling_steps ) & jnp.any(jnp.logical_not(sampler_state.done)) return jax.lax.while_loop( cond_fn, sample_with_params, initial_sampling_state ) def __call__( self, input_strings: Sequence[str], total_generation_steps: int, echo: bool = False, return_logits: bool = True, forbidden_tokens: Sequence[str] | None = None, temperature: float = 0.0, top_p: float = 0.95, seed: jax.Array | None = None, ) -> SamplerOutput: """Samples a completion of the input string. Args: input_strings: input prompts to feed to the model for sampling. total_generation_steps: number of generation steps. will correspond to the longest prompt in the batch. echo: whether to return the prompt as part of the output sample. return_logits: whether to return per-step logits used during generation. forbidden_tokens: list of tokens that are forbidden to be generated. Each token must map to a single token id in the vocab. temperature: temperature for sampling. top_p: top-p sampling threshold. seed: random seed for sampling. Returns: sampler_output: A SamplerOutput object containing the generated samples. """ forbidden_token_ids = None if forbidden_tokens is not None: forbidden_token_ids = [] for token in forbidden_tokens: token_id = self.vocab.EncodeAsIds(token) if len(token_id) != 1: raise ValueError( "Forbidden tokens must map to single token ids in the vocab." ) forbidden_token_ids.extend(token_id) forbidden_token_ids = tuple(forbidden_token_ids) all_input_ids = [self.tokenize(x) for x in input_strings] max_input_length = max(len(input_ids) for input_ids in all_input_ids) total_sampling_steps = max_input_length + total_generation_steps if seed is None: seed = jax.random.PRNGKey(0) initial_sampling_state = self.init_sample_state( all_input_ids, include_logits=return_logits, total_sampling_steps=total_sampling_steps, forbidden_token_ids=forbidden_token_ids, temperature=temperature, top_p=top_p, seed=seed, ) sampling_state = self._compiled_sample_fn( self._transformer_state, initial_sampling_state ) masked_token_buffer = self.mask_tokens_after_eos_ids( sampling_state.token_buffer ) out_tokens = [] out_logits = [] for i, (token_buffer, num_tokens) in enumerate( zip( masked_token_buffer, sampling_state.num_input_tokens, ) ): start_idx = 0 if echo else num_tokens out_tokens.append(token_buffer[start_idx:total_sampling_steps].tolist()) if return_logits: logits_buffer = sampling_state.logits_buffer[i] out_logits.append( logits_buffer[start_idx:total_sampling_steps].tolist() ) decoded_outputs = [self.vocab.DecodeIds(tokens) for tokens in out_tokens] if sampling_state.intermediates is not None: sampling_state.intermediates.trim(total_sampling_steps) result = SamplerOutput( text=decoded_outputs, logits=out_logits, tokens=out_tokens, intermediates=sampling_state.intermediates, ) return result ================================================ FILE: examples/gemma/sampler_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Minimal test for sampler.""" import os from collections.abc import Iterable from absl.testing import absltest from absl.testing import parameterized from flax import nnx import modules import params as params_lib import sampler as sampler_lib import sow_lib import transformer as transformer_lib import jax import jax.numpy as jnp import numpy as np import sentencepiece as spm class MockVocab(spm.SentencePieceProcessor): def __init__(self): super().__init__() self._start_id = 3 self._mapping_text_to_id = { '': 0, '': 1, '': 2, 'input': 3, 'string': 4, 'hello': 5, 'world': 6, 'Hello': 7, 'there': 8, '!': 9, 'My': 10, 'name': 11, 'is': 12, 'Morgane': 13, } self._vocab_size = len(self._mapping_text_to_id) def pad_id(self) -> int: return 0 def bos_id(self) -> int: return 1 def eos_id(self) -> int: return 2 def GetPieceSize(self) -> int: # pylint: disable=invalid-name return self._vocab_size def DecodeIds(self, ids: Iterable[int]) -> str: # pylint: disable=invalid-name reverse_mapping = {v: k for k, v in self._mapping_text_to_id.items()} return ' '.join(reverse_mapping[e] for e in ids) def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=invalid-name words = text.split(' ') return [self._mapping_text_to_id[word] for word in words] class SamplerTest(parameterized.TestCase): def assertReasonableTensor(self, array, expected_shape=None): self.assertIsNotNone(array) if expected_shape is not None: self.assertEqual(array.shape, expected_shape) def test_samples(self): vocab = MockVocab() num_layers = 6 transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=num_layers, num_embed=vocab.GetPieceSize(), embed_dim=768, hidden_dim=6144, num_heads=4, num_kv_heads=4, head_dim=256, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, attn_logits_soft_cap=None, use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=1024, ) result = sampler(['input string', 'hello world'], total_generation_steps=10) self.assertIsNotNone(result) top_p_result = sampler( ['input string', 'hello world'], total_generation_steps=10, temperature=9, top_p=0.95, ) self.assertIsNotNone(top_p_result) self.assertNotEqual(result.text, top_p_result.text) top_p_result_2 = sampler( ['input string', 'hello world'], total_generation_steps=10, temperature=9, top_p=0.95, seed=jax.random.PRNGKey(42), ) self.assertIsNotNone(top_p_result_2) self.assertNotEqual(top_p_result.text, top_p_result_2.text) def test_state_update(self): vocab = MockVocab() num_layers = 6 transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=num_layers, num_embed=vocab.GetPieceSize(), embed_dim=768, hidden_dim=6144, num_heads=4, num_kv_heads=4, head_dim=256, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, attn_logits_soft_cap=None, use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=1024, ) input_strings = ['input string', 'hello world'] original_logits = sampler(input_strings, total_generation_steps=10).logits new_transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=42) ) sampler.transformer_state = nnx.state(new_transformer, nnx.Param) new_logits = sampler(input_strings, total_generation_steps=10).logits with self.assertRaises(AssertionError): np.testing.assert_allclose( original_logits, new_logits, atol=1e-1, rtol=1e-1 ) def test_invalid_state_update(self): vocab = MockVocab() def make_config(num_layers, embed_dim): return transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=num_layers, num_embed=vocab.GetPieceSize(), embed_dim=embed_dim, hidden_dim=6144, num_heads=4, num_kv_heads=4, head_dim=256, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, attn_logits_soft_cap=None, use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( make_config(num_layers=6, embed_dim=768), rngs=nnx.Rngs(params=0) ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=1024, ) new_transformer = transformer_lib.Transformer( make_config(num_layers=3, embed_dim=768), rngs=nnx.Rngs(params=42) ) with self.assertRaisesRegex( ValueError, '.*must have the same structure.*' ): sampler.transformer_state = nnx.state(new_transformer, nnx.Param) new_transformer = transformer_lib.Transformer( make_config(num_layers=6, embed_dim=1024), rngs=nnx.Rngs(params=42) ) with self.assertRaisesRegex( ValueError, '.*must have the same shape and dtype.*' ): sampler.transformer_state = nnx.state(new_transformer, nnx.Param) def test_forbidden_tokens(self): vocab = MockVocab() transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, hidden_dim=64, num_heads=4, num_kv_heads=1, head_dim=64, final_logit_softcap=None, attention_types=[], use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) # Pre-cook the embedding matrix so that the output is deterministic. transformer.embedder.input_embedding.set_value(jnp.eye( vocab.GetPieceSize(), 32 )) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=8, ) # First, we check that the sampler would produce the tokens that we are # trying to forbid. result1 = sampler( ['input string', 'hello world'], total_generation_steps=10, forbidden_tokens=None, ) self.assertIn('string', result1.text[0]) self.assertIn('world', result1.text[1]) # Then, we check that the sampler does not produce the forbidden tokens. result2 = sampler( ['input string', 'hello world'], total_generation_steps=10, forbidden_tokens=['string', 'world'], ) for output in result2.text: self.assertNotIn('string', output) self.assertNotIn('world', output) def test_forward_equivalence(self): vocab = MockVocab() num_layers = 2 transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=num_layers, num_embed=vocab.GetPieceSize(), embed_dim=32, hidden_dim=64, num_heads=4, num_kv_heads=1, head_dim=64, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) raw_input = 'Hello there ! My name is Morgane ' token_input = jnp.asarray( [vocab.bos_id()] + vocab.EncodeAsIds(raw_input) ).reshape((1, -1)) batch_size = 1 cache_size = 9 cache = transformer.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) input_mask = token_input != vocab.pad_id() positions = transformer_lib.build_positions_from_mask(input_mask) attention_mask = transformer_lib.make_causal_attn_mask(input_mask) n_input_tokens = token_input.shape[1] output_forward, _ = transformer( last_tokens=token_input, positions=positions, cache=cache, attention_mask=attention_mask, ) output_forward = output_forward[0, :n_input_tokens] sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=cache_size, ) output_transformer = sampler( [raw_input], total_generation_steps=10, echo=True, ) out_logits = np.array(output_transformer.logits)[0, 1 : n_input_tokens + 1] np.testing.assert_almost_equal(output_forward, out_logits) def test_sampler_init_sample_state(self): vocab = MockVocab() transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, hidden_dim=64, num_heads=4, num_kv_heads=1, head_dim=64, final_logit_softcap=None, attention_types=[], use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=8, ) input_strings = [' hello world', 'input string '] all_input_ids = [sampler.tokenize(x) for x in input_strings] total_sampling_steps = 5 sample_state = sampler.init_sample_state( all_input_ids, total_sampling_steps=total_sampling_steps, include_logits=True, forbidden_token_ids=None, temperature=0.0, top_p=0.95, seed=jax.random.PRNGKey(0), ) # Check that the position indices correctly ignore padding self.assertListEqual(list(sample_state.positions[0]), [0, 0, 1, 2, 3, 4]) self.assertListEqual(list(sample_state.positions[1]), [0, 1, 2, 2, 3, 4]) def test_sampler_mask_tokens_after_eos_ids(self): vocab = MockVocab() transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=0, num_embed=vocab.GetPieceSize(), embed_dim=32, hidden_dim=64, num_heads=4, num_kv_heads=1, head_dim=64, final_logit_softcap=None, attention_types=[], use_post_attn_norm=None, use_post_ffw_norm=None, ) transformer = transformer_lib.Transformer( transformer_config, rngs=nnx.Rngs(params=0) ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, cache_size=8, ) input_strings = ['hello world hello world', 'input string hello'] all_input_ids = [sampler.tokenize(x) for x in input_strings] total_sampling_steps = 5 sample_state = sampler.init_sample_state( all_input_ids, total_sampling_steps=total_sampling_steps, include_logits=True, forbidden_token_ids=None, temperature=0.0, top_p=0.95, seed=jax.random.PRNGKey(0), ) masked_token_buffer = sampler.mask_tokens_after_eos_ids( sample_state.token_buffer ) self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0]) self.assertListEqual(list(masked_token_buffer[1]), [1, 3, 4, 2, 0, 0]) def test_sampler_sows_intermediates(self): vocab = MockVocab() num_layers = 3 config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types num_layers=num_layers, num_embed=vocab.GetPieceSize(), embed_dim=64, hidden_dim=128, num_heads=2, num_kv_heads=1, head_dim=64, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, use_post_attn_norm=None, attn_logits_soft_cap=None, use_post_ffw_norm=None, ) sow_config = sow_lib.SowConfig( embeddings=True, rs_after_attention=False, # This should results in a None value. rs_after_ffw=True, attn_logits_topk=5, mlp_hidden_topk=11, ) transformer = transformer_lib.Transformer( config, rngs=nnx.Rngs(params=0), sow_config=sow_config ) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, ) raw_input = ['input string', 'hello world'] result = sampler(raw_input, total_generation_steps=10) input_length = max([len(vocab.EncodeAsIds(i)) for i in raw_input]) input_length += 1 # +1 for BOS token output_length = max(len(tokens) for tokens in result.tokens) length = input_length + output_length self.assertIsNotNone(result) intermediates = result.intermediates self.assertIsNotNone(intermediates) self.assertReasonableTensor( intermediates.embeddings, expected_shape=(2, length, config.embed_dim), ) # Verify that the intermediates are different for two different steps. self.assertNotAlmostEqual( jnp.sum(intermediates.embeddings[:, 1, ...]), jnp.sum(intermediates.embeddings[:, 2, ...]), ) # Verify that the intermediates are filled in for each layer. self.assertLen(intermediates.layers, config.num_layers) for layer in intermediates.layers: # For the requested intermediates we check the shape and that values are # not all zeros, which was the initial value. self.assertReasonableTensor( layer.rs_after_ffw, expected_shape=(2, length, config.embed_dim), ) self.assertReasonableTensor( layer.attn_logits_topk_values, expected_shape=( 2, length, config.num_heads, sow_config.attn_logits_topk, ), ) self.assertReasonableTensor( layer.attn_logits_topk_indices, expected_shape=( 2, length, config.num_heads, sow_config.attn_logits_topk, ), ) self.assertReasonableTensor( layer.mlp_hidden_topk_values, expected_shape=(2, length, sow_config.mlp_hidden_topk), ) self.assertReasonableTensor( layer.mlp_hidden_topk_indices, expected_shape=(2, length, sow_config.mlp_hidden_topk), ) # For the none requested intermediates we want to have None values. self.assertIsNone(layer.rs_after_attention) def test_compute_attention_mask(self): # Check that the input mask is correctly applied when total sampling steps # is lower than the max cache length. input_mask = jnp.array([[1, 1, 0, 0, 0], [1, 1, 0, 1, 0]], dtype=jnp.bool_) seq_len = 8 time_step = jnp.asarray(4, dtype=jnp.int32) attn_mask = sampler_lib._compute_attention_masks( time_step, seq_len, input_mask ) expected_attn_mask = jnp.array( [[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_ ) self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all()) # Check that the input mask is correctly applied when total sampling steps # is *longer* than the max cache length. seq_len = 4 time_step = jnp.asarray(4, dtype=jnp.int32) attn_mask = sampler_lib._compute_attention_masks( time_step, seq_len, input_mask ) expected_attn_mask = jnp.array( [[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_ ) self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all()) @parameterized.parameters( {"url": "google/gemma/flax/2b"}, {"url": "google/gemma-2/flax/gemma2-2b"}, {"url": "google/gemma-3/flax/gemma3-1b"}, ) def test_models_from_kaggle(self, url): # A smoke test based on guide/gemma.md to ensure models are working correctly # Check Kaggle creds as env var otherwise skip the test has_kaggle_creds = all(k in os.environ for k in ["KAGGLE_USERNAME", "KAGGLE_KEY"]) try: import kagglehub has_kagglehub_dep = True except ModuleNotFoundError: has_kagglehub_dep = False if not (has_kaggle_creds and has_kagglehub_dep): self.skipTest('Skip the test as no Kaggle deps/creds') variant = url.split("/")[-1] weights_dir = kagglehub.model_download(url) ckpt_path = f"{weights_dir}/{variant}" vocab_path = f"{weights_dir}/tokenizer.model" vocab = spm.SentencePieceProcessor() vocab.Load(vocab_path) params = params_lib.load_and_format_params(ckpt_path) transformer = transformer_lib.Transformer.from_params(params) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, ) input_batch = [ "# Python function to compute a square of the input number", ] out_data = sampler( input_strings=input_batch, total_generation_steps=50, ) assert "def square(" in out_data.text[0], out_data.text[0] if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/sow_lib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Utilities for sowing intermediate activations.""" import dataclasses from flax import nnx import jax import jax.numpy as jnp @jax.tree_util.register_dataclass @dataclasses.dataclass class LayerIntermediates: """Intermediate activations for a single layer.""" # Dense residual stream activations. rs_after_attention: jax.Array | None = None rs_after_ffw: jax.Array | None = None # Sparse representations for large activations. mlp_hidden_topk_values: jax.Array | None = None mlp_hidden_topk_indices: jax.Array | None = None attn_logits_topk_values: jax.Array | None = None attn_logits_topk_indices: jax.Array | None = None def merge(self, decoding_step, layer: nnx.Module): """Merges the intermediate activations from one step.""" for field in dataclasses.fields(self.__class__): value = getattr(self, field.name) if value is None: continue # We put mlp and attn intermediates into this class without any further # nesting. So we have to retrieve the intermediates from the correct # sub-module. try: if field.name.startswith('attn_'): step_value = getattr( layer.attn, field.name.replace('attn_', '') )[0] elif field.name.startswith('mlp_'): step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0] else: step_value = getattr(layer, field.name)[0] except AttributeError as exc: raise ValueError( f'Intermediate {field.name} is not in the step intermediates.' ) from exc # This logic is the same for all intermediates. The second dimenions is # the length dimension, where we want to merge the intermediates from # multiple steps. setattr( self, field.name, value.at[:, decoding_step + 1].set(step_value[:, 0, ...]), ) def trim(self, max_length: int): """Trims the intermediate activations to the given length.""" for field in dataclasses.fields(self.__class__): value = getattr(self, field.name) if value is not None: setattr(self, field.name, value[:, :max_length, ...]) @jax.tree_util.register_dataclass @dataclasses.dataclass class TransformerIntermediates: """Intermediate activations of a transformer network.""" # Embeddings of the input tokens. embeddings: jax.Array | None = None # Intermediate activations of each layer. layers: list[LayerIntermediates] = dataclasses.field(default_factory=list) def merge(self, decoding_step, transformer: nnx.Module): """Merges the intermediate activations from one step.""" if self.embeddings is not None: try: self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set( transformer.embeddings[0][:, 0, ...] ) except AttributeError as exc: raise ValueError( 'Embeddings are not in the step intermediates.' ) from exc if len(self.layers) != len(transformer.layers): raise ValueError( 'Number of layers in the transformer and intermediates do not match.' ) for layer_intermediates, layer_module in zip( self.layers, transformer.layers ): layer_intermediates.merge(decoding_step, layer_module) def trim(self, max_length: int): """Trims the intermediate activations to the given length.""" if self.embeddings is not None: self.embeddings = self.embeddings[:, :max_length, ...] for layer in self.layers: layer.trim(max_length) @dataclasses.dataclass(frozen=True) class SowConfig: """Module for sowing intermediate activations.""" # Whether to sow embeddings. embeddings: bool = False # Whether to sow activations after each attention block (in residual stream). rs_after_attention: bool = False # Whether to sow activations after each FFW block (in residual stream). # This is the same as the residual stream activations after a whole layer. rs_after_ffw: bool = False # If non-zero, top-k activations in a ffw hidden layer are sowed. # We use a sparse representation here to save memory. mlp_hidden_topk: int = 0 # If non-zero, top-k attention logits are sowed. # We use a sparse representation here to save memory. attn_logits_topk: int = 0 def maybe_sow_embeddings( self, embeddings: jax.Array, module: nnx.Module, ): """Sows embeddings if configured.""" if self.embeddings: module.sow(nnx.Intermediate, 'embeddings', embeddings) def maybe_sow_rs_after_attention( self, activations: jax.Array, module: nnx.Module, ): """Sows activations after attention if configured.""" if self.rs_after_attention: module.sow(nnx.Intermediate, 'rs_after_attention', activations) def maybe_sow_rs_after_ffw( self, activations: jax.Array, module: nnx.Module, ): """Sows activations after FFW if configured.""" if self.rs_after_ffw: module.sow(nnx.Intermediate, 'rs_after_ffw', activations) def maybe_sow_mlp_hidden_topk( self, activations: jax.Array, module: nnx.Module, ): """Sows top-absolute-k activations in a mlp hidden layer if configured.""" if self.mlp_hidden_topk: _, indices = jax.lax.top_k(jnp.abs(activations), self.mlp_hidden_topk) values = jnp.take_along_axis(activations, indices, axis=-1) module.sow(nnx.Intermediate, 'hidden_topk_values', values) module.sow(nnx.Intermediate, 'hidden_topk_indices', indices) def maybe_sow_attn_logits_topk( self, logits: jax.Array, module: nnx.Module, ): """Sows top-k attention logits if configured.""" if self.attn_logits_topk: values, indices = jax.lax.top_k(logits, self.attn_logits_topk) module.sow(nnx.Intermediate, 'logits_topk_values', values) module.sow(nnx.Intermediate, 'logits_topk_indices', indices) ================================================ FILE: examples/gemma/tokenizer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides op for tokenizing a dataset.""" from collections.abc import Iterable import dataclasses import os import sys import tempfile import time from typing import Any from absl import logging import jax import tensorflow as tf from sentencepiece import SentencePieceProcessor # pylint: disable=g-importing-member from sentencepiece import SentencePieceTrainer # pylint: disable=g-importing-member if sys.version_info < (3, 13): import tensorflow_text as tftxt Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), ) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: dataset: tf.dataset containing string-data. maxchars: int: approximate number of characters to save from dataset. data_keys: Tuple[str]: what keys in dataset to dump from. Returns: name of temp file with dataset bytes, exact number of characters dumped. """ char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/ds_chars' ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: line = example[k] + b'\n' char_count += len(line) outfp.write(line) return outfp.name, char_count def _train_sentencepiece( dataset: tf.data.Dataset, *, vocab_size: int, maxchars: int = int(1e7), model_path: str, model_type: str = 'unigram', character_coverage: float = 1.0, data_keys=('inputs', 'targets'), pad_id: int = 0, eos_id: int = 1, bos_id: int = 2, unk_id: int = 3, ): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. maxchars: int: number of characters to use for sentencepiece training. model_path: str: path of model file to save vocab model to. model_type: str: type of sentencepiece vocab to train. character_coverage: amount of characters covered by the model, good defaults are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. data_keys: Tuple[str]: keys of dataset to use for training. pad_id: int: pad piece id eos_id: int: end of sentence piece id bos_id: int: begin of sentence piece id unk_id: int: unknown piece id Returns: path to the trained sentencepiece vocabulary model. """ if model_path.startswith('gs://'): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( dataset, maxchars=maxchars, data_keys=data_keys ) with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/sp_tmp' ) as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', # Setup ids for PAD, EOS, BOS, UNK as 0, 1, 2, 3 # Default values: # --unk_id (Override UNK () id.) type: int32 default: 0 # --bos_id (Override BOS () id. Set -1 to disable BOS.) type: int32 default: 1 # --eos_id (Override EOS () id. Set -1 to disable EOS.) type: int32 default: 2 # --pad_id (Override PAD () id. Set -1 to disable PAD.) type: int32 default: -1 # https://github.com/google/sentencepiece/blob/master/doc/options.md f'--pad_id={pad_id}', f'--bos_id={bos_id}', f'--eos_id={eos_id}', f'--unk_id={unk_id}', ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = abs_model_path + '.rntmp' tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path def _load_sentencepiece_tokenizer( model_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False, ): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse ) return sp_tokenizer def load_or_train_tokenizer( dataset: tf.data.Dataset, *, vocab_path: str, vocab_size: int, max_corpus_chars: int, data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) except tf.errors.NotFoundError: logging.info('SentencePiece vocab not found, building one from data.') vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, data_keys=data_keys, ) return _load_sentencepiece_tokenizer(vocab_path) @dataclasses.dataclass class TokenizeOp: sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') def __call__(self, features: Features) -> Features: for k in self.data_keys: features[k] = self.sp_tokenizer.tokenize(features[k]) return features def load_sentencepiece_processor(vocab_path: str): spp = SentencePieceProcessor() spp.Load(vocab_path) return spp ================================================ FILE: examples/gemma/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Language Modeling example. This script trains a Transformer on a LM1B dataset. """ # pytype: disable=wrong-arg-count # pytype: disable=attribute-error import dataclasses import os from typing import Any from absl import logging from clu import metric_writers from clu import periodic_actions from flax import nnx import input_pipeline import sampler as sampler_lib import tokenizer import transformer as transformer_lib import utils from flax.training import checkpoints from flax.training import common_utils import jax from jax import random import jax.numpy as jnp import numpy as np import optax import tensorflow as tf @dataclasses.dataclass(unsafe_hash=True) class MeshRules: embed: str | None = None mlp: str | None = None kv: str | None = None vocab: str | None = None def __call__(self, *keys: str) -> tuple[str, ...]: return tuple( getattr(self, key) if key is not None else None for key in keys ) @dataclasses.dataclass(unsafe_hash=True) class TrainConfig: """Configuration for training a gemma model.""" # Path to load or store sentencepiece vocab file. vocab_path: str | None # Vocabulary size if `vocab_path` is not given. vocab_size: int # Maximum number of characters to use for training. max_corpus_chars: int # Name of TFDS translation dataset to use. dataset_name: str # Optional name of TFDS translation dataset to use for evaluation. eval_dataset_name: str # Optional name of TFDS split to use for evaluation. eval_split: str # Per device batch size for training. per_device_batch_size: int # Per device batch size for training. eval_per_device_batch_size: int # Prompt for language model sampling prompts: tuple[str, ...] # Temperature for top_p sampling. sampling_temperature: float # Top-p sampling threshold. sampling_top_p: float # Number of steps to take during training. num_train_steps: int # Number of steps to take during evaluation. # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 num_eval_steps: int # Number of steps to generate predictions. # -1 will use the whole eval dataset. num_predict_steps: int # Base learning rate. learning_rate: float # Linear learning rate warmup. warmup_steps: int # Cross entropy loss label smoothing. label_smoothing: float # Decay factor for AdamW style weight decay. weight_decay: float # Maximum length cutoff for training examples. max_target_length: int # Maximum length cutoff for eval examples. max_eval_target_length: int # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, # ...) transformer_name: str | None # or alternatively define the model using the dict of parameters transformer_params: dict[Any, Any] | None # Whether to save model checkpoints. save_checkpoints: bool # Whether to restore from existing model checkpoints. restore_checkpoints: bool # Save a checkpoint every these number of steps. checkpoint_every_steps: int # Frequency of eval during training, e.g. every 1_000 steps. eval_every_steps: int # Use bfloat16 mixed precision training instead of float32. use_bfloat16: bool # Integer for PRNG random seed. seed: int # Parallelism mesh_axes: tuple[str, ...] axis_rules: MeshRules data_sharding: tuple[str, ...] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. dcn_data_parallelism: int = -1 dcn_fsdp_parallelism: int = 1 dcn_tensor_parallelism: int = 1 ici_data_parallelism: int = 1 ici_fsdp_parallelism: int = -1 ici_tensor_parallelism: int = 1 def replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def __post_init__(self): if isinstance(self.axis_rules, dict): self.axis_rules = MeshRules(**self.axis_rules) def rsqrt_schedule( init_value: float, shift: int = 0, ): """Applies a reverse square-root schedule. The reverse square root schedule is simply `lr = init_value / sqrt(step)`. Args: init_value: Base learning rate (before applying the rsqrt schedule). shift: How many steps the rsqrt should be shifted. Shifting the rsqrt schedule makes it less steep in the beginning (close to 0). Returns: A schedule that applies the reverse square root. """ def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): """Compute weighted accuracy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } return metrics # Primary training / eval / decode step functions. # ----------------------------------------------------------------------------- def train_step( state: utils.TrainState, batch, learning_rate_fn, label_smoothing=0.0, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this # metadata. # if such features are not present they are ignored and the example is treated # like a normal, unpacked sequence example. train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] (inputs, inputs_positions, inputs_segmentation, targets) = ( batch.get(k, None) for k in train_keys ) # TODO: this should be defined globally pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id attention_mask = transformer_lib.make_causal_attn_mask( input_mask ) # (B, L, L) # inputs_segmentation: (B, L) mask = ( inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] ) # (B, L, L) attention_mask = jnp.logical_and(mask, attention_mask) def loss_fn(params): """loss function used for training.""" module = nnx.merge(state.graphdef, params) logits, _ = module( inputs, positions=inputs_positions, attention_mask=attention_mask, cache=None, ) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) metrics['learning_rate'] = lr return new_state, metrics def eval_step( params: nnx.State, batch, graphdef: nnx.GraphDef[transformer_lib.Transformer], label_smoothing=0.0, ): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] # TODO: this should be defined globally pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id inputs_positions = transformer_lib.build_positions_from_mask(input_mask) attention_mask = transformer_lib.make_causal_attn_mask(input_mask) module = nnx.merge(graphdef, params) logits, _ = module( inputs, positions=inputs_positions, attention_mask=attention_mask, cache=None, ) return compute_metrics(logits, targets, weights, label_smoothing) def evaluate( *, jit_eval_step, state: utils.TrainState, eval_ds: tf.data.Dataset, num_eval_steps: int, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') eval_metrics = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree.map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access metrics = jit_eval_step(state.params, eval_batch, state.graphdef) eval_metrics.append(metrics) eval_metrics = common_utils.stack_forest(eval_metrics) eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree.map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums, ) return eval_summary def train_and_evaluate(config: TrainConfig, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ workdir = os.path.abspath(workdir) tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, 'sentencepiece_model') config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- if config.transformer_name is not None: model_config = transformer_lib.TransformerConfig.from_version_name( config.transformer_name, num_embed=vocab_size, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, axis_rules=config.axis_rules, ) else: assert config.transformer_params is not None model_config = transformer_lib.TransformerConfig.from_dict( **config.transformer_params, num_embed=vocab_size, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, axis_rules=config.axis_rules, ) # Mesh definition devices_array = utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) _, inference_rng = random.split(rng) def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) state, state_sharding = utils.setup_initial_state( constructor, optimizer, model_config, init_rng, mesh ) data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding)) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dataclasses.asdict(config)) # compile multidevice versions of train/eval/predict step fn. jit_train_step = jax.jit( train_step, in_shardings=( state_sharding, data_sharding, ), # type: ignore out_shardings=(state_sharding, None), # type: ignore static_argnames=('learning_rate_fn', 'label_smoothing'), donate_argnums=0, ) jit_eval_step = jax.jit( eval_step, in_shardings=( state_sharding.params, data_sharding, ), # type: ignore out_shardings=None, # type: ignore static_argnames=('graphdef', 'label_smoothing'), ) vocab = tokenizer.load_sentencepiece_processor(vocab_path) sampler = sampler_lib.Sampler( transformer=nnx.merge(state.graphdef, state.params), vocab=vocab, cache_size=1024, ) # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): with report_progress.timed('data'): batch = next(train_iter) batch = jax.tree.map( lambda x: jnp.asarray(x, device=data_sharding), batch ) with report_progress.timed('train_step'): state, metrics = jit_train_step(state, batch, learning_rate_fn, 0.0) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Write batch loss and lr every step to TB # without overwhelming the stdout: if jax.process_index() == 0: tb_writer = writer._writers[-1] # pylint: disable=protected-access lr = train_metrics[-1]['learning_rate'] train_batch_loss = train_metrics[-1]['loss'] denominator = train_metrics[-1]['denominator'] tb_writer.write_scalars( step, { 'train_learning_rate': lr, 'train_loss': train_batch_loss / denominator, }, ) # Periodic metric handling. if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.stack_forest(train_metrics) # Remove learning_rate from the summary _ = train_metrics.pop('learning_rate') metrics_sums = jax.tree.map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree.map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), max=1.0e4) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('generate_text'): # update sampler's transformer state: sampler.transformer_state = state.params exemplars = sampler( config.prompts, total_generation_steps=config.num_predict_steps, temperature=config.sampling_temperature, top_p=config.sampling_top_p, seed=inference_rng, echo=True, ) writer.write_texts(step, {'samples': exemplars.text[0]}) with report_progress.timed('eval'): eval_results = evaluate( jit_eval_step=jit_eval_step, state=state, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps, ) # (clipped) perplexity after averaging log-perplexity eval_results['perplexity'] = jnp.clip( jnp.exp(eval_results['loss']), max=1.0e4 ) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()} ) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info('Saving checkpoint step %d.', step) with report_progress.timed('checkpoint'): checkpoints.save_checkpoint_multiprocess(workdir, state, step) ================================================ FILE: examples/gemma/transformer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Gemma transformer.""" from __future__ import annotations from collections.abc import Iterable import dataclasses import enum import functools from typing import Any from flax import nnx import helpers import layers import modules import params as params_lib import sow_lib import jax.numpy as jnp from jaxtyping import Array # pylint: disable=g-importing-member,g-multiple-import Cache = dict[str, modules.LayerCache] def make_attention_layers_types( pattern: tuple[modules.AttentionType, ...], num_layers: int, ) -> tuple[modules.AttentionType, ...]: """Returns the list of attention types for every layers.""" pattern_size = len(pattern) out = pattern * (num_layers // pattern_size) if num_layers % pattern_size != 0: out += pattern[: num_layers % pattern_size] return tuple(out) class QueryPreAttentionNormalisation(enum.Enum): """Initialization strategy.""" # Whether to scale the query by 1/sqrt(head_dim) BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto() # Whether to scale the query by `embed_dim // num_heads` BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto() # Whether to scale the query by `1/sqrt(embed_dim // num_heads)` BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto() _NUM_LAYERS_GEMMA_2B = 18 _NUM_LAYERS_GEMMA_7B = 28 _NUM_LAYERS_GEMMA2_2B = 26 _NUM_LAYERS_GEMMA2_9B = 42 _NUM_LAYERS_GEMMA2_27B = 46 _NUM_LAYERS_GEMMA3_1B = 26 _NUM_LAYERS_GEMMA3_4B = 34 _NUM_LAYERS_GEMMA3_12B = 48 _NUM_LAYERS_GEMMA3_27B = 62 GEMMA3_ATTENTION_PATTERN = ( modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.GLOBAL, ) @dataclasses.dataclass(frozen=True) class TransformerConfig: """Configuration for the gemma transformer.""" num_layers: int num_embed: int embed_dim: int hidden_dim: int num_heads: int head_dim: int num_kv_heads: int final_logit_softcap: float | None use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] query_pre_attn_norm: QueryPreAttentionNormalisation = ( QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM ) attn_logits_soft_cap: float | None = None transpose_gating_einsum: bool = False local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY local_scale_factor: float = modules.DEFAULT_ROPE_SCALE_FACTOR global_scale_factor: float = modules.DEFAULT_ROPE_SCALE_FACTOR use_qk_norm: bool = False sliding_window_size: int | None = None dtype: Any = jnp.float32 axis_rules: Any | None = None def query_pre_attn_scalar(self) -> float: """Returns the scalar to multiply the query by before attention.""" match self.query_pre_attn_norm: case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS: return self.embed_dim // self.num_heads case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long return (self.embed_dim // self.num_heads) ** -0.5 case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _: return self.head_dim**-0.5 @classmethod def from_path(cls, path: str) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters.""" params = params_lib.load_params(path) return cls.from_params(params) @classmethod def from_params(cls, params: params_lib.Params) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters. Args: params: Model parameters Returns: TransformerConfig. """ # Post Attn Norm is only used starting from Gemma 2. use_post_attn_norm = ( 'post_attention_norm' in params['transformer']['layer_0'] ) # QK Norm is only used starting from Gemma 3. use_qk_norm = '_query_norm' in params['transformer']['layer_0']['attn'] # Num layers will give use the model size. layer_names = [ name for name in params['transformer'].keys() if 'layer' in name ] layer_names = [name.replace('layer_', '') for name in layer_names] num_layers = max([int(layer) for layer in layer_names]) + 1 if not use_post_attn_norm: # Gemma 1. if num_layers == _NUM_LAYERS_GEMMA_2B: return cls.gemma_2b() if num_layers == _NUM_LAYERS_GEMMA_7B: return cls.gemma_7b() raise ValueError( 'Guessing Gemma 1 model, but could not determine size from params.' ) elif not use_qk_norm: # Gemma 2. if num_layers == _NUM_LAYERS_GEMMA2_2B: return cls.gemma2_2b() if num_layers == _NUM_LAYERS_GEMMA2_9B: return cls.gemma2_9b() if num_layers == _NUM_LAYERS_GEMMA2_27B: return cls.gemma2_27b() raise ValueError( 'Guessing Gemma 2 model but could not determine size from params.' ) else: # Gemma 3. if num_layers == _NUM_LAYERS_GEMMA3_1B: return cls.gemma3_1b() if num_layers == _NUM_LAYERS_GEMMA3_4B: return cls.gemma3_4b() if num_layers == _NUM_LAYERS_GEMMA3_12B: return cls.gemma3_12b() if num_layers == _NUM_LAYERS_GEMMA3_27B: return cls.gemma3_27b() raise ValueError('Could not determine Gemma variant from params.') @classmethod def from_version_name(cls, name: str, **override) -> TransformerConfig: possible_names = ( "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", ) if name not in possible_names: raise ValueError( f'Unknown version name: {name}. ' f'Please choose one of the following: {possible_names}' ) if hasattr(cls, name): model_config = getattr(cls, name)(**override) return model_config else: raise RuntimeError( 'Something wrong in TransformerConfig code. ' f'No attribute {name} in TransformerConfig' ) @classmethod def from_dict(cls, **config: Any) -> TransformerConfig: # Deserialize query_pre_attn_norm values: if "query_pre_attn_norm" in config: config["query_pre_attn_norm"] = QueryPreAttentionNormalisation(config["query_pre_attn_norm"]) else: config["query_pre_attn_norm"] = QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM return cls(**config) @classmethod def gemma_2b(cls, **override) -> TransformerConfig: num_layers = _NUM_LAYERS_GEMMA_2B config = { 'num_layers': num_layers, 'num_embed': 256128, 'embed_dim': 2048, 'hidden_dim': 16384, 'num_heads': 8, 'head_dim': 256, 'num_kv_heads': 1, 'final_logit_softcap': None, 'attention_types': (modules.AttentionType.GLOBAL,) * num_layers, 'use_post_attn_norm': False, 'use_post_ffw_norm': False, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma_7b(cls, **override): num_layers = _NUM_LAYERS_GEMMA_7B config = { "num_layers": num_layers, "num_embed": 256128, "embed_dim": 3072, "hidden_dim": 24576, "num_heads": 16, "head_dim": 256, "num_kv_heads": 16, "final_logit_softcap": None, "attention_types": (modules.AttentionType.GLOBAL,) * num_layers, "use_post_attn_norm": False, "use_post_ffw_norm": False, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma2_2b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_2B config = { 'num_layers': num_layers, 'num_embed': 256128, 'embed_dim': 2304, 'hidden_dim': 9216, 'num_heads': 8, 'head_dim': 256, 'num_kv_heads': 4, 'final_logit_softcap': 30.0, 'attention_types': ( modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.GLOBAL, ) * int(num_layers / 2), 'use_post_attn_norm': True, 'use_post_ffw_norm': True, 'query_pre_attn_norm': QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, 'attn_logits_soft_cap': 50.0, 'sliding_window_size': 4096, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma2_9b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_9B config = { "num_layers": num_layers, "num_embed": 256128, "embed_dim": 3584, "hidden_dim": 28672, "num_heads": 16, "head_dim": 256, "num_kv_heads": 8, "final_logit_softcap": 30.0, "attention_types": ( modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.GLOBAL, ) * int(num_layers / 2), "use_post_attn_norm": True, "use_post_ffw_norm": True, "attn_logits_soft_cap": 50.0, "sliding_window_size": 4096, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma2_27b(cls, **override): num_layers = _NUM_LAYERS_GEMMA2_27B config = { "num_layers": num_layers, "num_embed": 256128, "embed_dim": 4608, "hidden_dim": 72728, "num_heads": 32, "head_dim": 128, "num_kv_heads": 16, "final_logit_softcap": 30.0, "use_post_attn_norm": True, "use_post_ffw_norm": True, "attention_types": ( modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.GLOBAL, ) * int(num_layers / 2), "attn_logits_soft_cap": 50.0, "sliding_window_size": 4096, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma3_1b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_1B config = { "num_layers": num_layers, "final_logit_softcap": None, "num_embed": 262144, "embed_dim": 1152, "hidden_dim": 6 * 1152, "num_heads": 4, "head_dim": 256, "num_kv_heads": 1, "use_post_attn_norm": True, "use_post_ffw_norm": True, "use_qk_norm": True, "attention_types": make_attention_layers_types( GEMMA3_ATTENTION_PATTERN, num_layers ), "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, "attn_logits_soft_cap": None, "sliding_window_size": 512, "transpose_gating_einsum": True, "local_base_frequency": 10_000, "global_base_frequency": 1_000_000, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma3_4b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_4B config = { "num_layers": num_layers, "final_logit_softcap": None, "num_embed": 262_144, "embed_dim": 2560, "hidden_dim": 2560 * 8 // 2, "num_heads": 8, "head_dim": 256, "num_kv_heads": 4, "use_post_attn_norm": True, "use_post_ffw_norm": True, "use_qk_norm": True, "attention_types": make_attention_layers_types( GEMMA3_ATTENTION_PATTERN, num_layers ), "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, "attn_logits_soft_cap": None, "sliding_window_size": 1024, "transpose_gating_einsum": True, "local_base_frequency": 10_000, "global_base_frequency": 1_000_000, "global_scale_factor": 8.0, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma3_12b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_12B config = { "num_layers": num_layers, "final_logit_softcap": None, "num_embed": 262144, "embed_dim": 30 * 128, "hidden_dim": 8 * 30 * 128 // 2, "num_heads": 16, "head_dim": 256, "num_kv_heads": 8, "use_post_attn_norm": True, "use_post_ffw_norm": True, "use_qk_norm": True, "attention_types": make_attention_layers_types( GEMMA3_ATTENTION_PATTERN, num_layers ), "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, "attn_logits_soft_cap": None, "sliding_window_size": 1024, "transpose_gating_einsum": True, "local_base_frequency": 10_000, "global_base_frequency": 1_000_000, "global_scale_factor": 8.0, } for key, value in override.items(): config[key] = value return cls(**config) @classmethod def gemma3_27b(cls, **override): num_layers = _NUM_LAYERS_GEMMA3_27B config = { "num_layers": num_layers, "final_logit_softcap": None, "num_embed": 262144, "embed_dim": 5376, "hidden_dim": 5376 * 8 // 2, "num_heads": 32, "head_dim": 128, "num_kv_heads": 16, "use_post_attn_norm": True, "use_post_ffw_norm": True, "use_qk_norm": True, "attention_types": make_attention_layers_types( GEMMA3_ATTENTION_PATTERN, num_layers ), "query_pre_attn_norm": QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS, "attn_logits_soft_cap": None, "sliding_window_size": 1024, "transpose_gating_einsum": True, "local_base_frequency": 10_000, "global_base_frequency": 1_000_000, "global_scale_factor": 8.0, } for key, value in override.items(): config[key] = value return cls(**config) def __post_init__(self): if self.num_heads != self.num_kv_heads: if self.num_heads % self.num_kv_heads != 0: raise ValueError( f"Number of query heads ({self.num_heads}) must be divisible by " f"number of key/value heads ({self.num_kv_heads})." ) def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: """Maps linen variable names to nnx variable names.""" new_key = [] for k in key: if k.startswith('layer_'): prefix, suffix = k.split('layer_') assert not prefix, prefix new_key.append('layers') new_key.append(int(suffix)) elif k == 'gating_einsum': new_key.append('gate_proj') new_key.append('kernel') elif k == 'linear': new_key.append('down_proj') new_key.append('kernel') else: new_key.append(k) return tuple(new_key) def _assign_linen_params_to_nnx_state( state: dict[tuple[str, ...], Any], mapped_path: tuple[str | int, ...], val: Any, transpose_gating_einsum: bool, ) -> dict[tuple[str, ...], Any]: """Splits and maybe transposes gate_proj.""" if 'gate_proj' in mapped_path: if transpose_gating_einsum: val = jnp.swapaxes(val, 1, 2) state[mapped_path].set_value(val[0]) state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1]) else: state[mapped_path].set_value(val) return state class Transformer(nnx.Module): """Gemma transformer.""" @classmethod def from_params( cls, params: params_lib.Params, config: None | TransformerConfig = None, sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ) -> Transformer: if config is None: config = TransformerConfig.from_params(params) assign_val_fn = functools.partial( _assign_linen_params_to_nnx_state, transpose_gating_einsum=config.transpose_gating_einsum, ) return helpers.module_from_linen_variables( module_factory=lambda: cls( config, rngs=nnx.Rngs(params=0), sow_config=sow_config ), variables=params['transformer'], map_key_fn=_map_linen_var_names, assign_val_fn=assign_val_fn, ) def __init__( self, config: TransformerConfig, *, rngs: nnx.Rngs, sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ): self.embedder = modules.Embedder( vocab_size=config.num_embed, embed_dim=config.embed_dim, embedding_init=modules.maybe_with_partitioning( nnx.initializers.normal(), config.axis_rules, ("vocab", "embed"), ), dtype=config.dtype, rngs=rngs, ) self.layers = nnx.List([ modules.Block( config=config, attn_type=attn_type, sow_config=sow_config, rngs=rngs, ) for _, attn_type in zip( range(config.num_layers), config.attention_types ) ]) self.final_norm = layers.RMSNorm( config.embed_dim, scale_init=modules.maybe_with_partitioning( nnx.initializers.zeros_init(), config.axis_rules, ("embed", ), ), rngs=rngs, ) self.final_logits_softcap = config.final_logit_softcap self.sow_config = sow_config def __call__( self, last_tokens: Array, # [B, L] positions: Array, # [B, L] cache: Cache | None, # (sequence length L') attention_mask: Array, # [B, L, L'] ) -> tuple[Array, Cache | None]: """Transformer forward pass. You can run this forward pass two ways: with or without an attention kv cache. Args: last_tokens: input sequence of tokens. positions: input absolute positions. cache: Attention KV cache or None. attention_mask: transformer input mask. Returns: predicted_logits, new_cache predicted_logits: output logits predicted by the model new_cache: updated cache if the input cache is not None, None elsewhere. """ new_cache = None if cache is None else {} x = self.embedder.encode(last_tokens) self.sow_config.maybe_sow_embeddings(x, self) for i, layer in enumerate(self.layers): layer_name = f'layer_{i}' layer_cache = cache[layer_name] if cache else None layer_cache, x = layer( x, positions, layer_cache, attention_mask, ) if cache is not None: new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch x = self.final_norm(x) logits = self.embedder.decode(x) if self.final_logits_softcap is not None: logits /= self.final_logits_softcap logits = jnp.tanh(logits) * self.final_logits_softcap return logits, new_cache # pytype: disable=bad-return-type @property def embed_dim(self) -> int: return self.embedder.embed_dim @property def num_embed(self) -> int: return self.embedder.num_embed @property def num_layers(self) -> int: return len(self.layers) def init_cache( self, cache_size: int, batch_size: int, dtype: jnp.dtype = jnp.bfloat16, ) -> Cache: """Initializes a new Transformer cache.""" return { f'layer_{i}': self.layers[i].init_cache( cache_size=cache_size, batch_size=batch_size, dtype=dtype, ) for i in range(self.num_layers) } def init_intermediates( self, batch_size: int, buffer_size: int, sow_config: sow_lib.SowConfig, dtype: jnp.dtype = jnp.float32, ) -> sow_lib.TransformerIntermediates: """Initializes the intermediate activations that will be filled.""" intermediates = sow_lib.TransformerIntermediates() residual_stream_dummy = jnp.zeros( (batch_size, buffer_size, self.embed_dim), dtype=dtype, ) if sow_config.embeddings: intermediates.embeddings = residual_stream_dummy for layer in self.layers: layer_intermediates = sow_lib.LayerIntermediates() if sow_config.rs_after_attention: layer_intermediates.rs_after_attention = residual_stream_dummy if sow_config.rs_after_ffw: layer_intermediates.rs_after_ffw = residual_stream_dummy if sow_config.attn_logits_topk: shape = ( batch_size, buffer_size, layer.attn.num_heads, sow_config.attn_logits_topk, ) layer_intermediates.attn_logits_topk_values = jnp.zeros( shape, dtype=dtype, ) layer_intermediates.attn_logits_topk_indices = jnp.zeros( shape, dtype=jnp.int32, ) if sow_config.mlp_hidden_topk: shape = ( batch_size, buffer_size, sow_config.mlp_hidden_topk, ) layer_intermediates.mlp_hidden_topk_values = jnp.zeros( shape, dtype=dtype, ) layer_intermediates.mlp_hidden_topk_indices = jnp.zeros( shape, dtype=jnp.int32, ) intermediates.layers.append(layer_intermediates) return intermediates def make_causal_attn_mask( input_mask: Array, ) -> Array: """Attention mask in batch mode. Args: input_mask: Input mask for the input. True for non-padded tokens only, else False. Returns: Attention mask. """ seq_len = input_mask.shape[-1] attn_mask = input_mask[..., None, :] causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) # Prefixes can be attended by all tokens attn_mask *= causal_mask[None, ...] return attn_mask def build_positions_from_mask(input_mask: Array) -> Array: """Computes the `positions` from the `input_mask`. Args: input_mask: The tokens `input_mask`, True for non-padded tokens only. Returns: The indices to use for RoPE and absolute position encodings for the given input mask. """ positions = jnp.cumsum(input_mask, axis=-1) # Subtract one for all positions from the first valid one as they are # 0-indexed return positions - (positions >= 1) ================================================ FILE: examples/gemma/transformer_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for the Gemma transformer.""" from collections import defaultdict from absl.testing import absltest from absl.testing import parameterized from flax import nnx import modules import sow_lib import transformer as transformer_lib import jax.numpy as jnp import numpy as np def create_fake_params(config: transformer_lib.TransformerConfig): def nested_defaultdict(): return defaultdict(nested_defaultdict) res = nested_defaultdict() res['transformer'] = nested_defaultdict() params = res['transformer'] # 1. embedding params params['embedder']['input_embedding'] = jnp.ones( (config.num_embed, config.embed_dim) ) # 2. final norm params params['final_norm'] = {'scale': jnp.ones((config.embed_dim,))} # 3. attention block params for layer_idx in range(config.num_layers): params[f'layer_{layer_idx}']['attn']['attn_vec_einsum']['w'] = jnp.ones( (config.num_heads, config.head_dim, config.embed_dim) ) if config.num_heads == config.num_kv_heads: params[f'layer_{layer_idx}']['attn']['qkv_einsum']['w'] = jnp.ones( (3, config.num_heads, config.embed_dim, config.head_dim) ) else: params[f'layer_{layer_idx}']['attn']['q_einsum']['w'] = jnp.ones( (config.num_heads, config.embed_dim, config.head_dim) ) params[f'layer_{layer_idx}']['attn']['kv_einsum']['w'] = jnp.ones( (2, config.num_kv_heads, config.embed_dim, config.head_dim) ) # 4. feedforward block params params[f'layer_{layer_idx}']['mlp']['gating_einsum'] = jnp.ones( (2, config.embed_dim, config.hidden_dim) ) params[f'layer_{layer_idx}']['mlp']['linear'] = jnp.ones( (config.hidden_dim, config.embed_dim) ) # 5. layer norm params params[f'layer_{layer_idx}']['pre_attention_norm']['scale'] = jnp.ones(( config.embed_dim, )) params[f'layer_{layer_idx}']['pre_ffw_norm']['scale'] = jnp.ones(( config.embed_dim, )) if config.use_post_attn_norm: params[f'layer_{layer_idx}']['post_attention_norm']['scale'] = jnp.ones(( config.embed_dim, )) if config.use_post_ffw_norm: params[f'layer_{layer_idx}']['post_ffw_norm']['scale'] = jnp.ones(( config.embed_dim, )) return res class TransformerTest(parameterized.TestCase): @parameterized.parameters( # Prime number to ease shape tracing dict( num_layers=3, num_embed=17, embed_dim=2, num_heads=2, num_kv_heads=2, hidden_dim=11, head_dim=8, cache_size=29, batch_size=7, sequence_length=18, expected_outputs_shape=(7, 18, 17), # batch_size, seq_size, num_embed expected_cache_shape=(7, 29, 2, 8), # batch_size, cache_size, num_kv_heads, head_dim ), dict( num_layers=3, num_embed=4, embed_dim=2, num_heads=2, num_kv_heads=1, hidden_dim=4, head_dim=4, cache_size=2, batch_size=1, sequence_length=1, expected_outputs_shape=(1, 1, 4), # batch_size, seq_size, num_embed expected_cache_shape=(1, 2, 1, 4), # batch_size, cache_size, num_kv_heads, head_dim ), dict( num_layers=3, num_embed=7, embed_dim=5, num_heads=4, num_kv_heads=2, hidden_dim=6, head_dim=8, cache_size=9, batch_size=1, sequence_length=1, expected_outputs_shape=(1, 1, 7), # batch_size, seq_size, num_embed expected_cache_shape=(1, 9, 2, 8), # batch_size, cache_size, num_kv_heads, head_dim ), ) def test_transformer( self, num_layers, num_embed, embed_dim, num_heads, num_kv_heads, hidden_dim, head_dim, cache_size, batch_size, sequence_length, expected_outputs_shape, expected_cache_shape, ): config = transformer_lib.TransformerConfig( num_layers=num_layers, num_embed=num_embed, embed_dim=embed_dim, hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, num_kv_heads=num_kv_heads, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * num_layers, use_post_attn_norm=False, use_post_ffw_norm=False, ) attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool) transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=0) ) cache = transformer.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) outputs, cache = transformer( jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), cache, attention_mask, ) self.assertEqual(outputs.shape, expected_outputs_shape) self.assertEqual(cache['layer_0']['v'].shape, expected_cache_shape) @parameterized.parameters( ('final_logit_softcap',), ('attn_logits_soft_cap',), ) def test_logit_softcap( self, soft_cap_arg, ): cache_size = 2 batch_size = 1 soft_cap_val = 0.001 attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool) params = dict( num_layers=3, num_embed=4, embed_dim=2, num_heads=2, num_kv_heads=1, hidden_dim=4, head_dim=4, attention_types=[modules.AttentionType.GLOBAL] * 3, use_post_attn_norm=False, use_post_ffw_norm=False, ) no_soft_cap_args = { 'final_logit_softcap': None, 'attn_logits_soft_cap': None, } soft_cap_args = no_soft_cap_args.copy() soft_cap_args[soft_cap_arg] = soft_cap_val config_soft_cap = transformer_lib.TransformerConfig( **(params | soft_cap_args) ) config_no_soft_cap = transformer_lib.TransformerConfig( **(params | no_soft_cap_args) ) all_outputs = [] for config in [config_soft_cap, config_no_soft_cap]: transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=1) ) cache = transformer.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) outputs, _ = transformer( jnp.array([[1]]), jnp.array([[1]]), cache, attention_mask ) all_outputs.append(outputs) soft_cap_outputs, no_soft_cap_outputs = all_outputs # pylint: disable=unbalanced-tuple-unpacking # Ensure that values aren't equal coming out of computation self.assertFalse((soft_cap_outputs == no_soft_cap_outputs).all()) # Run soft capping manually manual_soft_cap_logits = no_soft_cap_outputs / soft_cap_val manual_soft_cap_logits = jnp.tanh(manual_soft_cap_logits) * soft_cap_val np.testing.assert_array_almost_equal( manual_soft_cap_logits, soft_cap_outputs, 1e-5 ) @parameterized.parameters([ dict( config=transformer_lib.TransformerConfig( num_layers=2, num_embed=0, # unused embed_dim=0, # unused hidden_dim=0, # unused num_heads=3, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=False, use_post_ffw_norm=False, ), cache_size=2, keys=['layer_0', 'layer_1'], k_shape=(1, 2, 3, 4), v_shape=(1, 2, 3, 4), ) ]) def test_creates_cache(self, config, cache_size, keys, k_shape, v_shape): transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=0) ) cache = transformer.init_cache( cache_size=cache_size, batch_size=1, dtype=jnp.float32, ) self.assertEqual(list(cache.keys()), keys) self.assertEqual(cache['layer_0']['k'].shape, k_shape) self.assertEqual(cache['layer_0']['v'].shape, v_shape) @parameterized.parameters([ dict( batch_size=1, seq_size=4, config=transformer_lib.TransformerConfig( num_layers=2, num_embed=4, # unused embed_dim=2, hidden_dim=12, # unused num_heads=3, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=False, use_post_ffw_norm=False, ), ) ]) def test_forward_no_cache( self, batch_size: int, seq_size: int, config: transformer_lib.TransformerConfig, ): cache_size = 6 token_input = jnp.ones((batch_size, seq_size), dtype=jnp.int32) transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=0) ) empty_cache = transformer.init_cache( cache_size=cache_size, batch_size=batch_size, dtype=jnp.float32, ) attention_mask = jnp.ones( (batch_size, seq_size, cache_size), dtype=jnp.bool ) positions = transformer_lib.build_positions_from_mask(token_input != 0) output_cache, _ = transformer( token_input, positions, empty_cache, attention_mask ) attention_mask = jnp.ones((batch_size, seq_size, seq_size), dtype=jnp.bool) output_none, cache_none = transformer( token_input, positions, None, attention_mask ) self.assertIsNone(cache_none) np.testing.assert_array_almost_equal(output_cache, output_none, 1e-5) def test_attention_types( self, ): config = transformer_lib.TransformerConfig( num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, num_heads=3, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=False, use_post_ffw_norm=False, ) transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=0) ) cache = transformer.init_cache( cache_size=6, batch_size=1, dtype=jnp.float32, ) self.assertTrue(cache) @parameterized.parameters( dict( config=transformer_lib.TransformerConfig( num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, num_heads=3, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=False, use_post_ffw_norm=False, ), ), dict( config=transformer_lib.TransformerConfig( num_layers=2, num_embed=4, embed_dim=2, hidden_dim=12, num_heads=3, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=True, use_post_ffw_norm=True, ), ), dict( config=transformer_lib.TransformerConfig( num_layers=2, num_embed=4, embed_dim=5, hidden_dim=12, num_heads=6, head_dim=4, num_kv_heads=3, final_logit_softcap=None, attention_types=[modules.AttentionType.GLOBAL] * 2, use_post_attn_norm=True, use_post_ffw_norm=True, ), ), ) def test_load_from_params(self, config): params = create_fake_params(config) transformer = transformer_lib.Transformer.from_params(params, config) logits, _ = transformer( last_tokens=jnp.tile(jnp.arange(3), (2, 1)), positions=jnp.tile(jnp.arange(3), (2, 1)), cache=None, attention_mask=jnp.ones((2, 1, 3), dtype=jnp.bool), ) self.assertEqual(logits.shape, (2, 3, 4)) @parameterized.parameters([ sow_lib.SowConfig(embeddings=True), sow_lib.SowConfig(rs_after_attention=True), sow_lib.SowConfig(rs_after_ffw=True), sow_lib.SowConfig(attn_logits_topk=5), sow_lib.SowConfig(mlp_hidden_topk=11), ]) def test_sow_intermediates(self, sow_config): batch_size = 3 sequence_length = 7 num_layers = 2 config = transformer_lib.TransformerConfig( num_layers=num_layers, num_embed=4, embed_dim=48, hidden_dim=12, num_heads=1, head_dim=4, num_kv_heads=1, final_logit_softcap=None, use_post_attn_norm=False, use_post_ffw_norm=False, attention_types=[modules.AttentionType.GLOBAL] * num_layers, ) attention_mask = jnp.ones( (batch_size, sequence_length, sequence_length), dtype=jnp.bool ) transformer = transformer_lib.Transformer( config=config, rngs=nnx.Rngs(params=0), sow_config=sow_config ) transformer( jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), None, attention_mask, ) if sow_config.embeddings: self.assertTrue(hasattr(transformer, 'embeddings')) embeddings = transformer.embeddings[0] self.assertEqual( embeddings.shape, (batch_size, sequence_length, config.embed_dim), ) else: self.assertFalse(hasattr(transformer, 'embeddings')) for layer in transformer.layers: if sow_config.rs_after_attention: self.assertTrue(hasattr(layer, 'rs_after_attention')) rs_after_attention = layer.rs_after_attention[0] self.assertIsNotNone(rs_after_attention) self.assertEqual( rs_after_attention.shape, (batch_size, sequence_length, config.embed_dim), ) else: self.assertFalse(hasattr(layer, 'rs_after_attention')) if sow_config.rs_after_ffw: self.assertTrue(hasattr(layer, 'rs_after_ffw')) rs_after_ffw = layer.rs_after_ffw[0] self.assertIsNotNone(rs_after_ffw) self.assertEqual( rs_after_ffw.shape, (batch_size, sequence_length, config.embed_dim), ) else: self.assertFalse(hasattr(layer, 'rs_after_ffw')) if sow_config.attn_logits_topk: self.assertTrue(hasattr(layer.attn, 'logits_topk_values')) attn_logits_topk_values = layer.attn.logits_topk_values[0] self.assertIsNotNone(attn_logits_topk_values) self.assertEqual( attn_logits_topk_values.shape, ( batch_size, sequence_length, config.num_heads, sow_config.attn_logits_topk, ), ) self.assertTrue(hasattr(layer.attn, 'logits_topk_indices')) attn_logits_topk_indices = layer.attn.logits_topk_indices[0] self.assertIsNotNone(attn_logits_topk_indices) self.assertEqual( attn_logits_topk_indices.shape, ( batch_size, sequence_length, config.num_heads, sow_config.attn_logits_topk, ), ) else: self.assertFalse(hasattr(layer.attn, 'logits_topk_values')) self.assertFalse(hasattr(layer.attn, 'logits_topk_indices')) if sow_config.mlp_hidden_topk: self.assertTrue(hasattr(layer.mlp, 'hidden_topk_values')) ffw_hidden_topk_values = layer.mlp.hidden_topk_values[0] self.assertIsNotNone(ffw_hidden_topk_values) self.assertEqual( ffw_hidden_topk_values.shape, ( batch_size, sequence_length, sow_config.mlp_hidden_topk, ), ) self.assertTrue(hasattr(layer.mlp, 'hidden_topk_indices')) ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices[0] self.assertIsNotNone(ffw_hidden_topk_indices) self.assertEqual( ffw_hidden_topk_indices.shape, ( batch_size, sequence_length, sow_config.mlp_hidden_topk, ), ) else: self.assertFalse(hasattr(layer.mlp, 'hidden_topk_values')) self.assertFalse(hasattr(layer.mlp, 'hidden_topk_indices')) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/gemma/utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copied over from MaxText # (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). """Provides utilities for training the Flax gemma example.""" from collections.abc import Callable import logging from typing import Any from flax import nnx import transformer from flax.training import train_state import jax from jax.experimental import mesh_utils import jax.numpy as jnp import numpy as np Dtype = Any Shape = tuple[int, ...] class TrainState(train_state.TrainState): graphdef: nnx.GraphDef[transformer.Transformer] # Mesh utils. # ----------------------------------------------------------------------------- def create_device_mesh(config: Any): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas. Args: config: The training configuration. Returns: The device mesh. """ devices = jax.devices() num_devices = len(devices) try: num_slices = 1 + max([d.slice_index for d in devices]) except AttributeError: num_slices = 1 num_devices_per_slice = num_devices // num_slices logging.info(f'Devices: {devices}') # pylint: disable=logging-fstring-interpolation logging.info(f'Number of devices: {num_devices}') # pylint: disable=logging-fstring-interpolation multi_slice_env = hasattr(jax.devices()[0], 'slice_index') dcn_parallelism = [ config.dcn_data_parallelism, config.dcn_fsdp_parallelism, config.dcn_tensor_parallelism, ] ici_parallelism = [ config.ici_data_parallelism, config.ici_fsdp_parallelism, config.ici_tensor_parallelism, ] # Find possible unspecified parallelisms dcn_parallelism = fill_unspecified_mesh_axes( dcn_parallelism, num_slices, 'DCN' ) ici_parallelism = fill_unspecified_mesh_axes( ici_parallelism, num_devices_per_slice, 'ICI' ) if multi_slice_env: mesh = mesh_utils.create_hybrid_device_mesh( ici_parallelism, dcn_parallelism ) else: mesh = mesh_utils.create_device_mesh(ici_parallelism) logging.info(f'Decided on mesh: {mesh}') # pylint: disable=logging-fstring-interpolation logging.info(f'Mesh shape: {mesh.shape}') # pylint: disable=logging-fstring-interpolation return mesh def fill_unspecified_mesh_axes( parallelism_vals, target_product, parallelism_type ): """Evaluates unspecified DCN/ICI parallelism values.""" if -1 in parallelism_vals: assert parallelism_vals.count(-1) == 1, ( f'Found unspecified values (-1) for more than one {parallelism_type} ' ' parallelism axis. At most one axis can be unspecified.' ) determined_val = target_product / np.prod(parallelism_vals) * -1 assert determined_val >= 1 and determined_val.is_integer, ( 'Unspecified value unable to be determined with the given ' f' {parallelism_type} parallelism values' ) parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) target_type = 'slices' if parallelism_type == 'DCN' else 'devices per slice' assert np.prod(parallelism_vals) == target_product, ( f'Number of {target_type} {target_product} does not match the product' f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' ) return parallelism_vals # State initialization utils. # ----------------------------------------------------------------------------- def _to_array(x): if not isinstance(x, jax.Array): x = jnp.asarray(x) return x def setup_initial_state( constructor: Callable[ [transformer.TransformerConfig, jax.Array], transformer.Transformer ], tx, config: transformer.TransformerConfig, rng: jax.Array, mesh: jax.sharding.Mesh, ) -> tuple[TrainState, TrainState]: """We initialize train state, optionally loading from checkpoint. Args: constructor: the model constructor tx: the optax.GradientTransformation config: config object rng: jax.prng key mesh: jax.devices() mesh Returns: state: the initialized train state state_mesh_annotations: the mesh annotations for the train state """ @jax.jit def sharded_init(): model = constructor(config, rng) graphdef, params = nnx.split(model, nnx.Param) state = TrainState.create( apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef, ) state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) return state # Initialization with jax.set_mesh(mesh): state = sharded_init() state_sharding = nnx.get_named_sharding(state, mesh) return state, state_sharding ================================================ FILE: examples/imagenet/README.md ================================================ ## ImageNet classification Trains a ResNet50 model ([He *et al.*, 2016]) for the ImageNet classification task ([Russakovsky *et al.*, 2015]). This example uses linear learning rate warmup and cosine learning rate schedule. [He *et al.*, 2016]: https://arxiv.org/abs/1512.03385 [Russakovsky *et al.*, 2015]: https://arxiv.org/abs/1409.0575 You can run this code and even modify it directly in Google Colab, no installation required: https://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb The Colab also demonstrates how to load pretrained checkpoints from Cloud storage at [gs://flax_public/examples/imagenet/](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet) Table of contents: - [Requirements](#requirements) - [Example runs](#example-runs) - [Running locally](#running-locally) - [Overriding parameters on the command line](#overriding-parameters-on-the-command-line) - [Running fake data benchmarks](#running-fake-data-benchmarks) - [Running on Cloud](#running-on-cloud) - [Preparing the dataset](#preparing-the-dataset) - [Google Cloud TPU](#google-cloud-tpu) - [Google Cloud GPU](#google-cloud-gpu) ### Requirements * TensorFlow dataset `imagenet2012:5.*.*` * `≈180GB` of RAM if you want to cache the dataset in memory for faster IO ### Example runs While the example should run on a variety of hardware, we have tested the following GPU and TPU configurations: | Name | Steps | Walltime | Top-1 accuracy | Metrics | Workdir | | :---------------------- | -----: | :------- | :------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | TPU v3-32 | 125100 | 2.1h | 76.54% | [tfhub.dev](https://tensorboard.dev/experiment/GhPHRoLzTqu7c8vynTk6bg/) | [gs://flax_public/examples/imagenet/tpu_v3_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v3_32) | | TPU v2-32 | 125100 | 2.5h | 76.67% | [tfhub.dev](https://tensorboard.dev/experiment/qBJ7T9VPSgO5yeb0HAKbIA/) | [gs://flax_public/examples/imagenet/tpu_v2_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v2_32) | | TPU v3-8 | 125100 | 4.4h | 76.37% | [tfhub.dev](https://tensorboard.dev/experiment/JwxRMYrsR4O6V6fnkn3dmg/) | [gs://flax_public/examples/imagenet/tpu](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu) | | v100_x8 | 250200 | 13.2h | 76.72% | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0®exInput=%5Eimagenet/v100_x8%24) | [gs://flax_public/examples/imagenet/v100_x8](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8) | | v100_x8_mixed_precision | 62500 | 4.3h | 76.27% | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0®exInput=%5Eimagenet/v100_x8_mixed_precision%24) | [gs://flax_public/examples/imagenet/v100_x8_mixed_precision](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8_mixed_precision) | ### Running locally ```shell python main.py --workdir=./imagenet --config=configs/default.py ``` #### Overriding parameters on the command line Specify a hyperparameter configuration by the means of setting `--config` flag. Configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py --workdir=./imagenet_default --config=configs/default.py \ --config.num_epochs=100 ``` ### Running fake data benchmarks Execute the following code with `flax/examples/imagenet` as your current directory: ```shell bash ../../tests/download_dataset_metadata.sh python imagenet_fake_data_benchmark.py ``` If you get an error like this: ```shell Cloning into 'datasets'... fatal: cannot change to 'https://github.com/tensorflow/datasets/': No such file or directory error: failed to initialize sparse-checkout ``` This mean your git version is outdated. Just update it and re-run. ### Running on Cloud #### Preparing the dataset For running the ResNet50 model on imagenet dataset, you first need to prepare the `imagenet2012` dataset. Download the data from http://image-net.org/ as described in the [tensorflow_datasets catalog](https://www.tensorflow.org/datasets/catalog/imagenet2012). Then point the environment variable `$IMAGENET_DOWNLOAD_PATH` to the directory where the downloads are stored and prepare the dataset by running ```shell python -c " import tensorflow_datasets as tfds tfds.builder('imagenet2012').download_and_prepare( download_config=tfds.download.DownloadConfig( manual_dir='$IMAGENET_DOWNLOAD_PATH')) " ``` The contents of the directory `~/tensorflow_datasets` should be copied to your gcs bucket. Point the environment variable `GCS_TFDS_BUCKET` to your bucket and run the following command: ```shell gcloud storage cp --recursive ~/tensorflow_datasets gs://$GCS_TFDS_BUCKET/datasets ``` #### Google Cloud TPU See below for commands to set up a single VM with 8 TPUs attached (`--accelerator-type v3-8`), or for an entire TPU slice spanning multiple VMs (e.g. `--accelerator-type v3-32`). For more details about how to set up and use TPUs, refer to Cloud docs for [single VM setup](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) and [pod slice setup](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). First create a single TPUv3-8 VM and connect to it: ``` ZONE=us-central1-a TPU_TYPE=v3-8 VM_NAME=imagenet gcloud alpha compute tpus tpu-vm create $VM_NAME \ --zone $ZONE \ --accelerator-type $TPU_TYPE \ --version v2-alpha gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ -L 6006:localhost:6006 ``` When connected install JAX: ``` pip install "jax[tpu]>=0.2.21" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` Then install Flax + the example dependencies: ``` git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/imagenet pip install -r requirements.txt ``` And finally start the training: ``` export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets python3 main.py --workdir=$HOME/logs/imagenet_tpu --config=configs/tpu.py \ --jax_backend_target="grpc://192.168.0.2:8470" ``` Note that you might want to set `TFDS_DATA_DIR` as explained above. You probably also want to start the long-running command above in a `tmux` session and start some monitoring in a separate pane (note that we forwarded port 6006 locally above): ``` tensorboard --logdir=$HOME/logs ``` When running on pod slices, after creating the TPU VM, there are different ways of running the training in SPMD fashion on the hosts connected to the TPUs that make up the slice. We simply send the same installation/execution shell commands to all hosts in parallel with the command below. If anything fails it's usually a good idea to connect to a single host and execute the commands interactively. For convenience, the TPU creation commands are inlined below. Please note that we define `GCS_TFDS_BUCKET` to where your data stands in your cloud bucket. Also `YOUR_BUCKET` is the work directory you are experimenting in. You should choose ZONE based on where your TPU and work directory is. [Here](https://cloud.google.com/tpu/docs/types-zones) has some useful information on which zones you can have different types of TPUs. ```shell VM_NAME=imagenet REPO=https://github.com/google/flax BRANCH=main WORKDIR=gs://$YOUR_BUCKET/flax/examples/imagenet/$(date +%Y%m%d_%H%M) gcloud alpha compute tpus tpu-vm create $VM_NAME \ --zone=$ZONE \ --version v2-alpha --accelerator-type v3-32 FLAGS="--config.batch_size=$((32*256))" gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE \ --worker=all --command " pip install 'jax[tpu]>=0.2.21' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install --user git+$REPO.git && git clone --depth=1 -b $BRANCH $REPO && cd flax/examples/imagenet && pip install -r requirements.txt && export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets && python3 main.py --workdir=$WORKDIR --config=configs/tpu.py $FLAGS " ``` Please don't forget to disconnect and delete your vm after you are done: ``` gcloud alpha compute tpus tpu-vm delete $VM_NAME \ --zone $ZONE ``` #### Google Cloud GPU Can be launched with utility script described in [../cloud/README.md](../cloud/README.md) There are two configurations available: - `configs/v100_x8.py` : Full precision GPU training - `configs/v100_x8_mixed_precision.py` : Mixed precision GPU training. Note that mixed precision handling is implemented manually with [`training.dynamic_scale`](https://github.com/google/flax/blob/main/flax/training/dynamic_scale.py) ================================================ FILE: examples/imagenet/configs/default.py ================================================ # Copyright 2021 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # As defined in the `models` module. config.model = 'ResNet50' # `name` argument of tensorflow_datasets.builder() config.dataset = 'imagenet2012:5.*.*' config.learning_rate = 0.1 config.warmup_epochs = 5.0 config.momentum = 0.9 config.batch_size = 128 config.shuffle_buffer_size = 16 * 128 config.prefetch = 10 config.num_epochs = 100.0 config.log_every_steps = 100 config.cache = False config.half_precision = False # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 # whether to profile the training loop config.profile = True return config def metrics(): return [ 'train_loss', 'eval_loss', 'train_accuracy', 'eval_accuracy', 'steps_per_second', 'train_learning_rate', ] ================================================ FILE: examples/imagenet/configs/fake_data_benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration for Fake data benchmark.""" import jax from configs import default as default_lib def get_config(): """Get the hyperparameter configuration for Fake data benchmark.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 256 * jax.device_count() config.half_precision = True config.num_epochs = 5 # Run for a single step: config.num_train_steps = 1 config.steps_per_eval = 1 return config ================================================ FILE: examples/imagenet/configs/tpu.py ================================================ # Copyright 2021 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on TPUs.""" from configs import default as default_lib def get_config(): """Get the hyperparameter configuration to train on TPUs.""" config = default_lib.get_config() # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you # train on a larger pod slice. config.batch_size = 1024 config.shuffle_buffer_size = 16 * 1024 config.cache = True config.half_precision = True return config metrics = default_lib.metrics ================================================ FILE: examples/imagenet/configs/v100_x8.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" from configs import default as default_lib def get_config(): """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 512 config.shuffle_buffer_size = 16 * 512 config.cache = True return config metrics = default_lib.metrics ================================================ FILE: examples/imagenet/configs/v100_x8_mixed_precision.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.""" from configs import default as default_lib def get_config(): """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.""" # Override default configuration to avoid duplication of field definition. config = default_lib.get_config() config.batch_size = 2048 config.shuffle_buffer_size = 16 * 2048 config.cache = True config.half_precision = True return config metrics = default_lib.metrics ================================================ FILE: examples/imagenet/imagenet.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Flax Imagenet Example\n", "\n", "\"Open\n", "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/imagenet\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", "1. Run the entire notebook end-to-end and check out the outputs.\n", " - This will open Python files in the right-hand editor!\n", " - You'll be able to interactively explore metrics in TensorBoard.\n", "2. Change `config` and train for different hyperparameters. Check out the\n", " updated TensorBoard plots.\n", "3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you\n", " make in the file will automatically appear in the notebook. Some ideas to\n", " get you started:\n", " - Change the model.\n", " - Log some per-batch metrics during training.\n", " - Add new hyperparameters to `configs/default.py` and use them in\n", " `train.py`.\n", "4. At any time, feel free to paste code from `train.py` into the notebook\n", " and modify it directly there!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "outputId": "cb862d1a-2f71-444f-9770-9f0d53b11389" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPU 0: Tesla T4 (UUID: GPU-d1652fa8-88b9-3d02-7a65-e7ebabeb0372)\n" ] } ], "source": [ "# Tested with a T4 GPU.\n", "!nvidia-smi -L" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "outputId": "80340396-77c2-4654-cc6d-67040f227eb9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "# Install ml-collections & latest Flax version from Github.\n", "!pip install -q clu ml-collections git+https://github.com/google/flax" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/imagenet'\n", "editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')\n", "\n", "repo, branch = 'https://github.com/google/flax', 'main'" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "cellView": "form", "outputId": "9449a7b4-8a5d-4446-abe0-7886435ebd1c" }, "outputs": [ { "data": { "text/html": [ "

WARNING : Editing in VM - changes lost after reboot!!

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/imagenet/configs/default.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/imagenet/input_pipeline.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/imagenet/models.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/imagenet/train.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", "# example directory and nothing needs to be done.)\n", "\n", "#@markdown **Fetch newest Flax, copy example code**\n", "#@markdown\n", "#@markdown **If you select no** below, then the files will be stored on the\n", "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n", "#@markdown be restarted and any changes are lost**.\n", "#@markdown\n", "#@markdown **If you select yes** below, then you will be asked for your\n", "#@markdown credentials to mount your personal Google Drive. In this case, all\n", "#@markdown changes you make will be *persisted*, and even if you re-run the\n", "#@markdown Colab later on, the files will still be the same (you can of course\n", "#@markdown remove directories inside your Drive's `flax/` root if you want to\n", "#@markdown manually revert these files).\n", "\n", "if 'google.colab' in str(get_ipython()):\n", " import os\n", " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", " !git clone --depth=1 -b $branch $repo flaxrepo\n", " # Copy example files & change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", " if mount_gdrive == 'yes':\n", " DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n", " else:\n", " DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'\n", " example_root_path = f'/content/{example_directory}'\n", " from IPython import display\n", " display.display(display.HTML(\n", " f'

{DISCLAIMER}

'))\n", " if not os.path.isdir(example_root_path):\n", " os.makedirs(example_root_path)\n", " !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n", " os.chdir(example_root_path)\n", " from google.colab import files\n", " for relpath in editor_relpaths:\n", " s = open(f'{example_root_path}/{relpath}').read()\n", " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "acc1f45d-5062-4ff3-e6d4-10b4ffe0f8ef" }, "outputs": [], "source": [ "# Note : In Colab, above cell changed the working directory.\n", "!pwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports / Helpers" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "cellView": "form", "outputId": "9dc7fb32-331e-44a6-b6e8-830f6a64d845" }, "outputs": [ { "data": { "text/plain": [ "[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# It's possible to run this Colab with TPUs:\n", "# 1. change runtime type to TPU\n", "# 2. install compatible flax version: `!pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25`\n", "# 3. uncomment lines below\n", "\n", "# import flax, jax, jax.tools.colab_tpu\n", "# jax.tools.colab_tpu.setup_tpu()\n", "\n", "import jax\n", "jax.devices()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "import json\n", "from absl import logging\n", "import flax\n", "import jax\n", "import jax.numpy as jnp\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "logging.set_verbosity(logging.INFO)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", "\n", "def show_img(img, ax=None, title=None):\n", " \"\"\"Shows a single image.\"\"\"\n", " if ax is None:\n", " ax = plt.gca()\n", " img *= tf.constant(input_pipeline.STDDEV_RGB, shape=[1, 1, 3], dtype=img.dtype)\n", " img += tf.constant(input_pipeline.MEAN_RGB, shape=[1, 1, 3], dtype=img.dtype)\n", " img = np.clip(img.numpy().astype(int), 0, 255)\n", " ax.imshow(img)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " if title:\n", " ax.set_title(title)\n", "\n", "def show_img_grid(imgs, titles):\n", " \"\"\"Shows a grid of images.\"\"\"\n", " n = int(np.ceil(len(imgs)**.5))\n", " _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))\n", " for i, (img, title) in enumerate(zip(imgs, titles)):\n", " show_img(img, axs[i // n][i % n], title)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "outputId": "f943d165-b953-4a70-9f93-96eb857c3d53" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "# Local imports from current directory - auto reload.\n", "# Any changes you make to train.py will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", "import input_pipeline\n", "import models\n", "import train\n", "from configs import default as config_lib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "outputId": "a9b6cfe9-cc1c-451a-f8f7-69356cb7bdd2" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:No config specified, defaulting to config: imagenette/full-size-v2\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "INFO:absl:Reusing dataset imagenette (/root/tensorflow_datasets/imagenette/full-size-v2/1.0.0)\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split train, from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n" ] }, { "data": { "text/plain": [ "tfds.core.DatasetInfo(\n", " name='imagenette',\n", " full_name='imagenette/full-size-v2/1.0.0',\n", " description=\"\"\"\n", " Imagenette is a subset of 10 easily classified classes from the Imagenet\n", " dataset. It was originally prepared by Jeremy Howard of FastAI. The objective\n", " behind putting together a small version of the Imagenet dataset was mainly\n", " because running new ideas/algorithms/experiments on the whole Imagenet take a\n", " lot of time.\n", " \n", " This version of the dataset allows researchers/practitioners to quickly try out\n", " ideas and share with others. The dataset comes in three variants:\n", " \n", " * Full size\n", " * 320 px\n", " * 160 px\n", " \n", " Note: The v2 config correspond to the new 70/30 train/valid split (released in\n", " Dec 6 2019).\n", " \"\"\",\n", " config_description=\"\"\"\n", " full-size variant.\n", " \"\"\",\n", " homepage='https://github.com/fastai/imagenette',\n", " data_path='/root/tensorflow_datasets/imagenette/full-size-v2/1.0.0',\n", " file_format=tfrecord,\n", " download_size=1.45 GiB,\n", " dataset_size=1.46 GiB,\n", " features=FeaturesDict({\n", " 'image': Image(shape=(None, None, 3), dtype=uint8),\n", " 'label': ClassLabel(shape=(), dtype=int64, num_classes=10),\n", " }),\n", " supervised_keys=('image', 'label'),\n", " disable_shuffling=False,\n", " splits={\n", " 'train': ,\n", " 'validation': ,\n", " },\n", " citation=\"\"\"@misc{imagenette,\n", " author = \"Jeremy Howard\",\n", " title = \"imagenette\",\n", " url = \"https://github.com/fastai/imagenette/\"\n", " }\"\"\",\n", ")" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We load \"imagenette\" that has similar pictures to \"imagenet2012\" but can be\n", "# downloaded automatically and is much smaller.\n", "dataset_builder = tfds.builder('imagenette')\n", "dataset_builder.download_and_prepare()\n", "ds = dataset_builder.as_dataset('train')\n", "dataset_builder.info" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# Utilities to help with Imagenette labels.\n", "\n", "![ ! -f mapping_imagenet.json ] && wget --no-check-certificate https://raw.githubusercontent.com/ozendelait/wordnet-to-json/master/mapping_imagenet.json\n", "\n", "with open('mapping_imagenet.json') as f:\n", " mapping_imagenet = json.load(f)\n", "# Mapping imagenette label name to imagenet label index.\n", "imagenette_labels = {\n", " d['v3p0']: d['label']\n", " for d in mapping_imagenet\n", "}\n", "# Mapping imagenette label name to human-readable label.\n", "imagenette_idx = {\n", " d['v3p0']: idx\n", " for idx, d in enumerate(mapping_imagenet)\n", "}\n", "\n", "def imagenette_label(idx):\n", " \"\"\"Returns a short human-readable string for provided imagenette index.\"\"\"\n", " net = dataset_builder.info.features['label'].int2str(idx)\n", " return imagenette_labels[net].split(',')[0]\n", "\n", "def imagenette_imagenet2012(idx):\n", " \"\"\"Returns the imagenet2012 index for provided imagenette index.\"\"\"\n", " net = dataset_builder.info.features['label'].int2str(idx)\n", " return imagenette_idx[net]\n", "\n", "def imagenet2012_label(idx):\n", " \"\"\"Returns a short human-readable string for provided imagenet2012 index.\"\"\"\n", " return mapping_imagenet[idx]['label'].split(',')[0]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "outputId": "78142300-cc8b-4a6c-f781-5ab29578d828" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split train[0:9469], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] } ], "source": [ "train_ds = input_pipeline.create_split(\n", " dataset_builder, 128, train=True,\n", ")\n", "eval_ds = input_pipeline.create_split(\n", " dataset_builder, 128, train=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "outputId": "8b3b9cf2-7649-4953-99bb-32a689fe0a29" }, "outputs": [ { "data": { "text/plain": [ "{'image': (TensorShape([128, 224, 224, 3]), tf.float32),\n", " 'label': (TensorShape([128]), tf.int64)}" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_batch = next(iter(train_ds))\n", "{k: (v.shape, v.dtype) for k, v in train_batch.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training from scratch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", "# (In Jupyter[lab] start \"tensorbaord\" in the local directory instead.)\n", "if 'google.colab' in str(get_ipython()):\n", " %load_ext tensorboard\n", " %tensorboard --logdir=." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "outputId": "2d0bc789-213d-4a34-a7b1-e7852b40f375" }, "outputs": [ { "data": { "text/plain": [ "batch_size: 32\n", "cache: false\n", "dataset: imagenette\n", "half_precision: true\n", "learning_rate: 0.1\n", "log_every_steps: 100\n", "model: ResNet18\n", "momentum: 0.9\n", "num_epochs: 5.0\n", "num_train_steps: -1\n", "prefetch: 1\n", "shuffle_buffer_size: 128\n", "steps_per_eval: -1\n", "warmup_epochs: 0.5" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config = config_lib.get_config()\n", "config.dataset = 'imagenette'\n", "config.model = 'ResNet18'\n", "config.half_precision = True\n", "# Reduce batch size, shuffle buffer and prefetch to avoid Colab runtime OOM.\n", "config.batch_size = 32\n", "config.shuffle_buffer_size = 128\n", "config.prefetch = 1\n", "# Reduce epochs so this Colab doesn't take forever.\n", "config.num_epochs = 5.0\n", "config.warmup_epochs = 0.5\n", "config" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "outputId": "de56d320-c336-459b-f258-5d6ae41ce0af" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split train[0:9469], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] } ], "source": [ "# Regenerate datasets with updated batch_size.\n", "train_ds = input_pipeline.create_split(\n", " dataset_builder, config.batch_size, train=True,\n", ")\n", "eval_ds = input_pipeline.create_split(\n", " dataset_builder, config.batch_size, train=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "outputId": "018da2c5-c6f0-42ac-843f-7ac855a6bf14" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:No config specified, defaulting to config: imagenette/full-size-v2\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split train[0:9469], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "ResNet18_lr=0.03\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n", "INFO:absl:Restoring legacy Flax checkpoint from models/ResNet18_lr=0.03/checkpoint_1475\n", "INFO:absl:Initial compilation, this might take some minutes...\n", "INFO:absl:No config specified, defaulting to config: imagenette/full-size-v2\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split train[0:9469], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "ResNet18_lr=0.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n", "INFO:absl:Restoring legacy Flax checkpoint from models/ResNet18_lr=0.1/checkpoint_1475\n", "INFO:absl:Initial compilation, this might take some minutes...\n", "INFO:absl:No config specified, defaulting to config: imagenette/full-size-v2\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "INFO:absl:Constructing tf.data.Dataset imagenette for split train[0:9469], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "ResNet18_lr=0.3\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n", "INFO:absl:Restoring legacy Flax checkpoint from models/ResNet18_lr=0.3/checkpoint_1475\n", "INFO:absl:Initial compilation, this might take some minutes...\n" ] } ], "source": [ "# Takes ~1.5 min / epoch.\n", "for lr in (0.03, 0.1, 0.3):\n", " config.learning_rate = lr\n", " name = f'{config.model}_lr={config.learning_rate}'\n", " print(f'\\n\\n{name}')\n", " state = train.train_and_evaluate(config, workdir=f'./models/{name}')" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "cellView": "form" }, "outputs": [], "source": [ "if 'google.colab' in str(get_ipython()):\n", " #@markdown You can upload the training results directly to https://tensorbaord.dev\n", " #@markdown\n", " #@markdown Note that everybody with the link will be able to see the data.\n", " upload_data = 'no' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load pre-trained model" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "outputId": "b06fa3d8-a950-46d2-e03e-fc6c971bdbd0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-rw-r--r-- 1 root root 196M Apr 3 13:38 checkpoint_250200\n" ] } ], "source": [ "# Load model checkpoint from cloud.\n", "from flax.training import checkpoints\n", "\n", "config_name = 'v100_x8'\n", "pretrained_path = f'gs://flax_public/examples/imagenet/{config_name}'\n", "latest_checkpoint = checkpoints.natural_sort(\n", " tf.io.gfile.glob(f'{pretrained_path}/checkpoint_*'))[0]\n", "if not os.path.exists(os.path.basename(latest_checkpoint)):\n", " tf.io.gfile.copy(latest_checkpoint, os.path.basename(latest_checkpoint))\n", "!ls -lh checkpoint_*" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "# Load config that was used to train checkpoint.\n", "import importlib\n", "config = importlib.import_module(f'configs.{config_name}').get_config()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "outputId": "57777298-4b4b-4a82-b0f2-4b6ff3b949af" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Restoring legacy Flax checkpoint from ./checkpoint_250200\n" ] } ], "source": [ "# Load models & state (takes ~1 min to load the model).\n", "model_cls = getattr(models, config.model)\n", "model = train.create_model(\n", " model_cls=model_cls, half_precision=config.half_precision)\n", "base_learning_rate = config.learning_rate * config.batch_size / 256.\n", "steps_per_epoch = (\n", " dataset_builder.info.splits['train'].num_examples // config.batch_size\n", ")\n", "learning_rate_fn = train.create_learning_rate_fn(\n", " config, base_learning_rate, steps_per_epoch)\n", "state = train.create_train_state(\n", " jax.random.key(0), config, model, image_size=input_pipeline.IMAGE_SIZE,\n", " learning_rate_fn=learning_rate_fn)\n", "state = train.restore_checkpoint(state, './')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "outputId": "793a656b-f3ad-4596-ad4f-44c686e5e885" }, "outputs": [ { "data": { "text/plain": [ "{'image': TensorShape([32, 224, 224, 3]), 'label': TensorShape([32])}" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load batch from imagenette eval set.\n", "batch = next(iter(eval_ds))\n", "{k: v.shape for k, v in batch.items()}" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "# Evaluate using model trained on imagenet.\n", "logits = model.apply({'params': state.params, 'batch_stats': state.batch_stats}, batch['image'][:128], train=False)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "outputId": "6ab4bb0b-2c03-4663-d7ac-e51b979d121f" }, "outputs": [ { "data": { "text/plain": [ "[16, 24]" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Find classification mistakes.\n", "preds_labels = list(zip(logits.argmax(axis=-1), map(imagenette_imagenet2012, batch['label'])))\n", "error_idxs = [idx for idx, (pred, label) in enumerate(preds_labels) if pred != label]\n", "error_idxs" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "outputId": "142c1acf-037e-4ab0-9ca3-bdf0829c51c4" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAJKCAYAAAAx/3HgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d7wlVZX3j7/3rnzSzfd2387d5IyNiGRBRQQUEyM+Chhx1NHRnzpRUcfw+HMcnXFGMcyAaQIC84g6ijqAARAEJDYNTedwu28+sXLt7x+7zrn30k1wBLGZ83nZck+dOlW70tqfWuuz1hJKKUUXXXTRRRdddNFFF48L+UwPoIsuuuiiiy666GJ/QJc0ddFFF1100UUXXTwJdElTF1100UUXXXTRxZNAlzR10UUXXXTRRRddPAl0SVMXXXTRRRdddNHFk0CXNHXRRRdddNFFF108CXRJUxdddNFFF1100cWTQJc0ddFFF1100UUXXTwJdElTF1100UUXXXTRxZNAlzQ9i3HTTTchhOCmm256pofyB4OPfOQjCCGe6WF00UUXvwf8IdvA008/ndNPP73zecuWLQghuPLKK5+xMXXxxOiSpi4eE//6r//K5z//+Wd6GP8rcfrpp3PJJZc808PooosuuuhiHsxnegBd/OHiX//1X7n//vv50z/902d6KF100UUXz2qsWLEC3/exLOuZHkoXj4Oup+kPEFmWEQTBMz2MLp4BtFqtZ3oIXXTxjOMP2QY+Xc+oEALXdTEM42nZfhdPDbqk6WlCWzuzfv16LrjgAiqVCgMDA7znPe/ZyxgIIXjXu97Ft7/9bQ4//HAcx+FHP/oRADt37uRNb3oTIyMjOI7D4Ycfzr/8y7/stb8dO3Zw/vnnUywWGR4e5r3vfS9hGO61XqvVYv369UxOTj7u+E8//XR+8IMfsHXrVoQQCCFYuXIlAFEU8eEPf5i1a9fS09NDsVjklFNO4cYbb1ywjXaM/m//9m/53Oc+x4oVK/A8j9NOO437779/r32uX7+eV7/61fT39+O6LscddxzXXXfd447zf7KfR+OKK67gjDPOYHh4GMdxOOyww/jSl760YJ2LL76YwcFB4jje6/cvfvGLOfjggxcs+9a3vsXatWvxPI/+/n5e+9rXsn379gXrnH766RxxxBHceeednHrqqRQKBf7yL//yCcfbRRf7A/Z3GwiP/4x+97vf5ZxzzmF0dBTHcVizZg1/8zd/Q5qme23nK1/5CmvWrMHzPI4//nh+8Ytf7LXOY2mabrjhBk455RSKxSK9vb28/OUv58EHH3zCsXfx9KAbnnuaccEFF7By5Uo+9alP8atf/Yp/+Id/YGZmhm984xsL1rvhhhu46qqreNe73sXg4CArV65kz549nHDCCR2DMjQ0xA9/+EPe/OY3U6vVOmEz3/c588wz2bZtG+9+97sZHR3lm9/8JjfccMNe47n99tt5wQtewGWXXcZHPvKRxxz3X/3VX1GtVtmxYwef+9znACiVSgDUajW+9rWvceGFF/LWt76Ver3OP//zP3PWWWdx++23c8wxxyzY1je+8Q3q9TrvfOc7CYKAv//7v+eMM87gvvvuY2RkBIAHHniAk046iSVLlvDnf/7nFItFrrrqKs4//3yuueYaXvGKVzzhuX4y+9kXvvSlL3H44Yfzspe9DNM0+d73vsc73vEOsizjne98JwBveMMb+MY3vsH111/Pueee2/nt7t27ueGGG7jssss6yz7xiU/woQ99iAsuuIC3vOUtTExM8IUvfIFTTz2V3/zmN/T29nbWnZqa4uyzz+a1r30tr3/96x93nF10sT9if7WBbTzWM3rllVdSKpV43/veR6lU4oYbbuDDH/4wtVqNz3zmM53f//M//zOXXnopJ554In/6p3/Kpk2beNnLXkZ/fz/Lli173H3/9Kc/5eyzz2b16tV85CMfwfd9vvCFL3DSSSdx1113dV5ku/g9QnXxtOCyyy5TgHrZy162YPk73vEOBah77rmnswxQUkr1wAMPLFj3zW9+s1q8eLGanJxcsPy1r32t6unpUa1WSyml1Oc//3kFqKuuuqqzTrPZVAcccIAC1I033thZfuONNypAXXbZZU94DOecc45asWLFXsuTJFFhGC5YNjMzo0ZGRtSb3vSmzrLNmzcrQHmep3bs2NFZfttttylAvfe97+0sO/PMM9WRRx6pgiDoLMuyTJ144onqwAMPfNxx/jb7aV+X+Wifx/k466yz1OrVqzuf0zRVS5cuVX/0R3+0YL2/+7u/U0IItWnTJqWUUlu2bFGGYahPfOITC9a77777lGmaC5afdtppClCXX3754x5fF13sj3g22MDHe0b3ZTcuvfRSVSgUOnYsiiI1PDysjjnmmAU28ytf+YoC1GmnndZZ1rZjV1xxRWfZMccco4aHh9XU1FRn2T333KOklOqiiy56wvF38dSjG557mtH2VLTxJ3/yJwD813/914Llp512Gocddljns1KKa665hvPOOw+lFJOTk51/Z511FtVqlbvuuquzrcWLF/PqV7+68/tCocDb3va2vcZz+umno5R6Um9YjwXDMLBtG9Dag+npaZIk4bjjjuuMaT7OP/98lixZ0vl8/PHH87znPa9zDqanp7nhhhu44IILqNfrneOcmprirLPOYsOGDezcufMJx/VE+3kseJ7X+btarTI5Oclpp53Gpk2bqFarAEgp+T//5/9w3XXXUa/XO+t/+9vf5sQTT2TVqlUAXHvttWRZxgUXXLDgmi1atIgDDzxwrxCm4zi88Y1vfMJj66KL/RX7uw18rGd0vt1o261TTjmlE/4DuOOOOxgfH+ftb397x2YCXHLJJfT09DzufsfGxrj77ru55JJL6O/v7yw/6qijeNGLXvSEdq2Lpwfd8NzTjAMPPHDB5zVr1iClZMuWLQuWtyfdNiYmJpidneUrX/kKX/nKV/a57fHxcQC2bt3KAQccsFf9oUfrbJ5KfP3rX+ezn/0s69evX6DzefRxwN7nAOCggw7iqquuAuCRRx5BKcWHPvQhPvShD+1zf+Pj4wsI0b7wRPt5LNx8881cdtll3HrrrXuJPKvVase4XXTRRXz605/mP//zP7nooot46KGHuPPOO7n88ss762/YsAGl1D7HAuyVGbNkyZIFxrSLLp5t2N9t4GM9ow888AB//dd/zQ033ECtVlvwXftla+vWrcDe58CyLFavXv24+23/dl/HcOihh3L99dfTbDYpFotP/mC6+J3RJU2/ZzxWYcX5by2gPTgAr3/967n44ov3+ZujjjrqqR3ck8S3vvUtLrnkEs4//3w+8IEPMDw8jGEYfOpTn2Ljxo2/9fbax/r+97+fs846a5/rHHDAAb/TmB8LGzdu5Mwzz+SQQw7h7/7u71i2bBm2bfNf//VffO5zn+uMDeCwww5j7dq1fOtb3+Kiiy7iW9/6FrZtc8EFFyw4FiEEP/zhD/eZBdPWhbXx6OveRRfPduxvNnBfz+js7CynnXYalUqFj33sY6xZswbXdbnrrrv4sz/7swV2o4tnF7qk6WnGhg0bFrxBPfLII2RZ9oQCvqGhIcrlMmma8sIXvvBx112xYgX3338/SqkFBumhhx76ncb+WMbt6quvZvXq1Vx77bUL1pkvhp6PDRs27LXs4Ycf7pyD9huXZVlPeKyPhyfaz77wve99jzAMue6661i+fHln+aPDaG1cdNFFvO9972NsbIx//dd/5ZxzzqGvr6/z/Zo1a1BKsWrVKg466KD/8bF00cWzBfuzDXws3HTTTUxNTXHttddy6qmndpZv3rx5r3GBPgdnnHFGZ3kcx2zevJmjjz76MffR/u2+jmH9+vUMDg52vUzPALqapqcZ//RP/7Tg8xe+8AUAzj777Mf9nWEYvOpVr+Kaa67ZZ9r8xMRE5++XvvSl7Nq1i6uvvrqzrNVq7dOl/duk2xaLxY6b+dFjA605aOO2227j1ltv3ed2/t//+38LNEm33347t912W+ccDA8Pc/rpp/PlL3+ZsbGxvX4//1gfD0+0n31hX8dSrVa54oor9rn+hRdeiBCC97znPWzatInXv/71C75/5StfiWEYfPSjH12wzfY+pqamntSxdNHFswX7sw18vLHBQrsRRRFf/OIXF6x33HHHMTQ0xOWXX04URZ3lV155JbOzs4+7j8WLF3PMMcfw9a9/fcG6999/Pz/+8Y956Utf+j8efxf/c3Q9TU8zNm/ezMte9jJe8pKXcOutt/Ktb32L173udY/7htHG//2//5cbb7yR5z3vebz1rW/lsMMOY3p6mrvuuouf/vSnTE9PA/DWt76Vf/zHf+Siiy7izjvvZPHixXzzm9+kUCjstc3fJt127dq1/Md//Afve9/7eO5zn0upVOK8887j3HPP5dprr+UVr3gF55xzDps3b+byyy/nsMMOo9Fo7LWdAw44gJNPPpk//uM/JgxDPv/5zzMwMMAHP/jBzjr/9E//xMknn8yRRx7JW9/6VlavXs2ePXu49dZb2bFjB/fcc88Tnq8ns59H48UvfjG2bXPeeedx6aWX0mg0+OpXv8rw8PA+CdzQ0BAveclL+M53vkNvby/nnHPOgu/XrFnDxz/+cf7iL/6CLVu2cP7551Mul9m8eTP/+Z//ydve9jbe//73P+GxdNHFswX7sw18LJx44on09fVx8cUX8+53vxshBN/85jf3elGyLIuPf/zjXHrppZxxxhn80R/9EZs3b+aKK654Qk0TwGc+8xnOPvtsnv/85/PmN7+5U3Kgp6fnd0rm6eJ3wO89X+9/CdrptuvWrVOvfvWrVblcVn19fepd73qX8n1/wbqAeuc737nP7ezZs0e9853vVMuWLVOWZalFixapM888U33lK19ZsN7WrVvVy172MlUoFNTg4KB6z3veo370ox/9Tum2jUZDve51r1O9vb0K6JQfyLJMffKTn1QrVqxQjuOoY489Vn3/+99XF1988YISBe0U2s985jPqs5/9rFq2bJlyHEedcsopC9KN29i4caO66KKL1KJFi5RlWWrJkiXq3HPPVVdfffXjjvO32c++Sg5cd9116qijjlKu66qVK1eqT3/60+pf/uVfFKA2b9681/6uuuoqBai3ve1tjzmma665Rp188smqWCyqYrGoDjnkEPXOd75TPfTQQ511TjvtNHX44Yc/7rF10cX+imeDDXy8Z/Tmm29WJ5xwgvI8T42OjqoPfvCD6vrrr99rf0op9cUvflGtWrVKOY6jjjvuOPXzn/9cnXbaaU9YckAppX7605+qk046SXmepyqVijrvvPPUunXrnnDsXTw9EEo9ihp38ZTgIx/5CB/96EeZmJhgcHDwmR7OM4ItW7awatUqPvOZzzyt3pXf137a+O53v8v555/Pz3/+c0455ZSnfX9ddLE/omsDu3g2oqtp6qKL3xJf/epXWb16NSeffPIzPZQuuuiiiy5+j+hqmrro4kni3//937n33nv5wQ9+wN///d8/ZnZhF1100UUXz050SVMXXTxJXHjhhZRKJd785jfzjne845keThdddNFFF79ndDVNXXTRRRdddNFFF08CXU1TF1100UUXXXTRxZNAlzR10UUXXXTRRRddPAl0SdOjcOWVVyKE2KuZ5JPB6aefzhFHHPGUjmflypVccsklT+k2u/jdcNNNNyGE4Kabbnqmh9LF/wJ0bVIXT4SuTfr9oUuauviDxLp16/jIRz6yz4nii1/8IldeeeXvfUz7Gz7ykY88YX+vLrro4smha5N+dzwbbFKXNHXxB4l169bx0Y9+tGuguuiiiz8IdG1SF9AlTV108axDs9l8pofQRRdddNHBs8kmdUnTk8B3v/tdzjnnHEZHR3EchzVr1vA3f/M3pGm6z/XvvPNOTjzxRDzPY9WqVVx++eV7rROGIZdddhkHHHAAjuOwbNkyPvjBDxKG4ROOZ+PGjWzcuPFJjX12dpb3vve9rFy5EsdxWLp0KRdddFGnw3cURXz4wx9m7dq19PT0UCwWOeWUU7jxxhv32ta///u/s3btWsrlMpVKhSOPPJK///u/73wfxzEf/ehHOfDAA3Fdl4GBAU4++WR+8pOfLNjO+vXrefWrX01/fz+u63Lcccdx3XXXdb6/8sorec1rXgPAC17wAoQQnXj9ypUreeCBB/jZz37WWX766acvON4//dM/ZdmyZTiOwwEHHMCnP/1psix7wnO1cuVKzj33XH784x9zzDHH4Louhx12GNdee+0T/vYXv/gFr3nNa1i+fHnner73ve/F9/3OOldccQVCCH7zm9/s9ftPfvKTGIbBzp07O8tuu+02XvKSl9DT00OhUOC0007j5ptvXvC7j3zkIwghWLduHa973evo6+vrVir/X4CuTdLo2qTHRtcmPT3oFrd8ErjyyisplUq8733vo1QqccMNN/DhD3+YWq3GZz7zmQXrzszM8NKXvpQLLriACy+8kKuuuoo//uM/xrZt3vSmNwGQZRkve9nL+OUvf8nb3vY2Dj30UO677z4+97nP8fDDD/P//t//e9zxnHnmmQBPKAxtNBqccsopPPjgg7zpTW/iOc95DpOTk1x33XXs2LGDwcFBarUaX/va17jwwgt561vfSr1e55//+Z8566yzuP322znmmGMA+MlPfsKFF17ImWeeyac//WkAHnzwQW6++Wbe8573APph+dSnPsVb3vIWjj/+eGq1GnfccQd33XUXL3rRiwB44IEHOOmkk1iyZAl//ud/TrFY5KqrruL888/nmmuu4RWveAWnnnoq7373u/mHf/gH/vIv/5JDDz0UgEMPPZTPf/7z/Mmf/AmlUom/+qu/AmBkZASAVqvFaaedxs6dO7n00ktZvnw5t9xyC3/xF3/B2NgYn//855/wWm/YsIE/+qM/4u1vfzsXX3wxV1xxBa95zWv40Y9+1DmGfeE73/kOrVaLP/7jP2ZgYIDbb7+dL3zhC+zYsYPvfOc7ALz61a/mne98J9/+9rc59thjF/z+29/+NqeffjpLliwB4IYbbuDss89m7dq1XHbZZUgpueKKKzjjjDP4xS9+wfHHH7/g9695zWs48MAD+eQnP7lXp/Uunn3o2qSuTerapGcIz2S34D9EXHHFFXt1t2+1Wnutd+mll6pCoaCCIOgsO+200xSgPvvZz3aWhWGojjnmGDU8PKyiKFJKKfXNb35TSSnVL37xiwXbvPzyyxWgbr755s6yFStWqIsvvnjBeitWrFArVqx4wmP58Ic/rAB17bXX7vVdlmVKKaWSJFFhGC74bmZmRo2MjKg3velNnWXvec97VKVSUUmSPOb+jj76aHXOOec87pjOPPNMdeSRRy44b1mWqRNPPFEdeOCBnWXf+c539tktXCmlDj/88AXdwdv4m7/5G1UsFtXDDz+8YPmf//mfK8Mw1LZt2x53bCtWrFCAuuaaazrLqtWqWrx4sTr22GM7y9pd0uePbV/3yKc+9SklhFBbt27tLLvwwgvV6OioStO0s+yuu+5a0N08yzJ14IEHqrPOOqtzndr7WLVqlXrRi17UWdbuJH/hhRc+7rF1sf+ia5O6Nqlrk/5w0A3PPQl4ntf5u16vMzk5ySmnnEKr1WL9+vUL1jVNk0svvbTz2bZtLr30UsbHx7nzzjsB/QZw6KGHcsghhzA5Odn5d8YZZwDs0w09H1u2bHlS6cfXXHMNRx99NK94xSv2+q7dN80wDGzbBvTb5vT0NEmScNxxx3HXXXd11u/t7aXZbO7l1p6P3t5eHnjgATZs2LDP76enp7nhhhu44IILOudxcnKSqakpzjrrLDZs2LDAFfzb4jvf+Q6nnHIKfX19C87rC1/4QtI05ec///kTbmN0dHTB+apUKlx00UX85je/Yffu3Y/5u/n3SLPZZHJykhNPPBGl1ALX90UXXcSuXbsWXONvf/vbeJ7Hq171KgDuvvtuNmzYwOte9zqmpqY6x9FsNjnzzDP5+c9/vpdr/+1vf/sTn6AunjXo2qSuTerapGcG3fDck8ADDzzAX//1X3PDDTdQq9UWfFetVhd8Hh0dpVgsLlh20EEHAdqwnHDCCWzYsIEHH3yQoaGhfe5vfHz8KRn3xo0bOzf94+HrX/86n/3sZ1m/fj1xHHeWr1q1qvP3O97xDq666irOPvtslixZwotf/GIuuOACXvKSl3TW+djHPsbLX/5yDjroII444ghe8pKX8IY3vIGjjjoKgEceeQSlFB/60If40Ic+tM+xjI+Pd9zBvy02bNjAvffe+zud1wMOOGCvRrzzr9+iRYv2+btt27bx4Q9/mOuuu46ZmZkF382/R170ohexePFivv3tb3PmmWeSZRn/9m//xstf/nLK5XLnOAAuvvjixxxntVqlr6+v83n+teri2Y+uTeraJOjapGcCXdL0BJidneW0006jUqnwsY99jDVr1uC6LnfddRd/9md/9qTEfI9GlmUceeSR/N3f/d0+v1+2bNnvOuwnjW9961tccsklnH/++XzgAx9geHgYwzD41Kc+tUDYOTw8zN13383111/PD3/4Q374wx9yxRVXcNFFF/H1r38dgFNPPZWNGzfy3e9+lx//+Md87Wtf43Of+xyXX345b3nLWzrn6v3vfz9nnXXWPsdzwAEH/I+PJcsyXvSiF/HBD35wn9+3Dc1TjTRNedGLXsT09DR/9md/xiGHHEKxWGTnzp1ccsklC+4RwzB43etex1e/+lW++MUvcvPNN7Nr1y5e//rXLzgOgM985jMd/cajUSqVFnye/1bZxbMbXZuk0bVJj42uTXr60CVNT4CbbrqJqakprr32Wk499dTO8s2bN+9z/V27dtFsNhe82T388MMAnaJea9as4Z577uHMM8/c6w3iqcSaNWu4//77H3edq6++mtWrV3PttdcuGMtll12217q2bXPeeedx3nnnkWUZ73jHO/jyl7/Mhz70oY5h6e/v541vfCNvfOMbaTQanHrqqXzkIx/hLW95C6tXrwbAsixe+MIXPu64Hu+8PNZ3a9asodFoPOG2Hw/tN8/5+3j09Xs07rvvPh5++GG+/vWvc9FFF3WWP1bY4KKLLuKzn/0s3/ve9/jhD3/I0NDQAoO9Zs0aQLvhf5dj6eLZia5NmkPXJq3c52+6NunpQ1fT9AQwDANggfo/iiK++MUv7nP9JEn48pe/vGDdL3/5ywwNDbF27VoALrjgAnbu3MlXv/rVvX7v+/4T1rR4sum9r3rVq7jnnnv4z//8z72+ax/Pvo7vtttu49Zbb12w/tTU1ILPUsqOi7udkvzodUqlEgcccEDn++HhYU4//XS+/OUvMzY2tteYJiYmOn+3Dfzs7Oxe6xWLxX0uv+CCC7j11lu5/vrr9/pudnaWJEn2Wv5o7Nq1a8H5qtVqfOMb3+CYY455TDf4vs6hUmpB6vN8HHXUURx11FF87Wtf45prruG1r30tpjn3/rJ27VrWrFnD3/7t39JoNPb6/fzz1MX/PnRtkkbXJnVt0jOBrqfpCXDiiSfS19fHxRdfzLvf/W6EEHzzm998zBTK0dFRPv3pT7NlyxYOOugg/uM//oO7776br3zlK1iWBcAb3vAGrrrqKt7+9rdz4403ctJJJ5GmKevXr+eqq67i+uuv57jjjnvMMT3Z9N4PfOADXH311bzmNa/hTW96E2vXrmV6eprrrruOyy+/nKOPPppzzz2Xa6+9lle84hWcc845bN68mcsvv5zDDjtswcPxlre8henpac444wyWLl3K1q1b+cIXvsAxxxzTSb897LDDOP3001m7di39/f3ccccdXH311bzrXe/qbOef/umfOPnkkznyyCN561vfyurVq9mzZw+33norO3bs4J577gHgmGOOwTAMPv3pT1OtVnEchzPOOIPh4WHWrl3Ll770JT7+8Y9zwAEHMDw8zBlnnMEHPvABrrvuOs4991wuueQS1q5dS7PZ5L777uPqq69my5YtDA4OPu45O+igg3jzm9/Mr3/9a0ZGRviXf/kX9uzZwxVXXPGYvznkkENYs2YN73//+9m5cyeVSoVrrrlmLx3BfFx00UW8//3vB1jgBgdt/L/2ta9x9tlnc/jhh/PGN76RJUuWsHPnTm688UYqlQrf+973Hvc4unj2omuTNLo2qWuTnhH8vtP1/tCxr/Tem2++WZ1wwgnK8zw1OjqqPvjBD6rrr79+rxTP0047TR1++OHqjjvuUM9//vOV67pqxYoV6h//8R/32k8URerTn/60Ovzww5XjOKqvr0+tXbtWffSjH1XVarWz3u+S3quUUlNTU+pd73qXWrJkibJtWy1dulRdfPHFanJyUimlU0k/+clPqhUrVijHcdSxxx6rvv/976uLL754wT6uvvpq9eIXv1gNDw8r27bV8uXL1aWXXqrGxsY663z84x9Xxx9/vOrt7VWe56lDDjlEfeITn+ikNbexceNGddFFF6lFixYpy7LUkiVL1LnnnquuvvrqBet99atfVatXr1aGYSw417t371bnnHOOKpfLCliQ6luv19Vf/MVfqAMOOEDZtq0GBwfViSeeqP72b/92r3E8GitWrFDnnHOOuv7669VRRx2lHMdRhxxyiPrOd76zYL19pfeuW7dOvfCFL1SlUkkNDg6qt771reqee+5ZkLY7H2NjY8owDHXQQQc95nh+85vfqFe+8pVqYGBAOY6jVqxYoS644AL13//935112um9ExMTj3tsXey/6Nqkrk3q2qQ/HAilnk1Vp7ro4n+OlStXcsQRR/D973//ad/X5OQkixcv5sMf/vBjZu100UUX/7vRtUl/eOhqmrro4hnAlVdeSZqmvOENb3imh9JFF1100bVJTxJdTVMXXfweccMNN7Bu3To+8YlPcP755z9m9ksXXXTRxe8DXZv026FLmrro4veIj33sY9xyyy2cdNJJfOELX3imh9NFF138L0fXJv126Gqauuiiiy666KKLLp4EupqmLrrooosuuuiiiyeB/SI8l2UZu3btolwuP63VarvooovfDUop6vU6o6OjSPnsfCfr2qMuutg/8HTYo/2CNO3atev32vuoiy66+N2wfft2li5d+kwP42lB1x510cX+hafSHu0XpKndafnoV3yJoUWLyLKQLGjiT4yz8e5fMTu5gySaZGS4Ql9fkULRolgoksQRjZkZiq7N6OLFjE9OYNgWrTAgVSmlSoVyTwXXtenpqZDFKdKALFOYhoUQgpmpGcZ276avt4/VK1dQKRdZuXIlwyPDeEWPRYtGKFcqTIzvZunS5WzduImVK9ewZ2ycNIyJgpB6vUGxXCZJYkzToNlqMT49jumYDA4OYgmTuBWQRCHjE+MsXb6Uhu/z67vu4mc/v4Vt23eQphnFQoEXvPglTDaarDr4KJTdz1QTxusxzdjAcnpxcDASA0vYgESRIWxJkDVp+jUwFQ2/Rd/AIDMzVYLAJ01S+nv7OfTAg+ktutzwo+9x2vPWcvNPf8xvfvVLVBqyeOkyjlx7Agcd/hwMu8LYZJVNW3diug5KZmT4bH9kHSuHHD7/yfcSVcfwZ8e55+47SdOYZcuWMDIyQk9vD4NDg5imSb1eJ45ipJS0Ap9Wo4EhJYaUNJpNess9CCmJg4hmq8mOnbu46ec/45777mFoYJAkCTnhucdz6imnYBgGM9UqrltgeHiEar3Of3znar76z/+CaUqKpRKNZpM0TjEtUECSgOGapGECGRiGwJAS23aQ0qDVaqGUQilFphSmAAFkuQrQACwDwhRS9DZBbwdApQoBmBKkAWkyt5JCb0ugY+QCvY0s/26+0FAApimIEoUBCClIlSKdt5Jh6HGpTO/LNOY2ZhgCKSVpmmKaJr6fYJsSIRTSMLAsizRNkULQbIYMD/Vx5plnMrpkMRsefojbbr+LsT2TGPkxq32MMcvHbwPhvGf22Yj2sf3whz9c0M/tt4VSiizLUEphWRaGYSCEIIoikiRBSolpmhiGQRzHpGmGEBIpDQQGUhr6eymJk5Tp6Sm2bNnCAw+sY8MjDzO2a4xWq4lSCmlIAj8kTROEEFiWhWnqVhtJkubbTwHdR01KiUKRqoRUxQhDYFomQkhQgvYdrFRGHCdkqW4H4noeju3osVsWpWKR/v4B+vv7sWwbpTKSOMErFBkcHMB1XcIgYsvmLZjYDA0tplKuEPg+O3bsYHx8nEaziWVamKZECIkQAtM0MU2r0+qj3VBWqYxMKbI0JU1TkiQhSVKSNCFNEtI01edDCizbplQs4DgeUkriOCLwfaLcJpUqZYaHhli8eBFDQ0OYps3k1AStZoskSQgCn1bLJwh8fD8kjqPO9kEhJAipkBKkVBiGpFgs0N8/wMDAAD09PRiGgWGYmKaB63o4jo2UEtu2sSwLy7Kwbbvz2TCM3GPS9nAulCS37dW+/u0L7Wv96G3s616d/xto37u6WrgQYoHX9dH7fPR3j172dKHZbHL22Wc/pfZovyBN7ZNbKpWIogbbtj5MdXyMokhZsqSfV7z4OFJ/mj3jWwiyJikJcexj2yZLl47gGSajI8N4xQKNIODg0cUMj44yMDzE8MgwvX092JbJiqVLME0jn8j0jRSFCb7vY1kGpWIB13VxXZfZ6gzTM9NgSqaqM2zaupWZWoNmvU4ritmyYTM9pR6Wji6hWCpQnZ1hemaKnTt3glAUy2WEZbJzx27CVghpSpqEtFo1JmbG+dUdd3DfuvVEMQwMLWJ4eIienhKTk3u45bbbuX/9I6w86Dj6Rw/BNUuEiUQqG8vpQRomcZyhUgUiRSpAWEjHQ1kK1zRJpaQVRqxYtpqVo8spOh7j27fyvWuvYtuDd+I2dvDac0/lE/+/8/nvn/6Ub//bv3PTD+7nwfvXMrz8cA466nkcccThPPDQI0hpsGhkGbXJScK4ge1U8MpN8CdZuWIJAwMDLFmyhMnJCe789a/JUBx77LGEYcBDD62nv7+f/v5+fvazn7Hh4Q0MD41QKhQ5bu1x/P0X/pF3vO1tfPPb/0r/0CCHHnYYaw5YQ7Hg8YPvfY8kSSgUCqRJgkpSkigijiMc26Gnpzc37fqBti2LRIDKMtJUgQAVZ5RLRZqNFmmqiRFpPhF0JjT9mCRpglAgxZzJivNm4WrevZopIFMdYqQySPL1TCkx20ZPgSIjy7ShF4AhIFWahLi2hedYhIFPphSVgkWl1MNMtU4QhXoMEkzTIM0yBIpM6G24ph5UGEEcKUwjJckgiBNMIEgyPeY4gSBBovcJUK1WufXmm+nt6SXLEvrLBVaOHsHS0cU4lsXs7CyGNEAKCoUCjmNjug7SMLjlF7/kno07ntVhq/axFYvFvTq7/zaYP6G1J/80TTuTmOM4neVRFJIkCUoJpLSwDAfDsJDSIMsyWtNTbNq0hVtvuYWbb76VqelJTNNgaGgQKQ3CMCRNM5IkQwhNMqJo7hrNESkTKTRhypS+Z4TKkG3ShCBNFUmS5JO9iWWZJIl+XmzbwvVcTFMT8UazSZYpklQ/p4ZhEEUJQRChlKJUKpEmGX4QsGx0mMMOOwTHdtk9tpudO3dQq9XYs2cPcRx3mtY6jp0TDAfXdent7aVYLKDU3LURUiCFRH+cIxjteTxJIgBc1wEEvu9Tr/s0mg0Mw6B/YJCDDz6EI488kiOOOJxly5Zj2zY7duyg0WiQpgmtlk+9Xmd2dpbZ2Vmq1Sq1Wo1ms0kQBCRJTBSHtFoNfL9BHIdMTEzyyCObUSrDMCSO0z4OB8/1cFwX27bo7e2hv3+QUqlET0+Z3t5+enp6KJUKWJad3xeauKRpm8Bo0iulzMmVgZRiL0LTRpt86d+2SefCXnX640Li1b4Oj7Xd+dt5NGkTQnT64v0+bcRTua/9gjS1ce/P/ovSoj6couTwQxezetEgroo4csUovYWDsJwTKPR4WI5FBpimRcFxcQyJWyiSklH3A/oHh+gfGmRiepodO3cgBWzbuZPx8T1IFFEQYJkmtu3QbLQYH58gigJIU1q1Okma0PIbTE9PowSUKiWEECwaXkTBcbAtjzSICRs+UbOFaWjDFvpNhgb66OnrZXjRCMK0aLRaGBjYpsHY7p1s2PQQ1aCJsiRHHHcsPb3DtJohY7t2MLltM5OTk4zt2MRrXn8xy9c8h1ZSZPdMSrSnRaMZIZIAgUESJQhhEEYB0pJYrokQJbJIQZSQRQaHLDuaxQN9BNM1Hnnk16y/+1Z2rv815bLgvl//mLLYyZLBP+K8l70AtyL4t+98l4e3/Iqp5h6E6XPeH72JR7YmSNvBMCRCGvh+xsZNYxyyVFCrz1AsFjoP5caNm7jiyq8zMTnBCSc8j5GRYR566CGmpqbp7e3h4YceJk0zRhePsntsN7OzVW785c2UXJet27ezdcd2ZmdnOeaYozjkoENIXxyT5G93QRASxwmeV8RzPJIswzKM3IsjEZnSNDhToEC0eYvQBEka4LkOKoMwikizjAwwTIk0JGmWkikQc/Y3Jz0ahiFJ0zbJMrBM7dlJ0wzTlFiWQxBEIHKPUr79thdLCT0mlYEwQKVg2QaFok2j3iIGwiQmTWYJo6TjlSIDolR7eSR4BniOpOAYKBSqYNDbP8DQcD+7xycYXrSURn2WRaNLcQsllADLtBEqQylBT6WMISVbNm0h8H16ymVcxwbAc12yLMGPA00gDYNm5JOkCiUlpgl7qns38+xiIR79Bi6E6Hid2g1cTdPU3p7OhCOQ0gSl/5umGdPTU+zYsYONGzfxyCMb2Lp1C+Pj40RRgmN7QEaj3qJQ8CgUCti2TZIkpGlClqVkWZZPYrLj5VJKkaba85SkMRkKwzZAQRzFZGlKkmgvTqlUplQo4RUKxHFMq9UCwJAG5VIZKQ2UUlQqFRYvWkRPby9KKcbGdrNnzzjbt20nyb1B9XqdrZu2sXXLFnrKPQS+z/btO5iZmUER43pz5wMgjn2CoMHsrGJ6ZgIrJ3RCas+JmXtwDMNAGoY+RimRhh5Ts9kkyzI8z8WQBlEc02w0CIIg9/4IduzYCqTs3rMLz3WxHZtSqaTJrGFi2TbFkku5soTly5d0PFz6mZ7zfEVRhO83qdWqjI3tYuPGTezYsYPZ6gxKSVSmSfHk5DRxFINQOI5LwSvguDae51HwingFD89zcR0Xx3GxHQvHsbEsTa4dx2JoaIiBgQEsy8Z1nfw7KydI+tpm2UIio71xyYL7UXs050jRQs6hcoKqf9/28rXRJm3z7/d9kaf9FfsVaTr7hYez7JAVjC5fzAErlzHSU8afmcBD0eN5DI0MEMYRO3eNUa3ViZOYielpQr9Bgr6Zwzgl27QBQ5ps27GNhx56iCgM2bVjJ416FZWlmjQZJr09vViWRbPRRCBwTZOw0cC2TUrlIq3AxzANlixdxkB/HyKOMW0HM83o7+vDsmzIUtJ8aq2UK/T0VdixayePbNmoJ8FUUfAKeI7Lnond3L/+fmbqNVpxxNDiRYRTu9m9e5Lt27cxO6XJm2FZDA70Mb1nJw8/MkFqDlLpW41lOMSpIIpDlFCYjkngB2QJZE1Fve5jGS695QrFUoVl5VG2PvgA6+77Fds23kMwswWvmOIWJa3qLDfd9H0m9mzklNNOAgu8SoZVDEBMsOWhX7Jn7ASk0cSQZv6GbIPhEkQJpUoPwpAIJZiYGKfS04NlO/T19SOkDkEYhsGqVatw3QJ9fT0sXbqU4eERSoUS999/P8cceyxvf8sbGR4c5uxzz2FyYopWELBy1SpWr1zJiqVLuf/ee3BMi8HRQZYtXU6xVKLS00u90aSUu2SlAENKsiQlSzKQ2gi0p64gCJFSP8ipyhAGmIZBnGb5srSzHEUnLCYNMKUgSRRGHmvL2nErzYi0lynNEDIjVYooyRaQrQ4UkIKU4BUtHMumr6fCQG+ZSqWMxOCgAw9g99geykUP13FIMrR7PI/DmSKDLCKJQvxWk0azTiYU/YMlyuUCUxMGI/0eSVDHtixcxwIpsEyLtB1KNDRBsmybJI2JkphGs0V1tkoUBYRRyNT0NJlKkZZBFKcEgQ5BOB5E8bNT/P10oT15zA/VSSmRUnYmoza5MU2JUpI0zqjVWmzfvpN16x7kgQfuZ+PGjczOzuTeHgfLMskyTUh83ydJEyzLyvcnyDKVP7MCKecmMO25SDv7llI/KJnK8tBTRpqkxElMlmYY0sRzC5hGTBwlRFFEmqRkqY4VJ0lCGIQEQUQpTkEIkjilXquzc+cubQekgWVbTEyMMzmxh77eXoQQ1Kq1PFQmsG0D0xRk2dwY4yQijhOCsImQUr8MCdE5JiHniJJhGBjSQBp6Pd/383NlYxhmHmbUIUrDMKg3BDt2ZkzPTGrvWBxhGiajo6OUy2Ucx6FYLFEqaW9joVDA9fRz6TgunuN1PEKmaaBURqPRoFypEEX6JS9JEwzDwjJN0iwh8JtUq1XSNEEppUmfqUO0lmlhWhaWZWJZNrZl5x43B9f1sG0b13MZGRliYGAQ13UpFosUi0Vc18W27QVhtDZJNgwDy7I16cxDnzInnvsiNm3vkyZMCqWyDtl+rHu7/btHL9tfsV+RprPOeT79Aw6GVBjJBLPTE1hKEaQZszO7eXjTetav38CvbruDHTt20QpbtFpVfL8+5y6ct7041g+5UhlkCtu2UGmKZZqoNKM5Pc3w0Ai9Pb309/XSUy6TJjGWKRkaGSJTKUIa9PX26wkzTRECZmdnkEhKpRJBGLJnzzgTU5M4rk1Kyi9vu43t27fpbfb24Td9JicnCLOUYtFh0dJlFCpFfn37bbSmmxR6enE8l3K5iN8Cx/a48p+/gmv1MjOrGFx2JEccv5iiV6DWCkkznyTLIE0w7ZRCsUAcxQiVsGTIoafQS3Nqhg13buTWX/6E6tRWHMenVIwwzBZBs0Ffr01W6mPLlo1s3LSBTGakUlAo92GYGY3mJPfc9UsGhg4gzASGMrAMl1RGlCpD+q1neAlSZVTKFRYtXszg0BB9fT2kWYppWfRUyhSKBQzDplSukEQ+CnAdj4k9exgcXsTZLzsfv15DoCi6BTIBM7Mz1KtV+vv6QCnGx3aTJCmZkOwZn2BgYIjegX76+vpQ7YuuFEJIDKmJRpwBUmAYJikJZArfj5BSe41AILKMJE0hYU541CZEQks72h6iNMkwANsEKTIMdDhQ2gaGLREYDPZXQJooBKZlYDsupmXl4TltrA1L4toWpiGoFIuUikUa9SpSCZ5/0nPZsP5h+vt6KRVLoCSZSonDCJQiDkOajTrVmSkm04xWs0UUhOzZOc7WzbuYnA4I6rM0mjHr1u0gkxJD6PBlgsIElNAULE0SHZZBQO59QGlypARAikhiokQzQNs2sG0X2zJo1MPfj0F4FqEd8mjrmIQQuadSh1y0nsUkSxXVZoNdu8bYunUrExMTZJmiv78/ly+EhGHU0Z0kScz09DSTU1NUKnqyT5Kko78xDDGnacuJWpu4ubYNEsIwIFXocJxhkhgJabNJkqSEkQ77ZZkiyxRxnJCmGVAnjmMajQZCCMbGdjPQP4jnFWg2G9RrDZI4xTJtent7WLxoEbXZWXbt2kkY+gghSNKIONH3tlIZhmnM83gIpAGOYXa8Ipp46nXTLIMsIZ7nHRYIaHtS9AJafkx7VhBCYEhD6yiTkJnZaaamJ4njiDDU88TGTRs618K226TFxfMKFIuFPGRbplQq5lIOj1KpTLFQ7Og4hRAUCkV6Kr15CE0SRRGW1cI0zdwLmHuFVEaWZiRxigjCBWExKUVHE6XDcQIhwDANCoWFYykUCnieh+u6FAqFjufR8zwWLWprtozOfdi+Jx99j7bvDyG0nRRijty312+vM/937X/70lDtb9ivSNMn/uKvMI1aHn/NwywChNJagCzNCPyYViMkTlMykaLSiCzTbwkAfQMDFLwCQRhiSoOyV4QsxbVsBgf6SMOI3p4KWZrSaraI44Ta+DhRvc6M61Bv1IiSiCVLFlPuKWPbNrOzMzi2Q5ImxGHC9u07WLViNcVSmYce3sBDDz+EYVuYtkmqFNVGjaOOOYpXvuIVHHvMMWx+ZDNXX3MNj2zZxMvPP48LXvt6gijkT97/Hu6790H6+vsJmg3CRoO+cplatcaWqXX8yXv+gkXLDmHTrhZTjRlmp6ZJhInpODrUIyKWLO1j1crlOLbEM5ocfdAgD97d5L+//zNu/9lPmJnYSdFT9BQMFC1qjd2UioZev1yhXKjQbPoEaYS0LUy7QBClqBSmd2znhOeczkRVMNNICIMUkUj8IKPZUiSpgSUEg4PD2JaFUIqVK1ZSKBexTJMojklyd/bszBTbt23DdV1WrVhJtVrlll/ewgvOOIPrvvtdAJ573HEoBT/4wfe5+667ePWrXsX2rVtIk4TZao2NW7YQJxknn3QyLz33XEzT0oZSaW+SKbUiO0rSXD9kUHBdWs2WJnKGfiOM4rTzNqX5k8S09VtfluTiR1No71QmsERG1EoxFXgelMsuxaKXu9AdbNfFNG16e/vwiiWiJMN2XEZHl9A/MEArCGi1WhTLRZp+iz27x9m9eze2ZRJEMXfc9QAig/vXbyP2A8olE6szsWY0gwyRKcjDAyKNEWmMShPSJH8zlIqCzKhPJTiWQEUJSaqwLQPDFARRguu6SMMkjGPCJCZJMwzbxHEdsgySKMI2LRzXpdVsEscJZu7tyvyURssn6ZbK/R+hPVG1vUzzSVPbK5CmKfV6k7GxMbZv30Gj0WJocIiRkUXak6oywjCgXm8QhmFOoHxqtRr1eg3fb2kPUZYRhiFxHHRImVKaYCVpQpZq3RPCplQuMTw8jJNr1jzXIwgC9uzZQxRH2KaNaWgdlGu7ZEmKkBLPc3EsmyxJ8YOARq0BqaBQDAD97NmmrbVVYUzgByRpjBAZSRJiGNqzpJSelIWRdcJF7cnXMLVuR2UZWdbW87Qn53kanPY5bv8lBELqkKMmKPp5NwyJlJBlEMdpx6PVDkeiFI0kRs/5UieO5KSl7QFq66x0aMzOvUIOrquJixCCZrPJbHWWVsufR1I0ee3p6SXLUuIoJopzmUCakaZzHkBNUuYIiCZNMtdTJiiVdUTy88XkbS1YmzS19bl9fb1UKj1Ylg7jzV/fdhwc28F2tHerHTa2bZu+Pv2b9pjmk6pHh+LmE7H5BOrJYe4K/nZ4dLrKU4P9ijRt2vAQphVi2RmQoESK7TgUCxWCVkRvqY9KsYxreRTKJQollyCoE0c+o4sXU683CaKIYrGMFAaGkKgkZmp8Aoliao/+bxaFpHFCkr99hWHE5OQkpm3ieC5uwSGIQpKqjgXbpkVvby9ZBgWvxED/AJZj0Qp8LMfm9BeewTkvfzmVnjKpSlFC0NtTZvHICD3lCgcefAhHrX0OjWaDxYsX09NT4aqrv8P4rjH8ap2GXcUyBFHTx+vrp5XBkuFFrD3mUJ53ymls2R3yi9u3cttvNuB6NouWDlOslBhaNEip5JCEdR568E72bN/A9vsrbF63lZJpMDqU0OPaZGmLNKsTZw1KBahUHLI0IAospPIgExApUCl+q8psrYHCZkzdR2t6HMsaQgpwXIdEJYyP76Z03Ap+eMcd1GenUGlEb28PWZYShAGLF4+CgHXrHuDhDRtotVrEUUS5XOYFp7+A3koP9993P1/+0uXc9qtfsXPnTsrlMsVCASklDz30UOeNbe3atUxPTzM9O8vo0mWUKxWOOOIoRhYtIr3/PoQUWJahM2iUfqBNKTCBFEUShhhCh9kEWkxtCUEmACmwXQuv4LJ4yWL6+vtoNZskaYJpGtimhWfbNOs1pscncG2bgmshhDbiUgpsQ6DSmHq9weT4OGGY4YdauFTwXCzHJkq0R8syDcI0JQgCWn6AaUokivpsi8HBCrt27MKQMD2jQ46WJTFNSYbEtkwc2wKZkWUJhqmQpsByDFzLwRCCmZkqnqe9qD1lB78Vo7IUI5PIBEyV4De10bVFhilARSlpGmCY2likYUoctzAyhVACU2hVfAbECgqWIkieUTOx32F+eOPRE0x7YhRCEPgBExMT7Ny5k4nJCQQmw0Mj9PX1USwWQEAURbRaTXy/he/rzK40TYjjmLHdu2g06oD2HrVaTYLAJ060ximJY8Iw6EzWlmlQKhbzkFQF0zKplCs0my0MaTA9PQMoPNdDCokpTWzLRghBqVTGMLS2qdXyieIIKQwkBqZlojwIg5BGs0GjrrNms0yH4jKVIFWGZUukMecFk4b2EKVpHhaSEsOEJE5JM50ViNBeEAFzKap7oe0u1iRrvm5HkZKkWce7pfLsU8vWL2DtzDiVv7ArlZEkEUkS4ft+Tuy0DlIKnVKkMoFh6JcNx3Ex8kzWKIqIkwTDkBQ8TWLK5QoAcRwTRqEmbIm+fnEcd/RH7bEJIZBChyL1cZsdbVwYRgu0a21C5DjOvIw8s+Op0t4z/X3B8ygUihTz8GOxVKRYKOJ5HqZpUiwWiaJFeJ63z/v20Z5TnQUqUbmD48k2Ilm4XvsZefQ6j/f7J7Wb3wr7FWlyHAPTsDBEpF21lkGiMjIylIAgiMi0/JUsM8kUZMQIAyzPYffmTZTKPVTrVWrVOgaCguvS9FskYYBKUkxD4vstsvxBtR0HhQDTQElBnCXErZQwjbBtnQLqFVyKSQmVgtfvkdkZURTTbAUIQ9Lb349Ccfsdv+bBh9azcvVKXvrSs9m4ZQt33/kbDjjgQM466yWoLGP3tm385Qc+wC233oZqhqxatATHdlFpTOwViap1Sq5L1Gryta98mauvvZYjjzuZZYc+l4e2SFpBRLlgMjW2jbtu+i9iv0oUzDIxsYXazBiuTIn8Br2FImGroV3wZFgmWIZDEMckSUytWkdmEUW3D9uw8UoVDENQrc2StmbJhEWrajC+czPuiInj9mFYim1btnHHXb/mgpesJEl8du/eydbNm3MLk1HpqeA4LlmWEUcxhx58MEuWLMUrFojDiMGBQZIkZWBgkAsuuICjjz4aaVnYjsOi4WGklKxcs4aZ8XEOPPBARhctIk4T4kwxOz1LvVZjYHAQ13PzqJzCMAyWLl2CEGBbFgXXwrIMiuUC/b09xFHE9i1bGBwcJo1jlALbcZianiZKExrNJo3JKTJfi0QbDZ8oiLANA8s0aIUhvt/CMAxsy0DKjDRre6t0PDAKE+IkBiWwbA/LdqiFAXEU6RIXtk4Bz5RACYEKI2ozLQwpqBRd4mbAyGCFSrnM7j1TRHGIMARJlpGpjCTNCOMUoRRpmiFRWFIhVUYz0KUKWjH4KiPOMow41il9mSJL9FgjPyFNtVGwJPRXbPoqRTzXxnFM+vp6KJVKCDIqvT3UazWK5RLlSoVm4BOGEaZl8rff/Okzayj2I+wrVXs+TFOHgQAazSZju3YxMTGpkxeEYGp6mkajiWmbOkSVtQm7xDANypVK7gFI6evvQ+QhnTgKafktfL+Zi75Toiik1WzSbDUJgwCF0uG4JKPR9PE8l7QAKhMkqSJJMgzDpFAodzJY/SAgCmOUyvBcj6GhRQghCMOQIAiJIp21Zhp+/nLga69KGCENpcPVKiHNtBc0UwrRyeDSRCFOYtIkIVMpClsnaZBqIThAmzS1/0/pOWEuE4zcc5R7Vcy5LMQ40fZvTiRvYJoGRu7ZbYfO6GyrXZIEreFK8+9ila8itYCfiHqjgWGYeJ6LaZooFGEYYpkWpWKJUv5iGIYhpmV1Mvvanpw0axMoHV4No4goDInCKB+3Xp7EMRk6ZJokGVmaoAApJGmSEUcJUvo5ycyF29LIRfQGpmF07jvb1pl9bi5Atx1Hz3mey9DQIMViAcsy8TyvQ8baHi7HcbSIfZ5ny8i3PT+TUc0jqI/1HORL20/NPj4/1t9PvYZqvyJNtmlQsA2EEmRERHFCEMVEYR2VGSANUhIMCX4rJVZN4jRGINmxaycz1SqG7UAGzVaTJIrxHSfXbQBSaNG2ACVzb4PKMEwL07ZRZERpTJZkhHHYieu6LYdmtUESaxFmpdxDkqSajccJD9x3P5u3buE3997Dtm3beM7xz+HII4+g6BVZ/9BD3HfvvTy87kGmJ6eoTk5y1b9dRbXR1LFo28MtSVzTpIUBcUZPb5kwy7j37t8w2/CxShWOeN5z6e818Mea7N7xCPf8+tdsvO9uDDPFcVJM0WSgKBjoKTAz2cCv7SFNEyLhkKUKw7KxpIXCII6h0YjxLBdlK5I0Ikwj4jQijAIOO/BAzn/lBRhuDxMtlz21WYrDfdiuQaNVZWZ2AsuEU046gSMPPZjJqWmiMCRTip6eCitWriRNYuI4oVwu4To2tVqN++9/AN/38VyPJUuX4JgmrusyNj5OFEXc8vOfIw2DAw48EGlIrrnmGp53/PMYXTpKmkGz0dBvWWlKy/dBKYYG+ukpl+gf6MOUoNKUKApoNJpAhGvFRK2EZsPHlrMkiZbt26ZFq9YgShP8VoNavUnTqVPwCjTqPkEroOA4OI5DM/RRhiDOUrTWVWk3eV7QSSLJUqWzKFNQSYoSsQ6nJSlk+q1MZTqDTSEwkwQ7yyADESVEQYSBIJIBJCkihTjMiHNhpsoyJAmeZ1D0TIrFAkP9fbiWw9TENLXqLCcceyyr16whyTKkIUmiENex6entJYlCWo0aRc/DkpAGASXXoFx0UGmMymIGB/oolzxmZqfo7etnanoKaZoUigWt/VKKaqP1DFqIZwc64adOTR4w8qyv2WqVam0WIRwEgmZrJj/32ZzWxRBaZ1Pw8FxXE4ZMUekpaW1NLhKPo4gwCsky7RpMU53pFYYBYRjq0Fkc0Ww0UQi8QoFisQfbKjAyElDwKkgp6B8Y0lqpKMayNREK/RDDcvAKZRzHJk1SwijQxCkvfyAME4SBUoIoTrDQHppMu3nJlARhdLK0VD4hZgqtWUqSduUO7XlSuacpdwPNTZ8q/99cuC7NEgQSMBFSIHIvlui4qPK0eqnLFwipl88vDwLopBIlkGqOBLQTPYTSG5SGJFOCNElI4gQhwVa2rreWJZAKwiTCz4mq7/u4rotpmUjDQBhSyxukpoUK/WKUxLG+XnmZlSSvl5UmKWmq9xXnnqmOiD/VpLod+k0TnaiUiHaYFuaTjnYmnGkaGIZ+sRNCYNsW5UoJ13VwHAvPK+jMPtfNiZaL53m5rkoL0jWxcjthQCnn6m1pj5fWk0nRDnsaefLCfGKV27sFGaiPfn5AexLlvGN56rBfkSYpBCKDXLeNEBJT2MSpvpls08YVFhBDmhAFIanKQJiMj4/jOg5hEGpBo2WjMkizvNCZY5MmOqXWtmydGp5lZDIvqJanu2bZ3IMXRzFkKX6jRdDShsAxbZYsWkIYxlR6+5itVtm4ZTOGbTE1PUUaRuzcup0ffO8HVCoV1q1bx9bNm/nJ9T+hOlvVcfNmk2VLlnL8c5/LuvvuZ+uunSyu9OGZrg4NSYtIRcRpjFIpvt8kiQNGhstMjk+w8cEH2brhHiQ+xYKLaaQUbY8li/roqzjEjQZpK8aQFimCNEuJwogEheuVcB2XrK9AHKa0WiFB1CJLAvr7ejj00CN5zjHP5eSTXkBL2dxw+8M0ZkMcZWB5HrbrYJoCxwLSiF27dlCvt7BsG8d2aDSa3HP3PdRqtU52TxwGgGLLlq1UKhU8z8MwDDZsfISdu8ZoNVusWLEcPwyoVWsUCh6jo4vZtm0bi0cXo4TC9wOkYdDb16ejiXHUeVtEKGq1KiqJCXy/o/kwHYlXlGQh+H7KxPQsKIHKlHatC128T6FAmqQp+VuadjkrNfcmYwhBpjJ9T6mMNMkfXqFLjKoMyBRxokiyACFCHQrWvAgzDJGAkZtyoRRlAbYjSJOIHk/iOII4bLCoZKKEheV59A4O0ttbyUeR0ddXxnW0wHvRyAi9pR7Gdu5i5/btHHnUkRx04IFMz0zT29dDGAVARm9vD61mg6npcYoFj2LRIYl1yCFLY8JAkCbgFBVKBYR+jZpIydKQZjOkVlVYefG9sbFdv1ebsL/j0W/U80N1hmGQpilhGNLb24vneSRJQrPZwjQVpuWQZClJlqCyPDsBkKkgzRRhHNNq+Z1JxQ8DHLuuM6XyjLksyxASnUJv2ZTKBcqVPMySk60sUwgpMU0Lx3bIsozhxUvxWy3iKEQBYRTiixZJBq6S2JaLYWpvfxTp0JnjFrAdj1azRcsPkNLENG1MS+vrlEgxdQ4nSgjSTCAwkYauG4VSZCikYeZFVrWHlfxFt13NVaBAiVyXpR9CIefVcFICU2mylpERpzEy0yEjhUKaEqlkp9aTylP1FYooiUCRh9/ybDQp9Iu8yjpJJ20vhwAyUjIlOqQnVanejsg1VTkZbrS0ZzlOEkxDl7wxpMQw5zLmLNvGtq28ppaF5bg4ntcRw9uGiSGNOdsax0SRJlda4xYSBEH+Oeqs066zpQnJnEZJh3Z1gpMQoSbxQmFISbPZ1CVZOuRnrm6XlWf3OY7Wd9m29kK5jpvX8dJZe8VigXK5QqVSxnVdbUdMq1Mmo1Qqag987lFXWbtkgr4/hQApjHx8HcV/h/w9Hcl6+xdpMkzSNCZLFMI0saw8YyJRSGVgGTYmkiyNUUmCUok+0SiSKMY0LFr1OtIwyVIdN04BZRg6Wwj9xh5nGULqIoUqzUhJMIXWvCRxkr/VtQXoWiCrUoUlTSbHxiFURGHM1J4p6s0m9WqVUrmkdSAIdu/azS2/uIVVq1dRLlc44ICD8olaP4hSSp5zxFG84pWv5Kaf/jdfuuJrTIyNsXxwKWEY0vJDmnHAopERJmemabVq2FJx9GEr2PjgfWx96A6C+gSuZRE3m/hRndS16HENspZN0BI0W2BZHsKUSCMlSQLSVGCaHipzKRZLTLR20KhX6R+ocNghR3L82udw+OFHIc0Sv1m3iZ3TAdtnEmKjjwCDVJq4hQJJootAPvTgOn70X9fjhwmLFi+hUimTxBFhEFJvNIhjHXP3PJelS5cwPDyM53lEcUxPoUDfwABNP+DgQw7lucetpdzbx/ZtW6jXqqxZvZpGo8GypcsolIo0Gk1d4NHUQmYHEIZgemaWMPTpKRdIwoCgFRAnCUpobUTUUog008Uns0x7hbIUEHkmizbDCZI0SonSSGebKYUfRzTTBJUlOJk2ipZtYJg2wtLGzvM8zHbdEqWwXRclBSrNsAyZpzwb+p7KUgqmIGr4VGdmcWVGT8UhCAKKRRvbc6jVatiOjZKSYqXMAQet4sCDDiJTKb7folAsorKUZr1GuVSiUiwx6CxmpChxCQindjCxZQtBqYhpG6RZwsx2hR/51JqzCENQKHsYjkkQBQSRr8M0wmCm6mAlgmajiVWdRVoGcZyQZAmGaaKA2drMM2ki9ks8Vjp2+3Oapjkp0kUigzDETAU2sqNrwTDmvB9CkCmdyZbkld6FEIRxgiH9jgerTZyEEJiGkU/IdscTQJ727roFDNOkXavDMAU9PS7lci9JHBGEejIuFEIqcUymMk0ogCzPvBSAaZkgoO7WSdKUpu8jpaGrcCcRcRQQxxFtR0GSJB3PjyLrhNny6RKldGRgfjacyP8QeSbZ/FPa9liBwBQmKH3+siwjzQtD6gui/6NrqWWQzl2jBQLmeZlsSuXeLKWJZiccKPJx5GGv9jpJlsyF+DJFnLZQ/ryMs0yv1/b0WKbduT5apK0F2u2XlXZILbUdHMvKhyfz8hMOhYLK63NlHc9Tp3RDrpdauCwhSeZ0VO2K9DqjT3updF0uXQRVtcmiYC7EZ7WF6HZH02SZmvBJQ4dGC4Ui5XKJcrncCf/ZtkWpVNIlenoruK6Tl5AgL4dgYJkmlq1JvCZklk4MEEbnnsmUIrKeeoqzX5Emw7CQQiGljeOaWJ5FPc8IMQ0LyzAxMlBKIJTOjZapQhqCUrmX2WoVmSm9DWGgpO49YZs2aRpj2w5K5VlUCIQpQErt1TIMDAxiEeVMXruGwzgmihOGenvpKZaJA107pFKq4PsBQaNFyfEYGhikTA8jZCAEJ518EieceCKWbZFGMaVCEZUp/HqLgaEhlq1cwcGHHsLI6CJ2jI/xD//4Dyy3LEhSZmdnGF6yiLNf9hJuu/vXGCLDIqLHUbQmtzK5/cG8+q9DsVxicLBMGEbsHttNX6UPx60gzRikhTQNTFuhQkWShvitFo3GDFKYFIoui5cMc/Lzn89LXvJCFi1exPhElUe2TvLwzik27WqQeP24gxVCJcmkxC04NJs1UOA5NitXrmDx6AqWr1qNbVs0qjVsx8kfXp3uazsWfT09OqsnivSDY1ksX76cRaOj9PT047gutbquhN0KfMYnJjBtG8uxqfT2UiyXqdcbNBpNojhCGNotHiUJtBQlzwIyhEgxpS7GmCLIUoMsyYiTDKGSThTcMiVpEiFk7o42JNLVWUJSKWTRwrBMTCmQWUbBhELRxXMsiuVCRxMwPDyMFOB5Llma0dfXh2FI4jDUQttSEcexSJJYe5RKHtNjk9z7m/sJ/DqDAyVC36dQcoiiANe2EIYOGSjhE/sT2GIJSiga/jTV1iRKpQS+j1+V1GxH65Yin9l6iEGTsDXF7h0P09vXQ7FcIoxD4iwhUTEiE8zM1LEKWtQfJiEIsIRBmrSwMy0YjbIQI5EgJYYUpMTEcYrj7Fcm5Q8W8wmUEAI/8KnWavi+r6tjYyKMmDZTEB2iMOchaQvI217TtkgZkZOFNMsneAiJEX64gAQolSGlgesUkHn2cbsQ5nzPgmGYlEpuHlLRngcjr0athC5focXLmvQVi0UdwvE8fN8nCnWYsFarMjs7Q5omnRYoaaqLbKZpQprpkq4qr/8h8oKNbb9Yzkh0GRAdN1sgaVFoT7CWqJoYGHmpBB2+aBd9bOttHl2Usd2+xZBGrp8SeWg8r8ydZTrE3iZNKJAy10QZeoS5l0QTQDpide0xyUvxCl3wMsvaNeW0FyuOo5w0+8hGXkdJdJRcyPxFzzbNnERZnYw+/W8uu68dDtMFQ+no3vS8kXQ8UUEQEoYBvh8QRWHunYpzD1QuNFdaK6VJl95GmmjiFZkxphljGgbSmDtfbS+RIS0sWwvQpdDEyrQMXcKhUKJUKlAo6qKetm1h29a8sJ8mW506Wa7bqYslhI4M+S3/KX829ysLl2YKie7p4xU9Kj0VlJCEQRXHsJBCIAXYlk0mIAl1obVi2Wbp8CLiVkCxXEFIAyWMXAsAxUKBNIlwHZcg9HEdmzjWWW6GbeGHmigZUmiGaxk0m3WCwEeaJkWvwMDgIEMDwyRhRBIlDPQNUa/VidOU3t5ejnv+81h96IH5jZNx/InPp9zXy5133kG91WTxyCJavs/23bs46KgjMIsuVb+B11dh6ZoVxEAsoR74BFHCkiXLOOH5JzDTnGXz9m1s2rCewA/Zs/0R+is606Ovt58Tjj+Ro44+kgfXr+dHP/wx1WaDnsoAllsgjhMMBKVyEcNMmJicIYoDHNcFFfPis17Jc487nqOOOpxKX4Wbb/4VP/rJzQSyh8LASsy+HlKzSGI5xALsgoNXcKjVdiGEYnhokKOPPpqDDzmKvsFBwiBkYs8eTNMkikPaQsooF0P6rSY7tm1jbGy3rvHSrDM4NIyUFnEUsH3HDqIkZvHoCJs2bWR8fIJao84xzjFUKj26zozK8rcd7aOVAqRQRGGEIcGQ2rWPlGQIogwsaePYip6CRRKneK5JpeIRxQE9vZW8ordJb08ZQwpUmlD0XP1ZQBzEmFJQ9DxUllEoF5BSG9K+/j5afksL0zOF4xmY0iCOUpIkwjRreiIjwrQy3KJNZdghcxW1WpOiEAQEWDKjFlUxi4JUJViOQbXWZPP2LfQOFbFti3qtntcKU8RRRBLFuLaN7di65IJh0IgN3D6brJXiCx/XLiAtA5lkWFh4BQ8/r5NjGhaOqQMhaZSg0gTDsnAdG99vgVQgVC4Ylpi2TUT0TJqI/RKPlUk0v11Fo6E7ENTrDeIkxrCcOS+KyFv7yHb9Hpl7l/R3SmgykaTtrK+2hyT/O9fJZPlknimlBeVKv1wmST6hC7GgWGSbNNm2mddwspB5Gr72hlianNiaaAOd4pulYplFixZ3yGEUxXkPt5b2Rkchvu/TbDU64nTfb+ahpVB7PrKcVMUxSRrpauWZ9oSgEhBZ7p2Q80gkICRCGFqgLRRK+6sQ7cy7trdqXjbYfOEyOVmaa1+yUMis68HpF+42sUnTpEPv6Jx3HUZUSjBXLFKH7EVe1VyL2ufCpVmaaHF31t73XI0soXTYzDRMDCE6bW465QNyUXe7+KkuH2As6HvYPnYpRa5XslGq1MnC08Qy7Xio0jTRpXbywqZRHBJFMUnuqdKlgBRhEqPQGqok0V7TttdItAusKtk5/20Ple1oomTZFpZpdGpjeV4hr0VVyGtQFTpC9LYeql0B/6nGfkWaWr5P2dONRlthgOFrN6QlJQYZSRQAQr/5pookyjQDDiKiVgtDgSl0tY4kyzCk0B4nlM6eyKvepllGkqWkSvciS+KIWGmNQZQmpGQEcUICHHToIaxZtYrZqWkMYdDb34dtOpjCoG+gH7fogRCsXrOak08+mT3ju6lWZ5FCMjszy9TEJLMz05SLRcIoZmJmirHJcTbt2kalr8LU5AQ/vuEnuK5LLWjhJwm2WyAII+644w7Gxnbh2RYqDBFJzHHHPofnPed4lFIM9A9w6qkns3h0MYYl+dnPbmLn2DhRlCKESZwkuG6RcrGIbSdMTKYgMl728vMoFsu87LyXY1kuD2/cyAPrN3DLr+5h845ZRg48jlXLF+OVCqSJwE9BZimeZWKakrAREAYxvh8zM13lv2+4EYGRV71t0Ww2aNRqlMq6/Uyz0WDZ0iUMDvSzdfNm7rvvPqRp6oriQuC6Jd2axLap9PWyfPkKVJbmGowyMzMzug5X/mabJDFZkmCZkrKnywAUPBfPsXEME9e2KBa1dmtqdpa+ngpZ7LNk8SB+s45hQF9/iSBsMjTUhzAUgpSenhKolCyJqRQL9Pb2kIQR9VoLWxYwpEO9XkUYkjSLCGKfVjVmamYG07JAkjdtNpECnRoethA5+ZASenr7kIbLjqlJZmozOH0h1WqDULo0WwGDAxXSNEZZJsJRRFFII6xRkB6mK7EMC1SG39Q6ErtgYDoGAoVTcqm16hRLJQaWD2sdg60wTZuoqV8gMgVCSWQmMIWBKSVKZQRJRqo0eZQGKKGFvFmmtR6mZYEQnZ5eXfzPMefpUZ2Ck/V6ncmJCWq1GnEU4bhaIpBHgDrhKJlXu9dkAFCqs147BPPoVHBpyEd5HHRCAkrbRyHbSjvybFDtpYrSiDiKabWytpOHLMt7qrm6VpnrOnlLjfY+tQfEqxRwHZdiqairz8dxnnGlva5BGOD7Le09rteo1mrUalXq9RqNZgO/5efC8lb+uZVnlAVEYUAUh3ll7RSV6eSKNNPeJ4HQ1fRlnpWWzhVs1AJlo5OC366Wro8tJYriBbWI5lqSzBHWtnBaGjocmSQ6u60dzpvriSc7+pv2uWt7fAwMbQMFuVdKh0vn6h+1W3vrsKPMCbIWrKd5f8AIP4C2Mr1d6VuHZtvEWofJLNvKdW1mJ/zn5iUSHNfF9bT3v02w2jfV/FpWugioDtX6fkAQBAR5Vm2cZ2XrsF+Se+20R6gt9G4XJtVRiJQk8fF9Xy9TbV2VnHeN5ryeHZJl21iWvnY9lYq2S08x9ivSFAYRQwP9KAJ279lDtVZj7dq1lA4pouIUmUGr0dIakEKBYtEjCBpM7BljbGwXqZCM7dpJLYyIsrYuSZLmNTuE0Hr7Tkbp40AIQblSZPmqlRyzdi2/+PnP2bljB73lXlYsXU6z0aSnXMH0bLZu28pNv7yJe9ffy623/Yo94+Oc/oIXcM6553LowQdrAbpKEYbJ6PLlNAKf2+/8Ndt2buWWX/4cIxOcdvopbFy/hd7+QUwMxscn+Y//uArshDPPOJPjjn0Oo4uWoTB01eBqjWarwczsLDf97CZ+9vOf47fquJYkDptkKTiui2dbOJaJaRUpFEoo4MwXns7o6FK2bdnFf1x1Hb++/dcUekZYfeRpnLr2SKQ7hOUNsH1sksw0kKSIOCPJQtIoxVQ2k9M+rUARRhl33HEX09PTLF26lMWLFhGGEZbj0t8/gGWZNAoevf19HHXMMRxz7LEcu3Yt/QMDLF+xgmKhQIZkdmaSYrmM43nUarPUqlXSJEEaBrfecgthEDKyaBEoRavRIPBbeI7F6KJ+Bvt7WbViKUXXwiTDcy2GBofo6+1n0+ZNLFuymKmpCVYuXUJtdprZ2Sls1yRTZQoFhyBq4vsBzdoUYRiQRCEN22JmwiFotmg0fSqlYWzL061L0JV8lUqxHItmK8C0225jfW8Ztk3JKuGmNoYlcWxT6zmEyOu2SLLMpadSJvKbWKaRNwrWb6SB38KSUO4rYAmFyhsMx0mitVlKZwAqIM4SpGPqGmEqI4pCiq5HFMaYGNjSJpUpnufRajV1wU7TwjG0xzJJY0xpYzkOUlpkGXhusZMsIaQgS6HlNzHkU2+kns3YlxC8/d/5+pl6vc7k1FRepNLHMB30pDs3yesMJxPTpFNpen4PMcPQsoIsU52JR4ebJArIMp1tleaiYP0CamFaul5YJ1WftrdKLpiIdYHF/BgUJElGHKekeR2yudYsOhvMcR2aTR/LsojiqJO6rvvICgxp09fbT29PH6NZRpJq70WS5BliWUIY+dSqVRqNBlEcEwQtWs0GTb9Jq9mg0WzQbObeqqYucxBHIX4U67C1SjsC6HaFbV0YEkxDdcI8mY4raXIkjZyktMlP+2LmQle9qs7wEyCliW0b+eQ/P+TX7tk2d31Ae/WyXIQ9d588Wu8GmovqbgNtja1SqtPLKVMg2tW5hSJTgtyNhVRG7vnKCMKIlj9HGvU9pHVDZk5G2s2cLdPMPY0mruPpwqYdr6PWKJVK5U5oeL4Xrp3U0Gq1OmUU4jjG9zVB1gVXo04x0XaItn3P6ENTxCpBZVHn/Akp8wiQo8mdrUXzWao1pk819ivSJISFMCT1WgvHdXnpOS/lU5/6v6xcvpK06UOUsHnzVsbGdjO6ZAnLVy5nYmKMn914Pffedy8pgu07d/Hw5q20gkC/GadZ521CS5xy92hbXDjPJTs3Ds3os0zx61/fxa233pZX0k3ZPTnJAw+tZ3JySrdkkQYIxYOPrCdKtWDWcSy+//3v8aMf/YhKqYRt6d5GqUKXOpASs2CTSYXtmLzpDRfz0heew+X/8FXuvO1ePMult79CRJNGOMumjZv5wue/SKsZYVoeSapY99B6tmzdCCJGSknBs1m2aBGHHrCGsZ27mZ2u4ZgC1wTXNgkSA9vymJqd5B3vejuNWR+Jw8jigzn2+PNYc/jzUM4AY1MxrToUyoqgkVHpK+HKBE8YSKWwEpM4UPiRyfI1hzG0ZAXnnP8qBvv78mafMZZr47daRC39FpEonSLbrNUIgoCVq1ZTKJeYna1x152/YcuWrezctY1Vq1YxODzEunXr+M2ddwJQKBSYnJzk2GOO5fknPp/evt5crwBJEFD0PJaOLuaIgw8mDupMT+7Br80yHjZpVvdQmx7n4douZmemEfEs1ZkZZmenKZWLSAOkKYmSCCEUnueRqQy/FRH6EZQNioUepFGg3vCxXYhVhlf0sCyTIAo1GbYt7ILu/5SmKVGSkEUpjmsjpcgLDKZ4rk7NLhYK9JUcMl9hpZKi6VCyPGKjhcygYhdpNn2yTFDqqVAwPaQ0iNOENMkwpYU0beI0BsC2XYQUNOtNPMcjbAZI4ZA1QhphA9vxSNOMHqdMnGnDmCYZQZLoSQMJ0galiEIt+iwUC4RRhFKCJIxptZoYpknBKf3e7MH+ike3qXiswpbzlzWbLWZnZmg0GpogJBlOy9fEpV1Xx7LyTKu8V1l7gpN6ojfNtoYlJy+5d3POm6K9D2qeF0JKA0NKhJK6Bt68lHtp6N8I5jxX7YSHLNO6lkQmKJV7uZJEE4k2QbEMXKeBaZiEsW69Y5pmnjlm43oOnuvklfVtSk6ho8XRDcJ15m/g+4RhqJ+tWFdBDwKfZrNJvdGgWp2lWp1hdrZOq1HXticMieOQONfuRNFcfSatPRLEmUJlcV6VW3tUDcPCyDPmpNEmjLmuqCNaz70m5B7kfP12mE97aOamlk6YtO2pklITw6zdtVQseItvh/x0WDYfSyfcp3dBJueyzDriqTx0aLTJdN7OpFOGICNJ5tez0oKrthZufqFV03QoFkpYlg71tVvGeJ7XKT3gOE6HbLV/lyRpp3WZbrsT0Wr5eUHWIM/ia2f+tdv95F5VRS5GTzpFPtvidaV0PaskThGkkIvTBf/LSVNv3wBBFFCrVrngda/hc3/7eUaGBsniUMejTVh98CpWHrwaqbRQcGTpKK98zWs4//yXoWxdtl86uq5MnKa6A7SpBYVSiDwOm3UyHeIkRmBgW04u0MtQZGTo9gbSNJitzfLwww8zvnsPK1asYGLPBB//6McY27WboYEBkLph7K6xXSxbvpQ1a9YwMznLyPAiXvqSsxkZHmRsbCeGaWA5Ll65RCuJEKbiBz/4Lpd/4R/5xF9/EpFIRGZRtEuEmY+f+aQq1m8dSAQWCBuEQSbAMB0s06avr8SRhx3EUKXCPXfdRckpsnTNKsrlPrAdSn09zDSqjIwuZsWBK6k3Z9mxfYzDDnkeLzzzVRR7lrN5V53xiQamVWGwf4BKpY+kVcchYffYdsyeIiMD/ZRdl6kphekUSLISv/r5z1m6ZBEDfT1s27qVwG+xauVKbrjhv7nv3nsgyxgf34NpmRRLBYYGhhBScO+997FzzzimlCxevJhzzn0pQ8PDhFFI4PuUy2XWPuc5rFq9mmq1qlOxG03SNGV4aIhSqaQ7klsWrVqN+++9G0ukZJFPlsa0DEW9ZmLbkvrsDGXPwRSKLIkoFz0G+nqZqc4SBiHS1A+e74dIaeB5JQwpiRNBPY1QQpKgKBcdRCxIRUYU+ggJvX19+gFPFSqvraJSRZxpLYYQujmwY1iIVNGq+ahQ4tdDslDpgqvCwET/85sBsZkQBDHVWkKUWIyOLsO0LIIwQZoWluWQxDrEoeIU07ZxCy6u7WIKg3owi907gHJSJBLXc2i1WkShLkGQpinNwCdRWZ6dY2p9i9IlP5Ik0W+nzRaVnpJuswIUix6z9W6dpt8Gj1Vy4NEZdWEY0mg08VstavUqabWGkGZHq9PWcbRbarTbaBi5h9IwTHp7e3EcV4f98to/7cnNNE0dunYsQHs7TcMCJFGkyNI5T1OWqTy9Xs/IaaqzlNNEl+NQ6D6chmGQxBYC7TkReWq60JXxco4hEEgsaRMnMX7Dz0OO8/RZub7HMA1MK/dq5MTJzGu52baNECaebVEu9HQIoxS6wniSJoSB1kDFsS7i2WrUaNRrzM7OMj09w/TMDNXqLPV6Qwufw1z8HCekmd5GFKWdTL6OEDv3EhnSwDDzayG1vkuRESUhcRCBSvOaT0YnU7qtIUJp+5CmGalAE0vR7tGm3Uid+0IpVKpD5Fl+DtticsvQnkPyUK12AMwXtutQ2rw7Dsuy8sreqhOqbGuY5u7FOc2VUoI4CmlkIITf0XK1r9VciNLsEOn2cs/zdKZcfs1c16W316Wvr79DvHVIWgvJwzCk2WzSaDQ6lcTb2X1RFOYhQF1CIU0TokhXQpfSoLe3l0LhqX+J269IU39vL7XWHgBc29UGJs1oTkyTRSFuTxFpSZIghlR/Nz09zuz0OKNLRtm6bTtDo6MsXroCt1AAKRCG1WHUkAIOQhi0anUaM7O6VIFl41YMvGJ+AUwDskT7Qw3BooEhBo/vJUkSLMMijmKOvvo7pFGKSlNUrvvYvGUTo6OLaDQbbN+2i7AVEvot1t13L+O7dzEzPU0YJfQNDbJ5+2bCOGDzlo3UJxogJJ4tsLAJVUhsRuimBA4Lg4laU9Lm1wqYmZnill/eip3f3aUsYydbEKZNLAyUaeBnMQEx0rEhVZz9kpdz0Jpj6e9ZRLF3EOkOUKmFKKNAsdyH32gRFmNKns0JhzyXYrFArd5ABsuo9Qh2T8/QXzG4+tprefD+eygVXFrNBsPDQ7zuwgvp7+/j4IMOZvmypaxcuYIkSTEtk5HhETzPZXpmhmK5jBCCnr5e4jiiWp3FMCTPfc5aBDA9NUWz2WTFsmU0my2q1SpxnFfJjUKkITGFwLEtiq6LbSmaswEFy6Gnt0IQNDEMsItlfN+nOl0ljmIcVxtg03KwHBdhGDq0hZ7QoiRBxVpTJU2HLEuxXItm0CROQgqeh1ew85o6Nd32IdO3l22YGLaNzH34hkS3P1EZYdPnkBUH0PIjttlbUZ6FKS1UHuYwbO1JkIZEmQYybiIcg2K/bkwcNrWI1s90o127VIQsJg4ikjBianKKvr5erFxf5ZYKpKJJM2wSxAFZCI5n06jX8cpaIO4HPvWgjlACQ5gITAquR7Hs0YqaxEQIaRMlAXE9pFzpe5qtwLMP8wXfj0WadEaTvq/TJEahhbuZysgy/bI332PR9hBIqYX8UhoEgY/jOLSr5GsNkZ2LZ61O5pv2dph5MUMby3KRwuzodgxDYgpjgedh/r80TQkXhFb0uqY082rQuhZaOzTWJmKGNJCOi25Wm3UK1UJKLBI9EVsGlpXloUgD08iI4wwpg87YdIq+Di11CJZp4ti2PhcFxcjQIiSKKAzwfT/3drRotZr6b79Fq9Wi2WzSarZ02xnfp1arEwYBcf5Sonv86YzpLE2JI0VMDCRIqTP2UhIylSCkwkCHOjHoeG/aep4sr+Y+V2ogzxLM9U9Szrfzmry0a0u1mUuWppop5V4nPbXN9bebP1e0PUppXu8v3+qC30khUPN/l9emU0pXM2+HDdt6pDbm67vmkyZ9z7l56QRdNbxNsOaKYrqdptVKqZzUW3MC+rwmVpLk9adC3aev7fBo1+crlyvziOdTh/2KNMVxgt9s6hslj49neR8wKW3CZo0wjTGkiWW4WNKkt1zBdUyKvX2MximFUpnZqWlSoNRbwbAspiYnkYbQrQOSmJHhRUzuHqdeazC8eJSeSgWpFHGtShhFCFMLBKMkwrQtipWSTvE0LO79zd0MDQ2xZHQJmA76lUB33F6+ajnSMoj8gIPWHEwcpURxTBT6xIGvq2YLA1NAM2ppN2UaEwcBKhaYdi9C2aRxRmJlJNJHGHkn7EwhpYWQJlknBRWESlFxSFirEdYapI0WD92/jj3jM/z3z37BWGuWNFRzj5Kv3eSeU+CWX97Kxo3XoKQLZoEoATCxHQ+VxBgGHHHYYfziBzvIMoFX7GHPxBiVIgwPvpElIxaf//z/n4ldOxEqQxoGjusyNDCA32wyOTFOb28Po6OjtJracGVZiu8Huj2D67FixQq27thOrTpDHIVI2yaJQrJUYRkmnuOSxqnOFMur0BqG1EJwKXFsmyDwURUXA0lPuYgpMuqzk4ztHmP16lUYpkUYztI/MEChWKDZamhipARRFJOqME8dNkFIDFNPMpZp0S79Io2Ull9HCIhTA5nqiaxWq1GpVHAMGyUUQqW55gBQma4VlSlElhK2fKrTMyAt4hharYjAD4lSKID2LEgz19xl2LZBlARs2LQR1ytqV7hlEscxAoVhmtjSRYURcRjhWEUKXg+xEzFdm9XhBtNCug6WIfHjkJHhxfhZRCozwihBWCYHrT6cgw88CMsw2bp5Gzu2bWemPoPtWQgpaTZ1Jl2h7GG5T707/H8DHt1KJcuyDiFpL9NNdXW7ENd1MPNCk+3sqfaE2/YQ6CyldjhOJw1o3cxc/Z+5/xod7dOcFkriuSX6+4ewbQ+BkWddWdi2/m+7SrQWTevftEW/WVs3KubrhfSUkyQJKtHeKZU3m7by4odpphbUDHp0HzNdPNHsVElvNf18AledEgoqDyG2vVvtprpSCBzbZGR4kJ6eHopFl56egY4Aur3NJEl09l6zSb1ex/d9rSubnKTVauH7AbValVqtTq1WpdFo0mw28PNQoW5No2v62dJAGDYYqpPhOJ/E5HIpsnQuKy6TOrMuT1qcFyJbeN9oJ1CbJOnmvpnSWYOCvDAzbUmJnFeeAbRCUrdVaRNckSdH6V5/ufBaiLn5YY6fdRboawOwkKDIPDlhTtOUEYY+09NTC+759gtC2+NVKunmxrbdbtlSzMtU2LrHYTvcJ73OudOlKLI8A0+ftyjQob+nGvsXaYoiRBQjAFslCBR+q4YfNPKGkhIRx6RJRq1eJ00SKj099C9aQhgEDI2MImwLP4pIsizPaBK4ntNJUUyjCMuxMSxTN1gslbArFeJmQ2c+uTZKgWnpGkdSStIoIUoinEJZu4s9j1ajjiGbWI6NtEyyNKE2O4Mgo1jpoTI0qFm7QvtRZS7SUwKM+U+GLqqpfqdLpYAYohjVCti1YxeN2TpnXfBqpmp1QhIiMlKh3b6hH3H0UWuZna6yc8cE45M1mq0I2/ZQSHZs38n6B9axe2wHI70HUTQL7Ny5iyMPWY46uMRXv/Yx/vxPJSefeCyL+npZPrqYnt4eCqUiXrGEKSWDfWUOPGAVCN2407ZsBgcHyLKUIND6BCkE1ZlpbClxDBPTUpgGhM06vh/heAWGh0cYHFlMtVZj19guqtUqtoQwSfCjgEKliClivIIDaUir1SQNtWs+CpukaUQKebZRhh/6VOt1EjFX9Vcb4gwpwXEs0kzpFgZZihQGcRzgeBLTkBQLRYQhtLg/S3BsC8+xMYWh4/NJpt3zRl5LBKUrsicRQRYzWZulVKhoIqwyEgSxglasCMMUr2CSpGnHWNiGyeKhYXr6epmamcKPQqSp6+GkQiAtG8f2qCxeysjQSjKRMVPfzc5dm8nI6O8fZGTRUly3zLbtWymVPTZt20Zv7wAj/cMksWJmKuCmHfcwM9Uii2JWrCwyMDDAbK2KUgaeo0Oh/UMFdnUrgv/WmC/+hjmCNP9zm5AopbRkIJJ0Wkqgp992BtVc+4jcU4GR/y4ijtpERnZ0m21Rt578RF70UqeAm0adeq2ee5rIiYgOl2nPQFvH4uF6ukWGzr5ysS1H9wjNg3FGHjrKdBEitD80F1hnGVEQEAVBTnrIs8baqh6JFAapTIjnecSATjVrQ855NbQYPfeKoCCFRGUIBGkcsmPHGLvH9uiSJFK3b2lnYc1v66FF6y6O5dHXM8DKFWsQQocpkzghyov1BkHQ0VA1Gg2aTd0MudlqMFOdpFqb0WG6OCYMwrxvXEAcJ3mWHx2PkJmXZ1CZ6rRiykTW0S8tLOYpOsQIYaBETn5oX+M5IiPy+0JzFDG3VBqY5tw9OJdE0A4X6yiMtjlzhMu02qSs7RFrl05oR27nsguTJF5Ajhbq99rlGhS+rxO5pJzufK8THbQ43ZDmvNBzOwyYi9QtA8tyOl7I3p5eevt7fruH8UlgvyJN5UqF1nTKGcc9h5efcxaloiSNA6QjUKaBgYMgIUt9LbYDMimJlaBabTC8ZDEYEjdPQ8xydlrprWBYNuTF+ZPQxysV8QpFLDKmd+8iCkOKhYLuo5SmREmsFfuOQ6oUaeQjPMXSpUup9PQQRRFZlhBGAVmg3zCSWBf5isOQNMnyOiGCONINLguFQscIGoZBGseddOAoatBotrCkgVcuIYUkikMUiigICVthnkavw0KpUlh5o0kpBVmqPThplrLkiCMBOPjkE/d5nuM46fQYAvCbCb6fYNsmSim2b9/Dww8/yNZNmznm6KMpl4pMTI2zavVKTEMQNh+h2Zhl5yPruX3LZpYuWYybV/q2XV2YTAlJoVgkTRWzs7OdFgHlcpmeSg9SCPr7+tm+fTtHHX4Etm0TRy0c06V/YAClJEoYjE9Oc8uv72TFqlWsWLGCUm8PtXoVw7HpGRigb3iAPbu24JVdolZMqVJkaGAJnm3SqNXo6+vL+2DptOhWFFEWErfg5RknWrweBgFBGBInuhlqFCZkaYbnFUizPHwmde2vgutBpoiSkJ6eihZ6RzFpkmLZWvwdphLbsTEMk0yA29PDyLJlxHFGT7GXZc2AUn8fwyMj+IGPaRjMzE5R6fHwWy2SKEJSp1wqs3r5KoYWDfHAg+tIZmfAkCRBkzhLsJwCruVSbwYEwTQ7d03g9TQREvoHB6jV4cGfPUJPTy/jE5OcePLBWE6FUnmEwcGVbNmyhxtuvJt779mEyDKOOmSAoaFD6PUqgEkYZAjLpqdSwrUNNm3e+nSbgWcVHkvT1GkBlK/T1isBnawiTRryyUdIJG3CA/MTV9qfdb0l5nmwRO6dAkh1KENJRB72SVOIlCYEqsNk5ode5kIqrufh2LZuteI4FEtlCp6Xewf0C0I7w2pBqxjbxpCSNJ8025N8W8DcJoaCufpAadqe4OeOUwiRS3sEMvd4tD1Onb9ThWFoElCdrRNHAXOekjZ5MDttSnT6up2n2evq/aVSCSlNpGni2B5lOef90Wn3cd6qxCeKIvywRbU2Q71RJYp1K5PADzqhwGZL69SCXD/VLuIZBgF+4Ofep7wMTpK3N+mIusHoiNHzOkuaPs/dF53wnj5X82+3NnkSwuhk7rUJ9fz7Rq+sX/Dne0Tne0IffbeRi+Dn15/SwnU6QnSRi9lRBu2GxFma6fZNeQ0qrY/L61CJOQLVDl+2SZNl5qUHbLtDsNo9/J5q7FekKc0UjaTJK1/9al7w0lfh2BHK9ChW+knijDTMyASYtsvAsIdp6VoZUZIibBtsm7DVRKEFrvotP8NQgiQNMWxLZ1P4gS6W5RZIoojW1CQqU4hSiSTVlW2jIECZFkrolgW6SqzCcV1UlmkjZ0pdsC1o4ToOA4tGkNIg8n3CIECYGUKaBHGEAArSQKI9aqZpkcYJCBtpGiRxTKPZoFgs4JBhmJLI101fWy2fRrWhXeXSIGi1SNMEUSxocXucEYURfqtJKwxJgWq9xuBAP65XRJgShW4bI9A99aRyQUjSJMU0BH29rq4BksJhBy3lsCOWzl2YhLk7SWVc8a1vEMyOs3vXLn71q5tJ87GPT0ziBwFTs7Ps3r2HWq2m3bYKDGkQhBHlcpnhgUGUUixbupR77r6HerVOf18vjfoMhgH9/f0MDA1zxFHHsnX7Fj7xfz/JqpWruPTtb+fAgw/EMLUHJwPG9kywbt1GZCZQccBgb4XevqUUrSJKSiarWnBpewXSxAezSLm3QqFUwrJNkiSk5Tep9OkWAGmmcG0Xx3bzDBiFkClTU7to+VWqtRqV3j5M06Ve9zHtAo6rK9bqzuqCRlPXn4kDsB0Dy3Ip9w2xdPkK6o0G5WKF/mZIKxsHsxev1IeUkmQ2ptqQDPT1kkRNwlAQhHDPPZsobZmlWmvS2zeAV3aRYgbblli2w7btk9x+20Ps3FZjZzXmBc8f4tDDl3LYYUt5ZMMm/vFbN2BkChc47bQTOWDNMYDF5i1T3HzzA9y7fjfC9jjh6KUceXA/s9Mz7Bnfw7JlKxBGQuCHIPqozqY8cP/s788gPEsw36v06JBFG53aPxIW1uiZH05bmH3XTvFve670JGMt0Iu0w18wJ+QVQmiNkaFFw2lee2vue5ln3GVEsU8U+1Rr051srfaE7dg2nlfotGbRXqkCjuN2MqxKpYqe2ITCcVwKXqmTjt7OlGprnuY3a+3IUGGezkZnfyUknRBlOxNa62lsbMvUjn2010RnwbXF20YuQlYEga411B7L/HM0X6RuWia2ZefZY3NZY729feiebAZewcV2Lf0SluuWoijG931qtRozM9NMTU8xNTXF7OwMzWaDmZkZZmamCUMtcg7CgMAPOqE/XSyzXSBSYhg6nGubWouGyDr3R/tEZdlcaK19HvUxzXnt5vdXbUMf71z4dqFHNOsQnLn7jQXbFVJgmfO3n3QIWCfkKOZqjOn+dm1BuU5M0IkIc16uDhFDEcchUaQeNV7Brl07OyUynkrsV6Rp48YNQEKq9EnLMkiVQiSJbu5oSBzLQUhJkuhCWyJOkJZFpVxGRRH1Wl2nhKMzSjIUU5MTjO8Zp39ggM2bN9Pf30dvXz9GrY4hJb39/Z2u05j6yhYrPagkoVWv02zlPb+kxHEcAOLAR5oGxVKJgufRptkqy5uberpOTuAHuvuz52kilWZYnkeaJtiu1izEQYDtFVje16fv/DyUV+ntQylFuacXlkiyJCWJQ9xyUVdHlxIlJfVGA9KMwZERwjjm3rt/wwPr1nHueefqrECVFytDaPGkZ+ZaHUGU6HAnrk4f1b3+MtIw1YdkCsJWC5FlWLaJYQiEbeD19bGqp8KKQ48EFYIwQOWNN3OZ+qPmhQ4EulzA7MwM/3nNtYwuXsz4+G7GJibZvHkjmzdvZmhwiA9d9mFGRkZ4ztFH86tf384XL/8CzzvheSxdOkqlUmL9+nXcfc/9eHaJhx6ZxrMMtm/fw62/2kIShpimIEjBkjG7mzFrFpUplx3cQoFVq1aycuUKEA6tMGFk0SJWrVyFZVpkmaJUKGKYFhMTkwgRY3suvl/DnPj/2PvTZ02y+84P+5yTe+az3aWWW0tXVe/oRjcAEgsxnOFwqJFnhpalkDQeWeGQ5Qi9cMiO8L9g/QseO/xCctiWxxFWyJY0MofD4XDTEARIAAQHQDcINHqp7qrq2u72rLmfc/zid/K5t2qaJMABQDXBDFxU3a57nyXzyczv+f6+y2PibOq1GREqDLn17AvcevY5lgvRRHRtx2ZTcnK68Ks6zft3Trn34R9LoGqS8O577/H9792mKu026ffw8ZJjC3//33idSe44njdsVg2/95Uvc7jp2Z+l/P1/+29z+eJVkmDMeJrxwZ07/MbvvMH33l6Sobi8u8PNGy8zLlKuXH6el58zXA2hbCFNYm7deJVVOeedd9/ny1/5Dr/7je+hnOFvfuFV/ie/9LOMk46v/8GXOX50xM3rz7Izy+k7S55NuXPvmPv3fqyXgL902/mVO5wBqEFbM3x//u+DQ07hcFbOwyAUN5mzMr4TdxYyyhXkIHUqztK1zVmwpQ63Fu6u6+Txh/EMUjlltzdg59lmaJqaJE4IAi3VRa1UsKRpRhRFwr72FipDXWs/glGiN92O8BIfSijJ02mWk6VSPhtHEXGS+iBYGUWGoUgm5DliSRnaisYHNCDXlW2dBxKRYPoe03c09QalRLM0BPJZ5/vl1JkjbhgNnjnWkH0YCeJyg+6qd5jGiH4mLOX9hD5hWyt0qAnjgCQR4Bgnks4dBAFZWhCFCePRhEuXLtO1nV+US5ZaWUp2UV1VbMqNH/ut2Wy8QL3cUFalr6JpfCGvaKoGPdgQ1iwgRUarZwnpQxWNH2Uq/CjuDJxseSQl0RJnoMkhximQ9gOJr3DD86jzn21hCeW4eM0VA9j3z7N9TCU1UVsNm6Lv7TmQeMaSDg/vo7G295Mtu2Utzv2UJ4KTjrlcjNi9eIDWjmqzoTU9m+qELB/TN4bAgnKOo8ePOD09IYoTprMZWVZw6eAie3u7qEBTNzVdueH4+JAv/e6X+Oa//Je8/ulP8+Uvf9lnjYBG7O7PP/88L7/8MleuHBDFIdmoIAgTtFPkoxH5dIrDsZ6fkhWF6J205KegNcqftH3To53XZoUBXW8pNxvMaiVM0mrFbDIlzTJR/49GrFYr5osFxXjE3t7emY7KGJQ2bMqSrusIwghnxO2QJjHGWuqmIgpCcJa2a1kvl2R5zs9+4ef4zOc+T4BDRRGb9RpQ5MUIOgM6wA3zaeVQ+oxmHWhXERC2rE/XvPv9t7h86SKXDy5SLkpGo4I0E2ZKhz2L4xPCOCErRgT+99HBln7VQQDWOz+UnMhxHHPx4ID/5H/3v0Ubg9XgmgYbaKrNhvfffou8yInShH/4f/mHGNPz3u13+OrXvsqb3/kWFy9e5IUXnuf/8J/9Z7z4/HN85vVPE49m2GbNe998gze/8x2SLMH0hjt37vHf/ff/mA/vPeTRSQ0s+ePvP8LxVQCUlouJseCsBzEKHIrWOYpQ80t/+3X+3t/726xLuP3BA569dYtnX3idr/7BH/KP/l//OUeLjs449meaf+ff+iU+97nPMd45wBj45pvv8p//l/+cRIHTCkKNDoSO71tNjO+VigJGVU/ddFjXMZrssruX8+Dwe6jVks2i4oMPHnFw5QpBBN9/6wO+/eZ3ODxckkaaa1cu8Z/+b/4T4mTF9773DT744A5Hxwt6qVEkDgL2nnmOe3/4VX7t13+Lr/7hd1kuaw4O9rhwcAEVWpbrE27dvIZG0VUVTbniwsXLFMWM77/9BvMfve7yL912nkV62jF3PgDy/M8NbIfy4Gc7gnIi9tZK2JPeGs9UywjD9K3/HVmoDMyS1kpG3M5hjcH6G6xSChXK2GcY1Yh13XoWwMcOmIYgiInjAJBydKnk0CTJoLtzOHpMLyClazufuaO2eqG+l+8HoW8QSExClmZMxjOyTEAYQBwnFMWI8XhCngu4stb5G/758ZHcgIfxFdZhnFSQyDjSEqgBGCJAyzgMQ3AnnB/74Rl4rSVxf8i6GkIvre9aa2kYbuYKGRdaNQizNWkizFte5KR+ETowiFlSMJvE3oofCVvupHS5bRoqb62vPUjalCXr9ZrVasVqtZIAz3LN6fxEUuO7bvt7TdNsnWVPxBZsP5CST3i2nQF5NwAlJwdf+/iD7S+eYz3B7+/tmPDs8y7PfVY1s93xnH8s5/f9oIuyvmRYgkzlsGo0Z2zf8FjbobQ6B+yiAGt/yt1zcRDy+PQd7j94gLWafDQmNC2j6Q5BEMs83jroeiajnLa9igPCJEGHIVGR09U1kQpI8xG62hAozeuvvcYv/0//LZI44pf/7t9FhSEqjLC9odxsOHz8mHffe49f/dV/wjvvvo2xPVcvX+Pzn/scX/z5n+fyzVvQteAc6/WaNE19BYis/KyRE7vve5I0I0tiQJEoTTEeI6yZ9RUEEbbvCKIYrTU7Wcr0wj7WrwTzJAag7XsePHjA48ePGU8m3Lhxi2JSyKqx6yDQ5OlIUhGShNOjI07mpzwzm2HaltOTE4o8IyvG5KOxRCO0LUpp6rLm6OSYIssZTackceLHgxsCrcnyHJ2GRIkmK1Kmo4wwCgnjkCzPPL3qVw6dZTSe4YC+7bFAFGmsleBEHWjpFFQBKgyksqYt0TogizIiL3wNVYgqCrkQas0zt54ly1OiNCaMBRiOxq/y4svP0Xe978IK6TtDEsZkUSrxEkHCJ37mizz/+udEx2R7eqf4j/7T/z193WDskKWiZXVkew/+Ik5O5/zxG2/wR3/4DX7zt/8F3/zjt3DAAvjVf/oNfv233qQzPc5asvCrhKGmrDvKppOaNuDhI/j//uN/we9+5Y+4efM6n/70ZyQJ3BpcmvDFv/Yq/+F//B/w3vvv8/j+Ia+89EnqdcXudEoaab72td/Hupq63jCZzAiCiJ07d2ldyXO3nuXf+g/+HV555SX+xW/9Jr/zP3yTt96+R2dgOs757Oc/hXE133v3O5TdChVDnAcUU3h0Cq7vufvuu3z1D77E2++8j3ENP/dzz/PX/8bnuXRhxKY6Juo3GFNx+fIezlq5dSvDh/c/5P6Hj8kUzP8EBvGvtj99e1oA/qf9XG974iDeMkbOOdHc+VZ6Y8VD27at7+QKhXlXijSVcl1jDHUtgYKDEWYQ6eLPMwFXw6iso6okC01rhQ4cWluSNCCMUtFcKUNZrbdZQMEg4o0C4ihF1EnnRMIIWGvaDU1b+QoNAWB5mpNmGWEYeXCovGYqJUlir5OKmIynYsDwz7V1+EWJ1IOEoWd3UrnOAL0P28Q9OfY7X2WyBRceHOlee0OIPgOUnlYZAkT19qatvRpbfF0gP9+2HcasWa82W9ZE+xFhFAUEYUQYCmXTGyNgLQgIA3Eu5vmI2WyPOI6f0J06ZL+t1wKghuTt1WrFYrFksViwXC5YLERCsFouKb1MpK5r2q7dOg8FRYr0YGDYBHhbP06zoAyoXv5EDEwSgeEBvV90i1TtDBwNrr6n9U4OK+ND/5kYRq7gwWp07r06zsaCw9h2q03T55ygevu8P8rtYwWaknxMuzEYJydeEMXEYXBWAAiAQkUhcZgTxJEEXhnD4nRFWldMxmOqzYbjoyOqzZp8VPDSq6+SxAkqCCimk+3qwjkwu7vsX7zIMzee4VOf/iSbciM0rRH67/j4iNV6RZqkXL3+DMYa380DKC+8U4rOGE7nC6ZjR5okcnBDOVlQklvadT3lZu0F5p5+tE7KF+NIftavjpIkYX9/n6IoPO0bnZ2kTuy2xvSUm5JISw9PFsesFwuK8diL2hOaquLw7l0ODw/J0pTnXnyBOEm4cPECQRjKisn1hLEmiDKfltYLVRsE2K4ljkVwjr94r5dLNlXF7t6udEplmVDVviC5N3Zb+tl2HV3XkSTivNGhJkTm2A5w1lCWmy3t75zcDEbjMX3fimarqdFhKEWdOFnJxamsbhKNViGrxZokUsRh4lekAVmaygWtbaFq2b9yEaXA9B1xIu5IFBgfQnnp2g1u3HqWn/v5v8E/+F/+RyyWpVSW6ADVdazXa5LxiFCBsj1hFNJ0Hffv3mUyyfmNf/7P+fVf/w3uH55wOC957/0TvvwHb9Eay9pIxMCXfv+PeeuDf8imKjFtx+7oD7DGMokj0BFNdYS1HZevXuTv/L1PMJrscfKbf8TjdU/5/Q84PlnR9z3vvX+Hw+NjXnzpKn/rb/11XvnE61RVy6hIKPtLnC4M6XhMmK0IMs3FsODgyjXeeecd3n7vXR6fbphMpnzxiz/DL/z1z/HB7be4c+cBF6cx63IJvQT4jadjHPDO7Xt883v3qP4KMP3Q29P6pT9r00pJa7wHNH0nye9t12KNsBqRZ3IA/5nusV4fJOM5iQoIg4AolPG8lLb6G1Iv+lCtFUGocM54pkaeK4oijGlpXIdS1vd9BXRdh7PGM1Vy49TaEeozRs0IZevt/8IGWOcQYj6iLDtM31BWlt60RFF0Tld0Nj50ToTBeZ6TJBlaSxJ6FEckcUIcp1JQm5wJusPAO7C0xA94Tmigl7ZuQrkub/kmGOTXTlg5GXM53OAeUyDwaLhRO89QeTec1iLm7kTPxFM3+iAYOuGESRwSs7WPTJCRZkSXJqSZxRonhch+FBjogCiJSZOM3d29Lfgc3Hp1XZ9FKHiGqqokOX29WlOWG9q281U00hlXlVI70zSNNFb4ehM5ho62q3GulUWl0h60DsdH9Heyz85nj/lyXs4WCWek1MDSiaZp2O9nWjp/viAkhHUKq/Aoans2IU6/87qnH+32sQJNxih0eJHxbA+tNE3dEEaS0l21FVGQEOlQ4vWNYzV0kI3H5OMR1WZN7Vchs709OWi9JYlSlLXYvkeHAdZ3/uhQVkhKK5IkZvfCHs7IRUGh6dqe5WLBcr6gbRp0GIJhCygcnvZUijCKGY3GxEmKtZaq2qCjiDhJMEY6n+I4JkkTgsD3kFlLGMUEvmk+DAKaptmuUEajkU9yFX3WsEqVEDc8uEqxnThtiqkUzGItaZKikwRHw6goZH6cJChkhRnoWADQekVdV6RZQj4aEcQhtutoyoowlkoBsCwXa3QYUYxHRHFMoUXDpINQBJcogjAi0IGAOt+vBLKS6LoW1GAb9qWgfYcONXEs+UMq1GAcCk0QpugWbNehAsBZYct6g9MGFwkw6nsD2pFkBZoAdIAKQpRz4LvgQixojRuyUAIRkZuuwzgpdlbgU741l69e5cr1G37F46lg5zC9EaDVNqzmp0gOkmZvtsPNZ29y69bz/K1/8++w2GwwvRT/Wgc6FMDurMHZFoNBadisljy++yHL4zm2Lvna1/+Iz3zmU3zv+9/nu9+7z+PFrxHGKe9/+JhNCcuq4b/4R/8N//if/DrvvnebR4clZR/x6OEpLzzX8C//8FtMdgqSoieIC/YPnuHduzUfHDvyoME8PuJXf+M3efvdu+B65qcrTo9bZrPL2Ks1y6MPKQrF/PiIvq44ONjj9U99mgeHp3z3rd8nCBL+wb/9d/i//jf/3U/2wvAx257OZPpBt4EHCAIR/CqlaNp2O9IzxpDECdPZlNF4LBEqXU/n2eemaVit1ywXC/qgIy8KxqN8C3YkuV40UMNNFxxRHCDONUugz4p+27aRQN8okmuUzwsKAuUZKykbFi3MkB0lN3Ln8NlIiZhj2kZ0TWkmsSI+ODEMIQjOXG+ymDVY2/vaD8Nms/T7NdgyzBIbIAGKUSw6ozCK0EpcVtPxjCLLt/qewYk1ZEkNY8IBzGjtYxu6VpxcyBfOCmPiBiCgfEq3737TmsgzUc66M+DBsK+E+RJDkT2LRvDsjgW6Vvr7mrphs95so1C2onTk6QMdkBVSY3IWXpow2hsTRbGANyWgdXu8+86HeIrVvyxLlktJST89PWU+n/ui5LUHWQ3GdrRtxWIJfa+lcPycxkgqTc70TYM5QP7xzMF3pkECpQVsncUWnIEmkXF4EKq90kr5oFAvnZDHOtO1yR9nTsMf5faxAk29g9c//TluPfs8dVWzWjzm4sV9gjih6hpcD4RyYvdd550IG57ZmVGMCt+1pAnjmCgbcfToEUePHnPt2efBBwyiIlQQ0rWNWP7DoZ7AtzErJ8+jFC7UjMYjkiSha1qWp6cU0wl9bwiUEq0OA7rWpKl0hPW90I/a6xOsOnOy5EUBQFdKsnVgDLEVANN0HdYY8jzfriK1t4pKIaaRpuow3CrjkiTF6M6PkM8+rW3X0bctWZYx2dtjpBTWdKyWSwoLVVmK8N0N4k/DanGCQwSZtu+Zz48xvWFvf58wCnxqcSv7N4/BGNqqIwoANIv5krppmEyniJ1f0l5FjGq2NK4w3g7TWzbLJVmS4bAS0+AzSUzfEwYCnAOEJo91SJiI9mgxn9N1hiwrCNJILkgomqZGKU1SZKClECdMQpIslhPQGnrbYXvlT0gLKqA33RY0ORtKT6BzxGHsX29AtdlQhCMIQ+Ispa4bNuuKvf0LFKMZL73yGrdefhWDX2V5Onw4NioItloJlMN0S07vP2S1WNGuNnz+c9/hhU9+gvffv83tu/dYNx1Wa37h39AYAzglrKQzfPKTn+Hk6JhH9+5ycjjnzW99l1/5td8hz0LGexlBonh4At9/5yFla6lpWbTHHH3lX9I0NfSW3sKXv/JNJsWIg0sZ9abncVkThiOyiWI8nlHWhu9+/wHff/eYMIxp2r8SNf24tkEnYrYjLPVEAGQYhly9doWf/dnPcuvWzW01RZKIoeSdd97h9u3bPHz4gNP5nKqspDJns9mWqIY+oqTvZXTXda0kWQeKOE7Y3d0lSRO6tts6uYqiIAxDuk7CHOVa1m/HXUEQorSUR4tGctAIKeIkJNABbVdLPIvr6U2Hw/hxj/NaK29T1z5s0TP4cqceQhv94sX2mLan62rq+py7MAhkwaU0h8FDYZsCLdfMKPSi9NinVSc+ekAAmPY6y8lkSpJE2wXqoDHjHBCWG71CewLEDgXDfn/Yc9e5rQh9OyYUkDCAHuUkZ8p6a7MxiO6U80+pvN5MUVaVzy+S6IQhMytJYqLQV7s4iUfoexGdJ0nChQueZTdnDFXbtluGqiyFlaqqmt601PWG45OHbDYL6qZks6l8L+LKC9hb3xvny58dXgultiwbuO09WQc+LkMPuVAS9umckfcWCO0nLrozbdZ2hLj9OAhbN+iqBJz+aLePFWgCzSdefY0rV66LU6yVNFmZ98Yop7xIUcY0cZzgELeaMYZ8NPLjFIUKlAiZN6XURZteqMjNhnyUY52EpgVIiaS4DKS12iE9QdVmjVIBk50d2rLi+PiU8e6uuDmGOa4/0PiE3+HPNM8IIp9sGgbCUnkdwiCO1kGwLS8EmcOHUYQxhjCSeASlpeARawX4+RRX5wPjcI4winHWYrpWxlFa01sjq0qdEyYxAY6m7jw7ZumtQVtFmuc+iqGmXC+8XkgTRgHSq9TSm54sy1G6p2tbiGLiaBBnui3Aa5uGqiwp8pzOyGvP0gRjDabvRJPmxOkjrhlD37UEoxHrpQgeR+MJ4+kUNdigfUGo8t1ubmCxOkNXtxRZAU6SaAMVol0gPVRWM54UaCRsMstigkDTmtYzjYN7Q2h40xmiMCCKYw/aZF+ZXioqcIq2rgm1JooD0X1FAnYnkylBFELnCIysMK1zoP2KHEXTNJKImyagHU25RkUzdl68iKk7rHG88InPkF3Y4eVXHtFi6ZxEXgSBREbIwnfoBDMsjx/xzpvf5vjxQwId8+lPn8h4NwYVwXIdMN29zN/95V/eMuTOGJbzOZvFgjgKUMbyja9+i0+8dI1x7nhw/x7jPGZvOuK99x/wxvfu8vVv3ePR0pCFDb/9+1/7iV0Nflo3ayyd67YaJBl5C1N9/foz/OLf/EVeePEFsixjPp8DwkLs7e7x/HPPc3x8zHe/912++8ff5d69e5RlSRzF7O7tsr+/T5ZmWGto2pbNZs1ieSqGFSNWcdNLno4ABmGTurajaVthqb3zz/pUb4WUAA/6nSgKPAiwktsTKrR2vpLkXG6SCjxDJVUxCr3NrZPzEz8GCtBqeE6HsWfgxPT+pun8DVUpsNA1IowWluk8aDoPnKLtGFMrTZKlXNjf91U0PMFOBf56HZzT0wxWffyiKpQ361+O24KHwYZvjXT6OV8QbI050+cMovVh8wyTjPMkHNlaMRzZZhDbD68v9JUzgccszoMjYfxmsxmz2UwytbQmjhOyLD83rnXepdjT9wbnDG1XMV8cUZZLqVparVkuliyXc5bLlfQkVpI/1bW9d/H1Hmy324J701upqHKii7LG65u26eiyOAgR0kJckudcjeDHqH7UOojQnUToDCzsj3L7WIEmhWJ3d48sL8iznJAdQq3BOpIowlnoOxHiaqWYTiY4RoCjqjaMx1MWyyV5lpPogCRJmE4mWwamrEqWyxVKXyDLUiLf9aWCQS2Frywx2/GSRpitodTV+qRxiw9lU9rrAYxc4IyT80Z78IWMm4aiSxAXWeQj/a0VjYJSyveOBdRVRYLXLlnrdQmhn/ErAUh9j2nFORPGiRfmaaJUSovDJCY+V444uGNG0ymB1ozDAOczLpyV95hnOTZzhFEkoWpxQt9bAiXvOwhDnJELVt+0KAVJmhKEAcb0FKOcrMiJ4gjXGEkL9nqLpq78BdRfgFSKwpLEkeSaNDXr1YLedIDFtQ1aadLRiBAZeZi+o2k7ATs6JE81oZYwUY3DmoYwSqg2CzYnJUVxk7quiOLEj1wtWjmiQISKbVORpono5zQEOkJWR5Yg0linUNavnJwjSSPatkbphCRJSLUm3NnBti3Wj3UbfyFp+x6lFPmoIE0z+q6hrkrC3RmOjuPDR4RByt7sCqbqyUZTjGoo4pw6jMnzmHyyA0FM77TcxDxO1ziU62l3p1y9dNGPcDSf/hv/Jg4lFvJhleub7L2yA5zl8ME97r5/myJJ2CzX3L97l0sX98hSWCwqVstjIuVYLk9orcYSc+XKAVcOLvLWe3d/MheDn+JNVtiKKBTbfqOFAYjDmAu7F3jlE69w6eIlHj56yLf+6FvcuXuHQAd85mc+w2uffE1ACZoPbn9AXdVEQcQrn3iFz3/+87z88stMZ1OUAxUolssFb7/zNm9977u8/f3vc//BA9pWgnQno4kI0L3uxRrLeDIm9aXA0kIvadfOX18CLdqXHrOtWwkDTag1vWc6wmiIQ/A3+aHI1nYYI6Obs3JY77xFbrJbYfCwcPUMh3y+B15XAJtS1rM4lrbr6U1L02oBPoH2r/V8LpPi7p3bWz1mEsc+QkC0U2makibpVkMlWqSELCu8ED3cjtasrzzpei/c70Xnpfw1vO87qs3mLEPqXNK58mKrbXRAIF9Ka4JQ9FqDkF2Ogdm+/uH9nzGUsFot6ftOFl9uuG8Ni8azxO0tkIxCorggTkKcu4D2FSoD+GvbjrpuqarNtnJmCPycz+ccHZ2IC7CpWK9Kcf5VG9pa2K3edvJ5YQh4FZH4wM4NC9otSFJPnx/y58CY/ai3jxVoAgmPjGPfE2RzMAbb9bRdh7HOU5Carm4wXQta0ZUlvTUkTcM777zDszdvEXsB3XgyRvkckzhJmE6VZC35VVTgR13WGhEBWq9b8fZ+HQRYZ2mbmjRN6dqGyOeXoAaxooMgODsBkUAuY+UCAdJzZKxlOp1ijZEuOy+ABDlBAq3pjNn2N8kq023zRrAW28snpu972trrn4baDn8hqNuWOE3RYUjXttiul5myF3LL48lJ1TS1b0QPtsJ0cCLC15YkEReO5G4EBE5WPH0rWTFxKiWZOEuaJigd0HYNcSLuD6xBK0egFX3XYvpWLhB+DFg3NaY3xFHI/qVLrJdz3n3ruygHRV5w7ep1wjzHGIfWMYFWlKWscgZBaG8NURRieotzrYgXkcyWe3fvsb9/kbZtmU13BAAHAcb1VOsNgbMC7rybripLsjQlyTIwUrXiEJdk31th4qKY2Mnq3tDz+PFDLl266Kl+SAhJUqnwCQKFsy1BIMet6yqapqSsNiQBlMGacr4iCRNOj45JRhFJGNC3DZvVKTrKMcSs1w0YS+AckQZcQ10uiSPNhf0LoEIuXLpOEMdESUxVbqSbLk39R8iPBTQ09cvMjx8TBSHlquTR/QfkWU42Srnx/Mt8+N5bBIhTZ7Z/kaN5xfG65vN//Rf4tX/yq/yf/4v/+0/6wvBTtQ0jIYWUkp/PKIqiiFExIslSQq9z7NqO1rUsF0uaumF3d9fHERjyLOeZZ57hF3/xF/mFv/ELXL9+HWsty+XSszrwwvMv8uytZ5lNd/jy732Z+WLB9avXefbZZ8nyjLfffptHjx5Rlhv/fKJziqMYo30XXdf6UZ3Ur9jeUJel6AWdmGqiMCD0N38F4lxVQ2VKSGBFwzS4pgYxjziPYbC0DyJg78V5ou9N+XFe4J3NcKYrk98XION6touhYQQkLjphzobxWeR1qYnPnUpTAU1xFBOEIVEUk2UFWVYQRcmWHRyE7GZYsALaa7KiMEIrPGjw3YAebJ1phPwBHxgrP9mwzqKs3K+sMzg7aITUOeA0FOzKOV/XFX3fbffB8HkaBPHb8WU8vC9h4MNQHH06jkiTVP49DresnzECjNuuk9qprmW9WjOfL2hbAVKbTcmmlN6+zVpKkstqQ1U1tE1D17eSPdWW9L2A7L4TDbDxY08BxwKJh1GpUpLvZM1PuaZJo7YWWRyYrsfalljLzR8dEEQiTnZWhMG96Wj7Vhrbyw3GdCL6BYIwlDRab19M05TRZIIKQymcPKcBcsbQNBVd1zCejDEYyrIkjGKK8Zi+60SzFPhRWXCu88dnNTV1Q6DPLgx4OhXYfsAG94Y1ZmsLxqN+ccAY0QFxJihtq4a6rDw69+nBsZy0Q/FkkiYypux7mqYhyTLZj36KpgO5uK1XK8bFmPV6Ix0+kQTp6UCLcc43bXdty3KxIE0zUHIhlFGZQmMhCL2YVE4YrRRRHDG4ZoJAGK62E0t/mkRYz0gpv4qSsk05nl3bUkwmRBE05QalAvKswLQdve44OT0FHTCaTEjTXKjtSJg6rRTVZi35J1FIMUpJk0hCc/34sq0aurpBx4l0EhKIXsnKjUn5nBssOKswrax+kyQCHaJwREkGVjE/PqUb5exdvABtQxgEqCCgaRt0qCjSDJAohbaraVuDDkR7onRPmGiyIiFSAToy6ARUAsXOiKZZQehwtqfvLMo6ur4hVNKXqJ0lVAbXWSItrBhK+jNc14ANaDZrjh89Ioojpjs7hGFIVTUcHj1g/8IOxajg4MplnHNMpxOm0wlV3ZCNCq7euMHy8DOYrqbvGyY7u6zrlrJpef6lV9nfnf0VaPoxb1rLTb/vpN5HtClykxpSpseTCTs7O3zuc5/jlVdeoes6jo6OePDgAdZajo6OqOuaa9eu8YUvfIEv/vxf49Zzz1JtSt555x3eefcdrLVcu3qVT3360/zsz3yOJEqpq5YP79/nmevX+cW/+Ytce+Ya3/jG13nnnXf58MMPeeedt1mtVhRFwWQyIQi0vzmuqaoS5yDPCy9ClnGe82MUrbVnpQ1dJyJlsZtHJHGKUiFhKP8m40HRPtqBTfIgagjmlf8/62sDzpgan8T4hJbIsh2PCWN1lgbunEMrLSDB28Ks62lbS9e3VKW3uwdn4nE1sLg6JAwkKiHNMrI0I0piuS5oYbbiITE9zYjThDiKyfL0DCCfG885NxToOv/+/KRCi2zAmF4AEw5nRait1dnIcNBcyXs7K0bWgdfZenei1sq72KDrO7q+ZWMV1hmCQJPn2TYPTKpLQt8JF3gx/iCqD0jTmDwfszPb5/p1OD9i63tD27VU5VCQvGS5XG37+1brBcvFiQdTpddOSb9fVdVbZ9+QPzaQE1oFdPyUM01hFLJcLVlvSqzd2bIiOgxIEmFORCOCABgH7aqhNx3T8Yi6qrh+/Rrj8Yi6qVFaMZ5NccpRVxVxksjJhPN2VJnLDymxMjMXhIse2AdB8UJfykrC9nZ7Qlp3ljLbtCJaJhFHR9u0bNYrwigS15x6UtxtrdiJlfb0tQdSA6hqmgacw/rVF+eEhpF3onSdaA3quqZqGl+xYihGI3HsJbFny5xP8RUQF8QRSRKTpKks6JwRrVPX+ZA6w3Kx8oW1PX3SkeWebWk7oiihqipQlrZtt2WTYRgyGo9RKPqmozcdUSyrGB0pnBPn35D1NJnNMF6Qao30/V24eIE0yQh0RLMWAWtT1RjnSNOMvMgoRgUqDDCmxTnD8dEjxqOcOBGgrDKNdiEX9y4yGU9pNzXr5Rqb9eR5QRAGFFmB0mBNj3bil4mC2JeXamxv6ehRkWZxumC6MyPRAavTE+qypKkkA2VnZ4Z1hqPDx4zHBePpGNN3tE2JU44wHhriRfSY5zFtlxCokGKcEUSKPmgZ7RWcnoqA1SmIgoAg0vRtTZbF5Gkq7Yl9g20NRTbFaocJRcPXdy19D7Q9YSy1P31doqKQtlwzP3zAONO0mwWj2RQdinM0HsWsmw3r5YLi4iX2Lh3408TSO0O+JxfGejPn4Orln+xF4adw0559sNZu3bRDWv9iseC9997bCrQvXrzIaDQCDe+98x5lWQo4V5LXdP36dV577TVuPHODpm345je/yZe+9CXefPNNAF566WW0Dnn55Zf51Gc+w3pT8sa336QsN+ztX+D11z/NeDLmhRdf4v3b79P3hvv37/P8c8/z/IsvsLMzw1nL6ekJjx49YL2WKAtjeo6ODrddbcfHR7Rts2XWRb9ifCij8SxavP03AQSiXVXqLF5hcLGdOa8GsOSzmHxq0hC/qBQyuPOTK+1EfyWjJgkAHW7wWmnCOPLXQ7m2C8Cy9MMoUF4Ew2rU4utLUERRTBLL6D7wYEIpKT8Ow0iaIZKUOI6F2c3zrUZqq50KxR04gJQoEiF9EPmFOL78Fy+o1hrlk86Hwl7r7xPCckliuDb99nnCMCQgwKE5K+D1Ces+nVtYyJAwcE8wPQObF4ahT3VPfE5YvGUfozjajhq11qSSAsN4JOO0rm1p2pauben7jrotqas1TVvRNDWbTbkFTUNCelXJGLBtWw/IhQlfr9ccHj38kZ5/HyvQ1NueosiJEhlthGFA21uqspQLvPJuJwum7en9zT0YHCTGMBqNiMKQxXyB0prJbEpnDIvFgktXDkSD5DVLCkHVSil0GJLnBUnsK0CMZTIei122a2malmIceebIr2KU3a4CwzAUy+tACRuD8amtAMWo2Io6tSgcpd/OU9rWg7AojjGdiJCrshJdTF6Q5SMwIpwegJq1hiCKyIKAuq5Zr9ZUdUWe5b5/0Q1mBi861BRFIWDFjwml50guTkEQ0BtL3bSEUURejEmzAkeNQ9G24p7rux7rFE3Xys1VKdquoxnqG+JYRnWBJiDAWmjbjsCvhMq6YnU6lwT0OEY5x3i6gw5geXrEarXEWXDdhlBHBHHEhYPL29XGZr0mTBMSnUmVA4YojHDWsVkuidOUYrKD7WXcS29oqoq+l5/r/fs1vSFMIkBTbUoJm9MhcRoQhjG2s5SbiqxQrBYnbNanXD64zKWDixjT0dUbqs2SPE0wXUVVrtHKCoUdyCo0iCLiNMEYaXlfrE4pRhnGSQp63Veo0NFs5h5Iiu7M+BTiSOZ9lMsTgr4jjiLqzVqypvIEFwQ0bYuxLU3ZUFcNo3REUYwIkcb3znZkScyVS5fJ04zDo0OJ24gjrJK4jCRL6MuWcrEk1IowDqjahuV6QTrKmc52mM9PWC8XP/kLw8dsezpm4HyB7dOVKR/5e8rfnLQiCAPfqSYO0aZpWCzk2hYEAZPxmPF4TFKkzGYzDg4OiKKIi5cuceHCBa5fv87lgwPSJOWDDz7g61//Or/7u7/L97//fdI0ZT5f0LU9Ogz5m7/4N/jCF77AerXhV3/1V/ngg7v8zGd/llu3nuXK1WvcuiFVQW+//TbPPPMMn/3cZ3nhpRcpipyT4yPe8WO8vd09tC+gnp/OuX//Pn/8x9/h+PgQY+UaHfhr1mq9otyUtG0nrmQvMh+AjLUiyNZKFjXyPw90UH4Ed7aAHSpRBgixDbdWZyDKJ7aeAbDhODhfByLf4KU8WxPG9jhx9nsaESTLoTN0fU1vGj9GdNsRIUg9zsDypJ6VCraMzVn1zMBIZZl8pamMxoIo3DJJSp0BmLO+vACQOpteWfBifhFcG/+aRLN7PsBzCIkc9E4MdTNOgfOs2zlGDixda2mDnrDuCMNGwF0QctaPqJ44D7af5yAgCiPiKCIZ5YShIoqkTgVtGUqbxand0TT1NsRzuRzS0TdUVcXR0SHHx8e88+73foiz88/ePlagqV4tePbZG1y+tIcONZ0zLBZzbKCJs4x8MoFAetOGZOeiKDDOsJyfkCY5putpVSslgkEsEQBNQ902LFYrmrYHaySmLAjJkgitFF3X0nU1URQwHo+oNpvtSK5aLLh3/wF7l6/StjVRlMgJiqP2XUDT8YQkiZ44UfNRQZZnQq8GgV+ZOLSVlU/fddukXgl/87PiridKU2G/kEg1aw0K6S0DEcS3TYsKNGEYk2UjkqwAZHUUhuF2Vbd1R3TiRgxCybqqS9ETdb0I84oiF+Do+6HyfCQsWVbIBcU6YXH86HE0HcEgwfSrMulHaghDeV9dI0LS3vTkeY5SmocPHvHB++/zqddeI4xCsmJMGAK9IUtykos5Sgcsjk6Jk0CYxjDailLLVUXvw/gWywXj8YjpdJc4ClivV3RND50I7OMsAx/jkOQFcZyANXRN7ROURZi5KSu6vmf/4kXCJMEauUgb54izlCvXLvG1r32F0SxjMhrjGkuUpMTZHtVqSZamHFy5SlVWNI0kMDd1j6MnLMWN1JmOw0enGAPT2QTrP7dxEJElKVprZvE+2mnCJKVVHW3VEgSRdIFFMQrHelPinCUaT0jGBX21oSs3jCcTuvaEzWbNcj5nPBoReYfmaDZjtrNHGEfMZo4wjuhMT2c6ut7Qdx3O9tRNCUCsUg5Pjjg+OWT/wi5pFtM0Jfu7k5/0ZeGnYjsPoKwVVjgMQ+IkQTmJCAijgN29PV566SXQmg/ef59vf+tbnM7nRHHEJ15+mddff52bz95iZ1ccU3mRy2JNQWc66uZM1C3apgXf+c53+NRnPk0Qhlx/5jqXL1/i0aPHvP/++xw+OuTgygHFqCCJEv7W3/oldnf3+eD92xweHvHSKy+zs7sLzlEUD0jTJRcvXuLCpUtorVivV7x/+zbWWh492iFJI27cuMF0OqWuGz744APee+89Hj58uI1IkKb7gIFNGm6iZwWwA0j4U/bn1oYz7NThL+7s/93ZwlJkQU/GC5w9x3Zw5n9vYLW89AElukf/oOdLkofcIhmvdlgnq9iyLInW0fY5hvc8sDUSI5CQpAlxnAgoiWPyXCpaOA9A4mg7BRmkFpEKCEPRabJ1ng1vZRgFDizT8PczKb21lrKsCPSgO9PnxqDC41kLfWexRgCvVt35vbQVtJ//e+DF7ENIaRgq4jgkSUOp8xqOB9q/p4RRMWFv9wLGGNFPeXH9er1mfnLCr/yT//bPPL9+mO1jBZqCOGL/0pS8CGmqkraqSbOUZDImSlKsUvSmJwwiXKCZHy/ZrBbkoxxjDeWqZL5YcHDtOscnx4Q6ZHf/AnGSMNnZ4d133uXdd98hjmOhseOY69evkUQRp6cnLJan7OxM+fRnPk1dllKJUhTUdSMfkq4jimPqusY6S5pmpInMpUM/NrR9jzMySpNUcL86MIau7+mtIUtTHI44jrcrT6UUgY8ZAGGq+q6j3Gyoy5LRaESe5z5lu6XpGqQLKkKh/UUFwjiSvCe/auuHks4gIEiSLTgbnBhBEJJk4oYJgxClQwIlWiylxHW1vXBokGWc5GOoQNLadRD6DCKZXbet2KUb7/LL8ly0WL5j6sqVK+zs7jEpCqxzHD58zIWLF4kjTW8k9yVQAdM9qZdB+3oCn4gcpz79NwoZTcYUE3EEqq4jy72WrO1xGrq+JwliitmEKIhoa8nnstbQdS1hFBAqzWgyJowTwiBgs1qhdSRjPA8SozTh6tWrJEFAGEf0fUez2RDGkbgZjSEtBuAqYFcHLX3rg1SDEB3G3HzuBbIsJtQO07a0dIBGRyGb9YYkjimrasuctW1HoA1hGKOjmCjPmQaiDSNQlKuSqm7QQUie58TXUqr5gnK5Jhvn5HnO6viEh/fvc/XmswLGTlc+EDRH9Q1d35HlOWEKm/kS4xyT8ZhLaUKSptLiniReoN7+pC8LH7vtrA7C/Yn/7YluLc7cTHAWCRIqCZXsW7lJROmY8XTMpcuX0EHI8ckRp6enfPet79J2PaEHPTe4SVEUJGnCYrXk6OSI6+11ptMpz9y8wfUbz7BYLsBJ/tvRyYlEFxhQoWI0HlNWJY8fP+bBw4dSY0KOsT2vvPoqQRBw5877PH78mHJd0rUdJ8dz3nvvNrdv30apgMl0xq3nnyUKI8ajCR9+eJ80S5lOJ3zxi1/k2rVnaLuGb3z9D1EoyrKibVoPmjRBEHmdpRX2puu9niXYZv882U/mAc+wWbtlTZyTRer52pSzHKCzY6KV9lEowsCg/9VjpgemcBgzIroguY4LQBrUrmqrVVJPPKfx4mVre6+xck99VvRWPxX6sZ0KNGEUyX2gyNGBr5yJBVhlaUqaFqRpTJJEBGHsJQGIO+4cCyTC+3MgyIH12idr8dE+0LYN0EtwaBhJ1dUguOcc8LRWMqow2/15ngGT/SRmlM4N+01LnpeSgOAw8pE/HjRpPVTLRH78lxDHGUUx9cYdORab1fqHPT3/zO1jBZqMM9x/8IhV2RDUFZvVkguXL5BkGTqK6ABjHSrQRE5uahpHNsrRUUDftBhriaKQyXgiKdqAwTGdzXx1iNDdXScurulkQgCMihEX+4vkRYZSAfP5gp3dABVHZKMRFy9e3MYGBF4QrZTaNnlrrWUUZ6VCRAOm7eia1gvoAtG5dB1ZmtLUNbGffQ/ZO85IF16WSgr4eDQSp4XWosdSkiqrPeA5WwFIRsgwcx5yp5RSW+egVr4WZRgNwJbJiqKz+ANr8fkpkoHiAuWrG0SL5az1jzkcNUXbSZGy1hBGCaOJ5BOZrgOlRc+kg60gMc0yRpMpCtmfk854110PKiCIYmGJgpDlfI7qGn/SGZxGHCyeidNRRBD6kDhnOX58SrUp2d+bkY1jTN/SWktdGUZ5QV1LWneWZYwmBWEsuVhJFBDnGU3bsak2BHoobpbQPqdDLl66CtbSNYZAx1hl6FtDEMSsVmvakwWj0WirH0vjxNffaJqu5f69e1y4uEcRKapqw3q5pm4NnXFcvHSBJEkJIkWzruj6GtNZutaQ5zOMhbJqifIC0ohyVVLPlzRVRxSkXLl+gMYQhUCekGcJBqjaEqsdaM3p6ZzZbBdrFdYoMGA7Q1uVaJuSxImMA61D6YhsVNC3xvc9avq2o67/Ktzyh9nO33DPfz8ApOH7QWsiN1+xyhvbb88ZpaXQ+eGjh/zWb/8WL770Eteeuca//w/+fY5PTtiUJbdu3ODmzZs477Iqq5Lbd94nTmOeufkMN2/c5K/9/F/j0eNHvP/BbR49fEzVNFy+PCLLU29AkBykNEtByyKVQHE6n/Po8WOuXLnC5StX+MVf+jeI45idnT3Wqw3vf/ABX/36H/K1r32NN77zXR49PuLvxTHPPHOdLC9kzO8Cmtaws7vPpStXMX3L9773FsZBby0GJZl6Pu5FRluKIIxROtyaaIZR23D9+qhNwIrb3tnVwCptpUnujGEZWCP8bdsJXgqUFKe3bbt1H0dhKNfhIR9oq786G0EpHW6PrQ89ECCgJBU7jIKzY6+2XL28jqdAXe98pIBV9Kalqta4Q+d1TBqlAoIQwiAiihKSVKQRWZpv4wOyLCdNM+JIXOlZJoupLE+JosS7MIfgXZEVWANV1WGNiMAldFh0Zj4dgGGvDUDsibGckmmQ//AzBLcqPyNVauCrFFjlXXDDiBVQPa0y6KZnU7YEwVqAlvLHSDlCHdL+GK5HHyvQ5PqO3/qNf8Ff++zr/MwnbhAlETqKsH0vH77AI3qrCJRmurPDdHcqBgDvahtPp6ggZDydAvjcHwjjkBvP3uLGs8/7Jxue1Hgln5UlgtbM/coryzJwPil3ZwfnReRJkm5dGMYYWX14R8YQNKa95sh2LdYqCGLCOJYDrjVd24FSJIEEqzkrr0MrfLItZMWIrBgx1BcYO7RC262jASeOCu1zPtq2lZGM76CKk4QginBDKKgWhG+3jdTDBx2hkrcfblml9L2USjrrMKb1Fy5HU7ckaYbWEWBRXrAZxYnPvXKEOiYiBqW3tLZcGxSm6yVfCtjZ36eva5q6JwxjnF955lnms6sECDotJZdBFIAegtAUy+MTQh2RhDFaRwRRgvM6q6RI6eqWpt6QxBFaK2ygpL4lS4UxKisRSXoBY5an4CTaoK5qdnZ3WS9XZGkOQFNVBEFImk+pyjXGtERJRtevpSdM+Q+YF4GrMKB1HUo7Nss51GvCQBOgCVHoKCCIcvIiw9FSTHJwDtNBZBRpmNEbceM1bYvTltZ2qDAgDGVlaDvDsjyl7SqM6blw8QLaQm8NNgqo2hq7XBIoEZgqNM5Y4iBEp5ksqo2REFkL5WZDbzfUZUsahnRNh9YBSZz+WM79n5bto7RMZ0WknmmyojtxyGdfawFTfd/x3u13+ZV/2vDqB7f5xCde5srBFcaTMcVYmKXHh4ccnxzzrTe+zXvvv8eDhw9J0oRXXnmFixcvcvXaVX7+5/8adV1z584dlNI899zzvPap1wBYr0o+vP8hKMV4MpYOSNPz7nvv8e1vf5tPf/rTvPjii3zilVe2QO/k5IT5YsnR8Ql3797j9HRBkqRcPrgsWXnTKc89/7y0NgSaJMtYzOe8887bvPHmd7h790PWa3HeaR0KO4+MJIfy4CGG5Yl9+SfvZJxnUNiO1wbxNluh9/nDoAbWydrtr4k0QgCdcgqlndcvOXoji9FAD2yV8Mv6qbGh8noox5mwfOjC27rmlNqahLavwbkn3t/AolVVRds15zKZBOqdL7OVIM94u1hP05QkSb2ZSVx8IsXIt52ggyNQNFYxUZgSRyJUZ8vYPbnXB+fiWRXNILC32591w85h0NcOLJUPRQWMkf253Ufb2SfnRqbyjEMsxEAQ/NSDJhVGHD56zPzkhLR4lQjD8vEjRpMxURbjDPRWaFTnXV1d35JEgdggw5DeGpw1MvpCoxwkceJBjcF0lQipnZwo2sfvW2twtsPYnuXxIxlrBcISmLZjfnJK/fAxKgyZ7eySZRkq0DR1Rde0JD7TIwgjsJKqGoSaOIu3KDuIItIgkI6nIMAZh+t6ej/vDnVAEkdESUy9XqN9ZYozhq5vPYJX3mYvDI5B0WxKPyIMWZ7OmU4nvm27F1G9szR1JaOyQsZHOG9tDgLJUOrFNSdOjVhKOVVA11YikA6EfVJIQGXXdpLQbRAQA2iffG5NL+PHKKJtG+YnS4YE4UAHFKMxcRRRblZYY7ZRDnEmo86uaSjXa9IkoZhMOT56JPRskhJp0TqIKFzeR93URIElCkIObt4E7ejWc6p+TZSldF3PaDIh8JEENC1t3+O6jvVyQVVuKMYFbdfQtT1xlJCkGat1yWq5ZFyMWByfkB4cEMUxD0/uEwSaq89cJ4wSlssV070dZru7chx9FMR6vaZsKiIrx+D5F1+g3mx4dO8Ou7s7XH7mKs5A1XbUxrJYzNGBozGdHwOIYSAIFcUoo+k6FNLLFWgtn8FUs16sOXz0ALTBYpmfHGKMJS8K8vEYghBjLNk4oWkqrLGkaYx1oqmI44Qo0H7UYLB9z3q1ou17kjiR420MxXgmi4y/2n6s29YC7w0ugZYKo7IqOZ3Peffd9/jmN78pIOjqVS5durRNfd5sNrz99tu89dZb3L9/34/S7vDlL38Zay2vvPIKn/r0p3np5ZeZz8V8cOOZG2R5Rl02fOMb3+CNN94A4MqVAw4ODui6ju9///t85StfQRLwJ1y7do0sy7ZGkjAM2ZntsLe3T1mW3L59m9///d9nNtvl537uC3z+858jjEK6ruPevXv8xj//DX7t136Nt956i5OTE9GgRpFc10L5XErNR+VDF+Mffj/+SX8/NwrbLuaG/+7BidwjBHxt3c5uWLiK1EKAS4hSwTmR9Vmv3HkHmbWW3vTbMEqtA+I42v7sk4Lus/oWrc5ricDYGB0gcTHnSpvPRn9Guk67s87C4fHOs5xDfMDwJYXBsYjOo4TJeIcbN16gyMY4dy4ZfQgG9Qnpkuc0hJX69+nNUecF/VIp5bVgHkMZBdKVJ/vFbn/+bCEhqecSczCYa7a9fdZiuuaH/lz8WdvHCjSNxwW/9It/i0+89LJky/SWu/c+5NlnnyHNRGRt+xbnND2ao+Njlos5+/szspEImJeLJeWmZHdnj8lkJiK9XmLc21osi6brwWt7Yh+pH0QaZwPmh8c8fPCA2c4es50dHLBpGt5+9x3mp0uss7z00sscXL7MZHeHujOsTk8x3vGge19DEoAKFEEkgrmub5Ai2pC+bsjzQgTVpmfjA7+SOCLQijjLOD46JEsSJtOZUO3rFV3fMZpMyIsC0wmQ0sGZmNA5h+17Ecp3HdOdKUopTFNTlxuariOIIsLY08HKC++spWsaNqu1zMlnMwIcpmsIo5DVakmSZYzGY+g6HIrp3j7Lo0PmxyeEgawd4ixlsjMjzbItSJ0fH/L++++jjBWBexRzcPkK09mU08WczlmuPXONJMlRvbyHMEmY7e5LnpLSvP/+XcbjMdefuUGapPIefU6KM5bLV67gla4409I1Lcb1THZmdG1Dby2T8RQVJrRNg6laX6rZUFUNQRSTjyY0TcPjx48psoL9C5cY5yPSSwkYy3g8xrQ1bbnm9PiIJE1xRtx3YRyyOD1mZ29PIhy0pi4rTuZznLPEcYR1jjTPyWe7qMePqRpDZy3GWR48vM/pqTzmzu6Me3fvogPls2wsL734EsVoRJoExGlIVXVUqw31qpTSVu3QoWJnZ59sPKGuGx4+OkSpYy7uX+TKlQNeeeUTVKuSsqpYlhvSLKSqKuanJ/R9x85sxu7+HmEaU84rMD17O1OiKKKuG9q2JQg0Vf1XmqZ/3e1pd9357exm4Tsn+x6rhXkIw9AH0jbM53M2mw2Hh4eMx2MvEE6p63qb0TRk6BweHvKlL32Ju3fv8tprr/Hqq69y48YNf4OMKcuK2+99wJtvvsnXvvo1Hj56xKc+9Slee+019i/ugNU899xz3L59mz/+4z+maRp++Zd/meeev0XoNY/Gx6XkeU7f95yenvKHf/gNRqMxSRLzmc+8zigeodD80R/9Eb/9W7/Nd978DsvVkqH3bYhbaZrG3/yNZ0GC7T57Qgf2J+3fP+dxOf/Y25v++WP1FPAYeuv63tJ1AoYGl9kQEzFY8Ye/D+9rOMZKnY3mnngPnrGxSHfdsAmQTs+u9/ZslGe32qghi+lsDPgEUHSSzO6c1KZAs33/A9uUJEc8fHBEHIvDb7hPSrhnQpomZ0npSUIUxVLQnCQkccI2L2p4Y9o9ofUaxpJKK/peSIOzBZnzrxGsVRjTofqz8Z88CMRR7KcwP9rtYwWaXnj2GqfHh5yenNJfuUAaxbz2qU8RphGurUFrkrQANOiIa/kEd60jCIVqdT5E7O6mxOHou55qXTIaj6RaRGnS6UyezEm4mcDaAKdBacX+wQH7B5dBBzgUtmlIwoBP/8xnGO1eENDgU0r7piVwjr3dPQKtuXv3LuvNioMrl9m7LAWJXdtijPNFhIpxOiNJFWiN6SXssphOINIcPrzPtWvXsKZjPj8hvXwFFWi0ceRFgbWGMI6pqorToxOcc+xfvEgxHnldFFy6ehXTdx7wWWzfo4OA8XRK0jSAOPHqWrrQlA5QzpGnKYmPO2ibGqU1y9WC2d4F9i5eotmsqVerbYp5qBTj2S6TyQxMj+s7XKDRcYQzEmimFBxcvsrBleuAgl5OCmMt9XJJMRqx3kinUYGU8NZlyc7OHrOdXfq2QQM/+5mf9X1NPW1Vb8eTUSzap76ut1Ze2zc0dYlxhqgTyrhpWrqRJVbivkmyDI3yoWwZOgoJtCKNYkb5iLps6OuOIBarc1QUUmpsDToo+NRnPws4LyjvubC/z2q9pN5s6JpGYiDShOeefxYLUjWw2dC1LUmS8NInXwcHm+WCpml47hMv05Vr+l76xS5euOSDTy1luQGkIzCKE/q2I88Kiqtj2q7BOkuciq7ONi3desOLL78iujMtqfWr42MOHz3imevXGY1zLl49AK1Zz+fgZGXYdB0nx8dMphMRdCiH6Vr6rqXxiwxjLat1+RO+Kny8t6dHcT/Idn7UIQ5Ptb0Jhz6wV3rcWubz+ZYxOv9cg13dOcdms+Ho6Ij79+9z9+5d3nvvPV566SX2L+yTJTl9Z3j33fd44403eO+998jznJs3b9J1LfOTJXGcsLOzw7Vr1/iX//Kb3Llzh4ODA/b2dphMJyhfLDywKM452rbl+PiYu3fvcu/eXV544XlGkxGnpye88e1v8+03vs18MZc+xix7gg0ZymSHm/QAIH4c21bQzJOMzPmv4d+0//eBdUmShCzLCMMYY85ed9d122vUwOIMx/M8mzSAwUHT9DS4gXPgyb9/sfYHT0QAnI3Etu9k+xl6Min93Ps+p7EbgNxQwqyUZr2ueHD/MZIzFW3ZPhn1+UoZ/zXoeiMvVJ9Mp1vQr8+NniUYMxRxexhKOwMywYiigGAY6W1ZqYF56uUeOrxHr3tSWSqaux/x9rECTTeu7TAaKdJE0ZuGcrVgZzpGdRalUz9mAGs1QRSjQ1itSh4/+hDT19R1y97eHqPRhHw0kTTvwKeII6nT1hjatqPrmm2nUhyExGlMnImlsywr7t+/z2g05tKly+TjqYi4N2viNEUlEYEOWC/mki01GhHGEddvXef927c5PD5CRVIxsilLjBVt0Hq14epVSxKGVFWNDhSTmTw21nD12jXRJkUp127cIoliYZKimL4sWayXZMZRlhVvvf0uAHGasxsmcuNrpShx98I+GEtdNVJdEIVo59DGFyArTYCirWq6uiFKxCWllKKra/quYzQakWY5SkHfNERpKuGYlYw337/9vuRYWcVkNCbJUvrO0jcNKMW6XHuh/VSKLB2gY5yzaGfJdy8RVxt29g9wON59922iKGB/dxdreo4fPsA5Q101JHHCeDIhDEM2yyWnp6diz5/N6Pueo6MjrHEyZsgy0iCktZYHHz5Aa8VisWaz3rAz2xcXiC/tdL4D0PQd86MTVquVf9w9sqyg3GxYLTdcvHSBx48fsr+3i05Tms2G1ekJOorZuXyBdrMh1IpvfvtbrNYrrly9yvXr1zGm57333qWsK565cYMojKhWG/JsxGg04cM7D7h75wP++t/4BZLRjLBtUcrh6OhbyVWJdIoxlnrZUtNydHTEcrlkOptx6coVsjTm9PABdz74gLZu6I1hb2+PZ194gSgMuffBBzR1zfMvvshmtcJWckNdrdc0bUOW5+x5/d/Ro0fc//Ahe3t77O3uU9U1pu+YzKaMJhOa5ZKjw8d/gVeIj8f2UUzSDwKenr7BDUBkqCc5bzyJomh7cx5+93yH2SBgHrq5Jv78OTk54Stf+Qp/8Ad/QBRJPUaSZNRVzXK5omkaVqsVv/mbv8l3vvMdnnvuOV577TWcc9y+fZvFYs5yueSf/tNfQynFF7/4RQIto5myLDk8PKSuG3Z2Zty8eZNXX/0kL7zwIrPZVBYKPrSw83ErSqkt0Npa7r2reLghN03zYwNNH3UMhuN3/jmf3r+y4Eq4cOESzz//ApcvX9kygIvFguVyKeP5sqSqqi3IHY7JE2BM6+1jDl/n3ZRPgjlL37snXtdgCBoA0PnPDbAFPIPJYMhtOmOmhHSQEZvkPSkVkCbbZ8E56PuG1bpmuTpjr85/XrXWpGlKlmUC8AMpQE/TlDzLyHIvQPfZU3EsVVt5Jv+u9UePYJ1z9MN7Mka0xNbSdw3VT7um6Z/89/81zz17k7/9C5/llZefIcki6rYmz3N0JHUWUgLZEwUa04jL5/r169TVmmw0ZjVfMh6PRQvUnxXS9r1QojqKCJHFdJKIq0t7keFQ+RgEAXt7+6S55AU1ZUVZlezs7YO1lJsNzsmHUbKHZMUSxzG3nn8elCMIROk/Gk9ECO1LVyMdeoeIAInjw2PiNGZ3dwfTdagw5IPb73Fh9wLJaCxApqzoO8N0NJMVTpjwuc9/AR2GxGFIU9cAjIqCrmk4OTxiZ3ePMIyoyzWxExu/dhBlGaCIonC7InTWYvyq9c6dO2it+dSnPkVdVzhnefToEaEO2NvbI00SwjjmuVc+AU0rie1Ko9IUuhZbVyRJxoXxZfC6AJSWvrq6RinlAzxbotEYlAUV8OzzL4LrCLVoAzq/wt67kAhb5hQ6CNnZ3WUyncqqLQho61ouSlVNXW1o24Y0yykmY3GVZCmXroa4riMIU/q2p2sa5ssV69USlKOta7I05fLBFZxSlMsNlS0ZzWZkY3FXXtjdRSuH6luSOCQ6uIQxltXhobgcixGf+/wXOHz8kKoqaaqSnd0dPvPZz3q3iWM9n2N6x3q1oW969nZ22d3ZRSvF6uSUvuvI8oRkVKCMoatrgjAky6R1va5KLly6xGx/n77tQAfoMCLLCg4uHzC7dAllDK5pefzhhzLum83ouo6ToyNCf+H84M4HcoHzujy0ZLvESUyWZ+RFTpplJEXB6ekxm8UC7U0IO7s7P/kLw1/y7TxY2lq0n/p+uCG2bbu9QZ4f5Q2VTAP4OP/nMP4CpDnA38TFth2gnJKk/VwcV13Xc3JyzHx+yv37H/LgwQPCMGQ+P/XVFzVvvvltnDPcvXuXoii4c+cDTk6OiWMpvc6ylNFoxN7eLvv7e/LfFezt7/Hqq6/y6NEj3nrrLcqyfIJNGfRRSZLgnOQZ/aQA0/nt6WiI84DmbJ9LYnccJ+R5zmg0Yjwec3Bw4MMZ2+3+Hr6apqHruu1/H0pnhz8H1udptutsLHhe+wTiXhtGcoqh886/C5wbnJpDXpQXU9th8nhWv3LGTnnWcmB+ENONNR5sDVU0bsirOhtjNs2G1drvoyGTaZs9FRP7UV4cRYRepzQqCqbjsYxotZbC4EjKg+Mo9gXyktcXhWKfGaZ00Y8B4XysQJNWlu/+8RvcvXsb4zSjSUHXlCyWxyRpRt10WAKKyYwg1vRNS11tKG3L3v4epyfHfO+7b/Hapz5NrkLqsiJNU6wTEXa5WmE2G7I8J05Tural9EW6cRxJMXDTMNvZYTpO5EZ3Ome1WgvLcbokiSOSJMP2HaoXKz4K6motls4goqlLqr6V0EggyQsgoFptmE1nNJuKAE0UBkTTHdq24ejDR1hjmEynbE6WXBjvcv/d25RVRYAiiTN2L10kTvKzHdYbnHXEKhbxdWdJwoS2aqDtoJXX2K8rVCIXzeXRCVGSCPGjtWR+aI2zlslswgvZCwRRRFJkhEmEM5ZLVw9wvSEK5ON0+PgR9+7e48Yzz7Azm6KUZX7yUDKo8pyur5gvj7HWkkQRSmkBc8fHBEHAwdUrZMWIk8P7fPVrv8+zz7/A1evX0AoaIwL25XLJarVCoUiTlN2dXcn58PUReV5IX5Lr0XHApd0DsiynaRrWmyXH88eEoWbnwh73P7zDpd19ir0xh/ce0DY1ly5fJgoDmnKN8g5I0zQiQu0Ni+Upm/WGPC/AGR49vs9olLFel+zu7ZFkKSfzBcp2lJsl+5cvk06nXJ+O+d6b3+a92+9yvZc05t6vMPuuQ6Ep8ox0NMIai+ukYPje3XeZTqd8eP+UZ248w3g6RStHXVeUm6W87+mUfDri6PFjHj34kPTkkDzLwRomkwmpDlmvNgTAdDzl+OSEuhT3X1WWPHz0IRcvXSQIJFguSlI2m4r5QhYao+kMHcY8fPSYTbkhDEMmkxHWGA4fP2J3d5fjo8Of8FXh47/9sCzTeeC01b2cYxHO/855V1nbtk/c7Icb/HnmamCrkiQh9CXXq+WKKAyZzaZbRsSYkKZpODk55tvf3myBTZIkhKHm5OSYr3zly3z729+SwFvPGE0mY7quI44jX4mx4uTkGHDs7e2iFLz00ks8ePCAe/fuSeadf0/DeG9gI8qyZLPZ/LmF4H+e7YkuOs/0DGDiX3U4Wj+V+JC6bsjznKIoGI/HzGYz6RD1AHeo/hiO4Wq1Yj6fb8HU+b61YQJy/ufPM4rDdja+GgDUk4zV8Fq7rnuC3ToDS+opsPTUKG9IaXYOHQhgC1wADEyYOvc4nhHqzZYBNdZg2o6mkTDepwlY+R1Ik5Qiz0niiDiS8XOeC4jP80Ku90VOlksuYhLFhFFIkmSQ/ug/Fx8z0DRYE4cALAvOECWaMFEkQYTSEVGs6PqKsi4xtiOJYgI/ex2PJ+BP8MBXmzgnVtAoTdFdB85JEWbXS/N0JM47i4hvy7IUdktJaOVoPCJUIqxbr9ckeY5y1gszxamA6WmMIUmga1rKaoPFEsYReLdEmqaoICDESuFuoAi1RlmLCQKifETohdJxlpO3higQgWScZIQ6YHl0QlM3pHlG4ssTsQ7tFH3dsKpK2rZlPJ6wWS5Js0xqPbTCOkeolRTbJumZrVNLfUdSeL0YYI2AvjAIKZKczXqDUZBkKVmdsre/TxyHnJ4eoQOF8j1Jxnb+RG+J4oQgkiyNrm8ZT0YkSUxvOk5PjgjDgJvP3WJnb4rDoMOIMEzoTcB0d4fElxZFcSIuxq6T9Pf1krIuhU4OAnYu7FNkI7yxF6cKEjK0s8RaM8lH9H3P5vhEVjxxTKAhm07IRznjusY5CGOx6Y6CGPSSdl3S64owlHyqNM0I0wwFNOsVkYZsukO1WmGahsePHpJkKdPplMQLQDerNSoMyIpCIhSCkFCH6FBjlMWgifKEC5cvEAB7asrjh/fpu4bJdEKaJ3Q2ojwuqao1UaSJlGW2MyIKY7GToiiKMaBlRNe0KK0ZF2OUVpSbkqqpme3uMt7bYzybyfkWhfTO0bUNoReOFvkIdh1ZMcJhCbXGBT1BkBBpRXyuqPqvto/ens5k+qiYgT9tGwDS+bHH+Zvb8P3wc4NwevjZwb20LWr1eqinHVpRFGEDh8vl9S0WC05PT4miiCtXrlKWm+3iZTQasbOzQ+3NNKPRiMHhVlXVEyCt7w3r9Ya7d+/wO7/zO7zxxhus1xt2dqZcunSRppEk8NPT0+3vhmG4fc0DwBhe4wAcf9zb+eN1fr+f//uTGiDji5JPEPY+2up9hjHjoAMawNR4PKYoCqbTKXt7e9vHPP/YAxNVluW54tpKqrLOjTebppEIgral9cG/A9gbjsUA2s6P0rYidO2Tkj6C0XROdEbb4AX3tOZKYcx5wDSM7CxaC8CS4uFB5O6nOUNO0zmwhbOUm+W2FDnYslRyX5b0cGGeZOwn9/Wd6ZQwjH7kn4OPFWgyXcPFvSl7O1O07lmvF1TlmrzI0IEiTxLQkbRe44jCgDjIScKI5XyFs3Dz5i3yrPB9dOJkMp3UhKRZhkskfqCqK9CQpClRHG0/FEEkKyzjwUSUpaRFISnT1hL4dmkVaPBp0ChQcURb1QQ2JMoSslCBxp/0EUqFRKnGOUWYxRj/WQqR8t0imGC1IkgSpsEOTmmmuztiA9fSdddUFW3f+Togi1UOFSiMNWgNQRhhKsemWtPaTrI1lGO5WYITa7+xUqS5XFekHiQ4rdChjAz7XpKFJ9MJR8cn5FlGqLXowRzUZUmkNQcHl3GmozWScp4nKQ7JEZFMEe9uc4qybOjbjt29PUaTCavFnAcPHnLx4kX29y+RZAmr5RKtNFmSoBAnSKIjur5HGTlxkyAicEjHXl/6iIeQ2WyHxfExkQ6E8raGbCyg94M7d7h+7RZHh4+JIsVkMmG1XPLhnbtMZjOSJIZeUtN7K6WV1khcwmh/H9NbtLZc2t+X0WSSYnCYpmUEBElMZBymbtgcn2DznL3LlwkmM8r1CmcsWZaDhfVihXESepkmMWGaECQBVblhZ3fGaj7n4qWLPHj4gLv37hA+CpnOZmRFQZLGWNOyXi/QgWYyHRNFCcpID2NVlWw2JXEYYAONShKUc0J3pykuksqFzWZDU1dMp1OctdumcWcdWZySJRlBELF3YUwQatqqRCmI44BqsSD6cfDhf8m2P80d9yf93NOr/AE0PW1df1rwPfzuYPcGtpqmoX5kyFN6miUJfK5PkiU+fb4FnGcQ5AYnWpdARilaNDHAlvkZnkNcoj4nLZDryMnJCVVVonVAWW5IkoTZbEqapk9oe4YbeuyvR4CUlcNWRP3j3J4Wf5/fp+dBzbANY9K+73yJbEfb9mdi53OAZRDuD0BpOp0yGo0YjUYURbEFVufBVprKaHN4nvNaqAE8DWB1syklT66ut/+96zqsEd2PHbL9zLkRp5P3IK68Yazngek50b07S2DYslpP7rfz3z3JcA6jwbP9Kc81/Mz5L2sNxvUY652Czj2RVTWI8IdQ58A7/EaFz5H6EW8fryucNVy/eoXZeIw2QzGsBBr2na+b8EFaaRJ7+7mh2pSs1muiOGZvMiVKi20YFli0T+lGaVSgUEFEHgQeFUuytvP2+ziL5fexvsxXo1SIC4C+YTTbZXhkoS8tYRLjXA9KqlB0GBKrMQxRkR48GNODk8TYviqJ4hjiCEVEEDR0dYXSKXZYDSK1AUoHWJ/MO97bkV46a8AYVBSgW2GKgiRjlkW4QEpsJ3tTetOxWa7Eolmk2NawXCx5cP8R+aggLyRwDi0Xh7bvyfKM0URYCuf3f5ymtFXF4vTYu10QSn//Ik1VgbPedhwRhTF9X9E2ljAwWKNxhBijwGgCnZAkOVXdstlsmO7sUJUtzvTYpCXQelt0XNe1XGynU+IwQmlHmmQkkcUaS9/01EtZieV5Tt+0rDdr2qoiziIeP37MqJiy3qy5dLCLs5bl/JTlYiGr4jii8dqtyIc41nXDaDzl4uUrJHmKbRuKNKVrG5qqwaYppu8wGymSdjgCHbC/Kzq4RGmq9YZ6tZHYhEZWjsePD1mXa/b294WxG+WEccJqfsru/j5HJ8dEUcTBlSscHR1x985dlqslo9GICxcuEgSaum22TsG6bSnSgiTPOT58zKOHhzz73LPUdc3+hX1uv/0OCsezn/wUYwxHH96jr0ToG/mx6en8lNV6jVaanckOWCmKLtqcYpQTBSG9s1RNg7FmezP7q+0H2z6KIfmTQNV5we/AEm0dVn78NQh8z9vaxcEVslwuPUMQbd1LZ0G4Z7qngc3oTS+akWAIkI0Yjy8BsFrNKcsSYwxFkQOOk5Pj7WN1XYtzVqzi2CcAVuTz2bqupapK0jTl4OCAtm04Pj7eaqyyLPOP1W3daEEgC5+yLBlyqv4itvPgdDge51k60YOe9eENr/3pkWjf975sdsmDBw+eYAuHfKOBjcqyjKIotqzUAK6GsV+SJBR5QZwkmN7QG+kfbbuGuqrZbMrtmG+9XlNWJXVVbwXpA4MnwKql66TGZaiPGYTggQ5QWoqS8QvYj9LaPQ1mlfLap62Tz53bf2eZTU9uHrT5qhjl9VHODgwZnAdkMhkBpywnc5m6/Ki3jxVoaowjiWJJHQ5C8tGEUbRLoJxoiPCCNh9g6ayvVVGKvCiY7uzRNxWLo0fY3mD6HpB8nCAI6NpWgIo6oyF7v9oJwxC8dTaORNUvFxaDVhrrk7gditSzVcaarYWyNx227wnCUBT+OJQeLpACvLQShGyM3b6+tqqwWJwPhNwsllKg2vf0bUfoBXE4qVmx1sJyISBQKxGP4zwoXGFsh9bQtA21szjTMcpzVBjI6w00lw8OuHj5qlj2fQqrz7vEKQmrVEoJfQyAItSKUZ4x292Rz20Q+soAcUygtaTwGovSAdlo4sd/isnunszFdUCgNUUccWM8BsXWGWStkZ9RPcr0WC0g11knPXkegEZKkynFUC3ulEZhmTh87pVlp20Agw4ckwsXmB8dM53tEIUBysLe/j77F/ZJ8oxAwfLkhKapybIUPVasVxvKqub05JhLlw/QUYjpDXGW0bUdy9MTTuen4GAyndC2LRf295ns7sjJby26a0FrqqaVHKjxmIMwZrlZkRU5ZV2xmK8oRo4kLYjyMVXd0bSGfBxx67kX2L94mbZtWa83rMuGNE0oy5owjkmzjMXpKWVUc+3KNfLRiPFOSzYZc/f+fXa7jjt3P2C9WrF3cEBv5fPVWct0NuP09JQwDNnf3+PylSs0dYOyECpNF4kwtSxLwkBTlmu6tvJ0+I9+ZffTtP0gLNR5tmLQtQxjtgG0DI9zxvScjV/63vcdelA1aGrOi8IHNkokCto3GNgnQPHATg2jvyfHOHb734efC4LgCddWURSAGGvattlqrp4Gh8NI7nyi9EfZ5H8c20exfcO+G/59YN+H/z6Ao3P6BsRxdvaehsc7vw1g9fy/n48lGCIlhr+ft/UPjNSokPHeUPSeJIn0s/pYiJ2dne1nYisw7zsBWX2/HeltNpvt2E9Ys15K4DsZ9fV967Vt/b/CFg2hk+d1Uef3odZnzJ0AqLPiePnz/F6xWCcEwRaUcpbB5cB38w2SHbzrz4IS5ulHvX2sQNPBhQsUxVis7ga61pBGCcpZ+tagQ4VpOrq2x1mo6wbjYHfvItNZjjUdb775Jv/sn/0z5vM5dS3N9a+99kkuXrzEfH7KrVu3mM8XrNdrlqslDx88oG4aLl64QJqkbDYbrl29yuVLlyg3JYdHj9lsVnRtRxhF3Llzl5/7/Oew1vHg/gNGkzEXL1zg5PSENBbrfllucE5E4nVd07WGIIwZFxOmsxknx8eMxmMBQx70tV3Der1id3eXRw8f8+H9+5RlyZWDK+zv7Unnj7XUpTgwUu902axXAOR5hjU9R8fH5HnCwcEV6qbCmp4iT0myRATxKiDNp+TFlJu3brFYLDg8PgItQWNOQ5pnKAcWR9919G3HdDJhZzojjWPSLCWKE7/6EnBDoLeOQPrOnxleJQ+gxUHXeXCotMY0rYC9cChqlEJH660dgfarbGuFSXYgTYLnSiOd9CHpLbVsJMZLawzy38c7O9LrrQMIA7I4FKbOyYoon4zJ3MhLuxRxUTC1CoXcBFSgsYmAap0kjJKIbGcmvXdRiGmkrkZFEU4pHAFROGUnz8E6zz4G6Kwg299D6YC8bsA5Aa7OoYl47oVXiBPp/BpP9ymm+9RliXr8mN4Y2g6MDSiXJScnK6qyBGepu54L+7u4AHQScvPlFwnTMV/9+tf4zd/8Tb76h38IOKbTKV/6vd/jk6++SpKmXD444BOvvMLVK9ek29BYulayw3Z3dymrNUpBkkSEgaKqa8LoL2bl/5dlO29p/5O2gXVR6iz/6LxbLgzD7Wp+2412buXfdd32BjRk+pzXPA1MVNM0mL73lS3yXG3boIOAPMuJ43h7o5WbtOhH5AbmV/0KQp/Y7Kzbio6HzJ4B+K3Xa6y1W63P8N/Pi5wH8fOwiH1a7/OTOB5PsydPO9jOM0Tyu+d/lidA03mQOxzL4ViI81o0jsM49DxUGJx3AzvknOXpzKQ8zxmPx9ux33Q63YacDg7EAVwNoO+8q2+IQxiYveGrbmoxoPjsuAE4nhfJw5l7T9go64uInX+vT+qfZF+Le+88qJLr+ODCkwT084J7AU12u+/csFgGokiceD/q7WMFmv7h//H/xNXLF3jppRcIw5i+azl88BBrekm9TTKMaXHWEcUpSZr7D6rl7e+/xZe//BW+993vcvPmTX7+538e5RmXGzdukIQxnenpu46DgwNZgQHL5YrlYk5RiFj49PCI55+Tdu7lcsmNG9dJkphyvQGlePDsA159+ROsN2uuXLpMMSpIkpTZZMrB1atkSSKgyRoc1n/wLXGcMh6NCaOYe3c/JPQ5HzLLlxHR/fv3uXZwhUt7+7z+yiuoOGE0maCMYXFyzGazovUFwKNRgVKarmuJk5goDKkrSXyOIln5dX2Dc8bbfwuUUqzWJU7F7MymtE3Dl7/8ZX7vy7/Har0WkGRlJHp6siDLY65evUrbNCwXC9q6I9AhaSrsWtNanFY4X6Qo58/QYyehbUGgCUKxwfbGgyOkw0k73z806Di8vVVpCENJivUk2NlqZrDBurN1nlzMRHiolUNioYaLl0KhuXx5l5PjJS++9DIvPHeLcrNiOZ+TxPG2MTsIZcWdpQXOQbkpicOYdFzQ9A3Pvfg8ZVWKgygIMdZQjEaU6zVJmrK3u0dVVSxXK+I4YX9/n9lsRmwttrG0dcvuzh7lpgTrKIoCrTVNVZNFijTJ0HEKCFAMtEaPEw7iHOv7Bwf625oeazocliAKiaOQfDJBhxFFGOKU5t/7n/99vvDFL7K7twdWNCef/fzn2dnb24pWx8WIzAdnKjS2FeYyyTPSLqEqNzRNSRCFTHd32Gz+Ktzyz7MN7EkUReeqND5aq2Ot3QIfYCuSPp/F87QGZ/j3p0XoT49Szocgaq0l/kR+A6UdYTTc4I13BluC0Bel/iuMgRtumXKTtLLwCUKxwbdtsw2+dFg5z556DwOIG17r+df89Hv5Ybfzv/80G/KnbU+DqKeF6OfB6BmAGMAA2+/Pvvz+VWz/vr+/y61bzzL2wcsbX4WVZeIaE3au5eTkhMePH9E0LWEYUZYVddXg3Fm46Z07dwiCgCzLniiQT5JkO+47L04fgk8HMTqwZTOtleNkjWG1XvmuU8Pp6YLVaklV1X7sV22ZrLbrqKuKpqmxVuJ2BpAV++vrEHmglCYMgy2oks+lxTipxRKRzFkHquia/PdKelMZtFbe9f2j3j5WoGkyGjOZTAmDCGssTdvTNC04R9cbKQ1VmizNCNMcVMjRo/v8wR98lW+/8W02Zc1sNuETn3iZl19+kbIqWSzm5HlKsxFk/e4773D54IAg0JRlJReotsNEiTRbhxodaqq25nR5SpImFJOCsItYrpY8c/MZsknByXJOmEZkhazIZuwShAF121LWNWW1ASxZljGajjG95fBEohPSIuXSxcvc+/Aujx5JaOLB5cvcyn0Evfa5IFECWrNazNER7F/cF4eaOWuPtjbbUuXZOGfv8gW6rvV9RWCdVMj0pmO9WjJfrEmSEc9fusw7b32f4+NjJpMJeZFjfAZHnKRcv2qYL44JveD98qVLRGFI08h4arVccPHyFVSgcNpJXhWKuqoZgtadv/E7JQJxazsR+yECZdN3hDrcRhlMZjNWm5JHjx9RFDldJ8nl1hrfjyej0t2dHfZ293AO2q6lyAuOjg5J04TJeMxyccrpYs7O/h6npydUm5rxKObrX/8qFy/sc/OXfoHVasXtvkfhCHzpZOhvINlQFZAk1FUtjF2oufvgAY8fP+LC/gWCMODDDz/0F1PlO8I0aZbS94ZyU+JgK/Ysyw3r9Ybd2Q6LkzmmM+zs7hAnMcvFitl0hsWS5yOSJGK1XHPnzh3KsuLa9eu+hDPi5OSYqqrJ0oSqLgmTkEsHl+jahrfeegtnHUkcU1Utk9GImzduULc1pu1Yr9d0vWGxXEq9gxdVgqzyinxEkeYEWnPj2VuMd3cJ4oT2UKo5VpsN053ZT/qy8JdqG3RF5wMrnwY4A2japiqfA1cD4Bi2jwIAT9vSzwOQpzOP9KAjQZYXAyNhndnWdwTBMKo6+90tsNhKHYbn5lwFTMdQoyGA4V99DWfC4T95+/OAph8GJP1pz/U0gPuI3/hXANJ58Hf2mGdBkuDIsowLF/Z9mHJJ6UvDkyRmOp1sWbq9vV0mkwnG9GRZTllKCOnG65ZWqxVlKQuZvuupogrnJElea70FYQPrNGjgBoZq0EoNI8cojIhCcVfHSbwF4pPJjNq7jLu2pW7qbTxF0zRUZSVaUj/qk4iVHtQZu9W2Hc5Zf38YFn8Oi8E6mQrY0HfNBUOkgaAkpRXKui0ZN1x3LT/4sf1Bt48VaPrv/tv/lqsHl/h3//1/lxc+8QmiMGayI7bMKBTBdhDFOOe4f+ceb37nO7zz9vc5mZ8ymU35mZ/9LAdXDphNp2R5St+3KCt6oSiJOHz/EVmSMC5yTk9PWa+WTGc7jEZj4jDCOcNsd0oYB0RJxMzM5KKjFWmR0fQtu/t7qFATxSFpISWtCuiMYblaEgQRFuhMTxBqkjwljCOadoPUqVriLCOIQxGfBwoVKlQItnMsyyWXL1/GAaulCAjLzYog0OSjDB0FtLXosJI4ISSUk66tRBCaxszXC+I0JQzEwdY2LeuNdM+NZzMm413u3r3DG298m81mw6VLF8lyaa+3VkDTaFTw8OH9rWW0yHLyIqfvelbLJQDFuMCqHuN60jQhDhNM29N1PpW2FW1FmueoIKCsKspKRNdaK/q2Y393h6qUEtliNObo+IT51QNu3LzJ6ckJp/M5WZYSBiF916F1wO7OLvt7+/Sm5/j4hNFoRLkpsc6SJil1XVJ3NbO9He4/+BDlNAdXLtPUDa+//kle/eQnWa9WJGlKGkdo7URwHsfkWU6gw62VdbXa0NmeRVVilDgop7Md8jwjimIBIl2/DXgTXUK0pdWHVbokMzuOjg/JEwG6q/WS3OZ0fcPR8WPG4zHHZUmgFXXdsF4tpTLn+BF937O/t0e9WbFerXCmoKordKfZ63aEDS03LBcLdnZ26NqOti6ZTSeEgeLuB3dIkpSj4xMZLXB2w3Jey5ImGdPxRBYTSnPwzDX+zb/zt9nbvUTZ1Dx6/Jhvv/nGT/y68Jdpe9K2Ldufxaz8oIzLnwtcwBbY/Lk2eYAnvhf2Vxiosx96cvthXuu/jrbpB/3dP+31/GmA6fyf7pzb7KMf80l3nrWW+XzO4eEh9+/fB2A6nXBycsLOzg6z2YwoitjZmRGGIePxeMs2fXjvQx4+eAScVebszHYIo5C2bTk9Pd0Cp0EDN2RADXEOg/h8m+Dtx47DSC/PM9I02/YCFkXObLZDlqUopX3cTi9/9iJK32w2rDdrqZRpWlar1fZrvVnT1M22nFfaORrpwRs0wAYxWGlgCN7Ef0aDM0eeEHqeefoRbx8r0ORMS9c19H3n3XKK0cSLa/3PnByfcvvdd3nnu9/h9nvv0lnDsy88ywsvvcDOdJfRdExdlZyeHNF3PabrmJ+csLO3j3KGa1euMC4K2rpBoRiNxxJW2XX0vWGUZ6KR0VBkqVDpvjE61Jq2qUVHZA2BimkqqR1p6xrbG9I4JYtHpFGAjiQ8Egs2TSiKEVEQgQ5o64pJUZCn19jf3yVNE+anEv44LjLqpgHb4fqOOPSjPBx924A1ognqey/gtoT+g661put68lFE3zcorUVopwMmsx2uXrlGGo/53X/xe5yenGyzMDxrDNbRNw2ls4QogjDasjCtP+GyLGVnZ4eT+QmGns40dG1DHMRglawwUDgjY7gw7BmNc1wK1lhJlHWGOA2ZTmdoB23T0rUtSsH+/i4XL+wTaojjiKLIUEpjjZC2YRAIMGpq1qs5dbUhSVKcs8znJYEOmExGmF5C9tIoZXdnhysHB+Acp6enOGuJwoBilIOztG1LFAtz2NQNm3JDFMeEcYjrHK7rKYqccDqjyAtGo4JJlosNuKxwQBxFNG2DMT27Pi8sz8XZVO/vE2jNcrXk4OCAzXpD1dRMJxM2mw1t27Iz3aGppRrAWsdzz90kjqTm5fT0mL29PW9o6IjTRMLjcEx3pOH+2pUrPHjwgMSvJterDXmayaj5dM6Vq1dJk5TGCzzbVj7X2/Fo31OVGx49esz79+6x/B9+k+OTx/x7f/9/weUrN2h6w9e/8Ud/AVeGj/d2/uZ5/qZ5fjuv4xi2j4oW+NMe+8/56p7684f/9Y/CFNtp+tb9pM7m6X+O7UcBfv51tj/tcc/A0p/+BoeHGMZSWutt6fKjR49wzrFarTg6OqYoCnZ3d58Q0+d5zs7OLkpp6qbGWEOSxEymE3Z2dtjf3ycMA5q6ZTQq6Lp+G3ngnGW+WHD4+HArArfWSEF7mgi7FAZodTYSLoqCKIr9Y2RMJmN2d3cZj0e+NFjYoDCMiCL5mSiKGY8nEuWCkkV/WW4DO4cvYVNbqZvZrKmryovOve7LG616023zp5waXOuiZ8WPrX/U28cKNP3P/u2/S5qkzKZjurrEokhDCbO6d/ce6/WGd976Pm9+81tgDTduPMe1G1fZv7RLXZe8f/tdZjszNhsRHSpgtVjStC03b93Cmp7l4pTVaoE1jqZuWZwuxK7vZ6tBKBUj4j7ot2K3tmlwwGlRyEFtO3GQKQ0OwkDmyM3ahyG6HrSjXC0FNFmHDiK0U4RJTKgCKZa1HaYdYQNNuVwQJwkP7t7l+OSYMAhJ4wSdSDWBaVvRAnhty7ppvUNQxkPaQlPWmM6QBBFdXYNyuN4RqojpaEaeFrz1/Xe4d/9Dmr6V7KEwFNDSiN1fByGb1QqcjNfQkOU5fd+xWq29e0NzcnKC0w5je/k9V2E6i7JiWw2DmN5a1uua0/ka4wQwKQV93xKHESdHp9i2ZbXeYBX0ztCajvsP76MVhHEo4yVjsb31Ym8Rx3ZtR9uJxq3tevb2dqnrWubzdDRtTWcM63bFwaVLmL7n0YOHPHzwgJ2dHfq+Y+MFqmKPxgtfW5qmJYlT0jTn9OSEdlOhmo6uaxnpiKrtWC5XMsoLA3DgTANNTRwEOAvr9ZqT+Zw0EfdLkMRkQcB6ccK63BCEIc511NUK6xyr5Qnj8QRjNKv1mkBb0lFKq0Bry3J5QpEXUhFkWxFQGsvi5ATX90xGI/TlAw4fP0JZSxRoqnJDZQXQhYGmKHJG42Krmajqhs4X8uZJShyGZHnKxcsX+P2vf43/36/8Ck4H/IN/8B9y7fotPve5L/4FXiE+/tvTF/nz+qLzjq2fhHMMnr7N/3gAx9N01I8L2PxFbk8L1p8ekT79s9JdJ+y+sNIG7fVfq9VatLbLJUmSbB13YRiys7NHmmbM56d0bUcUR+RFThBorDX0vcM6Q5zE5EW+jS1wzpHlUuBsbO/jQ2qRbvQyNlO9ZzStvP7NRsrCnWPLRA0atNiH96ZJQl5Icvfg9CuKgp2dHR+VMGI8nmxZLIlqaL24vGOxkK6+9WpJux331ZRVTVlXkk3YtvS9wbgzJ54MkxXG/pQX9r7w3C2WqzUKRxgnOK1YzOfUTcuv/Mqv8OHde/RtR5YkfPLVV3juuVskaUBnGx7c/5Cjo2PCEJbLJVmeoxycnh6jlBTRxmFIVW4kODHJ0Dj6psYaS5qlOCxNWaMDTdM0OCduD5xocmazGbZtybKcVSmitzwr/OrfcHp6TJZkOAVtWxFEijRLsMb5+W5AW7dMZlOyRFx2dVPS1hVZnnJ8+Ii6abijNPP5gmefvcXlgysoZT0Lo7390luHg5C684JKpXHGsmnWYC3LxQpjBNg1dUeaJUynO5yczPl//1f/FetNhUaxu7NDkWU0VUXftiSxRP9X5YYgkKqXsioZjUdYBavVCmuNsDzrNU3foENNnmRopzCtIVQBxhlMYLEOyqahX65wiFssjALapqYNQ7R1OJ9l0jlDj6EzhsViTlHkklrsnFSRNHKyRaG4fxyOJJHjs96ssKZHK0djWjYbcZcFYchivpTspWJE2zY0tfQZmt6wqERLoBTUdSXN62Es9LM1khPjwHY9q81aglGjiLZtefjoEXmasbu3y3w+p+vbrdDS+aqEqiyJk5jRqECrABcqjlcnqDBgd2+Pk5OOo8PHtG1PFIY8c+MmfdexXi44qipOvLhzU64pNxu5MCWJB88SuWEdpJkwSuPRiPt378nIuJPmcmOsB5M98/mSJJOLm7FGwvGqmkApbNaRRlJlUIz3ef31T/LNN7/DP/v1X6fpOv7j/9X/mivXbv4FXyU+vtvT4Ynntz8puPKHffyPeuyP/Nmtjkka5dUPCJiefo4f7OWesVlukPh6wfSfRMz8MO/lf2zb+dd+XjT/9L9J+K9vrwgkumU8HtM07XbxNwRbDgvFw8PHgKKua7TWW9G4nMvSiWqMSAOSJPGBpWdhpIOmaTRqCYIn4yTOxwTg8Dlh4Tb+wBjDfH7qA1OHfKmEzDstg0AThvKzk8lYroNOOl6LIifPR1707wjDyOctphSXM+Jr1wAfA9TL+278tbquajZ1tc2dalrp79uUpZhqfsTbxwo03bt7jzQvCKOQMEk4Oj7kt3/7d7hz70P+8OtfJQg0r3/yk/zMp17nxjPXaeqKB3ce4TCs1kvGk7F0asURo6JAAVmWEkUxk0nBerkhjWK63hAGEYQRodaEQUgUR5T1hkp7nZGSC9loVOCsIwgU165dYbPecPHSZQ6jCGth6qPcl/Ml88Wc3d1dHI7VWhMlAXkhLfVt02P9BzEMA6q6RGkkkNF0lJuOPM9YrRe4ICTJYsaTMUkSCYBDBJmST6HI84IwillvNlR1LUWzdU1ZVmRFzqPDI/IsRYfibJnt7hLGCd/40lf49htviqsiHxEGoYzdmkZunJ2vhbEyY26DgLoq6doGpxWNX/UEYYDBsFovZITVtzgDrrNkcSbBk71Fh4FovHoRmRvboltFXVfY3tDlmR/P1Wyait45dBhQb9acHD4my3OSOKbzKxCQLrokTui6nrquJOS065ifHuOMpbeWRosoPYwiVssV9+7eZTGfc7pY8vwLL/Diiy/SNg1NU0sxZBzhrKOualSqCIOQuitpshoUdKajbmvSJGG1WVPVFVVTU9UVxXTEqlzRtA15nlE1FThHEAbEeYKzhrqrscahIxmvjfKMQGvqTYntDE1ZY5OYalPRt73/6qisI0CThDHrfklT1rhOQiaHmiCFot2UHK9W1LMZtu3Rgaata5q6put7VpsNQRBydHRElMgqsbeWqq5pPcOYxwmRDoTNCzRJkfH666/xR996k//6v/7/UHctz9y48Rd7kfhLsJ3PKRq2QXsCH617+vFsDhgcrX+e3/3hAI1D3LLbrIIf5Hf+R81KnWfQ3EdomrZ/89+fCcEls6l5Ql+0t7fLhQsXBRBsNh5IS7VXEGjiOPFJ4M02Ly0Igm1H3/Acwh6ZraNub2+XohhhTM9mIwXJA4ganJxDN9w2I8sfI9Exzbh48TJ933P33h3KUiQQSSrGpSECY7PZsCk31FXtoxmCrVtOnHxjfz+W3rgsyyiylP3dHS7s7pFliYC0cGC15BzpOkPVVGyqDeu1xCTUbc38dM7p6Zwv/f6/+JEe1Y8VaOp7+NRnfganFd/7zrf4jd/+Hf6f/+gfkWQ5X/z5z/FzP/d5nnv2JnEQcHRyn/nJKcb0foUGk8mIxWKOVprNekWgpUF5tVpSbjas1yvSSCL8a1fRdzLiydIMHQZsyjW9685GQl7IG0cRm3LNarNivV6T5jld31HXLU3bEkcx1jhGozGTmYjHsyLFaUkVz4KQaDehblrpn9Ma6yx5lpIXGWDQCrLsgN0Lu14noNnf25NLWtdK8rRS9P7D3VmD6TrqtqXte6ztadoWB0ySlKbv0V2H6wyXLl/iwsFVvvu9t/hv/vF/z61nn2UynhBHMXEYYdqWHiVON9NL5UyS0LUdm3VJFAeEYYxxFq0MTduyPlkx25tKOGbfUVqDaXviIGY2nRDogOVyhVKGNE7oTEPft/RGVkJt09DXDZtTy+5kCtbi2oY4iYmjkPVqxWq5wlQVCyO2Vpwkx260MG7rTclisSDPc/b39lidnkr9hA5wWmOd1KGslyu+88abfPDBHd65fY+93T2+8PnP0zaSpNt3HdbIydx3PRtT4qylrhuWS6kZ6ZqG1rQQaTalNJTXtsNaQ9nVTPd3Wa2FTWv8yjAJE3SoqKqG5WpDoAKoNToMKNcVddmAdbSNRDm0dcvp8SlVWVHXFXEUEQUxXdvLBaRX9BhsL3ldcRwTFuLAasoK23Y8/PA+QRD6mAFHW9dopZmMx1RVSdPUGOdXsJ24GcMoIvYXvtKv5KzWmMUpo+mM55+7hXGW/8f/7b+kmI7+gq8SH+/t6eDDARQM7qVBIPyvUx3ygwONcyOzHxo4Pa0A/7N+9qOf9Qd9iB8WPP3ExptbzdZHAabzr/m8dsxt09KN6YnjkMlkzKVLl7h+/TrGGJbLpY8BELlAkiRMp1PiOEGYpoq2laaB09NTVqsVdV17957cD43p2GxEbD10Dw7gTj5jiiAQh5pzA3MY+O9ljKa1ohgVXL58EWsdZbUmS1OSJOH/z95/xdq25el92G+EmVba+eR7bq5cXV3N7uogiYJpmSIIybYYYEAwn2QTDoDhR/rBfrQpioBhG2CQDJACKUuMlk2RMpsEKXWs7uru6gq3boUbTj47rjzzCH4Yc669z65zY90Kp+uOi3N3WmGuGcb8xvf//t83Gge3cq0jrLWsVksOHx+y6rY9SWKSJKWqWhbLOersrCvfDTowFTEZDXl8XyPxRL1TfJqRDcK/QTYgzQYBoCUxw739cJ2oIHlZFwV/+//1tz/SY/pMgaZPfOFnOT6b8tU/+H3+4T/+x3z5K1/l6rV9/o//p/8Dt2/fIl+vMG1NVTSUVUnT1uTrNdY4tnd2Qqv2aslkMgl+ILYLzPWC2XRKkVes/Zo4TnDG0tRNKP20Bh1HeOGxxtKahkhryqZhNpuyvbWNUpqz0zMGgwGPHwU7/DgJnQXGhjKHd5633n4rdGJlMV448vWKOImZTCYYY3Hes1wseP7552nqhun0FCE9WZYQRxHOOU7PTrvHG6yHumkQUjEcDDHWka9zZstVF0YMkY4QSpCkKXGScXh0HFayBJH33t4BZ6cz/tk/+W/58m//Dn/qT/9ppI7J8xIbGyIZfDOqYo3r6NqyaxGtW0PRVOwNBoDCtg1bkwnHJydss0usk86KXzHZ2mZvb5+9vf1QO3/4kGKdEycpGQ4KhxIQiQSjIwY7u5wdHqGqYOsvbBvq63VF5D17kzECQV7XZFIHk0ilQ04egiwbsr93gPWO8WjEcrnENZZIKryQtLbFC8l4ewulFJ/+zKeYzhbUbYuXAqTospzqIHwcDCnKclN/B9npCSzz5YLWBBCZZgnGWtrWkqQxDw8fsb29TdNUKCVIh0ko99U11bqiqmucteG8sx7RlV+FEOx0IFsJiWkbZtPpE6Gdy9UqMGJVxWgwJIrO9WfOLrtYoeCoLqWisYb1esViscCL0IVnrMV1E3prTHC+x2B6vx/AGcNqtqBcF2zv7FC1JU4GbsBLya2b13nuuef4B3/vH/2YZoc/GuOiSeVl08SPgml6/3oov/n/B+WLgs7lw+muvP+g7/eTPp4ERufg49KjLpTnejbIdfYrPSuTJKHc1YMVa8M+1lp1zTfbTCZbRFFMVYV5a71eB21mHr6Gx2ugN5xsKctio6PqtUk9OH/CZFL0+YaK3tAS4RHCk6RJ54kUSqxSSeI4YjBISZKM4DeoWK2WaK0YDDP29/c52D9gtV7x6NFDyqJECBiNR4ExywvqfE1dFlSrNcI7Eh384wJoGjAaDRmNJ0wmY0ZbE0YdWxWnwYMqi37KA3t/76tf5a/8x/8Jr33rW3zus5/mL/7F/4jPf/7zPP/cLXCO1WLFerVEdczSeLTFaDAiz9c0TUWWDRiPmy77qGG5WLJe5bRVTVPWxElKHMVkaUpTtwipiZOEtm2wrmMoXKCsh9mgiy9xTMYjxuMhb7/9Nuv1migKAZf7+xlCaeqyRuoIEMEsUUsiLdFKECuF1BodxSAMo+GQ3U4kZ6IoiPfahiwakKUpx2fH1HXLzZvPUbcNp2dTyqomTjO8VHjhww0fgXU+lP6MDcaECObLnPl8jpaKG6/e5NatWzw6fMzf+Bt/k//y7/9Ddra2eON7d0L3RNsyGQ5QsAlmHQ47R9koZjCZYJRkua45fusOdRO6vbZ3D2idIElGCCKqqmFv9wrGNPz6b36FN+89JC8qROe4ffPGNf6tX/5j3Lh6jXKxRLWW4XjAZDiiPZ6ToLh68waFrZlXa6q2pa5qhumAIs8Zb4fS4jIvaK1lMtoiG45Y5zmL5RqsJdIZUWzY2jpgPNmiaGuW+ZqqqmldBb7l5774M9x9eIh1jqYNZTepNFkWIYUCNHhNXRkKW1N3IChJ4o0bbdO2RDpinefMZwv29nbwwmHMcXAjF44o0ptIibpqsNYFM0xTBU1WFJMlWYjqaVqOjw+ZjEakWcZ8frbZLmvdRpRZG4OsKlQTym7CB3aibgweS5YNmM6mKCVZzNdY55E6CtmFWlOVZbC+MC21bfEQ2D5j0FIhEJjG0TrBo8MThAKdaoqqYpWvsd5z/eYNfvFLX+Bf/6sv/ziniWdyPGEI+ZTy3EUDyo/qvd55XPAw6m6CH6xE94MI1c/b89/PuOh59JNcqutNLp+6b/zFx4SvWutNq39f2jo9PUVKyWKx4PT0tJsroo3552QyoWnaMC90r9V3o5VlyWoV0iF6wfZF8H3R3LQPSu4NRnvrEaU0QnSpCkAcB/CV52vWqwVN2/D48cOgGc4yFoutzlMqAJkgaG9QSnDz5g1+5mc+z6uvvsrx8TFf/epXOTk5ZXt7m8985jMURc6D+w+499bb2KZmlKZoIVDdtWGbmrxtyJcLjni0ieySQiKkRChJNhyE6KyPeDxToOl/9b/535IlKb/8y7/EZz/zaax13Llzh7apuXf3Dq998+toLXnxxReIIs39e/dYL5dcu3qFsi544423uHf/Hs/dfp7WNFRlxWQ4wVtPGifcunGTRw8fce/OPdbrHB3FjMbj4COxXpKmCWkaykOPHx0iVBBYx9EDpJacnc0YDEfs716hNYaiLJktVhwdHZOmGev1kr29vcCadCJl5z29O3XdNIzGIxbzkBYfSYXwnpdefAlvPX/rb/0thpMR/+Yf/7c4PDqhKEuEEOzu7aOiiMeHh6xXa+Ik6ez3Q53bGkfb1mgdIQgX0nB7h8Viwd7eHr/167/F7/727zBOU65fvc709DSUGeOYYrmgWK8xTY1zHkNw5lYIGgFaSYxzWNetnjz85m98GeE9v/vl3wcPjfd8/pOfwluLijX/8V/+P/PHvvTHaMsKWxZ8+Td+k7/zd/8LvG3YHw1Yn83Z3d7m+rVr5GXJummZzqY0viUZDYiSlKYoGUQxTVFycnjEfLnkdLZiUYRSopAS4zxNf4FHCmMcnnBBBf1U123nPdY73v7udzk8OmaQJnzr618POYCtwbYhF6/qGC/nQhimJyyL54tFmA6tY7FYhdftTP2KomC8NWI2mwEW5wxRHBHHOmyP80RROEZ1XYcA3MaRr3JGo0Ew8WxqYERV5qxWa0bjUbCUaE1n7umIowQnBM5Y6ralqWqkEBjnMK0jivJQmhSBQVuu1iAk1jlaGxLEnQir2rPpafBfiaIgFDcGKRSDdMRgMGS+mAc2yzQYZ7He4pzl7ptvMBlPfmzzw7M4Lt/oL+eqvVP33I9o635E7/Ph3/Uio/VuQO0nAVBd3Nbe3w6eZNYulmPjOEF3TS19abZtW+bzOYeHhzjnNpEoRVGwWq26spnpzE/DObNarbqOcdt1tkUbb6ZwvgUmS8reTNV1+YQBLPdaLCnpDCVDd3ZvAVBVZRdeHz5ZL1Bfr5c0TcViMd8Au9VqRZxotrbG7O3tcuVKMFtOkhhrW9q2RkpI0pg4iYJko20YpCmJ0kRShGguKcAHX7zGtLQm5MBa11nWIJBaYf2TZq0fxXimQNNnP/05Dvb3GWQpaZKSxgmutTRlzdZ4wi9+6UvojYOxY/ypIXGsGWQpJ6fH7O7t8fzzt0NGmFbEOkKhaKoapTSDOKXtdEVt25sPRgxGA67Ja4xGw84TyPP40SOE9Ozt7ZImCXXbsr17JbispgPKMuTaKaXY398DPFtb400nxHCY4bwNZRIfHLMHHYtTFkUw9jKWSCmOj49o25aXXnmVn/+FXwABi/WaphP0npycYDudQ0gcF8Rd5k5dBm2WMRbSlP39A/AO4SXr1Yr5dMoXfuZz/Cd/9a+ye+UKcRyHkGAh0F1umm0bhAqBu6ZtQ8SDVpigPg/tqAIipZFOgLNo4bFOgg45blf2r/CvfvVX+Rt//a/xt//mf8bu9h6/8gs/z+H3vsPp3fs08zWf/dyniITj937r9zFe4CNJ04ZummD4AJmC8TBBCIuWGu8de3t7fOpTn+KVT3yS7f0DnJBB24WgtgG8Pnr0gK3JDsPhkKIoaaqWQTJiPpvRmIoHD+6zzmccPzzh5OEh3/3ma1y7eo08r6jqFqE0zofyldQilGq9QSqBqRu8daRxgo4UrekM2XwAJEdHR4wnI6xtMDaYutnW4hxYY2mkJY4STGuxOMqixLYhV7EWNaa2NFVDXuYY56m7smjTWqwLk5qtG+qqJlKKYZYSJylt3eLKYHEQ6YhlVYXOlfEWRdfd6LxHaY0xNUVZhFUaobwplMY5y7qsEV6Q6AFxLIiFQuqYslwjteyYW4e3jmK1/rHMDc/SuNxu/k439MtC8B+0e+79jydURR+sXPZ9T3i3Z78DkPnx45t3H32H2wd70pPP8Oe/853b52VmUalQMek753rnbuccZVkCdL57LXVdM51ONwz2YjHvOtOSTZZgWPSnnV9fHx58kUkKWaD9PUqIc42d1nLzu16P1Zf7rDWdeDsmy1LqOkFK1QnLa1YrH/LhfEh6GI6GCHyXX5cH3Vak8c6yXq84PHzcdZzXCOGR0iO8QwtPrDVxFKG6EqJ1lsZENF1ki3Gha9gLaEyLqdsPd4zfZTxToOnmjVtIEbqezk6meOuZjAeMhgO8s1y9cpU4jpjPpjhniUfDDm3D1atXQQhu3LhOY1pUpMFBXQYRnGkseVkQJTFXr13FO5jOZjQm3FzTLENKifOhRTsdDEJQaZSwLkra1nDz5i1WqzXT6YwsHXQnju9ozdAZNxoNQ2K3EsQ6mE1WdR0o2SgYiQ3Hw1BWXK2ZjEYoKVnMF7z66ieQSjGdzjAutHsKJSmqgrquGQwGXahmg4ljpFBBc2NDa3/bNKxXIS/INoHCnc1mm9WGMw0q1sTdaqIp8w6seMoip+1WKulwSDocUna5ecGB9j5HDx9Sr9c0ZUVblCSDAQZH1Vqu7u1x/43vcbxc8Ma//Ffspynf/PSnufu9N/jyV79KqxS/+LM/ywvP3eDG/j5Xn38esTWhwqMFCGtQQvPmt77NN37nt2nbmrqoiSPN7du3UUJw9PgR6/UKFUdMtrfY3d/DoLl2sM3uJHRJNnXDIBZsD7fZHR9w9Pgx48mQ6nOf5dqNKzw8fMzx8TFJGlOuV8G13IWOSaUUZdOwnq/xXRdl09TsbW+zmC9woxFCQmvCPpdadp0oCVIoVkUdmCprwmRF7+EVIzKNtWFFl6ZD9EjjgOPTKWXZkGYhfDqOA7jKixKkJM2GVHVLvl5hjSGJok23G87RFBVxFLM1GPD46JCmbdFK4/DkZYk1ljRLscZijSVJNINsSGsNeVEGAznvsa1htlhQqIK2aYiTTivgPd4E89QkjpnO5z++CeKPyHinWI6L7NMPnzl50gbggwzh6Wjn9zIqeBczzg9QnvtB9sUHeu47Ybz30G+d2wiErEvv5ZOaJh/28kVAEgTeNXmxpuxSEpRSoXM8CqW74WgY0hqyAc4G8fje3h5aKe7dv0++XiMuZPfVVU3TBlfuXtukdXAI70OCkzhGqqB1xfuNrrFnocLnAQQMhkGrlGYJ2SDtGPhgBdBLEJSSRJFE0oM0h2lqHj28j9aK7a0Je3v7/Mqv/Ao7uztBKL5YBCBYVUzGE4xUiLoh5Ks4wOGd7WQULXXb0nYu4kBXqqPzx/spD+ydns3Ae2KtiLVmejalKQv29/YQ3vHo4WOkAGsapBJ4XFdDlSRpTGuCA2oyyILrcduv8FvKvGTYefMEus+yLkuMtQitibyjrhtWy2VnJKawVrDOS/Iix7QtV62jqhvOzqZcuRIziBJME/LwoihmNp+xvT0JWhXTdheAwvlgYU8jiJOQzCyEIh1kDEdjdEelWuu4e/cuUmriLOmCiQON29Qh4wcCOAv+FII0zUJenoMkjsGDbS29WLPvxDg8PEQIwf7eXmCabGCU0kFG3TYsVmuk1iTZACMEUZJ2IZHB58h3OUHFKqdarajzgrRYU/mGw5NTTPUcg1HCz7/8PMXBDq5Y8ta3v8mjt+9Sn51w+8UXiF3Ddqz5lZ/9Asl4wODqPp/+4ueZnpyhhWT7uRdozub8we/8LsvjE1aLJbPplMlkzP179zg6OWZZ1UglmR8f8/Dtt1GRYqtzqVXCkykRxIlCsJyeUq6X7G6NUEmMxrEzzFglitl8SmMtUZzipaKpDUIrhHIgg+Gb8AIkJIOYrErZ2t1iMV1gWo+IFLQOYw2DJKOsKpwHFQUn3LatQ6lORyilqZsQM6BURJwmIUbAWVoBxJrWBvDSVgEEBWldEIOa1oCHSEYIJ6jqCm+DqL7M12FCNS33H9zn4OCA4XiE7cpAweVXBCNW69FSB7fdxlKbFidASB3sJawNtgt1CVqCVrTOBBd6CbFIaH8IdPhPw3ga+3QRPD3NEfxHtGUf+BnhnhpCVN/z2f0D/CUd0zuAmcvg5LJp5Ls99glfpL489n4/Xvdw3ynVRc8Q8STIPY/wEJt/4T0DGx/e/WJG3zlgklJ1i1SL9QbT+Sk1psYZj8cSRQqtFWmWcO3aFV586SWSON34E73yyivs7+9z4+690ATlLE3ThvLdcklVV5RFyWw+33RTGxNiTvp90wvBff/ZpNgYtm8+b5f5JqUImZw+sElSBBfwIF4Pwe1KSCRBElE3NaY1VEXJ/OyMe3fuMUgybt68ydZowttvv8W3vvkaWE86HLM3mkDb4MoS6QJoMtZuGLJ4kBG3hrwsA/u26SQI+qfoPXILP8x4pkBT27Qknbo/jqKQB2cdbVUxGo1YzKfUTUWk1QaVqkiB9KzWK/b294NJovOs12uEUOxsJaxWa/LVmq3xhOPjUxpjQqdR24YW9qoMJT2pMT68tLCevKjQWnaW8Yqz6YyyLBBKUVYVUZyE9n/r0HGIKlkXBb1rad3UlGVBawxCCmzjUWWJD8I7WgABAABJREFUVpr5YolWkuPqNLhpSxnEfOucrZ1dnIeyqYlc0CvlRRlWIlrjOgOwwKwqnLEbozBc1yoqFFJK1us1UkquXLnSrSgEzjuSLCUbDkJgZFkyGE2Y7OwQZRmnZzOs86TJgN3dXQ72al59+SVefekltLWYsmB+dsLO7gSnPY8PH3Lt1i3y+Yp/+d/8E954/XVuXr3Kjb19Hr55i1fv3+bGzZu4suC1r36V1WzOG2+9STYZ8+nPfZrjoxOG4wnPPf8i1196hb3dKwxR3LpyndV8zmQ85tXnnicvC+IkQSjBg0ePuPfgLtY7muUCKz061nz+859jOBjz+utv8PD0BGtaqqogz5f8+m/+a5brJaPRiPH2FnGWUpYFq3WBAeIsDQylDUDB4YjSmIcPHzE/W5FmKYvlnLpuEVLSNG3XJVd2K7KgPyidx7pQgsMJbFvirMdb0IkjbyoaZxiNhgyGGa1rsRDyAaua0XiM0jHWedbrfKNp8sYinEf6oDNrTBv0DMWaZbHk+OyYrb1tEuFoTAhtTuMUiaDIV5SrNRLJOi9wwuOkwBDiCaRWodPUeQyO2jYh10mCseGcaoWH6KOfpP6oj6eFx1prnyjH9df2jw449V5J/sK/9/vU9wmYLoxNweviW73PF/hQnXoevHj/qMnjnwSzUpxjLtmBIwHWu83fhRTgglQEPF4Efz/vzsX1IVct6IVUpJBKYisTAmqVIEoUQkLZlBhrUFqBCB5vu/u7XLt+NSQUNBVSCrZ3tnnp5ZfZ29/feCydnJxwcnLCarXaRLN84xvfYD6fI4xAKolC4QXUbUNj2k5Q3XXLCbnppZSiL9t5qrJGK4XNLGVe4D2kccYoG3WsWIQUAtyF8F0TGPbJeIsszrj75ttor5hkY1555RV2hhPmR1PKqiJNYmKliBVEOLwz5Os1J6dnVE3DYDhiMBzhEJydnXF0dETbNFjTdl3eP5wF3DMFmnQUTCajKMJZhzWGLE45PjqhqWviOCJLQxCv6XLK+pMbIRFCIbvOObzAGMt8Me9oPsPpdMbjw8fMZnO2trbDit8aVkXBbLFECo01FuE9Sks8IexXiLBty8USa0Opa7lcE0VZKCWenbG1tYWUkpPjU9Z5YKucc9RtTRTHDIYjEFDULaPhmLKogk9QEZy5t8ZjjLW8decuz78oiOII4ywISV1X1E0IlM3XOVVZMhgOiXTEbL6gaRqyJGW9KijLiqoKJcmDq1eYnU2xzjEYDJBCsM7XrFYLdvd2yYoiXGTrnLyoSLJDssEQLyRRlDA9m1GuKuJI8y//+b/gv5fw3I3rYFve/O7rvPqJV/CRQCrBcrnCt4aibVlZw6xpGHvP+MY1PnXtKrv7eyRJwvp0ipcRv3TreXQa8av//Fe5fvUqi2XJ7//BN1h3+jMl4MXbz/OJl1/h2sE+3jnyfE1rWqQSCGu4trtP4xoeHx0xPznlu996HWFa9vf3ee3r36KpPGXVIrFcvXqFva1tfuYzn+HzP/MzCCl4+84dXnv9O1TOMRkOEUoym85Y5znIoPlaTmfcu3efew9mrE6P0UpTN23XvhsTaY1L0+Aiv85RWlHkOUVeBJv/1tE2hsl4i/29A3Qc42QAxNJYTF5SL9dE4yF1Xga/sOUq6I0QGOuJk5iGBiUCQC7KvNOetcyrHKE8+/vbJLNj5tWKo/tnZNkQ6UAhwMDybEYaRWQq4vBsRustUZLQekvrPUmWYqMG10UHTfM1cRqRjTJEt9rUUlG75sc9TTwz453YkV5H0v/8tO9/dENc+vo+H06vLf0gwz/pX9QLpd+FMdo889LPT+s+fPKdnvjfBx4X2aUeHl4sm56/t++Yl04zJEK+5jnDFmKvemYGBEpqpFABnFiP7eKhpAxApo91ajpn8NVyzeHhIdPpFCEl0+ms82QKb9L7NNV1TZZlXeyIecf9473H4ZBehq5gQPYdncJvzlFrLbLbzz5UHs8ZN+vw0uGl3LCGsjOKxkMap6RJRtO0rBYL7rz1NpLQGXh8eMyjR48wpmV3Z5tRFhPrUG6z1lA2NdaBFxKdJCRphk5SBsMxxoRmnbosWEzPWHTh8R/leKZAUxwnLBcLykjB1phYj5FdZ0FZ5CTRmCJfU6zX7O/vkaQJs8Wc1luywYAHDx6SDgZB3d+2m3bE0Ang0VHMy698giwb0BrLfLmkKEt8kIfQVKEkk8YxxraByVKSKFIo67BtS5YNaVqLAGbzRSea1rStwTnLYDhgZ2cfj6eoCkSliZO0c+9eY51gPl0ynU65ce0aQuoQ+tq0mNayu7NPVdSs85KqqRmORkilcF5iHGSDMVIlob7sBNlgxGAoaeuGxlh29w+oqoqqrCnKhtFkm6LMWaxytFToJGX32gAdR6yrkjRJGG/vMhha6qqlLS1xHDLvhIHVbImUgtFgxHic4XGoRPP8J15i1VYcH81QWnM6XyO8wEQZz3/uZ0iShJXzuI7jXlcNKYJ0a4ud/WtsjcZ8983v8Ut/+t9jezRmMZ2zu5jhnMU2LavZnDSJOTw+5stf/m1wjuEgA+HDiizSXL9xndsv3EYIiVQQac3dO/e4f+8B3iuG2ZD7d++TaEG9XvM7//1vcOu5W8jWUVc1b7zxBnv7e/z5f/dPA4Kv/N5XEDLnpVsvsLe/h9IRRVlwduN5vviLXyJOYr7zxludnYLYdLQ8eHCfnf09iqJAe8HBaIxNBxRFRSUqVJays7XFMInRSlO2DdJ7EilRQmCEQLcWbSxZlm66Jsu6pWhqdgdXQntyHNO2hlZ7MA2alO1RaBP2g4gmlpTC4hPJ9pVdFmczpmdzJnHGrWtX+eTLr/D8jdv8mnGczaZY4amtpSVMivUqaCGyQcYqL4mVpFisaNsG1ZmEnpyc/PgmiD8i42nxGh/167/3uCgG78HBezzjiU39wSwHflij587olVof8O16sPQ0Wwh48ngFn6NgCNl3wIUH2fPHCEJjDuesovcBZGkdE8CnRKkIKTV4iXNgTDDXXa8KlssVp6en3L17j/l8ydtv39kwlUop6s46xHvPaBTMZ6uqekfQdFk7J3zPkp13KV70cHLe4QKlFkqKbRvKlz3Qu9QJqjp9lNKaSZLiPTx8+JCiCOkDZVFy+PiQk5MTrl7ZZzxMiLRnOEiJ4hjnPEiFLxREmhESpTWj7S2sMTR1iVSC9eqjB0zwjIEmaw37+3toJVHCU5Y5RbkijSP2dndQ0rNeLXDOduVkgXMheHe1Ktna3eHsbI7WQbfhnafRFmctWmlWq5zVMmdvfw/jHat1jnE26D2kIIqTrrOI4LGU6K7lsiIRivHWDk3Tdq3qDoEKJmRtTStMlxy/oGnrQOV2ac2ZEwyERAoNPmzzc8/dRiAxrqUoVqzXJbdu3KIxntPpnMn2FtZJFougNXI48rIOF6fo/ukYJWOiKEaKEDXSWo9HIqOUxoT6bxQPSX1wfFZxTGNqvPUInVC1nvV6iRKayXiCcILFbEFTW7YmO0RdWzreBV0TgsVigaNha2ubUW3IsiGj4QhnLRZJVeZkccbe1jbVOmc6n6GiiFoqqtowW055NJsTb+8wm81YTuekWiGHA3YmY5JI893XXidKMyajIc9LuoulpixzEpkSJRGz9Zqz174ZDEq73LednR3iOMY7ifAh3wgv2Nra5j/8D//nHB8d8fjBY2xrOH78mN/5zS/zm//613nxxReYbG2H55YN3/v6t3j86DHj8Rbbe9t868tf5c233+bVVz/B7vY2Z9MpN3b3+Pk/8Sf56h/+Acenp6hritFowK1bN8nznD/4gz9gPofxaMxbd+7x+197jTSOeOn2c4huRbe9s01rGtI0YStNwFkmOzuUrUEAUSSpV0vOjg/RccTWzg5gWK9mOB88oaJIYU3D9RtXKaqSqigo8jXWtHhnKMuCWdPy5lvf43vf+jbT6RQnQHdhzdYajLfEkQLnaauS3e0JSktaa0BplJREUcYgHbBcf/R5T38Ux0+6t1AYH6zMdlFAHp77/stfFyHNBx3vFIb7wxiXAcbTAO6TxzWIv50P5UcXkFKnKfWdSaUFgui6LBtAdbEiI9J0gFIRWidAVzGRGlB4LxBC4Vwwqj09Pd2wTD1g6QGOEKIzxpQbL6be4uJpn+2dPw8b7ZDvFqlN2+DxVE1F1dQIIYg2zvWdQa5zOOtQUlOUFVtty9bWFgKoyorT6SmDbMjLn3iVdVlQNyHpwliDsxalRCh0emhtxTIvOJvPSdIBURQyUU3bUFclVb7GNvVGd/ZRjmcMNLWUpUVJSRJHZEkwnxykKdkgDYnOoyHWWtZFzuHJMcZ6rt24yWy+IM8L1nnJtWtXg1i6KLE+1FgjnZAmGcfNjOPTs+C6bA1RkmBcCwiE0AjniXVEXTXUbYvWktY4WlMCgiIPQjdjLQciIOq8rLDO09Q1CDC2JU5idBwjncMax2qZ07QtcZLgrWe5WGGMwzYGiWA0HOO8wLQOJRNWqwqpFIPRiKZtqcqC8Xi8Me40ztEWLetVvWkxdc4jlAk5bU2L9Z6DvX1a01LUoatKoXAiYrkONvxJFBPrGK8UVWPx1mMQeOOo6hWDLEPgODp+zO7+Fq9+6iVW1QK8IC9yYi0xVU1pHJHWJB6s0Ix1gqwtdl2TeI0ips5rnGlJdRRiTqSgEp7dnW3y9ZK1qakKRxZFJPtb4AU+S7jy8os0VdCHResli+WCZROaAVznQ+QRWKUoF0u8tUHL4xRECpnEPD45ZjQYhkBo79nb3eNX/o0/zr/zP0xZLpZdoOSAx48eoZVmcCVFWoH3gv3xLvnZim/91u/xva98HSchThI+8+nPUh5PefveHabzKcPxEKng7MF90jRheXzEdDrlc/+jP8n/+n//v+O7X/0ab775FlevXOUb3/g6xyfH3Lx5k4f37yKkYH97m6OTY6q2ZTSesLO/T2Mt67JifziibhrK2YKqKlFNjZYC2RqMN9R4RllGEiUsREWT5yRKIYdDqnXIbbLWkMYpo/0trIDG2S6qQeBdyJ6LpabI10jVdX12ItbQpqzZ3t7m8PRj0PTTOX7SAeD5+LBberlz8bJB5EU2RsrACvXhuwIVWF/RC8qDrumcudIIQre1c9A0ocFDqcA4CSGxFpSKSJMBcZSiVYzWCUkyIE2GeBzWuk3Y7uVYnj7It//dReD0NBD/bsBe6yBKRxC82mwo2TnvQiRUJ4vzhJKkc703ng1hu21LFMeb0Pv5YkmaDnj+hRcoqpDbuZhPqZoKrTzaKGSXV2espagajHVIHVI8dBThrME0TbDJ8Rb7QzgnnynQpFTwhPDOI4RBqYThKENHinW+oqqLkJcjRZfebvEI5osleVmBFCRJxmpVIIWkrluaOmcyGtM0hvX6hKqsGA5HRBG4qsQYi/WdAaKpwQpGXQkuFrpbyQdDsLpqg34oG2GLgtaENk4VxZ2mSjMaDbHeYF3wTqrbpjuhQ9dT07oQcpjE2LrEuuAQ7Rycns66NGvLeGcb6xxV1VKbButgnZd4l2OMIYpi0jhGqiicZKa/AVoQQRwfSck6D8GydV0TJx5tE7RIQqdZEszVvPMY47G2oalqTGPY3t6mynOEDi7pcRIzHA3RkQrZZK5FDwdUVUFTGXLnSaPgGZJIiL3HrHPqVbgBG2soihwlBelI0rYNRjmKKmc8GVJjKH1LkVeksSZNYlrreDA7ZTIaIb3ASo8YDoiloMlzWtsGj5IuAdzHoWxpCZYAtqpwWuK04tHRIdevRbQCTGtQRQ7edQC5JG9KsqqkMpZYhLiB5156mRdfeJHbz91iECV88Ys/x+OHj4izFC8E0+mUfLUmlopBFDNOUu7ceYt8NuOll19gbzRiPZ3y8Dvf4Wu//mt877Vv8fjwmD8sf5+T4yMGacw40igTWKWTu/eYzqeUdU2cZBgvab1g52CH3b1dhoOMbG+XJAkxCuvVmqJYU9UlVV2Tz5bsXTmgmi1xVY3rzOLi8RDTBvq+xVI3ObWziDhCxKFL1FuBTIIPV9s0xGmY7KomxMBYPFKG8N+Pxx+F8ePQT/0oRsdm+U2B7gONiyCkBxQXOx0vap2CgaQAH7rNJBKEQMpes+bBy00DTqhkQJa2tI3l+OiU1arAGof3grax1FUNXpIk2aZ8J4QkidMQrWUNWgf9lOsWPRdHH39y0Sj1MlP2XpYX/fe6s2EJbD1PgEcl1eZv3vtzmVr30YMRpQUZGmSUlJyezTDWsLW7zYsvvchyueBbr61YlisQwZPJWIuOFHTZoVVT45qGqG2J44SOi0IohTOG1n3089EzBZriRId2bk/QI0lBkiaMxsH7aDBISOKOaWlbBlGMR2K94ODgCqdnU7RW1E3DMBuSpUNMGzLCTB18MfCSsmo2Dt3BFaI7uZwgkiEOpW1aqjKnqjRRFOG9p6lDnThNhzjnOTubdaxOQ5YkVFXJ1tYE6yyyMworqwpPYCasdeRFw9aWwBlHvl6jhEQJRVEuKYuKydYW0+kMnSZ476lMS1nXwX+naxWVInTq1W1gqYJhGTSNxedl6K6zjiiKWNYrhJIdGxVcr4UXFHmBVIo4CasIY4KJkHMWpSVOC/KmZOAy2rogHiQIBcdHx3hruX79KrPZlJaWeKDwxmJdjXcN3oMxEViLogEXnKtHWXBAb4slToIXNU25oMwj2qbG49BxDFqxKkvqqkJ6UG2DQpKv11R1HdzApaRuPR6BjCI8YKUKeYCpp61rnPWk21tUeCo8q6amtS1VU7JsgxA/jWO8d+hS4adnrJZL4jjBIyjLhtcePmBvZ4fJeMx6nbNezLl54zpJmlJFkvG1fdQoYd8fcPP6NZxr2b92hRdeeoHXv/aHrI6PObWerzUtR48fc+vWberZmqhuoW05fesOzgWTVA/sZQOO1jn4mrpuOZwtOXz8mHSQEXXXQpJlYbK0hkRJBnGM8prZdIUYThg4yc6VfYSS1MYQpwlewGK1pm4bVnmBkWB8g2mCniGSisa12MbSOoOtPTqJiNIYnSYYH/y6iqb+Mc4QH4+Pxw9v9KzM07r1Lv4sZWCWvBcoabE2sD90uqVQUgv/QiBuMLCMo6RjiByz2YJvfOObVFVF07RoHRpL6rohzwva1lAUJdPpjOVyxWq1xlq/2UbwG9DUM2BCiM7pm01Abw8CLwKmy9EqlwFh/xhjzca2prda8NClLViEEZ1ZLpvX852sK4o1UstwL/aeKIlpTcN8NWc6myKUYGd/l/H2hLJc0rYVTSuI4hjRNZ2opoVGYFoTMjBFZ3QhQEGn6/0pdwQ3pnP9NAatJNYZTs7OqOqKJNYY16JkODGkgDhOg8ZH68CYeMiSAc4LBqMxo8GQQTagrmqkh/39fRyCdVFQNzXD4QihVRd1IvDG460niiMmaoKQbmMrr7WkbZLQWp6kWBvYDSUVg+GQQZYxX8xIB9nG4j6OgyePR5CmKQ4oi4ooirGdUF0rTZIGhqYqq9AmniYYEwy9jAtaInzII4u0RgBaRxubAudCllHvJOts8JmqKkNZ1aRZSpIEM8zVao20grasg8heSKxw1G2DF8HXxwnH6WLKMl+RZhHr5YLxIGG1hpOTFbGCqQqt9vP1nJ2dbVCOqqk7PyeH8w2+anFVg9YRRVuTDTKUEOSrJRZopEMLS7mc0RiDQ+KVpmna0I1nHVpIlKoRCNZVRVEUAThqjROSyliEc5jWUjUlgxSE8BuPKa8cy7JAJAk1jsZZKu/x3mJwoIJVvxYSIyynbUFE8Lk6nc/4ztERSkum8xll43HWsDvK2BqNmQyH3L51k2K1YhBHGBmAWdYaTmczTk9PsE3NQGu0dcQOnrtyBdFY7GpNXRdUixXr1RyEYnt7i1RpXGlQqeL5mzf41Kc/g9cRPs1oTYPA0RjDo8ePOXzwEO3g1sE+W1tb7I+3WZ/O+OIXvsB0teb+44esqhIRx6Aky2LNlesHRMKihMA3IWoIAVoHwCwI57/tOnh0ErIUccGSITYt8HEH3cfjj+boS1mXtUBPsjRBx6PUuW4pMDqd5+cFp4Pw2PA1gBuFUhF5XjCfzzdWE95F5HmYy5umxTmoqpr5fBE646oarVWwmRHn9ggXmbBzQMX36Z0uj4tlxv5f/5xQdrRhEa4CWJJKhftFF5ZurMO6Nvg2bWJaeruMMIdEcUTdLfjTJMELzzpfc+/hfba2ttjamfDiyy/iXcPD+3coqxodx6g4wXU2CJ6gDfbCBsNnKZCekAjhHOanPUZFiM5TwiuiOELgmS8W1FVNmsY4QrhtpDWDLAOWeCTZYMTDR4+p6prRaEieF1jnaZsWU4fymHcuMDUdE2VtoA6tD54pSmu8ILizWsd4PGQ8GQaPpDzUVOM47roGII5D8nSWDcg6N3EVSQ4O9qmrqus46uMqwgnXtk0IC25aJD7QrEIyGY8QQrFcLEjTmNGt61g8i+UKbyHVCTHBhkDJYDSWpSlJkiB6wzIlkEm0ySMbDQehdV54pAorHq0VzrZEImacZSBFEHorEFpiCflkzoeVxXCUobRikGWkaRS8s+QIZ2oOHz7m1U+9wmw5Q3R0rYg9SiqsCUxXVVXEQpEkEaYpqdqWyXDIYDBktlzgJNy6eYPp6RmpjollhOkEj8IJdrd3Wc4XVEUQwCdJhlKaoiyxzqG6rkWlFFJpTBNadJ11eG9IhkOKoqapgsN1OhzipaTtQpnRCpFECGtwMrQCTw72UCoE3SaTCR6Fl4Lm7e8SG8dyseKsKlnUNclixtFySpXXxFLwvbffCI7sb95hpBWRqdBpRilgXpc8Xi75zp27PH7wiLOTM8bjAcPxuANDChVFVHUNQtEay2g04vNf+Bl2rl4j2tkhX69o65K6qXnzzTf5fdPy4M597h+e0DjBYJjx+OEhf/6zn+cf/Nf/hOm84OqtG1TGYIVnHCXUjaHtLBMyqVFRTNu0+LLFCodwAokgimJWxTpoFKSgdZbBYESWJMDqxzhLfDw+Hu8yNtU5/wTQeT/jacLodzPW7Js5hAzXTAAdstMyhREwTJhvQgC32MgovIMoDdEkWoVFcMiPlESdbKLpuqqV0ozG46DpsQ3GtJ3AXHwfcOrZISHEE6G8F4HSO5UdLz6m/6q1Jk1TAJI4QUkVFrndIit0y3VMkw/aJqWDxKFpw5yslELHMVW+5uGjhyAlV67s84W9n0VJz/HRY+oyVBKQGi8Uxli8F3h6H6kOvBHiqdyHKsC+93imQFPWxYSkScJgEKJTnDGMRiOcC34XbVNj2hDeN1/MKaua7e1d8rwgz3OiJObo6JiqqcBBpiNu3LrF6dExbdMwGo+ZbG1jvWdd5BR1hXWONBuQxDG2alFCsphHIFzQ7DShu0lpjZKKOE4xrSGKY7QOq/K6bljMZzz/wu3QjWRa4ihGSIkX4LynbVu0Co/3JnQARkoTS0lrLCcnR+xd2aNpKlQSAw6lRHCKbhrA4VVfSw415/FoGPaTd2gVMRwMmC+XpLHGtBFpGtGYFu+D2B0X8vaSKKa1LdbWqCgmS2Ia11I3NVkSc+XgCm1RgrXs3byO9JYs0ezt7NA2BVWV8/KrL2GMJRsOkF6E/LI0IS9LqrLEjFtSFaHjCDmaIJVib3sbaR3i4UOiQcYLLz5PW75GkqSgYsq6JZc5sVdc379KqhKqskQoFfyK2hbrPHmR09Rh8lBJymg0Qg2HeOuxpgUckY4woqaoa2Kl0EqHPDsXnLylEsRxhJAxeZnTNg06ivAIiqrEGE+WxahY8+qrr7C9vcVisQz6r8awWi7Bu+CG27YkGtoIjo6nnM4rBpEmSyLy+ZQoX3K6XpM/uEMxX1I3FWmyw+0XbrP33A2SKKGsCqRUbN9wlHXDYpXzm7/9O8SjMel4hHWhQ3M0GrOzu8u/+cf/bd648Sbf+d5bHLctg1JCmvHm/Xu8eece//af+pP8uf/Zn6GuK1QScffeXf7v/4+/yXx2ymQYk0UxkRBo4yjzirIJWggdRUz298iXa5QKzGpjLJFQlPbdr+GPx8fjWRyX9T496/I0IHX5q5IKoVQXpyUvmFsG/6Xg09S/NgSWSiKl2sSRSCWxxmDahqYJc3VRFqgmlOmF8CQqJsoSnEvI8zXWytBp1rZPWgR0HksQ3MAvf57LlgP953CbbrjQwaa1Iu6E3H25L4ljojjqynfByLN3CH9SKB8ApCekFdRNjY40CMFssWC8vc3B1QP2rxzw0suvcHp8xHx6Cj7Y1DgkSZqRZpairqnbZvMZXR84/EPqTH2mQNNrr30LCCGEcRQjBGRJHPJ2opDvZkyLNQYlJDpKuDLZ5eDggLzIOTo+Drk4cUSUREgEsdIY0zKdnrFYLHjuuef45Kc/xWAwYLpYsC4LkJIsG5BEMU1RsFwscNayWq+JIs3+/i5lVQT37HWBcy2LxZzRZIIybWjBd46mqTk8fISxpkPoCmu7IMOuKyCclIJitaZtQrpzXYRsuePjx2zvbTGbTtFZQlFUDMdjpNKUdU2WdloW7ykGA7xzKCHJ0oSQXaTRwnP46D6j4QghJdYb8irol5I4oVjmLBvJ9mSbqilpvUElEUSS1hkaE6hUYRpmp2e41tBcu0GkBFka01QVZbkmShRf//pr3LvzgEGa0dbBuyPOMhbFCus8WZwiPVg8XkmMs0znazKlWS5zfFExHE84PZmhVYxUumNZDPk8Z5EtqYqwSmnalqIsaa2hbtugRXMhbTvP1+T5it2dHSKlMbbFmZZ8uaAua85OTzkD6ibkr82Xc6wLLfZ5vkJqSVmXtNYyGAywDlbrnKo2JHFGNhwQRYJSBzZLSkWUxggmIGF7MmE2neK9ZSuL2LqyT6yT4FzflXadc6jdOaM0YefmDZZnM5zSrKVHCMm8WPH46IxYKUbjATt7e8QOyqZlsrPH1etXg2u31uzu73Pz9nNMtiaczWfcOzrGComwlna95p/8f/5rThcL7j14yK//xm/z/PPX+MLPfZF/+s/+OadHpwjf0jiHcTWRhwRBiiBKB4hYo6KIF2/fJo5iZos5GEckFViPbz9GTe81fhJtBt5rm348xpo/mvFBjkd/87/szH65HAZdi73zCKlCnIhS9DY43ashRcgFVQriSCOUIjKaoEnqopTaZsMIGdMG0OQ9dROCeJ11SBU687LBGK1DI0gPgC6Dpov/LgIheIrJ5SXh+8WyZF9aDFYCduN7qJWCJEK70CQkEJvP3r9X0NUasmwAMqRjiM5FvDEtZVWyznPyomBvb5cvfekXOD064vT0LFRYhCDLRiAV6zznZHrGbDYL+8labAcGgxHWRzueKdD0rW+9/sTPUshQrhqNyAYZ3geKUytFmqRsbU24efMG6SDj8eNDWtOyzlc4B0oHgXXp4OTkmKIIQaTL5ZL5bEpVl+RlcGAejoYMh0OkEChnGKT7ZFnKo0ePiGLNjZvXOT4+ZHd3j4cPHrG1tc14PGI4miCVZDweEkURcRyztTXh+OQ4tIXLkEHnBUSdEHswGFAWBXZrQlPXbI0neOtYLucMhimPHz+grmrmD1dY73nxpZdw3vPg4SOGwyFKyuCOHsWUec5ysWJ3Z4vhcIh3jjRJeeOtN3nllWC1v1zNKaoKHUUMBgMioVnMlkzGA9b5krxao9MYFSusd2EfF5pmuaTMcxZnc47v3qMscuqmQShBa1pu3rrGcDRgdnLK9avXKPMC01qy8YhVHTL9vIX1ahnKhTtbQTzYWraGQ5QU3HvwgAcPH3B4/yHj0ZgsHaKjkPK9XCypmxA+mQ1TrHcbYCO0JM1S4mGGTDSmbamKgrP5WSesDx2WAhhtT7CEuniUJEQxbMngfBvHiqJYo2OFTjRVXREnUfC50pqssRgTLBFOTw45LguccwwGQ7wXVFVLNsg4PT0L7u5VibEtuztbTHYmRFptJiFrHPtXdBDg64hSK1ZNRTU95fDxCflyRbEuONjbwRw5tmdLtra2GWQDhkDpwbcW6aCZLznJvx3YS2cROiIdZGgPaZaxXhcMBwMe3bvLP3z7LW69eJ15nvN3/4t/QKQdf/bP/E+4eXCV2aMTTu49oF4uQ7k4yyicYb5aUZYVCEdZFTTWEmdZsMj4ycMDP9Hjvbya3sk35+MRxocGc/79l+cus0xPdIU95bE90Oh9iXxnDtt3yQVtU8hoE1KgEGgtg7xCK5xNgBBN1DQ1xrREUbwBTQFEOYwN2ljv++QLQzZIECKiqqquizraOIBfFH4DG/PLp3XLXTbpvLifNkJw09LFnTKbzVBSsV6taZoG3SUieB90wM55jA8WCiE1oaEqS3Z2gtdfXdcIKUjSFO1ihBAsl0vumIabV/Z46aWXGSQpTd2QFyVxmnHj5i1Gky2WqxXqrbcpipK2qnDChYqAtT+U+eiZAk3//r//75EkCSA2FKcI6YmbLoEeyVZlyeHjx3zt61/De8/R0Qn7+7tUdYFAYp2lrhqEDTdHLRVN1XL/4UMeHT7GekdR1WSDAVs72wE0AcrD1SsHbG9v8fDhA5qm5v6De8zmcw72dzk+PuHFl15Gq4j1cRmCEJ3rNE8R125cwxxaVqsVWZqFGrsUlGWJEJJXX32Z9TpnazShLiuuHBxQFSVplpImKYfHh2xvb9N6h/OOF196meksGBleuXKFKwcHFHkeXFzxHB8eh+DCKORWRVHMdHbCF77wuVCWPD3BOktrDPPFApxgf3ePl196lYePHrLMl+gk6sSFwf9DenDGM7l1i7tv3aUsCuoipzUtpnEUVcUtHZMNx3jreeHFl8jXOWVZMdreomwbdByjpSZfr3F40tEw+EUtFozSAVevHIBWXL12g53JNmmSBtM3L8myjN29vdCG6n14/mDAZHeH2raUdUmSJcRJzHq14vrVA8qi4PjwMWW5QgmFFMF5t2wBJRmMhkihiaOYdDSgLEvKqsQBtosr6XOZhNDEcUaaRTSNYTQa8PhRw3g0DOXIrhHAGc8wG3D4+DGTyYjtra2QTVc1HD0+xhqzEeyb1rC7vUNrPekwZbi/h1KKg709hvs3WM0XbI8TXnj+Nocnc/LGYJqWvG54MD3jcL3Eeo/psqy8d9R1TVlWrNdrynXO1nDAtYM9fu6LP8etW7co6pKHR8ccHZ3w5S//HnVrSJKIV159lVdfeImjyV2GeJYnIRB61TTUrcEIz5v33iZKYmQSk0nJYDJhuVqTCQlHsx/bHPGsjXe68V4eF29sPyh4+jDPf1rH2LuND7et7++xl0tiH0aT9GHHxRLT017/YmnOOYfztjOHNbRGbuJIeofwHkA5p4DgDB5uYW4j4wi6KLrct+Dv5Am6UOdDd5q3Hu8NRZHiXEJVVRttVFVVQfrRdVf3Xy+W5d7p/LoMLC8yTnXdvYcxPHzwACEEs7Mz2qYhTVNEkmzYKN/FhsmumahtmnDPI9wCe41VkgSZSxRFrPM1JyeHxNJz/eCANE03ui5rbWiyGgxASEbDEXGUYFsT7Bt0qHqYH4IFyjMFmv7yX/m/BJpTarJ0gBCSqgylpWAsuUYgGE3GnJ6c8I/+wT/gP/87f5vDoyPiOKZpGlarPAiWkyRolNoW5zzLdWjZtEC8WGKtZbUusP4M+eBRONk7P/A41mgtQxSLEGglqRrLeBxuLt/53ttY6yiKGo9gMEhJ4piqrrl16yZnZ1OqsmIwSIm6Om5Z18RR0jnGekbDEYv5nP29XXa2d8jGIwaDIeOmYv/aNV75xCfIy4IkicmGA3Skef7281y5ciVk4JmWydYWxhhmZ2fMpzMGoyG7e/sMxgM+87nPsbO/x96VA/Cek9NTjk7PWMwWfPITn8SKEEY7ibdo2pblaon3jkGWUtUtq/kSb8ELwWA8Zmt3h2w4xDjLg0cP2dk/oKprVJzipKQFWqBqDY117O92q4csQwhweBbLJQ/v30cj+dSnPslwe5vtrS1irTk9PuH1b30bpTTXn7uFMW1gm9oGISXD8ZC9gz2s8JycndAaQ1VXPD48Ynt7m/2Dq6FrUEkirZnPZsymSx4fz0liz8g6trd20WlY5ZmyZLEu0ZGiLINJXJKNQmSAA6EjkiRlMAoGq2k24Pbt25yeThFCsLe3HxoAlGJ/f5uyyHnxpZdo24blcsFqsaDI8+DWLiVVUbK/t8/+9avsXrmK7zL/JqMxv/grV7j39h12drcZDoa8hERFMXlesFqtGYxGXTi0xzobgJ5pWS6XzE7PmJ2dkS+XjLIMrSU/+/M/x+HRIb/6L/4lDx49BiU4m32Zg71tyjzn7/7nfyeYnJYVsq1JhCSKE/K2pRGSnav7eGcZ7uwisxKkZGtnBxklVPXHnXPvNZ52c+p/d5EJuNjifjGG48Pe/N8vo/VurMNH8V7v8qQP9VofdnvfL9i6vG8uHpeLfw9MUgAJxpgNi6wUm25q7z1SiMCEeB0qDV2JzHtLFEXkXdXDeYfSgRF3roUOgCgVLGScM1gbch9DBpwLuatGkSTJptN8488HwTdwUyZ0G7KhB1KXS4yX98FFcBXHMbLT3Z6enGCMoW1rvO9YsrZFK0USJaHbve9k63w926bh7PSUOE1Dh/qGxQPnLU3dMJvNOB4NOLt6BecsaZqFQPvZnO8032H0+DFSaaZdac46hxKhe9xpi5Qf/Xz0TIGmuip488F9tsZbvPzyJ/A4zo4PGY1GTLZ3ODs+pi5KPvGpTzEajqjLmkcPD/FC0LaW5XJNUxvaxlHXLRACEYUPgrVejbFYrroOAxc8mlqHEwETey2RFprWYJ0nSROsF+RljcNS147F+gS6PBwIppPBmgCWq7xzavXIefBUQgZNCx7+6T/7VaRU4AXWhtVBlqboKAIJTR00WeOtyQUxX4uzhsFgSKRDt5PzoRvQ42mqpltphJpxnq/45//qv+ucZ4NtQlXWzObBXOwrf/BV6LoGEZ62tbRtC10nhHMO25rQUditEJRSxEmCkJDnBUIq8MEBdjTMaNtw8SqlsN6TpRmj0YhIR0F/3gnh66oED8PhiLqpSRKNdwJjWhazGYNswJWDfdZ5EHqDJ8syJpMJ2zvbCCVCNpsMER/5Omcy2eH5524zTIbcvv0ct59/IZRj65baOCQeFUVEKjBfy9UKISTD8Sh0ZfruxqWD6WZeFJR9h57UaKV4+eWXuXn9GicnJwgh2dneBqAsc9q24e233ySKAv08GsWU2yOmZ2eUVUXTNAjtkRFM5zNOpjPQCScnU+bTGVeuXqHIc6TwSB0HLZQK3SPgSJKE1vSlgG50+xNj2d0eceX6AY8fPGBnb5fX3/4ub7/9NpWruXn7Kk3T4PB86hOf4N69+yxmZwghUQhULEFpShwrZ8irhpO7JXXbcv9oivOeJI3IzhZBU9Z8bG75g4zLXVkXu5h60PRObeJPG5cB2tOA2kc1LguIP9RrcCG4RXw/AHq/zNz7eowAxHsHvTwNHG0AUAc4Lh4nOD9G58xUJ1J2BuM8UoGTIeWibZvwXAmRicjzNUHrGKF1YKOMMXgRmH6lQixL6B7zSHXe5Ra0T5YsC9Y2ZVl2QcCBJYvjuPPkO49Sufg5nwaaLgvEIfhKDQdjxqMBgzQN86G1XNnfxbSOIi9Yrla0dUMkFDqOghO683jrcMaSr/OQx9mRGLoTkSdJQhTHWGsoyoLj42PeeitjlKT4LkB4Np9z9/59hNIMhiOs811JskUohRAKKeTmvvtRjmcKNP2n/+lf5+zslCt7Bzz33HMUecnp0Qm3bt5ikA6IdERT1UxPz7h64yZ5WeIQaKG6G7/pLDKCNgcE1vkNRdjbZ7QmgBEBRFJ0J2oQ2SqtQseaUTStQUcxkY6Jq4Y0G6KiwErEScLWZBuAIi82XRC9wdj5SSnQKoCbsixxXpKlGYv5kjiOsA6qpkVYy2K9Jo41xekZ4ugQJWXwxfCB9hxmWed0HlC+NYambTflQdH5Nm3tbHNyOsUTDDbTNMVZx3q97uJlquB9lQSH2rYN7avOWbwP3SBSCuo65BhZ0yKlwvc3akK4ctvlEQklwiryfKERVhvi/Of+1O53Sx/7Ak/qPoQQRG+8GY6Rh7jzrdJR1InpRRBWdo0B3kNdG7a3tijWJbs7wT3bdUHjSsvgo0KwnSirgrKqUFozHo064OufiANomobGNDgXtsFbh6lrrh3scXI2Q0eSuGsJvnZ1B+9tSO8+PmY4GnDl4IAsy7h+4yaTyZjBYEBd1+zv7XNytuBsumI42ubBo0c8PjziT/yJ/wFta7hz922uXb9BksR4D16EyUsiwnH2/eTZlQEALRWT8RDvWn7rN36Dz3/hc1R1w+9+7Q+ZrlfEWpOXBbFSzKdTvLUMByMQ0BiD8R4jJUXVkDc1TigkEVG3DRJP00JjAtXu/JMC2Y/HDzYut3h/PJ4+PjBIE+F/H3SvXtQswZPMy0Xjx4uP950WVOvgw9SHt4cyv0bKrotNBg85pQLQihONdwFoOWc7ACC6zjrdJWS4bp5z9FYCxpwbWkadVrVnlJIkWOMAaK03/n1P7JqnaJye6LBzDiUl4+GQm9evc/PGdSajMYNBRpIkLOYrHtx/wBtvvMHx8THOWNrOMDqIxSVVWdKuWoqqCmbPWodmpDRlOBoyGo+I4oiqKDhpG1xdMc4GWGNYLoOetTEWbzzOF1jnMG27AZauk6d49+HA+7uNZwo0/aN//P8GB6PRiEQrinWJtZ6rV65gmpbr164jhQQp2b9yhbfu3NmsXKQQOO835o+e3uSroykvnDdShptt1xQaHMhdMDnEgDW+izNxuKKiUYYg3AtfBd1rmjZgBRfyu4QIHQfehY4578NjLd1F7z1a6VAGlOcXYhAMSpIo6k78uAsL7gCFswgPSgXLA+/DqsL1F4XtPKikxLtQG4+SNNTDBUQ6xgpLpGOkNKRJhBAQ6+DrpETQ3IS6u0fK0AKrRBBDWhlKXs55vAk7ctj5D1lnEB3FHFYJ3SrGh+3okY3sVpWtDa8/yBLyvMJaF6wEmgatQinTtJZIxcHuxDlsK2jqOhwfQsOE79p2EZKjo9/f7MsgAIV3ny0vr24v/CC+78+bX2WRpDIeSWDOAG7f2A8TYBzo79EwY3dnhzQbkiYpB1f22d6a0BpDlh1TVi1l2TAYrZjNFjRty2q9pq5r8rIMRnZN5wfG+c3CP9HlEkBTv/Fn0zNm0zPuPzpCJRlV0/D4aMrjozPwHlM1xFJycrIIuXpZSuscZeeIjxDUbeh0+bmf+yK/8sv/BnEcbY6ls30ZQpDnJX/5L/+Vd9u5H493GU8rk13+3U86gPog2+e/75v391oflSj+fOn6Ph9/iYF5mnfR5cf23kS9pUBwAfdYG8Ldo1ijdCjPhXuI6+xrJK1tO+dth/cSCMBDCbXJlwuVi14bFaoCF928oygKC/5+4d8xlhe382ks4Tsxht478A7hIVaK8WDItatX2dvdZZBlzKZzvDFMT09ZLxaUNnTyKSFJ05TBcIyOuiqIsV3Hc8M6z6nKkqIsyMuCLMvw1pA7y/z0hGGaIqViXZSs84KmtSAVWEt3QwmLd6CqK7wxH/s0nRwtQgbXSfiqBTgPh8enAHzzu9/jImZWCGIhMa4rJRmPFAqlQ9qz954oklx2Wg8nYPjeOI9xNkCbpu1W793jANsYPDVSQFO3AYT54MKxmM3CDfUDHLlgv/90F9OeOQHQicK0Fp7y0HeaADwBTE3Pphfu/RdRgCSOFJUMHXA4UF2XiOvYLAji60gpQHYlN0ndtHjniYRGSkG5ztGRwliHMX6zUaIDLFJAFHcXvgtGbs45pPco0V/I4TNrpWkJE4yUCuENkYwDELYWrTRSeyK6qADvsM6gdKDOtQoHLIljnLHYtkUqQAaPK3XBo8S6foLrtqHzyDKtCYyjVh09b/EuUOiR1piqoW4t43QAEPafh6PDGa1zJJHCOxdiYrzHdkxRv+t1N/majeme2Bzv//K//Hsb5k0qeaEJ4slj++Q3738oCO7njUPQslyVhMJfFyFEOC8n44Rf+Pkv8Zf+0l9iOj3pukF3ESI4BDvXUhT1x6DpIxw/6QDp4vhoSn+XVyg/ueMyaHqncb5wtkgVQIwxlrIqQxZnx/6s12vqjuEHhfMWY9uNSFpIuoVKMJSVQuAJzHowz1Q45zcVgBDn8v3eS32loy8h6k54fVGv1T+2f0wf/ut9mMuddVR5wfxsxlGWgXE0RcV4PKbMC5qyRFhLJEJQOh6U1mxPtrh2/Qa7+3ts7WyHIHopmK8WvPXW27x1520eHT7u9kWNkgJvWnzTsOr0WE3bUrUhIUJLiZZBw3QuKm/Jy4JI9O7sH+14pkATQJyoLsDQkyQR3juaxqIIrqvOdbbuItwIg8V6p/jHU7c1yoogNnOeqj73ldmUiAg3aynCTb6fC0QAs9ggP0JJgZYBWAlBiCJRiqIoiaNwgwc6YVy7ced+Ylro6uodbYD3DiXpLqwQThxphVCCqjLoSGCMx3Xml1L32yaIOo1SyIm79EYE8WE46Rq8tahIY1ob9FIiAMm2dQGYdPvR2ZA51+9PvEeGWaAz5bQoHzRZAkEcBZaqMTXSBhdte9lorGPxrLEI7zuHcN+V/gTCOdarEk9wVm+aGuccVVWCECihAvARXVZdW4dt68SNjtBR5+uefZHgw/Pl5mAKpA4lVoBIS6I47sSM4XdSSuIkuNkKCTrSJEnSTSphm9I0ZTgYsDJLhDGAIk0TYhcmpjSNWec5w2GG856mqtBtuwm0bFuD847JeMRgOKSsK/KiQGsd8uuamuVyye7uLovVkmvXrjGdTmlNl0fnw34P6eHhswWWMwBRiUbQmchZAz0oD9Rfty9C4KZWwVel9TYsLjp3edtNlicnp7z22mv82q/9d/y9v/9fsbOzwy/+4i+ytbXF/fv3mc1mvPjiix/oev54vP/xfm/SPwnjwzBBHxpmfRjWqStff9D37Pf9ZW+jd+t4DH/vdE4IlAxzXlMbTBsWX2nWd7MFPU7QMYVruG8CEAQNrrUtQgSjWe+COSQIIt2bYzqcOy+n9dviuuByeJJFappms50Xfah6a4VeC7X5XHiEc2As89kM0zQcPnjEIMsYD0c451gtl5wdH2+sZrwHpQx20KKFYGs04ta1G9y4eYt0MGBVrBkOhjhCCG9IqwjvNxiNmAwy0jR049VtQ1W3NMaFuBbrMN39J9IKJSKcsyG65YdANT1ToClOZNdmCXioWssgjtHKY4wLrqudUaDrSkGorrtAKSbZhPU6iOw2VKkAHUnqxqIjGV6HwGBdZKA22Maf/+ycD54/3eMB8rwEoGkMSgZA5n0APrJDC9YFQY/ogJt3oTylZHgd4wJr1Q/fd2EIQZbEVD64p1pnMa3vGCEfOgE74CFluMC6hUG48XuDtw4henv5sCNV50brOqRlXdjuJNJBBO3YlDe99yRdjl9twwXY2hYtFVIqahNq10pGNKYJ7FC33xQgZIgXcC6sVlp7XhpVErSWXYHTEGw+gxC9H1qHz9rWFustkQz2E6Y7J/qD00+GQnbAD4iU6BitbsXVnouWW+OwrumOV/idtY6qrDeTS1021OWT3RjrNme9WncTsKSqG+qmxrhQqtVlMONcFcX5Z5ASLT20nXcKcHQ2R02XWB8udutK5HKNjoLmarVaYZqWBw8ehPNOyU1Jrme+nA2smVIhtgYp8E4E/YOOMGXRHfVQLrQmOIgH2waBtabbb4Gt85d8Kr33fO0Pv8Zf/2t/jfl8xmKx4He//Dvkec7bd+6wvbPNZz7zGT4eP/h4mlbmJ338oJYI51Xzn9zP+rTjcLmkdfFxURQxHE42i2BrQ+NN07Xc13VNEEwKkjgj0klgpcy5GaWUEiV11+3WbjyXVNc17h0bYMYFpvq8JPj0smE/ehapf81e69R3bfYAcZM9J0D1ppjOspovmLUtEkGWJMFw0znqogxkhtabudm2hny9ZrVYst5akm+tOl0mZEnKeDBkNBh294ZguTAej7n9/HNsbW8RxTEIQWssRdkwmy85Oz1lPl9SVRU4gRQelERGGvVDOJWeKdB069ZtVssVw9EAJQVNVRMnCWVZIqUkTYKgWUnFaDRCa81oMmEwGvOVr3yFv/gX/yL/6B/9Q15//dvs7u4Sx5qqLtnb3eeNN99mPB6R5zlpGnx26qrpXFnPQZZ34cbkvA8H13uW+RrpoazOb6iCDhx1o23tZlLYAJYLNyXrob+FS+gYrJ5u9BjriWNBsa7xEJiGkAuJIzwWAgjsqlHB2EtArEKJrTF9wSW8runASNsBNA9EOjBZgiAEvqij67815hzQKdVTXeICOgsarFQPupw6tylZSiGQSIRwCCXCTd51WUj0NHZ4qUgL2iZsbydPAjzeWXRXypYShBII64MrrhS0JoCxzVzR7ROh/GYfhf0jSNIMrTWr1aq72IOSrV9hbahx0WmIQiJmKMtFgemsqgoQJJHqonAUkVRdcLQg7py667oJ8yOexppOa8dGD2St7dLMI3y/+iOUj4uiIYkltXF4D7Y/lr0szPYfNfy+B/weSRYnpFnCOl90O9FtohiC4k8QaU1jmu41+uJfX0tVIMI5ee3KFX75S7+EihV7+3vsbu9QFAVREvMLX/oS0+mU//af/Qs+Hh/9eFbAE3xYpukn/7Nd9Cl6mvbnImiy1jIcDvnc5z7Hyy+/DEBRFFRVRVmWrNdrjo4Omc/n1HVDkobGm9PTs43PkpDh2uyZqtAxF3RRUqoLnksB0AT3b7v53fspmV70bboo+O71Ur2UZeMqbi1ZHPHC9RuMshScD47gQKSioLe0DrNlNgtWgcA4S1U3nBwdM5/OeOO73yMepMF4N45ZFyWzxYJynQf9qwvWBQCT7S1uv/gCV69fY//KFbLBkLJquXfvAV//xjd5/Vuvc//efZZlToiuIsRgafXUz/yDjGcKND14+JDdnV2Gg1F3MC2j8Ra/8KVf4rvf/U44yTwsFgtW+ZokTXFrwboocc6R52vKsuD27Vtcu3qN07MT5ss5Hsdka8j+/j6j0YjhYIAXsusaC6UO2RloDbIBdd1QVSXXr11DacXxyRFJHNO0NS+99BLHxyfM5zOuX7+Oc47vfOc7eGP4Yz//8xhjuHr1CnXdUJQVt557jr29Pe7eucM3vv41pBSkacp6ueKLX/xiWAXYUMN+/duv89nPfpb/7z/5bzi4cpUrVw64cvUKpycn3L9zl2tXr4afz055+PAhddXQmroTj6cYY1guV+zt7bJarZhMJhRFwWQywTnHw8en1MaTpRGDzuBRyFC6CV1xYTXTdgZifUddL3oOZR4BWLgArHSXSG0BY905Le48Arcphbr+Qd33tTnnw5w7ryYpGYCPJwDKHthZHyYU121KX4GKohBY3DThj7EOjJe3dKaidKsqgZIgCWJCZ2zXZBPeTMpQp3OdH4rrmBmBRwhP6xosECmFjiSuFlRdN2FtbGefEOhkOgAYMFqgA4UIlLzpujelFBjjiBWoOKaqmg0b2ZOgsisZi0jgrN902ATnKwBH3eRYVwMepbuJD7rOksBGNhc6RoUAqSTOg3cWXM8owte/+Q3evvM2EEzqhJKbbU2SFOc+thz4KMezVJL7QcfGrBh+Ysmmp4m+L4OmiyW8UObS3eJSsLu7z+7uLtZYzs5O2dnZY71aA7Ber1iuVkQ62XTDeu/xztFe8FzaADbvu3nRYWzTdchdiCrZ+D89CZwub/vl0QOki92A/uJn79j9Yp0jjSXWGi0VSgbj415yonSEjGNUFw/mgbptKco6BPUWBYvlAqEVKo6wPvxdIkjjhFQQUhO85/D0GB8rCtuwbmpG4y1a6zlbzVlVOaVtMQJEEhFJSRpHvPD8c+xOJvzWb33lozr8wDMGmtrGMJ8vWK9yQveBYz5fcHp6xmq1CmUhJTu0HUzEwsEKbf5//+//febz0Gq/XK6o65q6aXjw8BHGOE7PzqirlrOzeSf4DaUdF/hPZvNFCNQ1BmcN09m0u6E6xuMR8/mM6XRBXVe0bUNR5EFDVFfgHG+88V2qquF733sDax2ttbz+7e+QJDF5njOfLwLLpBXOGMqyCiCjK+UtFmuiOKGuDY8fH+GcYzabsZjPqYqSJImpmorZfE5ZNwyyFF862rbpjMEStA6Os9YatFakSRK8QJzgYG/CcDxmtpgznowYjofEcXBgr+v6vOaOJBsMyPMiCMvnc3Z3d9nZ2WW5XGGM4ZWXX+H1117j7OSY8XBIEscXauZhQjHGUJShsyKOdeiQiyLiKOLw+BghRAfw1hjThs6LQYa1LfPpvCtTCbSA7Z0dPDBfrKmaBq0IN/ReS9Dpzjx0GUeKtgwTU5alobzl/CazSErJRibUoTVv2XStAUEcz0aOtvm+qluU4ZykEWBMYAcvlnfDa9AJOc+B46bK6AOgMTY0HGxsGDrmzPuwadaC7B4b/h6+lx2GtdZhbYPSASj1fxOKziSPTfC6js/BsJASoc5p/6BVazg5m/7wLvKPx3uOP9Ig6kN8rA/Fannxgd/raZYC/e+fNsIc4lkul9y79wBrHdeuXuPTn/psp3O17DeWa1dvcOXKFb797W9z7949bt54Dq011gabnKLIWS6XG1+l3jyyZ62apsW0LU3TkKQxSRI9sW8uArqL5baL7NHlLL2Ln7e3yul1YNJ7hLWcnZ6ykophmjJIUqIo2jTVCCRKytAwpCNEZ9czHo7YGm9Rd35367KgJeiOXWuQBDlNlMToJCbzAypTc+fBfR6cHjF4K2g/k3SAUJqyrDmbzpiul1gtGSQpWRwzSBOef+lFbl6/9sEO8vsYzxRo8h6qsgaqJ36/6pD6uw0hBA8fPtr8XFz4W9s0GAOL2RJ7Mazmkpi6pAT6C8exys/fd52vqKqW2Xx1vl3rPJRgZCjtLd66H05Uvr+j7qJmqucKzhbrJ7TcWsDXv/ktnPfUZcXDh4/w3mFs0FbV7VH4PNaglaJtQrijc5bWWHTUgAdjHE3rmS+WOOcp6wbd0ZjL1ZqiqHEsAYGUZdAzOdvdrDVSaoqqxNlgqGjalvVyiTOWuvOHenD/PlVVhbJU09CasE266z4LEQOOpg1lKKUVrbXUraGNmk77BHlR0Xbuuk3TQldTz7IB41HEep0zHAwYZCPyskJKyWQ0Yjga4XDUTYN3jnydkyQa19HbIImThGwguXHzJuv1mvVyRVV5kjhhazLuSmr9PgyhlNY5ojiYsdV1CAO+efMG88WU1pR4L4g7E7amrTrgfm7Q2bQBrMadj5S1QY9njAnaLgmRkmjVdbM4ixehe5HuHOm7O5UIxFxf9nQdaNqAukv9APai7kv2gPDyae43YMz3Zlb9X0QAV0LQNVKw+Xt/3n5IOcvHoxuXNTH9988a43QRXLzbNnuenRPmvcXe/vt+llKRpgOSOGU2m7NcrlguVxweHvI7v/NllssVN2/eYHd3j6YJfndf/OLPdferh1RV6KbrS3HD4XCT1bZYLDojzKBHrOqS1jRdaHuYU6qq2pT64FyX1JffAJI0IU3Sc1bp0r/LZqrCO5R1KA+2i3KqdRGYJkLTUN/oopTepDDoKEZpjVCd36FWTEZjnAoayuVqTdPUNHWF8w7rPRaLTjXJYIjXgtpZ1rMZzs8QUgfjamORUUSUOoz3LPIV0/kU+Q24e+/OR34ePFOgCX8uFXzfQ/STeTjocRwxHKTdywm8txRFvulMePL93uk15fnSHMB5qqq98PeNtKfrbnhCIdJ9fXJSDNsTSlfywmMvflrrwVRNl19E8KnohvEeU7UbgFYZQ1ldFDpbqOwTGqWiaGlN0FrFWiKkpGktQkHdgz//5G4QiCcuLkEIncyXq27XBOr25PAI74OYvG9Z7Vc0vcN2/+GkkFRV3YVQ2sBEdQxHWS7pmrioa4vIKyKtGSQpgTpROCNYzAvKpsY5CSLC2gAirAGtY7wvUTJBKEtdV7hOcI2Auq5pmiaYrHW1QGssUWdsGSwgAtOiOvo7gD9NVamunCVwbdg3vY1FVYGXwR5AK4Hoy4sqRK/0HmHGWIoidIIoJdna2urocUkURcxmM7SM2NvbY7VYkCUZrsuW20QkeI9QoQPw2vVrDEdDjk+O8Xh293ZZLJY0TcWnXv0kf/i1rxNpj3PhOA7HGVIltE3LfDEjjmFnZw/ZmaE2TShTN3UTJvaOxfPOnQOlD3FpPuvjnXxs3s+4LJq+DJLez3N/1OODAZwnYPhHvzEf0fi+bXuPTe3LVsC5/kcKhOt/d+GlfJgvldQMBkOkVOzt7TEajVHqFBDEcUyaZsQdE6+UZm9vbxPBVHXebIPBgOEwSEiyLCPPcx4/fsx8Pt/YBTRNTVkW1E2Jc5aiKFiulriZoyzKzbl2MaPVex/88i7ugh4oOX8+V1/kEjqGPY4ioMXUDbbPles8pKQw58/xoQOwdyJHyhCOnoacORFrnKCzLbGdrtXgG0FtW2KVEusBMosQ3uHrhtZ0tgtSk0YROo6ImxhnDG0lqLzldHbGcjn7AGfD+xvPFmjajP7MfB8XY3e8hRAMhxlaB5E4eJI4RinNyckxAEVZIFWM1hpjQoit1poojnHOd+2aAalrLTf6DaUktg1rfqkkUoQOtCRK8M5TVfWmhbxpwskUvG385oSSSuCswRnbuVSbzc1qPB6QZhmrdYGKFGVZobVCSIXb6IoC+EiSsK1VVeJ8EK3jxXm5UimapiSOUsB3HYOm0xUJIh0AQOj+67x63LknkPMBPGgV/IsiHToa8d1jbO+wHhi1pm2IulBa6+ymM0/KIDoKxqIO2wRtk+gYEGtDMjgdILHm/Fi3xrAya2RRoYRklecYHFKEDr48z1mtVuHJIgjErXU0xuBtCLrEdtEwFaxXD7oJxaN8KAPnRRUEn/78swetu8Cuc6SSRJ0P1dt37iOk37CUTR2OcdMGbwoh7cYw1TmQ0qFU0EO5zUqufw+BVJq6LlHKk6YZTduSJgmj0Zi6KJlsTbrk8k7wSfDVkkqSZinXrl9jd2+PtjPFu/3cbY6TQ/LVis9/9jO89cb3oOvAs94zHk4YjkeUVcl6vUJIyd7eHlor5rMFVVkHMKskXjqUUoxHI8qyxhpDHAfxZ1kUKB1TFOUHu5x/yse7iXU/KoD0Tm3xH9Xjf9Tjw3TqPbGfPyTK37yGOF9ASiFx/YqI8+3yPnj81U2Y/2/fvs1zzz3H7dvPdfN0mAvH4zGTyYTRaMR6vWa5XDIYDHn+9m3atmU+n28ec+3aNZIk4fT0dMMgnbt8e6w1tKamaWrm8zlKKeqqpqmbcP8SYWGrld6wTqY1tE0bAGS/ELiwe2SnSRJSdveJEHCfxhGuUXiC/2GoJARDZG8dbdN081TXEGNbhAxaSYo8fC9CeoOKNBawznb3Qx00UHVNQ4uLPaN4i2w4ZDga4QmyG08fwdV56Qnw3mLbBtM0NEX9gY/xe41nDDRdPskvgKfLf7p0LWmt2N7aYrlcMj+b0raG4TDj4OCAg70DhICHjx8xGI3YmmyzLgryPCcbDtne3qZpW/I8R4ogTNMqABspIEtTvLVY0xInMWWn04lU1DEWEWmaMZls0bQmdCbZcEJYE0760XiIkpLlaonWgrIM72+t5erVqxxcucLZbA5CcHp6RpwmZFmKs571uqCuW+I4IY4ijLXEyQBrDUpqEBLThEiTLM1YLOZsb+8yn0+ZTLa7Vtagt1ktZwgRqFflBRrwsncz7zUwHmEdGoikAtt1mwmxyT9TUuJF6Jhwpg0lJCkQSgWQ5M7LTZsSz+Xjtin9+A179+ThttCxHVr2Am3TiaT95rH95FRXFfiL+Wx0oDq8Rud3ufljH+OiVZgsjA30ryeIxHs/LCEDo9WPvuNvswmu0yp1P1tjKfKO+u5YrP6zW+s4PT3rWLyQVRhYq5o33nwL6R3zLhuxF4P24MsbS1k3zH/3K0ipQqwA8MYbd8I+cI77D/4r6iLvOjTDpPP4eNq5xQcvJwG8/q3vAIEJwwdheJImZNkABFy9eo31ekVd1ezsbCOF4sH9+2xt7XD33n1+msZ76Vwu3uCf1nHV+/D0Pjq97iT4cJ2bDW5iLC6VTz7suMhyXTZA3Pzs4ckL9P2++Dt8/y7j/PL+cEDo+17vaQDUd/Vl+qaVy+//dCAluhKBs6HpI2h8Ok1hvxAUstNOdlpa7yiLgnt377Jerdna2ubu3Tt85zvfwZiW9WqNjjSLxZzjo6Ow0POeO3fusF6tGI1Dw1PbtEFbay3r1YrFYs7p6Rnr9QrwjMcjrl+7xsHBPkLAfLHg7t07LBcrqrLCtDZEX8VBgmEbixeuA0EB9PluXvSdbokLwEl4AQ420WM4jPQ0vdYp1ngkViq87BbM0mO1xKFBhZ7epptHgzzFdf9AWo/qTBD7bXDd9kSxxklPXZeYM4POV8RxitJRaFIKKC5EsETBj05KRTQchZLloPq+Y/mDjmcMNMGTCowL37/DtSMkRDqg6qPjE+jMAGOtMNZyeHgcIktEyOmqypKy8+ZxwHq1oshzvBAbEW3PZvqudpyr0DnQti3jQcY6r4giReHKzvfJU+YVi/m6sy/ouhus2zhtn5xKBC6kXuMR3SrFeU959z4PHx0itKSqOuSuVAciBGEzwuv2F7MQil4ETbfteLsBYtPplKZtaM9CuHBvZqikxrQlSni0kMRJ3FksBL+mADLChVVWFcJ5NCJ0TyhF69vgoO5DZE2WpTRtGz6XAKUVsUoCk9fd1FVnieCs7dimvgQWUdfVxvrfNAacRW1WNRpTt2ilieKExhjyujrvauv8uZyzCHzXpn9e9rx4ygTvpPNWYomkj9qxHRO0KZl295LeL2zT7r+5oV0+X3vAFN7VuYurZP/kaez9E1YVvS+F8+DadlO6jVQoC/SfyXfPVUribEgI71ky03XwCQSmXXX+V9BYuslG05h200GntA4RQX32H0H0bgu7AbBvfPetTefkbLoAIXCmJf8hrOx+UsdlL6Wn/f1pP18EOxfbvIPflumYbL15bN8F9bQk+h+UDboImJ6q2UGEm+YHHBeq7x/kWR+qlPd+2uo3f++6S/v/Li/Sevbo+56HCNrCzuy3f51wvQbvJalUmAsudODWZc3D/CGPHj4iiiLSNCXLUtI05LTFSVjo9sw7HpbLOY8fDRmPxwFUKMUgSymKNWdnpxs903q9Dn5KzjEcDrh27eoGbD969BBrLFUZFuZChNgr6yzWWKwPDLzWehM91sH5bqI7v6+eM1D992HxWJk2OHYrifehRca7IBXwdPfJSOOjcC+yFyj73kdRiNAKHYySw/F3PizKhZSBnFCOtm0pyqDxjbpAX6lDZUFFiiRJ0XFEa0xoABpkIeor+dgRvBvv/3L0LvgRCdFrahR/6t/9d/jzf/bP8sJLL2NsuKH2dzrftccjAv0HG0IivLOHuq1ZrVZ4Z0njOJz0eNbrnLauUDomjRPqpsHYADbmpzNOpzOKokKpIAJ2NvhYGGtpmpqjw8f85m/9BifTkxCn0X1SW9ehe+1CiC3m3Vq7xYV/0DXhA12/Oo66aQF/4WU6ASPB8bv10HpLVdUdRDoHAz0rg3c0JpCkznpkByI9BHPEtqHuzBODP1Oggmm7Y9FNTtbYzWQpZUinNq3B2QZvoaktUgTDUgkg+rZbt6mBN6YFEbyd4iRFSInszE77G1HbNPSu5aa1lEXJIBsQxynGtLRNjfeOWMfEURz8qETfft+BWBe21TpQQhKliqpsUDIiUvHmxocI2iXn6cCg6KwEzpPPnQ/mk/2h8pfdJJ92ZDvUZHpthTwPnYbAVIVz1m9Kd8770JUiIyQO5wwOj+6ch6smACwpJVKrDZiHc2ree4e3riupempbbc4H05rNSlyI9/4Mf1TH5Zv3O7VzXxw9u9SD9d6Yt9eqXWR9ftRi8A2wu/D/9zs+uOdSmO1+mPqnJ/dbeD9x4T2fts2b57zHZm0O6/e9RABa4Ti3NE3Fei03BpKBVQwL4D4TzlrHaDRka2urO/Yhs204HJDESadnbIInnNZEkSKKFHUdWJXFYsFivsB7T5qkrHUeynOuCeBJyo2OKcyp/dz/xAfffB554S+BEQ+VB9Mv8Dqwc85f+Ev765xJ7xQTGyYpEBbhl7ZbOIfFngARDIzR4LUHEzrZbaD8umahMGeZpkYoSVnVWO/QcdA5/TBKzM8oaHrHM/T7Rg97rAOtIIkV/9P/4D/gf/xn/hzj8ZjVfMqjR494/vnng2VBFKOiqGODzieoTRHFB9+Mqgqp7rHWAXS5oPZvypJIpyipNi7N1nrqsqYoKowLERV9B563nfjNO/Ii58/f/XOsVguc8CitMKbBAzpNED4YS3oCGyA2DNA5y9QzHT21Lrq7bDgZ6cTFelNnDp5LsmuP7YCJ97SuRSoVmA3vO78mT1M31GVJvlrx1ptvsM5z2rphuVxSlgUSQWsNVVXjrGEwGnF0fMxsNqO2YYWzOXpP0SU463D0JmodkHU+XCgE3sV1V6BxfQt/AIKBYRMI67pjx0aYjqALlmVjL+AAYzyNzcMx6LRZtTG03iGcC0ac1uHEuSAcESZa5xw0HUXvLZI+08l2ALe7Abq+ESCwnL6bZAIl7roJgg0673VdUkikDswRPkyq1thuIg3gRWtFmsbUTQM+nDMB5NBpniSx0njrcMaGAGfoApZDeVFKQV7X4dy2HQ3fHRLX1RY3rFh/7DrwdrExwPV03E/JuMz0vB8wc5mZuiws7m+kF7ub4Fy8ezEa46McF8tz3/83+LAH9odpWPlR6K6eBtIulypFd71/KGsDIYiic9awB8Z1XXemuOfjoji7qkrKstzoHaMuwilN0ydYSCkVZVmwWCy4fz+Uxds2uIbXdR261DoW0xqL1kGYfrEcbMV57NjF7e6/ngOqPqala4QSF+atsEH9o57YV75vu4VQEgwvHKo+Um7Kcra7F0LP5HWVEhPmKOm73E3jsd3CL8xfirZu8VJQ1UF/KbRExxEfhu98r/GMgaYPeeF2xzaOU/6X/4v/iD/5p/40pql5/GjJ/v4e+3t7xOkgnIRKbWqlm1Nlc9GEbRBApKNO5xMOfr8iyLJRACodueMN2LplsrVFWxqc9ZxNp1RVhVLhhBHeY5ogXLt94xbjyado2pJVvsJ5x2gywgk4Oztl/+CAo+MTkiSlKEvKugxt6T7ohMJNNGNrss2w81IKF1uCtRZjHCcnJ3gPN65e48aNG+FEM4adnd3gfyG6yVz29W02Kw9rDKZtqaqShw/uUeYFbdOyWgUb+wC++vJQAJXzxYLlakXZiwJlJ2L3oUNMXSg79Cvuft8KH2waQyZf8F2yHXWLD2XMAGLOwQiwkSv4nq7rXrupW+q6pigKFrMFtrU0bZjA5tMZWZoyHA1Zr1YcPn4UdARVwarIu5MJeh8k78G1AdB472h8s6GzhRWbzLbzMzfY953nMXe/7WsZ/sI2d6/jrOtMKMPn3hiJdvjEWkfTthvdC6YDeAQfMyE6ANpZPHgb6HVrwz4N03Ro/7167QrT+ZSmaUInqe/2W/eeWsmuSaFzHd9cjv7S5/zpGk9rNe9//7RxuazXl1SAJ76/7KGz8cu59Frvdxsvb9NlW4D3YxHwgccHfCnxLj991OP97osPC8z686LXrPVdxP2/jcP2pfKolPKJztgeWPfBulEUhbLaBfuAi1lxWmvSNN2w3pePff/48w19Ejx+H2DqQdPmZxAdaOolIhd30WYxIc4BVD8N99vYmlBlUd391ncLzD4iqj/XHS6I7J1AOIVw4TUDGRFyRoXyyM6wGCOQKLAS34rN3PlRjmcMNMGG4/tAF5QgUpI4SfkLf+EvcHJ0yOzkmFc/+UnW6zW/9mu/htTJBhhs1tR9ydufE9QBIwXGIk3ijQB6e2sLCF0JsY5wxiMijW1CoO2V69cwtWNrZ4J1QdSnZMc2Cajbhvl0ys7uNqPJiLKAqqmQSjLKBtSmQUpFlqTEcYTSsmMpHEL6DqEbhPBo5UhiSFNF23qSVDAcxjRtQ1kYpDR4B2mm2Nkbk68ETdOytT0mTga47oZ7enpKmmVMxmOECJ1wNAIZKWSsufX885sLQXVMVVHmLOeLIIBPEq4fHHB6cspitSKKYiKtqZtgddCalqqqKaqauqzQkSaOkw3zNhwMqKqKnd2d0BXShKDbKAr73bQtWE+UJJ1VgQpRBJ1GKmgPJHESEeuIuHOsbpuWMq+Znp1RFiXgqcqSo8dHjIZDtre3mM9n3L1zh6ZtmOdLFqtFaMHty6bi4g1GBAB3YUUq6F1wO/OlzZnFBtxtGKYOMIUulHO/pQ0Yp9dP9D/TsTqds/elSfecbeScZge8s0GQOplw5+23ODs94+zsjKJu0LEiScJ5hWFD3kG/6BBoHcqdAo/vQ5o3r99149kNIvypGU/TKvXjnQTiTyvTXbyZbcq8sNE72a7jMyw8PrhW48Pc/Pvz7YcNYH6U40dU3bzwfucg+aJD92Vwdrn8evF6DgvewB6dl/W6Mn9XzvXebwBTciEDrr2gk7t4fvXvy4X3eQIkdoBIAD4gpQvnwkWK6emf+YnRz8dCbLz6rPcXwNH5PuiBn5QSh8O6EGoc4re6qoj3QAsusG3Shb9pHYGSCB0isBr70ScUPIOgqR/+0td3eZQQG3HxtasH/PP/3z9lNBzzy7u/xG/99u/yV/+v/zfOT44nihDn319gI50PAbtJEhPJoJG5evVa17IfM8wGeOuROmhnsmzAyy+/gkDxwideCN0LNpRWnLMgoSoLzs5OcMphI2jrMpR8hKSp6g1LcHZ0gu88eoSAJE5CjAZB4yOEQAkVuiaMoW4aqqZlnRe0TUNd1UipcM5wNp3Svv46bd1gvefR0QnWeGrToiLN3Tt3SbOMg729UBqywca/LEvmszlSiBBV4yw72ztEkWY6PaPIC8qyxDQtP//Hfo6333yT2WzG9vYuw9GIpmnpp+KirJhOZyzmCwaDAVmWURbBpG08njBbTHnppRcBwfR0ShwFvyJrLFVVUpVVEIE3LZPJhKvXrxEnaRCjd6BBKY2TMB5PSNM4xBI0lv39vWBs6lrqouKTn/wkTdvgveP6zWvcev4WAFVTkxc5UaRJBwOc9zRt2P9SSoq8QOCpqoo0y4giTdsY6roJIbj9ROU8OtIbcBV+GSB6mAglxlpUlKB0RJLEJFEc2D1jGI1GtE1D2glIpejBVAfmucQoXKTbEThnWMym7O7u84d/8HvcvXOXO/fv8Qd/+DXevvM2pydT6rbpfCbogn8FvtM9SAl0fiwblZzsgV/QSvw0gabL+qL3A0qeBqp6lrBnJS7eCPubrHNBDBuinex7slnvZ7wfZkVsWIQP9j4/Kt3VhxkXP89lgPJu++LDgs4nok94koHqRw9mLoOn/jx42teL2/6E/o1OvylCN5nr36/zc4KOxaQHceeWJxtivvsnOY+puvj7nl4QPbvvL3BVF7/vKPH+M1vnEMZt3j9IKmTnLyjpOxN7LaUUahMIpUSwPVAd6JKEcp3UKjRqEYCh0CGDtLYNxn30GstnFDR9gJNXCDznXkJNXXD1yi7Xr99AK8m3vvVtvvL7XwUfsoEEBu8tfedUOEkEvfitnwv6iU6JEHg4Go0692dFliSb9lPjPEmSsLe3h5Sag6sHaKXCiSIEzgfdiXOWpqrIRgNUFIVQ2kiHUrCxaBHa9VvTIpPovGW/K6FJzi/yvkjV30QdhDJLd7ElSQbedq3wblN+c94HF3Fj2dndYrVcUzdNoIOVZjDIGI1GLJZLHj54yNbWFmmacnpywv7BAcPhiKqquHr1KsvFgrt33+b173yH46PHSCkZDEYIBOPxhP39A4bDMTqKkFIzHI1JOpft6WyOc5Z1XmK94bvfe5Od3V1Oj46RQFOFqJi6qWmbhiQdsFwsGU/GrIucbDgI9K8PVgppmmGt4drVa6RpSt2EwMxindO2Da1pODs748b1mxwdH7FcBgBXlhXj8RjnHKvVmslkzLXr17HWsl7nG9r9wb27ZIOEo6Njrl+/zvbONsvFkpOTk24y67y8rA0RAJ2ninMO2xqcs4yGQwbDIcv1msnOLpOdHQbDAZOtbYoip1kuOTi4SpGv2d7eYm93N3SQSEXcCR61jjadiJuJTQQHde87TyxjUDri4GCX+w8eUtU1r/zWb/D//M/+Fscnx8HbSwlwMBwOSJOIqixw3Q3cWht8WC7otMJ5D64LEOanBDc9jRW4/Pf3ev7FLriL7FF/U+1X3X1w6YcVgr8TyLrYPdf//ScZ8PwwxmWQ9FHti/71epDT/+4yo/Q0cNYLxfvzou+g7F/ncim1B2BKKQShY9Y5jxeh+y50N8vzsn/XRR4mh2DvvwFMFwo55z514a++e9DFMl2PQTfPv/A9F8pzeHCtoWnNpgsRqYg0Qesk1QaYWe87c00QQoUFog+PU0JD97lQILs0CWNt0Ef5oDEVThD9ECDOMwqa3v8QXYO1604Waxt+9gufJRsMsQiywQgQZOl+4JV8gXMV3jc4b7EenA/uz0oJoigc1sD0nLcBV01DXdc45xgPR/9/9v470LKsLPPHP2vtcPLNdW/l0FUdoeluukEBxQTomHAQA2MAGQM6BgyjP2dGxziOg2lglO/oqKgIKKZByXSDNNA5h+rK4VbVrZvPPXGntdbvj7X2PufeulVd3XQ7Ns0L1fecfXbea6/1rud93ufF9y1Brx/bLLWTZ0/ZRpDPEESBWBZgp/Q9wlLJysK7OK8nJL6QeO5lSVw2mhUlH4RqUmUJx0HgI4SNGSsXRszPMUnTAjEzQ38H92rgjjYadQLfp9vtWZKxWzY2Nka/b52WkdExfE+yuLBIrV6jUqnieT579u6l3+9z+vRJjh47jjaK6elpPOmxttZifHScmZkZ6vUG1VqNWrXmyI2CdrtNs9l0nYHkhddfxz333MOVVx4g6feJ+32WlxY4dWYWpRW7d+1iz8QEZ8+e4fjJYxw5foRSqUTmyp6kLpQRRX1qtRoIQb/fJ476KKVot9vs2r2Ds2fPMTo6Yh2bao1ypcy5uTmmp2colUr0uj1KYYnasaP4no/RFuXzfZ/Z0yfZtWtnodyrtabfjzDGEIShI2A67R2XuWJBMCsBaknoxjrmNvcWkynifkzH75DEMd1Oj1arRRJZtd+obwUkhRCUStYJqzcahGHoMlNsiFUKQSkIMUAUR9Tro3SXFzl1epYHH3qYer3Ol37py+n3Y975znfS60bkUcOxkRHq9SrLS9aZ9jxJv9enYPsJi57lM0b1LPAH/jXbsG4SXOhEXe6Au5GAOzwQbgypbFRzfirox6XW3egs5OcyCA8/NXuuOF+b3ZOL3Yuna3lCyPB+hgn9mzlQm4XRhgngw3/XJQgYg1EaoyVGWOqHdJEWxICbm4sUCwbaSJDPn4VjpBgG/8Xu283G7flIEBfk1llgYd19HWT3Zg41t3ViPTxPFSE3m4ctUBqk0uDCbUaDUZazq5W9PnupAs/3CbwAo1OUNkhjZR90qhDSwzdfdJqeotmHYIlstkFF/T712ggApdoo+/bvRwiDH5bsYKVTEKlNBdcDrzn32MPQR6mMUqnE3r172b59G2EY0hgZ4Y7PfY7z589RH6kBglRneFmG1oJyqUSjUWdtzRZeLJXLRXFhk1klU78UUh0ZoR9bWQKjNKH0CYTn0Cbox3awL4UCoRSBFKTa0OwnUA7Ys2c39UaDM2fnOD8/j+f77Nq9g1K5zOzpM6w1W0gfVAqeL5yiuA2zeEI4HSlBt9tzDdO4si2CXq9fqD1rbWiurrpwgqTb6dBzBXwXF+bBwbqttSaAPa4n6XV6HE4Pr3tG+bykkDbAcoLK5Tqrq8t84tZPsHfvHsphQNLvk8QxZxdsnb2bb7qJOEl4/PGDPPrYY3RywvbwETbM5HIbGWnQbrfJ1C20Wi1OnjzB9u07GJ+YQBntnJ8+lUqF0bEx+r0+c3Nz1Co1GvWGQ+8CxsYn2LJlmp27dljETmkajRFGxywSp5VaF6LLeSIF/J46ZfkgYMYYtIOae+0O7bU1fM8njmOOHD6MFNb56XV7xHFEktpinVIK6o06fhAUg2nidJ0q5Yp9Rr0uE1MTLC+u0Gw2efDBh9Ba84M/9AP8zM/8NH/6J++i340KZfNSKaBcKuF5AbJkNVN63X6BtuYJ27hO2Liw3vPFhlGijUhR/vvFttuMeD0cOhkOvyilbHFxRw7+fM85tycjgpscVngaYannmsN0effiqdswETzfz8X2NczpAS4gjOfO0XBY74Jwr5MA8H0fXwgrTDy07vDxhz+ve16XeHRm0F1bWZd8OwdKXOy88mU5IT538mw4P12XvW2ReUPoeVSrNVSm6Xd7JHFk0SRjgQgpJWXp4Usf7Rm0Z6iEVp8p8UPSLCVRycUv5mnaF7jTNGiIWllm/tLSCtdee5VTJVXkRT9a7TZGGQxdBJElWIPjw9qOKjWKLO0XKt3f8R2v58abbqTT6fCSl7yEH/zBH+Tc3FnmF+YtNJoLQxs7a3zNa17NI488wuOPP8GePXu4Yv8+FhcWOHLoEJ12m7Qf02zNDRqtwGULgCcEgZSkShFKiUhApJqx8VHSQBLLgL3XXcVLXvoSdu3YycHHD/Kxj3yc7dt3csMN1zM5Ocmjjz/Ohz74YfJEKyE9UMrGkd29qpRDrr72WkBQLpdc0ciYrdtmGB0ZJVOalZUVDh+2jk+5FNI3VkBtass0+/fvdwRExV333otAsWPnDrZumaHf73H23DmbtaZs2ROBdFl0uTKtQRtFvd7g+uuv5447P4cxmnIpZGpyCk/A4uICy6vLXHPNNfiBz6FDhzg3d440TTelJw53QsPtot3uIKXg7rvuLZafmZ1b9/2pmOcJlLpwlvpkVnRqbKSMc8H1bHZ9XGTZBcdxvCcpJD/0Qz/IG9/4vfi+z4H9B+h1++vQRwkuSzIv9mlF/HJkyZiC1WCdJn2ZJ/EFZPmgo5wOVz7gbXSmNm4zHJ4ZRgvy7fOMp2EnKucyeZ5H4MQQnwoK8tQH/WF2y79+B+ip2ecX3rxch3Djs82XbWwX+XPfGLobzrS7FKI5/FcrK6Yc+LYcGHLAixrO5s1f9CIhZd1tybPiBue3ztY5Yu5vHp/LP9sPg3vheE/S8yh5nkO7dFFvVCgrNZCXldLGhicbjQaBXyIdSel2e7Y/yiy31mAIwxJBEJIWCJbB823B4Ay1rkTsM2Vf4E6TwegMIyzMrbXm/NwcUb9PmmVURgYZcwgPpAYjEW7QyMMUuJm/ffC2jlilUmbfvis4ffo0v//7f8C/f/P3cccdd5Cl9imFoURjS3EYYGR0lDe96U187OMfo7nW4ku+9Et485u/j9XlVf7X7/8+H//Yxwg891JIMNqJKuYzeSHoKYUHVKtlRD8iBG6+8hrSwKcZSL7tzd/DfQ88yNLiEl/z1V/N4sIS3/u938tnPvNZpPR4yw+9hQ998MOgHRtFa6RnU/eN1iSZYvuOHbzrz95Fu9WiWq3zd3/3t8zPn+c1X/t13HTTi1Eq45Of/BQ/9ZM/RaYyut1uMXN41au+hl/91V9zWT4xr3v9t5LFCd/6Ha/nzd/3fZw8cYrf/Z3f4TOf+Sz9Xr8IyaHsQK6NRhuF5wuuueYq/uH//j2vetXX0Gyu8gu/8AvcfMvNTE1O8tnP3M5PvfWtvPOd/x9TM9P4QcCf//mf8973vpdjR48Rx7F7+tZyQu16xMkQBiGlcpmJiQlmZ2eZnJyk0WggXAhveXWZcqnMSKNBEAREUczq6qrLuLOhX4FBOQc7R2h8384E09SKXPqe1ewa8oUHRZmFsEgeFCKgATZTDgMeksTtXzqnJ+cZ2PCelWTQOSI6tN+caCCkzSL0rI+MUprTp09x4sROrrvuOqanp+l0Oq4zhx07ttPtdvG8wGajYOj3+3hSOA6c/ZtzJp6vNpzVJKUtrpwvz2fNF1SI3zBA5uvmBPB8n0qpos/Kj2GMoVwuUy6XP2/EKT/PSxGh4emF5p7aOQw+/0uCUxt9yCe/F08fbdps+2H+1GZO2LBjlaMzMJCh2NRpMu46pMDHWN6Pc66L9d1/BTjnx016ihBazlVaHzIs2isDJ0ppNbiPZj35e/hW2XMzRYgxKIWAcwqjyEmjCJtUQl6j1IbnKpUKW2e2Mzo6Rppm9Lp9kshKxPSiPpnKSOKETq9Hp9PDGFsEOeeASfnF8NxTNs+3GjRpHCPCgKWlZfr9iLGJCWRYI0pTQNqsOdfwdN588oHIGAwKz5P4nk/qUoGN0aRxxpHDx/jbv/1bVGoHRs+DNNV4UuB7kGTQ6/X4+K2f4Oixo9z04hs4deokv/3bv8P+/fvdC2RrmwnhiuFqjU4tzGSwA6rzL+h2+1SMYQQIeglaplRHqzR8n0BpBCmBhHLoUynb6vVJmFFx2WnaGOIoQqWaIHCwsbtffuCjlObt//MdfOb223nDd76Ba6++jv/zh/+HU6dO8fKXv5yrrrqKm2++mc9+7rMWcpUeCNixfRvtVpPXfN3XMjE+ztve9jZe/OKb+Jmf/mn+4PffybZt22h3OmitkFKgdFZkWzk5SqsQ68I+nU6L48ePceutt/KOd7yDt/zQD/IL/+U/85KXvsQWhgwCfuqnfhpjFN/13d/Nm9/8Zv7sXX/Gw488ckE7GOaSuUkPvu/zzd/0DfzXX/olvuqrvpq3vvWtfNM3fxPVaoU777iD//G23+KVX/Hl/Nt/+y1snd7K448f5J1/8E7++Z//mX4/skV4tSbwfBDaVt4GlHKFMb0hbR01PCuz/3xPEPgBvoPbhRCEwkckKSa1EDRCUPVDhDbIwBJDbakdQVgKAWhH/QKt0xiE9Ah8j8zYWVxQKlGplun2umyZmuT8/BLf//3fz+joKA888ABbtmxh27ZtRaf4W7/1m3ziE5/g1k98gvnFBWJXmsZVSHCzwUFWjX0XPt839blnw4Nqr9fj/PnzpE4zaxglyAfA3GEaThfPl4VhOOAeJklRiNXzPJIkYXV1lcXFRTqdTiFFMEwkvxy7WMgk39eltivkMS7jGE83NHfh6T1bjWojdpsf//KRu8ux4faxGUq0kce0mXO90RHP29TwOkVIz7h3UkCqlC1b5fo9O3ESjoJiHRwvFzV2mdk50jx8nsP7t/3nwNnPUqv1VjhTxX0c7Me4cQvHnVJakyRpIYcwfF+CIEApRRRFxHFCFMcobQhLJUbHxmzWt7IEdqU0/ahPP45ot1ooY/Xq6vU61WrNZiJn2QUCos+EfcE7TdoohEMAtBCMjIxRqtRIUk05zJxaa542mSdSD16anGUjsBlDysXcut0u995zL71eH6U0jz7yWBH2AixK4AoUikw57aOM1dVVXve6b+X06TPMzp5h+7btrCyvFgOp5wuyTF/YX0iBKPkEAtJeSgOo4ZGsdkiFYnS0RlUGBDY5PD9jDIZUZzz86KO84IUv4uu/4Rv527/5G8p+CYQhUzaeLKTNoGi1O3zsIx/jsccft6reDp07d+4c7U6HM2fOsLKywiOPPIxSqmjoRhvuvPMuVlZW6LTadFptjh07xle+8itQSvHuv/gLvvEbv5Hv/M7vRCvNXXfeafU18hi3i4lrbcUiDx06xPd8z/cQRRHXXHMtS0tW6+nc3DkC32NhZZkf+v5/z/FTp/jP//k/c/bsWf7xAx/g+InjF7SB4ayTYev2etxz973s2bOP8fEJ/viP/5iPfvRjBIHP8soCnU6ba6+5invvuYc777iLqckptm3bys4dOzly9CjC0t7IdDZ0LAbhKgGZUevg7zyTzbZN2ya0UpbEKARKJHhKExqbwmCEQQKxSvFMhic8B18b0sSiSPksT+ezPa0xmSmacZokZCojTTOaq6sArpSPYNu2bVx77bWsrbXyZsb//L13cH5+npXVpkXScji8eEY26cC4dr7hZXneWLfbLdpVu93mzJkzhYpzPrjlCNEwJyV3nIqZdxBQKpUIXEp4mqYFqpQ7TWtrazSbTdbW1uj1esX68OQIyMU4fZutV6w/FG55qvbUSPD5NmARiqd+0GeDP7XRodns8+XuZ+M1bRZey/mNw20l50IN1yXMOU45OjncphAuedVotNIkqcbzBuvljooeOh8p5aDywobntjFcWDhNUtokJOnhiXy83ByR2/g5rzmah5iHOV95tqAxhiRJ6PV6dDpdVptrduJXLhOWylQrNWphSF3XiZOEcrlMu9MmzRJGRsYYHR0hDAP6vYiW69eeSfuCd5qMNhhp8AOJVpqTJ0/S60fUGyGZytWOrdaD1sZVFXTVk9fBjBpbcU2CUaytrfFP//TBAinpdCOX9m/XlhJbrFdYuDOJIm679ZOcO3eOfxD/l16vT3NtjXNnz7G4uGjPVVBwoaTMtXwcJ0uC0QolBJVSiIgTUqPYOjPDxMgIJ5fn+MS7/pJTCws0pqe4p7XGwccO8pfveS+lWo1tO3dw1913c/TIMcCA0oSBT+hLqyDtCTQRxlhH6N++9t8yPTXFyZMnWVld5Qd+4AfYf+AACwsL3PqJT1inUIpCxkAIge9JXvnKL+flL/9SPN/n7W9/B54neNWrX8UrX/lKtm3fztLSEkkSuxcGYFBkFpf1JaXHlpktfOd3fgd33XUXWmd8+Zd/GR/72MeQnk9YLhP6Pt/02m+k0+1x6tRJPvWpT9vq4SpDDtXoy2dJG8MkuZ2ZPUOWZnhScvzUKY4eOYoBSiWf6elparUGDz34KO9//99x04038JJbbqYx0rBObq4FhSVFWsVuJ/k/5ByJnD9gFdwsnUDb0j5aaxdqc4iiggCLKEos2kmqiTGUhCBTKZn7LVY2YUH6lktlnW4PbSBWLp7vkCEyW1S007PieEplHDt2nDS1kgyf/ewddoAH7rzrbgB8X+LJXGjTdo15FurziO99Ubv//vsZHR2lUqkwMjLCtdde6/TMBgTe4XRzuJDsOzyAbiT55hIPSZIwPT1NkiTcf//9pOmFM/VL2aXWuejgby5jnaeyv8u0p+o0bURtLn/f6yfHl7OPjejLZue7ka+2sUzOxc4fBtzL3HKUcnid3AHKty0mvORojQ9iUMLJ5BOqXCInL7/lOE5ZvL7Atq1TyToHLnfoNobp7KXn5zyEoA1YTUUf79ZAGpec5USK/aDk3pOMOE5RWhGWyvbeIjg/P0+z1aJ8vExQstnKvmcnGZVKGd/zSJKU1tqadRo9AR54gcfo2AiVSvkST/Tp2Re80ySljZUarTFZyr333sfr5hcpVxqQ5TETO1BrqVxRecEgTSA3UQxsxlgF7RMnT687lhCelSmwchEAxayh1+1zzz33A3D27PkN29mHrbVz0oRrbHrotbajNCqDoF4mEJI4iqlPjHLlrn0szc3x8Mc/SQuFatRpSsGplVWOnznPlddcxczWbZw4doKjh56gJELKQUDoeaQ6sbFpLIcqyWLuf+B+xkfGEcDd995NHCfs2rWLlaUlDj3xBGfOnCUMAxtic85Xmimaa6ucPXeWmekZVKZ49JFH+d3f+T2+8Zu+gRtvvIknnniCj330oxw9cmTdwGFrsKnBdTPUAWD47d/+bVZXV7jpphuYm5vjT/7kXVSqFb7qq7+GBx+8n7vvvodSOWRsfIy1taabheWpu95g5iwsf6jgCggf4Xl0Ox26nS7f931vZu/e3QSBz8LiIrd+/DY8EXDjjTfxfW/6Hl7wguvIsox7773PCZm6+k1GIyUo56TkyKQYfnhD6NJw32lccxOORKcxZEDf/R5i0SaNbZLKw6bk5m1RGFsJPA+feHkIzbV/4/4NEYfBykfMzy8QRRFaGVZXV10HCWFoEQytMtI0F+Jzg8CFr9iGi3n+WE7KFkJQKpWYmZmhUqlc4DQ9GdKQrwcDeZD89zRNSZIEIQQLCwtUKpXP65wvNdhvvkH+56k6Jf8ydile0Mb1NprIw0aXey+4eHht4zpPJ9S3EeXJz3tjaO+CLE2d0zhEUXjb6ve58iQuDXzdZTpJm0td80anb926QiCEt27ZelQJhjuEfB/D6Jjtj+21ZBkkDmEdFnhdaTbRqysIh0TZkitWWqdWq1EulxHGlvcC20emyiJZ1XJl3TU+U/YF7zSBfYBZZgg9w1pzjSS2yrr4qS0bASiV2phK7hkxcFgsAimxBFhja9sAwhfsv+IAL77xRubnz3PbP3+y6Fy0oeCqpCqjXC5zww03sLq6yuzsLNPTM+zZs4dut8OJE8dZWVm1SpkekJkiZCWxfCctwAgJUpNkGcb3SQSs9vt4pZD9+/aysrrMyeUFVvoxKvRo1KqsRn0evv9B4EEAAr+C1FDyQ4RRtOKIcmgRtCzWrGUd/uLd7y7uXe463nXvPZveW9+TtqyLJ3nwwYd58MGH1/1+8sRJ/tfbf3/dshx9EyK/9xkwQIdAceTIUd7ylh8B4Nd+7depVMrccsuLuffeezhy9DhjozV+4zf+Gw888CDf+I3fyPUvfCGdTpdWq2WLUjrytX1urtSLE4xLU1tbLQxLZGnKe97zHlabq0xNTvGSW76EWr3Cgw8+SKYUDzxwL7t37+alL30pk5OT3H333bTW1vCELXqbZCkmo+AzYbBFbzc2ovV3wF3/gEdned/WKQkrITdcfz1JnPLwQw8jfOcwOc4buYPm9p8qVdQ61JlyhKmB450jX2lqcBx1wrDM1NQUcZwgPElzrYkxxgkp2k5HSA9RoGlWmuL5yF26mL3iFa+wYqlxXGS95YkRwzP1jYPyRqRiWMdnGFkYDtkMc5/CMLxg4Pyi/cvbxZzEYR7T081wzMNWG3lReXu54DjCbpMZjfC8wiEEitC9dtmv5OG/YckAhhzQC84r38sGbpa8tLjrZteWO015sWBgXTjbYIqxQSlFnMQW+Zd5YV9DolLiLCFRGWGvh4BC+2mt3SYM7UQm9IMvOk2wvvFdzszG6hDZjLfMrR6GoZVixxTK2iZL7AArBWgPSTaUrQTWj/esMjc2blytVvmWb/5mfv03fp377r6Pr//Gb6LZXBk6+kD4a9v27fzBO9/JJz5+K+985x/w6le9irf+xE8wv7DA2972Nj78sQ8jQuuBK21T5z0h8JQrASMgzjLwoRvHKBNT0nCu16JtMirjo/jVMjtH93HV9hkW04jT7TUeO36SXpwwMjlBJSgxf3YBaQztXhtDxu6ZbYxMjJJoxWKzycryMlIKfOkPeDLGygBIIahUKpRKJZTOmJiY4PTpWTrdPmFgyccCqNZqqCxlbGKSpeVlW1DWKBeGsmFSBGQqKxzUMAwYHR23Ct1pwu7du4mjmKPHjjM5MUa73eX22z9n24CEldU2H/norQD88Z/82YY2Yh0Fnc+ujBUOFdIWwsyyDJ0p0jijUq3wEz/5VsIw5Jd+5Zf5pV/55XX7uve++9Z9l0Dg2VBkrx/h+bYEgFaX4U3kiA/WQZKeZwsSZ6ogZmqlqI82+NGf+DFWV1b5yR/7yQJ9xGVVbtynvWj72c7GhA21OXKnrWlnP3uejxPHp9vrkSrFyVMn+fN3v5soiZ2OFOsInkoZHBD4RRuy/F0A1jlK+ffNHKYnC9EMWz6I5bXG8ir3OZfpi/b/3oZDWM+0beRVbeZA5ci2nVhbpW+0Qnge0pPY8kkDfl0OQfueB76PEL4thZLvNz/28HkM/3cdTK7RJg8PiPz/QxOrjeHQPGtYkCn7roRhiPQAIYiTxPJ5kQSBRHp5KE8gAx8v8C1pPdUkaUKSxvh+QOgSZDAglcAkFt3taWPrkz7D9pxymsrlcqFXctlWhD5Aorj7oUdZXFxkZGwE7Qd4Mo9jAGkGJsUTbm3HsdEuNCZReHiW3VTy2TGzk1s//kn+9I93Mj0zzZ/+6Z/y7d/27URpn1qlDhiyLLWZU9o6cEeOHGGtucaHP/ghUJrrXngdYcl1glqjUmVRAg0ohz1oMeBFOGqVKAVkKD714H089tBDhEYxPjXJD/zcz9GMIt75//0BlalJwlJAtT7CW3/mZ7jhhhv5hte8Gj+skyU9YuDv//GD/Mmf/SnnF8/xute/jh//sZ9kYX4BT3g25m0UnuMfVaoVbrjhBm684UYqlQo/+JYf4Ou+9ms5dvwkaWoH/muu3s/rX/+t3HXX3fzGb/4Gv/gLv8onP/VJ+v3MRjy1crFniXHAnhCSa665jp/+mZ/izOwZbrvtk3zoQx9iaXGeF7zgepZXVvE96zi6OraFlUohWiub0u8UZHMUR/pWeDNXoY3jGIxx9Yt8VKbp9XrU6yN0u51in35RlTwXWrPeSpZmtnPSeYFkizBd0l3KHR7XFrXWxTKlHCqEQRhRtOvl+RV++C0/jPQEInTXW5CcIPAlKrNcMN/3ieMUrTRh6PGC666hVKrwxBOHWWut4bm2nynwZN7p2Q743Lk5lBPwnDt3Dp1ZHSBtmyLGyQrAFx2mzSxHmPKBZGM9sI32ZGGfiy3PSbJBEBRZdpvxXS7HLjuE9q/ggV/q2p4u9+nzjSFvdJIu5jRtPL+ner7DobGNGXjrzE2WCuqIW8+Kzg4qRiitnc6a41BKgTQGLYzNEt9Q5gUHDgjHVZJDy3DnlOq8RJc9D2EoJEguQNmcQ6W0cdIyNgxXrdXwXJmgbj8iVYkFyoMQ6fo34VmEtVypgDCWbiOMzVQ1Co3n5AqsFEqq7LVrxwd8pu055TTl6YO5nHwO713S7P2zPpGGIADteezYuw+tMh55/HG7njCUayXSCEoegBXEVEaBFgRhCSk8TGYohSVGJxv83M/9HNu3b+Pt//MdXPeC6/ja17yK0A+I0j69fo/Az2+voFwq26yvhXl+4Rd+gU5rjYX5ecZGx9i5facL65iCV5dllrtSwhLrJMLWL7OxH/pxSsX36ACRSpEGOknCY0cOsRJFXP3SL+Hb3vAd/NEf/zE33PRi3vzGN/KZT38WgF7SI3BFwqI0op9EjIyOsXXrtqGQYgq4DC6dZwz2eOCB++n1Otx00028+lWvYnb2bH6bAQhKFaa2bOHxJw7yoQ992NYzU5owCPE8YQd4rVGptqFRIciywQAQlkJuu+02piYnmd4yzWOPPc5Xf9VXcfDQEwhhic4Ig8o00vNsgVnXW3z7t38bP/TDbwHgH/7hH/jkJz/J2NgY//W//lfe/H1v5vSpU0jfZWloq6+EEPR6A4cJ4N9//5v56q/+Gn7nd36XHdu38uY3v4n77nuA//n2d7C6sjqI4LqO4EkR4CHHSQ4pA7sPxb4Giwxp3xE0M/B82+bjJLMdj3LCq0qRpZqykGQIskTxla/4Svbt3s0f/skfkyYJu7ZtRYQBhw5ZMdI4TvGDwGactDtkmSLqJ6wsLTvSqH0mSisuNr58MUJnrdfr2bqM7j3frFxGbhv5L5caRPPfhh2yPMNoOLzxRfuXtc3CbjBAgjZDFvP1n45zm2de5sfY+Hcdl0tYCocQoKVA54W0Ddi6qcJmKjsHSPq+C+NZB8Q4p6e4JuxkNq8G4QlROFA4FD8zyhXsdTGV3GHapIfIr962W+1qkVaZmpqiXC7T7/dpra3RabfRSlEKSwSByyj1JEEppN6oUyqXiJOESqdDv9slyyd6WpOlKX2X8S0lhF6I5z3zqOxzymkCCmG39HJht7xNSVGUeJDasHBmFiOgHFqtG9IuWoLWKf0sQ5Di+9bjRduSFB4aaSS9qE80H/Pf/ttvEAYB8/PzPPLoI9z2ydvoRd3iwFYdGjzpc/r0LD/8I/+BI0eOMHv6NGmS4AEnTxxnbmHexQGxTpP7Zx2WQRM0ysVIwgDSlChVSI1L14dAQlar8NKX3ExYqjCzZYa11TZ//b6/5rprrmPf3isACIOQJLUeeKYzrrv2GsJySLVaJmcpW7GzC2dLQkiuvvoavvu7vou/fM97i/Bmbrt37+b1r/82/uPP/ieuuuoa/sOP7OYd7/h9Hn30MaIoGXQCEqcPgpspCcqVMtu3b2f7thnmFxY4d36OyanJYt+e9K2j7DoDrXL4xZ7DRz76EY4eO8rk5CRnzp7lyKHDfMVXfgUz0zMWLcI6gGma2th+zj/TmjAMyTJb5DJTmr1793DDi67nE5/4BO9//984HpTNLsvpS1IIfN8Sh9TlOE7YUJyQ0iJtxhQVvDd2tlk2UKQ3GcRkVALfwu15aQW7IalR1Eplfv7nf543ff/388C99/GC/Vcy2RgB3+PkmVmMKR5t0SEf2L8fpc0QWVkWhMqLnD5C8EWvydmwA7SRCHsxe7J1Nu5n4yD5VPf3RXv2bfj5b3SQLokSXcSGM++Gl+V/12VeOmfF9mYCzxeIvGh3DskLgfA9AjybwZbLXggnzysMxskISPIC8IPyY8OIUx6yz38pUCkpi/7KxvcHsjl2u/x3G0oMgoBavc7k1CTVSpVWu01YsqKUmcoopSW8wLPZ6S7CEJbLjE+M40mPKIrodrukSQJGkKQJvV6XbquLyjJKoc2yM89CPcznnNOUJEnx0C97tiWsOGTelKUBk2XgS0xem0bHJFFShEsMVpRSOL4JQObKqdgfBcdOHnYhF4Nowdnzs85btwOpVU+2pLskTbj33nvIsoy1tSbCaELf5+zZM8R5+Q8JMpBW1NLSSIoG7OGR5gqJ2i41xhZ4yRP+VuM+f/fRjzBx190EXkAQ+Bw5cphOt8vv/fZvM7VlGokkUyl+4JGmmv/8C79AP+rhSY+/+7u/Z2V5xRGU3YsoKMiC2hiiOOauu+5itdkkSazj6ns2WyNNFQ89+DA/8zP/EYHhfe99H+12h1OnTg9mTcKGu4wxaGULDyNgfn6e97///WiluebaF/Arv/Jr3P7pz/Kbv/GbrK6s2vuvNAivSOuWQjiUx0JNUT8mimKuvPJqXve613P48GFOnDy+vvaTdsWOhX3KXuCjEkWSJvi+Lf/ie5Kx0REOHLiC2dmruOmmm9ixYztKGT760Y/TXFuzMglao7V4cqQpb4Yuq8UY66igbduYnJripS99KXt27eY973kPL3v5yzh9+jSLi0vceMONRHGP8clJPvyhD6HSFHcbCwDL8ySxyvjHD3+YRw8e5PzcHCePH6cfx2Ra0e52EMJ2VMZYqYMsyzh79hzS87j22msBXMgyd5ouFs74oteUW66jtFGkb7MsqNwuB424mOM0vP3F9nMpu9j+v2iXts3u18YQ3UZ7Os9neNvL/b3QEHSfjdUCH/B1h1BL3/PwpIeXO105D8ihTI6Z5H5z3836ZcaRRI3J27rt14rEhHzgMLI49jpHz+0rCENK5RLVWo1qtUqiMjzfR2NIs4w0S/FS307ulSBNy3jCo1Fr0Gg00FoT9WxlDwwWqWqtsRbY7OlarYoUkjR+nofn8sHyKdlQS3A+P2kcW9gvMFxz9QH+48/+BFACIxHaedBSg4MttTFoI0iimHZzjV67S5Zm9KIerXab8/PnOX36FONj48wtzg0dy5rSlpyXL00dypNmGb04dmERILO1g4Y3zgDrtmWDxXE6iOXkZiDuJxw9fBwpTxWx5yRJ0Frz6MGDeIcPg3uZjNMzuPOOO4sXTwhBHNtsBSvMuR52FgK0VpyePcPc+fNorfF9l92nNJ4nmZub4x//8YNobbj99tvp9SJX1sQU+knGuBdN6ILAt7K6zKc+9Sl0ppmcnOLRRx/l7NmzfPKTn6Td6uDhgwtX2LR9z7m2OTxn+Udra20OHz6CEJKV1WWOHzvGH/zB77PabA61B7H+/kmxTh/n3nvv5R1vfwfHjx9jYX6eO++8i7Gxcc6dm8Og3X2w16H0Jvu7iA0LBg67HkopmqurbJ2e4fu///t57b/9Fj7wgQ9w7733s3P3bhYW5/n5//Sf+OhHP2odVTnQespNo7nvoQd54NFHUJkVzLSnZUgz5cJ/xqKfLgvLEvrXZ24NbHiwH1r+xcF2nW2sRv90bLNQz2b73JiV90X7f2MbkaXNfh/+7ekSxYf7pM32P7TEJmxIgUoUmYRMKTtpxyYU+Z5fyBAIM8iiE9qiSzJHrcidJVEMnRv/mfw3QVGiSbiSLBsz/GCTUCI2ypDzTD3PI02SAhkTMk9AUaSp5Tgl/YQsVWAkgV+yKL/w7cTT84irMaWgROiF+J7Pli1bKIUhaZzwiY998Cnf+0vZc8pputz0xvUbYZUmdV5HSYDnsbi4wMnTx7j2+hfwYz/6k3YQNhQNx8aMik8ApElGp9mhs7ZGHMfOu20xe2aWo8eOMDW9haNHDxOnqeXw+HnqpyWT55yHIPAJPCu4o5WyWV7CavQoo1FSo4xBK42vJUJb2MnzA7IixmgH6tzzh4EukFYalSnSJCGKIrIsJY4iFhcXWVhYICiFtFot62htEDfLLa8HtG5W41A4rbOCT6bU8Ats/3a7PQBWV5tDMgID00oXodL85sZxwuLCEhhotdr8xbvfTb8f0ev1nNaIcc8Q6/DhA9o+NwfTpVnG/Pk51porPPjAfZbkHPU5M3uGZmvNcpA2tKGi5IkQqMwKYz7++EFOnDhFmiZkWcaxE6fxPI84joijyDpuOnfBNdLlhDxZq/Scw4fJOxt7f1utFg8+9BBLy8u89pu+mdlTszRX14j6EU8cfIK11hprzTWGu64C0DQG40TBoiRe5+p4vlfMCD1PFqreGMH58+fp9rrUavViIjJ4Vpt18KJwwr5o1jYiTE9mF7t3G0Nww2GezRCNLz6D/ze22bPejBh+8fDc5aO06/SYzKB32YxXlU+6Mw2ZMaQin2jn00nhqg9YdFsJF+/SxobkhJMfQBQyN0VPIwYOlZ2IY8NxDt8ahO3yrfKw4Hq0NV8HcHIDil6vz9paizi2obU4ThnOFi0mJNqNaakijVLifoInZSFKHPgevufjez6BZ/mak2MTVKtVsuc7ETzLsqcHLRdt1T5IPwxpdVo89OADyEDy6n9zHUrB2dk54jhlcnySkdE60pcYITDWpbahtxmFSlNHEtcondGP7OAehiHdbsfyYtKYleYKUT8CDEEYUi6X8QOPXrdHEsUIDaHvUalUKVUrCM+lVvqlApUQZoBiCCEvY2gevFBpktJpd+j1uqyurnLk8CGeeOIgIyMNzpw5y+pq0xLLhUWL7ABLESPP6xnBQBTSuJVyJENrCAK/0NmwoVNL7r7Q8XLCju549poGxzVmkF0mpSVrd7tdKqWKrb/VatGPeiitSTNFliQuszEDl+mW6ow0TWi328X96DsnLr836zouR0ZSOis6n34/ot8f1Cwa/nzhzbYoz+WYHWTdZkMSBXl234kTJ/ir9/81//ShD9JcbdJxZTq00vzyL/0KiUPs0GZd/2uMthQCbYoMO6PBpHmNKivBkKY2Y2XHzp0cPHiQw4cPc9NNL14PrQv38Ne1s9xRG8LqvzhuD9K4uZBrMmwb+6uL9WEbM7I2226z/T+TWWZftMuzzdCjSyGEn48kQdF7boZs5ccceiGlNyitIqUYTOqULamFtggT2NfZFxKjJFKwjtMkHf+JgstkkSouuJZ8ed49DAf6bBwD8gmBXWqpHV2Wl5cJgpAoilwJIvC90KqF+z6+H7plAUYZon5ES7QwRqOyjDAIqVYtItVtd2mvtdGJIokSAi+wRdOfYXtOOU15qu1TcpwMFoKUAqM0YIiTiAMHrmL/VVchPKs6ncQxt912K0tLTV72slfwwhe9gFqjhpKWM2Qct0eGHvgenoFQWMBnxB0HDMJl3t35udu59557OHNmliRJmZiYZO++vWijuPeeezkze5aSHzA9McHOnTuZ3DJFUC4xMjnGrt17qY80KJXKlIKQkhcghY/Bok/SH3B6hLCOVe6NC2lPyobAQG3JSBKrfXT11Vfx5V/x5QRBQBTFxP2+IyILiksQeTMfdOByiOg3GBwEaIX0fKQUeJ6PcZBrDisLK8AxeBC5bTYwDzll2hh8acvanD0zx9aZrSwvL/Pgww9z5twZev3I1eFq0lxrEfe7aOVeDgGFJoEowJjCpLTnprW9V0ZrkBLfaS0NZz/lpOcifLpp8zLFLGxwHLv+oI0OPJwi00oNakDlUDcIkighjmK0UoR+YJ1zqTl65AhjY2MEgU+WZqSZrZvoeR5xkjin0abhKq3c+Zhc4aEosBv6Pm/83u+l1WqxtrbK1q3TQ0KJDrkzeujZ5OduGGoqT2XS/AVs+ex5c3ThgrWH+q3hsN5m6MTGAXhjiONyHKXNBvVnXE/oIm3gchTEn/bO1x3nMrfdOAe4rL1fZM9DKODwsuHf8s/rUfgL+4OLzUKMwYrzDsfzN16LHJrhom11gFJAqRzihSF+GNhKF1qTxjFxP7aZyig8ARJZTIa10UWBeoywRevzPhkLHtgwnpXB0cKW9kLIwQTYDN0TYwYyBvlkyzh+rPDAZCRxQrvdwZOSKI7odXt2/JDSoV8+gR8ihKBcKlntQA0qzbMKbcHzSrmC53l0Oz20MmRphs4MWaIK3u0zac8pp2lqaor5+XnbEBw0sh72zKESBuiMAemKl0ph1xWej1cqE2Uguwngo03EtdddS6fXZ3LbFKkHHZ2RYkhc+ExIK0boCUlgwFeGEEHoSYyyxG9fCIQwTM9s4boXXEOjXqfX7TE5Oc3uPXv4oz/+Iz7ykY/QbDbxhKRWrTI2OkqlWsULfMrVCuOTE5QrFQI/IPB8fOnhCen0c7SrTO2UU8EN7LagcI4UwfoQkEBYsbO88W6mKJy/f84R2NhhB75HpVymVq3RaDQIHIxaqVbYtnWGLFPMnZ8jCAIajRG+7uu+jpOzZ5ienqZeryMcdygHNETOAicvLzIIT+S125qrLRqNBlHc54U3voC1tRaZUvR7Pavt0e0RxzFC5Bo5BhhoIeWzG6utZBWvBZbEbUvXKKRn72+WZu7e5DdxUPZ4XedWzLwsx8uT3sBpMrbRrfMJhzq8PI18Y90w4RxJgdUa6Xd7NJtNlpdXWVpe5AUvfAGd9hr7rtjP6VMnWFxcYMf2HYyMjXLi6DHSLKHf7XHmzBkWF+dJ0sTVtrPn1Y9jzs3NEfg+r/2W1/Knf/qnnDx5AjCFlpOF1lXxef3f/EIGDtYFI9bzzIkSQrrCx7YvGi6ECps7L0/m9AwPvht1evLtngzJeirLNzoAm09ILwffvnCL4T7k8ra59PfNt8knLM9c47v0vXhyGw6rbkTah47i/smh70PojMnLntixyxgx6NvzDwJMXvBU+PiBpNKoUWrU8EvWcZK+R6aVjXh0u+hOl6QfoTNFgC3BJDW2qL3ecIquT9LGuGxq8HIUSdq+wAgDwgUDhesr875tnQPl+uV88o1BKU3U74MxRFGfKIpsoovvg7FAR66ZVwpL+H5gUTPtxnUpCf2QSqmC7/n0yj3CoITvh0XZKPOkac1P3Z5TTpMtaOkP8UnAjYzFw5bOUdBOcdTGaF3n4PbjBSGlygg7d1/B2Og4a2tr/N7vvR0pBGGlwt0P3E+pUWNkaor6xDhho4ZfKuGHJWq1OmP1BjXhM16uUPYD264dDJr3Z1fsv5Jdu/Zy9NAhzs6e47prr2dsYoLf+72302p1MEaQak2z3aY5FEraaJ5rZHlqvxQCZQxhGLiQWN4oBgM9Q3fnAnOrWJRlfad+qQ5CCgj9gFIYEIYlatUqpTDE9zyCUsj01CSZUswvLhIEPtVqlfvvv58z56xsQK1WAyEKnaKcaK4dMVm4Rm5J+AKVpm6mI4oXSOTijLlPlMO9rgkIBk6jwb6oRcKsEGRKIfGKMKcQNv6eQ85GazzhEbjrEkJQq9UIfB/pBAY936dcqVCtVghLJfzAt7M1DbVKlVqtTr1Rw/cDMNZJq1QrtkaSHLSPJLZ8M9/3KZVLRP2IJEkZHxvDYIiiiHarTbPZpN1a48CBAzRXl9m9dx/nz82y1mwyNTVFrVZl7uw5MpUR9fosLsyzurpiy7sgMVKCkCRpxvz8PEIIDhw4wMTEODe/+CauPHCAfr8/eNBPNkgYMei4nePuJrrPOxtM2CQIDUaxmaOU22Yk74s5QZv9y1HeS5GQLxyo1x/7Yryc4XUuQL027UsGHMr1qKMo2sRTcrVMvv5TR6iMuUhvJ4Y/igt2vbHHvHC/Fz//jY5w/i9/PsX91oWoyGC/rBuy3D7Wn7SdDNuzy5OR8h+MACWsTpLByuKUqmXqIw3GJieojDRQaJQnML7AK4Vo16cszp1n+fwC/dU+yggCWUEKiTQGoTU4wUopfISQrnanJFMKo2x/GgQB0vPQOnNJJwYlvGLsxYDIJ/UFYq9RWe4AglEKozVxv0+mMuKoj0pTJ3tgJ83SrBfpFMags4w4y4oEpziKSJMEISCNYsLAp1QKbf9vT+Siz/Dp2nPKaVpbW0NKD6VScrKZ9LyiM/C8gf5NbgYn+sWgkSZxQq1a40te8iW0Wh3+z//+Q37rN/8HmVaUq1WE71MfHWFy6zYmtk4zMj5GUC7jl0qMNEYZb4xQET5TtRqNILSZBsLWxCmFgUvNj/iKV76SLEk5deIkWyZnqFZqBF6AJz1UlgsKXPqhCiDwfDzPoNLUIjBC4EmJUprrX3Q9tWqNw0eOsrK8csl9gXWWPOegWBkGq7qah9YuZtKhI91+n063R3N1FYTAkx7aaI54Ag0kmbbQr4B773/IFZC08fC8Ey1QL7BOi5QFjJxfdZaltqPLISfn6NlYk+tPnNNnnzngMttshh3rsss8TzpH2t3VAhpab570CPwA3w/wPI9arWbL7rjirH7gU6mWqdVqVKoVSuUSRtsw2EijwejICKNjY4Rh6PSfAmq1OtVqxQpzYs8vjmI6nQ5hKWTrzAxZlnHq1CwzM9MgBC+6/kU06nXm5+fZtnUr7bUW3W6XwwefKMrszM+dJ00TfOlZna1KhR3bd7Jt6zaUzmh1uqy2WrTaXUBw1VVXsXXbNh544AH6/T4vfOH1bNu2lUceeeQyZ9S2E7KlYIRzPAcdu20jzx+4SQiram+EFezbiCJdzHHaaJdCpIa/Dw/Kw9tt3OZSz3Kz0NLG423GyVm/LhSj98YurNjNxTyZS9lGZbjLM6s6ssn1bHpem37F7uLJj77RyR12mDYLoQopEUOZ0/l/RYEYDZ3vUOxbCBBGgrAoj1FO000KjOcKe2uDluBXAirjI0zMbKHeaCB9n14a0c9iEqWp+D7VRo1Rf4Qk0PR1TBR3SXoJsYrxReAiJKCMtt2qdMQN6SFcORaDrRAgPB/he5DYem/KWOEbC2rIwfkPiWNabmae0efujzGkma0Dq9IUTwo86TtpBNuwtLIcLJVmqCwlSxO0UvR6NsLQabfo9ToEvpVTyTJFlqW02y1KpdLlCWA/RXtOOU02DJMXMRx0IrnQn3SFCrMsKwZMU3B8tJscG7rtNlnUp1StceTQ/fz33/hN4jimXCqj4gTdj1lu91idX0Y+7pCRIMAPQusEGEFoJLUgQCiNzhQT4xPM7N7KA488iDKGbq/L7/7Ob/Oar3kV5VKZ2VOn2Dqz1RHsrPeOGSrBMfwWD727Smt8aQh8D5267AJhxbyUMrz61a9i565d/Nmf/hkry0tPfhONcSmemQMMJEJItDZMTExy7bXXsrq6wsjICMeOHaVerxOGIQvzC2it2L59B5Pj46wsLrLaarFt+3bOnz/P8sI823ftYuvWbRw6+Dj1eo0DV13NQw8/TGOkwc6dO2k21zh9+jSVSplSqcxac412p+Wey2AmqPWgw63VqnS7Paanp4njiLVWq5jNFEhZZjWrCtAxR76HbDiyLTZU+F6/XmrbmYPNV5urAI7Xlc/EKdJsbZjTYPSgGKXneUghC95UPtDlkHEeGk3TlFIYcvMtN3PVVVfx6U/fTqlUYnW5yQ//8I9wyy03c/ToMcIgoFap0u60ydIMpVMwkCQxcRxR8gNqVRvXT9MUlWVEUcTc/HlOnZll9tx5Ot0eBw5cwVVXXc3c+fOcPHmSNE2Zm5uj1+ttei82mg2vyoL3UAxyxnb6pVLIy172cm677ZOXtb/nuvmeLdmTZhlgCAKPjRIEwyjRZoNtnkCRL8vb0GZO0sZBeTO7FJ/pctdZ50BdHIdZ9+zththQzEXP7nJs2Hm4lD2dowx7eOufxVMNww07nxvRuWHCs8O83dFyZHYQ8jfC7Qf3Njl0CWOdlDzsq9FWnk9gQ2JS4AcBMgyojtapT4xSadTpRj06nTaRylhLerRVhF+tMDY1zpbpLZQbVaZ3bkUqRevcIt21LlIqPL+ML60wpkWbjEWXRGYL1ItBNCDOMlAKjbZUBuMNAot55KNAnQYhOozlx0onlKedM2U8D5yGHAg7SXVim0ZrFAaVpfafyhAMSgtJV68zVgMJg25X0u/3KZfLTyOw/OT2nHKaBnFgCk9da2VDOgKL3ogh2FqA8K33mya6EOeaO3uWqNOjXG2Q9SPGx8fZvnMnu2ZmHCktRWmD8CRKWyZ+uVKhVK7Q6/dAacoyZLTWIO716fQ63HjTTXz9676Zt/7/fppOr8vCyjLNTpup7Vt50U03MD+/CL4k0RkqL7oGg3f4Es82yzJb1kLYys2ZVmQu+2psfJzp6S2UyqXBBhv7nOEQlhQkhSK2VWNQmULplKuvupZf+C+/yP/9p39g+/btfPQjH+HGG25gx/bt3HbbJ5mdPc1XfPkreNmXvoxHH3qYO++6k9e+9lu4/bOf5hOf+Dhf9cov4+u+7t/w8z//8+zdt5f/9t9+nZ/8ybcyNj7Om978fRw/foK/et9fcc01VzE2NsZnP3sHjz32uA1jSRt2NAZ8KShXauzffwXVaoU777yLV7/6azh95jSfuf1zVMplduzYyamTp5mYmKC5tsZIowFCkCYJ5XKJTGUsLi5Sq9bYs2c3q80mpVKJ5mqTdqdzgZL5+ltnyY6AE7Z131xlkSAIGB8bJQxDZmfPWG7U5RTs3cTq9Rpbtkzxmte8mn6/h5Q+C+cX2bt3L1dffTUG62T1+hG+H+L5IVpnGK0plcs0GMMTEHiSIAzRLkEgCHx27t7NTbfcAp5PlMS01tbYvXsPx44fZ35+ni1TUyilmJ+fv6xBw/M8pOeRZZnrrBzZUwgq5Qo33/xifuRH/sPzxmnKRVozR75P03zwXa+nNOyMDDtE+fdhZ3rYQcp/Ay7YV77scojhn5c9WdzrSdd9qgcbcj6eZM11ZzTk+FzsXlwK178cR3T4+6U4aTkXyapf5xpoQxM6MTSUu7BkLlAphgYEYSzyU4RIhSh8SulJ/MAnrFWp1etUq1WEJ1ltrrK4uEBYrZDohF7UpbeywlqzSRbFzGyZZnxkFG8mg17CYqtLpFJCLwBh67wJYTDallJSqUEqz/GEbFSn79CesOTj+xI/j1wMTwCMKK4pv/C8zXuO+qCFQHue1Vxywptgk2VkXjDWtXml7GQ2DErF9rnTJIQsqoREUYRSqnCano334jnlNGk9qFE2SO/Nl7vQixPbyoloRhuU1Fb9O7M6SY889DCnX/qloDQiTXnLG9/EgeuuZXx0FOlJlDJoIWwj0RqdZvhBiFcKyLIUozSB8SiHFVJl46tjk+NM7JzmV3711zh+5iS//hu/wUJzhbVWhy07ttFLUys6JsB4luh2uSQ1geVm4Rqd7/ukub6TMg6ZMeTZoPk/M3j/3L2y9y1LBwRTIWXBR5kYn+BFL7qeP3/3u5ibO8cVV1zBtddey3XXXsPExDgf+chHWV5dRUjBgauu4jOf+xxBKeSKK65gz5497N6zh207trOwuIQfhFx/402UyiUee/wgwgj27NrFzh07efWrXoPWGQ/c/8BghuKI2gBB6HPVVfv5T//p5zl16hT3338f3/zN38hdd9/FnXfczYtedANvfNOb+KX/+iu87nXfyqdvv52bX3wzAHPnz7F37x6aqyv8zd/8Dfv27+NNb3ojd999Fy960Q188pOf4taP33rRez3otgcd5UD2zXZw4+NjvOY1r2Hr1q287W2/NYjbY4rOEUGRxLdRLLfwlw34fsDk5CRbt24FoFKp8O++67u55SW38Pjjj/Oe9/4VC4sLhJ7ndL8cauo600xl6CxDCBgdHUVnitGREfbu2Uu5VmZm+3aue+GLyFaW+Ku/ej+vf/3r2LF9O7VqlVqtRrfT4djRo5elrq+K0i159qGF7avVKjfc8CL+w4/8CHfffceT7ucLxaz8Rd5Z9+n1uiiVFQ5Trss2nDE3jDoKIWg0GoyMjBTr+b5/AcqUO0tJkqDyEjpDSNZmDsPlZMptzPTKt79UeG4zYvqTOSv/quxCoOmpbb7JPcs/X0D4NhduS+E4ieJ7bjljQIDlGbmi5hgX+pU2i80G+6zzEfpO808b4igi7kcIBNOTkwSVMu2ox+y5s7TbXVZOnaOSQjgxSTWsMDo2RtTukXR69FWGwhBIH89Fc3Tm2rAxSM9H4GE8QRorMpUSCMv1lA4xtfpLNhw2XCNxmCuXvxfD9yx3pHJbv23ukFn+kkAUzlueFJW/M1rrdWg/UBRWfybtOeU0bZwr5HDecKdETnAuGqgDQn0Pk9nOpt1qIzzJvffey/333suP/PhPUJsYZeHcOaTnMzE1hV8qF8cYPNyhODXCIg/CEHf6fOr2T/O+f/x7fumXfpH5Tot3v/e9fOTDH+GGG27gy17+CjpRl1RleK4Aock2iSFtEqKzXrjViNJuNlqpVEjTjExnBUEu3zx3mHCfhxFfyAdyUzRGz5GwM5UQRX3iOOKlL3kpC4vzvOxlL+PI4SM8fvAJbrzxBiYnJzl06BCTU1OM7Bvh33z9v2Hv3r0cuHI/11xzDQeuvJIkTak3Gpw8eYo07hNHMWfPnefosSNMTk5SLpfYtm073W7bpZbmyG2u5mF1n2a2TLH/iiu48YYb+LVf/XX27dvHY48/xsTkGN/9Pd/DN33TN/M7v/W7vPZbvgUDfMVXfAWrK6s8+uij3HTTDSwuLvCpT32KrVu3MX9+njiK+Yav/3oWFha49eO34vkSlV3oKAyco8H3XNohd+pqtRo33ngD+/fv521vs4iD7wtXNFJd2CEX1zg4gJR5oV/78ne7XWbPnGH7th286EXX8+CDD/K233obhw8d46YX38gV11yN1gqVKYeqpqwsLzN/fp6zZ8+wtLJCmiRoY2jU60yMjhKEAeOTkxw4sJ/l5iof+MAHOXPmDL/7e79DvV5nZGSEtbU1HnvssctymoyyldKFEBhlis7pxhtv4Du+49s5+MRB/vf//sMn3c8Xis3OzlKv14G8nxgQwTeSuIeXDTsXSZLQarUKYms+ix4O52qtKZVKLC8vs7q6SrfbvWCguRy7GBH8ybZ5di1HZgbn9K/N97qYI3mxsN5652mAPtolrhPIvSMLxRRoUxHeFE7iRgjA9tFIm+avjbYIbw4KJCm9TofMaKJ+hC99GrUGo6MjjCV1eitrRCstuu0VWiKgRkC9Yd9/NZ2ybBbpt7pk2lCRktDzbKgOAZkLjxkFeBgpMb6wSJIcOI8qy0jTtNBS3Mxhyp9vvnz4PuZgwDAoIqW0ITYDQRBisAXHwXKa88nDsKM0/N7EUUycbC7e/PnYc8ppGk7Xzm+YHHoYwqWKm6FigbZhCoQyGAUyEDRGGsxs387tt9/Ox267lbf+4i/Saa6ytLwEvke5XqEsDFmSEscJ/X6EShJ8X1CqlinX6wReGaGsHP3qyhInTh7n8UMHOXzqFPc++BD9Tp/DjzzCh//xg3zJLbewtraKQFOvVvGRJMMaHPl5bmLFID7U4YWhTam075wp0vWl8NyLt8k+hvZvG6e9n0ZrjLAe++yZWW775Ce54op97N9/BZ4n+dwdn2P29Cxzc6/mhS98AXv27qXb67HabLLviv3EaYoyim3bdtDt9zl65Ajbtm5ltbnGqdOniKIYjOHQ4cNUq1WWV5ZYXJxjfGyM8fEJq1arbdq7J0AbgfRs5tmhQ4d57WtfRxiWSJOMteYae3bv4ZVf/mXoTLF163amt0wyMTGJJyVpkjjFbsm+ffv5yq/8akZHGxw+dJix8XFGGg284bYiLM889xc8x0/aSGYWwjqeRuRp1FZJXCnbQVQqJZI0QwhDlg3xVVw71EOpvMMoU947CiHd8zFMTEwgpOHdf/lu7rvvHn78x97Kd33Xd7Pvit0XtI2VhUWOnzjBkSOHOfj4QRYWFshUWsz4tLZOztz8PMsrK2it+NwddzI6MkKpXGJkZITV1VUefviRy6/jyGBQyJGRXbt20uv1+OVf/tXL3scXgp04cYKpqSnCMGRkZISpqQl83y/6p40lT4YHg7wv6/WstEQcxxeUiBpGmcrlMktLSywvL9PpdKhUKoSu2PhmKNGTIUaXQqiGv2+0S6FMG397qlagtP8SVoBBT98pXD+hvpAfVkyUHCleD4SQXKcymNAKN8EXQ+dljLFlQnB9kyfJyLPcNCZVZHFCH0On0yXOEpI0pVIuo+IUEkWoBHV8KgqiTp/u4iptv0xJ2oQWb8sUcRzT60dWkVsrhPCQ0kOGAinBZJrMaLTOkF6AHwaYwEP6vuVbpSlRHJOmqSut5RfO/kY0KV8+rPgNF7YXIVyWnpT4vs3YNkagM21FZTXFvhKn+j3MczLG0O/3Li1M/DTtOeU0CQSetESyWGWELrTR7nbo9XpuFi7WdwDaQKKKjOjA9zh24gRnz8wipGCltcbhQwdJ4h6eH5KkMYdPHKPT6dJsrrKyvMLi/AJxHNMYrTM+Oc7E2DiToxOUSjUkgjD0+bbv/DZ+8Iffwm/+3tv5H7/127RWVzHGEAhJr7lG1usRGEOtVCLwQ6wQwkWgw6H3WBmDUNrqYxhDr9ejUqkUse8sy2w4Zt8+4iQhDIJ8zlLsbLg5CgG+L4ljhZQ2k3Ct1XKFhJf5oR/6wU0zDu666+5NnsfFOQJSCP7kT/6U8+fPY4zhr9771xiTMT4xwSduvY3rrrvOhhi17baU6zg0kKkMZQz1eoOHH36QJIl59NHHmDu/wJnZs3z4Qx/m+utfxK6duzh18hS3f/p21laaSE8yN7fA/PklvuIrr+PVr34Nq6tLvPWtP47nC9rtFqdnT9n7mrnYu01QsS9paGFeS6Y2AxV014HlCrtwIVclSdIC1QwCH8RAJmIoKcZlDVotJoOx5HIBSZqCkGzdvo3FpSXa7Q7/7t99N7/0y7/I/PwC9913PwhjFeLdMQPfpzHS4GWveBlf+2++jkZjxNHXjdVocfF+YzRHjjzBt73+9Zw8dYZur8vC/Bw7tm9ndbXJ0mVkXdp7Zd89G9a1cgxpmvCe97yP97znfZe1jy8k279/H2Nj43ieR7lcolqtrivgu5kjujHLqlSy28VxfNFMH2MMlUqFKIrYsWMHCwsLpGl6ybDZZsfN9/Vk62/8beN6T7avp0Os/hezotMaIFxP1TY6SZuhikNruwmS9ZByFMmSTIfQZ+Mmuw6A0pkhJRsgL96A0lRUclcGnSlSA7FK6UU9u/9Mcf7MWaKVJiXhQyeipjziTKLW+rS9ZUI/YMyfolStUp+YoBMl9Fod0lSBzjAGSl5ow29IlHYCmCj8ko8vJJ6BLI5J0rRw+od5ecOTh0vx9fJ3YeMkw/d9wrDEyMgoExOTCCHtZFANqAlxHBP1bX3TNLNIV5rExEmC7/nU6pWn95AvYc8pp2lkZASbGp+iNdx40418/dd/Pe997/s4dPjQJYQDLYIAhl4/5lOfu4MTZ88wtnUrDz9xmBtvunnwIogNs61hDchCJ8jNItzAunV6nO/63jfxwz/64/zCz/4c1990C7NKsLLYp9Nc467Pfpa77rqD666+FpWkWD0hp0G0Dm0aCg7lyASDMdcTEoXV20gzq9GxsrLM1Vdfyx/90R+Rkw6L3W3GMnfTOeMg/zNnznDvvffQbrdZXFzkfe/7a86dn7czMBd7F9iw3vA9tcJjksyFPPOKK8Y5IJ4neNe7/oIwlExPT1jUQ0GSaD74wQ/xgQ98ECEkk5PjhZhkjvAIIbj33gf4ru/6LrevgJ/5mZ+lVPaRnuRXf/VXUZmmVKryd3//DwRBmbvuup8sSwHDhz/yQbQe1Bz62Z/92XX3xQ8EUgqUyrM/1s+uBcJW/pYSzMCxE2I9zK6dcrZSinq9RhwnxFFiRTJdu7F8RtuuLGHRzhKFExq190tSqza46qqrefHNL+Y9f/kePvWpT/Ld3/3dxHHCBz7wf/npn/6PGKMK6YdarUa1UqFWr7Jz504OHDjAS26+hcD3AMOevXuZ2bqV8YlxqtUq27dv40Mf+hAveMENfPYzn2XHjp3cdNNNrK1dXCNso0mn1JumKaOjo3zJl7yEpaVlHnnkUbIsxfc92+k+T+yKK66gVqtdMFgOk7cvhbgIISiXy4yOjhZ8pXx5bvl+c6dp27ZtnDx5ktXV1XXq9ZfDYRq2y+U8/cvYcFjg2bZLTfUubZcKs26+Xo4om2JCZUTxnyHYmWL9QWSBIgSHtHxc4zSU0FZjDmUg1WTKkKYJaZQgpSTVhuX+IpG/ykhQxhceI14J7ZfpRH3aSyvgecgwYCwMqNbrjG+ZQngevbU2ST+xmkzCI5QB0peoLEOhQWf40rfyKalNSMkcwiSlLQQ+zDHKuU55W80RpjyctrGN5Sh5EAROsiWkUqlQq9WKkN8wZylNU3q9Hv2+FceMY1sCK4oiwlLoogfPrD2nnKZmcxWl7Yz/Dd/5nfzWb/82SRzxsY9/DA5ToAG+71vZAWMxFw9hU/cD26mXKwHlWoWOE/Xbc8UOTp86y4Er91Gt1Tg5e5pypczE1BSZ0cRpShD6lKslOygpTdZLOHviLKvLa8wvrfJ7v/t7/K//9QdonfKm73oj7333X3Lv0jmyKKYWhhzYt5dGo8bMtq00RkZtBpfKQ3T2BZKeBOleysQ2Mt/pGeVq1vbN0njSVqd+5zv/N3/0R39s1b8d3wryF87eNyGsTIHvWWfB9322bdvGTS++kR07dvDCF7yQl7zkZqZnZvjxH/9xPD8glyHItx8OFeRxZBg08iIjQin8ICg6CrCDSKbUuu3s/vIna0NiOc8qh3iTJLH6Slrg+x6eL+h1u8yfn+fc2fMsLCyjleHBBx/m5MnTdLsdOh3r/M0vzLHWWsEUGKOypWYMZOnA+dFuamfh3Lg4N/uj5Q/l/awlJA7uaZ7hoZSi24swbiZptCnUdXMXy/43nyIKcuFtow1RL2Zlpcny0iqL88ucPXuGXq/Drbd+gve97738wA/8AG94wxsQwqJYcRIjkfiBFZTLz18Iz51TPn2VLC4u8Jd/9Ef84wf/iT9/15+jdcYv/eIv8uVf9gpe/iUvo9fu2OspztKFLF34Vkg3CdHGZlk6dLTdXuOhhx7hV37ll+h0uvz0T//MBeGlL3RLU6sxA4XfPcQnXM/ZuFiozBhT1B5cn+CyngOSr5sPRJsN3MPfN/tto21cPz/upcJ6l0MEf7rOlngWhAifTdt4zzby14r11m2ElQwYolGIIUaFBqRznoS0XFYJCJ2Tsi1CjScQ2mAyZdFyZfCRBEGI50kyV4rJoKjXqzSCEp4yJL0+K6ttIp2hPNCepDExzsTUpOX9GkMnXSPq20mnCKHslyzapJSjJSi09DDOIQIolUoEQUCpZLO4c37TMHqUj8lBENhrdW0ZKBwipVQRbsvfh+XlZdbW1jAGh+qWqVarLkQdUKvZpJZhBFcpRau1RrPZfAafuLXnlNOktOHGG67nR3/0x3jDv/su5s+f4y1veQv3338/ntPQGS6gKcBqPRhbiytzs+Con7LUbNp0TgylsMSVV+xFpRn9dodqECKEJI4iMgz9JKYbaUTHEEd9TKKYaEwwPjHG/n27mZycYG5unuXlJuWwwcTIKDlFs14tc+CKvTRGq4RhwNTUFKPjE8wvLKB1ivR9G5ZSCp2pCyZCmdL4zjOPswzf9+hHEV7g86M/8P3c8pKX0up0OXr0OAefOMSJY8dZWlyi2+kQx30uZvPLTR589CC+gEqlTLlaQno+YAdeMeShD5PzrGlXnsMKinrekDdv8smRtHXSMrWu0/CcLP5wpz/ovEXxN1+Wv2BxHDE1Nc7VV13Frl27qFbqjI9v4Yp9B3jd676VNE0Jw5Dx8XFqtTK+L8lL3x09ephv+qZvYGm5ecF9MMV/hpatC01ceO/ydhaGlhuWJVbCweQ727TvtwOKF1gFeVsrT1KpNvCDElGUMDU1je97zJ45Q5ZlNJvLLC7Ok6YRN9xwPZ4XWOTNH5rFkpfWgTCwtZnSodpMSZrS6bS54YYbkV6IMYIz5+bpdnucPHmS06dPFTvKKzLkLpfWGk96lEolm9I7VMdJa1haWuK//JdfXDezDoLnD9okRF5yZ7jNXKjHlP/dzBnJZ+Ibw3bD+9zoQG1E0i/m/Az/ttlgfjGH51II1JM5ahuv+6mYGQJfnl0bIPo5tvxkINdmz27YWdrcQbWTMws2DYfm3JFN3t0XZ1G80xKBJySezENXBjOUPS6MRGhsiE4Z0BpfeAQy71+tLECqIlK/RDko0ajU6NVHSLSmm2rWlpsYP0AEAfXxMeqNBmkUk8UpfaVJlCI0GaEIcsIpJlPWWRIZZFaDpVQuU6tWqVarrq5pRKvVugCBzR2knPOU96M5cpT39Tk3KpcQSNMMpWwmXE4YL5VKlEqlQnh4eB/5siRJXf2+Z9aeU04TwI/+6I/yPd/7Rh649y5+8id/igcefIiXf+nLOHn6FKdPn7bp4ULYdHzH7vEdqpGgEd4AEpWeJU6rLKPiB0Rxgo4VFQJUBlm7j/EkgYTMgC8DxqdG8BBk/ZS0EyENhJ6HyVKyKGJxpcepYyfInKc+P3eOBx+8j1K9jF/2KTdqSKfAbSs166KjkJ4AT9p3TWmM0/5RWtuCidJ20v3YoLOU+x94iJOzs2RK47nK0DfffAsT4+MEfoDWisCTTE1MMr1lilq1jOdL6vUGE9NbkFIwPzeHENoiOsZQqzfodHqcPTtnC+cGAUEQEEcRi0uLxFHC2NgY5XKFqalJ+v0+R44cod1u2xmwNoRhyPLyEuVymX6UsLK6SpzERFHE6soqa2ston5Ep2O3kZ5HHF3cwcvt6FG46867aTQaNOoNPK9EEISMNEaQ0ndOlo/nHCbPIWszM9O8+93vZmpqqhigpJT2hTJDM/t8AIQLBj6braGKEOjCwjz/8Pf/F8/z2L17J8dPnNrUAXN7cH8lOhuEZmu1Ggf2H2DXzl0sLy3TaXXYMjWDJzzC0MpdvOPt/4t/+sd/5MSJUwUC5Lm3Nq9RXKuViZO0mBTYcx5GPwT3P3A/3/kd34rWKQZIM8XcubPMnz+PAcqBR5wqi3QO9TNaKaJM0Rhp8B3f8e3s3LGD//7f/wdC2KzBJLFEy3I5JM0S0mehk/rXatKJfUIesoGNThOsd0g2Oin57HvjusN8qI08kctFcjZzlL5oz7xdOlTn/q1bOgQtDcNMQ9sIYast+E67CJ2Rx/ekkAjhIYwrMo52NVEFwtWRC/0AkSWkaUan1cVUbDLBxOgEIghY7LVp9WKWzi8QlEsEpZCwXGZkbAydWeXvqN0lQVMizxwHYRRkGcZIdJIigVq1xszMDKOjowAsLy/T7XaLdpdPkpMkuaAdK6WKyXGpVCrafT6ZSNMUpew0LkdkhxHtnPSdW76fcrmM53ko9TxXBAfo93r87V+9l3e+853c/9DDvOiFL2B0dKTQk4BB55T7+YXAdM5XEriyHS5Ek2k6/Q69tTblUpnRkVESlZHqDKU1nV4XPIFfq9CJW1TLZUQGtVIZk6acPzNLr9Ui9D0SFMuLi/TjHgZDpRSC0Bw7dog77/gsR48dpdPr2SrS9mRtZ+v4Uja0Y98020jzKzcYI8i0QngCieG+++8nTTMQglq9QbVaY3xsjLGREcIgxPcCpDDUqhVGGyOUwsCeU6VCY7QOSLqdFkIOitQGYYUoTmi3m2gXTpOeR5JktNsdjNFUK1WEkNTrdbIsZX5+fvBCuBDp2lqLcrnE6Ng47XaLcqXC5OQkAEJ6jO0bt7XrXPx7ZWWFLVNTSClpNpvs3LkLgVV2HR8fRwhDkvQBjecFqEzT70c0GqOF+rYdsCyPyA8kKks5eeokjzzyCLfddiu1Wp04jl2nZNE7VXCynEqt0z8aYpcVBHCt7cwqjiMOHz7EfffeSyn0XO02UcwX12fkiHWfzVAPGvg+E+P2PpyfmyNNU2rVCpVyiZe+5CXMzMzwwQ99mNkzZ9gyNUaSxASlCkIYPN8jTVK0NoyM1On3e0RR4siTAUpl9HoRICiXK6gs4zOf+SzGQL1SIYkjnnjiCc6cmQWc8rwnUEN1HXPHyGY3euzdu483vfGNRFHE29/+dgBmZmZI04xTp0/heYMCwc8Hs+GEYafHc87q5Ts1+cx5OKtOKVU4U8MO00axy437gs0RnosRs/+18Joude7PhD35fnNs9ZmxJw/RWUhtgEMJNjJvlNagU2Tgwk1aY5RGCMuj9Zxmkza6SP8VUhD1+vhSMjHSIChX0f2EtB/R7rTQWuG56gIdldDrpUT9mLXlVSq1Glu2TjMy0sBoQ5KkpHGC1gYlTFFmRRiDMBq0Io0TyqUS9Xqd7du3MzExQZIkxHHM4uJicf156C13iIIgwPf9gjju+37B7cvbfL/fL2gaOd81XzeneAxzmPNQeeakD3KnLHm+Sw4A/OW73835uTnOzM2zY9sM1UqFublzdFptFxkZdPrAEOoEoR+SZMl6th0wOjLC7MmTbN+6FaMM1VoN0evTa/XQwHjNat6stVu0Wm380YbVwxCSNWFI4h4jo3V2bNvJ4fZJPDG4saO1KlOjI5w40SPrdal4PmUvRCLW586ZXAdokPvmJha24LCw4UmJQfq2llu/HxXhvLWVFdZWVjh/9swFM5cLZr524eAeiI2/PvXMEuOcvLz/ybcfGRmh1+tRH6nTaDQcUdq+cOPj45RKIVmakcYxSRw7bhTkxXeFcAW1sW5Jjnz1dUIYhlSrlXUhvjxkEoQBKssYHR1ldvY0v/M7b3+a9M8LzZOCUikgDH08T1q19/zCnaO7vrsU6z9baXo836NWr6G1ZmFh0aWrW1X7l7/85WzbupWPfuxj+L7H2Ng4CwsLjnyfkSaqCGdEUULivhtjrIaXG3R93yMIPCrlEmHgs7jUZGpynLW1Jk8cOsT8/IKd8Wl7n4coXBjXJoUnGB0bYWxsjCuu2M+P/uiPAfC+972PXq9XqIPDphPnL1izfL5cT0Ig5eaOzEYbfh+HOUzD213Mydn4/XIcjafq+DyZE7NZiG/jupc65rPhHF0sXDn8faMT8/mexWahzWICts4s4m3WL7K0DGy/JtxJ5e8fxqC0JpMKT1rKhCc9BMYiTU7JWGgbkZBCIH2fJIrJUAikrcvqe+BJl0SuwBhSo6gGZZKSRsc9knaPtcVlwiBgbHLCToanMuIoImp36KYRjaCM73sgbdYcSpGr92qtabfbxT3OVbnze5Pfj2GZgTz7MwgCRkdHmZ6eZszV7EyShG63S6vVYm1tzXGcjJvE5Q7TevHXfEKhtXLil7qYfDzT9pxymgRwz30PoI1hamyMWqXMyRPHGRkds/XMeiFJmrkOf0hFVWtbW1IOenSVZag0Iww9JienOHr4KN/zpjfyqdv+mTRJUQbW5ucZHWnw6q95FSrN+PSnPsmBa3Zx8803c/DgQWZPnSIsBdTqZYIQGmMT3PfAo9SqFcp+YF8GYyhLyfToKNdfcw1JXOX2xn2cQK6rh2ZDcoBvz9MoUyhKK/e7wKXmK+OKKgLeYNuiJpuyo7YUksBzquZa4QtbDDHOYkqejxGGVBtKYQAI4iQBYflKUm4kA1tV1kHIxzbgImYsineoIC9qbJFlgNWVVZorTSfWJmiurqzrtQZOj1jX8VxyYDAbOuBNVis4INrVpkOAGWRv5C9VGIZ4gU9eXNhyhzw8TxLHEVmaEvi2gna1UqZUCun3+/T7yyi10VXf5ETt2WDRJrusUimzffs2xifGad5/P2HJ1jY0Bvbu3ku1WqXd7hH4kna76zTDYsIgIM0y0iwjLAWUS9CPErIsJQxzXoDVjRrORvGlrRHl+wGrq6tkaUaWWYhdGhuC9jbcQ60NE+PjvOzlL2N6yxSPPPwQL7z+et785u/n7rvu4Z5776IxUgcBz4L47r9ys8VJNwvJXa7lA++w4B8MZtH5+zAclhvm0gyvm+9vo9M1zInaDHW61G8b17vY75fjIF5qv5e77uUc+0nJ6WJYkuXyj3kprhlwwTPKf/M8D+HCbHmkAxyKnU8Ic4cph3gzBUqRZRoZeIRBCe1p0syW+Mo1WoTTLpK+R+j5mCAky1L6vYjYIf/10VFKYQgGVlZW6Pd7NMZGGfd9DIJukrC2tEqSZmRKs23XDqamtxBFPc5HPTrNFhXfo+SX8AIfTwPa4Hu202+1WiwtLRUFzpMkIYps2H4YQQ3DEGMM/X7fJfl4jI+PMzU1xc6dO6lWq4yOjjIxMUGn0+HUqVM8/PDDtNsdN5m2SS6D+X4eXbCyP6GTjMkjDjb098xP4Z5TThO4hgkYY8uI5GJyo6OjJGnGymqT1MUxDcbWbLNfiGML1RljiLo9m7GCtCiOELz227+NO+65l/5qk8QolNGMTUzwki95GfPn5rjrjru5+aVfyute9zqOHDzEww/fz8hYlfEtDR55/FGWmx28kodfK5N4llMVlAKq9Qrlsk+v2yaLU6SryyPcdGPdC61AK5e+WXGPx+TKsMrK6CtDuRKgMls12kKWkizJinCezgy+tKmiUkoC6bl6c1YqXwg7gJYCQZqkSM+jXK6QpgkCS7wLfEHmwjW+425k2aCWUubUqS23w0MpbdP4BXjCQ2qrJuuEr/F8STkMQVgl5DRVxYwvvwMSS4LMOxFjDJnzxnxpw0cK8KXnSt5ojFZoctja7U24Ts5oBFaDKUsVOu+cxIB0m8fWM6UsDF54f7Z6Ui4bEEWx1e9yPZ02VpTT9yG9rAmNcf/svaxUKuzetZOtM9P0ox5hGOB7Es/xGYTLztMaFheXqFbLxJEVsPN9H4NNbhBlQaVcIss8qtUKUko63S4qS4vstzRJKdVChMBqmqmUUliyz86dXTUISLK0eBaetCjtzMwMX/ZlX87i4iJve9vb+NVf+1UOPfE4q6urpJmm03EFf4cf5PPA8tDasGjf0yFDDzs5FxuQn6ptdLCG2/rlIEiX8/1iy/LjXmr55mHETTe55H42cxA3XvsF51Qsu/zns9myy3FEi++DH9hIadrMgiCwuoFxbDWQUsvlSbIUIyAIQ0p+hVJYwghBmiVE3Z5Flktl0sT1E4HP3gP7uWr/AaYmJzn8xCEOHjxIp9Mm1ZqyH5ColH4vJhKCzuoarZE61UadkfExsjRj1Qj6cUzWi6goQTUo40tJts5hNCiVkJO7c7rGcGhuPU9JFevmxPFOp1OE5XLHy2anDqIIm93b4fa9kYsq5VNTzb8ce045TUEg0UqjtEWKhDE0GnWiOMH3AzzPhj6EGFRjV9rGHYQQmMxCndZp6lpOiNF0owjjSe576CFOz53j0MFDKKWIo5jmWot777uP5mqTs4sL/NXf/h3/8KEPEccRUdzHCwRBxaNU8/nSV7yCSCsi35A4hHax3eb0+bOsttfo9LrFgGxdKksAERJLUDcCreyMc2x8lG3btrniqxLfD2m3WoTlEt2oT61eJ8lSlNbU61XCwKfTapElKVmaoTJNOSjRbXXRyirNCmP5LXnJBqUU7W6HMPAISyWEZ9VX46iP0SlCSDwJmbaFJz1PIkNZOEtaW7K1DbEppJD4YYDKUoIgIEkTAiHJtCYMfXzpFSKQQeDhCUtktDWWXFFKY5wOlkYD1XLZZgtK+zy10ATk6t2CSrVssxxVVmwrpVPtzuw5l8uhDV1lVoYh8ALL08qSdTP9IrbkYqJCCoy2iJ3v22raNvSlLclcCJJUDyD1TWxjd+s7eFlpgxS4/aZEvS5bJicsKmqsMnriiN0TY+Mkccq3fdu38sEPfYT5hfM06iMIIWiurbG21i54Mc1m24b4cI62tB1WGPiMjDRoNpt0ex3SOLEVxJVFMhvVMnv37uHkqVP0osgmxjj+Vb/fZ+7cObZt3cro6Ah/8Pu/z/vf/7ecO3fOtg+VWWQyV1R43jhOYh3Ks+6XS3gAFxvUh397ymfyJCjRk53XxRyPi32/3HO9mJO2WQgwDzf/i9pTdNQ2u+bLvfcmv29QaFmuOw83gxTunvmej9G2f6mUymijiVxpEKO0jZT4PpVyGdPXtNptatUKfmA5jWmaojGUa1X2X3M1L3nxzezZu5dKrcqD9z9As9nE93wSzyLXJIpeu8PSwiJjRlFrNJia3oJnYP70LFG3g5ABVT/ADwJkkiOtUCqF5ITuXOgy5zDl5O3833CSQ5IkNJvNwpkql8ssLy8jhCjKC+Vo0mYO94WTjcHNFGLzd/PzteeU05RluohItXt9gtU1du7YRpokjI6O0Y9j5OoqQmGzCoyb2QuKqqk5+WxhcYFMadJM0et22H9gPzfc9GKuv/kmZnbu4Gtf83VkUcKD9z3Ann37uOa6Co2xEa68+kquuvpqPv3Pt/OxT3yUTtSiOlrCSMOuK/aBD3OL88QO7WqMjrF1xy4OHzvCqdmzTM8cYHp6ykKbeVCnEI405Lys1lqbqB/biYmwpTYyleF5ApVpwmrZhhi1ttliQqAyhXFqqYHvU6tVUakmjRI84VEOy3T6vQICrtZqjIyPEQQBlWoVpRVra6uMjo6QJJF9GYVVqxaeR6VcJghLDmWys4VqtVqEgDA5QiLodjporbjqyis5cfIExmhq1ZoteGwMnY4VVSyXKy7+bOj3+iilbAgxjp24pG3858/Pk6Qpo6MjTqsoLeLmmdHU6nWbWdjvY7Qi9H0I7PPuRzECgY3QCacZpVE6AwSSAe+t6MqMKAoZ2yKsrikJ8DwfMGitKJc8wtCG6vJO35jBP/dIAdsMtVLr0MVcm2ppaZF+316/53ksLs5TCm39Q+06mp/9uf8fiwvLfPAjH+INb3gDBw4c4M/+7M944vAh9u3bx5VXXsns7CyPPfYYMzPTfNVXfRXbt2/jjjvv4J6772ZiYpITJ2bpdSN8L5dzsHIc/Sjm1OnTxHGC0gNOE8C5s2f5P3/0x1QrFV75yi/nK7/yq/jLv3yvbRdyEJZdR4h6HpjWmtiRZaWk0Cp7MrtcJ+tyOUuXe4ynsv6Tfd9s2dMnjT815Ofzts/z3jz7JkjSDE1Go95g68xW9u7eg+95rKyuMDs7y7nz5+l0ugRZSmN0lFJQwpc9+r0+WmU0RutoIegmMYvLK3SjiJ2797Bz9x62TM9QCss8/sijrKyukjmENDaaqNUhMRlaGEpOD0lMjtNeWWGt0yVOMzI0IvAIyiEqNbavFoLR0VFmZmbIsozV1VVWVlaKLLocdcoR2Xyykdde7Pf7xHGM7/t0u10nGZC46JApIhqDJjZwljYuy5f7Ts7nmbbnlNOUZ8EJwGhDp9tlZWUVsKGOMLA6NPYmDnpwYcCkyoWILLoxv7BAqlLGJ0e56cU3ceLkKT74gQ9w4sgxlNKcOXnaQpMrq5w+NcuWLVOMjI4QlgLOn5/j8JHDaAO7d+9FBNBsrVIPG6AEgZEErtxFGFSo17dQrU0yNbWNnbt20xirIjyHUOR1PBhk0RmN5ZuojHUhWQukYYwgiZPCKVy3ivuSeKktbmhcuA9JP47IARUhBN1+18aoPQ/f99BGEccx7a6VArCZC9iMPbE+c0G5opF+EFhZe62LcKNEEKcJQhv6UURrbQ1tDEHgO6TNkMRWksEPfOdgWM6ULz0C36JUUkhGRhqUS1XC0hrlap3xyUkXfuqQpAlhYImOY6OjVGsVer0eUa9H4Pv4gU+r1abZbGEweCKwKA4WkbJmU3YvMAO5lsrGe2wrNbtdZBqlU4aoXcVfMdxe3b9c1iAH1QwaZRRR1C8kDSSQRlHxLDudNgbYe8U+Wq0mBkO9VuHA/n3s27eX4ydPcGZ2liSOaHc6pGlKc7XJfffexxO1GmfPncUY6PWss1TyA/csDGmWFVfvl0tk/Zjhsa9SrbBzx3a2b99Go97ghde/gJmtWwlLru6Zdu1yg1TB88Gk9PCkz8ZySP/yg+wzb89kRttmyNrmaJV5Vga5fx0mBp0BbO4nDqFNWZaRKcXoyCijI6Ns3baNUimkXK0QxTHdqI+Rgkqtyo6du/B9n+XVZdqtNZTKaIzUSbRmtdPh3NwcZ8+dox/HbJ2aYueuXUxvmeZUtU671aZaKiN9DxP1iOI+qUmRgUe1WsWfmqBSrzI5MwVZSn+pSYwi1Clppq0On+tAarUa119/PZOTEzSbTR5//HGOHj3K6mqzmAwO87xyTqnWmiiKSJKkyKzLs0jTNC3GnBz5Hp7cXnQZAz7rM23PKacJ8nRu+zdKUlbXWmzdMk0cxfR6EQLh2PNWSRtph6vAlzYs5eCAdrdF5AQTX/jCF3Di+Alu/+d/ptNsUylXOXzwECrTxFHC4uKi1RQyGUvLC3S6Xc6ePYdSkKUglMY3IVOjUwgtGClVCBwfRYoACNE6BBkytXUL9ZGK5ck4v67oOIRDKnIWNRsm7Xrw3uXZSuvfOqtMBaAyg1bp0E4UyQbejUjiAs60kalilF73gg/I34O48hBYD5vMMI2xZ7Ky1rQhJ4YmeA7GMMIiP/lxbW1BiS+9Qj4iyRTVUkwUJwhPIVttEBD1IzKdkoY+aZbS7fXIlCJJYpI4sdpVmSKOU+ewCCfSbS9MCsulvKRdZOY8XNBXKbNOC2S4X9wYtlvn4hY/WsQqCH0mJsedFpMiDAJXEgVGanW+/Cteydv++3/nzJlZvvs7vxOtNO9973s5cuQwlXKJq6++mvPnz1MKA3bt3MHS0hJHjx4lCEKk5zM+Pkmr1QYjOHBgP9NbtnB+4TzHTpxCCCiXS8RJapudtAkMIBhp1JmZmWZifBzPkxw9eoSjR4+uU9q1s0AJ0kNl69IbvqAtCEJKpQrGKJRWLiyaix3maz23vYCnix5dKpy3mfP0hWt5J8+GpjAI1RU21DVoBkkcURzR6XTIshJplhKWy0xMTlKp1xifmODAVVdRrVVZba6ytDDPWnOVVKesuPDW0uISp06d4sixI3S7Hc6fm7PaeW5iVglLeCYgUiki7WOSjG6rzdLCPIEvmZqaYmrrFgSac/0+kVaQRKg4QxqPQATESUwYhlx99dXc9OIbMFpz++0zGANPPPEErVZrHQcw5wHm2kuw3tnJM99yHSelVVGBYdM2aS6c2mqXAPVM23PMaQKQGDTXXXst+/bu4Z6777H1drIMozSetGnbyoVbhEMKimwsGx2h3+vR6/cIvJCJ0TEatRpJFLFz+06ElPjSpxz61Ko1jNHEaYIUhqifYDLD9JYZWu0ugRfgeVDyS1x31bV4wiN1xD0AhEeqDGutDqfOnuXKF/ZAuuE7j8fmUwwDwpOg9GBQtW/QOrt0Ze5hiBIuFS/JvfRn09Rlevo5aVspRYotKaGNIWut0ZU9yz8C+nGMzdSyFPN+3wpPZmnmeFbalWSxnJ5MWaKN5wiJBoMU3oC/9CyY2eTzukMVj8QS0PMU3dGR0QLCzpVywZJCrzxwgP/9h3+IyTJ27dzBwYNP8Jk77qDTabNlyxa2b99GksSkaZ0wtPyCMCgxMTFFkqbML8zTajYJAp+9e/cyOTHB8upywRkAiCPH8XL8gZorVdBaa5EmCbt27aS5usan/vmfabvyK/l4YIxZl536fLA8o1Rrm5DxhTT+Xy5P53LsycjnX/COk3OYBK5Pvujl5tqC4PmWA5pmGYtLS0W4SSmFEVCt16mNjDAxNcnU1BSVWpUgDKiUQmqVMrNnzmBS5Xijivnz89x5x11UK1Xm5+Z46MGHWF1cwvMkvhfiCSj5IdWwhNAJSZzSWlmlUi5Tq9dojNRpTI4TLi0RtbuoLKFSCql6VUIZ0DnfYXV1lWazSaVcYcv0FItLyxw7dpzZ2VnW1tYKnSVjjEscUkW4zkY7fEuAd7+Ba4eunzRyUBoMN+HOPxcm3FIHRuhnAf5+jjlNud4GbNu6lRdcex2PPfooY2NjeL5PENqwQ5Jl5AVnbWdmUGnqwmH2vkf9Pkkco1JlBRRHx0gzxfj4FpZXVjHaqmEbY0gym3UX+D7VUg1/tMRKs4XREiM0vX6HaqXEnj17kEJy9tw5otimXPaiPivNVXpJn1a3RbffIdUp6xg00nKWCncoP1EpCuEy8uXFRpd/z55xe1rn8SS7LAZfQFhSuAFbudqFP4wAlcVD0JdF3Dx/APNuvm9HIhSWeG9sgA4ulY76eV5b3vltcjbrftHGcgI6nQ5xEhedwdraGmFoQ2CZVhw+fJgzZ84wMzXFp/75nzl37hyddhtjNEmacPjwYbrdLkEQuBIvARPj48zMTNNsrnHy+DHSNKUU2lIsVgelTZpZ4nmcpEVJkNyq1arjky1QLods27YdYxQLC0uDa8j/GJ53mgNZmrmkClv+yPedxthzrIbak9nn6zxdygH7Qkedhue+GNuHWdtk6juYO1tERkiUMayuNd27bp2N0fExGqOjtkB2lrHWbtGPI/r9HkbbSWG/10Nnmka5Soah1Wxyz513k6YZq0tLLMydRwoYGx0FT6HQhEFAQ9YQiUTFPeJun7Vmk2qjigw9TOhR3jJKLDVZJ0bqgJJfpoTveJlLPPLII+zdu4fruIZKucK2bduYmppytePWZ7ldmOkmL5DPkDKvvZqto90UXN+h73Y/zpnKb/izYM8xp2lgS0tLHDlyBJUptm7bRqu15rgVAxExi15YtVWlLOHVldBBpRlCG5Io4tzsWcqlMmFoFZB7vS7dbo9SuYzn+zbTQSl6CEy9Qbkk6HX6CCBJMlZWVqju2E7oiLvNtWYhwLGwvMDs2VlqjRpbZrbQaDTwfH99DAdscUZBEd4QvnQSARtsY7znX9qcQ2dnTc/cCeQk6oJIvclxixBiMV0bhPUuFQWxMw6L7tkBTRck7X9xGzpPIaWdUfo+CGFJj65G1erqKpVyBQ+bKfj4448TeB5plnLffffh+T6VcsmR6jscOnQYKQWlUqkYiKIoZrXZJIkTBIbRkQa9bpuFhQVKpZBOu2u1uLDyEb4/qDqec8wCP6Ber5FmKUeOHnWSEqIInwIXdIDPF8ud9FzVfpib+Dy8HRe1zXhMm3GavtCczeGXfSPenztPZkN/LkT+1yLlNjRlM8uyNMWTHl7gW+kZo1laWWFpZdlyFF2d0rjXY63ZwheCidEx+nFE1O5y9NARMpWRugxmPwjIjEGlCUJCWAqQeKQ6oxcDKqPbbjG/4JEIRXWsTnVmHBV69OabxCsp3X4PI8KiPuWhQ4cYGWmwutpkamqS6elprr76arIsY2lpiV6vd0E4LrfhsFwuzjsoDD+gQGx0vAaO04XI0/OeCG7NVVUul2iMjoAQJElMnCQkSYLSyo3rxa10AT1r+XtptEEa6LbbnDx+nKmpaTSWzDbSaIAU+GGA5wcIaVn+aS+iF0V4XsmWLKlW0TpldKzBC194LWEQotFs2bKF5koTMStYWlzk/Pwc09Nb2DozQ73eKNLOi8sRTn5AYDmlEuvZ6SES00Z7VvsXyfoDDwWZjL2rZriDu3gE8NI2tPtBWIj1XtPQjMGYDefiTGs9NAsZtvXrIGwSpbFs8GfWLgd9E0MdhQGE1beqVmpMTk4xMT6JkNZZzjkA1VIJKSW9fo9azYq/RXFsSfi+j5CCtN+37V74pGlaZFu2WmssLy0jgJFGg4nxMVqrK6RJTL/XoR9Zx99WkJGFUKlwgrBRFDExMc7u3bs4fvw4Dz/8qA0bep4NIWMnKSL3dIGiQN7zwqQrTB0ghCFT6YAGV7TfjS/wuob97Nlmh/lX/liGx9B/kVMdPsiz9DgG9R0YeET5bxeJZgthEejM2MQYW2vO9o+pVnSjPqlW9OOYKEnwfI8gDAj9AB+BShM67TblcplGpY400Ol26XVbVlamVKLsso0zrTFZRhDYyZsUEEiPUHhkwiOOIlaWl0ikYqrqsWX7dkSlhE4NvZVF2t0uWqTUalWElCwuLvK5z93B/PwCX/qlX8ru3bu55ZZbGBkZ4dFHH+XkyZM0m811St3DWnmwvp5ckSiluMDJuhiC+UxpnV3MnltO01AjGxkdZfuOHQhPcn7hPCqzJOA0saEvm66ODce4jXxp3SdjDHEUIQXEUcS5s2fZsX0H5UqV/VdezZVXXYWWtiK9diG+NE1JehGekdSrdXy/4uLIHn7gsX//HkrlEgBTW6Y4fuwEAmi11jh79izjEyOoNKW5ski/13OzCYlNOcdKBbiRXCAwTgnWaG0Rp818mM/zXm6+r1wEbwgKK7DloVjMM3EOGzqtdSHLPEQJF0XcbHkZMUTMHpyz7Z5yZGnwIj1rnfFlPptcsNII+1xVZgVBK6UKU5NTdDodkiRhTWWUKxUajQZZllEp2SriExMTxHFMa20NpRShX6JSqVDSmkqlUihLl8tl0jQl7sf40mOk0bD3y8DUxCTzi3NkqVMDl7bMQiF+aESBfFWrNa668mparQ5JmuIVStXWsc51z7R2IqD/ygfmZ9IMZsChcwKog9+sbYbICnHhss//XC48N3v89ee0vpkOoTyIdcsuNdSsSwYxG5cPBrKN2jm5GrbgMkJy63a+2ct1aXT5krbZrb+kw/b0HV1xwTdxwWVsPKDBOk1aKbRwWcnGkCnb3qI4IdOGNEtJ04w4TQm1RpQFSZaRRTE6y5AaPC0oiQDllZChB1Igg4AgDFFoVGJDy0YKJ0Js8IWkHIQ2i1v1iZKEOEvQvqA8NoIOfIJmh7BWRvT6mMzKE0gpCyXvVqvFzMwMBw4c4EUvup6tW2fwPFuns9VqFZzNnLdpC6hnRe24PFxn3c5BBGIzjab880YnaWMx32fKnltO09D133rrrdx2222Ug4DxkQatVotut2tFuoQtkaGVJjMa6W68LcoKJjMsLy0yPjbC9q1b6HXanDtzlm07ttOPekzPbCWsVJCeRDj5doMdpAMhMUaQKYEQToRRG+bmzvHY449itObRxx7n9JnTaGPodLocO3KUkZEa27ZuY3FskZkt07z0lpewvNKk0+0RRRG9bpeV1WWEsBoWmVZ40sa0C6fhYqjT53kvL/hpYyfvXvaiUy16zWewQQ75Y+6ghZrs0ImsW9+I9UE6e6aDAR2s0KUQBuVi4nmK/NNGxy51/k/ybGxmp3S+vHVCtVL0+z263S5JknDPXXfT7XWpVioIKdm+fTuLS4uAVbQ/P79AvVGnXCkXxYfzmk7lcpkoiopCy3ka79TWbezZtYNHHnkYIwTSg2qlQhiG9PqJBTi13SbvmIWw9aGSOMH3fKse7sTisixDuHcsv4m23I4cZGw+D2xpaRmBZHQ0IPB9jFFusuEalthssIdB4zMbvm/8/ORWrGkGHJnc+d3oJImNqNeT+QLP0PtxYbhy47WKC9bPVxMbzm8jAG2Gf9hkd2KTixMbvzzbjn5xwOGTvPC2Cwa9mF1gQ3MCge+I0krbShhKKYT08ANTVMAQUuSBAHzPBwNpnEKqCD2fwA+JlRXnjZMELUEJ+95rKYjSBGM0wpNUKhWE8UkSg5KKWr1BvTFCWK3Qi2OMJwlKIUEFyqksJgd5v7O4uMjs7Cyrq6vs338FW7ZsYXR0lEqlQhAE61Cl/F9ej26z8K1ArqvtKIberXzSvJn4pX4WUO/nltM0ZL4nKQUeSmU8fvAJtHY4jbCtRicDqFw7pylxkGDJk6w119i5fTsvvvHFdDs9er0+h584xKc/c4fN0JIWxci0lXnv9yOSNEGniiwzGOGjMgVGWUTDWPK5MYb777+PqGeJ4Mpl9JlU01xeZfSGEX7o3/8AP/B9P8jxE6d58OGHWFxaYm7uHHfffSf5C5Xk3KZ1V32pBiCK+Pilxm5ziV/zhm9T9pWrHJ2RpClpktrwZ55eP9zZPFPtckMnboy6YFYm3Xp5aZbBjFoWKI7l21gOk1IgPC7oHP9f6ArlJX+0HlyX53nUa3W2b9+G73u8//1/Ta/XZWZmmnqjzjXXXcfY+TnOnz+PUopeFLFz907q9TrLy8t0Op0i26TRaFCr2QLA5UqFSqWCGVPs27uXrdPTfPazn6HkS44cOUKv1yNNIoSworG2zp6HUmkhBIsxjI2PsWv3Lk6cOonvBwRBSBRH5HWf8usRQmDU84sIfvDgYZL9Gs8vMeLVnSK6QAiPnHu3KajxpJONJ0Fh1q3lPjluR4F+5VMcYwfTvP1rowupD8sFdBlbQwqlw9HWTY87/B65l9W4wSuPQG28cuN4iDIPwTM0KBpjFfgZLF8/PzLr5mrr9m8vYd39GPgo607UrTe0AU81Q3Dzda2Dui4QxyAPbsiLc0TK4Tnb8HXmpaPA4EmB9DxUBlLYqgtAUTpLCIHwJNITmEQX6KX0PPxyCJnnEjwisjQDTxKEAb0kohP1UbHAK4X4lRAv8EgFdOIexmjK5RJeWKWEoq4CyiWPbdt3MTmxBakEaTcmacd4ma2fWS+HtFotWwszDKnX62itWVxc5LHHHiMMQ7Is48SJE6ytreH7PqVSaZ2DlGfPDUQp7Z2wQpgW+ZYuW2iDMqG928a24YITa0QBaDzT9px1mhCWtO37PpmKUeug5iEP1RGstVJFFdm+0uzbsoU9e/aw74oreOF113PgwJUgBCkCISXSs7HUTr/LyRMnuP+hhzh+4jhrzRYrK6ukmWBleYUs6gIUL83oaIO9e67k4MEnSJIeM5NTXHfdC7jm2msw2nDw8Se4+qpr2bl7L7fccgsvuuEGlMpQKqPf79k3RwpUmtqOZN1sdbgB5C+pffVMAVdvBqMM5jDrmtBGR0IIsjRj9vQ5lpdXWFxaZG7uHGfOnuXMmVnOnZuj2VxBZcllPKBnCha79H5yYTSt7QuYZQrrG1vHSWsrRbFxl8/K5PIiO81vc5ambsZvn1XuoK6urHD65Cl6rQ6PPfYY3V6PbrfHw488wtTUFt7whu/goQcfola1mTDn589TLVs5gGazSb1ep922pVQajQblcoV6vcbExATbt23jqv0H6HQ6/J93vYvtWybptDvU6jV6/YROv+MSEJQrj2NDwtKzsgzHjh3ntltv49iJ49aJzqwcRBCEtkyDVoXuipUveP7Y6dk5ur2MufNLTG+ZZMvUGFNbJqjXKi7M4GbuRUV2XCjT6jl5Xl5bMHeScOi1DSsLkb+3OVqzXmBWk4duBD6CUqVMmiZ0mx38ICAoha6Qt0UBsySxYqZau8QXW14oSRPSNLUFvsPAIutOGDAvjQG2DxUIMm3JumEQEgR2CEkSO7HKBQxt5QU7GJp8WyFA+oRBqahIH0UxWZbheRIv8PDcPRLYWpVpptCZss6W51n0fah6vXTLBEU5bPt5GIJyg63YyCla3/kNPg494+EkE4Nx+ngDL026PtVs+F+xPgahFTm3zeR9dO4ErkPF3PGMAWkzMaXjDBaOkgThZozCGDwhCFxiUa5phBBoz5aA0gZkGOB5EuFbdFj6HpnQpEahdIYyEi0gEprUZPSNpOx7lKtVJupT1EYbjG2ZRIqQ9twy7TNL9BeblNoxijLKHwgUa60pl8sYY2g2m9x3333Mzs4CMDc3x8rKCkopyuUyvu8Tu9p6FjU360J2w5QKwbAPbN+lgcafjfhgcLXmrK6T1dB7vksOyEEWvlKaNM0KkSxPSgJXRTnLFEZ6jmg3qHA/LGY4O7dA59Of4d77HmB8bILxsTGklESZnZl/yUtfyotuuoHdu3dy9YGreMUrvox2p4XSiiRRaCNIk7xjVBidEfi2Q6iUR+l02sRRTCkIGBurMzpSQxmDJ0PCcgUZeGQmRQYgQ58An1KtNHCShjBt+/J4RTqmrQPnYbsJWayeNy8pRbH+07FdO3cTxwlJlhDHEb1e3yptR32iKCbNUqwkjy2SmxcYtllVClzmlc0Ks6VQrPaUACkRxgpbGjd4CM8rZsG2T7FhOU9a5XFPSPIpRAHdCokwktXVFp/73J10uz26/T6nTp3m/Nx5MpVhjKLdadHrrl3WdW8WdVzvtMLGOfD6ZZfnKOYzrEq5zML8PP1ul3K5zMLCgm1TSnFubg6lNcvLq/z5X/wlvuczOTnJxOQka6stdMNQKpUYHR1ndGyUdqfH2Pgk9XqNc+fmePzgQaJen1IY0KjXbZhNemzbto0nDh2hUq0Shvb1V8NOpACMZueO3ajMnsf8+fOkaULeORkjC4cJuKCe1PPFOu0ea2unOXnyDNPTE1x91RUEQUC1UgEcV9GhBznfwobw3MCbS8ZjisEuR0itDYrswpDT5A90xowwKKMRRuAFHpm25ZbGpyaYmtpC1I+Ioph+FCE8D2ksIq+xvEVlNKnSpKnC9wWh9BDSQ6sUpQ0CQxA4VMq4d1iDzjQylAR+aJ+9SVGZFSAc1JK0aAjGDJTjja2pZtWgBZi42C6/FyZ3ELDogVKZyzT1kNJWdDAqs+iaqwEp837GwmYufDf0fg5IMesmTZui80KsX5D7WsaNJ0NOkyOzua+6+GfTj/IKo7kmX57p5Xbr+AXFkiGnzhjLecwdvSxzemrF+q6vxWrQWV9LFe+yNrZvDj2J5weFTp1fCqhIjTEZiTAkRhGrBC0EWSjJREAaSEzZpzI+yvj0FibGx8AY1potls6dpbOwTLbcQbQzekIjSyA9iY/Vkcq1lqIo4oknnuDhhx9GCEEYhsUYVq1ajqaUVnS63++vc5iKe+qSdoonmd8j94wsL9j9LkBK16cpbevUZs97RfAhX8JAmloHRngeV+zbx6tf9SpWmy1uve025hcWgYCic4IhlxVa3R6trqvOzvF1w95oqcSnP/UpduzexZbpKeojI67qtEB6Eq2lnd25MiNKKzxfYhROngA7Yy9eWk2aRlg4WGJ0HtawPYktxGyVpa3YmJ1hVcoBoCmVymzftpWZbdPU63XW1prs2m2l8623LUkSyz+RUjI5OfH/Z+/PYy3LsvM+8LfPfOc3R7wYMyLnOatYWVWsIqtYlGyV2BJEC7Balu2WgPYIGA2hW2i40XaPQPcfsmHJ8NDobqvRtiFbNE2ZIpuiOKiKQxbJGnKuyiEiMuaIN9753jPuvfuPvfe59714kZlFJkWlShuIzPfuu/ecc8+w91rf+tb3cfbsBVZW1/nDoD3Cg6QRkhBCp/VDf/7Ito5ncMu/P+Rv7sE4gpw7jobN1Oqat4QXX3yRNDWq4Yf9AdPJBFAITzObzxgOR3iGA2nJqHYRUib4KoqcMIyM4TOasjTZeBSFBH5Qu9ebTE/YBjGFsj6GnhegEciqxEPY7FsgpTLmm0FAVRTgCXw/ZDKZkKYps+mU9957l/fefY8sTdnb3aWqCoMKaNPMkFcl33/nHcCYdjYaDfI8J45jwshMTlEcMxqNaLfbeL7P4cEBa2tr/M9+5md4+skn+NYrr/ALv/ALaM9nnqZWGd+YApt5XxzlzQE/+7N/ga997WsMDg75/d9/hW9885tcufaBuxgPKO3+cRAu/2kfK6trxEnLoh6au3fvIYQiS2dsbG4YXohnzpWUJlQBaj0sF2geQRzqFeBDhjDzUOiHBJHpmMyLnGowwPd9trZP89M//dM899xz7Nzf5dVXX+Ptt99m89QpE0AXBbs7exwe9lFaWe/JBnGSWFRJUHpGGNahAGFgngVZVQjPNBoIIUxyVRgdu3a7bQ9P1BpjDn2oLZmqitFoZM+DWSQbjYZZLLWRgHH3Uu6+rn2v255nty/cImv3UZdDtX7gfjSJ2mK+qRHfh43lfOjY9XCBz8frzFou0T14YY/HZouPKXO8evEXYYNsjePT2qNZumU0FhwQoJW24r7GdByljPSK5+ETEPigfUhVRaUgaDTorXRp9Tq0Ox1azSae7zPoDxn1B4z6A+bDMbGCJGxSMDMItWekENw1cqKVQB0ouTnU3QtlWRqKQFnWqKHzhV3mNDn9Pc/zahRTLaGgvh/U1itu+25fcDQB/qTGpytoslF0HAdGWyaXrK/3+Kmv/hRf+1N/in/hz/xZXnnl93n9jTfZ3TXkWcPPWLphl+5fgTH/9QDftUoLIy9w9do13rl27cTOdA9z4ko+en77RIaAM2c2OXV6i16vR78/4Mz2GZqtBBBUhaTIC8qixPN81tZWOXfuHBubmwRhZCcnv057jR6GNhOkEAax8zyC0EdWBoI3kPBi/4tx9DwqiwQZHSxN6BuEL/B9POGz0utR5SVBGBInCUHg44cRHrC2tkqjkZDlOe12izhJSBoNWq0GUlY2A5nj+QGz2dTww/yQZqvN6soaSkGr3eTipfOfyGk+SU+mqgxR2vcNalaWBdPpiKLIjQu5tuKUszm9TgeBYDIeo6QpgZRFwWAyIfA9VlbXKKuFp9+V968wnc6YTefEUcxjlx8jLwrmufGV2lhfp9FocPfuXdI0RQhBt9slTVMmkwlZnpPO5xwcHADQ7/cJgoCVlRW+9rWv8W//O/8WZ8+cY3N9nclkRJaljEdjNHD50cvcubvL/YMBgaDmAzoe887OfSbjEY2mCdSOLhA/egHSSaNSkl4cIRoNijxlMBhSFHPybI7wBNtntonCwGANcmGv4rqCpDIq9rUfYT05nbzALgo/2Kzaq/27KinJ8pxWq8XW5ik+//Ln+cpXvsr7V66wu7vHe++9x+dffpnHHn+cQb/P7/3e7zNPU9qdDqdPnaLb7TKdzdjb3WU+n4MtlURRZH0lLfphlux6gVRK1gRe4IhNhtPd8X2f0AoFV1Vly+lm8XQq0AZ8kygrd6KUrhdGt00nbbHorAKsRZNeCpQeCOCXuWUu63YJ2g95zesOU7FI3B4u3rlUyqv3ZREnsfj55HHUiHa5VOv+VieXAPqowGP9u7Y8KOuX6aoDgeeZBF9AgIcXBXRWeqyfPkV3bYUojqnKivl4xPjgkOHBIdlohlcp2nGDxAvAD9FVRZ4bCxV3naSU9XWPY9PdW9n3uXPngnHnL7csO+A+q7WmKIqa5+SCI2V1q0xwpCwy6e7HhXbaHxf6/akLmgAQHlEUsLXZ5Wtf/Sn+/X//f8cTTz8LwuOb3/ydhVaNbz3MtAdiyVF1ad2XWEBcmZup1JoAiHyfUGsq20kUBiaIqKSmHQe0Wi0mk5RZURB6ZjupcoiJ6dfzwgiQqLJYzIVeYI6nLqepBSwpNFVhhMbiOKQoSpSNzO7d3efe3f36+N98/Z1/Iqf8jzq2VlbQlfFSa3c7JM2mgfS14tz583Q7ZrLe2tqg1W7S6XRYXV0lz3OyLKM/6NNqt9nd36MoKqI4YXXFBIVRlLDSW6HdaeN5gYHpLUqvlHDVQITwAWOAjMBwPDCdaw5OllKxu7trlLTX1tBas7u7x879+yilaLUaRFFIlmUcHh4ihGBjYwMhBMPhkL29PV544XkCL6B/cABK0+uskOU5N2/d5NJjl7l06VH6wxGnt7c5f/4c22e2+at/7a/RarSYTaeEYchoNCLNU0TgcfHiRVZXV3nrrbeYTqd4nsf58+fp9/s1CXw0GjEYDCiKos7az549y5NPPEESJ7z+2qsMR0P+0v/8L3H69Cl+4zd+A+EJnnjqKaapyeVlzamxBGbh83M/93P8wi/8AlEQUFYVsqqOSjCdvK7/SI3bt24iS836+jqeJyirkr29IVk2o9Np0+60WfG7RxZYWJTqZGXLQx+zil4vmdp01LoM3G1TKYUnPNrtDufOXuDxy4+R5wXdbpdms8nP/MzP8LWvfY2rV6+S5wV5XvDMM8/ywgsvkCQx3/ve97h75w6Hh4e0223WVlfpdboEvk+eGv8zFwi5ORYMmjCbzZhOp3S7XVqtllVKV7V1hksU3M+Oa+L4We5np1jvfNdggUgBRxALh2i47lF3bn6YcWJe+FHX4SFo+YfuRBzfyw9fATgpqTNbsrp52txXnvsHCKUNp9Nq/pm/g4+HVApPaJpxTNLtsLFxivWNU8TNBlmRMxqM6O/tMen3kWlOiCAOInxtXOXjMKKwCbs7rtrFwF47R/IG6uDadcq5gMjdJ+77OGFLF2QDddDUbDYBKIqCPM8fQJbA3EeRpeoYWsEnOz5dQZMdeVogNLz00ov8G//mv8H5C+fZ29vl/fev8Z1vf5fhcIrwYnzPRxU5tcCcXoqWWPyosfOWEDz5yCWiOEJWFUVVopUi9L36zZFvYM6qkqyvbuEFAXmZgdAo3+fWvZtkBZSFRBXp0R35QFUtF2jBA1XJI4uRriBT5VEBxqW6Otht2a/lAaiFJqbbtNZGAV0vfVRZhC0MwrpDb3kXkVV2NVG/CeqOP6QPSAHY4XF0/tfAwXBIIwiZzabcH/RRQIDRBHn17bce2MY/7SMKA1ZXu2xsbLK2tsbKSo8gCMiylDCM2Fhdpd1ukYQJwvNIGg0uPfooz7/wAqe3zzIYDk0HE9BbWTEEdqXprPTwMOKq+KZrsyxLDg8PuXDhQu3PpLWm2Wxy7ty5WlgRqIMmt3gcHBzwjW/+Fu+99w7ddptnn34KJTVbW6f4F/7Mv8hbb7/N3Xt3CQOP0tb9taqwUSVCGARDVpUhm/qe0Y75ESzDPWzs7u7iezGNJKHXbdLrdOiXKZPxmPv379Fb6RIGPu12m8D3cZ2rjuxdjxMX1ZPG4jksqxKpJNq2jMdxDHqhd1M/n3qxmHW7XVpxQqdtSu5BEPDCCy/wZ//s14mikDzPeP3114miiOeee44Xn3+Bxx69TKvVRJYV9+/t8t3vfoc333yTO3fuIKUkjmOazSZJktRIglsM3bzhvBUdutRut2k0GrY7t6wDsKqUZMokMVEYEsdxjSoUhZXXsOjacqB0/Aw97A51pX1t+TD6h49bTgySPjRu0h+GJn38sRwwPSx4cl2RBknyjVc9NsDWZkHQYmmetsFU0mrQa7RohzGiksyGIwbDIXv37zMeDtF5TiOMaSYRkRaQV+hKEYchBBVFmtXzTxiGD5TSHEHdSQ2415aDXiMQuyhbu/un0WjUxP+iKOqgexlldWW95WaUPM/rQP2THp/KoMnzIcsKRqMxm6fPUVWKf/ff+Xf5B//gV0D4aC0MzOuBHybIcv7hSCgg/IBcVvyn/+V/wU985SvosmD3nsm6kjgmCgP8wCgrZ7OU8Thl49QZNjY3mAwPGc5GrJ02+kvvXruFF0SEkb+AYQ2GbJ0WFvIAeKIWBdTSwPdBJKhyVR+vMwYFTRwZc2KlTXeEVAqlFhre/jJs7BBsS5AUWCI9UMmqPh2+7+F7PlIqY+XhDg1RQ99HhlaEgdl4VakFHcP+P/R9tJRG3FyYDNwFpkuAsv3dcYQ0yywZ22XqhKfr76I1HykXcJJFgxthaFpxnfq1G74vbAfeYv+etxDONHwnkLKi3x9weNinki7YNojWP/zV3zT7EBAGPrJSFFoTeIJKH+NosSjLnCje+QmMOuv0bEnWnsswWMTuwhN4gbC2ccvwvzmm0PeIwoBKGlPkfz4WYzqdMhoOGPU6xJGg2TAB1GRacHh4yN7uHqsrq7RaLUPexqvLCdqSqj8WLebY0FohK4Oc+oF/pGyhtCbLMrI8o5AVZVXWi1BZuiRJ1EjQmTPbbJ/aotFocP7cec6ePcvFixf5+p/5Ol/+8R/niccfMxwsDbPpnO3t03iex3g8ZjKZ0Ov1eO6552g2m4xGI27cuMH+/n6NKLiSi5ErkUdKbstlFld60xobRHmEUYyUyi6YVY1aCbsgO5uRmuOC5ftwnNO08CpjiXn9wz91y3IMD4opnhRQfZjEyyc9hCv/4uG7DjOhkNZKzBU3fJu0etr8C/EIFJSzlDzLmWZzBv0Bg/4hMi+IA58k8mmIAF9pSqmgUkSRj7ZlOXedq6qqieAORXKBUc0BXjLjBRO8N6xunNaa1PIuhRA1ajmdTknTFK11vf3lbUSRsXIJgoA0Tcnz/AjK9UmOT1/QJBYetp4IUVXJ3/7bf5vvfPd7/O//w/+A8+ce4e/8f/5rvv2dVy3B3i3GYnHXnJCOlMo8lH/5r/yroKTJsAHfs+iN0HXnnu9DWZmyXxwGgKaoSpJWwMHhBBT8b/7GX+ff/Lf+l5w9t01Zpqgqx9cCIT28IMYLQ/A9lIW6AdCa6XhEu9Mmnc3I8xzP87h39y6/9/u/z+3bd1lZXeE3f/M3eebZZxn0B3zve69x+849yrKkEcd85jMvcv78OdI0pdvtEscx169fpz8Ysr6+ThTFjMcT5umc8XRCv9+vIXxP+NzbuY+S0hYOT55WlAZdLf6qgU6jge95TGYzUuvz5wHl0uTlXlvuDwJ9pKtxeR8cnfdM4GSksRA+JEmCrArKYulI9VKwJBb0Bcdb1kvBi+8vJt2FRIWdU+17osinqjRlpQgDnzgJKMsKrSRxaAjUhbTbF5DEAapUZKU5B5HvUcil43P8lWWyquebHdeTiQvYzI0ehmGdQTlD3uVFwS1Oy0igQwPlUrDkvt8y2Lk4Hyby84PQmCJbmyGBBiXx0PgCqn8ONNWj2YhBK+7fv0dRzDl3dpsgDGk0jBq7U3fXLmmxdjVSLjUWOIRJY+FhwckeP7p+fXnRDvyAMIwQCGRl3Odd4hP7Qd2hZIROjcBrt7dKs9k6UhIpygrh+Zw+fYZz587xxR//cS4+crFueddoWp0mf/7P/TlWVlZI05QPPviAS5cu8e/9e/8eKysrvP766/zyL/8yeZ7X37uqKoLAHIe7n5VSpGlaL7BxHJPnOa1Wl06nzcHBIVmWgZUjqKTCD0MajSZtq0OWZhmV5VGa8+hKy3qJ+mCHO8d2MlhM/26S+bCy2YNBkSNhH+cZmeFwHMFymviHHQ/nSx0b9ZTn/Ni0TcRNw4paEpMUeITCg0DgodBlxXQwYjQekVcVaWGunycVsR8a/lJWkekSv1T4CALb6BQGAd1ut0YTHdoYRdGRgNbxktxwSKGjRzgk1L2/stImzWaz9rWbTqcMh0PTuNBoEMcxcRzTarVMZ/HaGmmaMhwO6ff7Nc/qkx6fqqCp2UjIy9KUN6Rmf3+Pt996k7/8r/wV/rX/xV9jbW2dX/zFX2Y43EdWc4SXgG3/FDiTEv3AYuyGUoq8LDi1tUlZFMiyIAx85jMz+UkNk3lhaqZhSJrN8YSgnUR0um3u7x7S6YWoWcl/+9/8f/mVX/5llCr4qa98kf/w//R/Rk7H3Ltxm+98+1XeePv7zPIcPwpJ0xRZSU5tbbG6ssL6xhqhH9BoJJw/d56yKAgVvPD0s5w7e5af+NwX2Tp9iv/3f/V3qNIcISWxL1hf6bGxusqFc+dM7VfDNJ3z3HPP0e12SeKEg4M+3ZUe49GIL3zxi/zg3XdJMxNgVZVEKcn2qVM88/RTjCZjbt2+zeFhn/l8Rp4XzOcz3n//fX7nt3/XlvkCXnrxJV54/gU2NzZpNRq0Wi08ITh/8SLj6YSyqvB8j0azgdZwcNjH9wMmkwn7B/toNKPRiMPDAzxfEEYhBweH+L7HeDIi8H0GgwF3795jOBwjlXkon3/+WSbTcZ2lZmlOVZVEcUyn00FoSOdz2p0O7793pW5vzfOCRhJRVkYd140kjqmqkspGGUoudZ8YggBFkRsuCgsuUD00ZNlRNEYe14g6CVXSEpZwNq3B80M836csCsqqRHgeQRhQFSbAdpkcmEBqmS+wPFmBCTDddjXUnYTSRa8a8AOQFbLUdVaigGarzeb6CtPZjN39/oPH/iM8fN9Dqop+f0yWTRFC0m41CMIH+TVSSSNeKBbo5UJeBOoWqI8BTBgitG8TCfM5pe31tdm9G4m1t1hIRGiadsFxWknOqqnd7nD58iUeffQyZ7ZPM5+nXLtylTu3b5MkCRcunOfyo5d56aUX+epXv4rv+3S7XR599FE2T20QJzGvvvYqb731Vn0v9no9Ll26xBNPPEGz2WQwGHDlyhVu3bpFlmV0Oh0ajQb7+/tmsZSKeZoyn83q7+EkGszfpf2O2rbVe0sIj1PuOf6MWXRpCRnS9TpwHCd/+M91gFsHQt4JSKFY+vdHHz8MR6tG65SiEgKkRBUVsiwNr8kifIH2CTzPaFyhKaQim8wptSSXFaU0Qs2hH5CIgEh4CKkQlYJS4QUBgRCoShIEYX0vpWlKZrsfXTnOBUbLqJNDgJaRueWS4zIp3JXZXNdwURhNMRdsO57T6dOnuXDhAgDz+ZzxeFxLGbzyrU/kUtTjUxU0ZUVu9T/MuHr1Kn/nv/o7/N//5jP86i/8fR599AnS1ES6CAVCLxKP+oE6NpYTBU+Q5Tn7h4ek8zmyqmgkMUhJWVU8/vSTnHvkAr/zO68gKvixz32OdD7jnR+8wzwv0ECamvrt/t5e3cG3udIjn41ZbzWYNBu8/4Mf8Ou/9puM0zlRaJTFy0qSNBLarSZRGAAecRzSbrYQGvKiIIwTms0GjUYTfJ+rV64yHozw7QSVpSlvvvkWH3zwAWFoVFillmgEUWgi/yzLaTabzLKMN958k9F4fKRbIQh84ihmZWUFrRWT2ZTZfG7r0gY67/cP8ZyPotQcHh7y6muvmc97Ho2kgdCapNmkkCYI8QKfMAyQSlHklZHTj0LKsgJhgqBer8vKao+qqoytSFnQaCTs7OyYoC9pIOWIuBGx0uvgtI6arQYokF1FHMeWM5GQzlMG/QHdTof5uTO0WkYxO0kabG5u0my1uHPnLuvra+zs7LK1uWEftAylJINBHyE87t27y+3bd1hZWePxxy7RiBOK3MDHYRzR7nbpDw65fPkxRqMx89mMVqtNHIUUZcnp09tcvXaVzVOnabRapGnKdDY1KvMWFVjpdjl75gyNZpN/8D/9Iu+//96C3KpMBl0WZc3LW4a4HQrlJpvjfDPtTKBtgKRc0u1ZsVOpQSv8KEIW5YIQh1EUPhz0yfPqKCL4zxEntFYoWVn9silSZmxubbC+vkYQNOt2aOEJpFy4ti9KFguNtR9mifU8QRD4VNbE1CxK5m9uMVruZltwEEX9Hqe6LKUp6QVCsLq6wuVLlzl/7jy+73Pt6jW++Y1v8srv/g7tVosvf+lL/IW/8Bc4dfoUX/7yl9nf32cwGNSlkHPnzhq0wHp1mme6x+c//3l+9md/llOnTnHlyhX+3t/7e+zu7jKbzdja2uLcuXO02y0OD/umqSHPazTBLbrFEnHYSWTYOMiO5Tb8k0tlf/SxWNCPoIQsL/gfwXH6w+75Q8qAy0Gj0oALUsqq/scSx8zzfQINgfAQvqAsM2RZoDFK5EYDR+Mj8DUECnzt4SHAE0TCwwOKsqyDJlf+df9395hDvRf36aL7UQhR8+/yPCdN0xqBcnyl8Xhco0pJktT3wHIw6Xhyp06dotVqEYYhSin29va4f//eJ34tPlVBk9Kuag2BL8iKilde+T3+y//0b3HuwgUunj/LZDIBK5JovJaWVVpPyvIxlYnAetWVBZNRyU/85JdYW13j9dff4O6dO2xtbvH1r/8Znn7uWV5/9Q2Q8PWv/xm2Njf5pV/6Byip+Mff+CZlaZVJfYgCyAqNtHpArc1TXBARfpwwyXMmWUZYmBu5UppJWTKYTLFVKJzgnctbFAueihICUVVL+jqKNEuZpilyx2iSOAjaTJdmO8retBq4dv265Ttp+6RblXG7D7CZa92lYbdnOzG0NqrE9/d2uXv/Pnkl7Xm3D7HLfNxnl7bpCaMq7fwAG42E1ZUejWaTqqqYTifGNR4YDgdsndoinWdG6EwpsjTj5o1bAMRRhCtXhFFUa4OUZUE6TwnCiHSWMp9nTKczojhhMp3RaDQYDYfMZzOG9v+uLm86hAouXDjHZDIy6I+Vr5jN53W3W0M36a2sMBqO2N3dIfSjuvSQJAmTyQTP89jfP+D09jZbmxt4nmc6P6Qk9I3Mw3g8RhYFpRCsr63SajWZTedWCiIgLwqSODSLhWe8qJRUlLIyi/MSOVJ4RljRTVxZljEYDEzC7Zsgyl1yB3CgJFqaO80PTNkYrcmKEqU1rVaLII4ZjcaL0oQ4ysf6oxUiPn2jKnMCP6TZalCVObPZBO9AIdAE/sK92D3Bjt/o1OuDwAVNi4DmKJJx8hDCSIgIpVBSUxgfg5ok6wLpQkmUMu8XYsHtmM/n9h6V9m+CMPTpdjusb2zQbLWYz41Q7Ouvvc7v/u4ryKpkf2+fy5cvs7a+zqXLl3j88ce5efPmEoHXknqDRXNCVVW02222tjZZXVthZXWFLMuYzWZ0u12+9rWv8fWvf500nfOLv/QP+Lmf+x+4dPky58+fZ3Nzg7KsODjY5/p1w5WaTCe2nLOw4Ygio0xukNTlBXUB3R0B9f5ol32pTLe4Zv8kdMoeHiy5hAmw87WUEiUrS7Uwk7W3zC9D46kAHVghYQRBFBq1cCmpypKqKNGeQoQhcRgTBZ6Bp6UxGUc7+RqTwAVBQBRFNXfNIULHj3mZ8O1QcRd4N5tN2u02cRwzm804PDw8QkdwZHIn9eK66waDAf1+n7IsWV1dZXV11VpM/YiX56yvSE30C31Bmqb8vZ//Bf6P/4f/gEcuP8I777+HEFa1VVcYKdoFH+CBsSB31Al2q9Xkz/+5P88Xv/gF/u7f/e/4+3//f2Jza4ONtTVmozFIRRw3eO65Z/nTf+pP8/TTT/Hmm2/yj7/xzXpBUWpBbM6LnBtXr3Ows4dQkuF8isR0MijPTaqmHFRqjW3Qc9/46BQqBEp4CITtZnKlJFBaWh6XETMLAp+8lOhKmq4JN3HYbTuez/HMyCWljqDurATM+dFGEbjO6DSVLK1XncLzDFeoLDW+b/bxsOkkz/N659PJhMPDfp2VmMxY2+NR9L1D0tQgKrKsmEnJdJraYzg6eTiugbkHjPOgk5QQNsNxCubStrsqKRFWP8dxRkynjsd4PEUqTZ6l3L9/j/F4ymw2JwgEfhBx9959dnZ2ufbBTVa7PSPaVpWEUUiW5YRhQJ4XTCcTtq5eJUkM/ysOA1qNBkVRcv/ePYb9ARqotMK336nT6XDh/Fnefe9dNtfXmM/m+IE176yk4XX4RsLCdKFE5h5RhjAZxUZl3gVN5iRheVP6SLVQK6fO6sJjc68EYcz22TNUZcVoNK7PsUAgWbSOO97Ej8qoqhIpC8Iwxvcgzyumkyloal4TWGTIckAWZYqTAqXjPz98GG8tExwrWwJ2QbJDmxxx2uiLLbSUCuspacqEXs2DiaKIZrNhLHIqyXyeghCcP3+O4WDAZDLh1q3bHB4esn3mNGfOnKl5l2COI/CDmpC7ECYUBGFQW8k4ou/KygqPPfYYzz7/LForbt+5wze++Vu8/PLLfOYzn2Fzc5Msy7h37x6nTr/D3bt3ydKURqOBHwT0+312d3YpKzP/LCOtSqo6gKobIgLflvPcef4442iYVWsjLSFay52CH3dbxzvgjo8PI5k/NHgSi/Ks4zMqvZhHAZSsKJVEVAUi8CHwkZ5G+8aKxxeeWZLZQJ0AAP4wSURBVCeEqVAIqUzVxpOmacl2+2KTSFlJZrNZXXJzmkwu+XRcpuPE7eXOueVmgWazydraGmEY1s0Gxvs1rbvz3Dl3CLvWmv39feI4xvd91tbWuHjxIv1+34Aon/D4dAVNtuNEa5BaIzzotFucP3+et994k19b/xWuvH+FNE3NYqkqW2Y4tmwfT+Q0dQavteb06dPMZlOmsymXLj3CqdOnmEwnvPKtbzEcDJhOxmgteO/dd3n6qaeQUnJ/Z6cW9vIsp7ey9YzJdMLvffs7hL5g+/RpJvOZCXgAaR9oVyZxRrRuSl3OkNxPzmdnudhYa01Zvzrf9/GDkEBZwTAbzAVWg6A2uj0pohFmO0a5eAk1wgQwysLgruNXO0TM7lt4VhfLQmYfOp3oRTWoKqvFcS19Z09AWVT14mA+Z7Inx9NZftmgyJ7tTlsEYApRE3nMRGdFnZzSrq6rUmiMKfTu7g5pmiMwiu3GTiajLAuk9KjmBfN5SlVWiLygzAqkVEi9mMTduHnzJnfv3sPzPLqdNpvra5RFyf3dPbSUREIwTVOUZ6wwwOh19VZWiIKAwPfxPYHQGmnZ3GHgU9muR98q7Zp2Xk0QRnS6XUs0p/6OCPB8Y7ejl0jqDqEymePiuMMwZHV1jTw3Vg61wu9SWQIeeMr+mR9SVbUOTBiaEsFwOOCwf2iaOdLUBk6CIAhxliDLJdSTF80PX4CXk4plfSL3z3Wn+ZjuI6OLVNVK5J4QpmPWLmSmomMEb4UwiUIYBnS7HcNf+spPcHhwyOHBAUopdu7vcGpri263y+bmpiXbCoLQt11MCXGc4HmF3YeoUSBYiHK68k0cm66pXq/HM888w1e+8hU+//nP0+l0KMuSwWDA888/X2uVtdttiiLnjTfe5JVXXuGDDz5gPB4f0XJSWh05H3geQmmUUMfKaA8PRk4aZq5a7sbjgef8kxofhSw9cGA2Edc2eNJgAh08w2u0UjlVKZFammqF7+GFPkEUmLXVKW3j4YUhQmmwyH6hNL4WhL6PF4QmGSwqZvNZ3b3mut1cYLysx7VcOnYBtRM9ddep2Wxy5swZer0eo9GIg4MDdnZ2KMuy5jAB9TyntekWPTg4qBH1ZrPJ5cuX8X3f+Ll+wuPTFTRZ9ARbOqiUpt3t8b/663+d3/i1X+M/+Vt/mzTPOTg8RHjuYbWR9vFAyd1zS4uIkiYQ293d5b/9b/8uv/zLv0Kapty7f488zxkPR0hVUlaS8XjEz/3c/8Bv/MZvkmU5w/EAjbG+wBMEnglQirJkPJnxwfXrPPnE42xtb9Nsd8HzTJZfKeTSDS9spOQpOzkeO3CHgpgc0eaJwtXzNZUCLTRVWVEqhaxkLauvBBRLEgHHT4UbTsRQyUUU4cp9plxn3ucQKVmZCVcrkEKb7ER/OMpU73s5wV7K3hwK5P5WllX9AC7PT8tzh0PQXNkQKeoXNNTlWvceIdQi+VsKvtwmldZMJlPjZaiptUKchUgtwFZWeMK0kxdViXOQ10vopfCNqamZLDRntk/z/PPPs3P/PleuXKXTbtLtdZmkc/KiqI9lOBrz/R+8i1Sa3b1DUJX9fuZIHfqnpAJR1Nmm1ppIKbp0abdaBGFIVS50uZRy5qMGUdTKOYpbJV3hgipT2paWROwCpuMZ9rIFwj+JUsU/DcP3PcLApyxy0BVRu0kYBpRVznA04N69u9zfucfa+iqrqysIYYLahYekQ/z8xXO9dP5cOc+4tUsbDFc1oXtxHCb4cXyPNM3Y39/ncDCk0+3yzNNP8aUvfZn19XWE5xHFMe12x5JrM/KipNFqITDob5GXrK0mXLp0mfW1Nc5sn2Y2nbBzf7fuhAPPutUnZjFk0cWplu4VJ4dQd6kqIwZ67tw5VldXaTQa9fdotzo89uhjXLr0KKdOnUYDsVI0Gk02NjfQaqFJJaXk9PYZWq0Wv/Zrv84bb75BluUEYUCv1+PUqVOs9FZst2nJYDDkxo0bpFlGFEVEFvlabpV3pWaNrgUSHRJSlZX1ejM+buZBs3ZKlq7g+55Fspx+lPv5ZCTqwTLf0deXuYnHn7flnx36ZV43gIISmqTVpNNu02w0kGXJ4HDAeDymlLmpcmiDxHsotLCUDq3w/cB4CAofzyxEqEpaHplAWXFgrc28sNx97OQHHAq0LHK5aD5YJAzLIqUrKys8+eSTfPWrX6XX63Hjxg2uXbtWm/w6fpTjOy0ji86SxSW1Sqmaa/VJj09X0KQkp09tM5/NmAyHljsQ8Oijl/hGIHjhhecIo4R/+I9+jensHkdW4mPI0vGhLeLgCY/p1JRfXNu3qdULtjbXefTyZfYODtjb2+envvZVprMZv/Ir/5Af//KPc/XKNcI4ptPrkFjjztFozGg84bd++xXefPsdXvnW7/P++1epZEkcBWiHhAiPIDAZYhgENhCRlr9gyi3uBlPSmHS6r7JoNTd9gq5spyqFXjI7dMjMSWM5dvHcInoMwlkGel0w4OrariPI/L54gD9qaLdy6EXXS90yX2fPUJQLP6NlRGiRMdpzsTjc+jdXsjx+PEeoDycfnM1oDBpTVhKR2xZyhDE0FaaDUCvqrjswqKXAdkq5wMlb7Mw97EJAu2X4XK1WG8/bZ0kkg6qsmIxHSKUQVEYKQCwO2UzUAnkMUfU841Q/HA4Q4xFKSXzfmIgqG+DWO8FOykqDD8IGgA6FS9OUu3fv0ml32NjYZH9/z+7bLNZu0vxRCpiAGpExXT0VUeEThD6xjJjPZty+c4urVzZZWemRJIbIaoIJd960fb69WoftaELgEAx3bk13K8cWT4fYGBX9nPF4wvvvX+Xd997js5/5DJ956TMI4XH2rBFETZIGZ8+e49KlS6bdv8ityGRBlmXkeUHgB2xvb7O1uc7aag9ZSbZPn6EoCkvQdpIYuuYlgtExc35iDvH2fFMScwKtjz76KGCEXDc2NtBaM5uaEs/pU9usrKxQScnu7i5FURBHMRuba8RxTFWZ0lsQRjz66KMoKbl79x43bxmj7larzRNPPMlLL32GRy5eJGk0rDzCdaTS3L171zS+eNY02AahztjVPDfGZNc8wR7OZFlJjfaxPAfzjChpNPQ8+5rwjH2U1rpOnpZRqeXxYUHTcTTSBUkPKwU6IQWlFVIrFJA0G2yePsXayip5mlKUFdN0jspNQuQLgcRWN6REVKWdKwWelWNxlQhDZzAUB+2ZhFaVC7XuZd8582z4tdCqEy+ty8VLf18Wv+x2u1y8eJEXXniBbreL7/usrq7SbDaPWK24z9eBuF5okDnF8Ol0amQr/hjGpytoQrO5sUFfCMajEe1mk/W1Ff6b//q/5rXX3uCnv/ZTJM0WQRjAEc7AR27WLNoWpdEavvKVn+D8+XN873vf4/33ryCE4IUXX+Tf/Xf+bd599wp/97/7u3zuc5/jySef5NTp01y89Ai/9Pd/yfI6BEVVolKj31TKkns7O1y7fsuYJGqNHwSEUXBEuNITHmEQGghSGwJuGIaG9GtvLtNAJUEatMFzqJMQBFFEFIZogc1KFxoZtZrzcSXvYyiLTaAW7t+21GViEL0IjpaGQhuVNMeVcujNcXTvYadfH5049DGp3uUAyHPB2pHPmn2dhK4Lh5hot7gvvlf987HDXA68itKR200JI9eVLafaCQwjMOqcuE3b72IfRzYqFqdfKoXU2vjora6B8Eiz/KjitnAegIZH4PkeroDpRO7Pnj3LT/zEl/nW7/0BN27cwBeCRy49wle++lOcOrXNK9/6Ft/4xjcMmuo5sih2UjQLt8Z01EhVoZeUpB1SVpYlw8GYp596lueefYGf//mfr7PDZWj9RylgApCVCTTCMEApyXw+M2KhvpG2ODg44N333iFpJERRZK1/IrQ2nKIgWPA5FufRbHvRnbW4h33fJ4yMFIVruw89D99ydUzwZQxx33nnHV579VXOnTvH2bNnWVnpsbKygicESRTy2c8aztDFCxfwPZ+93X0ODvrIStUda3megzaq3zrSOJNys/CZEktVLcp+ixL3UR84p/4shEe32+Pllz/HU089SZIkXLx4kelkxrvvvsvO3i6tdptmq81sNufNN9/i3r17tFotXn75c2xubjKZTNnf3wetOXNmm/WNDR5/4gkeu3IVKRWPPfYY/+K/+Gd46aUXOHv2LGEYks4zHnnkEr1el29/+9t8//vfN8iJRZCk9YNstVrEcYKUxietLEtbkjaBQZJE9rwvURO0S1oNeigeWHMW6ahDhRfX9+OVBj8qYLKnHo2Zi10iqtCUVUVRllTKdDD7UYgfmmVfC/CWktX6utX/lr6FMNUGZas3JgkVdQOCC2Bc8OJKxAsOn64DH9eoA0cTLtctt7e3V5+fXs/ctzOrW+i4ectdoU4gVUpZazO5IO24sfgnMT5lQROMhkPSdI7SmlOnNvnqV7/K/+9X/iHXrl83TvVBQP+wz4mr9ocs4nWibu/Jl19+mZ/5mT/L6dPbjMf/Pffv32c0GrN/eGil3RXf+ta36K30OH/+PG+8/gZgFsPpZIJU5kHUddDidF2MZYope8h6MUeDQqJ1aerOlbRdCJrKdyRSVbcI18e9/N1cPZvFPyGcCas2reUfQaHQWJfs5cCktqFZ2ufy35dkID7eOGliWdRLzUT7IGpx/FP6+GsnXF+HMn1Y/PxQoAksT8xwQAzapxeMevueZS0m05nojl0cPbdLO0rnKYeHfeIopN3pMhoOGc3mBmnUgtLC2poFcqe05l//1/5Vtre3+cY//se88dbbbG+f5V/+l/8S93b2ufbBdS498ghPPvEErVaTRiOm2WzUB6qkxvOpy8CW6oHJml0rtzJkeQvBC88kEVKagMoF8FpbTysnfvgjFjCBOSdZltJqtdBaMJ3Ojc5YaMo6WZZx69Ztoiih2+nSarY4c/ZMPdE7JNIYjjrUAU66WV2WHi2VuxwqYAJfVaNWZVly48YNvv2d77B95gzPPPMMnXabe3fvG2Sz3eaRixe5eP48jUbMcDThxo2b3L+/U4thVlVlTZvHKGm4obOpuT97vR6NRqNe7NwCaMjli2fBfc/Kks59z6Pb7fLEE0/i+aK2z7h/b4fvvfoqw9GY8+cvEMUh89mcq1ev8vb3v0+73ebCxQs0mk12dnd56803mU6nvPTSSzz19FOcO3+ep556iqIoeOmll/ja136Ks2dP4wchk/GEpJFw+fIlTp3apNlsMBqNuHfnLtPpFCklKys9zp49y5kzZ+h2O1RVxWAwYm9vl93dPYbDwZGStBknBz8njeUy/fHX3TaWn5+Pep4eLM2Z46nnZTvpZUVBf9A36tqARBPGIWEZUzrRVXcfublTabD3k7aSJI4ygku8wcwNDkETi+u9fP2dcrcrf1ZVVSNE7ridKbRSiuFwyDvvvEO73eaJJ54gDEMuXLhgOMP377Ozs8N4PK4Ry2U+lOM4ObkDd3+Wx6zCPonxqQqaPCG4d/cOUikuXjjPl7/8JbrdDlJptk+f4vDwkIP+gCzPOQI96OXV/uThyhWuPnzr1m1u3DDttI1GgtaK1159jf/ob/5HrK+vcfv2LXZ2d7h1+xZC+Pz+H/yBIdJ6giLLwTfkT4GgkhqlSiN/YW8+cyMtFKDdMVRSLpWrHNzpUk5dZ3P1EObGlkqTFyaDNSU/817hAimtF+zy4+fh+MNs/+MQDW2zkaP7XXzuoc/3Q8/3RwVN6ijatBR7HL+UJqhYfPz4sZggx2VRS4f2sGMTJ52gZfTpWDB37B7TNkPT7mfv6AacdcpgOOD7P/g+rWYLVVXMs4zpbEYYRUZ88giKgw2uNX/hL/x5/uyf/Rl8D96/cpW8yBmNR6z0OnTaLdrtFu12h/5hnzt37nD9g2v2MJczVW0XYa9GOmUp8QMfadE4w9swPws0RVnw9ttvsbW5xSOPXOTunbuGn+OQMG10hx4Q8/xneAhhOujAoICuDCWEIYYXhVEFv3fvLtev32Br6xRr6xvEcUQcW/f5JXTJBecLBG9xv2k0CCs3EfgEYVCjx3lW4JIN04oP/X6f1157jSiKuXbtOuvra/T7A8Iw5NKlR3juuec5t71NFIWMRnd57733uHfvPufPn+fixYtUsmJnZ5f3332Xa1eu1ujCmTNnePzxx9nY2Fh06tmgyXFJHBq1TNZ1mX8YBiitjB6c5ZxMp1OuX7/OfJ6ysbGJEGbRnqVz+oNBrSMXxRFZnnHj5k12dnfZ2j7N8y+9yObWFucvXGA0GXP5sctcuHAeIeDWnbu8+eabNBpNzp87x6XLj/CZz36Gm7duMRwMOTw8pNVq8eyzz/KTP/mTPPHE43Q6LSqpGA3H3Lx5k9dff53vfOc7HBwcUBQFSZIQhguT5IeV3uqxNC8Id9O435d+dts7+tGjQdVJaFO9XYtkYY9HeIKyLBlNJuR5QWzRnSiOiUvjKSkd34jFfKWVMTKvS5M2YKqnOXufCa3wPTOfLXOwXKntuG2KUsrIYBRFLXp5HHWbTqdcvXqV2WzGeDzm2Wef5aWXXuLChQv84Ac/AKj1nJbPl/u8QzVdUFUtfcdPcnyqgqbAZlFSw1d+4sv8S3/xX+Jv/+3/lLt37/Lii88znk7YOzx84Cb8uEnwcoT/W7/1W/zgB99nOp2xt7cHCGazOe/84D2iOCTPC5TSjEZjfM9nMp7ghQFayjo4KfN8sdoKgbQk7Fpm8zhU4sAcu/C65q4PC/qWX1JS1VpOpqxjKvPHodaPPhGLDS/vfulE2Tjn2HaPxBtOr+Ckg3ZffBEsmX0so4Nq6WlevHzS1ziOfB3flTqOhNXHeTQwM19NHPndlTaXlR3V8vFrYbsuF+TwB/az9LXMJANFWbC3t0/gDwwnyWVjSqHwjIwAoKrKIj6Gc/SNb3yD7e2zzOdztNbs7e7yO7/9W4SBYHWlx9Vr19jc3OKnvvpVDgd93nnnPXMoltTtgukkiYjjiDQrKKyxdK0Ipl0ZU9c6XlKW7O7u0Wy0ePyxR9nf2yUvCqPvpKl9wH6Uhu95huenNb7wSJIEpSpMULrIhKfTGTdu3GRtbZ3V1TXOnz9PEjeY5BMUijAMLOonMMRh50+3aB/Xapk3xlLCoi05XNkSkukwmkwm3L17lz/4gz/g+vUbtNttRqMRSZLw6KOPsr9/yBOPP0Gz2eCdd97lO9/5HuPxiCAwSv2rqyuMRyPee/99du7tUJY5a6srfP7zn+f06dMIT9Qkb98S24WAtbU1zp07RxwnTKcTIw+wxD+ZTqfcunWLqqrY3NpkbW2dOI6MGG2akTl7FN8jCE0p0g98Gs2EOImRSnLQP+Te/XuMx2NTVmu3Wd/c4MLFi2yf2SYIfK5+cJ1v/tZv849/8zfp9Hp8/uWX2dza4JFHLvGlL32Jd995l0G/z6OPPcaXv/wlfvqnv8alSxfNNQOKvGR3d5fNzU2klPze7/0e9+/fJwzDI9pDbq49KeDRWtfzubDJ7cd5Qh6GMp0YMAlRz23aHo9nAx0ppXW2MChxo9kkCCOiuDLAgqxslUDbcpuphIjKbswPwLcef660b9u7l3Cxet10QZMjay8biLsgqapM6XMZcXLBt9aawWDAcDi0vLuzfPnLX8LzBEkSMxqNGA6HdTnOSRosB+XL5eBlOsYnOT5dQZPv1USzTqdDnCTc292j0prhaMR0SXrfxsSIj3nWXH3U83yEUOzu7rK7uwssB18aWUnmS2av86ltaRSu20HUC5zSGmGhITfp1RmB/YzZbn3IR0CXj3XoBgowgYyzZ/Bs279Si4DJBTofwxxWfxRY8LADO4aALcScTvj7gx9Y+r89EUtB2Unv/LjPQ324D3zwWHR43H3+oRs8eiRHDHeX4zBHqF76m0OilHLBSGEWTCGsbISy335BEq4nX+Dv//1f5MrVq9y+dZs0S1ntdVBKMplOSRoNvPGUDz64bsTh5nMO9g+OHONyHGeECAN83zPBmloEqlo9iJJVlWR/fx/PE0RxTJAZ8UJnruw0eU7ivf2zODzf2ZSYBKfRSMiy1HYOCXzr/VaWFXfv3q1VrhuNBmfOnKkTp2UezDK65NCj4/9cF2clFXES1qTsIAjr1u+qKknTjNFoxGw2R9hFNIoi9vf3uXLlKqsrq0RRxMHBIXfu3KbRaLC+vs7t27cQAkbjMf3+gFdfe5X5bMqF8+c4f/4cs9msNgJe1uNpd9o88+wzZFnG/fv32d/fJwgCVlZWAFO+29nZ5Vvf+j329/c5f/48X/nqT3Lq9Cl+7Md+jA8+uGGP03RvJklC0kgIoxBh52e36KZZtiCbBz6dToezZ8+wurrOPCv49ne+yz/61V/lW6+8QhCGjIZDnn7qKT772c/y2R/7MX77t3+bsiz5yZ/4Sb7wxS+wvX2KLMs5OBjg+x6NZszW1jovv/xjgBFPPDgwz5JDMpaDpeWfj84hS0jTUlL+w5TkTkKiHkSezL3kIWxZTbubCakVlZIotFEDj0LCKKpFLLV97pVVvXX5oQA84VvJkoUHoWtCdonV8qR30jlwAY0r1Tluk0OcXNDUsr6C4/GY3d0d7t27R1kWbG+f5vLly1y5cpWbN28ymUxqSQOXVLjtuO2naWqC+vBHXNyy2+mwc9in1Wzy5ptvcnDYZ9AfcOb0KQ77ffqDobHlcEP/8JHmR964fMiibTulFpARi1KbC1xYcFSOpB36If8/qZJ1LII4/sBpZTptjrxP60Up6eMMG2M5UKWGQJUjCx47pvqYjfZH7e79sXd5/Itz7PixYpvmfHpLpaCT+AIP3YU4/sJDfn3Y+T82zLlf1o96yOeWkIGTjnU5cwMbPAlH3rYTogfT6YzvfvdV5vM5VVmxubnFF7/44/zcz/08n3nxRfr9AYPhkM3NTdjfr4/t+H3teHSetwjiTdBv2qnrcu6RLwuj0YgwDHnxhRf4/vffsSgstkRHXfL7URhhaDS0ZrOZ7SqLLSl6wfXxPI8iLxhOTMt7u91mY2ODZrNFHEdLLdiGYCzEsr7QYsFxxrnmWtntC69W1V9WYTaPuhE6LcuKyWRKnue2tBQxHo/Z39s39AEhrB9jTrfb5f3336eqKlZWVjnY3+PevXtWQXxudXMO2dnZ4e7tu9y6dYvd3V1OnTrF5fASjWaDJ594gm6nS7/ft2rOJZcvP0ocx4Apr9y+fZsf/OAH3Llzh8cff4xHH3uU559/Ac8P2dvbr9FhP/BrnSfzrBsScaPRMIhaFBjivZU2aLc7JI0GeVmabrpdQyyfTCbcvnOHGzdv8dTTT7O1tcHFRy6ipOKF55/n0iOX8Dyft956mzfeeBOt4ZFHLvDMM0+xurrGCy+8wFtvvcX169eZTCZ1udCN48HM8bGYdh7sjnzYOE4DOKmUt0CazPyghXVZ1dr4YiqNh8D3LNInJZ7nE1hfUwFkQFma9cI0zKrak1JgGhCEsveeZ2xXHPrv1oLlY3JyGrAQsHRBjbM3yfO8BjfiOK4DH+cxV5albWb4Ab/7u7/L888/jxCmnBtZxweXOCyX45bLgQ7NckjXJzk+VUFTs9nCHwz58hc/T1VV/INf+mW6nTadXpfheHI0YIKPD0Ww0MSo0aKlhfhIQMIRo/pj+7I3r1tt6mjclT0WnQr1Z056bj7quJc+4/lm23qpfPRAUFN/yY95QixycLQ8Zl48rhJuFsijEZRwPX2WF/OHw0gf/A52OXrgq7lM68TdfCgsdTwC/ZCXxdKG7DX1XEaHePA7HrsPPuyaOnTBvNW82emMKblQB9XaWPecOXOa3/2d3+HNt39At9flueee49d+7df5whe+wN279xgMh/zVv/qvc/XqFW7fucPewWF9GhwQ6wufJEnQeXaMh7REfDvh/AsBYRSyvr7B+fPnGNusr/708e7Mf4bHxsYGp0+f4fr1G4xGQ5xWmu8HeF5Qo0JuYp9MJrZM931arTbPP/88cRwxmUxwnCFsUAROegMbgLnW/UUnHXh1ti6EcUdI0wwQdDodut1evf88z01nbhAuBVbWsFebbqQ0TXnvvfd49913kdI4E4RBwOrqKt12C8/z6Pf7vPfee0gpee+99wz6oqHICy5dvkSvt8Lq6hq+75FlGfN5SmTLb2BawoUQ9Pt9RqMRe3vGn/P8uXOMJhNmeY7wfao8t3w74/tZVhLh+bTaHVZWV+kNBiRJo+7cs5h+HawKITh/4QJf/PEvcuPGTYNgTMYcHB7QbjfYWN9AV4rTp0/TbDYYDAZ861u/z//4P/48aZry/PPP8/Wvf50vfOELrK2t8tRTT/Hee+/x1ltvMZ/PiaII+PDynBmLyeSBKeJDUKeTxkloFe52EdiqisDTmEYSrQnCwIhRIpBFiQiMBlWr0ST0jWDufDYlz7O6JOxyfiUE0vKWPM8zt6Yn8PCNK4TStZae4zEti626Mpm7l53FSp7n9etOc8t5FcZxzNraGqPRiLfeepvJZMrLL1/jqaeeqi1VHHEcOLIP31+o7rv7/uMn7R9/fKqCpvF0wiMXzvMX/+JfZGd3j7e//w6ra6tcv3GTyXiy4PK4C2k/t3yvep5nHy518npZRwuCD7NeWY53Hp5bHIeSFmhV/bkfJp44VsESHni+Z9vdT7QjfuCjH3eXy7wJ00q7eGgbjQbdbhcwyEM6n9to0hyLsmrWIgxtMPeQIO7Eo3n40ZlMXFjF9eUuHW05IS4AOWFzJz48H30mhF3JHNnVIDRioT1TSYQv0Cd1tj4sKP7ovSJ83wTZCvwgMLwEDX/lX/0rfOYzLzGbznj3/auMRiPefvttsizl29/5Nnu7e5RVxW/8+m8wGA6Yz2f1Vj1rHeUJQbvTZnVlFUbDj30Pam2en4P9ff7hP/xVnn32aR5//DGuXr3GfD635+cP830/nePLP/FlVlfXAXjvvfeYTMaWV5QQhiFlKUnTFKUUnU4bz/OZz017fZIkXL58mSSJ7ZxjuhVdCXexGIPNUo7c3BrDM8uyDN8PTNAlTGt8GEZcuHCB7e1t8jxnNpsxn88J/JDZzHA0lTTq5GD253ziTNBn1J6VBOVJPBFSKk2Wzfjggw/q7rwbN24wGo0YjUbs7u3y+OOPs7W1xebmplUKt6KbeU4a+DQazfrcGH/JKWmamjTLljqddVEQeHWA4AyJPc+3AalnicUleV7azqyyLtkEQcClS5dYW1vlc5/7LLdv3eHuvfusrq7WhtxSKsIostfFYzgcc3Cwz8HBAcPhsO4SXFtb5Stf+QnOnj3D9vY277zzzhFy+8OCJurLpW2ybMbDEKPjf6uTqGOB0vL7jmzHlQbs7aIdCu+2oxSlDS4iG8DEUWT4WUqhqgqpjXq9e7+WEuWB8qygpdWiwhr9GiVlfaRcuWyD5c6TK525f3EcM51OawRqWa/JST84tOm9996nLMv6ugwGA8AE38782iUlyyR0sN18xSc/IX2qgqaqyPm5n//v+fVf+3X+47/1t4wp6rxBVVZgCXF6adauA/Glxd/8wEMXiqPw6nG4xYzjClB1LGMX7oWnm5vgOBI3iA8LcE6KxI7szKIaGqPAXZ2gkn3CJsBMy6bbCVhyBD8JvTGmrebhDcKg9pJCw/b2Np/73OcYj8dcuXKFW7duUWS5+f6htdiQGm0tJh4+fjgEKo4jut0OSikGg2F9fl3D4ZHz8CHXeDE+IpQUTv7fPCZKKQQCqS0x101yH9UxdvxafsRxeb6PqiRoeUTeAOBnf/Yv1psRAt56+23+13/jf8s8zSjLqv5Gv/3K79Wf8d03tX9stRJ6va5BmvoPC3QePFBDU1PIUlGUFW++9Taf+cxLvPjiC/zBH3zbTp4CKX+46/ppHS+//DlazQ55lpFlBqVxrelG5busW6BbrRa+b9Ccmzdv4vs+L774oumECwI7PywkHoSwXZPY9m+9/M/d98oQfX0nrigIAo84jjlz5gxPPPEEWZbVfzs87LNzfwcwSs2OS5JlGePxmPl8jtFS6tJqtUjnM7I0tRYrkqLI2d3dYTqdcvPmTfb390nTtA7Erl27xubmJtvb25w9e5YkSepS1traGpubmxweHjKZTGrNHddxh6C2mDFB00IwEcy8aqY+VQsmOisOKVVdrjHnO+DixYv4/iW+8NnP8tSTT3P9piklemJR1g98jzD00VpRFPmRsttgMOC9997lueeeQYifZHNzg9OnTz+gaH38/0eDJ7NuHEmwtT4xwDppHA+cTvq7WJr4nMClQNsQXBv5ACtTozHrhfSlRaFCWp5HmaWUWUhpt+m+n+ukU55Aeka6xjSyaEN3FIvjOB74OZ7e8nrqDH2dAKVLKNznnQaTK8HFcWz9Dm/V8hBaa9rtdo30udKyQ5fcPe2QMfnHkMV9qoKm4Szll37xF3njze/XJ/n2nbtobQxFjDedWWdcYnbEkHS5TvzQ+3Y5UFp+jSOvnbQsaE0tclgT5k5470eqGtVPmTi2rh+DUWzgdHxzx7e+wLxsGzPLAdPJ+5elqonAZW4cr/3AlAhu3rzJjevXwdbJhTAO2YYoX+HKVx/9JX+4v33mM5/h2Wef4cqVK7z66qvMZgtfoQdKcw8LOh+6L33iS1IqXnzpOc6fP8+Nmze5dvUak9FkMRV+FMH+4wNp9ljNgoun8IXRS9JS4UfRiVmn1prZ3CjfOjG55b8B9QQnNcSRxzwteP/KDT7z0gtsbW1y6/bdjzgoM1zZo6qMMXRZFnz7299ha2uLRx+9zNWr135kAiaAtdVVer01Pve5l+uk4t69u5RlgdNU833DIXE1DyUVs5lpsf+d3/ldtFY899xz9YTleY5T5vTdJEoZt3pjwGqeXWPIHRBYLRyzaJSUpcTIdRjfOYCtrS3W19d59dXXmEwmPPHEE7RaLatQbtSTDw/7DAamM6nVahGGIbs790nnc6qqxPMEjUYTKY2Ip+OTeJ6g3z9kPp9z+/Ztms0GrVa7LsNprWk2Gpze3ubcuXMURcHrr7/G/v4+nU6H4XDIrVu3ieOY8diYqwZ+UC9+ziHBfEdlVc8zpDTioEkSkaZBTTQWwtyjcZwAikxq2q0GZ7a3mU2nJI2EVqvJ6moPXwiUFoBBwVqtVi28qJQiy3Lm84yyVKyurnP27NmabPxh4ygf6cHkYznYORr48MB7P2wfDwRSdsoWWpAEIZVWhuytC3w/MGa3WlMVJak0NiNhEBL5IUmc4AuvLndJKa2rhKDECC17gC+AwLfGvl6NjDpaSz2EUxH3bGPS4m+Ou6SUYjab1UiRSzjyLCO0wVWz2STPcw4PD+tGhmUSvukW9CirkkIXCGHFM4UgsN2rn/T4VAVNAvjP/7P/gr/5H/9NfvpP/TT/l//r/82owx7jYTj9QSHMTVQrVB9ZSC1hWamTF3m99IMQi0DG4+g6e+yaKG3IsCbSXYJRMQGPJ8yDrbXxCHJdMuoIqc4FR/rBZ+gkpOJjoSomB/k4H3CNd8vVSa2hKiVUC4XshXi0PmKmu1B7fdg+/nA38muvvc7Vq1eQUjKbpUfjR41RvPU8kx254zke7x65zCcETCcc6mgwpCoKbt++Qzqb2RZxQVGY7p0kjo/wej7+OCGa04ZL8tNf+2n+0l/6S5zd3mZ3b5eN9Q329nZ57tmnuXvXCPPN53N2d3fJ85yVlVUGgwHj8dQaVabs7+2RZSmJ5QBIKfnyl7/E3uFBPRm98sorJxDpT8YqpTIlIOGbYL2qzCQ3mY5ZXVvhhRee54033vxDnIdP53BUgFOntnjxxReQUvIH3/4Drrz/PkWRkSQJq6s9lII8z+pMudlqobTie69+l1Ontnj++eeRSlFJReTVuKBJcJSuPRw9IazZrtFzQph2bYRA2xIWmOev1WrRarUYDAY1X6TVanH61GlWV1cJgoD5fM6tW7dI04z19XUeeeQRAKbTOcPhkFOnT7HS6yLLEt+S0cfjMdPplLzI6XS6IIxQ63yeMplOAfCt3ILL+KMo4uat27x/5Qqe57G7u29KbUHA97//AzT/I2EUMppMKKqKjc11pJRcv36dw8NDNjc3CMOwDmhcM4Xxh4uJo5goDE1gEPiAZGfnPsPhkDhOWFlZqcuUvW7HniMfzx6fb/+5pGBhJuzb3xf0hONr8MNKc4vxIcnUAx/7YUWCT96o7/lEYUReFsiqopLSNg2EKG30kmZ5QWhLyVot/u4Jg9BopVASKipbGQHfqt1r6aGt/Z6wk4chn1cYKxl7Hm15lpo2I00A7vtsbmywurLCnbt3mc1maK1pJA0836u9N5NGQhgYeQeXFIBmPp9bVNAYUSvbHYg03DylFAQ++sMuyx9hfKqCpsCHg8GYv/Wf/Gfge4zHYwuj+HaGscRj3CRjF3fJg4iA1kZTyb35wVreYsfCbgxRo1lHFuFja0wllTGtPaH846GtirLLKBeZyceOik+KfU5Y58Sx/7vvVWvqLCU4CpstaHPjVaWsyeAu+Dlac1/U45fFxVzW4erbnyQxOM9zY+2wNPwlXytlM5Yj44dFeh4Ymjt37gDUJpRGR2dxLo4f08cbJz3R5rWyLHnjjdcZHB7W+jXtdos8z2kkxkvLlCYkZZkTBAHdlR5aYcsxPu12u144kzih1+3S7xtewHg85Nq1q1y/foObN28fWwjsbHjCiRIWnnelJGwAlWYZH3xwnc2NDf7Un/ppfvM3//Ef4nx8+kZeFHjzKd1ul7NnzxLHCbPZjMODQw4PD8116XYpipLJZEya5TSaDVbXVynLihs3b3Ll2lUOBwPj5+b7xs9waR8ONTLyEz5BGBEEhvytYCmIMEmW55ksPo4N6XZ/f588z2s18e3tbTY2NuqA+80330QpxRe+8AWeeeZpWq0W1659QJrOWV+/QKvZYD6dElgF5v39Aw4PDxhPJga5EILhcMjUNgSUZelEpa0PJsyznOl8h529fULLa2m1O3hewKuvvc73XnvdzB8CWp02lSwRnuCN115jd3+XKDIuDwf75v4t8hy0psgLJhPDi6qkMQk2tALJtWtXef/9Kwz6fbq9ngkYpGSl16MsJUVZImVlhGQDYcuD8kinVxgGRFFAGPqUVWEXd3UkSHLEaZdb21cfes88NMAS7lPH1qF6i678tkCZjqDKSwGMJyAIA+NPWojasSHwfcrKyjZkGcSxQRXjiCgKTXlTFGglqYQHQtrynkYKCZWH71VGhoDFOik8Bz5oKiXxhcb3fPAtF9SuIUqVlHkFnseZM9skccJwNGY8nlCWBVEcE4YBlZKUSiKKAgXEUcRqs0mRZ6TZ3Ho9GpFMhLBeewaZlVpRaYWnBULJj6ZO/CHGpypoQkPgwetvvQWAH/o2oLHQiADfCxAW+ThSKlguyx1HiDxhg6ljJRpX46uN2PTDSz0ueHLoy0OgVonC+LHah+/DymTL2/44r7nXH5rcOLkAUbcwH33dfFBpbaxThMukqIMkp8MjpVpwtrRTNT66Y2ca+0kApCcAhab8agO0GtT6oeDYj36v1pBli6CopsXZc2HQg082pVFKs3t/l937ux/7M75VpDacEMOhaTSaBqb2fVpJgzSd8dprrzJPcybTMaPRDAWEgUdZucnl4d/F3NJL2bYw94oqJVUh2dndoWlLQj8Kw/MEQehRlgXCilt+/vNfoNPp8s1vfoMbN29QydKgIUmE8E2mXskSBHR6Xd6/epX/7r//Of6Vf+Uv8+hjj3FwcGgWNDRxHJprSomS5rlN4oQgiMhz01EWBDFCeEipKYqKwHpX3rUZ/M2bN2m320gp2d7eptNtU1YF09mU6WxKlpmAYzqborWi2+uwtbXBcDig0+kQRVHtJxY1Gmyc2mJlfQ2ALDXoUtJsGm+wLCNpNAiCgDzLGI3HjMdjirK0+apGeB5RGKKAkeWpAMaTD81kPuFXf/VXEAgmwwl5lfPBB1f5f/0//x90uz3G4zE7O7sEgc+N6x/wG7/+j6jKivl8RrPVRmiB85K6fesWP3jrbUbjEatra/wLf/pP88xTTxF6vmUeaHwh8D2BH3hoFEWZU1YFZVUwn0/JsjkIaLVbrKz2jMyBVSw3ht2inoMMoqJtIGH4aY5Pthwr1UvDiY+a+YNLSpZfdZSP4wmqCRatAbGSVGVFXhUgIAh9lNBoVZGmM6RWBL7H6dNbPPPMM7zw4ovM53M+uHaVGzduGJNkyxXzgwAPH2Wtv1QpKRQIJfCjGG115IwAqbEBq1SFFAolNIEf1A0zWmnDhSxKsrKkdXjIam+FbrdHnpcMBn3yoqBSFUEUugwNhaKSFUKb0odpLogsF0oZBwMUfhgYhCsKDLoSeBRq0ezwSY5PVdBkNc/wfWwEK/CDwBCWOQr6ODqQqFV7zSJuzAqFgSQ941HTbLVQgKpKlFSmc8nydarSvGYyCvOwIAw3oaqM0qoQAuELZFUZtr4r4Z1QvluMHwJZOultJ712DIGq44gHAimr4/GQ3VeltXex7c4OvFno+Ni/LQXxJxr5/hEDpuNzyknxoEPqHvjAxyxZ/rBjGYj8YUidR4dDGpe3e/SG8SwSuCCDuk+ac19vSSyCWEdwlZWiKEb0DwcPPQK3icA/fvwP89oxf9Lug8eTDwFFUfGuVSD/URim8cMJ7kp83+Ps2TMEgc/u/i7TdEa/f4gqFEkzISahLCsqWaERNNst+sMB3/nud/nxL32Z7TPnAR9EgFY2oxdO2R+E8Gg0WjQaDQQYs1mrwOy4cI7ftLOzw87ODru7u/R6PQKr3u08wNI0RaNYXVsFoNlsoDClm3LJGkZ4xgdRKkUQRbQ7HRrNphXJPEAEAZ3eSs2D0cosclmW0bFlsTwvSNO5CaCKAu3QiaVkp6yk8UVUFffu3kUrRRIlCF8wm0x49bvfNTQAbfSpup0u7/zgBxzs75nzr2FtYx1VSfZ29rh14ya793e4c+sWu7v3Wd86xRc+9zKB55PEPkpVZOmcvChQ2qiyJ0lMFAX4vkCpikqWFsGCIPCJotCWRcHd+I575jogtSXw10CM5fIsJ6jOytONBxNsE1waTpB540Lh2uFRZn5wQsrLHWTGL9N0IXqeTxQFaKCsjK1Np9Ph8Sce50tf/hJf/OIXKcuCt79/jsYffNuS6iVZliJq/a8lVElgnAUqo9ruC98eK4CRzZB2jRS4QFKghEGBiqpElTnj8YTIj/D9gCiM0VqQphl+6BMlkbk/UAgNUlUILSw+YmyElBKU9rs61XIv8A3lxrcCn1rzydv1fsqCJqvgbpAaeYx0Y/9UHUfjjrSmmwXHBAPmJg+CiCRp4AcBVVUaUTXfs10WFXleUFbOhVzadnqF8s2DUknDaxEIRCAotemWceJiLjP3PNMKHMcxeZrjW5ExJU3k73kelayI44j5fGa4C4HJYqVVWq4qiSNxR3FgyZguMDkSMtoTsqQTVI+lJ1ofe3n5Ybcn/MijrG2rvzi6cD8w7INeK0w/cAzH3rs8EeCg6PooOLIyu8nKTRoY0qwfLALdI8Hose910rG4P590qA87fNNJZ47f7O5hUdryARy9PouypnpgG2ZxMJOU4+S5U+W+Xhga8muRm66jMPTtBLoQRlwWSUQbHl3gL/z6FBpVHT85J5yg5T8fD06X//YjNhZX1mTcUkriOObFF15Ea8V3v/sdBsMBoVVhzvOC2SylKivCMDb8szzje6++StJo8cRjTxBFMZUUpoyPJdNiFJQ7nTbdjul8nGc5ZVnVViZ+kuD7PmVZMhyNKPKcvb09qqqi11shz3IrChnRabeJwohHHnmEJEk4e/YsaNjd3eXW7dvc37mPVJJWs0WWZWRFQaU0rXaHpGHavf0gIGk0WV9fp9PpIKXk2rVr7O7u4vsBjz72OKdOnUJKya1bt3j33Xe5d+8eGkEUJySNJnmWkaYpsqrwPJ8kaRMGvjGHth10Ds0pyxJZFKR5Sjo3YptRFNWdeEEQ8Obrb3Du3AXef/999vf2SBoNeitrdNsdyqIkL0yn1WyJ/H769LYNSBt0uz0mkylgFN1d56zr2DuaKC2Vx9xzeewZOLGr7EgydPJD4/tGR809w8vdgU5XK47iOmh3nmzub2EUWo2rkna7jRCC2WxmOs6ThCeffJKnn36KixfP02q1OHVqy2g2zecUZcHh4SFlabo/6/Ju5BNahW1ZKYSWSJwUgaCWCndzmNGfQSBQUtcBnVbGWaMoKsqyIk0z8rwwfKfSJIF+6OMFHp4f1qtBvQq43XmGBK7BdDLbEq/WlhvmGTDjkx6fqqDp4sXLBH5oFxlpSkiYGrSrq2pXZtPCnmxvsYgLQVUUJviw7yurip37e3VEq5WiqvJ6ZXrYTe1KFUdGKPC9ELuU4/lBLQvg+TGd7hpbm6fYP9gnCSNaSZM8S2kmMY0kYTafsblpbAyiKCBJIsaTMUWRE0chszRFapNNdHodBv0BeV6A55n4DM/ASnYmV5WyC69nrAiERzqfW5NNrHy+RiNRviHuBVpApfC0QTuqWpnanEdp6+PCMw+OK2HW0vvuhtWAD34o0HIhmobG1JmV1ZgSwrTX20VHYFA8jZHwV1rhCQ+lDTJogkZFGJpJQdrV35A27QMjF0iZE2ZzGh5GGdc8uJ4NQkLPR2rTqiuEZ/ld9rbRug4QnR5UfQ9oV7b0zcQhHCG+bvg1x2xJ/gLzEBuNMG2rCAZexnP8FG/BWcAGkAKr1+RTVQrPtqUrrREWBnRWGlobN3lzzjRKqDpuC8IAVZQm1bBoqNIsYKfjN7i259AFbOqYPIE49n+HsP5xpHf/NA9331nUudPpcPnyZYoy5+Bgj+KKXdStjUqa5VRZju87Re6SN998k067y8XzF2m1m2h8o9CMa8k2XUEb6+usr60ZQ9yiMNc5COoyUOUJpCV+SylBSdLZjL3d+zXyG1prk06nC1rXZOCDvV0yG2jN5nPiOCK0CtxlVTEZDenHEVpJYwysFL1ul4sXLtDrdRmNxrz15hvcvnWLJIk5d/YMlx55hCgKSeII3xN0Ox2m0wnNZpNTp04zm02ZzWYkScx0OmZvb5eyrGxJSBOGMZ5nkKiyMublUdJAK02aZUxmU4IgxPd8yqri3v0dxtM5g0EfhKDb64FdWN94802UVvzet17hnXfeoaoqWq0We3t7TKdTPvjgA6bTKVVV4ftGgbrRSIiiCCEEKysrNZJXX3jg42QLP2wXl9OnOq5o7biizWaT1dVV2u02YDogR6NRrXvkgp3lwGc0GgFY0dF5LXJqGgfatNumHBuGIXEcH7E9WebdVmWFkoUhbnseSlrk3MPOqWb+1C5I0nbOlQofQRTGRGFE4PsWNDDdeBVmTpNKgRL4+HjCoG1m/lw+BwJfmPVOaZaMhReK5K6z7pMen6qg6dd/7R/R7XZqrKEeDu88AnM+mA5nUnHv7j36hwMjhJZn3L13n3fe+QG+8BgOB8ynM65eu8LOzn2qqrQLsQnGXGeF8AzCVG9aG1SoKtxrxh06iCIQHlVegvCZjifMRlMQPnMx51AdoKQ1ZMWswHfv3KKSxj1eqhJp/YB8u4hKrVAamzWYGq9vlX1VVdQImvB94laTTqdLM2lyZnubQPh857vf5dSpDcIgoChzvNBnnqXksmB1tcd6b43712+y1lkxImTzOQpodzuE1mW8qCo836ff75vypzZqx0pKQt+4m8dxwjydoVRJ0ogpZWW7aiKyNKXMC4tECfwwIPCM4F0lJb4IzM1ur2cYRsjK1MuFgNAL0dIs4lEYma6M0rQ/myDRr923TTnCalnZSMCIrbkgw+wmCAKKsjJdOMKjKHKUQ+qUNno6vrDbsqrDNiCXVVWLWwphuEWe59eKxotyqVq8b1lW3h6EWoo2tNK2C8VHKwkBpkPEF9aU2aBGsjTBoFYaT5vOJYfYeb6H7xlXeWNyWRIEPpU0lgmeX8e8dWC0OAB7WCdpwDq06fj/jz2W/6yP5ZhRaSM2GYYmyGh12ly4cIHHH3+C0XjMvfv3TMmq07WdZYUt65lF8MbNm2xsbPGTwyFxHNUm0Eax2zQexFHExvoGGxvrJI0IUChZWa8xExggTdIQ+B6+ECRRhEAzm0zI8pyyMF6H09GITrdL4PsEYcjertFvKoqcdJ7iB8Y3LwwiPM8snrPplEPPI88y4jgmSRLarTWajYQoDEErsjRlOh5RlQ2UlCRxSBTFtFstLl28SK/b5datWzQaDR577DHG4zGz2ZS19TVu3brJ3t4uUkmqSpoigfDxfY+8MKWyIAhoNZsEYUg6n5POU6I4JopigzQoxWA4IMtzkjgmCCOarRZpmnL1gw+4decOzSRGKUmv26UsS959913yPOfatWtMp9N6njcq7lPu39/F8zxmS96mH1WWP0n644cZTlJnucnG7Tuw12ZtbY2trS3C0IiWHh4e1u8djUa1YKmzLXGWJXmec+PGDd577302N09x9uxZDg+NPc54PK7fG8dxjXItNLFkXeFYaCIJhE0yPbuaoRW6snQWbR0rpCYQHkmYkIQRoR+gA+ogTSqJ1PZ+Vy65NHOhtZ6vz0dNRtcaoTSqDprE0oP5x2Mi/qkKmt743utsbG7S7XYIwsDI4W+sI6VRl+10OnQ6XYvzCKRSZHmOVIooToiKkq0XngVtMhctzaQzz1K76EKRFRRlYQm+ZtICHOwA2pgeAnZR1XZ/VvhQKQPpalfDMktYWUp27u+zt3eI5wWgNYPDQ/Z27zMcDJhNp0wmQ7IsIwx9dnd2uHnzOoNhnwqNsQBdrEmqlGigKCtEtVySMu/SVUU+KynTGUPhsbdzG4Egz+fcuzczi7EWIBRKVWg088GQHf8OqiiZDkbgiVoUjR2xpONgvmdVFDUiIg30QeWZ34s8RdlVN0fjByGy1JTpzARXYUgcJUZPRRtX8UbSoml/n85nVGWF1gGdVpdut0O/f4hQmlarxWgyptlssLKywnA4Is9zhGfUyPMip6xKNKaF2/OMjocnPPwwQmq56ITzoNQSUZosvLKBmkFjPPwoQEtpyrCeb1pZlTTf1xO2QuwiDoXGaFVJUbn0xwRQVCyeZrUkwro0oS5zsW3AUpsWu4/b37ULVAQIzyMMjWp4KVUdNJltGa2gRqNlJ9ES4dtAzPNM42mpTHB0vNT2YQn18TLdj1CwVI+l76+Upqokvm95ZVLSbrd54oknGY6GHPYPGY3H+H5gfdg8hoMxnhcQxwnj8cSUxm7dotlq0Ot1UdLwKRGGyxOEIatra2xsbNQlujCMSOKkRl6XRR993zfO9laFOQgCqtgEDJ7vG80juxgvKysrZZBc79Yto22UGCPVqqqM56E1yl1bW6PZbHLv3j3CMGQ6nRLHMWfPnTMaQGHIwcEhWZaxt7dHt9vl/Pnz9fY3Njbq9v5et0ez2SSOY8qyQVkWVJXVqpKqRkwEhjS+vr6G520xmUwYDkdUVUW706Gw0gLOUFgfHtpOuLBGb8aTCVop8ixjNjMSHS648H2/RpQmkwmvv/66NV0uuXbtGrPZ7IiFynL3MDwYLB3vij5pET9J4dtdD/d31yXphivfOSHSdrtdt9trrWk2m3UnoAt6qqqq7Wu+//3vU1UVk8mEjY0NDg8PefPNN7lx4wZ5ntdBsUOb3HUvysJKC5gKSuWXNRrtW90mD2ERIIuSKlNd0Ipa4sGR9X3PJwpiosjw/VSVoaVEoilt/BOFAaFNAnC0DLcfIew0qFHiGOH+2Ln/pManKmjqttpcuHCeOAw52Ntl/96QdDJGa02apYwaTcIgNARJ64sTRCHdlS7zPCOdp/S6PcaTMZ7n0Wl3kGXFeDymt7KCEJoyz+l1ewSBgWGNt5smabaMP9NsYiHMyJLPvNryAhEBmnQ2ptFsgTCkc9c9kD9dkKUFQntoZVpmszQ1HRtlYVybdQVSkRU5s8mEPEsNquUJ/Dg0woKeQFeV+b8NzNC2dKatCau1GqkqabWLFGVlkKB+/5BbN24xmU/rFs7+cMD+4T6T4ZBbN27SarcYz6bMcyOcWN96wunO6hMXygqORneALCSqtFmDvYmrokJVc7LUbt+WNuZzk0FIaVAggWA8GZlApKjQSjKuKvIyRypDhM2yrEb+tFJIJamrgVqbv2lDRFRlsYgFfPDDgKqwvDDPt/wxjyhpmOCwKowzeBKZbXqe8UD0AubzlDzLUZn5DtiFbaW3gu8HzKZzSlmRRDHdlR7ZPGWeGrFAZTMqIajLkrIy7drSEfMcDGbROFmVC572Iv5CKYl05bp68rZlS8zEnufpQlfL8tJUZWEkp9Txh5lfjgdSP0pjGey26EStbAxEUcz29jaPPvood+7dZX7lCpPJhLX1TXq9JpPxHMOrNAKOs9mMq1evsr6xTqfdqjNljVucTePK5uYmZ8+eZX//kPncyE+AWUhDa1qKDRB8a2Hhex46CMx9pkzJQlp/rhpJ0Lpe1LRSRvdrMsUPIqI4JonjOtlwTvJO6wnMQj+fz+l2u3Q6HbTW3L17l+FwyGRihCubTcOBAiPhEQYha2trrG+ss7u3U3NVtMYo3AuTEPi+b5IwBJ7v02y26HS7bG2dMomjEDSbTSaTCYcHByht5tfpdGYeF89bkJzzwtzyFjUpigUnallt2jkeuIBtOBwuWt2XrvuHLcwftXCf9Pnl8tIy+dsNN+fNZrMaGTvOd1pZWam3kSRJzY+6fv0677//Pnfu3OHdd9+lKIr6vO3s7DCZTGxZslEH2y4Ay7LMli5N0GNQKLMe+b6Hp6xyO8J0MSqryK0MyiS0rQRoUUsZeMLoeDWtcKoooVIl2vo4KqlQnjScJc88dELbqEmAZ0tzAmq6hVtnBO6FT3Z8qoKm9dUVWnFsvLNaDeLIs+Z9pnYfeR4eUCkFUhF4HpHvocqCWAga3Q55nlKlc5N5FDHpbM7uvbuoIsPzfbL5HGHJ3Vor/NAQK5OkQRzHjEYDEBCFMUVpBMKCKGDQ79PprjIejem0mzQbLQ76fWSlWFtbIwoTirKg0WiS55XhVGkDh2s0cRgQ+gKlAwJPsBGuIs5sk8QxUpYcHB5QKsnjTzzBeDy23lCSRruN5wdkaYHvBURJVJcvDc9roenhsqIiLzjc3Wdugw2lFdN0xnA0ZD6Zcri7T7vVYjafkeXZUucHOLKdM5J1iq9aWTdrYRA2rZTF36g7Qdx7lDICZe618XBc8xiEMAHe/v4+1658wGAwYDIeM5z2wcK+ttBlgs4yM4ifQ8TqYX9bFuhELzxuMJuShVxod9lropSgstmUsoKOqqpc/MKskgg/ROalRSIXNSxlO4cMx8KQ+ItSM5/5JpivSpQ95xptz5cJ2pTSdmGARWSkrFFpm/FwAIDnWwE3aQjwzUbTAj+CKIzMOXToQWVUdFvtFnfu3qXT6RqfOAFBEKJQlOVSKbIOyOwPy8zz4+NHLUg6PmrepJF8MGJ7hpTr+R5hFNFumzLdU089yWAw4MaNmwiBVc1uUNiSfrPZRFYVV65c4cLFC5w9e5okDvEC3y4EhsfpeT4bGxs8+eSTTMZTPvjgOoPBgCzLaFhft8CiA4ZcbZ7FIzpmFl3SS/8XQhBaLSanLJ9Z2YBKmuPrdDq1gnOWZUynU/b392sfsSiK6HQ6rK6u4nmeVRo/rIMNt59Wq0VVVdy9e5dms8nm5iZra6tGrsIh10pRlKUhm/shnW6jFgcNwwjPD2k0WibgWt+g026jgcGgz8HhIVEQMp3PuXXztkGebOADxq4Frez3NRICzsLFnRMXJOR5XgelZVnWPKMfpjx30u9wFKk6PpwwqPMDdNwmR0ifTCbs7++jtUHepZRMp9Na2HRtbQ0hjCRKs9ms0bP9/f0aeZtMJly5cqUOuoqiqMVIXVnOqKvHdWnQiPga0U9dWd6VNDQITwRWgBUb6LPgQSrb1KCFKadVJlgPQsOhajabKBR+LiiqzGhMofE9Ydg3SqHxrM2QXnDOl+FemyjX87n3xzNBfaqCpiSJqPKMRrNBt9shzxaKrnEU1Uq0CA/h+XhBAFpSlQVB4KERVGnKqa0tkjhGKI0v4PyFc7RbLXzPZ9g/JA4ND0QrQRwGxIFv9CG0NOaGSlEqUKpCBL4h4yrrFSZLqCpkmaPyjDIvyKIQHRn4l7IizXLDuxGCvCgIwgBPRBTZ3EDnScSBFbyrGomFlSvyNKXMc0bDgQkWtUGPPOGbdk0vIGk2jDeaNnVpg2pooiA0gmeVZDIa0223WF1dAU8wmk5oF23OnzsHUqPKEilLZFVZ6xRD2KuUI9+bluE4jomjiEpKZFmSJAlRGDFP57Q6XabDEVGjgdKaPMvxPI9Gs4Hv++RZgScEZVkxHU/tBGJb7TXs7uzy/rtXGE3GTKZj+oNDIzPhLR5mRwI3uiksUdqcX5dxh0eAb8ngWpgMxATa2gpw2sxEGNXjLMvJspwoiSmrkr3dXabTMXEU0e52iZKEw8MBd2/fpZKK+WRMWeRk2dxm2zMTaNmHV1VYaw1DFjbdHTa0Fbbjj5MmVvOttJJks3kdACrbYm1+VrX3khDCNB1YPpiqUTZIkgZaG+6XUZ9XpkwsQCths/tl9FCcUK/75+PoqMlqCM8zQagrqywtiKurqzzx+OPs7Oyyv39YBzDtdpvJZEaaZkR2Ybp16xZ37tzh8ccvk2xuEEYRSpVUZU5RVPi+ZH19g8+//DJxlBAEIVeuXOHw8JA0TU2Z2i6WQO3/dtwcddm2wvgr+rWxquPzgHluHDl4GZVypayqququLCdpUOQF/X6/3pdztR+PxwwGA7a2tijLkv39fdbX11FKsbLaI45jNje2LLlYk6YFURSzsrrKmiW/F5YS4LoE0zRlMBjU3mSz2Yy8yOludmi0Wgi8+rtrbZK14aDPzr17KFmxuWGCi8lkzHQ6q7+jQ26WTWCFMArsy0jQhwVHHyZa/Ifh2jgk05gVF7Xa+zJXKYoisiyrz5MLfobDIVJKdnd3awJ8ZjsXXeDsrr8jkrsSr7M9cSiqu7fcnGXOmVFOD7SzsgEtPMtpWiTcAlN1KcsSvyjMWi1Mk0NTN/B8CEqPSlUmSRCLQMikyotnTyPMeqRN242Zruw7tAGmfuSJ4EkUomSJ0AlRGJCnGlmWhHFMZC926PmIIMQPI/zAR8uCEiPEmKYZcRiwurJC4PtMRyMaccTGqU2qsjToiKxoJHGNXsRxZNcPQwJuN1u2k8wQ2JJGghDQTBokzSZr7RZ5ltNoNmnGMXmaIYRPGMXE9hg9NGEUIzyfoqqIIhP5z6amxttsNSiynGazUZeWVldXiaOI2WRCmedUZUUQBqRqjpKasjDQ7GQ8RqPwAhNMzqzLfbPRIAxDijxneDhg69RpiqIiaiQcHPbJy8J0migNsmQ+myzaXn0PaRcDqcwCH8UxAk3gwWQ8oSoKgtU1hFQMDw/odXpk8zme7zFPM+azOXESE4UBHhgiYBhQ+AU+gjhJTNlIQxRHbK5v8NTjTzCbzymlKWkqFKWu6jJUYR/sRjM2pGsnF+GbjgthOVlFWZhALTeoXqOREAiBzAvCpFF3OPqBmRiLrGA8mdJsG8j41u2bDPqHJI0Ga2vrNFtt7ty5yzvvvEuRlwz6fbJ0ytQGT5W0kLXwjM5Lzf9ZJgIthBUckuP7zn7GLMJmgpSm4066jrpFl5ZZCKhlFuxUTlGUpLO05rcUeUEpKzbXNgwaUpUEYYLvB8zTueVXLQVMdfS5TGz65+OjxrIflrCkZC1Md9r29hkeffRR7t/fYTKdMU/nNJoGdZ6MZySNCKkkg8GAe/fusbOzw+pKl0YjAjy0JfNLaTqnLl9+FM8zSVCn0+H+vXuMxmOjjl2WRn/IEnodedct+Kr2FlsIEDr0RLAIBEwAFRAniSldV9IkSDaQUFY80JncAvVCnOcmSXJlobIsGQwGDAZ9BoMBUkoODg84PDw0yaTQFEVBt9tjns4Yjyc16tHr9Th37hydTqfmVTlUxOwvqxFsV6JKYsOX9H2fbrdLu90mDEPKouDw4IAiy6iKnAsXLtLtdkjTlH6/z3A4rLvLXGB23NPxuHXKh/GZll//sFLc8eECGReguIDXHYu7puPxuA6mnNxFURS1Ir2pjoyQUprmJ3tfODRp2b3BBUhAjcw5VCoMQ5IkMfITmbUEsu2yUik8KZHSQwU+PqYMp4VGC98g6QITTGHuvyLP0dpQSDz7nepjUPYz2jRIaZvkaenkXUBrgRauw9gFTYLjhq8/8kTwKAyQZUk6nxLHkblkArPKaNO6qJUknxV4XoYfBIBEa9N1VRU5nWabUECezsnTOX4YEmjTXVKWJb7wiMPQmGMq47Pmurq0hsjeQFJKojAi9HwqWRH6PgEaaSHFMPCIghhPG1+5OArwbE3e9yITbPkBeWk4AkJLAs8jSWJjvZDltLodxuORyUo7HdulJUlCI/4VxrGRxi+liao9oy0lhE+cxEhZUUWBJdP5RlVVVqyvrdBMIvZ2dlhZXaMZm7ba0A/I8hRZZjYgMpO1Vhphb2DrSYOSJdlMEniCyXBoHsQgoMxL9vf3aEQJB3t7NNMpk+kMraFZGil8IQQeppzh4ObADyiLEg9jtukJn+3t0xwcTJFKsrV5irwq0NkcPwprnv1sNqPdXrOaJJIoNsGysUCISDOjfh1FEVJXyMpk+J1Gg/loxOraBr7wmM8z2u2O8T5KC9Ispd1tE8cRly+eZ57O7SSU4HsBT1x+lB976SWqUpHO50hVIKuSIsvQGIVege3eKyuCOAS9ZH+gdS3QqqwlihcYRM/3IzxrUKqUBqHqTF54PrI0C4ayyJWRNXCqxB6T0Zj9nT2yLKesKg72D9jZ3aHRaOD7HteufUCr20YI2NndZTKd4PmecUC3ZUgzlkuFi9cWry4hLfY1qRTXPvjgj2kG+KdxLBZMF0zUfnCWYOH5Hs1WkwsXLvDkU4e88867DIdj1tY65l73fTwhUDYQ2dnZ4YMPrnPu7DbtVsNuf7EQO9+tjfV1PvvZz/Loo48xGY9NEHJwSL/f57B/WJfNRqMx/f6h6VjVi7KcsJy3IPAJfXvMyjYp2LKMxgorVtKIY2jTpeoHRvqltCVAzzMLl/CE4aEo283qG+RYKssfKhflrrkNrubpnOFoYMp2UjKbzknnWY3UBUFAq9ViY2Oj9swz6JY8go4JYVDi2WzOdDrn4OCQNM1otVqsrKyYMpYNugCj4dTrsb19Gt/3mU6njG3gOZ/PmU6nNY/HBVLj8fiBEt1HcZqWfz6+iD+MJH48aKq7xqzFi0POHIrotuVQstFoZL0PVxmNRjX/qdFo0Gq16n05lNEFvm5+Oi6m6RoJwjA0AXklLYkIKq0RWhHYwFX7lrwuPLBin9oTOIsmKSWlzMiKkjAvajXvSpWUVUFVGW6vtp3jQpvwTFlFZWHbh+u0zv0sloyB9dHA9ZMcn6qgyUrxURY5gQeBJ/A8o9JalIagprUAVdSEYCkrI2cfCFvTryizjCLLEIAvBPl8ZrhMQhBYF+XSZujowEwGtnyj7Y1b18i1NjwVJdFVxXgyotvpglJUlUFnDDEO0EZW37RSKkCiZElZmrp1UZiMqydWqGSJVkZh17NcA7QmDA3RPUriBeFNa6LAp2WdxT3fM2U6WdHOWzVs63seU2B1bQ3fD+i0WvQ6HZJmC6kxWiwCAhETBr5BXqwGllKKNMsoS5ux+h4egtX1dULPGOVGYchoMGS110OWBb6AbrtthDxt1iPtA15VFUqa8+dhSoIeijhO0FpyeHhIEgfcv3vHKB3jkZUZpSyJGzFFZdpxD/b2SCLD8Smq0hI+K4QvaDQaFHlhyoXtFpPplCzNKNM5kyRhdHDIbDSmETfoD4a0Wh2UMhmv53uM+jHdXpcsz5jNp0RRzBTBZGy4Aytra4jIp53EIBSeL5iOx7RaLeIkobTeSZPRmE6va6+/VwdDQWCy/CLP7bUXZFlGt7ti/MXstcyytEaqNja2GA37KKmYp2byP3fhEdLp1LSAd3uMBkMO9/bxgxDPD+jvH9Af9On0enhCsHP/Po1uqyb7pla8rioKw3EzxmEIowJbkyqdrpRpUTFkTjMVLib7dD7nr/+Nv/FPclL4ExsGkDsaNNX8E29hP+TKQ1tbWzz5xBPs7OzS7w8RCOLYuLmjjdNBs9nk8PCQa9eu8uLzz7CxvmK3o+tFU0qJLCWBH3Du7DkuXzKowWg04uDAoDd7e3uMhiOkkhwcHHD//n2m02mt+J1neV02q1vaLXpR2UAoThKEZ5SaTZIW1cHColPLsH9dIBYEAUEQGZ6uQ908QRSHIIyoYZqltWK10orRaMhsNqXT7dJsNhmNxnX5T0rJZDJhb29vib+jahkRF1C12ybBKYqKwWDIvXv3yLK8DnhGo5HRJioKZtZYOLYWMfP5jDA05O9Go3GEm+Xm+SzL6Pf7dRdZfQ88ZFE+CWk6aTwMCTleAnTXCXgAcVq+79x1cHIFrvzmym2uw9Jt1wWbwJHtuNLqcpnO8ZuMb59CYSgSSis8pZHKN4KW9tr7nvWe0w4h8gHPFjMqqqIkL0pjtB6YkpyUJVIWpqPb0mF8zyMKQnwPg25pZQIlsQiaFBZl0qJORh1C9UmPT1XQJLWk2Uhqk9ayLA0jP/DQhUJpSRhEtqWXWkHb9zyjSyRKBv0BrVanLm3IuhXSrzs+tIbZ1BhBdrsdg174RkpcaSwMbKwFDEPfihjaz8ZRzDydUxbVUo24QGAEFSttMjFdSPJ0jheENFpNlJZMpmPWynXG4xFCGJPaMIrN4mS5Bp7vkVhz0DzPEUAcRHgIwiBABCZz9cIApQOydIasjHeU5xkydZ7lrK2u0G63EJ6PKiq0rAg8QRLFBIG1SxGmW0UK43JNENBqNQmjGC2lkQjodgiDwEg4tFt0ux3KoqAscrZPnWI6N0hT4Pt1SU1rS3o2Eq4IBI1mk063x3w2N5IFquL82W0rLucTxx6NZoIf+AyGQ9bW1mhHIevr6xRlQZbnFGXBZJqD1PhRRIQm6bSJopheo8lkOjUoGhD6Pvl8RhwGzKdjstncBCsW7p5Px2TzKUWZM51OieIYEOzu7OF5gjw9Y1rwNMznU0TgMez3WVlZIbLK7mEUM+z3CeOIwDMIoEHXKoLQcOWK3BDGwyBgPBoT2Y6VMAhBQGVV4WezKU88/jg3b96kqkpLyhRk85SDvT3arRYbG1tkeUFVFDSaLSqpUFXJ2dOnycqSlV6H0Bc02qYseXpr0xh2hglVmRGEIePJxPAYmk2qsjSCqDZQzquCMI7xhUc2Twl8nziMEJhyZL/f/xOaHf6kh2bZ2NrzjdhqvagJTbvd5uzZs5w+fZrd3X2kNIHPysoK08kUrSra7TbjyYg7d+4wGA5I09PGBqPmnPh4XmDQX9t95MoazitufX2d8+fPU1qRS0fYnozHzGYzZvM5h4eH7O/v14FAURQ18bksSsMe8YRdnBVhEFo5gAKlFgGctmzfsiwpq9IGTQb5cbyg2HbeBb5PlhnNIITxRIzjGKdWPZlMjAxCltWlPcfDGY1GNJtNYzAbhrV0g3ut0+nQ6/Xw/cDMNUFgNOqajXo9yK1OFWh6vR5xGDAcDhgMDmsLIt8mqGYf5mdnfK215tatWyciRsfH8RLeDztc4OICR/dvedsuQHKokEOMzHfv1HyzJElqOQcXNC2rrS94SVVdEnUlv2W1dafd5Eq7WVWhSmU0ygBpWUcaTNKlTcDsaYHWFtf0AhSCUipUWSGVRCgPX3omYFKl4QoLhUAZnrAODMigjcesEovuOa1BWf4U3qI857hTJ9l7/VHHpypoQkCn20EAeWEWyFIWJF5sFzrzNt+3Cs+eIggDw6WJzGSTF8YWJQx9irKCsqLZbNnMP7Q8EElRlNa41qMsTdkiCCMUUJQleV6YjpkwJApDyjwn8gMaSQuET5YX5FlOp91GeII8zQmtO7kxOfTI8ow0ndFZXaO3toKsCoaDPqoqTfnRM4EGSpGmKZ4QJogKQ+I4YjIZEwQ+jaSBr0CVJXlVgicocs/44amKqsyRZUEgqF3AZ5MpzWaL0PfMsaY5YPQ9ZJkyT82CqJR5KHzfJwoD/CQiSWKarRaz8YTJaMB8Pifw/LrTY2tri/FoxOpKj0bDiFxigy6tDGkcwAsCiw4CFvovrJSA7wui0OfCmYtorRmPx2htbCRAo/Oc9W6HThSRNBpMphN8pYgFxB0jxBnFMel8TrPVIs8zVno9mr5PKSt6vS7dVoMiK+j1OsxmU6Ig5tT2GZTGelpV+IFHb6XDmTOn0WBRO7MwrK6uUZWKIAjZ3ZUEcUQraZAkMUa3y5TpehfOMxqP6HQ6BEFgIX9JGBolZxknJlBqNBh1RkwmU7rNJn4QkOc5K2trDIdDSCKKdGZ4d2FIt9WirCpGBwfMxxPm4zEHe/sY89gYKRUHe/tURcmp06e5fuMG22fOcP/+fXor3RrZCuIQ3wuZTEdsbmyyu7dLHMX0el3SdM5wOCQII/w4YpKZACqJYmbTKa04oREnpkvPj+gPBn8CE8OfzFgWQVR2gvY8o8tl7B1sAUGbEoVrBX/88SeYTlM+uHYd34/Y2tpkNp1RFAWddgII0nnK3u4eZ7ZPs7m5USd5rihqUG+T5RdFYflwhoPZaCT0ej3TeNFoUFUl0+kFsiytS0/9/oDDwwP6/T6TyYQ0TWtV6dFoTJrOQQgqKU2i4fmGClCYphTXPStwkiqq7nJCK8tFMT50ge9DGNbm0U7dQtjuUVfeqyqHdEoj2msDBocUuQU9tB1XSRKTJI06KHByAVpjrKi8gDC0QZwtLbVaLeI4YqXTQSvJ3p4RdASD8jm5gSwz5cHQqrg75MVd9+XOt4cRvZdLasf/tvye5dfgKAfKBabHO/uO7/d4qdAFo04iIgxDQ1dI07qLcHl/y/9c+XEZKXP0gCAITDCLJp1OUEKD9SjEavQprZDLDlqewHWaaqs5JrWksvxUT/gIrSnLHHRFGAb0Om0ajQglK8qioiolZSGNSbLlrSqswK/WeFYgWfg+wvNrFEqdcG3+qONTFTSlVtPHEZKjyAiW+UFg2lVrBWaPODFcACVLoiikksbGYHt723ShBAFRHFmym/WEy3PSLKPX7bK+voESEDUSqjxHViV+aLyjlFLEScOKxxlYeJTvG1PMMCIIIjbWN00Xm1KoqqJhidhxnOCXhbnwpTEZlKpiPOwzGQ6MuaaFnJMkIYqMM7RSim63a8t4BVorfE8QJU2aSZMqzZBVSVUaTadA+LXcQGh9ohqNhCiQCAFh4JPEMaHvUwUK3wN8n3avSzrRVGWBb9WtS2mUslvtNlpr43GkzIOUpSlJnOAJU2YKAhNEVkVJr9NFA6X1ewpbLdZWe4RhyHA0QthMdjadU1WK0Wgf3ztEIJjPJly6cJ7JdEhZGr6XUhXz2YRmkrC2soLvCQLfA1kRCEHgCeJGwtbWBlEjIU1TQqFptRsMZYUvIA4DkihktdshSSL2dvfxgoC1jTWajTat1R6TyQSpJV4gWF1fBaGs71eXLMtpdZoIbVS/5/PMwPndDgBBs2lJtyFeu8PBwT5nH3+cVqtJr2e27QlIbKBUWkPopGECj0aSkEQx7XYDpTR5FNJqN0FWNE9vorXm/LmztFotGo2ESkrDIdGQzo3rvFKKOEoQnsf6Sg9VScIootfrMhwPef75Z4miyHTYWNR2b3cPQYvNtTW2t7Zsdq7JiwxZSbPQBQFpWTAcDul0ukSBT6/ToyoKhoMh3Wab1dXVP5G54U9iCKsxYxSSqzrxCgKj9yWVbRd3dkHK3EfPPPM0WsOVK9fwPM3GxgY3b94my3PiqKDRSIiTmHv37nH27Gm2tjYIQ+ON6fy7hDPONq1FSKnqRUpJQzIzooSm5BIEHu12i1ariZSrnD59ypRYLJm7KApmsxmj4ZDDfp/+YZ/RZEx/MGB3b5/ZbG45nywIxVLVdIXQ94gCH7SiyIyEhS8MOo6SZLaUrGv+m0ZVFenMlMqkdp1RCj8wBd+yNJK+rrMPsF1iGVmWMp06AnNwpMxkmXaEYWT0nDodeisrrK6s0Go26XbarK+tkmcp/f4BSWKCzLNnz9Lr9ZjP5xwcHNSkcIfkjEajRfmVRenMBU/LQdIysfnjIFPHh0PjXeCyHHi54M0dgyNvm/Or6+TVNfI4xMgN9/5lLtRyh2Bl9bvc9sPQ2P241/wgIExi5HSCstszXdZGlb5S0tIuDNpk63UmWKpKiqoil1VdyvN8rB6iIhCClXaLyxfOsrW1YTv/xuzvHdLvD5mXGZ7262sspaTUmsjz8GPjOIHnUVqtPvWjTgQvipLheGzFG03NtCgKKnszK6XJ8wKBZ+wKohglS4oiMwhRUdUk7k6zQeh5zGZzZmmKH0bc390xD1q7xdxqFLVabUuS06jCePV0ez2ana7pAKhK8jQjLyWIwnrXGTXxZq9HUeRMR2N6HdOqbmc5yjzHF4J2I2E2m7J77x5VWXHx0mUz0QioqpI4btQ8m0aS2DpzTlka1VYtK0b9A6hMR0SSJHiB6TTLS8OVMcJ1gvF4QpHleAiiKGF/b492p4MIQpSSjIYD9nbuc/7sGZpxw9Sq/aU25MoEq1VRIkujOlzmJc3EBIRREBpNJmE698bjMZWqGA365kELfILQI9CeIYR7Jjsuy5zAD1ld6ZLEiekGUxWdXocb167hiYAkNnpcaTZjMOwjhEev10UIj1argcLw1YwGFGTZjEoauDmdTxkNDvGFAY/n8xQlJM1eh3E2o7Xaoz8e0x9PaM+nzGdzRpMRW1vrdFa73L9/l2KYozDaNWVZ4ns+WgkqrfB1QBT6pvOx2TTdQZVkpdtldaXLqe1t7t+9jVaSRhLRSAw3JM9zsjQlCEJCP6AqTNBrgkEf5WnTLRVHpIGHwGhGFXmGLAvKvGEtfUxTAlriYXgmVWFazwNrp2OCrBavvv49up0OSSO23C2f4WhIOp/x9LNPMxwMiT1zvoWAOApQyggCFlqTxDHra2skUQJas2ad7LM0Y+vUltHA+REZhpQb1R5eNQLBQo+s7kgTi8W13W6zsbHB+toaRWEWorXVVcq8pCwqPN9DCM/ykwZUlSQKHVGXRVSw+AFO+knzgKo0GBkBorBe1F2nX2mToNlsxnQ2ZTqbMZ3Nmc5mllxtkKjhcMhoOGQ2mzFP07pbLs8ysjQ1tAW7P9dlZ4jBDlXy7bO/+AaBF1huiq7nci0tsZgljXv7gxBATXwua94dwihTC9+z4sGZPe4h/cMOjSShkcSsra5QFgV379yqgwdnPLzcYg/UdiK+73P79u0jBOnj3KWTyMcLNOnD76flj/0wJOaTUKrF/h4kq3ue4PimT0K1zHsX19BxnrRYCtysKa6wbhEO3VG4xhRwXbgaY+KrPUGjadbCsixAge8LOs0Gjchnrddhvdvh1MqKKS/3Vmk32oRhxP5Bv+4Q1e54lbINNyWhNTH2MRzkwPP5pMenKmhaW1uj1WyjdGVbLnPKSuILo0eSNJpUlaLIioXRn+eRlxmBrck3Ggmz2ZzC1u6NumyT6WzGqVOnkUoRR0b3QtpW7MCWSfI8M+RI4SG0Jp1OKYuCZqtJaOvI0+msNnzM5ymyKolCo2WUjUa2A8oYDpqs1Ey87W4PLwjxEIgwpNvtUeY5hSXtuVZSgG6ng+8HhI2EssjJJzN0abtYrOqXWzBFlNBuR8RNAzGXs4x0PmdlZYXJZEIlK2IRsrG5QdxpcbhznzgK0KqiqjRRaDg+ykLDaGNjEPg+SRQznU5p9VZIbDvvbDYjzzLCyIhstpoNNtfXCZLYoCllyXQ6Ic+NZ1QQNFldW8VDMBqNyLI5vW6P1dUV5tMpUWgI7GVZIDxodw2xvCgKFJpOu0EYhRRViR8Yg98sz9BCEwYhQWh4E92VnFKWtFot8AR5kdMQHZJmg5W1NaZpilZG8DHWmq1mg+3TWzRaTZJmA6UlSlckjdhw3DxjqDqbpaYEHCc0koTuygqdbpfhYEBR5DSTBulswnRiCOLLLb3G2btCScV4PLZCp4o4NhO2qkpC37PlHWOUubrSo7Kf8xy8b9E/3w/oWrK3skggmHsty3MExvoiikOroK7pdju0W0ZY8alnn6UoSw7+/+z9eZNlyZneB/58OevdIiIjM2tDdTUAdpMipRHHbEya0WiM+pajryMbmShNUzMaiqQZ1VgKqCW32O5yNt/mj9fPiZuJLGxdALpY5bBEVkbc/fpxf/15n+XrFxRGE5wjBI8tGtbrhCoqiqZhHCfqSojul1eXTONE8mKid/306V9odfhLDNnyU5Kg0HnT8f5ReSRFztKokILWOcqy4rPP/pqXL8Uc8tnz5xhd8Iuf/0Jabylxe3fH/f2drDulzBttchF27uB+tjn+xst7fObHTe8ddGSmDCitaFcr2tWK6/RUbFeUxljLNDmOpwP3d/e5tScqvbu7u/z37aLSGoYx815kvZgVe7OrRUwRnXIbMxPCVQ7VjSn7wUXP7LUW833ngmv25lNKojq8E/TN5JaeyYIdIbd7+tPE6XTg5vWbjAAptps1xMjDw92yvgMcDoclcsXkvWG1Wi2twhlxmYuG9xVQ77bS4G2k6Btnk+I3ipnfZ/x2ovlv//lc7J/Ph3cf8xz1UkrhY2DK1jNovajXYJ6SM+fo8Ulmj6WkxVZnvdlQFJbj4cA0jqiY2K1bNm3Fui7R3pH6jqauaHY7yqIiJYVz0q6dxkl8GJc0Bc80jlRNi1EivDJaC/f8Wx7fqaKpKArGYcAaTdPUIpFXPTF7GUlyvXhBjGOGiAvZSLv+xGZ7Id5NKTEOA6CoSiEJDl3HarUSDyBjWa3WuHGkO3aQyd1ay4U9ZYQgeJ9NL+UCG4YB5ye++vWXVGXJdrNePJh8VoGs12uM1Yy5ryzETo1B5PP7+wd8jAxZ4k5S2W9EczjsaduWaQpcXq5FPZWkfu+HjqZuFlWLCx4fI5uL3WKDP/Yj0yQeLjOEu95uJWTzcEdRVaw3K5x3i2O31hrcozfSvBDoUv7bO8frFy+4vLigrusF5q9yMri1BVVZSfacjxyPHZObSEmQw0MmHc/EPVNK0K01kvlW1RXjMLC7vOB+f8/gJq6vr/HHhJsmBj/x5cuv+eqrryjLimfPnoq0PqNj4zShC8tqs+F06vD55Do4z93DnnFyvLm9JcZE26wAxeyavN/v2T/cEaOX9mOOUPjVL3+FnzwffPAh3icqWzH0PRrFrz//nL7vRcZtxSleK1FurtoWYwyn0xFSZLNuc0Eo6NLQdwvR1fuwhCCXRYH3kZcvv2K9as94AxJubLSRCB0eNxbnR7pB7B3quib6gHMTu4udxFNoRUiJ07HDGE3d1Nzd3jKNE2PfU6xENGCMJNRPzksWX1L0hxMGlX1vHClEmZduoiiL9126/8mO+/sHdFaHqhycPRcnjzZceVOVf6AUrNcrfvrTn+Kc59e//pJPPv4Uoyw//9nP830it7e3vHnzhuPxSNtUGHtGLlbz2f0PGLmAU4/NPebtLAaPinOLSS9k9plwawvDei18oKsnl/zVZ5/i3MQ4TgxjT98J4uS9qNd+9etfcfNGDDdPpxPHw4HD8UjXdYy5qAJp381cmUXFlaS4VEm8y6wSjtiCXCRIweOT8MWslvZQmiM88ucnjxHzdaGw+SAtYduChszF0GwU+cUXXyzFQ13XCwl8tiSYzUPfVca9D3H6Q8efgH7zR4930au5CATAK6Y0G6Pqt1qTj+jSjC+d1e5KEfP/lAZbWAkQHxMpOlQwpHFknAbuh564v+fwZsvm6inFaktdlOJ955xwe1PCqvn7liczSua2c26htXzb4ztVNI3jhEngMqnSKI1GY7Nc1I0jCrkwjBaCpgpyoklIwnvXd5Jgn92gPZ7hdliKIK0l/qCuSuqqIfpI8gFba6ytOQ0dSsE4DpQ5cmBGgKZJEK5msxFUa7OmsoaYlQnNqqVZrQjBMY4DpERRSCtu6geGLIVVOSCxrmsKWwIyGSc3sbEbDg971utNRmJE3eKco7CF9HSVIgbpY5fjBMhEC5NMpFlVIQVW4Hh8IKXIRm/ROuL9Y4bR4t8xQ6D5MzKza2wSjyuT2w4ikR/wzrHZbiUjqeswoyzA3vkcSVNz6jq6rhO+k7FcPbmirlr6ozhrt5utIG2IfcBx6PAEggYXPZ7E19nhtm5bNtst9Xotqp2+l1al0XTDyJiT270PhBCFBF+3jGOgP3W40REKKWwKazgdDuzv7lEkytLK9uIjLoy4YSDlHKXSalZ1jZ+kCB77ntPxwMWFKBPv7u4IBFZtg9FqKfDJ7YT1qpUWREoodZVjgWJ2PTfS78+ZWPf3Nc4H6qamKEuMEZg9omiakmEcOXW9mBEaccQvy1Lc892ImwLtekXfD9RNjZ88wSdsYdhdXPD6xQtSjFhjsUUhJ7gsjIg+oJKW998PDEqT6go3CL+O3Fbph/7PvCr85YaQlGP2vzJIwR2yz9EZOfesXZaye3Fd1/zoR5/w4sUrfvnLX1GUBU3bZFWTytfokfv7e+7u7rjYbWitkMRTSo8n+rf+/u0jprOk+DOfoVnscR6porUWWRJvF2dzIv3MM5rRnjlUN8bI8Xjks6/+ittbaaV0J/E3mlt7h8OBYy6gJPfNL/J270WVPPvDkVv9xtjHQ9zc08sqvMUPivkzV+iElIZGNlGlNSY/1qwglvX30en87u5uKQ601ksG26y+k414WhDFc1Tm2zBRfHyI355n97sf5+227e8a39ROPP/dWygaMjVMRs+0zq3X+c/CWsuyhaWnPM+38OhYHgMmeznp4FCTFwsaPzGMA+54ZBoc5cVEdzxlXmBJ06TF+kbWpvmKyP/LiKB33z5d4DtVNLlhpF6vxS7/2AEy2TdzwnNIovKwEpQ7uQkf/SIdFd7OvagvskO0dx6FXCBi3x+Y+h4/jZS2YLfZ5NtJgTANA6vNiroqCT7QnU6kJItgURZM48Q//xf/nKQUY9eR3EQMnpQiq7oV1EBDqsVJ3FqLD5GqadhdPZFToNaEJHDzOIghnS2q5YIWVKvPcK5Mnk0maXvvhFmXF+39/T0gLcC2adhuttmzCTabDV9++QXD0PHk6TV1WzOMAz64ZSMoy5Iy+5ns93tSjKw2G5q2pTud2G63C1Fw9v/QxrA/Hrm4vGTsh2xCJ/L6tihoVqvsFyOu5hLKWWDLKmd3DcQktg6nrqMtKrrjkcJYqqrBO08MwnnwzvPZj3/MbrslIiGbLx9ecDoeSCGy3e7QSnPf9ey2WyE1x8TlxY7VbgdBYbTmdrzD9SPFZkdb1oR+Qlew3qwY+hN+GqmyieQ/+ezHNM0KkmZ/OKBSYLtecXt7x9On1zx79lRQtsJwsdvSDx2rtuVw2It5pTW4aaI7ndhsd1hbcNjvefbsGX2OQCjKksvLKyHLFwVNJe1n54QoDHnxGR0+jLTrNT5GRucxtmC1XrG92GHzST71GrSgRXX2olHKULdSkI3TyKuvv+b6+gkxBKqyXMKiTTKUZUVVNWhVUGAoqoJj36Eyd0D4byaHan4/hnNyncBvbpyPCqtcUOafoaTQtdZydXXF1dWlXF/ThHcuI8mPhcRsF/D8+VNW6yxuma1Oslng77dB5ggfHnkq54TmeZN5t600I03M7yK3c/R7Cq+5+Kjrmn/6t/9UCsmMyoTsl3Q8Hrm7u+PVq1e8ePGCly9f8ubNm6WoOhwOTNPANDmm4AlhWlpls7miMcLHmsOJRakX8vvIhHCtF+jmsShISzzW7I0nmYFyMJxl/eek8z5ztt515j7/nt/9zt+njpPX8bt5TW99Y2/xk775jr/v7X6f5/pdVgnp7LYmf9ZKgVGIEziP5dFjmTS/NrHcgUTfd/hhQBGxCiqrsClQotlWFbvCUgKD8+xfvub05oFjkBbzk6sr0FKsd8OAPxyJPCYmKKWoq4pYJELxPS+a5l5yYSxKgc9oUfCSI6aUhpjweVEJ3hMQN/Bh6MVYrTvx5JMnlGXJYb8n+kBdNbKBW8vtzY3I5LueN/f3NHVFXVeAZpxG/DSSQpVdSnO2F6LKOh4OghCVBd3pxP6wX/yfmrrkeNhTVVX2wbCEKO603TBSti2XqxVjP/D11y9Q2rCrVxxdhzYOW+ww2tB1HVUh3ks+BDrvqKylKRti7rm7IIhQ07SEGMWrJ0rvv6kq+vueqizExsAaqrqW01eWAHenju1qS98J92sO2Bz7XkibRQFGs3944OrqSnywqhJrC0Y3URuDLcXvpywKyqIStVhZMnrHMIwZnRMvFiFdQrcX+DuGiJscYQrc3d7TPHueDe38YlcQo/hrfXD9AW274fWrG+7evEEBm3bFs524hPennsvLCz5+9iEKeP3mDeMwUuiC7jjw6s1rtpstcfLouuB4fyB6zzj0GKupqwYDdM6DF3WHUYbpJCfow/HEw5sbirrh1evXPH/+nIvLS9w40p/E32ToOq6vrvBuFFRUafw0YbSckrp+ZDh1HPd7+mnC2lJiZZIIAIZxwhtLipHJe9xBsvqMNfK5xIQ2lna9xhTSCvVZ/t6PI7Hv0MZy6nppswHp1EmwsDGchh5ikrziKHlRx67juBeZ98VuBSHhBkdVGlRIFNqSXCCoiCkl+DelJKar35NhjFkIwyLJLhBUOKMZM69mRp6MIB0+CEetKCuurq744IMP6Lqe4+HEbrfj9u6G4dRRZrXu119/zV9/9inPnj6RAmHefR9TS3+vMaMxkAhRxCHv/l7PyLLWc6PrrfbMvD/L3jdvtGInojBilaNE1WqyV5VSCgpLXZas2pbLix0fPH/OT3/yY+FA9QOTmzgeT+z3D3Tdkf3hwMP9nv3DnsPhwKk7nRHOJ/qul9ZwRu3KsswIg/Cj5u1aKbI45Yy7BUx+Dtsu34qBmRG0R36aX6gMs5Ls3fFuofFNRdM/ZPxDkKc/5rHf5Tct8zije0mBReXoEiV2NuS8t7l4ytM0AUnJnlzVJbU2QuR2Dp0SJkUKBWtrebpq+XC340ldUirFMHl+fX/g5nZPnxTVZsvl9ROatuHU97x6/ZrTqcc5McIMzqNrxcXFBVdXTyiLkv/tf/1fvtXP6ztVND199oy6LNnvH6iqmsvdlfCOTif6caQqK1KQcNiUZJGylZhylVXBZnPBqxcvFkWZMYZ1K6ZlL15+zaeffioQ+aqhNoqylrydcZBe/Xq9pl2vJErAKFSSttEs05zN1qZJQnnXazHRHIZBEJpVg3MelAThkjSzUPj0cODzQw9GCNHr9ZZpHMSsMiX8OPLs6XOKwvDy5Svq0uTYkhEXPU2JcJ/KkjJKQrc2htIYcc9NInPv+4F/82/+DkXiX/7L/xNXT54wTj3H05FTdxI0bF1K+8/0S3yAhO22csrSGp8iXXfEFhofI9u6pB8Hbm7esF5vaJpmgbmHocP5kbKqCClSlCVNU6FGlnBJk43uyrKk73qRyxYlZWl58vQpDw8PFEZhUsQFx7Ora6qq5Ne//jWHV6/pjkcqW7DLad4pBqq6ZLVqcc7x8uXLRXKtjcG5CaPEMXfsBz7+6BM++Ku/4usvvhQjvaZBxcjh5obLyx311SXd8YAtCup6xel4pCxL/tk//ZTXr19jq5qLy0suLy8p12u6w57+sKddtdhC8sKcmyjLAuc82+2WqpIsp1BaPvnRRzw83BNTwFYFIThAEuGXJowSjkeITvh6RlzbXQyM0yTWGc2KPqfTF4WhLAuaVSOmiflns7CgbRuMsdRVzTQM1GVJSpG+75jGCe8jWkfevH7D/mGPSprL3QXT5AivA0nB6EZsVXBx/QRjNC9evfmLrQ9/7jG3qmbX5RhFDXreRkppJl9nXEhpwBNCxOjI5eUVf/3ZX/O//W//P7pOiqb9YY/3njYrHF+8eMH+cFhO0fAOV+T3HEvEBML5Oad7PJLCHwna+YbirBzjkgyQ8pOL8w6Ltm22GEiZ6+jP/ICEqy7P0dRV5vc9eysseBzHXBxJ9tz93X3Oq7vj7u6e+7s77u8fuL9/4GG/53Q6Luj2TJ4Xv6yYzSofmfJLIDFxue1cK5zzdWZFodZaRBozh/NMafiNn+873J73+Sl9V8a7hd/StlWakEN4tdLMjgKW86LpEQOdJ6ogrVBaQ1FVeKOYiDBNEBwKWDc1zzZrPn3yhKtSkCaPxiXDi8NAN4nn13azZbPdYOye+4d95jEJ+Xuec7vtlk8//ZTNavOtfzbfqaLp65dfi8/GdktKgV9//kuMldaOMlDUBcEFTLS0zQpbWMZpZBzEdsCNXjyGUsIY4fOEmF27C4npUFZzerin7455woiBnAsBvKOoCojSNvPZV0drxf3hyEfPP2AYBtZty3F/oCpr2tWKw+EgPXo0TbNidI6+6xfFRtOuMc5ze3vHkydP+OijjSi5jKVsS4iR7nDM6dQN/fHI68OBjz76iHW7wTnHqesZhju0EWWCMVLNl0px2j9QVDVXz54R3rzh6uoJP/3pjylLy939Pd6PdENHIhPeEe8khaKpKomjiUEMJ8sCP5zo7jqePr+mbVtub++wRjFNEiOjjKIfOp5/8JxpGBn8SKkNUzfQDz1t2+JD4P7uQF2tePr0OUQYukG4ZLWYsD3/+AOKuqAbOj7/1a+IKfLZZ59xeXHBy5dfM+TPb3IOozX1aoUpCw4P9wyDcJpISVqMlWWzyyamVYEtS/anjtFNDF3Hervh9f/331Bow4cff0jbtrz4+ite373h0D1QFeLKbLTmdDpRFmLwuT/sCTFA8PTTyP5XezbrDV134ng88Pz5c+4f7s9Oq5q+H3DTSHAS+zINPa9efJ2NEj2FhZvbN8So+PijTzkeTxwOe548eUI/9gsh2MeJlBRlVVLWFSEm3DiiteLqySVKRbruwM3re477ks2mEh5CShgV8dGDUkxjx1dffom1lpubW5qm5uLiYmm9ygHDUhrLdr2i63tCiozesbJyndnSULct20P7F1wh/rzDGHO2sSYJH838GpWR2yWkNMVFqSStsMg4OjabDZ/99V/z7//9fyCEyO56x8uXL4ghUhYF0zTx6tVL9vuH3PoqRTL+R7zeuRASmfY7CENKWXEm2XFLGyYXTedjtlHI/zj/hXD1tKZYWlgpU/geydHC4Qtn3D2V2/uSOyaty0uuri5J6a/z7YL46HViwnk6nXLI7h0//8UvePXylfjGZZFO1/WLPYgorc8LQ724W8+xNHN78dxj6Tx/bSarn/O+vonw/bvUaP/YxjcVdu9Fz4zICHRSBB6L97lgMsRHlAmWyl5nblOMkegD1hhMWTBOI2EccDHS7LZcNg3b0lKFiAmBsmp4ttvxsU8MN3d0bmK/fyBEzyGr4JVSlEWBycamMUaM0kJNKL7nlgNKKdq2ISWxG4hKctcSQkqbSWdJgQuekCKTm8REi8devlKKaRpzqvaWupU0+9PptNjpF1UlwYjWSuVsLSEmxtEha4hAtUVZoiDzfuQEOUYhpBftCl3X2L5fLrjRezCGSjdC5rWWumkgQtuuljDFsRffk6ptSUn8gTQQXMRqw6effsp6tZIcurKiblppQ/Y9MQZ88ExuoKoKKXTcxJuvvqQfJ/6rf/Wv0GPPy1dfY4zm4uKaRMIFT11UHB46HvZ7KUTaWjbFwuYTmCKERFmKK7n3jrqtmZwjEfnRp5/y5MmTbNEwgoYPP/qAdrsi4jnc3TEMA0UqWbVbqnKFtQV9J+3Tkz+htQR/fv75L1i1a5SG7cWOQhs2uwtUdvn96OqKcRyp6lo8bd68YbVqxQG+KhnHkYf9Hmstzo+00wprDKObiAp2uw0fffIhHzz7gMJa/o//+PccuxOvX78SZZgb2V3saJta3IonJw7muQVsonCqiqqkbldgDYeHfeaYbSjrit3Tp7y+eQWwcOge7u8Zx4Hr6ydstyK9FTKuzOunFx/xUVXw8qtX3N7eslqtaaqVRD/UGqUjITqGcWAcpX0w7T3R5w0qa6KsBVJAJTFYdT7gtSH4mB3d9yglGWPExCc/+oTLi8uMhrlcCEhy+uVuSwqe/nTCWp35Ymt0UeD8tHgTfZ/G+ab6SCA2qDMir84xEsHPOVhpsYqIIVDXjViprNdU1R11LXwzk1W1zo3s957D/sA4jDkpQOKc/tB9OKRHM8NZSTer0t7i3ohPQBYsyM/Of3++jn7TeN9ceFd19r6NelbdCjdSY2ZPpyhh0D5LyyX813E6nvirTz/l/v4B5yUj7nQ8cjgeOR4O7A8HDvsjXdfT912Oe5EYozlDbi6YQNDDOWJkLp7gMdh2UUaevZ8ZiXpXZfZNvk2/7/hj7vfHtgPPkaXz5363cJqnjlaPjCUN6ASGJK25dOasNbeQlRwigpsIzlEYg04RFQLaB0wI2BAoQqAKgdJ5TIgUZWJdluzaFnu3pz8ceBVfUe1rRidK7ZQSthC+W1SyNnVdx4sXL7i9+fZjnb5TRdMm+xOF4BYPjXMI1IfswAskL3wLsmxVabtcDG12bZ5PFlqJvLTO5pFzkjQ8KibmvJ2EQmlDaYT8TJSQylXTCoSN2AgkIiojRA8PD3z4wQeSjTaOKKMXW3utlKSPhyD/PTlxEQ9RFgwlF7UtCjarFZObeHItrak5psCHzFtCQn+NEafiStdMwWOKgqQU/TjgnOf1119RWU1Z1Uxu4tR1aPtoYlaWVoolbSRvTclnG2IQM0XvCF5c1d+8eYMpLNoYhn5E5QiJYRyWhWYcAw+He7SBsigEBTQl4xBw00QK4F1kGh1NW7Ja1fRDh06KoT+hlFsWUjcMjMeRV69ecXF5yTCO+CCRJHOWmwQoy/coeVSGrjtx8+aW6+trirLk5vaGqmnwk6eqSqp6xbNnz+lOJ1G5xUjSmsJYyqKAFIlassIiScQEVUk/jhAU3d0NSYGPfpHU+uA57h+oK+Ft9X3P6mJHu9twf3PDOI3c3vnFuZfMv3h480bmgZOkeKsNQ9fz6uVLXr35irqybLZttr/wxKgZp57CVpRFScgtIaPFS8tmg1Y3TgRA2URVVISmpaxKiqLi9HAUJapSNGVNXVRSSLuJ/nSkPx7QKmGUEhds7zLCqxmniI+RqIQY/30Z0+Romjla43yTZDF4tFajk6h35yJLaY01CpdvU1Ulu92Oh+0D1so1V9dC9hdFGZxOHX3fs163KC2tiEgivWePPKeGP7KYENPYM3Xckm+WhSPziElEFinzppR6bMHJe33cYL/Jz+e8aDonnr+7KZ//ba3JHnm5GI1JCvL0iFaBrOd1VVKWBdvNmk8//XRB/SS7rud4PC1qvTdvbri9ueP27o7ueOTUdbx4+Upanvn1zsXRee7cu+/vPP/t/P2fezadv+93/Y3+3OObaq75paT3FMSP9/0trzsJOGF0bkWmuXCKGWFSb6FM+U6QxE/Je09A2rlFkkBeGwKu6xgOB1zTyL4dQTmHSpJuEbxwgA+nE8ZaAgkXAyFBOQtelLSQ9/s9x1OHP/u+vq3xnSqaQAjd1poF5RA7f/lgXCYXixp1dqVVuf+u3prUc9ESQuC43y+cpNevX7PZbBbPEckyypb0ZYnzgbKs0IjkfhxGghEZfnAe7x06xysUZUn38MDh/oHr62t5Xmul3TeNODctZEMFQvwcJFRyVpSdboQc3bYtGM1p33F1sWOaJu4f7sX2oChxXtR9RuAFur6j7zu0NlxePUEZw+g9l8+e8/rrL7n45GNCyCGbRYGxOpOOBwyWpm0w2uQi1ROdqPgKLQuKc074Sv1ApWq2qxWFLYiBJb181awoi5Lj8YAbPUon3OiySeMTDocT3XHko48+ycGie6yXzK7j8cjz62tubm5Zr1pGJeZ7Klvtb7bbRd2y3W4pq5IQnNhNFKKucV7Ud03TsN7t+PLXv84uzRB9pKlrgjvy5tUryrJh6Lvsi1SxXq2IsSV4kca2TUtVluI4fpKAU3EfH8Q/bBpIKs+rStqLwzBQ1zPnRbLAbm9u0FqJbcBuS386iSkbsqn1XSfRMHk+T+OesqworGUcBrarDVUt7uo+eJqqoapaTl3PanNBWUhQcAgTxohviVGa0paE7D6NSgy9eDg1VU0/DLx88VKIzXVDWY6AIgSxqCDJwcBoieQpipKHw5F+GIgkur4nKUWN/sPhj+/wqMqaaXILB0ZS1aX9JO1sOTSkhUwdiDFgC4PWhphkEx6HkSbzIckWF3VdEYOY+BqjmNyYla0eqwopls5J2fxunpNReiZWLRvbkmXGWdGXHydnn6J5bOud/w2/WTgJL/hRkTa/snenxSyLny0C5jHlyKW339nj4y4/N2Q52iMHaT4Ul2XJarXm6upK1rR+YBjGjPQGxmnk9Zsb7u/v6fqeh4cHbm9vuc8u59M0vR1gnNc6KaQkNPldy4H570eu1GPxOL/HP6Ruevfzmh/znU9x+Wx+c/w+T/abirm3/abm2zx+H7kDi9bij6XQpJw1qFBoZYCY73x2v/xCq6qmqliMoeuqQpclwTlevHlD4TwrW2A2W9ZlhQ+RN13Hq7sHjkMnocBaoeYMuyyOiVqhrKWsapSSTM2YwIW3c/a+jfGdKprGYaQsBU2Y+9whBOmlGpO/tOzLFPNk0EpO3jxWzjOxeUZWTuP4lsvrTOyeHWBnuau1hSySWkmkSEahJsSRXNKco0SFWIPPdgNlU6MLKwWK0RBy+rIQBx77v+/A4TMpcz75DDnz7TD7OSmB6e8fHjicOna7HbvdlvmCnb1jfPAEHyiKgu1uy93rV6DEAFRkvCWSU01GUSzBzzl+Em0gVwugEPuA3JpcbdZUTS1t0yyuKMty8fbRlXiszEiKm0QdF7zHaIu1kXEYRT10PFJWliJ7cM32+NIKlE1EolEU7aplHCVUsyyz18pM/s/+ROF04nQ8sllvltfgndgcbDcbmqrGjQ5iZOxOKMQB3SqTN43IcX/kdDhSNxXF3KJUEp4avIT1brdbTieNi566qel78aXx3om5X3fK7vLwcHeHVorrp9cSwWEnCbvNLvSr1ZrCSmRM2zYYVUv8zGZD8CPt6gJbKUY3oCYlzvC24JR6UYyGgJ8mlBKVi/MTLoFK4hJeFCUuOFIQXoE452uauqI/nQhu4nTas1qvsUa+A2MMVVmKXDjJ6d8Hj5qz6/J8tcawatd/ptXgLz9u7+55+vSaoqgAxZQjnLTW6Lz5COcpLA7W4hEUSCmiFaTkUSnw4QdPcdPAi5evMUbR5nilmQc39IO4Jz+5wtgyb0LpreJmhpjmU37Kv8tsgqUtmOZdNqZsI/C2B4/cXOVKhXc2UZiLoHkDnYfKT7yoB+db5w30fYXTfB9QpHSG5OTfKR5//1gcKEgqk5CToAkhnK1POtulVKxWwJV8LimjICEEUYceT5xOJ+7v77m5ucmu5ntOJ/n5I3eqw/uQI1vGt5AlQZ9ipi2EXOSaZW8RFEuh9YyWyef37ni7sFKQ45AeC6PfbIeeF2fn340oCFnuu3xm/Ob3OP/uET3kLVQvvVX8zLeV+/ocJyXtaEnQCEminoxSuYU3v9Y8L4z4kNmyIqKwZY3W4p93P46wP1C8ueUUYde2TAm+PHZ8vT9wCgFK8Y+LuQOTpFcoQEme/DEl8c3Tipi+50iTUixGY3PLbF6sS1tIfz7JBxiSZNygIHhJVJ4Tq8dxZA7PLDIqUVWVkLiz51PTNL8ROzCOI9MwUtiCFAKK9NjKSRFtLYW2wnshMnQnCiOkXF1YpvsRsmOy1RKiO5/2fA7iFD+cannO3W5LUVVMw8DhcGB3sePu9oakFav1ihgixxvxOVmtVrktpbJSTKwUxNcoUFcVfhgoy4Lj4SjxDFWZfZF8XgiFS3A8HklResUyEWMuQOUUYQuLj0HUatZIq8g5gs/ESWPohwGjLX7yUvyVlqasSe2G4CNVYTFtweFhT9f1WK2pyxJrLOvVSiz6g2wu1mgGJ6afxggK4jWQQx99cAupVFAgac1WZU1pSw43d1hl8+lPitzueEKlSFOXjMNEXbUUthBPma5nGHrGyeH9yKk75sT0gsKWghhGeT5bFJKFqDQheLruxDhOOZ5hYHd5ic1xNLuLCzHYU5rh1JFiom1X1HWTFZrgJgkJ3W22bNdX+BiZnMcp4VBNwRGCy5thktfa9YCGhETUlBZVWlJMYnQ6TML9Ky3H4SRFvtLCydIFT58+pV6tJKfv4Z66qlg1G1KSwrXvu3wgUaSYKKy0LYtKENXJS+6fendn/E94/Lt/9+/5V//df4ctKoIPhAimsBRliQ6GcerpugGtE2VhKcoin4JFvVtUNVoZ6rrgb/7mJxij+PKLX+c5WTP0p7zhKg7HAzd3dzz/8EPa1Vp0azG9vS2mb/hvWPa9NMNIb6FUjxt2euf+C4p19ngpqeVnv4lyZF+m32Ovks35HCeDxehrfr1vbbzve5B3XsBcJOb/e6wd8msmoYymbRuqquTycseHHz7PB3C/tPj6vhcF3/09Dw+CQN3d3fHFF19wOBxwzmWyeZ/vJ6bAzo0Z+TJLoSMFpHzo7+M4yQH+saiWkN+4oFSz67bW54gQpCwumBWDkIhRCrS5EHpE59J7nv+xACJnxc3F21w4yZ+4fNdKkbsPkaGT1Ioq54UG7wQJNZaqLFBzZmuacxgTyQslpqpbbLsSwDAE0JqA4s04cby944vRsWobIpq9c9xOE84WGF3gYhCCv6ilJNbKCkUkEuWwHYOYPf8JSpzvVNFUlKW0QsYRcsGiMLmNZCDO1X/+knKrzgfPOEmxIxC4kHLnfn6ygkC5LFcUye2j4du5YqLMm3pSClNIG0pCgue+vGW/30vro6mWx+0e7hn6jrKUNo3P7rnzCfQRVamW9yjE9MA6ExBloa3YXuy4vbtjGEdWbcsnn/yIZ8+lvZiy26rPES9aqfyY8tnsHx6yakRlN3OH0glj5YjhXMCYia7vhdcVxOVYCks54YYYMIXleDzRNg3jMGIL8RHquk5eZ1GwXm0IQbK5CiN8BaVg6kdRkLmIwlKUFc+3O4rCYEzKYaBHrp5eY+xrAEJw+QIXFeJqtUJrzddff82TTHCdHYWFS2Ipi4rddpd5VdlvKgT6cSAEOc1BpDse2O+PrFYbyqKCpAg+oJTm+QcfAOLv1fWdZFOZgrKs83cSc16Wk9PmGKjqmvV6Izw1a8W3aZposms5SXLkUpKQ5VnePI4jh8MDpChFWYh0fQ9aM41iM6BV5NQfCUHyy6wRFHK33WKLkhAjxiiMEa8ctCUQ8T4JF1BbHvIcGMaBNAga8uLFC/6r/+5fQQx8/fkvIUbh5WkJnAjBS6J9khbIqm0gz11jDMpLQXt7++0TL/+xjn/37/4Dz59/yI9+9Anb7QV13coc8LLJ6HziV0rl+KYoporRi9O7SigSGFitWnbbDWVZYMwjqpIShBDp+4HTSYjMLLyRXPK8gzRBpp2kx6IHNSMt8FhZzIjS2ZvK9zvv+f1GU+gbOj9v1wK/b/H8NorBssmfFz9/zEjf8Bjyj7l7IOu6PPe5VcB58XQ8Hpc4mLu7W45HUe4dDgcOhz2HgyBSXdctbfm5ree9XxR8M69rbtsu7dskBQXIZzivYcKzmlt86TcsIua/zw033+WNPSoW41vvdd7fZnL3uy3Xx7bc/PPHSZEymjTbl0jiLuik0Rjh1gZPUAmd0WwFpAg+RhH6FIaIYcjpE5gS1jtiNXH0AZcUpp8YnWOKAa8EVVdK4QGfeyNLWza7g881tJrf3x87fX7L+E4VTW5yuIzUaKUhf3kp5QTmIBtA8GKkJdb7BqsMQ0gEFyBCaUsIsmDMJ3GxIIhLJIpSIv8WrxqNSHEDMUS8m8RIi/ToIlsUxCgkcGvNUnChckjnNDInAMzZZgI/p/waAj4EvJ8WaHtW3B0PB4GgtSSf73Ybthc7CSmc/aE2W7xzdCcJ4S1MQSJyOJ7kPZcVznn2+yNt29C2NdMkhWNVlAJzeydeP1XN9oIll0/iYYSvk2IkefnMTqeOwpbc396z2W5omxatesmE6zp2OyNFhY8kqxf0JoSESprNakOz2lE0K5SGqTuiCIxh5P7ujk9+9AnbzZabm1uqql6k3Lawi6XAOI64GNleXrJarzns99ze3qFQ7HYa70MOs52/E3k9hSlomhLnPQRBxozSjIOQ5efsOoGgFZvtjna9pizvKW3Fql3jQ2SaHFVVUjQlPpNW5wK9qSqGaeJ0PIoxZUbARP0kczY6B0kM+rbbDcZoxv4EKGKCU9cRgiycVVPgos+nWDn6LdyZ7IKvjaFUReYYJEQMKsTjaZxIEYqiWopDrTRUUDcN06nj4eGOrjtRWIOZZMJ6J6iWmb1wgidFRSCBn8nEijqrFr8v4+9/9jOqquKf//N/wY9//BOePntGWZRAQCmR3ksU0XwA8zgnMRyFtdK+yJVMStIW32w23N7eCo8w841SSiKlzwjVvMmq96As81DvFBuP/363ClK/9X7vG+dF0/uBxW96kHe3sPfd7gzx+r0Lpvdtjb/9zueu3oLSRM4Lp7kL0bYtFxcXi7XETN3ouo7j8Sj5lPv9Endzd3eX0akHuq5b2nzDMCyUiXMSPrBwn94mxdtMQ0lvzQPgreJrMezMr3vmdZ3bKLyNbL0tApiLxPl274uIOX9d8+1nAIEkX5Y2BqUNKgScd0whoGPE2rTQS0iy14lCPeCVqIVJYKzFVjW2rnNnKNF7TxdHQg5jlrJNBE9pfs9aS2iv0fJnLh7znAi/D+T5B47vVNE0T5Y5FFMhNvhz9IgQhRUC9GmMLSjqksq0kj1kDC4TN5f+cEKya5SibloOhyNl6Uhpyq0gw3q9Fr+ayS2QeIhhOVkURcHl5eUZGlXkk8pwVumHJevNOY8iYXXB7LpbWrlIhWs0EANYY7JiD5x3vHr9mqurK4rS8vrNa6psi9D1PdZ2wjvJEl054UgitDUFjTEYUyxS25TSWzEQolhJ+DBRekF1+qGXU3J+TOd89lKRk1Pf9VzsLkjA2A9UZUWbH09OxNJO6k491q4oS0MMieDFg0ZrUXT1wwREpqFDq4gpFNvthq+/+opVu15QMR8E1Xl4uOfly5f81V99xma9lsDh/R6S+Cg1TUtVN9RNw+3r16SUWK1WSwL7NE0Ebylsy+3tG55cXtM24v/lJs/Dfo9WYizp3EQ3jLSqkf0lKgpbUtUN2vmF65DGxMPhgXbVEGKg3/ccERVVCGEJM9ZaPyqjvCd4yaKbW7w+E/rrqsbokr6Xk1hZVQQECTXWCg8gBCbnqJRhHD0oKZ7mLCexDPAUpmK1WhOC5DY9uXqSX3egKHTesAvu37ziYf8g3mdlCfm0VtYlpERlSwyG4+lISJGyrlHWMnlHCKLw3G23f7b14C89Xr58yel45P7hgVevXvPpp5/x4x//Nc+fP0UpOQTJWvPYPrH2cclNKWU+yKyQtVxcXPDVV18tOW6yocp1LOIU2QTmhs8fPn6/e808qN916/ejTr/vKzuDxt65729Qcn5j/K6K6v13nuuHd5Vh4h31yGl9txB5RKWk/VjXNZvNhidPniyk8b7vl2Kq67qlWzCTzGcEai6muq5bfnZOBQGW+Cp5bb+pRJzfw7sF1VxwnYue3n0/54+xfFrv+SLn277bUpxvG2NEKyNoUr6dwGFSSEmGoEfI88XCPR2G8fEwMMhBsygTeglezt5gxlLUDTYl6f4YDZnna0x+b7lbNO//82ewtC3/BGyB71TR9KhEeIQIJaIgw4w5+FQbI8VQnIl5j4ova+2ySXnvKcqS9XbL0PdURbl4L3nvl8c9n1BS4Vs4a61YaygKm2FYabvIRm/O8oxMVutl9Z4tHl9LQnxZjCEhJ5nBTXR9x2qzYXN9jXq4J758yXqzwflJ5OvrFVVdMQ6TnGCVSJwBnPNoLdydpmpYtStCFKfwMrc5t9udZOGFCYJI1FPwHA576nr1iJblEfJpq6hqSi18mLqqefbhh/TZbK6qKpq65hSz4rCW1pjNMLQgecI7Oh16nIuURcl6vaIsLD6MBOfZbNYc9gf2zrPdXGELS0w+b+TZwMxoUdGFkInlmrIoKcqKpmlRxtDkOBKTT19ayUVmC0tZVUzjKC7wk2xqKhdoKUnRWpYtMUrhpzKfx3lP1/VSUHW9+Cy5SN91GJMjGGLEpdk2Qi2nOJ3NJVNmhRpjKHJxfDyeSElsBlTOVvIhoE0hShESTVUxuiCLB5LPF+ZWmpJwaElvl3gh5zwkIxu4tcSUsEXBOE0zn5aYEu16hbbCuVJGUeUFLqYg7ykljLYYhI8V8wZirSVkzoK1VpCr78k4nY7c3t4wuYm723u++uoFXXckhL/l+vqKqragIiFMOX+RZfMVBMEDGq2EI2KMZbfbLYaiwFL8juOUuZhhqSjeje397o8/3zs530uAtwqR9/1+/t1vokGWtm2Xa/yRHC5IzzAMi/XB3Op7eHhYUKk5b28uss6L5flwe14Unbf35uLnnHcrfLK3KSVzC/13IVznLb93bSXOvaeA3B2JFIXB2mKJmwFYrzfUVbXsN+M40HU9VRWXx2qalVBdbCGpCCG89d611lht0IVe5nlUwMz7sjy2p89e73mxGNOfxo39O1U0+fD4xdiiwCiyjb5Mgkf1gn389zBgghe4OwTUWRXuctuhWa1kwy9LyqKgqmuKLMOdW3YxBGbbguDFxLKua6wVspsxmnGcC4w5p0gg1mnKzs1ZVWaMEe+fGHFuyl+4EZ4Lsllaa/EnicfQ3YnROXa7nRQ8WJ48ecJqLUqlokhZuRaZnLwGa4WXNVfdwzDQ5Yv2k08+YRz7xazSREXIeHhZFAyjXy7OyU2C6Kf5ItSUVYUtSo6393THjs2TS7z37G9vSSlRlRXBe5GfKs04TnhfY6xc0NIaElKmzQQ+5x0JzTD2xOh4+vSKuiiF6GeLjNhEqrJgt9vhfWC9XnM4HAFBhaqqWvLmuhBp2lY29XxBFoVIZLVSbLdr1rvtkh2WSFlq7KWNq6QtKaR8mWdz6nkIEjg5D2stLjjqqqIwYrJW2mJZwOZE79m+wjnxuZoXXm0MGsSuQucCPCt1XJC2WIgBVCSW8p1YpCAfxpGyqKmamhjE5ytGT4xy2iuKEq0th6M4yhe5PWjLErHATGhradqWchYgKCkeU/C4IATZOTIhItEtMQhSRl6Mq0o8otz07Ut8/7GOkNeVu7tbmf/7Pff3t7x8+YL/+v/6f+EnP/ksI6QTMQbhOWVid0oinFBo4Z6lgFaGzWYrYoQwG/KaZSN2zr2DkMzEpG+/2PjzlS9/7DP9kTibUst1Of9bvedwDG8XUPNtZ6TwvID6bQhOCIGLiws+/PDDpbCY0aa5gJpbe/f39xwOBzHlfXjg4eFBgs2dW/hRc5twVum9+5yz6vu8aJp5P3N78N0i632f0btF5Pl7Fg8tRYpZSW0tXSc2LE3T8OGHH/Lxxx9DghcvX/D5559zc3PD6XRa8hovLy+5vr4W65z7e25vb5eiad57QEjis6ozBlHLGSM85ne/p+W/zzl5f4LxnSqajJa+qFZCStZINW5yH9fkynSGvFNiIYbZLM221jLmrLM6oxBD1zH0PW69Jqa0BN9aa5mcozudqMuKzWrFmAszIZLXgFn8i0BCL1X2cDon+83uykVRMvtYOC8Xgy0KDGIkGIIUae16w2q9kmLn668xRcF6s6XrOna7DU3dSLCtd6JmMmVGzxxlVeTiR9K5XZTnGYYBBdKqOu6Z3Ehbi6Td5MJFK8t2s2Vy/pGDozQsfJbI1A/ESQiO9/f3KJMRqUxinKYx974Dp67j5uYNdVNy/eySixhIEabRcXF5wbrd0fU9h8N+yairVyXNasU49GyfXDPsB8ZpwkdPhZxO2rbFGJGqNk3D7MKcYqLve2afq/PTX103BDcuSOBMcO+6jqGf2Gx2xBjph16oJkoxjCPjOCD+LxLIOk0DIFl/bVsgRTILRD0/b991koWU51Sa+Uz59czo6DRNcnKLScjw1uaYnUBVi12CSpq+PzI6jQ/i9ROSqER0JrUObm7fSf6XUlDVNWVRLYtx27bS0jMG52RRL7OXUEQxTkNGSFPm2InNg1GaaAoKFUFpjAWfIjEGMTTNCp8/BYfgH+vQWrNer1EKadV3A69eveTVq5cYq9AGPv7kQ4rC4oMnBo/zPs+T+VFkrfJ5bqzWK0xRLHJqbTQG2ehCDEtL7z8ddOkfx3i3aDpvxc3IzCM5m7eQl/M/72tznUe2nPOInHMLwXwukvb7/VI03d/fv8WLmgnps9DnvIB6n7HmeXtxfv3nvKV5bXy36JsLl/Mxr1tSoGSFeng8SHsviHdV13z88Sf85//Ff05RFPzq81/Jc2stB0WgbRqePn3Gp59+itaa/X7Pq9evOOz3ch0NwxLdY4sir0NObFvSY97hck3Ex+KQ9Pb3FL7v7bnCWmkhxUAKQYyu8mTxXk7D2gj51ztRPxmbK9Izgl/f9xS5ZzxX8FVVUa/XHDsxG0wx4vKXHMPjhNRKJoxAp2R1XqDNiEGIonwKMRAzGTciLcNhmChLURvM6VFzEniInkTMqgYpoGyRw4aVot2sMcZy8+a1ZI/pRx8Qay11VRFDpE/iv1MWBW4awVjKoswngoLgPNPkSEHk4bNvCSRBNLynXZUcs4GjNYIWzcb4KSZccOhasV6vCd7jBmn1lGWVX7twoowR9Ye1FluWlHVDimLg6IKj1op23VJUFdZoxrEHC822xWQZf4qRU38iRYWtxKJBCo6c4eYcq/UKN7rFhK4oH7li8+fTNE1WC45ScITAw90tKHGDv7vbs1pvQSkm78SRWwl3YZ438+IwS3C1VrRty+Gwl2I7RqZxXNpt4ziK260xOC9IZ1lV1FWFqqql7z9bQiglB4MYhN9kbeYoaYNWhtMxYqyh1rUYJBZiPhpJGdoP1FVNXZeMYy9hxNMguXF1xTiNrDZrwl7y8pQRAndCWnQmL+oxRYHEFaAVxMfbhBTz4m9QOb8sBi/3z+qW78uwxrBarQAYxonj8SQZad7B//A/cOpP/Df/9/+aH//4s9zyV7hpIkaf52mJUpYU5LNVRlM3LbYoRLeUkrRhlTh2x/i2XPyH8YeP8zbX+xCmb0KQ5k14LlTOx29DZc5/9j7S9mq1omkaLi4uFkTpvCA7nU7s93tub2/faunNBPTzHL7FauXsNZ2/npk3Nf+9GJuete/eLaDebdnJ5yBr6xQnfE7KUMpQFJq6bmhXKzabLev1ihQT4zRxcXklB8n8One7HevNhrZtefb8OZ/+1V+x3+95+eIFv/r1rzmdTigle0xMMZs1D4RsJaCUkoi0IeBd7owUj/xmk70cY/j2r5PvVNGUvMc76e2rHgnjS2kh0xolRlvTODJOjqqqKXQhlSti1FjkEEyQAmj2bEIp0uLILDycmBLb3U74Hkl4Qn0/iF+NEnfYGDJ07sVd23uPalv2hwMXVYXSmnF0tG2JDyPd/T2Tm1g1NetWNj+lVH5tBq2toO1GyNdlqYnzApk3rHEcubi4pBu6PJl1JpcLohaD9JKDD8va6pwj+kBZlLz86ivZC4lyijk8oHSiqRucD4SkOR1PslmbYuGHzReMUgprLL3vskJRionCFhSFKM5mhdt2u8WaQgqbScj1ITutG63F7DKvQSEElIYUPA8PHa/fvKFdbQBxOi9ym85NQjwexynzlQxTmpiyWWdZyEU1S4bbtqUsS4ZeiO3iZ+V5ONxTlCXtakVZSXagNprtdofLHKm6rnJhIM7N8+cZgsDsVSVopTxXJyaRfcc4OiDRbLe4yVEWBUFrIUqmuVhWWOQwMCsZy1KK+rpqqNs10wT7/YG6aeT3RUFUMIxdbr9JTpx3geAi6/Waum1wfsC5USwl0BSlRHP44IkpMgy92F0Yg8vXVFmWGFMw+ZEQPQol36mVyB6rNIUyC9diPnX64InOYWee3vdkKMjqN9lEFIqiFMT357/4OZMbiQRi9Pzkpz+mqmrhqC2b0uwMjlzbRlHXuZjSOv88K72UksI85UPYH6/F/2HwdtF0zlP6pmJn/vdccLxbcH3TYeHdguwc2Zkfe47tmu1wgIUvNU3TgkadG26K3cHjn7lwmgupubNwjkydt3fPX9f7ImPOUbPfbF3mDLozWoyxlpQi4zjyxRdfLFyvYRg4nU5y+C9LmqbhdDrhvOfu7g7nHFdXVzx99ozLy0usLTgcT+KHV1ieP3+OLazwwU4nORgrjQ+ew+HAw/39W4XiUuQq4ULpP8Hh4ju3wgUvbQ5TFdiqEufdrKpSSvgqktYtXg1aa0F7ovCEgvOUc/uk75eMqBQjU9cJMmCt+BwZw2qzQQFT1xFcQBtBO8rK5BBfUSF1XU9SSbykmprJO1kgQ6QfRowtRZZZWFTOcAOyXFNjCyH7ajVLxy0hSctm7Ae6vsstKYMxiqqs6DKvxntHN5ykBVlI7Il3brEu8M4zDCNKaa4unnDY73lyfcXkPJHsC6IVyhhMEiPLuq4pjc0mjrPDuPy3MpqyaRlfv6GuqyXrzlp5bQBlabm7u+fZ848xxpKi57Q/cDoeqOuG1apFK0Pfd2KU1vcMYw+jJ7keHz03Nzf8k3/ytxLEmDk6bvKMg2TcWZvkFJ7U8r1ZaxeIOcZHjtk49BmJehQNKCVcnG7oxVwySqvqydNr0IbTQfLWUErQMS+8pTnnSha0TiwmELl2UzcopbLflVkKyt1uJ1LcccR7h88ROlVVSfyMtYLqkRfMomS13mCnxKtXrymKkrppKIsSHyNjP4oH17okKLFMIMo8dtOEdz6TzEus1gTnWDctp+MB7wJl8ZjYTmLxASvzpu+dqD2tKRHvIU9pLaW2TG5aUFZjDAYIYbZT+P4gICGIghYlBg9NU1PVDcpo7h9u+fnPf8409aTkadqaTz/9JAfS6nziFwK9MQUpBYQUDmVRYUxBDHJt2GWDZymcxLf3+/NZf9vjvFA6/9n7/pzzn85Rp/ln7yNSn//+3ec9L8rezbKb/8wWBTMitdvtuLy8XG47cyWHYWAYBqZ8IP36669F1ZmdzmeEaiabnyNl50XGOYH8XJF3fptHNEov9AdQy7oaAuz3e/7tv/23/If/8B+W9XhWGs7G0c45TqcTr1+/pixLPvroIz755JPcDTCsVi1FYRd+1Hq9JgGn7sQwCDWi73tevXolHNH8nhYe1Nnnm/4Eh4vvVNHUNC1lYbDbDdV6RVSJ8XDA4ZbJZbRFG9l0ZtPL3XbLNI1UZck0TSItzz3hOdW6z6qyZ0+fYQqBTWcQ1lpLsJYU4frpM2KK9KeT8Dg0KKdR2rDbrSV3brvh8njAoEg+oLSmHwY22y1Xzz8A7/FjR3SjeC1NIk/WOTfOx8g4nHJl3qKNpRv6zNmJuKi4eXNLwKOMyqbY4iPlFZlnIr45wQeCD6zqlma1omlWXF5ccnlxyTCcaDdbPn7+jJTEa8p7h65acOL4Ol8Q1liMkQnvU8KWFdfXT4HEMPQLKX3/cEvTtHjvs7Q+sd8/EKPDFjmTzxpCEI8ppQtRz203bC+2GJOIceJw2PPkyROAhTyt1MwF0WIwmt/32I+o3CqbL0rvPW3bstls8F4KtjmQuTse0FazXm9Bw+s3bzC6oG3XjKPDn07YsuT+QQJUt7st9aqVdqfWGKWRdyYgQ1mWbFZrdustVS2Gpm2zku8qc5rmYGaVkTrnpXgiRVIh4oaQIsM0sVmv6YeR29s71rsnXGx3aFtQlHoGNSiz90ldVlidIEAoItF7Tg89KUa2K+G+TZNwy9brJwxjRwC2m634f8XAKmeebdZbnJsgigVGVcl7mcYJUkRZsTkIQQ4PKGlRnfMdnPN/vgXhLzxmg1xtLXXd0DQtMUTGSRDIECbevHnD//I//y8olfhv/pv/G//iX/yL3Fru8S5izGzSKy1P9HwAsTncNGDsI8IL/G61/Q/jG8c5EXzm7syo07uk6vfxleZD2bny7FxdNj/HN7X93odkvU+t9u7939cmM8ZQVZW0sDJF5enTp0sRNav19vv9wpE6b+vNBdd+v+dwOCwF1XnBJI7fglLPQpiyrEgR9vujmLWeoVIxSh7obBRdFAV1XXN/f794SMHbPlkvXrzgF7/4BdvtNvN/44JMHQ4H2cMvL9BGL2tSWZaculNGx83ymPP38K6Fw7c5vlNF0ziOpGSYvKPrO4ZRvpztdisfptYSXBvSMrmdd9hcSXfZrdVM0noT12yZFHNA7zBNTG5WTQlJd5aUjv1If+yWRe2wP5BSYLPbUDcNITq+evkVz8NzDLKZVlWdFTDCVznc3nE87FFE2rqA3LcWcrolIj4UXdfz/PlzTvHEOIllfN/3rDcbhr7n9ZsDdVNRNxWlKRb7A2kTZEl9YQVlShJnopXi+LDPEzuHzvo7vB8hRYyV28d4S0gq+z/JhuBNIEW5yNw04fqecZqyl5GQ5WMM9L1waPb7e0KQ4N0kQYCiXtOFSP5tIbEgYb7QBryfqOuCqpLW5Ha7JcaUYd6HhWRtbUHT1JxOJ9brtYQzxsA4inllCH4xp5uLluNJlCrXT65Zr9fsuwPBpWxlkPDjyPPnH7M/fMUXP/uFeE61Dav1ipDd1ZUSdVphDbYsFouLGCPH41GUeEPK7UwhYiv16MMzDAOFLSR41U/stluaXMgpYBhHiJHueCSi6E73HG+PbHYX4mk1jbx5/QZdROq2Eo+g2zs++OAjQAm6qKSoUxqi9xweDvTDgA9J8vAmR3foqCsxBh0nKeR0UnTHjhAcwzAtSODkJLxUo3AmYLTCliWmsLld4BZSqFIiJvi+jMeNTfpoSsGYW53NqqGqCrruyOef/4oQBcW7urriyZNrJC7kMQZjPgyIerIUkcOM/qUi+9CIT80P448f7yJM8/V7bur4bsFyzv35JpTp3cc8l/Wf/31+3/f9/Lx9Nz/+Odr1vmJrFhsVRcFms3krLHlW683cqLu7uwV9mlt+r1+/5vXr10srb7Y+mMUqRVEshVPbttR1Q4wwjlP2TDKLlYoPYTm4zoXLXMDN7UF9djCQw7Diiy++WA65FxcXbLfbxV19u9vyrH++HBbqus6UBL9QYuDtNupjcfr7z43fd3yniibvPRcXW1ExKWh9kwupRNd1KDQxCs+iLFj4TtYayMZl282GpMVJdO71nn/YVVXy5IPnGCUtieB93vgTV1eXFFo8Kbq+y0hQBRoeHh7wYWRyjpvbW5ybiElx7Drc5NhtLnj9+g113UC2lUdp6rpge3khi+TkCSkxDANPrq+pqppxdDSt+DEdTkd2W1GbffZP/gnBT6gYmMaRYZho6hqrxdXae0dQ0DQ1bd1kPo5i9+RKpNLOU2jJL1cJqqalaVv2Dw9YU1AUNdv1hnK1IqHpjwfC5NiuNzlbT1xgm6ZhGMWfaSZ+N03D119/yScf/4hpHHDBid9PUDRlRSIxTCN9N2F1SVVWS9tMG5HXD8PI5dWOm9sbgovLCcN78SDqczr59fX10lKS789QlgUpRU6HI36Uovezzz5lmibq5nHONG1DIrHZbHAucjgeuX76jA+efyQ+UqcjTSNE8L7vSClSN6W0wpJ457hRWm9DN2K1Zd2u8T5wc/sGiDx9+pT9/l5OhG1L26wyiV/4U/JdeVZtKyrMouXV69dcbC+oyxX7h46mbjmejtjK8vEnH3Pq9lS1ZdOuGfoBYmLoBrQylIXFatmQffCY0nJ1ccHoHG/evOYnP/kp//u//Xes1y1tW7N/OLLdbNBG8+bmZoHRyZtJYQpMLQWbcw5dFhSZ2H86nYgpUZYlu13L1dU1tij+3MvCX2zMOYLDODKMI10/yJHFGBEcWJN5ZI6bmxv+9b/+n3HO89/+t/8P/vZv/xajLSEkgpu95uZYGmlrxJAk+Dqrg+bMQiXyIeARdPqhlPrDRlmWv4EYzZY17yto5vVltg95lzQ9j/chG+8jlp8/x7t/3kWt/pC/57Xt/HmstWy3W9q25fr6ekGlzg98MwL18PDAzc0Nb9684fb2ln1WtLkcTu+c43g8YYyADRcXq4WLNT/uPM7Rrr7vmZ3K5+ediygz29tMEzc3N9ze3lJVVc7uFK5lu5Z1s65qnlxdYQrL7e0ttze3EvSehV3ne/m5B9W3Ob5TRVNT1wTv6fJp1hYyGfpOkJS6rjFG8tTGcRIJZFvTNjWHw0FaCYVAhNOZUiHGSNs0DOPI7cM9h+ORyTlZ+LSRSIym5WF4YBhGNpsNSsPxeOT+4Z6ysIglkUGhISm6Yw9Js1lv0RiGYeDi4oKmqmVTjx5UEnm39+zv90zjxHq3ZbPZ0PVDRlAitSqYxpGbl69ZrVq8c3z9q88Zhl7g0hwyG53LXlLZZ8jJCUMjTuiTczStOGc3ZUWIYldwd3dL1Vc4N/DyxSuuLq/pGMSDabfDWsupO2GNoapks3z5+iUXl5fEFCgy+XfKbut1XfPs2XPKuqJ/OGL03KfW2fwzMS1+PkKiXuDdEHIUjWTHhXBEacvV5TX7w4EEbDYbHh7uBJYOkdWq5Xg80J16jDWEUCwnrznJ/fb2lnEcsEY2/BQTd7e37C4v6E4dYChtYOh6UpDCrK1r6aWfTigldgDDIPYDhbFLrM7d3b1E11Ql//E//kfW6xVVXdJ1jyTIu7s7ClvwcHgQgqKWWA1po9b008A4DNhMeNfa5HYiTMOR0+mBN796w1//+EdYYzjtj/jgsKaAwGIq6b0nIJE4KYkE+DR0HA4nlNY87B/YXV7w8PCAsYb1usH5iZsXtzSrmuPxkDduLShc8PRDL35eynA8dthC02abijJHuAz9wFdffsXPf/mLP++i8BceYiyaF+m8PksbXf5ba4NOEuXzy19+jnMeayRc/Kc//RtpMxwGZpRp5sCN44TORqkiwlCYrBr6jY3gT+xL830b736+5+2n9yFLv+v+v+33v8/jvfu7d8nr7/57iWp6B7WSzkf1G8XezM+cSdszuXx2Np9/fnNzw5dffsmvf/0F9/cPtO2atu2Ei2RsjnMyS35qeYZOTdNEiIFpckuxBG8jaz6rll3uJKUkfF2QkGDZWyru7+4pq5JpnBiHccm2C86Tzcgz+Dtrvr/d8Z0qmqqmoigLur4TqDpGDscOW5QM/cipf5DE+ZkI7APDOHB/d4sPjpTky6mqSlCofLLYHw5cXl5RlhVt3dI2Db4MxHZN27YUxmKKkuAD9uGBOqsAttsLjt2JFB27qwuwBdP+gDaaJ1fPxNrAaDabLafDEWM0hTV0xyMuZ8zFmP17lMaWFW50RB8hRqxWlHUNKXJ82DN2Hfc3b8SA0U/UpaWpSlHC5Yw3ayVhWgrDKZO8oawLyrqgrioppKzGT7BatyhznYN9S7bbCy4+/ADtI4W5xwdHVIl12wAJ7wYKq/n0k48xZUE/jCQSXX8iAXVV41zg1A1UtRhOrjctxoLzI9pKjMg0ekgWrUTNGHKhInJ+nVuiWmwaCktZ1Uy3d+wPD0ASHyEf6LoHVtsdzXorwrRswjlHE1hjqOqavo9cPrnOeUlw6juON695uH9ARyVhxFEUlKSczB0c2hjWq0Y+y+Bo2iYTNUecD1RNzWqz4f7+DmUU24sNF7udyMZTQmtoVi1GG8ZppFmL07rKJocuL1YpJZSxjJOjrCsmP8EIVS32CbtNQ0xr2kZczAsrRqxaGawpsneYeJkoDTobq9pS5OvVuhZvFxWJKlIYSARGNwhi2NYYq7F50UkqkVQkJE8kUNiKzWYrthZ5QbPWkkgE70gJqqKgrau/5BLxZx3DWfyF1oaqsksbZVHwJGnXpgjHw4mf/+wXkKRle3X1lKdPn4vvlREX+nEc88m8E7Vo4ZdNsCgsi2GtSuJBl36ol/6Qcc5pene86030u8Yfypn5fW//7ut4XxH3Tf+eOVrvtql+25hbfOv1mvV6zfPnz99q/c2dnM8//5y/+7u/43A48vXXLxnGibv7+9xSLhZ1XNu21E1DU9cURcFqs2aTX9OMWM1gxSxAmaYph4MDWmVAQxTvMURimAhagu6HflgCl40xWVyhFh7VWwT2P8HF8Z0qmm7vbkWGmRfpuR+6u7hkyoaV8OgMHoJHa0XTrJB5KBYBVc70mU8PIQTazY7ueGSzWrO5upKNc5LHnzffpm5obCFy4BAorWXTtKIAGzwBL4HPhXByrNZEH/DjhNEK70YmlXBuQCjMWja9ql4KueTFk6Kuy2xtIO+9rWueP7tm1a5whWNyktOjtexydVHmCSiTrSgKCmOytYCoCKw2lEXJYAYmN2Ksloo+RrSVk8l6tWJ/e0uh5L5GK4xWVHWFsRJF46Zp8V5KKXHqTrjgKaz4QSXARyjKChM8JEFwRtdLJIdSeCffwyyrneNolJaLXC4kJ3l304k3Nzf0fZ89s2b/pIaHhwdu7+4IbkKRMuwOSslRf5ykOP36xdf86Ec/EuuBaSTEwCarMp5+8DH3r1+jYqQqCpSS4ssFBylgbUnCEKJjjh5BiWGb0poQI6e+o6hLJjehjaFtGk7Hg8zRpCnbktN9x8XFBSklDscDxhpMlpMX1hJ9yGhgk6MHjlRVxRzauWoqHu7uePXqFZv1mvVmzeQdRRkzsmYXJVdMEZ8S3fGEjz77TEHMvmFlXeGdEy8pY9BWduCYktBtgKiEhFwmCX+eppEUFePoREFHYrNaSbu065nG4Xu1gS9IdT7aaiPeZPLHQPZAC0F8bLS2OBd48eIl/+bf/H/QuuC//C//z/zNT/8ZZVNxPBz5+5/9jFevXgvvo6qkaI2BWSsnGVzy2GkxI/k+feo/jG97vGsvMP9sLmxmx/8PP/yQf/kv/yW73QV/+7f/jL//2c+5v7/PVIaOw+HAzc3NwlsqyoKqFGPPIrf1m7alqSWf9NycMyVJIajqiipVy5RWiE/irJqPMQq6G4W7KlxC9VaerBTFfjmIfNvjO1U0KaVAC1KkjRF0IpsHqqoSVCiTy2YFWWFtjq+QxSulRJFtA2YDOe+9HAVTZHJi664Bg0Kl7DCtNU1do1AYazkdj4tyaBgH3DiiM9fHeccwDoJYeYlZaNomx5I4yqqS/K5MADXGZpNDh4qP5mMoxeF4xGhBXow1YivPY1r2OIzi8t20BGPwfViKSZHqy/uUIF+DC56YxJW7rFrGYRJ+jvP0w0ihC8LkSLrAGiFsxxTpuhNFRi26/kQIgcur6+Vk0xQNIOTxqqooqwpTFHTdidKKpYLS4js9jtk/SFmcS3Rdv/SjjRb5+n6/hyScrBQ0yhZcPnkCURCpOeKkbYWv5aaRuq6oyuzMnu0mQowUTUuIUeThkPPcpMV2c3/HBx8ZDg97ri+uqcoim2aOxBwnMgwDyojkuzudFm8qYwzBBU5RlJSr1Yo3L17x5s0bTscjx+MDVd1gkllOVzM6MTlHqZCFIEWaqkFZIWuTzVBNUdCs1vT7/RLb451nu92xubigtJbueHwkiaIWE1cQC6H53/NcTykJCT8E/JmRnvee1WrFFCastpn3JuKAqq7w3tMPA0SwVtO2jSxMSuHHCTdNlGVFab8/nCaVW8ApJUKcs0pVzpDMuYDBo42SE/hK1HUxJb7++iV/93f/KxcX1/yzf/pfoNC8fPma/+l/+tf87Ge/oOtHZqJ4yg1AlQ8WxhiIv+PF/TC+k+Ndi4I/13i3lXfOe5q7M5vNhh//+Mc8ffqMzz77a378k59wd3dH1/VZnffAfn/P8dgtCj7vheYyDD2QKA6H7G33Nt9I1mub479mewMxfjZa0j5k/XpEqGaqx/w6bVaZzrfROocHf8vjO1U0XVxcSMFhpMVjrPBK+l5IqlVdo+eKWZe5aBAvlVnZVVUVTksYYIK8gQtKtd5uOZ5O9Pf3GKUlj6ssF3VV8J4UBMUZpkmy55RdcoCqVb0YSwrrf4cpCo7HI1pbiqLMMmVAs1TMgg7lU2OIC9QYjeTRBa2wzi2hjtpKy0lbQyRJcadUDuw1ufVjAE2MfiGXyulXTsartkUZzWa7w7mJfugJzqFV5MnTa1Ik+8hIO2YYe5wvsjcHoKQgmYONN5sN4+S4ubtHWwknHoaRaRxZry7RJhGTRSmBelOc6PsBks8nEyXIi4HoBOooq4rruiVMsLq4hOiJ04TzTjZwxGRys13jXLVcHtIzjwIZG0O73fDxxx+LwWQ+OTVtSz+IGdz9y5d4N2WDNoWbFWPWUBiLdxM66eWxldIUpkBpcSaPkFUlNU1T40NkcBNos7hspyhZeFppohYTylXbCFmy62TeKvE9CiESQ6BdrVjvdkzZHVcg9A3Xz5+jypKx7wleWogpJYZxxIcsozYGU1jKWvyzUko8PDwsn7WbpFg+J8SWZUmYRP0y+z0pFGVbYo3FKUcKUBSW7WYjbaSuywuizIFhHP8sa8E/hiG+WRVKK5wP9N0gggYtBycplAOrdcvu8kII9sDx2DEOI4fjAec9ZVXTDT3/x8/+nv/X//g/cpcR9aoqJN+P7BOkHwNXY/qhavpjx7dVmPy5Cpw/5nl+3/ucE8/nNeacqD4Tu+f1YbYduLq65J/9Z//ZQugehnHhQ93f33Nzc8vr16/ynxvu7u7Y7yWc+O7u7gwl0szZeDMXajZ51lnQkhZxV0FRFpRlQYgSAyUqPU+InhAVRkvGXiIS4qL/+lbHd6poOp1OzCGv534S3stGMZ+cZ5dvY3ROpZ+wVudq1BKdx6fHxOiZqFat1+xCpKlrkY0HUbUYY7DZz2k4HVEomhips9nkyVrGaUQhbtM6K/NcSqxWKz7KPCufzfCGYaDO/V5ypSwhviUqR3+4rFQoa4E3Y/YVWeU0eq01LgSJWrHiIOzGEe98jsfQeQF/zAQLwWNLK1l3tmAcOmzTELyEEa/WG0CCRFMUd/F5QrdmRdM2UnRmawYSTNOIz8WTtZZdVmkcDifxfMob+t3dLZMb2G7XC6dnmibKQnrpZVkiF1EEIhcXF1xeXuLGkbvuyOnmDfv9HqUV24sddb5437w5Lpv10PeCFk4jRSHBvlPXkYK4pRdFge97hnGkqEqePnsmfB9bcXl1JQtAdDTZfDPGAFrRTz0qJWxuo1VlKVls3lNVJSFGjt2J2zc3xBB59uGHXH38EeP+gddffAExLTElZVmCB+cnTl2isJb1agMxcnd3R9uuKMuCcRw43d8Tu55Xr19zdXXFzc0Nm43YO2itl2iT7cWF5O9pLRE5ZG+rcWIYAnVdiSXGwr/J0nYriKv3nqYRJPTRNRiMtQunQIpyTUiBoe/pTqK4Wa1W7Ha7ha9w2B/+UsvDn30cTz11mbKx7iz91tIAzetLVVc8ubrmo48+5vLiAqUV+/2BvpM14OrqCdZavvzqK/7+7/+eFy9eoBTi7F5XFGFCKbnuZ58x5xyaP4x/88P4YXzTeJdIPqNLM93lPOBea53XCGhaMWoVU1xZj7ebNVdXV3z00YecTp9xOp445YOVeEIdeHh4oMvu5g/7PYfDXmx3RuHgDv1EiLOqUaoea0ToVcdKuh8xLvYIs+Jxtl4AiSUri+oP5qn9PuM7VTQFL6GwReYkzd5KoLBWCMXLF640Vhl0qTF2loea3PaCc88IpcQYcf/6tRhrFZbT8UR36imyiaDWSrLuVGKz3ZCiuFiPU0AXGov41vicVxdi5HB7y5g3OJFo1ux2l0LOVjIZp4x0zUVeXVdL2K+2hrv7e8qy5Pr6mjqJtHPMEn83jkxhQjdQGIO8NUVpS1GexSDFk5GstWGcwDlCEFuEshIvmMIY6rxpjv3A5CaGfqTKhmZ126AUODdxc3sixCg+G0VBu2qYnKiDbFFQluKhcXt7y8XFBff391Lua5b8trKU28WgCV7iaYIPGA2oiPcjk5u4efWK0/EIqeTyyVraGzEQvON03FOVFpUiuqqoUiQFl32cdgzDwO3tDcZoNts19/d3TNMk6FNGy5xzlGXJNAinqh86rC5o6poUPIfjnnKOUUEuTI2CmJjcwDiOWFMw5iKtbRrc2KOSJ+7v6e8fpNBdreiGHrRGebdw7pxzHJzDGsOmXbFei/AABU0ucqKCq6fXfHD9jFf6FZc5w2mJMgmB4BwDUuwqJZ9zArEcSFJcdaeTRKEYI7En3nOfjePquib4kL8ftRTxxphs/9ALchnAT04KvfWGpqypGrGa6LNXVvoe9Y2MtZiigOCxRhyMpS035pO0kFTruqau6hzWDW2zom3WXF8/5fnzDwD45S9+yS9/+UuUktbner3GGIXxoA1s1htW7WrZpKyxPziC/zD+QWMulM5DiJeM1TPDz/l256TtEGPOnNOZ1ydhurYo2JYlu91WrDHO5mhMkWmcOJ6OC+L06tUrXr96zc3NDQ8PD1mxd6LvB6ZJLHyC9znBgVy8SdtvGMYciZbyAV8RQlpuQxJ1+rc9vlNFk7UWFzyp6xbil6hKBNYLIYh6CPEeWqrP0jCN02J2uFmtF1SqruvFGfxwOHB5dQVIa6qqK9pGUKjudMr+OsIFmYuzub1B9mepq4pTL5lr88lwtg4IIdA0zcKZmaviqqryZi/FWVmWC3JzcXGxxLocDgc2m83iAhtiRHmoiswHMpqUZsm6ztD+o/2/LQqqWmSg49jLpujlvfkQKCtJlL5Y72gbCSI9HI90fUdhRe5ZVRW2kMc/nMQfo64rCUkOEaXFMmC92fD02TOeXF7gvZyWUZLDNU2OvhspbENR1GgtGUFKSUFblCV1U9NWtRhIRkthNKfoCSlgMQvKWFYVeE+hFC4jkCq3C601HA7idjs702qt8UG8rU59R9d1XGwvWK/WuMnjx8A0TiiVqMoKoyXPUGn5bgorWXcuJFRRURclY9djS3jYP+BDoD+dON3vceNIu5a5tmpX9H0PMaIzR0B8tcRyQgog4V05P1HksGaQbDpx8T2x2WzFt2Q2XM2t5RAjwft8OJDcpRgkL9BqQ9Sa9XotUHo38Pz5c+pacvOMMbx69YrVZrWYlarcKnTeLQaY7arFW8fxeMzOvYH+vmeaZC5vdzvK8vujnvvp3/wNq2ZNDCJ1NkrTdSfuH+4yl0OiLu7v75cWfoyRoih59uw5V5dPuLy4IsbEz372c77++iXPnj2naWqKwjIMJ6ZppG4qmrZls92KWa4yP7iC/zC+tTEXSPNh7F2n9Pf5UXkvfL2kADTaqJxVaiXTUinh4qVH3qaOEqRbNSW7ix3PP3jOj3/8Y8ZpZBgH+q7nmBHs+/t7Dg8Huv7E/e0dtzc3TOOEm7KxtXP4SZIKyrKkqeuFKzWOo3RqxgnvvuecJmMVVkk0SeLRbl5r4bZYrdEoaTdlozjvPcGFxZp9JqgdD4e3ZKfLhMmZPs5P2biuwzmfJZWWYQxEpM1htJFQzhQpywpNwSEjS1KhA+gcP+KFUJ4J2taKSV1MAQUYKzyF06mnrZslVDglIQVXTUPVixR5Vi6N07iQ0YGlDTeOI2Wmj858Gu/9UvN3XYfW5MDZAlPIRC+rCpc9rgSFkfestfCfUAqXkbSyKs/afkLQmybHME65RanouiMaJR5PNkEKZ3L1MkvxS9x4wLkJraGs7ULmTiFSGAu6wGakLqlIWRSsVg1KZb8jJwq5FCPjOJBSZLNZL2RlIJPGczu16ximAVsKzBt84H64py4k+yjFRIxeCIUBabfEhMpqxBQDiiTFKbKYNHWN7zrWqxUKOA0jSks+08Phgcu6FMPR5EErVNTiM5IdwEMIgiRAdtktCF5ajaEs8T6w2mzwMYrHWOaSBe8x1i7I1cxFMFbabzElhozIzkVzl0+UM7QtRfS4uLcH7zGFxRZSrI/9wISiMAV1WeHsSAqSd1evKxKrt2Tx35fRNi1FVeG9ztYhSYrMfP3Mp3Y5JLmloH/69BlXV1f8zd/8DU+ePOHu7o6vvvyaoRv4yU9/IqrW7sTh8MA0Oaq6pK6b3LITRHi2xfhBPffD+IeOdx3Q4W0n83fVdUu+Z67c04woKZW5ROcVveSnzl5mSoMxBcYK33K1WWXENC1h7F3fCeJ0lLbeMQfzDl1P33UcDoJUDVmtK1zSlqqqpKMy9BwOB/Z78Zr66uuvvtXP6zu1wvngqatmybCZfOZvZDO/ol2hPIxTl/1rLIGzcFzkA7bWLq2IcwhSK7XI35tG/Hhsbl8UhRWUKZqF0F1WNUarTO5WpKiWx31soQCoR5VNCCK71wUhegiywVlrGHOhNBM+50VXHkvQrbKuaC+21E2Du70lZpt8CestUDn/3BhDSI8us2VZYssSH6R4W60aURKWBWUQdEAbg1LClRJyfFa9GSF9z22l+QQSQmC1WmVJtaAxwyh978IatFL0fU/fD7RNSVmKUajPxeo0jmgVMdk5mRTQWgqwoR+yU2zBNEUxoEwxS+NZcubEAVe8nay1pFRhrVmQPaUU682G7nRa5oACqrISzyQUdVEz9ANN05JCYhoG8dSqGikG3Yj34t9kjc199oQ1CkXCFobgA3Xmns0mn7YwuGli1TS5TVOJJYGWJHttNG4U4nXTtKxqMWFVWpNCEu8uWwhPKYlz+Sy5neetyYGYs8Ll3AVXa4NKiWQeF73Nak137JeFcIbkV6sVhZXg2LoSya9PYZkjhS4otCGGAEna32b2U0lx4Rmczj7n/9THqe/p8+k3RBGJ7B/u2T/c4yaxnhD0V3zDvA8ZmW14ev2MTz/9FGstX3zxC+4f7rGF5cMPPySmyOvXcu3HjATKxtCIN5fJ38MP44fxLYxzL6d3c/jOLQFmTjBKYQpLoQXRDjEIBykEYgiy5yUFKknXR+nMrc2K3hRFwZ0iJEHGtWEx9m3amsvLC8lNDV5cYxN4NzH0A91JOgSTk/2yLMrcrbGEIIfB4+nIV19+xdcvXvD//jd/961+Xt+toim3Deb21ZzKDgqVyAhBXNAcbTTBPSZUj6NwT7z3QujNifXz5rq5uFhaGsboHJjpJDwzeMZJUAyl0tKS0GWxRLGQJ5X3MxQvuVJz8vtjNS8oVAzSBzZGfJ2WpGaFIDuQDfIibhzJCnLhMzmH0ZqirLDaoJJMcJ0X6fmxXEYlKmOwRgwW59OCOHPLqVgb8bDSSi/EdznApkWZp7KZWEqJvhP0wlhL35/EJ8hKLIRWEvnQVDWHh4PIpLWRYqWuKBNEn/CTIwYvWVtaSyQInrou8H4iKiPoUiQT3IEYcdOwFC8pBryblm5FVZXL5z27guuMOPZ9L5YQ+f3OMSAxCZI0TZNE2XhPXYt5pA+ONMnCYQtL2dQE53HjJD36acRktWO7XrF/uBdUsyyJPuD8RNPUotozGhOlMIW0fP/WWuo8F5XKJzaFQN6ITUVK8S3jxHk+pfTYhpO2tIQ6zyTxiBTlbdlyPB6zIk6c18/JnlUlpMl5gRT4PQDprWxDP2a1o9GgkmQVpricPmeE9PswulMnIdvZuXjoxatmjowos92JUoKEe+8IIXJ/f8+vv/iS//1//3es2i1ffPkVJMXFxRWbzYaErC8PD/dUlSTcr/IhYVasqh/acz+Mb2mcx7Y8yv0feU7zYX+JecliJWMM6Lx+phmdmhGlua0ne5gUTWnxnIO4cJ6UEr5SFhDn/xY/P2tEgGK1FZHURcT7mPf/zEtWOrvnK2IQcGWaBp49fconN5/w//zv//tv9fP6ThVNxhYLh8g5R3AeU9VCeFaamCHwqqwoypKUIjbqxYRxmiZM5pCoPDnmCVEUBavVirHv82k6GxzOnk+5R1tUQuaMMRCCo0h2Qaqqqlps3+f4A0ElzFK168yNSZk4Nxck59LPReVkzONk0moh5k45/FCjsgJKEVwQMrUx4hWjFdaWj5LRmPA5RLEsH710vJfQQx2E+6KNIDbRS0szRSkglVKLXD8huUIxBKLz4llVlKxWTW6VOvpTx2a1Fg+ptsVoRYiTFFTG4KPHJ0FNcLKR931HWRqMreU50SilsUaLdFRHQvT4rIYjpGyWmT23UIs/03zhhxAYs0Bg/mxjlOiLbpJw4dJI4T0rDKVQlEXAWIu1wpez1lCUhRRumTOktMT5BO8w2cupKAoKaxgz4T+lxDgMS3GTgJAz9LQyRCUHApKYdkYARSZbRokziZHRTQt/bZ5jIaOCirchdeYCPRdlZVmKPUfXY7TJ8zst6si5IJsPACEE4SHM8zOKCELlolopAdWtNSLtTVGuKfv9UXX140SZFHVdYY14qFWlmIbK9WdzYRvzhgMxeu7u7vn7v/+PxJC4vLwG1BKEejr10rbPDvlVVbNatZRVJdma2Y5CiuQf2nI/jH/4eLdomjsw587a85j3SymUJH0g5XVBaZXXbCVzcymE1KNZpSLvf/M6oZeiiRSz8lcMjIXPmqk4Sq4HoxVVVeQosjIXdeT2XwKjsEnc+derNZvt9lv/vL5TRVNRyOLvls2/XBb4JaUaceWNwWdEwVBXNZObsueScJu8c1nF9phU7qYRo0VKHmLIflDVUsQUVUVZl8Qs0Z4VZCVkJ3CLc9OiKFBKAmhBJs65hb9U1GqRcpL/22jJZ5vfkxRaijK3BJWStlZZloTcmiMJ8Z0kBnvSjrQZCyVD+kpaQ8jvZyShbkp8zuELXnKx5oslhLR4B0XvBenL7rCqrulPJ9yiWCS/TyH+dV3HKRtBFqYkJS/OrimgCFnGLm3JGKSgs9aKWi1/FkLSn0he0BDvHC6MoCBG+d4LaxfIeM4TFHt9zTT5rEQUblHTNIzTSJrl9H524y6hgNKWOahZIGbnHLY0ORDYC4ITsjNtCBhTUNcVymimIL48ZVlSWJstCcJiRTErIqNUN9kSwsjCM7dRM+/OhcDkJuGP5dbqXPRbLciomVG/GJfveXYXP+cgLIX4bFSXW8XmTBUzXztzMbQo54LP95cC0UeHVQVlURKCQ9vHQt65CWXA6m9frfKPdTgfsYWiKGtWdcPV5RXrds3d7Q3jKGpCn1sMEhM0F/KeV69eczz2fPjBR3zwwUdcXF4w9gO//vWviCngppG7uzvKcl6iRWwiXM0c+L3gqz8UTz+MP3ycWw3AY5vunBYCj9Es82EtwXJglaJpzr4UYCE3KB4RKMTRfvZlMkqR5v06SfftkZ+X8mPkfVnL45IJ5SEp2eMUzHYE5APcfFjUChQGrH4LIPi2xneqaJqGER8EabCFhMfGmBbl3Lz4ey8tFh8CpVLE7I80K+SePX2GSHcHilLiJUIMjMeBKqeShxCoypKinLPRAiSxGZjVT0UhcvMxk5+Px6O4p8QkHJe2lZbGNEGWqj+6l5oFtp/ciCyIYPJ7IVfyi6QyG3nWVc0YBLHpxolpHMWzqG6WU2jMssuQN3db2FzU5IDgKL1nhUSrWGMIzmGURKaEacqE3syryI+jyAkbSYqHkB2u67pitVoBChcdVSkKvZSVYtMwEdKI8wMpBkG0tKEuS7QuGEcHMWGMxWrN0PWcTh1lKRwgo6VF13UdITrqRny0gvfy+LkVGbIcNsUIKkmIcYoYJQTlGfVRStAhbTQhROFxeS+FTvYGiVERvGMYgzwf4kk1F9ExJerCUpQlzjt88LiTo67E4JIIY9fjJicFGoJATd4tnIHgPS6jn1VZLp5JMct558VrhsiFuyau7korQlTElBbHnrnoSblgjykRQsQlaZkZJeanwyCS4XmBnMUR52GeSkGM6rFdmITkjJbHGIeJEosujJi/Onmf36ft2xYtIK72ZlXx9MkTmnKFVpbjYU8ILjvg9xyO+0WUoLUIFo7HI6u25dNPP+Xq6ppxGBndQN8JR2qmCcimIUiVnOhNdu2bNzwx//th/DD+mHGOKp0XTt9UNBkteRlaGSmaBOOXtj/CP5qdR5aiLEleotGyiyxFToRlHs+3zT06ndce2bpUFj8kiDOfL897NZdr891zAkKU9e/bHt+pomn+VIzVmKSz39JEUVRLK2WGF43W2LzpzDwLk8nYKW+yQig/a43puc2gUPGRJHvuwuvGgSKTque2iwsBW5VM9w+sVit8ECTFaDFxnJ8PbdCFtO2MNRgrvVjnpsVDpypF4WcLS1kWuQ0oPkhDP8imnK0MSFIwNnVNXZWMw7gQ68jtRLKlfJnVUKKykqJIYUg+QBQ0oyxLoo9M00BZbnJRF5aJb3K70DtJo9bIxG9yOGPfDwQfqNYVuwUWlZO1SiJJVVrnoFlNYS0gikel5YSQQiTmtpZwQBLNdo3vpaBVqszBpXqxfpjbFYW1wr3K7cQ51gYABeMkRbfSj1b7c3s0BI/PbShprWgIghbNJGt5TVLUWCOtFzdNDE6y+IahZ2JA73ZE4iO6FCJ1I2amPgThSJ6d8myeizPiJwuGzghXXFqKRVGIR5lS+b1qUohLJpM1Ztk6Vf5TWJEBa21ymy0u87G0xaP4QM2iCPMWXD8XUKAIKeJjwKpi4SZYXWTeGEur+Psynl5/wOR8LtZB65qqDDT1DpLBFoq6Ljie9vgvPYfDQ27jS0ajcyOJQFUZNpuGq8sd603L/f09+/2e0+lIjIG6btDaZi5c7kWjWBJ7fxg/jD9inKNM77bgzn8HvFVEkTKylCAFKYaY0fKZbJcDOpeCh1zYxKxGVmdUghmFmteO3DV5/M38upZX+84bYca3Hu+hcviv+vbXo+9U0SSGb5boBQ2QNspjZRxTwmbXZYUgEZMThV3diJv1breTIsIWmLYVy4C8edZVvbRDghY0Zjb6K3LBEVO5ENFjFNWQLUqc87RZJZWCKLpi8IRRCp6mlpZQUZS4aSIS0ErQssKaM2KuqAwKWywn/xCDCAgyCbSpa0m6L8SdurBWjD+zxUBZymuanGPoe5J3GGsoM1rlvaMsLIURdZVRCluUlLbg1B8luiUExkFUgXOY8KLYyUhN0zQ5GkSKB+99hmmz/YH3gqgZiy40KAtJWpBucgzDAMngfcynFznVtG2LLS3BeaIKFGXJcDxlZ+uKlIJwbbQWMiCivLA5sDYmgXiFo5TjVuIjF8QUIsXv+15iebTGLjEz0pKcSdhlJUaRtigx1ohHkTaUlSgRvfe44KnXrcSvDKM40nuxANhsN3gfFhf7EETNqQtLWclnbrQWX7Fc6iSkhWpz4OQwDKAVu90F1sxeVGeRGoA+s5cIISzKyaIQc0VpCctnpjPnYFaHFrk4k9fnF6GAUjqjnjn7bC7Gk4QxK5OWhW1GRbT5/hRN//Sf/XPGydOfOpIPjL3neBzxLmFNSdvUbLaNxKG4kfV6RV2XXF1dkkh8/vmvGMeOX3/xOYnE9ZNnbDZriqJgvV6LD1bfkYjUdUNhi3cKpzy+Px/5D+NPMN4tmN433ldIJem95Ronyvp1VunIWUu9XbjM/KOzAkgpQarmiazeuu3ZP5bfvjvh59LqrGmd5sf+vjuCh0DfnZi8oyykqLC2wJpCvqQoKrOQN+8l6DJv6gpyvMlEqmvKslhI5Y/eSo9huD5nzUEUgnMmhCutOWQ35e1mR/CBw80dm81mIZDPNgcppswHmgjeoUgMw0DwjlDaRb0nPBXzSE5KamkLVk2NsQW+H0gp0Z1OdF1Hnd2q3STqqsf2TmBlpR3lQuYrpUjrpXBUInVY+F8m6EV1JcTh4rFVFSHasBB/F8jWiNllXZcc98dsxlnQ1jUpBNwoPlduzO3FQgGZ3wEMw0iMmqpspdBMaSHOa1PQli2nwxGIhK7j4e6OupHQ39nQrMgcMmOtFJaTNLtDDBIlk79HnU0F5ygTrSTSJoTA0A/0ul9OUoJuebyX4hVVCi9pkkyxyTs2Kym0xkmI2SHzjWYDwzCJDB2VoWyrHnlrWbIfckZcyEabpS0obMk0eZSZidoz307lwq1g7KWto5XKFgECedtFvSiF/szZmyZRbKWsgmzqmtOxY7/fE6IIIGJ4jB7qutMi350/E5XncS6PGCYJp47JM01jdtf3GKMeeQbfg/Hjn/yUoqzx08Td6xt+/Ytf4V1EK4O2OrdgBSn6yY9/QtNWXF5e8OTJJf3QU5UFv/zl5/z7f//v6bqe4UcjT58+kza1NdR1S1FYpmmirhvKssYYC7kR8sP4YfxjGG+jy+qtguhtIPQbqvv0viLod9znt97u/L+/fST2O1U0xRSJUVCG2ZV6mkRFN5+i5cScSAhZtq4KlLUwiKx6miaSF45TCCVNK+jSMAyLis5ai4vurNU6a7MUaAVGvIzK7FwdnJziZz6VUgqrNaassEUldzF2IZOHGKHIJo6Kt+JfztVMQz/gvFgLaGSjbdsKN405r63IxUBA52JbFH+BcRwoq5rLiwv6YSAhG7giE8MhO0gXS4zLul1RVTUwZiNEKYCEfMpiUxBy+zGmyND1jPnxzzOBtFI53FcsE4xJxARKiSVBBRANVd4MUpS24OQGuu6UM7wSbdtQlAVPnz7Flhqt4XQ8MAyDcNm8p9lu5f7jkC0WhF9jGiFx+2zgOFtLzKqxZ8+eZUPHRFGWFLaASG7RacpSOEvKKLqhw09OfKsykV6KwQJTFoTcjmuqekFqNut1NlWt8TEsruxjLqqskZZZhEzol/uZwuJjeOtkNz+nNZLVl6IUM1LYym1jVp9YayWcOGRFZZ5TD4d9Rg9NDi+WUN+ZC7VarYiZS+N9YJokfLfMwdfWVNR1y8PDAzmVCmt1di63b/mhfR/GzZs3XF4+oapKnjx5Qm1Lxv5jpvHE8fTA4XDPNJ3YbNZ89tef8tHHH/D06RMuLnb0fUfbNDRNyy9/+SsuLy+z51ng7u6OvuuWKB3hM8YlMeCHltwP49sY7wv1/X2CftP5/EvvFk3f3pi70Gd//TGP8K2P71TRZIwYVo7TJLk31uK7jqHvWbUrTJblU0l4bYwBkqTUAxyPR5q2IfkogI4SZEFrzWq1ZprGJc9OK0VZFdiiIAa/ZMQpbbC6ZFW31GWFn6RFM29qSitWTcsw9MRJlG1iW5CJdl3PNAxZiin8GYVk581uzPPmXmZPqXEc5fWU4ug8Dj1VVS5qLp3begDtjNoEabUUhcV7K0TnLMU3uS03t17cJAhYDH5p+4QQ0IApy+w9JYhVQgjNZV1zOhwYuj4jOmLwGFzI6izxzpKcooC1YIucoq0UWhlSEvftvu+zYtGjDdnhWkizPvO5MotbyOX59c0+VGTi/ExWrwrh/QQnYoBhGBbUKSYJQ54NA6dpIjmkpamtcKqitDpTkuf3+fuossmpnPYDOgb6caCsKla7rWS6hUA/TeIDVhQMk7iVV0UpJpWZHJ/mxSmjX3P0zenUkSYgF1jWFpiiwCgJ9G2ahnEYKK2VdnKMBJ+9mxKi7gsRZQwqpqw0TaxWUubMRPEYxMogId+RtZYUE2VZZb5XyO1nKcC8MoSQGKYRUxjqqiRGh/cTIYkbv9Watmn+fAvCX3j84uc/Y3+9Z7vZcLHd8fz5MzarBmLg9ZsXfPnF59zevebp9RV/89Of8ulnP2K73bDZrAjB0zZtNuEt+au/+ozd9pLD4civ7+/56quvFp7h5eUlzJw1pbJI5Ieu3A/jLzPeKpL+hJPwH/7Q72vl/cPHd6poSklS4UMMoo6bJiRsd4XSmmEcMZkjQkpiLplXl7ISHlJZlKhCgk7HrD7TRk7LdZ2dsbUWFVpKaGvRpkCbSFFkx+8QIURCiItia/YGCiGf1J1HZ5fu89BA76RAEc6PIyDmg8Zo8buJUYqV3FYiFzfeOWkrTVPOMmtZvJ6yN5G0YBBUNDwSkefW2mxp4KYJzj2qjJXIjRA5no7srp/k5/XLY0qrZyTGtCAXVmt0UVCWVS7uJowtqNsWN44SOGo0KfnMGzLElH2FVP53CLm9KJEytiho6mqxyCc7nHvvIGgKqxeZu/hieYITb6zZzHJJuiahQ0Dn4qquaynQgtgZjMOQ0blA9AqXxE7AiGZVkE0vbZCyqjBGfu5D9t9SooAb+oAqjLh3uyG3GpOETeairahKMc88K8KKbGgZQxADt8Ky2W7ohwGf7QGcE7Vo3dSM454nV9eIX0/Inxn4ELGASz7nMT0SM6XVU+CzW3fTtthC0Cy0WlRvl9sNdze3QlI2lkTKLvYIW8FYgpsYp1O27sjhnc5hrDgEh5Dox/HPsxj8IxgxOB4e7jk83PNSfUlb1fzVj37EJx99yAfPn7FeV5y6D7m62vLxJx9TlgX7/T3T1NO2Dc+fP+fy8hKloGlqdrstKQl3U9So4nx/ff2Ei8sLUXEqQZwNj0yOH7RzP4wfxp9vfKeKJpX7+IW1+BBw44DRhqZt6bse1FxoyGbipomoFNWqwXlp1x2PBwoz83MiWkHwEyc/SduDiLYFzkt7bHKRoqhwU4AYqUuRWWZziaVVVRUSixG8oFIzH2RWIs15UcoojH7ctFLwi+WBd47m7KSeYhQfIm1ISPmQIliVOTrZZRiV8EGkyIq4tCgFkXl0a17AjRhJs6t0krZlXVUMgwQnNuMoXjHBYZHPKOUw3ZSEjD71PT77JM3KK2utEO4zUV0bhdJRMgMzEhZcwDsvnkhGeEJ1WRATjFPMaBL4KUuurcJNY+YqJYKTjXoaR+q6YRwFRfRZ5WcLKXAB0Cq7kxtR+ymd3WUNpba4wZFcwGoxRnU+UBRW+DwpERBvKaW1OIafRVeorCDRVmwnVOYWKQVt20AmcCutcN6hjKgm4+zBlV3BtVEyV/1E0omqqrHeMvYSRmmMwbtAqhRVUcs8NgXeR4ZpQqUkQclKkZBCn8Qjj0lriroErRiGAR8DPkVUihDAOQ8oVBBLD6sNRIMyCq3zdeYDWlu5tkIADVOYxEMrJQpbUlU1x+ORh/3hT3X5/6MbXXfM36VCRegOeyyJ4EdWq4ayMux2W66vr1m1LWi5doYh5pa15Xg8cn9/z/FwZNWuc9u5ZLPZEELI93/CatWeBaFqHpO/fhg/jB/Gn3N8p4omISoHjBaDqxDCYkaZYhQ1UlVLayomphAJKaG15XA80K5W9IcuE8kNdVVhC0M3CGl3HAdOpxPrwmKqEpsUKENICh8SwTlKM9C2zVLciGRfNi2VxIFiGEdBqJRiHAamaWSzWkvLxdociSIqPU1B8BNTP2U38XoxvCzqGsWjV8voR8qqhKaBEIVcHgNJKVJIUpxYabnElLJpojhZGytuwiFnqM3O3OLtFIlBPJuqqsSNPRgom1IcNZSi0IKYeR8eI1u0QhtDd5IYlbKsQElczd3dHc+ePsX7CRDr++SkiCxMidGijgOVjSYDSokL7DhMjIMUEat2JVEl1kDSYAwBhfdhIXNrW+CGEe8cVUqgJSsPRDVXZr7S0A/iEWUVQzdQlzU+TLTtShRQJPF0sgWRSHRhsaiYydjnqhBxUBe7gu12x+FwT4yeqt6gjJGWaCb6z15i5w67KUlIZSIRVaSfeoZpRCWbEUklMTJdh/eR1XrFqetpaollSTHQzPEryHcqbbb5O/WE4AhUtOsN9bqlqmuG/QGTItZIjtk4TPjhhtIIp2vyTlA6LRy+rEVEWUVlykcTzSjGnNoYaU2m+CeR+P5jHTevX3B19ZTLywuaqkElxd3DPS9efg0pcP30kn/yNz/myZMLvPdUTclms87eWD2vXr3m5z//OV9//TXbzY5xnOh7ERSs12ucG9luN1xeXi48x9IWqMKSgifFH/Lnfhg/jD/3+E4VTd2pY7fdLh4RAvkIaTUlMTMcY4+xBc1mTbVaEaJD6UhbF6zahk1T0N8/EPyERpLsDZ7NZoOPil6Joqxt12zaOR9OozdrgpuYhg7n3eK1I8bxhtPQUdUVurT0pxNlUWJKS1QJnXx2TVWYqqCfeqZTJzEsVSk5aylQtxVJJYZpoO8GLtZrVFIQxNTr4WGPNrKJNU2NLo2gFykRFkWmkjwspIUprayMPvkgfzIRuqoqhl7S7b0TRaIxhosnVyA2WAzjhJsmrFKUTUuKEJynrEpWqzUhRo7HEyhFmb2TpuORN2/ecPXkQswUpwkXHAohIJdFxThNjLlNp1224LcGZQwRsHVNjI7VboupCt7c3lIUBVdPLgFFV0gb0GhDs9lIoZRd3+fPwcfI/cM9q/WKsii5u7mhbFsUicP+RF3W0tIzFj+T44PHe4UtHx3Vi7KkqiuU1jgnETHRe4rKIA7bJa4fcJNjHEZ8ShQ5O6ksSnov0SU2m1yauchAjNpExWgJcSLGQGEL6tpS1SVKR/pxT7UqMKbk6fUlRWmZXp7wzlO3G4hRkD2dcpyHZPspo0kuEJNHG8Xl1QV1UxJSKynjtqT2Jd2pYxoc280GABfFcDUqBCHLcULWaCpbAGJlEJ0XBV+M7O8OVFXJ1e7iz7MY/CMYQ9/xcH/LNA5s1xt26x3BO8axx4eJrVtRlBJAagtLURhSEsTPOXGQ915UjsM40HUdzgmvUJSIlu12ywcfPGe9Xi1ZhyolKZL/0h/AD+OH8T0c34miaVYRnbqBw6lntle3tuR46hmGQVpkdU3iUc49jQNdf6CqDIXR3L96QWENVWkxCrqjFBMoxee//CWb7Y6h63DTxI3/isk5SCLtNqZYQnVD8pJgbqWNkaKQqStXUVc1x9MRYy37k7QqlFI83D9IWG1Z0vcDzk0cxx5IuDBhjGW9WWNCIKrEFCNffP45McFmu6Ouavo0Md7coFSinVrGaQBy6O7ouNhdUpYlXdcvyFulDVO2Kkgp0Q8jV9fXnA4HXt/cUlrJ8uuGkWPfM4w9XqX/f3v3GxtVlf9x/NNOOzNlpQW3y7TFwQYM4h/+rEW6FQnxl65NNLg82MiKgS5RWZU1SrMrIMKoKGWVZUmkSmRVfKALasQYaepqV2LUGhKgiS5/DIJS3Z2BrtIpUzpDZ87vQWG0tsVzCzMF5v1K7oMezun9DnC/+cydO/d2h0B3nnJcLsWise4H6+Z0P7Ymerz73kc/+9nP1N7erosuyleuZ4gikQ51HO8OCMUlJfr2u6PqSpzoPvPm6j6zFzUJHYtGZYzk8eYpcfKbi263u/vv/uSNIhPG6FhHRJ1GCrcfTV6UHO6InDyjJ3UcP65oR4d+duyYuuIJRWOdyefjefOGyJPnkTcvT8ciEZnEMYXDYbWFw/J6uu/wHWoNyeVy6VjkmHJzcxVpb+++OWeeV1mu7O6P1bKzlTjW/TFMlqv7Oqzuj0fjam/v/up/Xl6e2o8dU/eDVrsUaY8okWjX0f99J6/Xq87OToXD7Sfv/9T9UeuJkxe+S1LWyY9Ku05+dNd9q6O4huYPVU6uW5GO75RwnVC8K64hed7ue0x1HFcsFlXr//6j3Fy33Dk50sl7cGWp++aXiUT3fnLa3TrWHlb4WIda//e/7o/ksrLl8eQqS923QXBlueTKzVIs1imT1f3Ina54XHJ1X7RuEgm5s3PUlZXb/bDik/eDyh96kbKypWjkuLo6Y4p1RHscsxei5HMMT0R1rD2uY+1hxY53Kvfknbq7H5vS/e/b/SiVmDqjHTJyy5i4IpGI2tvD6uzsSD6qJhaN6fjx44p3dV+Dd+rLDllZUo6r+81I5NjJ52Cq+7rCU/c2A9C3SCQi6ez2oyxzHnS3r7/+Wn6/f7DLAGCppaVFl1xyyWCXkRL0I+D8cjb70XkRmhKJhP7zn/9o6NChVveRADA4jDFqb29XSUlJ8llVFxr6EXB+SEU/Oi9CEwAAwGC7MN8KAgAAnGWEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuEJgAAAAuOQ9MHH3ygGTNmqKSkRFlZWXrzzTd/cs22bdt0zTXXyOPx6LLLLtPGjRsHUCoA9EQ/ApBOjkNTJBLRxIkTVVdXZzX/4MGDuvnmm3XDDTeoublZDzzwgO6880698847josFgB+iHwFIpyxjjBnw4qwsbdmyRTNnzux3zqJFi7R161Z99tlnybHf/e53Onr0qBoaGga6awDogX4EINVyUr2DpqYmVVZW9hirqqrSAw880O+aaDSqaDSa/DmRSOjbb7/Vz3/+c2VlZaWqVABnyBij9vZ2lZSUKDv73Ltkkn4EZI5U9KOUh6ZgMCifz9djzOfzKRwO6/jx48rLy+u1pra2Vo8++miqSwOQIi0tLbrkkksGu4xe6EdA5jmb/SjloWkglixZopqamuTPbW1tGjVqlFpaWpSfnz+IlQE4nXA4LL/fr6FDhw52KWcN/Qg4P6WiH6U8NBUVFSkUCvUYC4VCys/P7/NdnSR5PB55PJ5e4/n5+TQp4Dxwrn5sRT8CMs/Z7Ecpv+igoqJCjY2NPcbeffddVVRUpHrXANAD/QjAmXAcmo4dO6bm5mY1NzdL6v4Kb3Nzsw4dOiSp+1T23Llzk/PvvvtuHThwQA8++KD27t2rZ555Rq+++qoWLlx4dl4BgIxFPwKQVsah999/30jqtVVXVxtjjKmurjbTp0/vtWbSpEnG7Xab0aNHmxdffNHRPtva2owk09bW5rRcAGmU7mOVfgSgP6k4Vs/oPk3pEg6HVVBQoLa2Nq4hAM5hmXCsZsJrBC4EqThWz70bqQAAAJyDCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWCE0AAAAWBhSa6urqVFpaKq/Xq/Lycm3fvv2089euXavLL79ceXl58vv9WrhwoTo7OwdUMAD8EP0IQLo4Dk2bN29WTU2NAoGAdu7cqYkTJ6qqqkqHDx/uc/4rr7yixYsXKxAIaM+ePXr++ee1efNmPfTQQ2dcPIDMRj8CkE6OQ9OaNWt01113ad68ebryyiu1fv16DRkyRC+88EKf8z/++GNNnTpVs2fPVmlpqW688UbddtttP/luEAB+Cv0IQDo5Ck2xWEw7duxQZWXl978gO1uVlZVqamrqc811112nHTt2JJvSgQMHVF9fr5tuuqnf/USjUYXD4R4bAPwQ/QhAuuU4mdza2qp4PC6fz9dj3Ofzae/evX2umT17tlpbW3X99dfLGKOuri7dfffdpz0dXltbq0cffdRJaQAyDP0IQLql/Ntz27Zt08qVK/XMM89o586deuONN7R161atWLGi3zVLlixRW1tbcmtpaUl1mQAyAP0IwJlwdKapsLBQLpdLoVCox3goFFJRUVGfa5YtW6Y5c+bozjvvlCSNHz9ekUhE8+fP19KlS5Wd3Tu3eTweeTweJ6UByDD0IwDp5uhMk9vtVllZmRobG5NjiURCjY2Nqqio6HNNR0dHr0bkcrkkScYYp/UCgCT6EYD0c3SmSZJqampUXV2tyZMna8qUKVq7dq0ikYjmzZsnSZo7d65Gjhyp2tpaSdKMGTO0Zs0a/fKXv1R5ebn279+vZcuWacaMGclmBQADQT8CkE6OQ9OsWbN05MgRLV++XMFgUJMmTVJDQ0PyYsxDhw71eCf38MMPKysrSw8//LC++eYb/eIXv9CMGTP0xBNPnL1XASAj0Y8ApFOWOQ/OSYfDYRUUFKitrU35+fmDXQ6AfmTCsZoJrxG4EKTiWOXZcwAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYITQAAABYGFJrq6upUWloqr9er8vJybd++/bTzjx49qgULFqi4uFgej0djx45VfX39gAoGgB+iHwFIlxynCzZv3qyamhqtX79e5eXlWrt2raqqqrRv3z6NGDGi1/xYLKZf//rXGjFihF5//XWNHDlSX331lYYNG3Y26geQwehHANIpyxhjnCwoLy/Xtddeq3Xr1kmSEomE/H6/7rvvPi1evLjX/PXr1+upp57S3r17lZuba7WPaDSqaDSa/DkcDsvv96utrU35+flOygWQRuFwWAUFBWk7VulHAPqTin7k6OO5WCymHTt2qLKy8vtfkJ2tyspKNTU19bnmrbfeUkVFhRYsWCCfz6err75aK1euVDwe73c/tbW1KigoSG5+v99JmQAyAP0IQLo5Ck2tra2Kx+Py+Xw9xn0+n4LBYJ9rDhw4oNdff13xeFz19fVatmyZ/vrXv+rxxx/vdz9LlixRW1tbcmtpaXFSJoAMQD8CkG6Or2lyKpFIaMSIEXruuefkcrlUVlamb775Rk899ZQCgUCfazwejzweT6pLA5Bh6EcAzoSj0FRYWCiXy6VQKNRjPBQKqaioqM81xcXFys3NlcvlSo5dccUVCgaDisVicrvdAygbQKajHwFIN0cfz7ndbpWVlamxsTE5lkgk1NjYqIqKij7XTJ06Vfv371cikUiOff755youLqZBARgw+hGAdHN8n6aamhpt2LBBL730kvbs2aN77rlHkUhE8+bNkyTNnTtXS5YsSc6/55579O233+r+++/X559/rq1bt2rlypVasGDB2XsVADIS/QhAOjm+pmnWrFk6cuSIli9frmAwqEmTJqmhoSF5MeahQ4eUnf19FvP7/XrnnXe0cOFCTZgwQSNHjtT999+vRYsWnb1XASAj0Y8ApJPj+zQNhnTf+wXAwGTCsZoJrxG4EAz6fZoAAAAyFaEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAAqEJAADAwoBCU11dnUpLS+X1elVeXq7t27dbrdu0aZOysrI0c+bMgewWAHqhHwFIF8ehafPmzaqpqVEgENDOnTs1ceJEVVVV6fDhw6dd9+WXX+pPf/qTpk2bNuBiAeCH6EcA0slxaFqzZo3uuusuzZs3T1deeaXWr1+vIUOG6IUXXuh3TTwe1+23365HH31Uo0ePPqOCAeAU+hGAdHIUmmKxmHbs2KHKysrvf0F2tiorK9XU1NTvuscee0wjRozQHXfcYbWfaDSqcDjcYwOAH6IfAUg3R6GptbVV8XhcPp+vx7jP51MwGOxzzYcffqjnn39eGzZssN5PbW2tCgoKkpvf73dSJoAMQD8CkG4p/fZce3u75syZow0bNqiwsNB63ZIlS9TW1pbcWlpaUlglgExAPwJwpnKcTC4sLJTL5VIoFOoxHgqFVFRU1Gv+F198oS+//FIzZsxIjiUSie4d5+Ro3759GjNmTK91Ho9HHo/HSWkAMgz9CEC6OTrT5Ha7VVZWpsbGxuRYIpFQY2OjKioqes0fN26cPv30UzU3Nye3W265RTfccIOam5s5zQ1gwOhHANLN0ZkmSaqpqVF1dbUmT56sKVOmaO3atYpEIpo3b54kae7cuRo5cqRqa2vl9Xp19dVX91g/bNgwSeo1DgBO0Y8ApJPj0DRr1iwdOXJEy5cvVzAY1KRJk9TQ0JC8GPPQoUPKzuZG4wBSj34EIJ2yjDFmsIv4KeFwWAUFBWpra1N+fv5glwOgH5lwrGbCawQuBKk4VnkLBgAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYIHQBAAAYGFAoamurk6lpaXyer0qLy/X9u3b+527YcMGTZs2TcOHD9fw4cNVWVl52vkA4AT9CEC6OA5NmzdvVk1NjQKBgHbu3KmJEyeqqqpKhw8f7nP+tm3bdNttt+n9999XU1OT/H6/brzxRn3zzTdnXDyAzEY/ApBOWcYY42RBeXm5rr32Wq1bt06SlEgk5Pf7dd9992nx4sU/uT4ej2v48OFat26d5s6da7XPcDisgoICtbW1KT8/30m5ANIo3ccq/QhAf1JxrDo60xSLxbRjxw5VVlZ+/wuys1VZWammpiar39HR0aETJ07o4osv7ndONBpVOBzusQHAD9GPAKSbo9DU2tqqeDwun8/XY9zn8ykYDFr9jkWLFqmkpKRHo/ux2tpaFRQUJDe/3++kTAAZgH4EIN3S+u25VatWadOmTdqyZYu8Xm+/85YsWaK2trbk1tLSksYqAWQC+hEAp3KcTC4sLJTL5VIoFOoxHgqFVFRUdNq1q1ev1qpVq/Tee+9pwoQJp53r8Xjk8XiclAYgw9CPAKSbozNNbrdbZWVlamxsTI4lEgk1NjaqoqKi33VPPvmkVqxYoYaGBk2ePHng1QLASfQjAOnm6EyTJNXU1Ki6ulqTJ0/WlClTtHbtWkUiEc2bN0+SNHfuXI0cOVK1tbWSpL/85S9avny5XnnlFZWWliavNbjooot00UUXncWXAiDT0I8ApJPj0DRr1iwdOXJEy5cvVzAY1KRJk9TQ0JC8GPPQoUPKzv7+BNazzz6rWCym3/72tz1+TyAQ0COPPHJm1QPIaPQjAOnk+D5Ng4H7ogDnh0w4VjPhNQIXgkG/TxMAAECmIjQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYIDQBAABYGFBoqqurU2lpqbxer8rLy7V9+/bTzn/ttdc0btw4eb1ejR8/XvX19QMqFgB+jH4EIF0ch6bNmzerpqZGgUBAO3fu1MSJE1VVVaXDhw/3Of/jjz/WbbfdpjvuuEO7du3SzJkzNXPmTH322WdnXDyAzEY/ApBOWcYY42RBeXm5rr32Wq1bt06SlEgk5Pf7dd9992nx4sW95s+aNUuRSERvv/12cuxXv/qVJk2apPXr1/e5j2g0qmg0mvy5ra1No0aNUktLi/Lz852UCyCNwuGw/H6/jh49qoKCgpTvj34EoD8p6UfGgWg0alwul9myZUuP8blz55pbbrmlzzV+v9/87W9/6zG2fPlyM2HChH73EwgEjCQ2NrbzdPviiy+ctJYBoR+xsbHZbGezH+XIgdbWVsXjcfl8vh7jPp9Pe/fu7XNNMBjsc34wGOx3P0uWLFFNTU3y56NHj+rSSy/VoUOH0vLu9Ww6lXTPx3el1D44zufaT52Fufjii1O+L/qRc+fz/y1qHxznc+2p6EeOQlO6eDweeTyeXuMFBQXn3T/aKfn5+dQ+CKh9cGRnXzhfzKUfnVuofXCcz7WfzX7k6DcVFhbK5XIpFAr1GA+FQioqKupzTVFRkaP5AGCDfgQg3RyFJrfbrbKyMjU2NibHEomEGhsbVVFR0eeaioqKHvMl6d133+13PgDYoB8BSDunF0Ft2rTJeDwes3HjRrN7924zf/58M2zYMBMMBo0xxsyZM8csXrw4Of+jjz4yOTk5ZvXq1WbPnj0mEAiY3Nxc8+mnn1rvs7Oz0wQCAdPZ2em03EFH7YOD2gdHumunHzlD7YOD2gdHKmp3HJqMMebpp582o0aNMm6320yZMsV88sknyT+bPn26qa6u7jH/1VdfNWPHjjVut9tcddVVZuvWrWdUNACcQj8CkC6O79MEAACQiS6cr7gAAACkEKEJAADAAqEJAADAAqEJAADAwjkTmurq6lRaWiqv16vy8nJt3779tPNfe+01jRs3Tl6vV+PHj1d9fX2aKu3NSe0bNmzQtGnTNHz4cA0fPlyVlZU/+VpTyenf+ymbNm1SVlaWZs6cmdoC++G07qNHj2rBggUqLi6Wx+PR2LFjB+3/jNPa165dq8svv1x5eXny+/1auHChOjs701Tt9z744APNmDFDJSUlysrK0ptvvvmTa7Zt26ZrrrlGHo9Hl112mTZu3JjyOs8G+tHgOF/7kURPypieNNhf3zOm+14rbrfbvPDCC+bf//63ueuuu8ywYcNMKBTqc/5HH31kXC6XefLJJ83u3bvNww8/7PheK2eL09pnz55t6urqzK5du8yePXvM73//e1NQUGC+/vrrNFfuvPZTDh48aEaOHGmmTZtmfvOb36Sn2B9wWnc0GjWTJ082N910k/nwww/NwYMHzbZt20xzc3OaK3de+8svv2w8Ho95+eWXzcGDB80777xjiouLzcKFC9NcuTH19fVm6dKl5o033jCSej0o98cOHDhghgwZYmpqaszu3bvN008/bVwul2loaEhPwQNEP6IfOUVPypyedE6EpilTppgFCxYkf47H46akpMTU1tb2Of/WW281N998c4+x8vJy84c//CGldfbFae0/1tXVZYYOHWpeeumlVJXYr4HU3tXVZa677jrz97//3VRXVw9Kk3Ja97PPPmtGjx5tYrFYukrsl9PaFyxYYP7v//6vx1hNTY2ZOnVqSuv8KTYN6sEHHzRXXXVVj7FZs2aZqqqqFFZ25uhH9COn6EmZ05MG/eO5WCymHTt2qLKyMjmWnZ2tyspKNTU19bmmqampx3xJqqqq6nd+qgyk9h/r6OjQiRMn0vJU+B8aaO2PPfaYRowYoTvuuCMdZfYykLrfeustVVRUaMGCBfL5fLr66qu1cuVKxePxdJUtaWC1X3fdddqxY0fydPmBAwdUX1+vm266KS01n4lz5Th1gn5EP3KKnpRZPSnnbBflVGtrq+LxuHw+X49xn8+nvXv39rkmGAz2OT8YDKaszr4MpPYfW7RokUpKSnr9Q6baQGr/8MMP9fzzz6u5uTkNFfZtIHUfOHBA//rXv3T77bervr5e+/fv17333qsTJ04oEAiko2xJA6t99uzZam1t1fXXXy9jjLq6unT33XfroYceSkfJZ6S/4zQcDuv48ePKy8sbpMr6Rz+iHzlFT8qsnjToZ5oy2apVq7Rp0yZt2bJFXq93sMs5rfb2ds2ZM0cbNmxQYWHhYJfjSCKR0IgRI/Tcc8+prKxMs2bN0tKlS7V+/frBLu0nbdu2TStXrtQzzzyjnTt36o033tDWrVu1YsWKwS4NFxj6UfrQk85fg36mqbCwUC6XS6FQqMd4KBRSUVFRn2uKiooczU+VgdR+yurVq7Vq1Sq99957mjBhQirL7JPT2r/44gt9+eWXmjFjRnIskUhIknJycrRv3z6NGTMmtUVrYH/nxcXFys3NlcvlSo5dccUVCgaDisVicrvdKa35lIHUvmzZMs2ZM0d33nmnJGn8+PGKRCKaP3++li5dquzsc/d9T3/HaX5+/jl5lkmiH9GPnKMnZVZPGvRX53a7VVZWpsbGxuRYIpFQY2OjKioq+lxTUVHRY74kvfvuu/3OT5WB1C5JTz75pFasWKGGhgZNnjw5HaX24rT2cePG6dNPP1Vzc3Nyu+WWW3TDDTeoublZfr//nKxbkqZOnar9+/cnm6okff755youLk5bc5IGVntHR0evJnSq0Zpz/LGR58px6gT9iH6U6toletJgOSvHqtMr1FNh06ZNxuPxmI0bN5rdu3eb+fPnm2HDhplgMGiMMWbOnDlm8eLFyfkfffSRycnJMatXrzZ79uwxgUBgUL/i66T2VatWGbfbbV5//XXz3//+N7m1t7ef87X/2GB9W8Vp3YcOHTJDhw41f/zjH82+ffvM22+/bUaMGGEef/zxc772QCBghg4dav7xj3+YAwcOmH/+859mzJgx5tZbb0177e3t7WbXrl1m165dRpJZs2aN2bVrl/nqq6+MMcYsXrzYzJkzJzn/1Nd7//znP5s9e/aYurq68+aWA/Qj+pET9KTM6UnnRGgyxpinn37ajBo1yrjdbjNlyhTzySefJP9s+vTpprq6usf8V1991YwdO9a43W5z1VVXma1bt6a54u85qf3SSy81knptgUAg/YUb53/vPzSYTcpp3R9//LEpLy83Ho/HjB492jzxxBOmq6srzVV3c1L7iRMnzCOPPGLGjBljvF6v8fv95t577zXfffdd2ut+//33+/y/e6re6upqM3369F5rJk2aZNxutxk9erR58cUX0173QNCPAukv3Jy//cgYelKm9KQsY87x82kAAADngEG/pgkAAOB8QGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACwQGgCAACw8P+J+/AF07Vw5wAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The mistakes look all quite reasonable.\n", "show_img_grid(\n", " [batch['image'][idx] for idx in error_idxs[:9]],\n", " [f'pred: {imagenet2012_label(preds_labels[idx][0])}\\n'\n", " f'label: {imagenet2012_label(preds_labels[idx][1])}'\n", " for idx in error_idxs[:9]],\n", ")\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "outputId": "4fae2533-5598-4f2e-c133-50bfba463311" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset imagenette for split validation[0:3925], from /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0\n", "WARNING:absl:options.experimental_threading is deprecated. Use options.threading instead.\n" ] } ], "source": [ "# Define parallelized inference function in separate cell so the cached\n", "# compilation can be used if below cell is executed multiple times.\n", "@jax.pmap\n", "def p_get_logits(images):\n", " return model.apply({'params': state.params, 'batch_stats': state.batch_stats},\n", " images, train=False)\n", "\n", "eval_iter = train.create_input_iter(dataset_builder, config.batch_size,\n", " input_pipeline.IMAGE_SIZE, tf.float32,\n", " train=False, cache=False, shuffle_buffer_size=None,\n", " prefetch=1)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "outputId": "d01e1993-28ab-4a4a-ac58-01c83b80e6c9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1/7...\n", "Step 2/7...\n", "Step 3/7...\n", "Step 4/7...\n", "Step 5/7...\n", "Step 6/7...\n", "Step 7/7...\n" ] }, { "data": { "text/plain": [ "Array(0.9118304, dtype=float32)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute accuracy.\n", "eval_steps = dataset_builder.info.splits['validation'].num_examples // config.batch_size\n", "count = correct = 0\n", "for step, batch in zip(range(eval_steps), eval_iter):\n", " labels = [imagenette_imagenet2012(label) for label in batch['label'].flatten()]\n", " logits = p_get_logits(batch['image'])\n", " logits = logits.reshape([-1, logits.shape[-1]])\n", " print(f'Step {step+1}/{eval_steps}...')\n", " count += len(labels)\n", " correct += (logits.argmax(axis=-1) == jnp.array(labels)).sum()\n", "\n", "correct / count" ] } ], "metadata": { "accelerator": "GPU", "gpuClass": "standard", "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/imagenet/imagenet_benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark for the ImageNet example.""" import time from absl import flags from absl.testing import absltest from absl.testing.flagsaver import flagsaver from flax.testing import Benchmark import jax import numpy as np # Local imports. import main from configs import v100_x8_mixed_precision as config_lib # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS class ImagenetBenchmark(Benchmark): """Benchmarks for the ImageNet Flax example.""" @flagsaver def _test_8x_v100_half_precision( self, num_epochs: int, min_accuracy, max_accuracy ): """Utility to benchmark ImageNet on 8xV100 GPUs. Use in your test func.""" # Prepare and set flags defined in main.py. config = config_lib.get_config() config.num_epochs = num_epochs workdir = self.get_tmp_model_dir() FLAGS.workdir = workdir FLAGS.config = config start_time = time.time() main.main([]) benchmark_time = time.time() - start_time summaries = self.read_summaries(workdir) # Summaries contain all the information necessary for the regression # metrics. wall_time, _, eval_accuracy = zip(*summaries['eval_accuracy']) wall_time = np.array(wall_time) sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) end_accuracy = eval_accuracy[-1] # Assertions are deferred until the test finishes, so the metrics are # always reported and benchmark success is determined based on *all* # assertions. self.assertBetween(end_accuracy, min_accuracy, max_accuracy) # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics( {'sec_per_epoch': sec_per_epoch, 'accuracy': end_accuracy} ) def test_8x_v100_half_precision_short(self): """Run ImageNet on 8x V100 GPUs in half precision for 2 epochs.""" self._test_8x_v100_half_precision( num_epochs=2, min_accuracy=0.06, max_accuracy=0.09 ) self.report_extras({ 'description': 'Short (2 epochs) 8 x V100 test for ImageNet ResNet50.', 'model_name': 'resnet50', 'parameters': 'hp=true,bs=2048,num_epochs=2', 'implementation': 'linen', }) def test_8x_v100_half_precision_full(self): """Run ImageNet on 8x V100 GPUs in half precision for full 90 epochs.""" self._test_8x_v100_half_precision( num_epochs=90, min_accuracy=0.76, max_accuracy=0.77 ) self.report_extras({ 'description': 'Full (90 epochs) 8 x V100 test for ImageNet ResNet50.', 'model_name': 'resnet50', 'parameters': 'hp=true,bs=2048,num_epochs=90', 'implementation': 'linen', }) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/imagenet/imagenet_fake_data_benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark for the ImageNet example using fake data for quick perf results. This script doesn't need the dataset, but it needs the dataset metadata. That can be fetched with the script `flax/tests/download_dataset_metadata.sh`. """ import pathlib import time from absl.testing import absltest from flax.testing import Benchmark import jax import tensorflow_datasets as tfds # Local imports. from configs import fake_data_benchmark as config_lib import train # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class ImagenetBenchmarkFakeData(Benchmark): """Runs ImageNet using fake data for quickly measuring performance.""" def test_fake_data(self): workdir = self.get_tmp_model_dir() config = config_lib.get_config() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).absolute().parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # Warm-up first so that we are not measuring just compilation. with tfds.testing.mock_data(num_examples=1024, data_dir=data_dir): train.train_and_evaluate(config, workdir) start_time = time.time() with tfds.testing.mock_data(num_examples=1024, data_dir=data_dir): train.train_and_evaluate(config, workdir) benchmark_time = time.time() - start_time self.report_wall_time(benchmark_time) self.report_extras({ 'description': 'ImageNet ResNet50 with fake data', 'model_name': 'resnet50', 'parameters': f'hp=true,bs={config.batch_size}', }) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/imagenet/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet input pipeline.""" import jax import tensorflow as tf import tensorflow_datasets as tfds IMAGE_SIZE = 224 CROP_PADDING = 32 MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] def distorted_bounding_box_crop( image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), max_attempts=100, ): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. Args: image_bytes: `Tensor` of binary image data. bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where each coordinate is [0, 1) and the coordinates are arranged as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area of the image must contain at least this fraction of any bounding box supplied. aspect_ratio_range: An optional list of `float`s. The cropped area of the image must have an aspect ratio = width / height within this range. area_range: An optional list of `float`s. The cropped area of the image must contain a fraction of the supplied image within this range. max_attempts: An optional `int`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the entire image. Returns: cropped image `Tensor` """ shape = tf.io.extract_jpeg_shape(image_bytes) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, use_image_if_no_bounding_boxes=True, ) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) return image def _resize(image, image_size): return tf.image.resize( [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC )[0] def _at_least_x_are_equal(a, b, x): """At least `x` of `a` and `b` `Tensors` are equal.""" match = tf.equal(a, b) match = tf.cast(match, tf.int32) return tf.greater_equal(tf.reduce_sum(match), x) def _decode_and_random_crop(image_bytes, image_size): """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = distorted_bounding_box_crop( image_bytes, bbox, min_object_covered=0.1, aspect_ratio_range=(3.0 / 4, 4.0 / 3.0), area_range=(0.08, 1.0), max_attempts=10, ) original_shape = tf.io.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) image = tf.cond( bad, lambda: _decode_and_center_crop(image_bytes, image_size), lambda: _resize(image, image_size), ) return image def _decode_and_center_crop(image_bytes, image_size): """Crops to center of image with padding then scales image_size.""" shape = tf.io.extract_jpeg_shape(image_bytes) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ( (image_size / (image_size + CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32) ), tf.int32, ) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([ offset_height, offset_width, padded_center_crop_size, padded_center_crop_size, ]) image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = _resize(image, image_size) return image def normalize_image(image): image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) return image def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for training. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_random_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.random_flip_left_right(image) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE): """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. dtype: data type of the image. image_size: image size. Returns: A preprocessed image `Tensor`. """ image = _decode_and_center_crop(image_bytes, image_size) image = tf.reshape(image, [image_size, image_size, 3]) image = normalize_image(image) image = tf.image.convert_image_dtype(image, dtype=dtype) return image def create_split( dataset_builder, batch_size, train, dtype=tf.float32, image_size=IMAGE_SIZE, cache=False, shuffle_buffer_size=2_000, prefetch=10, ): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image. image_size: The target size of the images. cache: Whether to cache the dataset. shuffle_buffer_size: Size of the shuffle buffer. prefetch: Number of items to prefetch in the dataset. Returns: A `tf.data.Dataset`. """ if train: train_examples = dataset_builder.info.splits['train'].num_examples split_size = train_examples // jax.process_count() start = jax.process_index() * split_size split = f'train[{start}:{start + split_size}]' else: validate_examples = dataset_builder.info.splits['validation'].num_examples split_size = validate_examples // jax.process_count() start = jax.process_index() * split_size split = f'validation[{start}:{start + split_size}]' def decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset( split=split, decoders={ 'image': tfds.decode.SkipDecoding(), }, ) options = tf.data.Options() options.experimental_threading.private_threadpool_size = 48 ds = ds.with_options(options) if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(shuffle_buffer_size, seed=0) ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(prefetch) return ds ================================================ FILE: examples/imagenet/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the ImageNet example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': flags.mark_flags_as_required(['config', 'workdir']) app.run(main) ================================================ FILE: examples/imagenet/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax implementation of ResNet V1.5.""" # See issue #620. # pytype: disable=wrong-arg-count from functools import partial from typing import Any, Tuple from collections.abc import Callable, Sequence from flax import linen as nn import jax.numpy as jnp ModuleDef = Any class ResNetBlock(nn.Module): """ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: tuple[int, int] = (1, 1) @nn.compact def __call__( self, x, ): residual = x y = self.conv(self.filters, (3, 3), self.strides)(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters, (3, 3))(y) y = self.norm(scale_init=nn.initializers.zeros_init())(y) if residual.shape != y.shape: residual = self.conv( self.filters, (1, 1), self.strides, name='conv_proj' )(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class BottleneckResNetBlock(nn.Module): """Bottleneck ResNet block.""" filters: int conv: ModuleDef norm: ModuleDef act: Callable strides: tuple[int, int] = (1, 1) @nn.compact def __call__(self, x): residual = x y = self.conv(self.filters, (1, 1))(x) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters, (3, 3), self.strides)(y) y = self.norm()(y) y = self.act(y) y = self.conv(self.filters * 4, (1, 1))(y) y = self.norm(scale_init=nn.initializers.zeros_init())(y) if residual.shape != y.shape: residual = self.conv( self.filters * 4, (1, 1), self.strides, name='conv_proj' )(residual) residual = self.norm(name='norm_proj')(residual) return self.act(residual + y) class ResNet(nn.Module): """ResNetV1.5.""" stage_sizes: Sequence[int] block_cls: ModuleDef num_classes: int num_filters: int = 64 dtype: Any = jnp.float32 act: Callable = nn.relu conv: ModuleDef = nn.Conv @nn.compact def __call__(self, x, train: bool = True): conv = partial(self.conv, use_bias=False, dtype=self.dtype) norm = partial( nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype, axis_name='batch', ) x = conv( self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init', )(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls( self.num_filters * 2**i, strides=strides, conv=conv, norm=norm, act=self.act, )(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) return x ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock) ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) ResNet50 = partial( ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock ) ResNet101 = partial( ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock ) ResNet152 = partial( ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock ) ResNet200 = partial( ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock ) ResNet18Local = partial( ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, conv=nn.ConvLocal ) # Used for testing only. _ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) _ResNet1Local = partial( ResNet, stage_sizes=[1], block_cls=ResNetBlock, conv=nn.ConvLocal ) ================================================ FILE: examples/imagenet/models_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.imagenet.models.""" from absl.testing import absltest from absl.testing import parameterized import jax from jax import numpy as jnp import models jax.config.update('jax_disable_most_optimizations', True) class ResNetTest(parameterized.TestCase): """Test cases for ResNet v1.5 model definition.""" def test_resnet_model(self): """Tests ResNet V1.5 model definition and output (variables).""" rng = jax.random.key(0) model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32)) self.assertLen(variables, 2) # Resnet50 model will create parameters for the following layers: # conv + batch_norm = 2 # BottleneckResNetBlock in stages: [3, 4, 6, 3] = 16 # Followed by a Dense layer = 1 self.assertLen(variables['params'], 19) @parameterized.product(model=(models.ResNet18, models.ResNet18Local)) def test_resnet_18_model(self, model): """Tests ResNet18 V1.5 model definition and output (variables).""" rng = jax.random.key(0) model_def = model(num_classes=2, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32)) self.assertLen(variables, 2) self.assertLen(variables['params'], 11) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/imagenet/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.6.5 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda11_cudnn805]>=0.3.16 # change to jax[tpu] if running on tpus ml-collections==0.1.0 numpy==1.22.0 optax==0.1.3 tensorflow==2.11.1 tensorflow-datasets==4.4.0 ================================================ FILE: examples/imagenet/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ImageNet example. This script trains a ResNet-50 on the ImageNet dataset. The data is loaded using tensorflow_datasets. """ import functools import time from typing import Any from absl import logging from clu import metric_writers from clu import periodic_actions from flax import jax_utils from flax.training import checkpoints from flax.training import common_utils from flax.training import dynamic_scale as dynamic_scale_lib from flax.training import train_state import jax from jax import lax import jax.numpy as jnp from jax import random import ml_collections import optax import orbax.checkpoint as ocp import tensorflow as tf import tensorflow_datasets as tfds import input_pipeline import models NUM_CLASSES = 1000 def create_model(*, model_cls, half_precision, **kwargs): platform = jax.local_devices()[0].platform if half_precision: if platform == 'tpu': model_dtype = jnp.bfloat16 else: model_dtype = jnp.float16 else: model_dtype = jnp.float32 return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs) def initialized(key, image_size, model): input_shape = (1, image_size, image_size, 3) @jax.jit def init(*args): return model.init(*args) variables = init({'params': key}, jnp.ones(input_shape, model.dtype)) return variables['params'], variables['batch_stats'] def cross_entropy_loss(logits, labels): one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES) xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels) return jnp.mean(xentropy) def compute_metrics(logits, labels): loss = cross_entropy_loss(logits, labels) accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, 'accuracy': accuracy, } metrics = lax.pmean(metrics, axis_name='batch') return metrics def create_learning_rate_fn( config: ml_collections.ConfigDict, base_learning_rate: float, steps_per_epoch: int, ): """Create learning rate schedule.""" warmup_fn = optax.linear_schedule( init_value=0.0, end_value=base_learning_rate, transition_steps=config.warmup_epochs * steps_per_epoch, ) cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) cosine_fn = optax.cosine_decay_schedule( init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch ) schedule_fn = optax.join_schedules( schedules=[warmup_fn, cosine_fn], boundaries=[config.warmup_epochs * steps_per_epoch], ) return schedule_fn def train_step(state, batch, learning_rate_fn): """Perform a single training step.""" def loss_fn(params): """loss function used for training.""" logits, new_model_state = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, batch['image'], mutable=['batch_stats'], ) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_util.tree_leaves(params) weight_decay = 0.0001 weight_l2 = sum( jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1 ) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss, (new_model_state, logits) step = state.step dynamic_scale = state.dynamic_scale lr = learning_rate_fn(step) if dynamic_scale: grad_fn = dynamic_scale.value_and_grad( loss_fn, has_aux=True, axis_name='batch' ) dynamic_scale, is_fin, aux, grads = grad_fn(state.params) # dynamic loss takes care of averaging gradients across replicas else: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) aux, grads = grad_fn(state.params) # Re-use same axis_name as in the call to `pmap(...train_step...)` below. grads = lax.pmean(grads, axis_name='batch') new_model_state, logits = aux[1] metrics = compute_metrics(logits, batch['label']) metrics['learning_rate'] = lr new_state = state.apply_gradients( grads=grads, batch_stats=lax.pmean(new_model_state['batch_stats'], 'batch'), ) if dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). new_state = new_state.replace( opt_state=jax.tree_util.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state, ), params=jax.tree_util.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params ), dynamic_scale=dynamic_scale, ) metrics['scale'] = dynamic_scale.scale return new_state, metrics def eval_step(state, batch): variables = {'params': state.params, 'batch_stats': state.batch_stats} logits = state.apply_fn(variables, batch['image'], train=False, mutable=False) return compute_metrics(logits, batch['label']) def prepare_tf_data(xs): """Convert a input batch from tf Tensors to numpy arrays.""" local_device_count = jax.local_device_count() def _prepare(x): # Use _numpy() for zero-copy conversion between TF and NumPy. x = x._numpy() # pylint: disable=protected-access # reshape (host_batch_size, height, width, 3) to # (local_devices, device_batch_size, height, width, 3) return x.reshape((local_device_count, -1) + x.shape[1:]) return jax.tree_util.tree_map(_prepare, xs) def create_input_iter( dataset_builder, batch_size, image_size, dtype, train, cache, shuffle_buffer_size, prefetch, ): ds = input_pipeline.create_split( dataset_builder, batch_size, image_size=image_size, dtype=dtype, train=train, cache=cache, shuffle_buffer_size=shuffle_buffer_size, prefetch=prefetch, ) it = map(prepare_tf_data, ds) it = jax_utils.prefetch_to_device(it, 2) return it class TrainState(train_state.TrainState): batch_stats: Any dynamic_scale: dynamic_scale_lib.DynamicScale def restore_checkpoint(state, workdir): return checkpoints.restore_checkpoint(workdir, state) def save_checkpoint(state, workdir): step = int(state.step) logging.info('Saving checkpoint step %d.', step) # Orbax can not handle host local arrays from pmap. Convert to global arrays. replicated_state = jax.tree_util.tree_map( ocp.utils.fully_replicated_host_local_array_to_global_array, state, ) checkpoints.save_checkpoint_multiprocess( workdir, replicated_state, step, keep=3 ) def create_train_state( rng, config: ml_collections.ConfigDict, model, image_size, learning_rate_fn ): """Create initial training state.""" dynamic_scale = None platform = jax.local_devices()[0].platform if config.half_precision and platform == 'gpu': dynamic_scale = dynamic_scale_lib.DynamicScale() params, batch_stats = initialized(rng, image_size, model) tx = optax.sgd( learning_rate=learning_rate_fn, momentum=config.momentum, nesterov=True, ) state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, dynamic_scale=dynamic_scale, ) return state def train_and_evaluate( config: ml_collections.ConfigDict, workdir: str ) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: Final TrainState. """ writer = metric_writers.create_default_writer( logdir=workdir, just_logging=jax.process_index() != 0 ) rng = random.key(0) image_size = 224 if config.batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') local_batch_size = config.batch_size // jax.process_count() platform = jax.local_devices()[0].platform if config.half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 dataset_builder = tfds.builder(config.dataset) train_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, train=True, cache=config.cache, shuffle_buffer_size=config.shuffle_buffer_size, prefetch=config.prefetch, ) eval_iter = create_input_iter( dataset_builder, local_batch_size, image_size, input_dtype, train=False, cache=config.cache, shuffle_buffer_size=None, prefetch=config.prefetch, ) steps_per_epoch = ( dataset_builder.info.splits['train'].num_examples // config.batch_size ) if config.num_train_steps <= 0: num_steps = int(steps_per_epoch * config.num_epochs) else: num_steps = config.num_train_steps if config.steps_per_eval == -1: num_validation_examples = dataset_builder.info.splits[ 'validation' ].num_examples steps_per_eval = num_validation_examples // config.batch_size else: steps_per_eval = config.steps_per_eval steps_per_checkpoint = steps_per_epoch * 10 base_learning_rate = config.learning_rate * config.batch_size / 256.0 model_cls = getattr(models, config.model) model = create_model( model_cls=model_cls, half_precision=config.half_precision ) learning_rate_fn = create_learning_rate_fn( config, base_learning_rate, steps_per_epoch ) state = create_train_state(rng, config, model, image_size, learning_rate_fn) state = restore_checkpoint(state, workdir) # step_offset > 0 if restarting from checkpoint step_offset = int(state.step) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), in_axes=(None, 0), out_axes=(None, 0), axis_name='batch', ) p_eval_step = jax.pmap(eval_step, in_axes=(None, 0), axis_name='batch') train_metrics = [] hooks = [] if jax.process_index() == 0 and config.profile: hooks += [ periodic_actions.Profile( num_profile_steps=3, profile_duration_ms=None, logdir=workdir ) ] train_metrics_last_t = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter): state, metrics = p_train_step(state, batch) for h in hooks: h(step) if step == step_offset: logging.info('Initial compilation completed.') if config.get('log_every_steps'): train_metrics.append(metrics) if (step + 1) % config.log_every_steps == 0: train_metrics = common_utils.get_metrics(train_metrics) summary = { f'train_{k}': v for k, v in jax.tree_util.tree_map( lambda x: x.mean(), train_metrics ).items() } summary['steps_per_second'] = config.log_every_steps / ( time.time() - train_metrics_last_t ) writer.write_scalars(step + 1, summary) train_metrics = [] train_metrics_last_t = time.time() if (step + 1) % steps_per_epoch == 0: epoch = step // steps_per_epoch eval_metrics = [] for _ in range(steps_per_eval): eval_batch = next(eval_iter) metrics = p_eval_step(state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) summary = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics) logging.info( 'eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, summary['loss'], summary['accuracy'] * 100, ) writer.write_scalars( step + 1, {f'eval_{key}': val for key, val in summary.items()} ) writer.flush() if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps: save_checkpoint(state, workdir) # Wait until computations are done before exiting jax.random.normal(jax.random.key(0), ()).block_until_ready() return state ================================================ FILE: examples/imagenet/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.imagenet.train.""" import pathlib import tempfile from absl.testing import absltest from absl.testing import parameterized import jax from jax import random import tensorflow as tf import tensorflow_datasets as tfds # Local imports. import models import train from configs import default as default_lib jax.config.update('jax_disable_most_optimizations', True) class TrainTest(parameterized.TestCase): def setUp(self): super().setUp() # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') def test_create_model(self): """Tests creating model.""" model = train.create_model(model_cls=models._ResNet1, half_precision=False) # pylint: disable=protected-access params, batch_stats = train.initialized(random.key(0), 224, model) variables = {'params': params, 'batch_stats': batch_stats} x = random.normal(random.key(1), (8, 224, 224, 3)) y = model.apply(variables, x, train=False) self.assertEqual(y.shape, (8, 1000)) def test_create_model_local(self): """Tests creating an unshared convolution model. Uses smaller inputs than `test_create_model` to due to higher compute. """ model = train.create_model( model_cls=models._ResNet1Local, half_precision=False ) # pylint: disable=protected-access params, batch_stats = train.initialized(random.key(0), 64, model) variables = {'params': params, 'batch_stats': batch_stats} x = random.normal(random.key(1), (1, 64, 64, 3)) y = model.apply(variables, x, train=False) self.assertEqual(y.shape, (1, 1000)) @parameterized.product(model=('_ResNet1', '_ResNet1Local')) def test_train_and_evaluate(self, model): """Tests training and evaluation loop using mocked data.""" # Create a temporary directory where tensorboard metrics are written. workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # Define training configuration config = default_lib.get_config() config.model = model config.batch_size = 1 config.num_epochs = 1 config.num_train_steps = 1 config.steps_per_eval = 1 with tfds.testing.mock_data(num_examples=1, data_dir=data_dir): train.train_and_evaluate(workdir=workdir, config=config) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/linen_design_test/attention_simple.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from pprint import pprint from typing import Any, Optional from collections.abc import Callable, Sequence from flax.core.frozen_dict import unfreeze from flax.linen import initializers from flax.linen import Module, compact, vmap from flax.linen.linear import PrecisionLike import jax from jax import lax, numpy as jnp, random class Dense(Module): features: int use_bias: bool = True kernel_init: Callable = initializers.lecun_normal() bias_init: Callable = initializers.zeros_init() dtype: Any = jnp.float32 precision: PrecisionLike = None @compact def __call__(self, inputs): inputs = jnp.asarray(inputs, self.dtype) kernel = self.param( 'kernel', self.kernel_init, (inputs.shape[-1], self.features) ) kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) bias = jnp.asarray(bias, self.dtype) y = y + bias return y class SoftmaxAttn(Module): @compact def __call__(self, weights): norm_dims = tuple(range(weights.ndim // 2, weights.ndim)) return jax.nn.softmax(weights, axis=norm_dims) class Dropout(Module): rate: float @compact def __call__(self, x, deterministic=False, rng=None): if self.rate == 0.0: return x keep_prob = 1.0 - self.rate if deterministic: return x else: if rng is None: rng = self.scope.make_rng('dropout') mask = random.bernoulli(rng, p=keep_prob, shape=x.shape) return lax.select(mask, x / keep_prob, jnp.zeros_like(x)) class SoftmaxAttnWDropout(Module): rate: float = 0.0 deterministic: bool = False @compact def __call__(self, x): x = SoftmaxAttn()(x) x = Dropout(self.rate)(x, deterministic=self.deterministic) return x class RawDotProductAttention(Module): attn_module: Callable = SoftmaxAttn @compact def __call__(self, query, key, value, bias=None, dtype=jnp.float32): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim attn_weights = lax.dot_general(query, key, (((n - 1,), (n - 1,)), ((), ()))) if bias is not None: attn_weights += bias attn_weights = self.attn_module()(attn_weights) attn_weights = attn_weights.astype(dtype) contract_dims = ( tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1)), ) y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y class DotProductAttention(Module): qkv_features: int | None = None out_features: int | None = None attn_module: Callable = SoftmaxAttn @compact def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): qkv_features = self.qkv_features or inputs_q.shape[-1] out_features = self.out_features or inputs_q.shape[-1] QKVDense = functools.partial( Dense, features=qkv_features, use_bias=False, dtype=dtype ) query = QKVDense(name='query')(inputs_q) key = QKVDense(name='key')(inputs_kv) value = QKVDense(name='value')(inputs_kv) y = RawDotProductAttention(attn_module=self.attn_module)( query, key, value, bias=bias, dtype=dtype ) y = Dense(features=out_features, dtype=dtype, name='out')(y) return y # Trying out a slightly more compact vmap notation: def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): variable_axes = { k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence) } splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)} return vmap( module, in_axes=in_axes, out_axes=out_axes, variable_axes=variable_axes, split_rngs=splits, axis_size=axis_size, ) class MultiHeadDotProductAttention(Module): qkv_features: int | None = None out_features: int | None = None attn_module: Callable = SoftmaxAttn batch_axes: Sequence[int] = (0,) num_heads: int = 1 broadcast_dropout: bool = False @compact def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): qkv_features = self.qkv_features or inputs_q.shape[-1] out_features = self.out_features or inputs_q.shape[-1] # Now, vmap attn.__call__ along heads and spatial dims. Attn = concise_vmap( DotProductAttention, (None, None, None), -2, param=(0, True), dropout=(None, not self.broadcast_dropout), axis_size=self.num_heads, ) for axis in reversed(sorted(self.batch_axes)): Attn = concise_vmap( Attn, (axis, axis, axis), axis, param=(None, False), dropout=(None, not self.broadcast_dropout), ) attn = Attn( attn_module=self.attn_module, qkv_features=qkv_features // self.num_heads, out_features=out_features, ) # evaluate multi-headed-attention. y = attn(inputs_q, inputs_kv, bias) return y.mean(axis=-2) # run it. if __name__ == '__main__': inputs = jnp.ones((8, 97, 256)) rngs = {'params': random.key(0), 'dropout': random.key(1)} model = MultiHeadDotProductAttention( broadcast_dropout=False, qkv_features=256, out_features=256, attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1), num_heads=8, batch_axes=(0,), ) y, params = model.init_with_output(rngs, inputs, inputs) print('input shape: ', inputs.shape) print('parameter shapes:') pprint(jax.tree_util.tree_map(jnp.shape, unfreeze(params))) print('output shape: ', y.shape) ================================================ FILE: examples/linen_design_test/autoencoder.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple from collections.abc import Iterable import jax from jax import numpy as jnp, random from flax import linen as nn from flax.linen import Module, Dense, compact # A concise MLP defined via lazy submodule initialization class MLP(Module): widths: Iterable @compact def __call__(self, x): for width in self.widths[:-1]: x = nn.relu(Dense(width)(x)) return Dense(self.widths[-1])(x) # An autoencoder exposes multiple methods, so we define all # submodules in setup(). class AutoEncoder(Module): encoder_widths: Iterable decoder_widths: Iterable input_shape: tuple = None def setup(self): # Submodules attached in `setup` get names via attribute assignment self.encoder = MLP(self.encoder_widths) self.decoder = MLP(self.decoder_widths + (jnp.prod(self.input_shape),)) def __call__(self, x): return self.decode(self.encode(x)) def encode(self, x): assert x.shape[-len(self.input_shape) :] == self.input_shape return self.encoder(jnp.reshape(x, (x.shape[0], -1))) def decode(self, z): z = self.decoder(z) x = nn.sigmoid(z) x = jnp.reshape(x, (x.shape[0],) + self.input_shape) return x # `ae` is a detached module, which has no variables. ae = AutoEncoder( encoder_widths=(32, 32, 32), decoder_widths=(32, 32, 32), input_shape=(28, 28, 1), ) # `ae.initialized` returns a materialized copy of `ae` by # running through an input to create submodules defined lazily. params = ae.init({"params": random.key(42)}, jnp.ones((1, 28, 28, 1))) # Now you can use `ae` as a normal object, calling any methods defined on AutoEncoder print("reconstruct", jnp.shape(ae.apply(params, jnp.ones((1, 28, 28, 1))))) print( "encoder", jnp.shape(ae.apply(params, jnp.ones((1, 28, 28, 1)), method=ae.encode)), ) # `ae.variables` is a frozen dict that looks like # {'params': {"decoder": {"Dense_0": {"bias": ..., "kernel": ...}, ...}} print("var shapes", jax.tree_util.tree_map(jnp.shape, params)) # TODO(avital, levskaya): resurrect this example once interactive api is restored. # You can access submodules defined in setup(), they are just references on # the autoencoder instance # encoder = ae.encoder # print("encoder var shapes", jax.tree_util.tree_map(jnp.shape, encoder.variables)) # # You can also access submodules that were defined in-line. # # (We may add syntactic sugar here, e.g. to allow `ae.encoder.Dense_0`) # encoder_dense0 = ae.encoder.children['Dense_0'] # print("encoder dense0 var shapes", jax.tree_util.tree_map(jnp.shape, encoder_dense0.variables)) ================================================ FILE: examples/linen_design_test/dense.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from jax import lax from flax.linen import initializers from collections.abc import Callable from flax.linen import Module, compact class Dense(Module): features: int kernel_init: Callable = initializers.lecun_normal() bias_init: Callable = initializers.zeros_init() use_bias: bool = True @compact def __call__(self, inputs): kernel = self.param( 'kernel', self.kernel_init, (inputs.shape[-1], self.features) ) y = lax.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), ) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y ================================================ FILE: examples/linen_design_test/linear_regression.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from jax import numpy as jnp, jit from dense import Dense X = jnp.ones((1, 10)) Y = jnp.ones((5,)) model = Dense(features=5) @jit def predict(params): return model.apply({"params": params}, X) @jit def loss_fn(params): return jnp.mean(jnp.abs(Y - predict(params))) @jit def init_params(rng): mlp_variables = model.init({"params": rng}, X) return mlp_variables["params"] # Get initial parameters params = init_params(jax.random.key(42)) print("initial params", params) # Run SGD. for i in range(50): loss, grad = jax.value_and_grad(loss_fn)(params) print(i, "loss = ", loss, "Yhat = ", predict(params)) lr = 0.03 params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad) ================================================ FILE: examples/linen_design_test/mlp_explicit.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from pprint import pprint from typing import Optional from flax.deprecated import nn from dense import Dense from flax.linen import Module import jax from jax import numpy as jnp # Add `in_features` to the built-in Dense layer that normally works # via shape inference. class DenseExplicit(Dense): in_features: int | None = None def setup(self): # We feed a fake batch through the module, which initialized parameters. # Assuming we're in a jit, should use no FLOPs -- "just shape inference". self.__call__( jnp.zeros(( 1, self.in_features, )) ) class MLP(Module): def setup(self): self.dense1 = DenseExplicit(in_features=3, features=2) self.dense2 = DenseExplicit(in_features=2, features=1) # explicit instances are materialized immediately at init pprint(self.dense2.variables) # {'params': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}} def __call__(self, x): return self.dense2(nn.relu(self.dense1(x))) # Return an initialized instance of MLP by only calling `setup`. rngkey = jax.random.key(10) init_variables = MLP().init({'params': rngkey}, jnp.ones((1, 3))) pprint(init_variables) # {'params': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), # 'kernel': DeviceArray([[ 0.18307537, -0.38739476], # [-0.902451 , -0.5190721 ], # [ 0.51552075, 1.1169153 ]], dtype=float32)}, # 'dense2': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}}} ================================================ FILE: examples/linen_design_test/mlp_inline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from jax import numpy as jnp from flax import linen as nn from collections.abc import Iterable from flax.linen import Module, compact from dense import Dense # Many NN layers and blocks are best described by a single function with inline variables. # In this case, variables are initialized during the first call. class MLP(Module): sizes: Iterable[int] @compact def __call__(self, x): for size in self.sizes[:-1]: x = Dense(size)(x) x = nn.relu(x) return Dense(self.sizes[-1])(x) # Return an initialized instance of MLP by calling `__call__` with an input batch, # initializing all variables. # # Variable shapes depend on the input shape passed in. rngkey = jax.random.key(10) model = MLP((2, 1)) x = jnp.ones((1, 3)) mlp_variables = model.init(rngkey, x) print(mlp_variables) # {'params': {'Dense_0': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[-0.04267037], # [-0.51097125]], dtype=float32)}, # 'Dense_1': {'bias': DeviceArray([0., 0.], dtype=float32), # 'kernel': DeviceArray([[-6.3845289e-01, 6.0373604e-01], # [-5.9814966e-01, 5.1718324e-01], # [-6.2220657e-01, 5.8988278e-04]], dtype=float32)}}} print(model.apply(mlp_variables, x)) ================================================ FILE: examples/linen_design_test/mlp_lazy.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from jax import numpy as jnp from flax import linen as nn from flax.linen import Module from pprint import pprint from dense import Dense # Here submodules are explicitly defined during init, but still materialized # lazily only once a first input is passed through and shapes are known. class MLP(Module): def setup(self): self.dense1 = Dense(features=2) self.dense2 = Dense(features=1) # shapes aren't yet known, so variables aren't materialized print(self.dense2.variables) # FrozenDict({}) def __call__(self, x): return self.dense2(nn.relu(self.dense1(x))) # Return an initialized instance of MLP by calling `__call__` with an input batch, # initializing all variables. # # Variable shapes depend on the input shape passed in. rngkey = jax.random.key(10) mlp_variables = MLP().init(rngkey, jnp.zeros((1, 3))) pprint(mlp_variables) # {'params': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), # 'kernel': DeviceArray([[ 0.18307537, -0.38739476], # [-0.902451 , -0.5190721 ], # [ 0.51552075, 1.1169153 ]], dtype=float32)}, # 'dense2': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}}} ================================================ FILE: examples/lm1b/README.md ================================================ ## Language modeling Trains a Transformer-based model (Vaswani *et al.*, 2017) on the One Billion Word Benchmark (lm1b; Chelba *et al.*, 2013). This example uses linear learning rate warmup and inverse square root learning rate schedule. Based off of Machine Translation `wmt` example. ### Requirements * TensorFlow datasets `lm1b` need to be downloaded and prepared. A sentencepiece tokenizer vocabulary will be automatically generated and saved on each training run. * This example additionally depends on the `sentencepiece` and `tensorflow-text` packages. ### How to run on Cloud TPUs Setup the TPU VM and install the Flax dependencies on it as described [here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single v3-8 TPU. First create a single TPUv3-8 VM and connect to it (you can find more detailed instructions [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)): ``` ZONE=us-central1-a TPU_TYPE=v3-8 TPU_NAME=$USER-flax-lm1b gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --zone $ZONE \ --accelerator-type $TPU_TYPE \ --version v2-alpha gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE -- \ -L 6006:localhost:6006 ``` When connected install JAX: ``` pip install "jax[tpu]>=0.2.16" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` Then install Flax + the example dependencies: ``` git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/lm1b pip install -r requirements.txt ``` And finally start the training: ``` python3 main.py --workdir=$HOME/logs/lm1b_256 \ --config.per_device_batch_size=32 \ --jax_backend_target="grpc://192.168.0.2:8470" ``` Note that you might want to set `TFDS_DATA_DIR` as explained below. You probably also want to start the long-running command above in a `tmux` session and start some monitoring in a separate pane (note that we forwarded port 6006 locally above): ``` tensorboard --logdir=$HOME/logs ``` You should expect to get numbers similar to these: Hardware | config | Training time | Loss | TensorBoard.dev | Workdir -------- | ------- | ------------- | -------------- | ------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------- TPU v3-8 | default | 13h18m | 3.127 | [2021-08-08](https://tensorboard.dev/experiment/n30WkNOZTJq3RHWD7wNslg/) | [gs://flax_public/examples/lm1b/default](https://console.cloud.google.com/storage/browser/flax_public/examples/lm1b/default) ### Downloading the LM1B Datasets We recommend downloading and preparing the TFDS datasets beforehand. For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS data on a storage bucket, from where it can be loaded directly. Set the `TFDS_DATA_DIR` to your storage bucket path (`gs://`). You can download and prepare LM1B datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b` ================================================ FILE: examples/lm1b/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Path to load or store sentencepiece vocab file. config.vocab_path = None # Vocabulary size if `vocab_path` is not given. config.vocab_size = 30_000 config.max_corpus_chars = 10**7 # Name of TFDS translation dataset to use. config.dataset_name = 'lm1b' # Optional name of TFDS translation dataset to use for evaluation. config.eval_dataset_name = 'lm1b' config.eval_split = 'test' # Per device batch size for training. config.per_device_batch_size = 32 # Per device batch size for training. config.eval_per_device_batch_size = 32 # Sampling temperature for language model inference. config.sampling_temperature = 0.6 # Top k cutoff for logit sampling. If 0 then no top-k cutoff is used. config.sampling_top_k = 20 config.num_train_steps = 500_000 # Number of steps to take during evaluation. Large enough to evaluate all. # Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198 config.num_eval_steps = 2_000 # Number of steps to generate predictions. # -1 will use the whole eval dataset. config.num_predict_steps = -1 # Base learning rate. config.learning_rate = 0.0016 # Linear learning rate warmup. config.warmup_steps = 1000 # Cross entropy loss label smoothing. config.label_smoothing = 0.0 # Decay factor for AdamW style weight decay. config.weight_decay = 0.1 # Maximum length cutoff for training examples. config.max_target_length = 128 # Maximum length cutoff for eval examples. config.max_eval_target_length = 512 # Maximum length cutoff for predicted tokens. config.max_predict_length = 50 # Final logit transform uses embedding matrix transpose. config.logits_via_embedding = False # Number of transformer layers. config.num_layers = 6 # Size of query/key/value for attention. config.qkv_dim = 512 # Size of embeddings. config.emb_dim = 512 # Size of the MLP. config.mlp_dim = 2048 # Number of attention heads. config.num_heads = 8 # Dropout rate. config.dropout_rate = 0.1 # Attention dropout rate. config.attention_dropout_rate = 0.1 # Whether to save model checkpoints. config.save_checkpoints = True # Whether to restore from existing model checkpoints. config.restore_checkpoints = True # Save a checkpoint every these number of steps. config.checkpoint_every_steps = 10_000 # Frequency of eval during training, e.g. every 1_000 steps. config.eval_every_steps = 1_000 # Use bfloat16 mixed precision training instead of float32. config.use_bfloat16 = True # Integer for PRNG random seed. config.seed = 0 # Prompt for language model sampling, # taken from MaxText (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml). config.prompts = 'I love to ' # Parallelism config.mesh_axes = ['data', 'fsdp', 'tensor'] config.logical_axis_rules = [ ['mlp', 'tensor'], ['vocab', 'tensor'], ['embed', 'fsdp'], ['heads', 'tensor'], ] config.data_sharding = ['data'] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. # ICI (Inter-Chip Interconnection): A high-speed connection between # sets of TPU chips, which form the TPU network. # DCN (Data Center Network): A connection between the TPU networks; # not as fast as ICI. # ICI has around 100x the bandwidth of DCN, but it is not a general # purpose connection, which is why DCN is necessary for scaling to # extremely large ML models. config.dcn_data_parallelism = -1 # recommended DCN axis to be auto-sharded config.dcn_fsdp_parallelism = 1 config.dcn_tensor_parallelism = 1 config.ici_data_parallelism = 1 config.ici_fsdp_parallelism = -1 # recommended ICI axis to be auto-sharded config.ici_tensor_parallelism = 1 return config ================================================ FILE: examples/lm1b/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for a LM1B dataset.""" import os from typing import Dict, Optional, List, Union from clu import deterministic_data import ml_collections import tensorflow as tf import tensorflow_datasets as tfds import tokenizer AUTOTUNE = tf.data.experimental.AUTOTUNE Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: """Normalizes feature names to 'inputs' and 'targets'.""" def __init__(self, ds_info: tfds.core.DatasetInfo): self.ds_info = ds_info def __call__(self, features: Features) -> Features: features['inputs'] = features.pop('text') # Unnecessary step used for uniformizing with examples/wmt. features['targets'] = features['inputs'] return features def get_raw_dataset( dataset_builder: tfds.core.DatasetBuilder, split: str ) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `split`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( split, num_examples, drop_remainder=False ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( NormalizeFeatureNamesOp(dataset_builder.info), num_parallel_calls=AUTOTUNE ) return ds def pack_dataset( dataset: tf.data.Dataset, key2length: int | dict[str, int], keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate "packed" version of a dataset to train efficiently on TPU. Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: _segmentation: an int32 tensor identifying the parts representing the original example. _position: an int32 tensor identifying the position within the original example. Example: Two input examples get combined to form an output example. The input examples are: {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} The output example is: { "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] } 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length "length", and the sequences in the output examples all have fixed (padded) length "length". Args: dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) Returns: a tf.data.Dataset """ shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) if keys is None: keys = list(shapes.keys()) for k in keys: if k not in shapes: raise ValueError( 'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys()) ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: for suffix in ['_segmentation', '_position']: key2length[k + suffix] = key2length[k] # trim to length dataset = dataset.map( lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE, ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( batch_size, padded_shapes={k: [-1] for k in keys} ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. def my_fn(x): return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) def _pack_with_tf_ops( dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: a dict from feature-key to integer Returns: a dataset. """ empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), ) return new_partial, new_outputs def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) outputs[k + '_position'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] ), ) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( [partial[k + '_position'], tf.range(new_seq_len)], 0 ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( cond=lambda *_: True, body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] ), maximum_iterations=dynamic_batch_size, ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = tf.cumsum( tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch() # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- def preprocess_data( dataset, shuffle: bool, num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, batch_size: int = 256, drop_remainder: bool = True, prefetch_size: int = AUTOTUNE, ): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) return tf.less(l, max_len + 1) return filter_fn if max_length > 0: dataset = dataset.filter(length_filter(max_length)) if shuffle: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat(num_epochs) if pack_examples: dataset = pack_dataset(dataset, max_length) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, padded_shapes={'inputs': max_length, 'targets': max_length}, padding_values={'inputs': 0, 'targets': 0}, drop_remainder=drop_remainder, ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) return dataset def get_datasets( config: ml_collections.ConfigDict, *, n_devices: int, vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.expanduser('~/lm1b_sentencepiece_model') train_ds_builder = tfds.builder(config.dataset_name) train_data = get_raw_dataset(train_ds_builder, 'train') if config.eval_dataset_name: eval_ds_builder = tfds.builder(config.eval_dataset_name) else: eval_ds_builder = train_ds_builder eval_data = get_raw_dataset(eval_ds_builder, config.eval_split) # Tokenize data. sp_tokenizer = tokenizer.load_or_train_tokenizer( train_data, vocab_path=vocab_path, vocab_size=config.vocab_size, max_corpus_chars=config.max_corpus_chars, ) train_data = train_data.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE ) eval_data = eval_data.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE ) batch_size = config.per_device_batch_size * n_devices if config.eval_per_device_batch_size > 0: eval_batch_size = config.eval_per_device_batch_size * n_devices else: eval_batch_size = batch_size train_ds = preprocess_data( train_data, shuffle=True, num_epochs=None, pack_examples=True, batch_size=batch_size, max_length=config.max_target_length, ) eval_ds = preprocess_data( eval_data, shuffle=False, pack_examples=False, batch_size=eval_batch_size, max_length=config.max_eval_target_length, ) predict_ds = preprocess_data( eval_data, shuffle=False, pack_examples=False, batch_size=eval_batch_size, max_length=config.max_predict_length, drop_remainder=False, ) return train_ds, eval_ds, predict_ds, sp_tokenizer ================================================ FILE: examples/lm1b/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pathlib import sys import tempfile from absl.testing import absltest import tensorflow_datasets as tfds from configs import default import input_pipeline # We just use different values here to verify that the input pipeline uses the # the correct value for the 3 different datasets. _TARGET_LENGTH = 32 _EVAL_TARGET_LENGTH = 48 _PREDICT_TARGET_LENGTH = 64 class InputPipelineTest(absltest.TestCase): def setUp(self): super().setUp() if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') self.train_ds, self.eval_ds, self.predict_ds = self._get_datasets() def _get_datasets(self): config = default.get_config() config.per_device_batch_size = 1 config.eval_per_device_batch_size = 2 config.vocab_size = 32 config.max_corpus_chars = 1000 config.max_target_length = _TARGET_LENGTH config.max_eval_target_length = _EVAL_TARGET_LENGTH config.max_predict_length = _PREDICT_TARGET_LENGTH vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train_ds, eval_ds, predict_ds, _ = input_pipeline.get_datasets( n_devices=2, config=config, vocab_path=vocab_path ) return train_ds, eval_ds, predict_ds def test_train_ds(self): expected_shape = [2, _TARGET_LENGTH] # 2 devices. # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in self.train_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'inputs_position': expected_shape, 'inputs_segmentation': expected_shape, 'targets': expected_shape, 'targets_position': expected_shape, 'targets_segmentation': expected_shape, }, ) def test_eval_ds(self): expected_shape = [4, _EVAL_TARGET_LENGTH] # 2 devices. for batch in self.eval_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'targets': expected_shape, }, ) def test_predict_ds(self): expected_shape = [4, _PREDICT_TARGET_LENGTH] # 2 devices. for batch in self.predict_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'targets': expected_shape, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/lm1b/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the Language Modelling example with LM1B. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', lock_config=True, ) flags.mark_flags_as_required(['config', 'workdir']) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': jax.config.config_with_absl() app.run(main) ================================================ FILE: examples/lm1b/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer-based language model. Reusing decoder only model from examples/wmt. """ # pylint: disable=attribute-defined-outside-init # See issue #620. # pytype: disable=wrong-arg-count # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error from typing import Any, Optional from collections.abc import Callable from flax import linen as nn from flax import struct from jax import lax import jax.numpy as jnp import numpy as np @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" vocab_size: int output_vocab_size: int share_embeddings: bool = False logits_via_embedding: bool = False dtype: Any = jnp.float32 emb_dim: int = 512 num_heads: int = 8 num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 2048 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 deterministic: bool = False decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) posemb_init: Callable | None = None def shift_right(x, axis=1): """Shift the input to the right by padding and slicing on axis.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( x, pad_widths, mode='constant', constant_values=x.dtype.type(0) ) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) def shift_inputs(x, segment_ids=None, axis=1): """Shift inputs and replace EOS by 0 for packed inputs.""" shifted = shift_right(x, axis=axis) # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: shifted *= segment_ids == shift_right(segment_ids, axis=axis) return shifted def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): """1D Sinusoidal Position Embedding Initializer. Args: max_len: maximum possible length for the input. min_scale: float: minimum frequency-scale in sine grating. max_scale: float: maximum frequency-scale in sine grating. Returns: output: init function returning `(1, max_len, d_feature)` """ def init(key, shape, dtype=np.float32): """Sinusoidal init.""" del key, dtype d_feature = shape[-1] pe = np.zeros((max_len, d_feature), dtype=np.float32) position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) pe[:, : d_feature // 2] = np.sin(position * div_term) pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) return init class AddPositionEmbs(nn.Module): """Adds (optionally learned) positional embeddings to the inputs. Args: config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ config: TransformerConfig decode: bool = False @nn.compact def __call__(self, inputs, inputs_positions=None): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a learned position embedding is desired, pass an initializer to posemb_init in the configuration. Args: inputs: input data. inputs_positions: input position indices for packed sequences. Returns: output: `(bs, timesteps, in_dim)` """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) assert inputs.ndim == 3, ( 'Number of dimensions should be 3, but it is: %d' % inputs.ndim ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. pos_embedding = sinusoidal_init(max_len=config.max_len)( None, pos_emb_shape, None ) else: pos_embedding = self.param( 'pos_embedding', config.posemb_init, pos_emb_shape ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') cache_index = self.variable( 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) ) if is_initialized: i = cache_index.value cache_index.value = i + 1 _, _, df = pos_embedding.shape pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) if inputs_positions is None: # normal unpacked case: return inputs + pe else: # for packed data we need to use known position indices: return inputs + jnp.take(pe[0], inputs_positions, axis=0) class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. Args: config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ config: TransformerConfig out_dim: int | None = None @nn.compact def __call__(self, inputs): """Applies Transformer MlpBlock module.""" config = self.config actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=nn.with_logical_partitioning( config.kernel_init, ('embed', 'mlp') ), bias_init=nn.with_logical_partitioning(config.bias_init, ('mlp',)), )(inputs) x = nn.relu(x) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=nn.with_logical_partitioning( config.kernel_init, ('mlp', 'embed') ), bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)), )(x) output = nn.Dropout(rate=config.dropout_rate)( output, deterministic=config.deterministic ) return output class EncoderDecoder1DBlock(nn.Module): """Transformer encoder-decoder layer. Args: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: inputs: input data for decoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ config = self.config # Decoder block. assert inputs.ndim == 3 x = nn.LayerNorm( dtype=config.dtype, bias_init=nn.with_logical_partitioning( nn.initializers.zeros, ('embed',) ), scale_init=nn.with_logical_partitioning( nn.initializers.ones, ('embed',) ), )(inputs) x = nn.MultiHeadDotProductAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=nn.with_logical_partitioning( config.kernel_init, ('embed', 'kv') ), bias_init=nn.with_logical_partitioning(config.bias_init, ('embed',)), use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, decode=config.decode, )(x, mask=decoder_mask) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) x = x + inputs # MLP block. z = nn.LayerNorm( dtype=config.dtype, bias_init=nn.with_logical_partitioning( nn.initializers.zeros, ('embed',) ), scale_init=nn.with_logical_partitioning( nn.initializers.ones, ('embed',) ), )(x) z = MlpBlock(config=config)(z) return x + z class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. Args: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ config: TransformerConfig shared_embedding: Any = None @nn.compact def __call__( self, inputs, inputs_positions=None, inputs_segmentation=None, decoder_mask=None, encoder_decoder_mask=None, ): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. inputs: input data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output of a transformer decoder. """ config = self.config assert inputs.ndim == 2 # (batch, len) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=config.output_vocab_size, features=config.emb_dim, embedding_init=nn.with_logical_partitioning( nn.initializers.normal(stddev=1.0), ('vocab', 'embed') ), ) else: output_embed = self.shared_embedding y = inputs.astype('int32') if not config.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) y = output_embed(y) y = AddPositionEmbs( config=config, decode=config.decode, name='posembed_output' )(y, inputs_positions=inputs_positions) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic ) y = y.astype(config.dtype) # Target-Input Decoder for lyr in range(config.num_layers): y = EncoderDecoder1DBlock( config=config, name=f'encoderdecoderblock_{lyr}' )(y, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) y = nn.LayerNorm( dtype=config.dtype, name='encoderdecoder_norm', bias_init=nn.with_logical_partitioning( nn.initializers.zeros, ('embed',) ), scale_init=nn.with_logical_partitioning( nn.initializers.ones, ('embed',) ), )(y) # Decoded Logits if config.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense( config.output_vocab_size, dtype=config.dtype, kernel_init=nn.with_logical_partitioning( config.kernel_init, ('embed', 'vocab') ), bias_init=nn.with_logical_partitioning(config.bias_init, ('vocab',)), name='logitdense', )(y) return logits class TransformerLM(nn.Module): """Transformer pure decoder stack for language modelling. Args: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies TransformerLM on the inputs. Args: inputs: target data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=config.dtype), nn.make_causal_mask(inputs, dtype=config.dtype), ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype, ), ) logits = Decoder(config=config, shared_embedding=None, name='decoder')( inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, decoder_mask=decoder_mask, encoder_decoder_mask=None, ) return logits.astype(self.config.dtype) ================================================ FILE: examples/lm1b/requirements.txt ================================================ absl-py==1.4.0 clu==0.0.9 flax==0.6.11 jax==0.4.13 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==0.4.13+cuda11.cudnn82 # Make sure CUDA version matches the base image. ml-collections==0.1.1 numpy==1.24.3 optax==0.1.5 sentencepiece==0.1.99 tensorflow==2.13.0 tensorflow-datasets==4.9.2 tensorflow-text==2.13.0 ================================================ FILE: examples/lm1b/temperature_sampler.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fast decoding routines for inference from a trained language model.""" from jax import lax from jax import random import jax.numpy as jnp # Constants # The default End-of-Sentence token id is 2 (SentencePiece). EOS_ID = 2 def temperature_sample( prompt_inputs, init_cache, tokens_to_logits, prng_key, temperature=1.0, topk=20, eos_token=EOS_ID, ): """Temperature sampling for language model generation. Args: prompt_inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. init_cache: flax attention cache. tokens_to_logits: fast autoregressive decoder function taking single token slices and cache and returning next-token logits and updated cache. prng_key: JAX PRNGKey. temperature: float: sampling temperature factor. As it approaches zero this becomes equivalent to greedy sampling. topk: integer: if nonzero only use the top-k logits to sample next token, if zero don't use any cutoff and sample from full logits over vocabulary. eos_token: int: end-of-sentence token for target vocabulary. Returns: Array of sampled sequences: [batch_size, max_decode_len] """ batch_size = prompt_inputs.shape[0] max_decode_len = prompt_inputs.shape[1] end_marker = jnp.array(eos_token) temperature = jnp.array(temperature) # Initialize sampling loop state. # initial loop PRNGKey rng0 = prng_key # loop position counter. i0 = jnp.array(-1) # per batch-item holding current token in loop. token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) # per batch-item state bit indicating if sentence has finished. ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) # (batch, length) array containing prefix prompt tokens for sampling loop # as well as the generated output of newly sampled tokens. sequences0 = prompt_inputs # Sampling loop state is stored in a simple tuple. sampling_loop_init_state = (i0, sequences0, init_cache, token0, ended0, rng0) def sampling_loop_cond_fn(state): """Sampling loop termination condition.""" (i, _, _, _, ended, _) = state # Have we reached max decoding length? not_at_end = i < max_decode_len - 1 # Have all sampled sequences reached an end marker? all_sequences_ended = jnp.all(ended) return not_at_end & (~all_sequences_ended) def sampling_loop_body_fn(state): """Sampling loop state update.""" i, sequences, cache, cur_token, ended, rng = state # Split RNG for sampling. rng1, rng2 = random.split(rng) # Call fast-decoder model on current tokens to get next-position logits. logits, new_cache = tokens_to_logits(cur_token, cache) # Sample next token from logits. # TODO(levskaya): add top-p "nucleus" sampling option. if topk: # Get top-k logits and their indices, sample within these top-k tokens. topk_logits, topk_idxs = lax.top_k(logits, topk) topk_token = jnp.expand_dims( random.categorical(rng1, topk_logits / temperature).astype(jnp.int32), axis=-1, ) # Return the original indices corresponding to the sampled top-k tokens. next_token = jnp.squeeze( jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1 ) else: next_token = random.categorical(rng1, logits / temperature).astype( jnp.int32 ) # Only use sampled tokens if we're past provided prefix tokens. out_of_prompt = sequences[:, i + 1] == 0 next_token = ( next_token * out_of_prompt + sequences[:, i + 1] * ~out_of_prompt ) # If end-marker reached for batch item, only emit padding tokens. next_token_or_endpad = next_token[None] * ~ended ended |= next_token_or_endpad == end_marker # Add current sampled tokens to recorded sequences. new_sequences = lax.dynamic_update_slice( sequences, next_token_or_endpad, (0, i + 1) ) return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2) # Run sampling loop and collect final state. final_state = lax.while_loop( sampling_loop_cond_fn, sampling_loop_body_fn, sampling_loop_init_state ) # Pick part of the state corresponding to the sampled sequences. final_sequences = final_state[1] return final_sequences ================================================ FILE: examples/lm1b/temperature_sampler_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest import jax import jax.numpy as jnp import numpy as np from temperature_sampler import temperature_sample jax.config.update('jax_disable_most_optimizations', True) class TestTemperatureSampler(absltest.TestCase): def test_temperature_sampler(self): tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32) cache = None key = jax.random.PRNGKey(0) def tokens_to_logits(tokens, cache): jax.debug.print('tokens: {}', tokens) logits = jax.nn.one_hot(tokens[..., -1:] + 1, 10) logits = jnp.where(logits < 0.5, float('-inf'), logits) logits = logits.squeeze(axis=1) return logits, cache new_tokens = temperature_sample( tokens, cache, tokens_to_logits, key, topk=5 ) np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]]) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/lm1b/tokenizer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides op for tokenizing a dataset.""" import os import sys import tempfile import time from typing import Any, Dict, Tuple from collections.abc import Iterable from absl import logging import dataclasses import jax from sentencepiece import SentencePieceTrainer import tensorflow as tf if sys.version_info < (3, 13): import tensorflow_text as tftxt Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), ) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: dataset: tf.dataset containing string-data. maxchars: int: approximate number of characters to save from dataset. data_keys: Tuple[str]: what keys in dataset to dump from. Returns: name of temp file with dataset bytes, exact number of characters dumped. """ char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/ds_chars' ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: line = example[k] + b'\n' char_count += len(line) outfp.write(line) return outfp.name, char_count def _train_sentencepiece( dataset: tf.data.Dataset, *, vocab_size: int, maxchars: int = int(1e7), model_path: str, model_type: str = 'unigram', character_coverage: float = 1.0, data_keys=('inputs', 'targets'), ): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. maxchars: int: number of characters to use for sentencepiece training. model_path: str: path of model file to save vocab model to. model_type: str: type of sentencepiece vocab to train. character_coverage: amount of characters covered by the model, good defaults are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. data_keys: Tuple[str]: keys of dataset to use for training. Returns: path to the trained sentencepiece vocabulary model. """ if model_path.startswith('gs://'): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( dataset, maxchars=maxchars, data_keys=data_keys ) with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/sp_tmp' ) as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = abs_model_path + '.rntmp' tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path def _load_sentencepiece_tokenizer( model_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False, ): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse ) return sp_tokenizer def load_or_train_tokenizer( dataset: tf.data.Dataset, *, vocab_path: str, vocab_size: int, max_corpus_chars: int, data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) except tf.errors.NotFoundError: logging.info('SentencePiece vocab not found, building one from data.') vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, data_keys=data_keys, ) return _load_sentencepiece_tokenizer(vocab_path) @dataclasses.dataclass class TokenizeOp: sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') def __call__(self, features: Features) -> Features: for k in self.data_keys: features[k] = self.sp_tokenizer.tokenize(features[k]) return features ================================================ FILE: examples/lm1b/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Language Modeling example. This script trains a Transformer on a LM1B dataset. """ # pytype: disable=wrong-arg-count # pytype: disable=attribute-error import collections import os from absl import logging from clu import metric_writers from clu import periodic_actions from flax import linen as nn from flax.training import checkpoints from flax.training import common_utils import jax from jax import random import jax.numpy as jnp from jax.sharding import PartitionSpec as P, Mesh, NamedSharding import ml_collections import numpy as np import optax import tensorflow as tf import input_pipeline import models import temperature_sampler import utils def rsqrt_schedule( init_value: float, shift: int = 0, ): """Applies a reverse square-root schedule. The reverse square root schedule is simply `lr = init_value / sqrt(step)`. Args: init_value: Base learning rate (before applying the rsqrt schedule). shift: How many steps the rsqrt should be shifted. Shifting the rsqrt schedule makes it less steep in the beginning (close to 0). Returns: A schedule that applies the reverse square root. """ def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): """Compute weighted accuracy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { "loss": loss, "accuracy": acc, "denominator": weight_sum, } return metrics # Primary training / eval / decode step functions. # ----------------------------------------------------------------------------- def train_step( state, batch, config, learning_rate_fn, label_smoothing=0.0, dropout_rng=None, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this # metadata. # if such features are not present they are ignored and the example is treated # like a normal, unpacked sequence example. train_keys = ["inputs", "inputs_position", "inputs_segmentation"] (inputs, inputs_positions, inputs_segmentation) = ( batch.get(k, None) for k in train_keys ) weights = jnp.where(inputs > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(dropout_rng, state.step) def loss_fn(params): """loss function used for training.""" logits = models.TransformerLM(config).apply( {"params": params}, inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, rngs={"dropout": dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy( logits, inputs, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, inputs, weights) metrics["learning_rate"] = lr return new_state, metrics def eval_step(params, batch, config, label_smoothing=0.0): """Calculate evaluation metrics on a batch.""" inputs = batch["inputs"] weights = jnp.where(inputs > 0, 1.0, 0.0) logits = models.TransformerLM(config).apply({"params": params}, inputs) return compute_metrics(logits, inputs, weights, label_smoothing) def predict_step( inputs, params, rngkey, eos_id, max_decode_len, config, temperature, top_k ): """Predict language model on a batch.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.TransformerLM(config).init( jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype) ) cache = initial_variables["cache"] def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.TransformerLM(config).apply( {"params": params, "cache": flat_cache}, flat_ids, mutable=["cache"] ) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. seqs = temperature_sampler.temperature_sample( inputs, cache, tokens_ids_to_logits, rngkey, temperature=temperature, topk=top_k, eos_token=eos_id, ) return seqs # Utils for prediction # ----------------------------------------------------------------------------- def pad_examples(x, desired_batch_size): """Expand batch to desired size by repeating last slice.""" batch_pad = desired_batch_size - x.shape[0] return np.concatenate([x, np.tile(x[-1], (batch_pad, 1))], axis=0) def tohost(x): """Collect batches from all devices to host and flatten batch dimensions.""" n_device, n_batch, *remaining_dims = x.shape return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims)) def evaluate( *, jit_eval_step, params, eval_ds: tf.data.Dataset, num_eval_steps: int, config, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info("Gathering evaluation metrics.") eval_metrics = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access metrics = jit_eval_step(params, eval_batch, config) eval_metrics.append(metrics) eval_metrics = common_utils.stack_forest(eval_metrics) eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums, ) return eval_summary def generate_prediction( *, jit_pred_step, params, tokenized_prompts, eos_id, inference_rng, decode_tokens, config, predict_config, ): """Generate text from the prompt.""" n_devices = jax.local_device_count() logging.info("Generating text.") predictions = [] # Use batch of prompts provided by user. for pred_batch in jnp.array_split( tokenized_prompts, int(np.ceil(len(tokenized_prompts) / n_devices)) ): cur_pred_batch_size = pred_batch.shape[0] if cur_pred_batch_size % n_devices: padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_util.tree_map( lambda x: pad_examples(x, padded_size), pred_batch ) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) inference_rng, sub_rng = random.split(inference_rng) inference_rngs = random.split(sub_rng, n_devices) predicted = jit_pred_step( pred_batch, params, inference_rngs, eos_id, config.max_predict_length, predict_config, config.sampling_temperature, config.sampling_top_k, ) predicted = tohost(predicted) # Iterate through non-padding examples of batch. for s in predicted[:cur_pred_batch_size]: prediction = decode_tokens(s) logging.info("Sample: %s", str(prediction)) predictions.append(prediction) # Save generated texts for tensorboard. exemplars = "" for prediction in predictions: exemplars += f"{prediction}\n\n" return exemplars def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[: np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): tokenized_batch = np.zeros((len(strs), max_len), np.int32) for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. tokenized_batch[i, : toks.shape[0] - 1] = toks[:-1] return tokenized_batch tokenized_prompts = encode_strings( [config.prompts], config.max_predict_length ) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), ) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) # Mesh definition devices_array = utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) m = models.TransformerLM(eval_config) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ) state, state_mesh_annotations = utils.setup_initial_state( m, optimizer, config, init_rng, mesh ) data_sharding = NamedSharding(mesh, P(config.data_sharding)) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dict(config)) # compile multidevice versions of train/eval/predict step fn. jit_train_step = jax.jit( train_step, in_shardings=( state_mesh_annotations, data_sharding, None, ), # type: ignore out_shardings=(state_mesh_annotations, None), # type: ignore static_argnums=(2, 3, 4), donate_argnums=0, ) jit_eval_step = jax.jit( eval_step, in_shardings=( state_mesh_annotations.params, data_sharding, ), # type: ignore out_shardings=None, # type: ignore static_argnums=(2, 3), ) # Since the inputs and rngkey args for predict_step will be batched, # we must vmap them, otherwise the global arrays will be seen in each device jit_pred_step = jax.jit( jax.vmap( predict_step, in_axes=( 0, jax.tree_util.tree_map(lambda x: None, state.params), 0, None, None, jax.tree_util.tree_map(lambda x: None, predict_config), None, None, ), ), in_shardings=( data_sharding, state_mesh_annotations.params, data_sharding, ), # type: ignore out_shardings=data_sharding, # type: ignore static_argnums=(3, 4, 5, 6, 7), ) # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(train_iter) batch = jax.tree_util.tree_map(lambda x: jnp.array(x), batch) state, metrics = jit_train_step( state, batch, train_config, learning_rate_fn, 0.0, dropout_rngs ) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.stack_forest(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_util.tree_map( lambda x: x / denominator, metrics_sums ) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip( jnp.exp(summary["loss"]), max=1.0e4 ) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( jit_eval_step=jit_eval_step, params=state.params, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps, config=eval_config, ) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip( jnp.exp(eval_results["loss"]), max=1.0e4 ) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()} ) with report_progress.timed("generate_text"): exemplars = generate_prediction( jit_pred_step=jit_pred_step, params=state.params, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, config=config, predict_config=predict_config, ) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info("Saving checkpoint step %d.", step) with report_progress.timed("checkpoint"): checkpoints.save_checkpoint_multiprocess(workdir, state, step) ================================================ FILE: examples/lm1b/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pathlib import sys import tempfile from absl import logging from absl.testing import absltest import jax import tensorflow as tf import tensorflow_datasets as tfds from configs import default import train jax.config.update('jax_disable_most_optimizations', True) class TrainTest(absltest.TestCase): """Test cases for LM library.""" def setUp(self): super().setUp() if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') tf.config.experimental.set_visible_devices([], 'GPU') def test_train_and_evaluate(self): config = default.get_config() config.max_corpus_chars = 1000 config.vocab_size = 32 config.per_device_batch_size = 2 config.num_train_steps = 1 config.num_eval_steps = 1 config.num_predict_steps = 1 config.num_layers = 1 config.qkv_dim = 128 config.emb_dim = 128 config.mlp_dim = 512 config.num_heads = 2 config.max_target_length = 32 config.max_eval_target_length = 32 config.max_predict_length = 32 workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train.train_and_evaluate(config, workdir) logging.info('workdir content: %s', tf.io.gfile.listdir(workdir)) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/lm1b/utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). import functools import logging import numpy as np import flax.linen as nn from flax.linen import partitioning as nn_partitioning from flax.training import train_state import jax import jax.numpy as jnp from jax.experimental import mesh_utils # Mesh utils. # ----------------------------------------------------------------------------- def create_device_mesh(config): """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" devices = jax.devices() num_devices = len(devices) try: num_slices = 1 + max([d.slice_index for d in devices]) except: num_slices = 1 num_devices_per_slice = num_devices // num_slices logging.info(f"Devices: {devices}") logging.info(f"Number of devices: {num_devices}") multi_slice_env = hasattr(jax.devices()[0], "slice_index") dcn_parallelism = [ config.dcn_data_parallelism, config.dcn_fsdp_parallelism, config.dcn_tensor_parallelism, ] ici_parallelism = [ config.ici_data_parallelism, config.ici_fsdp_parallelism, config.ici_tensor_parallelism, ] # Find possible unspecified parallelisms dcn_parallelism = fill_unspecified_mesh_axes( dcn_parallelism, num_slices, "DCN" ) ici_parallelism = fill_unspecified_mesh_axes( ici_parallelism, num_devices_per_slice, "ICI" ) if multi_slice_env: mesh = mesh_utils.create_hybrid_device_mesh( ici_parallelism, dcn_parallelism ) else: mesh = mesh_utils.create_device_mesh(ici_parallelism) logging.info(f"Decided on mesh: {mesh}") logging.info(f"Mesh shape: {mesh.shape}") return mesh def fill_unspecified_mesh_axes( parallelism_vals, target_product, parallelism_type ): """Evaluates unspecified DCN/ICI parallelism values""" if -1 in parallelism_vals: assert parallelism_vals.count(-1) == 1, ( f"Found unspecified values (-1) for more than one {parallelism_type} " " parallelism axis. At most one axis can be unspecified." ) determined_val = target_product / np.prod(parallelism_vals) * -1 assert determined_val >= 1 and determined_val.is_integer, ( "Unspecified value unable to be determined with the given " f" {parallelism_type} parallelism values" ) parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) target_type = "slices" if parallelism_type == "DCN" else "devices per slice" assert np.prod(parallelism_vals) == target_product, ( f"Number of {target_type} {target_product} does not match the product" f" of the {parallelism_type} parallelism {np.prod(parallelism_vals)}" ) return parallelism_vals # State initialization utils. # ----------------------------------------------------------------------------- def unbox_logicallypartioned_trainstate( boxed_train_state: train_state.TrainState, ): """Unboxes the flax.LogicallyPartitioned pieces in a train state. Args: boxed_train_state: a train state that includes LogicallyPartitioned leaves. Returns: a TrainState where all all LogicallyPartitioned leaves have been unboxed. """ return jax.tree_util.tree_map( lambda x: x.unbox() if isinstance(x, nn.spmd.LogicallyPartitioned) else x, boxed_train_state, is_leaf=lambda k: isinstance(k, nn.spmd.LogicallyPartitioned), ) def init_train_state(model, tx, config, key): """ We pass in "static" objects like model, tx, config as JAX compares them by object hash, and instantiating them inside causes pjit top-level annotations to fail to match as pytree prefixes if we re-instantiate. Args: model, tx, config, key """ input_shape = (config.per_device_batch_size, config.max_target_length) initial_variables = jax.jit(model.init)( key, jnp.ones(input_shape, jnp.float32) ) state = train_state.TrainState.create( apply_fn=model.apply, params=initial_variables["params"], tx=tx ) return state def setup_initial_state(model, tx, config, rng, mesh): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: model: the flax model to initialize tx: the optax.GradientTransformation config: config object rng: jax.prng key mesh: jax.devices() mesh Returns: state: the initialized train state state_mesh_annotations: the mesh annotations for the train state """ init_train_state_partial = functools.partial( init_train_state, model, tx, config ) abstract_state = jax.eval_shape(init_train_state_partial, rng) state_logical_annotations = nn.get_partition_spec(abstract_state) # Initialization with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh_sharding( state_logical_annotations, mesh, config.logical_axis_rules ) state = jax.jit( init_train_state_partial, in_shardings=None, # type: ignore out_shardings=state_mesh_annotations, )(rng) state = unbox_logicallypartioned_trainstate(state) return state, state_mesh_annotations ================================================ FILE: examples/mnist/README.md ================================================ ## MNIST classification Trains a simple convolutional network on the MNIST dataset. You can run this code and even modify it directly in Google Colab, no installation required: https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb ### Requirements * TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary ### Example output | Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | | :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | | default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | [tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default ``` I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25 ``` ### How to run `python main.py --workdir=/tmp/mnist --config=configs/default.py` #### Overriding Hyperparameter configurations MNIST example allows specifying a hyperparameter configuration by the means of setting `--config` flag. Configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py \ --workdir=/tmp/mnist --config=configs/default.py \ --config.learning_rate=0.05 --config.num_epochs=5 ``` ================================================ FILE: examples/mnist/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.learning_rate = 0.1 config.momentum = 0.9 config.batch_size = 128 config.num_epochs = 10 return config def metrics(): return [] ================================================ FILE: examples/mnist/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the MNIST example. This file is intentionally kept short. The majority of logic is in libraries than can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train # pylint: disable=g-bad-import-order FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': flags.mark_flags_as_required(['config', 'workdir']) app.run(main) ================================================ FILE: examples/mnist/mnist.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Flax MNIST Example\n", "\n", "\"Open\n", "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/mnist" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", "1. Run the entire notebook end-to-end and check out the outputs.\n", " - This will open Python files in the right-hand editor!\n", " - You'll be able to interactively explore metrics in TensorBoard.\n", "2. Change `config` and train for different hyperparameters. Check out the\n", " updated TensorBoard plots.\n", "3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you\n", " make in the file will automatically appear in the notebook. Some ideas to\n", " get you started:\n", " - Change the model.\n", " - Log some per-batch metrics during training.\n", " - Add new hyperparameters to `configs/default.py` and use them in\n", " `train.py`.\n", "4. At any time, feel free to paste code from `train.py` into the notebook\n", " and modify it directly there!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "outputId": "8520b2f8-2b9d-4216-ba1f-d96175455bbc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l\r\n", "\u001b[K |███▊ | 10kB 21.5MB/s eta 0:00:01\r\n", "\u001b[K |███████▍ | 20kB 12.7MB/s eta 0:00:01\r\n", "\u001b[K |███████████ | 30kB 9.4MB/s eta 0:00:01\r\n", "\u001b[K |██████████████▉ | 40kB 8.4MB/s eta 0:00:01\r\n", "\u001b[K |██████████████████▌ | 51kB 5.2MB/s eta 0:00:01\r\n", "\u001b[K |██████████████████████▏ | 61kB 5.2MB/s eta 0:00:01\r\n", "\u001b[K |█████████████████████████▉ | 71kB 5.7MB/s eta 0:00:01\r\n", "\u001b[K |█████████████████████████████▋ | 81kB 6.3MB/s eta 0:00:01\r\n", "\u001b[K |████████████████████████████████| 92kB 4.4MB/s \n", "\u001b[K |████████████████████████████████| 634kB 8.5MB/s \n", "\u001b[K |████████████████████████████████| 102kB 6.0MB/s \n", "\u001b[K |████████████████████████████████| 61kB 6.3MB/s \n", "\u001b[?25h Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "# Install ml-collections & latest Flax version from Github.\n", "!pip install -q ml-collections git+https://github.com/google/flax" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "example_directory = 'examples/mnist'\n", "editor_relpaths = ('configs/default.py', 'train.py')\n", "\n", "repo, branch = 'https://github.com/google/flax', 'main'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "outputId": "2dfbdfa6-d213-4b5b-dc82-ee1765705255" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'flaxrepo'...\n", "remote: Enumerating objects: 330, done.\u001b[K\n", "remote: Counting objects: 100% (330/330), done.\u001b[K\n", "remote: Compressing objects: 100% (298/298), done.\u001b[K\n", "remote: Total 330 (delta 58), reused 126 (delta 14), pack-reused 0\u001b[K\n", "Receiving objects: 100% (330/330), 1.81 MiB | 6.59 MiB/s, done.\n", "Resolving deltas: 100% (58/58), done.\n" ] }, { "data": { "text/html": [ "

WARNING : Editing in VM - changes lost after reboot!!

" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/mnist/configs/default.py\")" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/mnist/train.py\")" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", "# example directory and nothing needs to be done.)\n", "\n", "#@markdown **Fetch newest Flax, copy example code**\n", "#@markdown\n", "#@markdown **If you select no** below, then the files will be stored on the\n", "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n", "#@markdown be restarted an any changes are lost**.\n", "#@markdown\n", "#@markdown **If you select yes** below, then you will be asked for your\n", "#@markdown credentials to mount your personal Google Drive. In this case, all\n", "#@markdown changes you make will be *persisted*, and even if you re-run the\n", "#@markdown Colab later on, the files will still be the same (you can of course\n", "#@markdown remove directories inside your Drive's `flax/` root if you want to\n", "#@markdown manually revert these files).\n", "\n", "if 'google.colab' in str(get_ipython()):\n", " import os\n", " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", " !git clone --depth=1 -b $branch $repo flaxrepo\n", " # Copy example files & change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", " if mount_gdrive == 'yes':\n", " DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n", " else:\n", " DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'\n", " example_root_path = f'/content/{example_directory}'\n", " from IPython import display\n", " display.display(display.HTML(\n", " f'

{DISCLAIMER}

'))\n", " if not os.path.isdir(example_root_path):\n", " os.makedirs(example_root_path)\n", " !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n", " os.chdir(example_root_path)\n", " from google.colab import files\n", " for relpath in editor_relpaths:\n", " s = open(f'{example_root_path}/{relpath}').read()\n", " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "e9061488-ac3e-4d23-f24f-06e1988e7541" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/content/examples/mnist\n" ] } ], "source": [ "# Note : In Colab, above cell changed the working directory.\n", "!pwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports / Helpers" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from absl import logging\n", "import flax\n", "import jax.numpy as jnp\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import tensorflow_datasets as tfds\n", "\n", "logging.set_verbosity(logging.INFO)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Helper functions for images.\n", "\n", "def show_img(img, ax=None, title=None):\n", " \"\"\"Shows a single image.\"\"\"\n", " if ax is None:\n", " ax = plt.gca()\n", " ax.imshow(img[..., 0], cmap='gray')\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " if title:\n", " ax.set_title(title)\n", "\n", "def show_img_grid(imgs, titles):\n", " \"\"\"Shows a grid of images.\"\"\"\n", " n = int(np.ceil(len(imgs)**.5))\n", " _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))\n", " for i, (img, title) in enumerate(zip(imgs, titles)):\n", " show_img(img, axs[i // n][i % n], title)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Local imports from current directory - auto reload.\n", "# Any changes you make to train.py will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", "import train\n", "from configs import default as config_lib\n", "config = config_lib.get_config()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "outputId": "bb4525f4-8ca4-4e9d-d1cc-48a3e0533645", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: mnist/3.0.1\n", "INFO:absl:Load dataset info from /tmp/tmpqyu1t56xtfds\n", "INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\n", "INFO:absl:Generating dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your\n", "local data directory. If you'd instead prefer to read directly from our public\n", "GCS bucket (recommended if you're running on GCP), you can instead pass\n", "`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1c176027dbbf459b8ed946ccc58de845", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load dataset info from /root/tensorflow_datasets/mnist/3.0.1.incompleteJ6TJES\n", "INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\n", "INFO:absl:Constructing tf.data.Dataset for split train, from /root/tensorflow_datasets/mnist/3.0.1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/mnist/3.0.1\n" ] } ], "source": [ "# Get datasets as dict of JAX arrays.\n", "train_ds, test_ds = train.get_datasets()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "outputId": "89de05b0-aede-414f-cf43-5e7c71871140" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAANRCAYAAAAGcOaXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdebzOdf7/8dfLcpBdSTE4FTIVGdF8lVC0qGnKpMVSGi2GNMUwE9+0KISMqUgLRX1Li1TTIC1M2qcUqdMyKrJn32V7//44x2/OeL2PPud9Xde5tsf9drtut+PZ5/P+vM7p7bhePj6vo845AQAAAAAUT6lkFwAAAAAA6YhmCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgQFY1U6q6RFU7RDjOqWqDwGsEnwscjD2LdMOeRbphzyLdsGdTS1Y1U6lOVXNU9UtVXZ7sWoBDUdUzVXWuqm5W1SXJrgf4OZpvpKquL3iNVFVNdl1AUVR1lqpuK/TaraqLkl0XUBRVLaeqD6nqGlXdoKqvqGqdZNeVaDRTqWWgiKxNdhFABNtF5DHJ37NAOrheRC4WkZNFpKmIXCgivZJaEXAIzrmOzrlKB14i8p6IPJ/suoBDuElEWkn+99jaIrJRRB5IakUlICubKVU9VVXfV9VNqrpKVcepas5Bh52vqt+p6jpVHa2qpQqd37PgDtJGVZ2tqvXjUNMxItJdREbEuhYyT6rtWefcv5xzT4rId7Gsg8yVantWRHqIyBjn3HLn3AoRGSMiV8e4JjJICu7ZwrXlisgZIvJEvNZE+kvBPXuMiMx2zq1xzu0SkWdF5MQY10x5WdlMicg+EeknIkdIfgfdXkT6HHRMJxFpISLNReQiEekpIqKqF4nIYBH5nYjUFJG3RWSq7yKqekvBBve+Djr8gYJ1d8bjE0TGScU9CxxKqu3ZE0VkYaFfL5Qs+EMexZJqe7awq0Tkbefcklg+QWScVNuzk0TkdFWtraqHiUg3EZkVp881dTnnsuYlIktEpIMnv1lEXiz0ayci5xX6dR8RebPg41kick2h/1ZKRHaISP1C5zYoZl2dRGRWwcftRGR5sr9WvFLjlap7ttBaHURkSbK/TrxS55Wqe1by33Q0LvTrhgXraLK/ZryS+0rVPXtQLYtF5Opkf614pcYrVfesiFQVkWcKzt0rIp+KSI1kf70S/crKO1Oq2khV/6Gqq1V1i4gMl/yuvrBlhT5eKvn/9lNEpL6I3FeoG98gIioiQQ/YqWpFERklIn8MOR/ZIZX2LBBFCu7ZbSJSpdCvq4jINlfwDgBIwT17oK7WInKUiEyLdS1klhTcs+NFpJyIHC4iFUVkumTBnamsbKZEZIKIfCUiDZ1zVST/NufBU53qFvq4noisLPh4mYj0cs5VK/Sq4Jx77+CLqOpg/e9JPP/1KjisoYjkisjbqrpa8jfe0QW/MXLj9Pki/aXSngWiSLU9+4XkD5844OSCDDgg1fbsAT1EZLpzju/BOFiq7dlmIjLZObfBOfeT5D/CcqqqHtzgZZRsbaYqi8gWEdmmqo1FpLfnmIGqWl1V60r+dJJnC/KHRGSQqp4oIqKqVVX1Ut9FnHPDXaFJPAe/Cg77XPI3erOC17Uisqbg42W+dZGVUmnPiqqWUtXyIlI2/5daXu1Dr8huKbVnJf/B/f6qWkdVa4vIn0Rkclw+U2SKVNuzoqoVROQyYa/CL9X27EciclXBWmUl/58VrnTOrYvPp5uasrWZGiAiXUVkq4g8Kv/ZWIW9LCLzRWSBiMyQ/IfqxDn3ooiMFJFnCm6pfi4iHUMLcc7tdc6tPvCS/Nus+wt+vS90XWSclNmzBdpI/rCUmZL/N107ReS1GNdEZkm1PfuwiLwiIosK1ptRkAEHpNqeFckf579JRObGYS1knlTbswNEZJeI/Fvyf9TP+ZI/FyCjKf9cHAAAAACKL1vvTAEAAABATGimAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEKBMcQ5WVUb/oViccwf/8LgSxZ5FcbFnkYbWOedqJuvi7FkEYM8i3RS5Z7kzBQBAelua7AKAYmLPIt0UuWdppgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQoEyyCwAAIFSjRo1MNn78eJO1b9/eZJMnTzZZnz59TLZr166w4gAAGY87UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAjAAAoAQNo67bTTTHbWWWeZzDlnsh49ephs3759JrvhhhtMtnv37qglAkBWqFy5ssl83z+HDx9uslWrVpnshBNOMNnmzZsDq0sc7kwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgAAMoUkDnzp1N9txzz5msV69eJnv00UcTUhNwQIUKFUz24IMPmuywww4zWZcuXUy2f//++BSGrHLeeed587/97W9xvU7Pnj1NlpeXZ7KxY8fG9boAEA/HHnusyXyDIS655BKTlS9fPvKavmzhwoUmu+qqq0zmGwh09NFHR6qHARQAAAAAkCFopgAAAAAgAM0UAAAAAASgmQIAAACAAAygSAFdu3Y1me/hvBo1apREOchiqmqyhx9+2GTdu3ePtN6IESNMtmDBguIXhqziG2YydOhQ77G+B6vjbciQISZjAEV6mzt3rsnatWtnspEjR5rslltuSURJwP9Xrlw5kx1zzDEmmzBhgsl+9atfmaxKlSom873PLA7f+4WTTz45pjXTFXemAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFCWsfv36JuvYsaPJ5s+fb7Knn346ITUBB5xwwgkmizpsYsuWLSZbv359zDUh+7zwwgsma9GihffYqA9R+wafNGvWLNK5ZcrwR2W68D0Uf/zxx5vM95D+/v37TXbTTTeZbN++fSabPn26yYram19//bU3P9hZZ51lsmOPPdZkS5YsMdnMmTNNtmfPnkjXRcny/T997rnnTObbs1G9++67Jvv2229NNmPGDO/5mzZtMtns2bOD6/FZsWKFyXbt2hXXayQKd6YAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQIOOfqvU9jOoT60+CjuqPf/yjyXJyckz23XffmWzZsmUJqQk44NJLLw0+94cffjAZexY/59prrzVZu3btYlrT9/2zbdu2JvMNuujQoYPJfAMojjvuOJP5HuhGyWrSpInJPv300+D1fH8+33LLLZGyZHr77bdN1qlTJ5Nt3LixJMpBAd/AsaKGPhxs69atJps7d67JRo8ebTLfAIriuPLKKyMdt23btkjHVa5c2WRvvvmmyTZv3hxpvWTjzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACZPwACt+DzGPHjjXZH/7wB5N98MEHca/H93Csz4IFC+J+beDn3HTTTZGO27t3r8lGjBgR73KQYa666iqTjRs3zmRly5aNvObixYtNdu6555rM92D0+vXrI12jXLlyJvP92cIAipJVv359k7300kvB623ZssVk+/fvN1n16tVNVpwhVr7BWFHP9z2QX7VqVZO1adPGZMOGDTNZnz59Il0XxXfiiSeazLc/ff/v//Wvf5msc+fOJluxYkVgdcUzf/58k40fP95ky5cvN1m/fv1MVqlSJZP17t07sLrk484UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAmT8AIqdO3eazDcEom3btiaLdQDFL37xi0jX8f1U6ylTpsR0beBQqlWr5s19DzL7rF271mRTp06NqSZkljp16phs0KBBJos6bGLVqlXevFevXiZbsmRJpDVj0b59e5NNmjQp4dfFf1x//fUm8w2l8Bk5cqTJ/va3v5nM9x7irLPOinSNRPj8889N9s0330Q6t3LlyvEuB4fQtGlTk5UpE+1t9/nnn2+yjRs3xlxTqLy8PJPdeOONJuvSpYvJatasabIdO3aYzPd7LV1wZwoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABMn4AxY8//pi0a3fq1MlkvoetP/74Y5MV9bA1EA9Dhw6N6fxFixbFqRJkAt+wnZkzZ5qsUaNGwdcYNWqUN//nP/8ZvGYsTjzxxKRcN1u1bt3aZDfffHPwevfff7/Jor5fePnll4OvG6sGDRpEOs45Z7Jzzz3XZOXLlzfZrl27il8YjF/96lfB555yyikme+ONN2Ipp0QMHDgw0nFjxoxJcCUliztTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJDxAyhq1KiRtGvXrl070nHJeoAa2evaa6+N6fz77rsvTpUgE0yaNMlksQxoWLBggckmT54cvF4ipFo9mc43HMI3PGH37t0mGzdunMk2btwYn8JKWNeuXSMdp6ommz17tskYNpE4Tz31lMkGDBgQ6dzXXnst0nH/+Mc/TObb276hZi+99JLJPvjgg0jXFRHp0aOHyZo1a2ay1atXm+yOO+6IfJ10wJ0pAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABMj4ARSdOnUyme/BzFjUqVPHm/fu3TvStR977LG41gPE06ZNm0z2+uuvJ6ESpIJzzz3XZGeffXbwetu3bzfZxRdfbLLNmzcHX6Movu/HUf982Lp1a7zLwSH8+9//NplvyInv/8uKFSsSUlMyVKlSJdJxzrkEV4Kfk5eXZ7ILLrjAZMOGDTOZ7//zMcccE2k9H9/3tX79+pls/fr1kdYTEalatarJfPvuhx9+MNnJJ59ssoULF0a+dqrhzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACZNQAinLlypns+uuvN5nvAbkuXbqYLDc312Q1atQwWdOmTb31VK5c2WSffvqpyb7//nvv+UA8+H4iedmyZSOfP378eJPt3bs3ppqQHqpVq2ayiRMnmizqw+6+YRM9evQw2bJlyyKtVxw5OTkmO/LII03m+1z27dtnskwaapAOfP9fvvrqqyRUUnKGDh1qshtuuCHSub5BHJMmTYq5JkS3Z88ek82aNStS5nv/GHUAhe/7tm8Ahe/3lO/7sYhIzZo1g9ds2bKlyT755BOTLVq0yGQDBw40WSoOwOLOFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAgIya5te1a1eT+abv+TRp0sRkvil9UadWFeWee+4x2f79+2NaEziUUaNGmaxMGf9vfd/0Id80P2QH34TU2rVrB6/3yiuvmOzFF18MXq84brzxRpO1a9cu0rm7du0ymW8CFxDqrrvuMtmgQYNM5pug5uObuvnPf/6z2HUhOXzTGD/77LNIWVQdOnQwWa9evSKfP3/+fJONHj3aZOeff77J2rdvbzLfe+7nn3/eZM2bNzfZd999V2SdJYE7UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAiQUQMoWrZsabIdO3aY7LHHHjPZypUrTbZhwwaTrVu3zmTTpk2LWqK8+uqrkY8Fiqt+/foma9WqlcmKGqSyePFik61evTr2wpDy2rRpY7K///3vwev59tjMmTOD14vVb37zm+Bzc3JyTNaiRQuTffzxx8HXQOYpalhEt27dTPanP/0p8vkHmzNnjsluueWWSOciO9xxxx0mGzhwoMkqVKjgPf/dd981WY8ePUzmGwTx3HPPmax169YmmzdvnsmqVKliskqVKnlrTCbuTAEAAABAAJopAAAAAAhAMwUAAAAAAWimAAAAACBARg2g6NOnT6QsFp07dzZZUQ+JTp8+3WRbtmyJaz1AYQMGDDBZxYoVI58/atSoeJaDNDJu3DiTVa5cOXg934PITz31VPB6xXHmmWea7PTTTw9eb//+/SbbuHFj8HrIPLm5uSa78847vcdeeeWVJitqKNDBvv76a5P9/ve/N9nevXsjrYf0VrZsWZO99NJLJuvYsaPJfHuuqO/Rffv2NdnmzZujlOjVvHnzSMd9/vnnJsvLywu+bqJwZwoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABMmoARUno2rWryYp6cPSjjz5KdDnAf2nXrl1M50+ePDkudSD9+H5KfVEP0Efx7LPPxlJOZN27dzfZHXfcYbLSpUsHX+P222832bfffhu8HtLbSSedZLKRI0ea7LzzzvOeH3XYxIsvvmgy35Ch5cuXR1oP6eOoo44ymW8A2uWXXx7p3J9++slkvj3ry0REdu7c6c2j8A3B6t27d6RzR4wYYbJUHK7CnSkAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEYABFMbVt29ZkRT1M+tZbbyW6HGSxk08+2WSNGjWKdK7vJ6Qju61evTqu6+Xk5JjsmmuuMdkpp5xismXLlpmsqOEqbdq0iXRtn/3795vMN4hjzJgxkdZD5qlTp47JJk2aZLIWLVrEdJ2+ffuabMKECTGtidRToUIFkz344IMm69Gjh8miDi554403TDZo0CCTTZs2LdJ6sWrSpInJfO9VVqxYYbKZM2cmpKZ4484UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAjCA4hCaN29usjJl7Jfstdde857/wQcfxL0m4IBx48aZrGzZspHOHTp0aLzLAf7LgAED4rpeqVL+v/vzDZHwWbNmjcn++te/muzee+8tXmHIaDfddJPJWrZsaTLfcIBt27Z517zllltMNnHixIDqkMp+/etfm8z357ZvCI+qmsz3/WrYsGEm27hxY9QS465evXommzFjhsl8n99dd91lss2bN8ensATjzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACMIDiEEaOHGmyypUrm6x9+/be83v37m0yfqI5QlSqVMlkxx57bKRzfQ+j5uXlxVwTMovvJ8379skJJ5xQEuUYvgf8RUTWrVtnskceecRkkyZNMtmSJUtirguZw/cAvG8AhW8v+h6UHzRokPc6Dz/8cEB1SDeXXHKJyXyDzYr63nawL7/80mS+96S+IRCJcNppp5nMt+erVatmsm+//dZkvu/b6YI7UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAjAAIpD8D0U6Mu++OIL7/nTpk2Le03ITo0aNTLZ0UcfHenc9957z2S7d++OuSZklpUrV5qsTZs2JrviiitMNmTIEJPVqlUruJbJkyeb7B//+If32Pfff99kq1evDr42soPvofiuXbuarEwZ+zZJVU32zDPPmIxBE9nN933swgsvNJnvz3cf34AG34Cp6tWrm8y3Z6MOviiKb03fewvfcCPf77V0xp0pAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABGAAxSH88pe/NNn27dtN9rvf/c57/tq1a+NeE7KT76HVqCZOnBjHSpBNfA83T5gwIVIGpLIuXbqYLDc3N9K53333ncmGDx8ea0nIMHl5eSZr1qyZyXyDfk4//XST+fZnhQoVTNa5c+eIFVq+mkVE5s+fbzLfoJ+XXnrJZB988EFwPemCO1MAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIoMX5CciqGtuPS04z69atM5nvgeyGDRuWRDlpyTlnf0R2CcqUPXvEEUeY7IsvvjCZ7/fzcccdZzLfIBXkY88iDc13zrVI1sXTcc927NjRZDNmzDCZ73tq7969TfbII4/Ep7DswZ5Fuilyz3JnCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAHKJLuAVOZ76B9IBt8wlFq1aiWhEgBIf3PmzDHZhx9+aLLjjz8+0rkAshd3pgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAARQAACCr/PTTTyZr1apVEioBkO64MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAHKFPP4dSKyNBGFICPVT3YBwp5F8bBnkY6SvW/Zsygu9izSTZF7Vp1zJVkIAAAAAGQE/pkfAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQIKuaKVVdoqodIhznVLVB4DWCzwUOxp5FumHPIt2wZ5Fu2LOpJauaqVSlqmeq6lxV3ayqS5JdD/BzVHWgqn6uqltV9XtVHZjsmoBDUdVZqrqt0Gu3qi5Kdl1AUXhvgHSkqs1VdV7B99k1qnpTsmtKNJqp1LBdRB4TEd6QIl2oiFwlItVF5DwR6auqVyS3JKBozrmOzrlKB14i8p6IPJ/suoBD4L0B0oqqHiEir4rIwyJyuIg0EJHXklpUCcjKZkpVT1XV91V1k6quUtVxqppz0GHnq+p3qrpOVUeraqlC5/dU1S9VdaOqzlbV+rHU45z7l3PuSRH5LpZ1kLlScM+Ocs594pzb65z7WkReFpHTY1kTmSXV9uxBteWKyBki8kS81kT6S7U9y3sD/JxU27Mi0l9EZjvnnnLO/eSc2+qc+zLGNVNeVjZTIrJPRPqJyBEi0kpE2otIn4OO6SQiLUSkuYhcJCI9RURU9SIRGSwivxORmiLytohM9V1EVW8p2ODeVwI+L2SulN2zqqqS/8b0ixg/R2SWlN2zkn9X9W3n3JJYPkFknFTes4BPqu3Z/xGRDar6nqr+qKqvqGq9uH22qco5lzUvEVkiIh08+c0i8mKhXzsROa/Qr/uIyJsFH88SkWsK/bdSIrJDROoXOrdBYH0dRGRJsr9OvFLnlep7tuD8O0VkoYiUS/bXi1fyX2myZxeLyNXJ/lrxSo1Xqu9Z3hvwOviVqntWRL4RkU0i0lJEyovI/SLybrK/Xol+ZeWdKVVtpKr/UNXVqrpFRIZLfldf2LJCHy8VkdoFH9cXkfsKdeMbJP/5kTqJrhvZK1X3rKr2lfy/5b/AOfdTrOshc6Twnm0tIkeJyLRY10JmSdU9CxQlBffsTslv5j5yzu2S/L9sPU1Vq8awZsrLymZKRCaIyFci0tA5V0Xyb3PqQcfULfRxPRFZWfDxMhHp5ZyrVuhVwTn33sEXUdXB+t/To/7rlYDPC5kr5fasqvYUkVtEpL1zbnmcPk9kjpTbswV6iMh05xzfg3GwVN2zQFFSbc9+Jvl3tA5wkgWytZmqLCJbRGSbqjYWkd6eYwaqanVVrSsiN4nIswX5QyIySFVPFBFR1aqqeqnvIs654a7Q9KiDXweOU9VSqlpeRMrm/1LLq32AENkt1fZsN8n/G7CznXM8HA2flNqzBetUEJHLRGRyXD5DZJqU2rO8N0AEKbVnReRxEemkqs1UtayIDBGRd5xzm+Pz6aambG2mBohIVxHZKiKPyn82VmEvi8h8EVkgIjNEZJKIiHPuRREZKSLPFNxS/VxEOsZYTxvJvzU6U/L/1mCnZMEoSRRLqu3ZuyV/7OlHhf526qEY10RmSbU9KyJyseT/e/65cVgLmSfV9izvDfBzUmrPOufmSP7dsRki8qPkj0bvGsua6UCdy4o7cAAAAAAQV9l6ZwoAAAAAYkIzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgABlinOwqjL6D8XinDv4h8eVKPYsios9izS0zjlXM1kXZ88iAHsW6abIPcudKQAA0tvSZBcAFBN7FummyD1LMwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQIBi/dBeAAAAANmjfPnyJjvnnHNM1r9/f5P99a9/NdlHH31kslWrVgVWl3zcmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAABlAAAAAA8Bo/frzJevToEenc1q1bm2zChAkmu/HGG4tfWIrgzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACMIAigcqVK2eyd99912THHnusyTp06GCyTz75JD6FIWU88MADJjvllFMinfvqq6+abOnSpSZbvXq1yWbPnh3pGgCA1NG4cWOTLViwwGQfffSRyc4444yE1IT0lJOTYzLfexIRkauvvtpkzrlI19mzZ4/JPvjgg0jnpgvuTAEAAABAAJopAAAAAAhAMwUAAAAAAWimAAAAACAAAygSqHr16iZr3rx5pHMnT55sspYtW5rsp59+KnZdSDzf8BHfTxDv2bNn8DVatWplMt8Dofv37zfZxx9/7F3ztttuM9lrr70WUB0AIN5at25tstKlS5vspJNOMtlxxx1nsm+//TY+hSHtXH/99Sa75pprYlrTNwTr7rvvNtlTTz0V03VSDXemAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFAl0xx13BJ9bpUoVk9WsWdNky5cvD74GEufPf/6zyWIZNuET9aePlypl/87k1FNP9R7rG5LRpUsXkxU1wAJIlDZt2njz+++/32THH3+8yfr372+yCRMmxF4YkCAdO3Y0mW9IUJky9q3cjh07TLZr1674FIa04/v+2a9fv5jW9A0v6dChg8mWLVsW03XSAXemAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFHHQqVMnb96rVy+TRR0akJeXZzKGTaSP2rVrRzpu+vTpJlu4cKHJtm3bZrInn3zSZOXKlTOZ7yeNn3baad56jjvuOJM98sgjJmvZsqXJ9u3b510Tma9SpUom27t3r8l8++ukk04ymW9/FjWAokmTJlFKlFatWpmMARRIFaVLlzZZnz59TFa3bl2T+b73vvnmmyZbsWJFYHVIJ1WrVjXZ0KFDTZabmxt5zTVr1pisW7duJsuGYRM+3JkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAAZQxEHjxo1jOt/3UGjPnj1jWhPJ5Xuw/YcffjDZqFGjTBbvQQ7t2rUz2auvvuo99pxzzjFZs2bNTPaHP/zBZOPHjy9+cc5TUHEAACAASURBVEgZhx12mMlmzpwZ6dzdu3ebrEGDBiarVauWycqXL28yVTVZ1OE9Rdm6dWtM5wOJ5BsQ8Jvf/CbSuR999JHJrrrqqphrQnoaN26cyVq3bm2y4nxP9Q2dYqDJf3BnCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAEYQBEHsT7o+cgjj5hs9erVMa2J5Prss88iZcly9913e3PfsIqcnByTDRkyxGSvvPKKyXxDN5CaKlSoYLIzzjjDZLEMh9i1a5fJtm3bZrLHH3/cZIcffrh3zcsvv9xkpUuXNplvSAaQDL6hVTfffHOkc30DinzDK5KlRYsWJvv444+TUEl2uO6660zWuXPnSOfu2bPHZH379vUey7CJQ+POFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAAIwgKKYunTpYrKGDRvGtOayZctiOh8ornfeecebjx492mT/+7//a7IjjzzSZLm5uSZjAEX62Lp1q8kuuOCCuF5jyZIlJtuyZYvJVq5cGXnNU0891WQNGjSIdB0gkQ477DBvfvvtt0c+9mBTp0412axZs4pXWALt2LEj2SVklYceeshkUQcCPf/88yabNGlSzDVlI+5MAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIAADKIppyJAhJitVKnpPunbtWpNNnz49ppqAeHn55ZdN5htA4dOkSROTzZs3L+aaUDJ2795tsldffTUJlfhVq1bNm/se3FdVk/mGXwCJdOGFF3rzK664ItL5GzZsMNnDDz8cU02JlpeXl+wSMpZvAFpUvj+L+/btG0s5KIQ7UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAjAAIpiql69ekznjx071mRbtmyJaU0gFfgetvb9dPZ9+/aVRDnIMMcff7w3r127tsmccyY788wzTfb444/HXhggIu3atTPZlClTIp/v27P9+/c32TvvvFOsupA5Bg8ebDLfsB2ft956y2S894wf7kwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgAAMoDuHKK6802ZFHHhn5/G3btplszJgxMdUEJNLatWtNtm7dOpMdccQRJmvQoIHJcnJyTLZz587A6pDNmjRpEtP5ixYtilMlgHXbbbeZrFy5cpHPHzdunMmKM8ACmaVZs2Ymq1+/vsl8g0t82UsvvRRTPVWrVjXZOeecY7Lf//73JitfvrzJnnnmGZM98sgjgdUlH3emAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFIdw9tlnm6xUqej95969e022Z8+emGpC+qpWrZrJateuHelc31765ptvYq7pYDVr1jSZb9iEz9ixY03GsAnES6wDKBLx+wXZqXfv3iZr3bp15POXLl1qsltvvTWmmpBZfO8XDjvssEjn5uXlmWzx4sUm8w2IOv30071rTps2zWS+oRRRtWnTxmS+oRt9+vQJvkZJ4s4UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAApvkV8E0RufDCC03mnIu85qhRo2KqCemrY8eOJvNNu2vUqFGk9Xbv3m2yO++802QzZ8402cKFCyNdQ0TkoosuinzswRYtWhR8LjKPby/5JvJ9//33JuvWrZvJGjduHFM948aNM9kpp5xisttuuy2m6yCz1KpVy2R/+ctfTFa2bFmT+aawioiMHj3aZFu2bAmoDrB80/x8k6R9k/Luvfde75qqarLivB+OokuXLibzfd/2fX7Jxp0pAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABGAARYGGDRuarGrVqjGtOWPGjJjOR/p6+eWXTVamTPhvt5ycHJMNGzbMZLfffrvJXnnlFZMVtTf//Oc/R6rH9zDrTz/9FOlcZJ6JEyea7PLLLzdZxYoVI60X68POvoEtvt9DQGG+79FTpkwxWf369SOtV9RQnvHjxxevMCBGDzzwgMmuueaamNb0vc95++23TVbUUIuDValSxWQ1a9YsfmFJwJ0pAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABGAARQKdfvrpJvvss8+SUAlK2ooVK0wW9aHlVatWmcz3IPM555xjMt9D9pdcckmkrDgWL15ssg8//DCmNUM1b97cZHXr1jWZ72FZxMfdd99tsjp16pjsuOOOM9m6detM5htAUa9ePe+1jzrqKJPNmTPHZL6BGFu3bvWuiex00kknmezcc8+NdO7evXtNdtddd8VcE3CA7/uiL7v00kvjfu3u3bubbObMmSbz/Tnrq9GnVKn0vb+TvpUDAAAAQBLRTAEAAABAAJopAAAAAAhAMwUAAAAAARhAUeCKK66I+5qjRo0y2YQJE+J+HaSeoUOHmuzhhx82WZky9rfg/PnzTXb99debrHz58ibz/fRx3yCAWDVs2NBkvqEbeXl5JjvhhBPiWku1atVM5nvg9bDDDovrdfEfS5YsMVnHjh1NVrlyZZNFHQLhGyoh4h9A0bhx4+DrIHsNGTIk+Nz77rvPZC+++GIs5SCLrV271mRbtmwxWZUqVUzmnAu+Ru/evb3H+oZbPffccyZr3bp1cD2vv/66yRYuXBjp3GTjzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACaNQHw0REVDX6wWnm008/NdnJJ58c05rbt283me8B7EzmnIv2o68TJJX27JlnnmmyRx991GTHHntspPXef/99kzVr1sxkFSpUiLReOvjhhx9M5hu6MWbMGJMtWLAg0jXYs8mXm5trskWLFnmP9Q0WWblypcnq1q0bc10pbL5zrkWyLp6Oe7ZFC/vlmjdvnsmifv8844wzTPbOO+8Uv7DswZ4tJt9wiHHjxpks6vv6PXv2mGzz5s3eY2vWrBl8HZ+NGzearF69eibbuXNn8DUSoMg9y50pAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABCiT7AIyGT/9HIXNnTvXZP379zfZ6NGjTeZ7IL9Vq1aRrrt7926T+QauiIgMGzbMZF999VWk6/j07NnTZL6fpD5//nyTffTRRybbtGmTydatWxdYHVLVL3/5S5P5Bk0U5YUXXohnOchAAwYMMFnUYRNvvPGGyT788MOYawIOZcKECSbzDaCIqmzZsiY74ogjgtcT8Q+1eP75503Wt29fk6XYsIli4c4UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAmTlAIq2bduazPfAc1SfffaZN7/qqquC10R2+Pvf/x4pa9asmcmaNm0a6Rrz5s0z2ZIlSyKdG6vBgweXyHWQWXwDV1Q18vmrVq2KYzVId0ceeaTJog7w8bnnnntM5nvwHki0W2+91WR33XVX3K+zceNGk02dOtVkEydONNnChQvjXk+q4c4UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAmTlAIrDDjvMZDk5OcHrzZgxI5ZygJ+1YMGCSBmQCY444giTOecinz937tx4loM0V716dZPVq1cveL39+/fHUg4QNyNGjIiUIbG4MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAAWTmA4vXXXzfZzTffbLKzzz7bZN9++63J3nrrrfgUBgCQRo0aRT52yZIlJvvss8/iWA3S3ffff2+yBx980GR9+vQx2YYNG0y2bNmy+BQGICNwZwoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABtDg/VV5Vox8MiIhzTpN5ffYsios9m3xPPvmkybp27eo99osvvjBZ06ZN415TipvvnGuRrIuzZxGAPYt0U+Se5c4UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAApRJdgEAAIR64YUXkl0CACCLcWcKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAAdS56D8Emp8YjeJyzmkyr8+eRXGxZ5GG5jvnWiTr4uxZBGDPIt0UuWe5MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAHKFPP4dSKyNBGFICPVT3YBwp5F8bBnkY6SvW/Zsygu9izSTZF7Vp1zJVkIAAAAAGQE/pkfAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQIKuaKVVdoqodIhznVLVB4DWCzwUOxp5FumHPIt2wZ5Fu2LOpJauaqVSmqs1VdZ6qblPVNap6U7JrAn6Oquao6pequjzZtQCHoqqzCr6/HnjtVtVFya4LKAp7FulGVfup6nequkVVV6rqWFUtk+y6Eo1mKgWo6hEi8qqIPCwih4tIAxF5LalFAdEMFJG1yS4C+DnOuY7OuUoHXiLynog8n+y6gKKwZ5GG/i4izZ1zVUTkJBE5WUT+mNySEi8rmylVPVVV31fVTaq6SlXHqWrOQYedX9Bdr1PV0apaqtD5PQv+Nn6jqs5W1foxltRfRGY7555yzv3knNvqnPsyxjWRQVJwz4qqHiMi3UVkRKxrIfOk4p4ttHauiJwhIk/Ea02kP/Ys0k2q7Vnn3LfOuU0HlheR/ZJ/gyCjZWUzJSL7RKSfiBwhIq1EpL2I9DnomE4i0kJEmovIRSLSU0REVS8SkcEi8jsRqSkib4vIVN9FVPWWgg3ufRU69H9EZIOqvqeqP6rqK6paL26fLTJBqu1ZEZEHCtbdGY9PEBknFffsAVeJyNvOuSWxfILIOOxZpJuU27Oq2lVVt4jIOsm/M/VwfD7VFOacy5qXiCwRkQ6e/GYRebHQr52InFfo131E5M2Cj2eJyDWF/lspEdkhIvULndugmHV9IyKbRKSliJQXkftF5N1kf714Jf+Vwnu2k4jMKvi4nYgsT/bXildqvFJ1zx5Uy2IRuTrZXyteqfFiz/JKt1ea7NmGInKXiByV7K9Xol9ZeWdKVRup6j9UdXVB9zxc8rv6wpYV+nipiNQu+Li+iNxXqBvfIPm3MuvEUNJOyd/8HznndonInSJymqpWjWFNZJBU2rOqWlFERkkW/DtohEulPXtQXa1F5CgRmRbrWsgs7Fmkm1TdsyIizrl/i8gXIvJgPNZLZVnZTInIBBH5SkQauvyH5AZL/gYqrG6hj+uJyMqCj5eJSC/nXLVCrwrOufcOvoiqDtb/nsTzX69Ch34m+X8DcIAT4L+l0p5tKCK5IvK2qq4WkekicnTBN/PcOH2+SH+ptGcL6yEi051zvv+G7MaeRbpJ1T17QBkROS74s0sT2dpMVRaRLSKyTVUbi0hvzzEDVbW6qtYVkZtE5NmC/CERGaSqJ4qIqGpVVb3UdxHn3HBXaBLPwa9Chz4uIp1UtZmqlhWRISLyjnNuc3w+XWSAVNqzn0v+N+dmBa9rRWRNwcfLfOsiK6XSnpWCdSqIyGUiMjkunyEyDXsW6Sal9qyqXquqRxZ8fIKIDBKRN+P1yaaqbG2mBohIVxHZKiKPyn82VmEvi8h8EVkgIjNEZJKIiHPuRREZKSLPFNxS/VxEOsZSjHNujuT/bcIMEflR8iefdI1lTWSclNmzzrm9zrnVB16S/08D9hf8el/ousg4KbNnC7lY8p9PnRuHtZB52LNIN6m2Z08XkUWqul1EZha8Bse4ZspT5/gXZQAAAABQXNl6ZwoAAAAAYkIzBQAAAAABaKYAAAAAIADNFAAAAAAEKFOcg1WVaRUoFufcwT/voESxZ1Fc7FmkoXXOuZrJujh7FgHYs0g3Re5Z7kwBAJDelia7AKCY2LNIN0XuWZopAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQIAyyS4AIp9++qnJnnjiCZONHTu2JMoBAAAAEAF3pgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAARQl7IEHHjBZ3bp1TfbUU0+VRDkAAAAAAnFnCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAEYQJFAV199tcn69OljsrFjx5rsxx9/TERJSANnnnmmyfbs2WOyefPmmezKK680GcNMAADIXr73o0ceeaTJqlat6j1/8ODBka4zZMgQk+3evdtkTz75pMlWrVoV6RqpiDtTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACKDOuegHq0Y/GPLll1+a7KijjjLZ2WefbbKPP/44ITWVNOecJvP6qb5np0yZYrLLLrvMZL7fpxUqVDDZ008/bbJu3boFVped2LNIQ/Odcy2SdXH2LAKwZw8hJyfHZL4/y0eMGGGy0qVLm6xatWqRjispmzdvNtkjjzxiskcffdRkixcvTkhNERS5Z7kzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgABlkl1AJhg0aJA3P/744032xz/+0WSZMmwCxffVV1+Z7McffzRZvXr1SqIcIKEqVqxoslq1apnM932yKN27dzdZjRo1Ip377LPPmuy6664z2bZt2yLXA/wc3/CgDh06mGzatGkmK1PGvm275ZZbTDZ69OjA6pAKypcvb7JJkyYl/LqffvqpN1+xYoXJGjZsaDLf+16fqlWrmmzgwIEmO//8803WpEmTSNcoSdyZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQAAGUBRT27ZtTfaXv/zFe+wnn3xisqlTp8a9JqQv308vL1XK/h3H3XffHWm9H374IeaagOJq06aNyc477zyTnXXWWSZr2bKlyVTVZM65yPVEPfayyy4z2RtvvGGyknjwG5mnUaNG3nzIkCEm69KlS6Q1fXu7adOmxSsMKW/fvn0m+/DDD0126qmnmsz3/XPz5s0mW7x4sckuuOACbz1RB2P5Mt8wlF//+tfe6xysTp06Jqtdu7bJVq5cGWm9ROHOFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAAIwgOIQatSoYbKHHnrIZKVLl/aef+2115ps/fr1sReGjLZly5bgc8eNGxfHSpDtjjzySJP93//9n8natWtnsqK+L4Z6/fXXvblv6Er58uVN1q1bt0jXyc3NLVZdgIjISSedZLK5c+d6j61evXpcr52XlxfX9ZB827dvN1mrVq1Mdu+995rs888/N5lv2MSmTZtM5hs0URTf915f5qvx+eefj3SNatWqmaxr166RrlGSuDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAAAygKFCqlO0rH3vsMZP5fqL573//e++aCxYsiL0wZB3fg8w+u3fvNtn+/fvjXQ6yxMUXX2yy22+/3WRNmzYNvsYnn3xisunTp5vs/vvvN9lPP/3kXXPv3r0mq1WrlsmiDqBAevP9WT5w4ECTvfnmmyb7+OOPTVamjH2bdN1115nsjjvuMJlviJWIyKJFi0x22WWXmWz+/PkmW7ZsmcmmTJnivQ4y34ABA5Jdws/yDQSKyjlnMt97n2TjzhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACMICiwODBg03229/+1mTjx4832RNPPJGQmoBDeeihh0y2atWqJFRStD/96U8m27x5s8kmTpxYEuXgEEaPHm2yY4891mTr16832dSpU032/vvvm2zu3LkmW7NmTdQSgZ9Vs2ZNkw0bNsxkP/74o8l8D8r/5S9/MVnHjh0j1TJq1ChvPm7cOJP5BllVqFDBZEOHDjXZypUrI9UDJNqQIUNM1q9fv+D11q1bZzLfgKJk484UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAmTlAIqjjz7aZP379zfZF198YbK77rorITUd7OKLLzaZ7+HYWbNmmcz3gD9SU7Vq1UzWpEmTJFSSGDfccIPJcnNzTbZixQqT+fY2EufWW281WcuWLU3mG3yyePHihNQU6vLLL092CUhxjz76aPC5vkEqvmETr7/+uvf8Fi1amMw36GLjxo2R1wQSqW3btibzDZj6zW9+E9frvvvuu3FdL1G4MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAAWTmAYuLEiSbzDXfwPcTs+6npRfE9ZNqpUyeTDRgwwGQ5OTkmc86ZbNu2bSbzPQA4b968IutE8hx++OEma9WqVaRzS2oYSlQNGjQwWaVKlUy2du1ak61atSohNSG6Z599NlKWDurUqRN87r59++JYCUqa7/tLz549TeYbuFKjRg2TDRs2zGTjx4832e7du6OWKB06dDBZhQoVTOYbarF+/frI1wEOpXv37iYbMWKE91jfsKyKFSvGtZ4ePXqY7IUXXojrNRKFO1MAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIkPEDKNq3b2+yc88912S+B/Hy8vJMVrp0aZPdcccd3msPGjTIZBs2bDCZ7yHTb7/91mS+wQTXXXedyS688EKTMYAi+XwPcBb1sGcUW7ZsiaWcuPM9VF2zZk2T+fbiggULElITMp9vaEDfvn2D13vwwQdjKQdJtn//fpM98cQTkbJ4833/ExHp1atXpPPHjBkTz3KQxXzDhDp27Ggy39CoWPkGt/neu7711lsm27FjR9zrSQTuTAEAAABAAJopAAAAAAhAMwUAAAAAAWimAAAAACBARg2gqFy5sskeffRRk3355ZcmmzNnjslq165tsieffNJkZ555preeN99802SdO3c22ebNm73nR+F7iG/dunXB6yFxGjdubLJLL700CZUULScnx2QnnHCCyc4++2yT9ezZMyE1AYfyu9/9zmTly5ePdO7cuXNNtmnTpphrAkREevTo4c3r1q1bwpUg2/kGSyRi2ITP22+/bbJXXnmlRK5dUrgzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQIKOm+XXp0sVkubm5Jjv11FNNVq9ePZM9/fTTJmvQoIHJHnvsMW89N954o8l27tzpPTYKX92+yX1TpkwJvgYSp2vXrsHnzp4922RHHXWUyfbv32+yfv36mcw3qVJEpEKFCia76KKLopQIpB3f76vdu3cnoRJkogsuuCDysc8995zJdu3aFc9ykMWuueYak3Xr1s1kN998s/f8qlWrmqykpgGmA+5MAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIEBGDaDo37+/ydasWWOyrVu3mmzOnDkmO/roo012++23m+yvf/2rt55Yhk10797dZL4HCEeMGGGy1atXB18X8XHSSSeZ7JJLLgle79xzzzXZ0qVLI53rnDPZnj17vMdu3rzZZDfddJPJOnToYLILL7wwUj1APLVt29Zkqhrp3Hnz5sW7HGSp3/72tyZr166d99ivv/7aZL6BVb6BQkAI3/vCMWPGRMpERO677z6T+fasT7NmzUxWv359k0V9T5OKuDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAABk1gKJixYomW7t2rcn27dtnMt+widtuu81kf/vb30y2ffv2qCV6dezY0WTDhw832cyZM012zz33xHRtJMYxxxxjstq1a8f1Gr494ntgefHixSabMmVKTNe+//77TeYb7MJPSEe81KhRw5ufdtppJvMNXfnmm28iZUAI3yCqogZIvPvuuyZbv3593GtCYvi+5yxYsMBkO3bsKIlyUt5xxx1nslq1apmMARQAAAAAkGVopgAAAAAgAM0UAAAAAASgmQIAAACAABk1gCKqTZs2maxXr14me+aZZ0y2bdu2mK7dqVMnk/l+srRvaMDAgQNNtmvXrpjqQWLMmjXLZO+//77JWrVqZbJ58+aZbOzYsSZ7+eWXTeZ78B7IBPfee683z83NjXT+xo0bI2XAz6lfv77JKleubDLfUB4R/wAfpI+5c+eazDdUZPTo0SbzvTdIB0OGDDHZjTfemIRKUhN3pgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABAgowZQ7Nmzx2THHHOMyX7xi1+Y7PHHHzfZ3r17TaaqJqtevbq3Ht9gic6dO5ts2rRpJhs0aJDJli9f7r0OUo9v7/j+35cvX95k69atM9mWLVviUxiQBmrUqGGy008/PfL527dvN1nv3r1jqgnZqUwZ+zbJ9+D94YcfbrKnn37au+bChQtjLwxJ880335isXbt2JmvZsqXJ7rzzTpNNnTrVZCtWrAgrLkH69euX7BJSGnemAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAECCjBlB07drVZLNnzzbZJ598YrIPPvjAZJ9//rnJfAMt2rdv763n+++/N9mVV15pMt8ACmSelStXJrsEIC3cfPPNJmvQoEHk830DKHjoHyEaN25sMt/+9JkxY0a8y0EKGD58uMkmT55ssooVK5ps1KhRJrv++utNNnHiRJN9+OGHESsMV7VqVW/+hz/8IXjNuXPnmiwvLy94vVTEnSkAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEyKgBFL4hEj179jSZ78G+//mf/4mU7du3z2QjR4701jN+/HiTLV++3HsskO7WrFljskqVKiWhEqS7W2+91WTOucjnP/vss/EsB1miTBn7lmjQoEGRzn3uuedMxj7MTFOnTjVZbm6uyXxDG+rWrWsy33Cde+65J6y4FPTVV1+ZbNu2bUmoJHG4MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAAWpyHelU1+sGAiDjnNJnXZ8+WHN/DthMmTDDZvHnzTNa2bduE1BSCPVuybrjhBpM98MADJivqz6q9e/earHbt2iZbv359QHVpY75zrkWyLp4pe/bPf/6zyYYPHx7p3KZNm5osLy8v5poyWMbvWd+wPHj1NQAAIABJREFUifbt25vsyiuvNFmrVq1MVr58+fgUFidr1641mW8QXPfu3U22devWhNSUYEXuWe5MAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIID9cd8AEMD34OmGDRuSUAlSVbNmzUzme8Bf1c4AKWoAxZQpU0yW4cMmEAennnqqyYYMGRLp3IULF5ps2bJlMdeEzOLbE5MnT46UnXPOOSbzDTnxufXWW01WpUoVk61YscJk9913n3fNffv2mWzs2LGR6skG3JkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAAZQAIiLBQsWmOzwww9PQiVIVfXq1TNZpUqVTFbUsAmf6dOnx1QTstOAAQNMVqFCBZNt3brVZL4H/H3HAaFee+21SJnPvffeG+9y8DO4MwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAADKAAAJSI7777zmRbtmwxWdWqVU12/fXXe9ecM2dO7IUh67z33nsm69ixo8mWLl1qslmzZiWkJgDpiTtTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACKDF+Unzqhr9YEBEnHOazOuzZ1Fc7FmkofnOuRbJujh7FgHYs0g3Re5Z7kwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAECAMsU8fp2ILE1EIchI9ZNdgLBnUTzsWfy/9u47zKrq+v/42lQF6SK9JIIYEEHASp2ICqgUg2AgdLDgA4IFUSQUEQLql2ASpAmKGiE0RZCiBFQkkRaRL4oE+A1SpUmHoe3fHzPmO2HtwTP73pl7z53363nu8wwfzzl7DW7uzOJw1oRRrPctexaZxZ5F2GS4Z421NjsLAQAAAICEwD/zAwAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOAhRzVTxphkY0zTAMdZY0wVzzW8zwUuxZ5F2LBnETbsWYQNeza+5KhmKl4ZY/obY7YbY44ZY/YYY8YaY/LEui4gI8aYRcaYE+leZ40xG2NdF5ARY0xRY8xbxpj9aa+hsa4JuBz2LMLGGJNkjFlujDlqjEmOdT3ZhWYqPswXkTrW2sIicoOI1BKRvrEtCciYtba5tfaqn14iskpEZsW6LuAyxopIARGpLCK3iEgnY0y3mFYEXB57FmFzUkSmisgzsS4kO+XIZsoYc4sx5h/GmCPGmL3GmD8bY/JdcliLtLtFB40xLxtjcqU7v7sx5ltjzI/GmCXGmEqR1GOt3WatPfLT5UXkoohwaxX/EW979pLaKotIQxGZHq1rIvzicM/eLyJjrLWnrLXJIvKGiHSP8JpIIOxZhE287Vlr7Wpr7dsisj2S64RNjmymROSCiPQXkatF5HYRuVNEel9yTBsRqScidUSklaS9gRljWonI8yLygIiUFJHPReQ91yLGmIFpG9z5uuTYDsaYYyJyUFLvTE2MzqeKBBF3ezadziLyedoXe+An8bhnzSUf3xDJJ4iEw55F2MTjns15rLU55iUiySLS1JH3E5F56X5tRaRZul/3FpFlaR8vEpEe6f5bLhE5JSKV0p1bJYIaq4rIiyJSOta/X7xi/wrJnt0qIl1j/XvFKz5e8bpnReQdEZkrIoUk9c7/NhFJifXvF6/Yv9izvML2itc9m+5aTUUkOda/T9n1ypF3powx1xljFhhj9qXdDRopqV19ejvTfbxDRMqmfVxJRMal68YPS+rfFpWLRm3W2n+LyCYRGR+N6yExxOueNcY0EJHSIjI70mshscThnu0rIqdF5N8i8oGk/g3srgiuhwTDnkXYxOGezZFyZDMlIq+LyGYRqWpThz48L/99K11EpEK6jyuKyJ60j3eKyCPW2qLpXldaa1dduogx5nnz3xPP/ut1mfryiMi13p8dElG87tkuIjLXWnu5/YycKa72rLX2sLW2o7W2tLW2hqR+/Vsdxc8X4ceeRdjE1Z7NqXJqM1VIRI6JyAljzPUi8pjjmGeMMcWMMRVE5AkRmZmWTxCR54wxNUREjDFFjDEPuhax1o606SaeXfr66ThjTE9jzDVpH1cXkedEZFm0PlkkhLjas2nXuVJE2onIm1H5DJFo4mrPGmOuNcaUMMbkNsY0F5GHRWRE9D5dJAD2LMIm3vZsLmPMFSKSN/WX5gqjB2IknJzaTD0tIh1E5LiITJb/21jpfSAi60TkKxFZKKlTdMRaO09ERovIjLRbqv8rIs0jrKe+iGw0xpwUkY/SXs9HeE0klnjbsyIirUXkiIgsj8K1kHjibc/WFZGNafWMEpGO1tpNEV4TiYU9i7CJtz3bSFL/aepHknoX7LSILI3wmnHP2NQHxQAAAAAAmZBT70wBAAAAQERopgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOAhT2YONsYw+g+ZYq299IfHZSv2LDKLPYsQOmitLRmrxdmz8MCeRdhkuGe5MwUAQLjtiHUBQCaxZxE2Ge5ZmikAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB7yxLoAAAAAAIln6NChKhsyZIjKkpKSVLZixYosqCj6uDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAAD6EdQJE/f36VVa5cWWVdu3ZV2VVXXaWy3/zmNyorUaKEyiZNmhSswAwsXrxYZevWrVPZvn37IloHAAAAiKXGjRsHOm758uUqC8tQCu5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwEMoBlBce+21KhswYIDKevXqleW1PP7441E/f/fu3Sq76667VLZ58+aI1gaAWCpQoIDKypcvr7Lu3burrEqVKs5ruoYHrVmzRmWu4T8jR45U2ZkzZ5zrAAAuzzVEokmTJt7Xc53LAAoAAAAASBA0UwAAAADggWYKAAAAADzQTAEAAACAh7gbQOEaNrFs2TKVVaxY0XuNCxcuqOzixYve18tI7ty5VZYrl+5fy5Urp7K5c+eqrHr16tEpDHGjYMGCKqtcubLKNm3apLIiRYqorE2bNiobNGiQylx/zkREjhw5orJ3331XZa6H9N98802VHT58WGU//vijyp566imV3X333SobPny4ylzvD4i9kiVLqsz1/6pGjRoRreN6765bt6539tBDD6ns+PHjntUhTFwDUlxfdx944AGV9enTR2VXXXWVc51//vOfKrvhhhtU1qxZM5WtXr1aZefOnXOuA2Ql13CISIZNhBl3pgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOAh7gZQVKhQQWVlypRRWUpKisq2bt2qsqlTp6rso48+Utl3330XtMTAkpKSVBb0YXnX5+waGrBt27bMF4aYuOKKK1TmGtrgerj5888/V1mJEiVUFnRIibXWmRctWlRlvXv3DnTNJ598UmV79uxR2d69e1XmGgTgMmvWLJXVqVNHZcnJyYGuh+goVaqUyhYvXqyySIZNfPPNN8785MmTKrvppptUlieP/nLnesD/3nvvVdmMGTOClIg4lTdvXpVdffXVKps3b57KbrnlFu91M3qfve222wId63rff+utt1T26KOPqsz1PRLgyzVYYsiQIdlfSJzizhQAAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA8xN0AihUrVqhs0aJFKnMNY6hZs2ZWlOStWrVq3udu2LBBZQybCLd+/fqpzDVswqVRo0Yqcz0U/+tf/zrzhaUTdNiE66Fs18P8ZcuWVVm5cuVU5nr42pX16dNHZbt27cqwTkSfazjO/PnzVXbjjTcGup5rgIRrbz/zzDPO848ePaqy+vXrq2z58uUqy507d5ASESLXX3+9ykaPHq2y+++/33uNHTt2qOzcuXOBzx81apTK7rrrLpXdfvvtKuvSpYvKli5dqrL33nsvcD3Az3ENm3ANpYi2oUOHZvka0cCdKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHgwGf2EbufBxgQ/OIpcP5H+tddeU1njxo1Vlh0Pp+fJ457j4XpI1fXwtkupUqVUduDAgcwVFgestSaW68dqz7qsWrVKZbfeemugc1evXq2yli1bqiy79kjt2rVVtm7dukDnnjp1SmUvvPCCyrZv366yDz/8MNAakWDPXt6SJUtU1rRp00DnnjlzRmWdO3dW2Zw5czJf2M9ISUlRmeu9u2PHjipzDcSIM+ustfVitXis9myxYsVUtn79epVVqlRJZa6BTuPHj1fZzp07VbZ48WKVnThxIsM6fVWsWFFlX331lcrOnz+vMtfQos2bN0ensOjIkXs23rkG9Yhkz7CJpKQklbmG0sVQhnuWO1MAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADw4J6cEGcWLlwYKMsOBQsWVNm0adOcxwYdNjFv3jyV/fjjj5krDDFTvHhxlQ0aNEhl9eoFe9bW9bB7nz59VHb48OFA18sKkTyQ36lTJ5W9//77kZSDLNK/f3+VuQb9uJw7d05lVatWVdmePXsyX9jPcA1IyZVL/93hsWPHVFahQgWVuQYBfP/9957VIVquuOIKlbmGTbgeYm/fvr3K4m3Ik2uPdejQQWV/+MMfVOb6vWnQoIHKXIODTp8+HbREhJhrqER2DJoQcf+ZjLNhE5nCnSkAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4CMUAilipU6eOyqZOnaqyG2+8MfA19+7dq7KPPvpIZS1atFDZ/PnzA6+D7ON6kNn14L7Ll19+qbJYDptwPbQ8fvx4lVWrVk1l1uofKD9ixAiVMWwiPrkGpIwZM0ZlrkEOFy9eVFnbtm1VlhXDJlyaNm2qMlfdhQsXVpnrYf5hw4apbM6cOSpzDVdB9nINb+rbt6/K4m3YRFCrVq1SmWvYi+v7hfLly6usefPmKluyZIlndYhXQ4cOVdmQIUOyZW3XYImkpKRsWTu7cGcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHhJ+AIUxRmVFixZV2cCBA1XWu3dvlRUsWDCiesqUKaOyyZMnq8z1QPfXX3+tshdffFFlrgf8XcMBEB2uoQ2u3+8333xTZa59l13DJlxatWqlsi5duqjMtT9dD+m/+uqr0SkMWa5QoUIqcw1tcJkwYYLKFixYEHFNQbj+DA0aNCiqa+TPnz9QhtgrVqyYyqZPn66ywYMHq2zhwoVRreX666935q6hVa6vGcWLF1eZa7jRdddd51Fdqpo1a6qMARTh1qRJE5Vl17AJF9f3BomGO1MAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwkPADKB599FGV/eUvf4lBJZnjevC7du3aKpszZ47K7r//fpVF+8Fa/J8//elPKnvnnXdUdurUKZWdPHkyS2r6Obfccoszdw1dcXH9RPPZs2erLFafHzIv6P97l48++ijQcXny6C85rofny5cvr7IOHTo4r9myZUuVFShQIFA9kTh06FCWr4HLc72nTps2TWUpKSkqy5cvn8o+//xzlZUuXTpQLa7vNXr16uU8tn379irLjiFRo0ePVtm4ceOyfF1kr+XLl8ds7aSkJJW5vl9INNyZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwkPDT/G644YZAx+3YsUNlrsk+WeG9995TWZcuXVTWrl27QNd74YUXVMY0v6xz/vx5lR04cCAGlbjVqFFDZS+99JLz2IYNG6ps//79KhswYIDKvvnmG4/qEAvFihVTWbVq1byv17hxY5W5Jq09//zzKmvatKn3uiIi69evV9lnn32msp07d6rs1VdfDbTGxo0bVfbcc88FOhdZ5+jRoyrr0aOH9/Vce2T48OEqq1ChgsqKFi2qspkzZzrX6devn8pat26tsvHjxzvPv9Tu3btV1qZNG5Vt2LBBZefOnQu0BuJTLCf3uab05YTJfS7cmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAIAHY60NfrAxwQ+OE8YYlRUpUkRlrocwT548mSU1BVGlShWVbdmyJdC5mzZtUlnNmjUjrsmHtVb/D8hGYdyzkbjyyitVtnLlSpXVrl3beb7rz8vf/vY3lT300EMe1YVDTt2zrgfWZ8+eneXrnjhxQmXvvPOOyv761786z//Xv/6lsjJlyqhsxIgRKgs61Kdnz54qmzZtWqBzs8k6a229WC2eyO+zefLoOV25c+dWWUpKSkTruIah9O/fP9C5rVq1UtmHH34YUT3ZgD2bSUOHDlXZkCFDsmXtYcOGqcxVT4LLcM9yZwoAAAAAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAe9JOVCcY1YOPIkSMxqCRzInnwe+rUqVGsBGFyzTXXqKxw4cIqcw2aEBH54YcfVPbUU09FXhji3oYNG1T2ySefqKxp06aBrnfx4kWVTZgwQWVjx45V2fbt2wOtIeIe1rN48WKV/eIXvwh0Pdd776xZswLXg8Ry/vz5QFlm5M+fX2WNGzdWmet9eseOHSr75ptvIqoH8adJkyYqy45hEytWrHDmOXDYRKZwZwoAAAAAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeEn4ARbwpXbq0ylwP+NesWdN7jRkzZnifi3Dr1q2byn75y1+qzDWYRUTkyy+/VNnu3bsjLwxxzzX0oWXLlirr37+/ynbt2qWyjz/+WGWuASdBXXvttc48kmETx48fV9moUaNUduLEiUDXA4IYN26cyurUqaOy/fv3q8w1AGbbtm3RKQxxY/ny5TFZNykpKSbrhh13pgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBARRZyPXAtOunSHfs2DHQ9c6ePauyP/3pTyrbt29foOsh3KpUqaKyxx57LNC5GQ0p6dmzZ0Q1IbGkpKSo7A9/+EOWr1upUiWVffLJJ85jK1asGOiarvfPu+++W2VfffVVoOsBP6d48eLO/JFHHlHZ+fPnVTZhwgSVbd26NfLCEFeaNGmS5WusWLFCZcOGDcvydXMK7kwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAQ8IPoChfvrzKdu3aFejcvHnzqqxBgwYqa9eunfP8Hj16BLqmtVZl27dvV9mLL76osrfeesu5NhJLwYIFVfb444+rrHDhwoGut2bNGmd++vTpzBUGZIEKFSqoLOigCRGR48ePq2zu3LkqW716deYKAzJQsmRJlS1evDjw+ZMmTVLZkCFDIqoJ8cc1bGL58uVZvq5r2IRrKAX8cGcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHkI7gML1k8UHDhyosqpVq6qsTZs2KmvdurXKfv/736usdu3aQUt0cg2b+O6771SWlJSksn379kW0NsLLNViibdu2KsuXL5/K9uzZo7KZM2dGpzAgC9x3330RnT9q1CiVjR49OqJrAj8pXbq0yubPn6+ym266yXn+iBEjVDZlypTIC0Pccw2giDZjTJavgf/GnSkAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4CO0Aivbt26vs6aefDnTu2bNnVZY7d26VRfoQ3+bNm1U2fPhwlc2aNUtlFy5ciGhtJJalS5eqrFy5cio7c+aMylq1aqWyvXv3RqcwIEINGjRQWdeuXQOf7xrMM2nSpEhKAv6jTJkyKnv//fdVVq9ePZW9/fbbzmsOHjw48sIAEVmxYkWsS4BwZwoAAAAAvNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeQjuAolatWt7n5snj/2lv3LhRZa+99prz2OnTp6vs3Llz3msjZ2jTpo3KatSooTJrrcpeeuklle3cuTM6hQFZoGfPniorWbJk4PN/97vfqezHH3+MqCbgJ7169VLZzTffrLKtW7eqjEETuNTQoUNV1rhxY5U1adJEZa5hE0lJSVGoCpHizhQAAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA8hHYAxR//+EeVnT17VmV169ZV2alTp1Q2f/58lS1ZskRl+/btU9mxY8cyrBO4nHz58qnsgQceCHRuSkqKylauXKmyAwcOZL4wIM58//33znzz5s3ZXAlyklKlSgU6zvU+6/p+AbgUQyTCjztTAAAAAOCBZgoAAAAAPNBMAQAAAIAHmikAAAAA8BDaARSuh4779u0bg0oAf3fccYfKOnToEOjcCRMmqOzTTz+NuCYg1nbv3q2yZs2aOY/du3dvVpeDHOLOO+9UWa9evQKdu3btWpW5hmIBSDzcmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAICH0A6gABJBzZo1Ax23bds2lQ0ePDja5QDZrmvXrrEuARARkUGDBqksTx79bdLXX3+tso8//jhLagIQ/7gzBQAAAAAeaKYAAAAAwAPNFAAAAAB4oJkCAAAAAA/GWhv8YGOCHwyIiLXWxHJ99iwyiz2LEFpnra0Xq8XZs/DAnkXYZLhnuTMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAAD/pHe1/eQRHZkRWFICFVinUBwp5F5rBnEUax3rfsWWQWexZhk+GezdQ0PwAAAABAKv6ZHwAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPOaqZMsYkG2OaBjjOGmOqeK7hfS5wKfYswoY9i7BhzyJs2LPxJUc1U/HKpBptjDmU9hptjDGxrgu4HGNMHWPMZ8aYE8aYH4wxT8S6JuBy2LMIE2NMUWPMW8aY/WmvobGuCbicnLpn88S6AIiIyMMi0lpEaomIFZGPReT/iciEWBYFZMQYc7WILBaR/iIyW0TyiUj5mBYFXAZ7FiE0VkQKiEhlEblGRJYZY3ZYa6fFtCogYzlyz+bIO1PGmFuMMf8wxhwxxuw1xvzZGJPvksNaGGO2G2MOGmNeNsbkSnd+d2PMt8aYH40xS4wxlSIsqYuIvGqt3WWt3S0ir4pI1wiviQQSh3v2SRFZYq1911qbYq09bq39NsJrIoGwZxE2cbhn7xeRMdbaU9baZBF5Q0S6R3hNJBD2bHzIkc2UiFyQ1L+dvFpEbheRO0Wk9yXHtBGReiJSR0RaSdpmMMa0EpHnReQBESkpIp+LyHuuRYwxA9M2uPOV7tAaIrIh3a83pGXAT+Jtz94mIoeNMavSbuV/aIypGLXPFomAPYuwibc9KyJiLvn4hkg+QSQc9mw8sNbmmJeIJItIU0feT0Tmpfu1FZFm6X7dW0SWpX28SER6pPtvuUTklIhUSndulUzWdUFErk/366pp1zGx/j3jFdtXHO/ZLSJyRERuFpErROQ1Efki1r9fvGL/Ys/yCtsrjvfsOyIyV0QKiUgVEdkmIimx/v3iFfsXeza+XjnyzpQx5jpjzAJjzD5jzDERGSmpXX16O9N9vENEyqZ9XElExqXrxg9LauddLoKSTohI4XS/LiwiJ2zazgTicM+eltQ37DXW2jMiMkxE7jDGFIngmkgg7FmETRzu2b6Sum//LSIfSOpdg10RXA8Jhj0bH3JkMyUir4vIZhGpaq0tLKm3OS+dnlch3ccVRWRP2sc7ReQRa23RdK8rrbWrLl3EGPO8SZ0a5XylO3STpA6f+EmttAz4Sbzt2a8l9W+tfkLjj0uxZxE2cbVnrbWHrbUdrbWlrbU1JPV7ttVR/HwRfuzZOJBTm6lCInJMRE4YY64XkcccxzxjjClmjKkgIk+IyMy0fIKIPGeMqSEiYowpYox50LWItXaktfaqjF7pDp0uIk8aY8oZY8qKyFMi8mZUPlMkinjbs9NEpI0xprYxJq+IDBaRldbao9H5dJEA2LMIm7jas8aYa40xJYwxuY0xzSV18u+I6H26SADs2TiQU5upp0Wkg4gcF5HJ8n8bK70PRGSdiHwlIgsldSKJWGvnichoEZmRdkv1f0WkeYT1TBSRD0VkY9r1FqZlwE/ias9aa/8uqX8DtlBE9kvqv43uEMk1kXDYswibuNqzIlJXUr8vOC4io0Sko7WWf7WC9NizccDwWA4AAAAAZF5OvTMFAAAAABGhmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAICHPJk52BjD6D9kirX20h8el63Ys8gs9ixC6KC1tmSsFmfPwgN7FmGT4Z7lzhQAAOG2I9YFAJnEnkXYZLhnaaYAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPCQqR/aCwBAvGvbtq3KZs2apbKBAweqbPTo0VlSEwAgMXFnCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB4YQAHgPypXrqyyPn36qKxu3boqW758ucqGDRsWlbqAzBg0aJDKrLUqa9CggcoYQAEAyAzuTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMADAyiAHCh//vzO/Nlnn1XZww8/rLLJkyerbO3atZEXBkSBMSZQtnTp0uwoBwCQwLgzBQAAAAAeaKYAAAAAwAPNFAAAAAB4oJkCAAAAAA8MoLiM2267TWWtW7dWmeuhfRERa21U65k9e7bKJk6cqLLly5er7OLFi1GtBeHWoEEDZ+4aNjFhwgSV9e/fX2Vnz56NvDAgk0qWLKmyEiVKqCza78cAAIhwZwoQ7kTBAAAQRUlEQVQAAAAAvNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeGECRplWrViqbMmWKyooXL66y7Bru8Jvf/CZQ1rdvX5X95S9/yZKaEE4tWrRw5nv27FHZmDFjVMawCcSLIkWKqKxs2bIqO3funMr+8Y9/ZElNyB6NGjVSWefOnVXmGj7SsWNHlV155ZXetRw6dEhlnTp1ch67aNEi73UAxB/uTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMBDjhxAkSeP/rTvvfdelbmGTWSGazDF+fPnVZYvX76I1rnUSy+9FGjdiRMnRnVdxKf77rtPZU888YTz2OHDh6tsx44dUa8JiJbBgwcHOs710P/atWujXQ6yiGs4xAcffKAy10CSoFxfJ13DK1zfQ7i+X1i4cKFzHddAKNd774EDB5znA4gv3JkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAh4QfQOEa7jBs2DCV9ejRI+prv/322yqbNWuWyh599FGVFS5cWGWun/buUqhQIZWVK1cu0LlIPOPHj1fZunXrnMe6HoIG4kG9evWceefOnVWWnJyssueffz7aJSEb5cql/+43kmETr732mspmz56tslOnTqmsWrVqKuvXr5/Kbr75Zufajz/+uMoaNmyosq5du6rsq6++cl4TyCzXnh0yZIjz2KJFi6rsww8/VFn37t1VdvDgQY/qwoU7UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPCQ8AMonn32WZUNGDDA+3pr1qxR2YwZM5zHTpkyRWUnTpxQ2aJFi1RWsmRJlU2cOFFlrVq1cq6NnOmxxx5TmWsvjR07NjvKAbwULFhQZfPmzXMea61V2YQJE1T2zTffRF4YYqZWrVre5y5evFhlTz/9tMrOnz8f6Hrr169X2dy5c1V21113Oc+/4447VOb6/L744guVbd++XWUjR45U2cyZM1V28eJFZz1ILG3btlWZ6/tH11AJY4zzmq732fvuu09lrqEpr7zyivOaiYQ7UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPCQUAMoXA8tt2vXzvt6W7ZsUdmDDz6osp07d3qvkZEDBw6o7PPPP1cZAyhyrrJly6rMNVwlX758KnMNRwFiwfW+vXr1apWVK1fOef78+fNVNnr06MgLQ1wpVaqU97muh++DDpsIKiUlRWULFixwHuvK8+bNq7Lp06errH379ip79913VVahQgWVjRkzxlkPwqFYsWIqc73/uQacZDRY4lJ79+515mXKlAl0vmv4BQMoAAAAAABONFMAAAAA4IFmCgAAAAA80EwBAAAAgIfQDqBwPXzvesi0evXqga63bt06lbVs2VJl+/btC3Q9IKtVrFhRZZUqVVLZ8OHDVXb8+PEsqQnIrD/+8Y8q+9WvfqUya63z/I4dO0a9JsSfpUuXep/bs2dPlbmGQFy4cMF7jUidO3dOZdOmTVOZawCFS69evVTGAIpwmzt3rsrq16+vMtfX9yFDhqhs1qxZKjt06JBz7VtvvVVl77//vsquvPJK5/mJjjtTAAAAAOCBZgoAAAAAPNBMAQAAAIAHmikAAAAA8BDaARR16tRRWYsWLbyvN2XKFJUxbALx7N5771WZ6yF9194GYqFp06Yq69Gjh8qMMSpzPVAvInLy5MnIC0PcO3PmjMqmTp2qsu7du6vM9V7peqB+1apVntUB0eUaLFGvXj2VrVmzRmWtW7dW2d69e1VWoEABld1+++3Oelzv3a4/k59++qnz/ETHnSkAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4CMUACtdPVH799dcDnbtlyxaV3X///So7cOBA5gsDskmpUqVU9uijj3pfz/XgqYj7Qe3rrrsu0HGuB08/+eQTlS1btixIiUhA48aNU5lraMr06dMDZRkpWbKkylwPb7ssWrQo8DrIPhcuXFCZ6z2wbt26KqtVq5bKRo8erbKnnnpKZatXrw5aYkTKlCmjsldffdX7ekePHo2kHMTYe++9p7KCBQuqbNu2bSp77rnnVFakSBGVub6OFy9ePGiJTq1atVJZ3759I7pmGHBnCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB5CMYDinnvuUVnZsmUDndu+fXuVbd26NeKagOzUu3dvlbkeFF2wYIHKXANXWrRo4VwnozyI2267TWX9+vVTmeth1MmTJ3uvi/jUvHlzlVWvXl1lycnJKuvatavKMhog8corr6isUaNGP19gBj777DOVtW7dWmVHjhzxXgPRcf78eZU9+OCDKnMNvalfv77KFi5cqLK33npLZRs3bgxaYmCu98oaNWqozBijMtcQl9mzZ0enMMRE/vz5Ax330EMPqez7779X2e7du1X23XffqWzevHnOdVyDLubMmaOyoHUnGu5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4CMU0vwEDBsS6hGxXsWJFlbmmtiBnKFSokMpcU51ck/tatmypsl27djnX+dWvfqWyLVu2BClRmjRpojLXFLSJEyeqbNGiRSrLqEaEwyOPPKIy19SxCRMmqGzKlCkq+93vfudcJ1++fCr79NNPVbZjxw6Vufa7axLgoEGDVPbMM88460Fsuab1Nm7cWGWuCX+/+MUvVPbkk09Gp7Aocf0ZcnnjjTeyuBJkpTp16qgsd+7cgc51TRo9duxYRPVUqlQp0HGHDx+OaJ2w4s4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPMTdAIpnn31WZTfddFOgc4cPH66yb7/9NuKaYmHp0qUqq1q1aqBzT506pbJDhw5FXBNix/XQcdAHkSdNmqSy//mf/3EeG3TYhMuKFStUdvDgQZV169ZNZa6HbRlAER5VqlRRmWvwiWtoyqhRowIdd+LECefa99xzj8pcAyhc2rZtq7K//e1vKnPtY4RHcnKyyho2bKiy22+/XWW//vWvVeYa9FO6dGmV5crl/vvq5cuXq+yKK65QWf369Z3nBxH06wPi0+7du2Ndwn+pXLlyoOPWrl2btYXEKe5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwENMB1CUL19eZa6fPu/6Cfeun2b/2WefqezcuXOe1WWfu+++W2XlypXzvt6IESNUNm7cOO/rITwefvhhlb399tsqO3v2bHaUE9j69etjXQIi0LhxY5UFfQDeddyBAwdUltHD+Fu3bg20TvXq1VU2duxYlZ08eVJlrj9DCLc9e/aobM6cOYGyxx9/XGUlSpRQWUZDo/75z3+qzPV9zsqVK1V28803q4xhE8hq9erVC3Tchg0bsriS+MSdKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHiI6QCKatWqqaxZs2YqS0lJUVmnTp1U9sUXX0SnsCzkGjYxefJklRUoUCDQ9UaOHKmyV155JfOFISFs2rRJZdk1bKJQoUIqe/3111XmevD7xIkTWVITos8Yo7IHHngg0HGuzDXwISkpSWVBB02IuIdNzJo1S2VFixZVWbdu3VTm2rNAeocOHQqUZcT1Pp2cnKyyoIMAgGjatWtXoOO+/PLLLK4kPnFnCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB5iOoAiqAsXLqgsDMMm7rnnHpXNnj1bZUGHTezfv19lS5YsUZnr9wvh5npwf9++fSrbu3dvROvkz59fZddcc43KevbsqbJBgwapbMGCBSpr2LChZ3WIB9ZalR04cCDQcS5z5sxRWaVKlVTWvHlz5/mu/dSoUSOVHT16VGWtW7dW2bJly5zrAEBO9fTTT6vM9T1I0EEViYY7UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPAQigEU8a5Zs2bO/K9//avKIhk28cgjj6hs5cqVga6HcHM9zF+6dGmV1alTR2U7duxQWYkSJZzrzJw5U2VJSUlBSpRhw4apbPjw4YHORbj9+c9/Vtlvf/tbleXLl09lnTt3VlmXLl1UltFAC9dgiddff11lY8eOVZlrcAYQL06fPq0y1zAi15+Nm266SWUff/xxdAoDROSNN95QWXJycvYXEge4MwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPoRhAkT9/fpX9/ve/V1l2POx+5513qmzSpEnOY4sUKeK9Trdu3VS2ePFi7+sh3P79738HOs61F5955hmVuf5MiYjUrl1bZX//+99VNnr0aJV98sknQUpEAlq7dq3KevXqpbKXX35ZZVdffbXKPv30U5XNnTvXufaMGTNUxmAJJIJ3331XZZ06dQp0bt26dVXGAAr8nIwGTl1//fUq69OnT1aXExrcmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAICHmA6gcP3k+t27d6usXLlyKnvhhRdU9uyzz6rs4YcfVtm2bdtU1r9/f5Xdd999KsudO7fK8ubNqzIRkVOnTqlsxIgRKhs3bpzKUlJSnNdEzjRlyhSVlS5dWmWDBw9W2a233qqykydPOtdp2rSpyr744guVnT171nk+8JPp06cHygAEZ4xRmbVWZYUKFcqOchBiefLoFsA13E1E5ODBgyrbtGlT1GsKK+5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwINxPbiY4cHGBD/Y01133aUy109kdg2biJX58+c78xUrVqjMNWwikVlr9dOy2Sg79iwSC3sWIbTOWlsvVouzZ7PGjTfeqDLXQKCCBQuqbP/+/SqrVauWyn744QfP6iLGno2xNm3aqGzOnDnOY++9916VLVq0KOo1xbkM9yx3pgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOAh7gZQuOTOnVtlAwYMUNmIESOiuq7rAc4OHTqo7Ntvv3Wev2/fvqjWE0Y8zI+wYc8ihHiYP4d44oknVDZ27FiVGaPfxl5++WWVDRw4UGUXL170rC5T2LMxtnHjRpUVKlTIeaxrGMqxY8eiXlOcYwAFAAAAAEQTzRQAAAAAeKCZAgAAAAAPNFMAAAAA4CEUAygQXjzMj7BhzyKEeJg/hyhatKjKxowZo7JOnTqpLH/+/CorWLCgyk6fPu1ZXaawZ7PRL3/5S5WtX79eZe3atXOev3Tp0qjXFEIMoAAAAACAaKKZAgAAAAAPNFMAAAAA4IFmCgAAAAA8MIACWYqH+RE27FmEEA/zI2zYs9lo2bJlKluxYoXKXnzxxWyoJrQYQAEAAAAA0UQzBQAAAAAeaKYAAAAAwAPNFAAAAAB4yBPrAgAAAABErnLlyiq75pprVDZu3LhsqCZn4M4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPBhrg/8QaGPMARHZkXXlIMFUstaWjGUB7FlkEnsWYRTTfcuehQf2LMImwz2bqWYKAAAAAJCKf+YHAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAe/j+4x3N5ZeguRgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "show_img_grid(\n", " [train_ds['image'][idx] for idx in range(25)],\n", " [f'label={train_ds[\"label\"][idx]}' for idx in range(25)],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get a live update during training - use the \"refresh\" button!\n", "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n", "if 'google.colab' in str(get_ipython()):\n", " %load_ext tensorboard\n", " %tensorboard --logdir=." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "outputId": "a0eb78b5-ee73-4f4f-8400-41b521f42b75", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load dataset info from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\n", "INFO:absl:Reusing dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)\n", "INFO:absl:Constructing tf.data.Dataset for split train, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:train epoch: 1, loss: 0.2352, accuracy: 92.94\n", "INFO:absl:eval epoch: 1, loss: 0.0592, accuracy: 98.00\n", "INFO:absl:train epoch: 2, loss: 0.0584, accuracy: 98.15\n", "INFO:absl:eval epoch: 2, loss: 0.0575, accuracy: 98.14\n", "INFO:absl:train epoch: 3, loss: 0.0423, accuracy: 98.66\n", "INFO:absl:eval epoch: 3, loss: 0.0357, accuracy: 98.78\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\n", "INFO:absl:Reusing dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)\n", "INFO:absl:Constructing tf.data.Dataset for split train, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:train epoch: 1, loss: 0.2745, accuracy: 91.83\n", "INFO:absl:eval epoch: 1, loss: 0.0478, accuracy: 98.36\n", "INFO:absl:train epoch: 2, loss: 0.0508, accuracy: 98.42\n", "INFO:absl:eval epoch: 2, loss: 0.0382, accuracy: 98.81\n", "INFO:absl:train epoch: 3, loss: 0.0374, accuracy: 98.85\n", "INFO:absl:eval epoch: 3, loss: 0.0264, accuracy: 99.09\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\n", "INFO:absl:Reusing dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)\n", "INFO:absl:Constructing tf.data.Dataset for split train, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/mnist/3.0.1\n", "INFO:absl:train epoch: 1, loss: 0.2676, accuracy: 91.85\n", "INFO:absl:eval epoch: 1, loss: 0.0485, accuracy: 98.57\n", "INFO:absl:train epoch: 2, loss: 0.0483, accuracy: 98.54\n", "INFO:absl:eval epoch: 2, loss: 0.0461, accuracy: 98.64\n", "INFO:absl:train epoch: 3, loss: 0.0341, accuracy: 98.91\n", "INFO:absl:eval epoch: 3, loss: 0.0396, accuracy: 98.74\n" ] } ], "source": [ "# 3x 3 epochs trains in ~1 minute in the GPU Colab...\n", "\n", "# We don't use TPUs in this Colab because we do not distribute our\n", "# training using pmap() - if you're looking for an example using TPUs\n", "# checkout below Colab:\n", "# https://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb\n", "\n", "config.num_epochs = 3\n", "models = {}\n", "for momentum in (0.8, 0.9, 0.95):\n", " name = f'momentum={momentum}'\n", " config.momentum = momentum\n", " state = train.train_and_evaluate(config, workdir=f'./models/{name}')\n", " models[name] = state.params" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "cellView": "form", "tags": [] }, "outputs": [], "source": [ "if 'google.colab' in str(get_ipython()):\n", " #@markdown You can upload the training results directly to https://tensorboard.dev\n", " #@markdown\n", " #@markdown Note that everbody with the link will be able to see the data.\n", " upload_data = 'no' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "3af424f7-4433-475d-817c-5c0bbc4599ae" }, "outputs": [ { "data": { "text/plain": [ "0.0126" ] }, "execution_count": 13, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Find all mistakes in testset.\n", "logits = train.CNN().apply({'params': state.params}, test_ds['image'])\n", "error_idxs, = jnp.where(test_ds['label'] != logits.argmax(axis=1))\n", "len(error_idxs) / len(logits)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "outputId": "949487f5-8aa2-45c8-9b54-efbf34ab58f1" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAANRCAYAAAAGcOaXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzde7zVc9r/8etKqelEaiuFUigR6TDMOFQynYZu/RJjKodRQjRKdJIGHSZTFCoGk+PcIjWYSEqoUJSkExWddNBOEZUO+/P7Y+/u2XfXZ3d/12evvdfp9Xw81uNR7/09fNZ2WXtdfff3WuqcEwAAAABAbEokegEAAAAAkIpopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMxZGqrlXVSxO9DiAqahaphppFqqFmkWqo2djQTCUBVe2tqltU9UdV/Yeqlk70moCCqOpZqvq2qmarKh9Uh6RHzSKVqeosVXWqWjLRawEKksmvszRTBSiuFy1VbS0i/UWkpYjUFJHaInJfcZwb6aUYf9DuF5GXReTGYjof0hQ1i1RT3A2NqnYWkVLFeU6kF15ni17GNVN5ly4HqOpyVd2hqhNVtYyqNlfVjaraT1W3iMhEVS2hqv1VdY2qblfVl1X1uHzH6qqq6/K+NihwSdeJyNPOuWXOuR0i8oCIXF/4Z4p0kWw165z70jn3tIgsi9dzRHqhZpFqkq1m845zjIgMEZG74/AUkWaSrWYz+XU245qpPJ1FpLWI1BGR00Xknry8mogcJ7lXiG4SkdtF5AoRaSYi1UVkh4iMExFR1foiMkFEuuZ9rbKInHjoBKr6R1XdeYTHyXmbnikin+db2+ciUlVVKxfB80bqSqaaBaKgZpFqkq1mh+cda0tRPWGkvGSr2czknMuoh4isFZGb8/29nYisEZHmIrJPRMrk+9oKEWmZ7+8nSO5lzJIicq+IvJTva+Xy9r80xvWsEZE2+f5eSkSciNRK9PeKR3I8kq1m8+1/au5LSOK/RzyS60HN8ki1R7LVrIg0EZHFeceslfe+oGSiv088kueRbDWbb/+Me53N1JsZN+T78zrJ7cRFRLY55/bm+1pNEZmqqjn5soMiUjVvn/85jnPuZ1XdHrCWn0SkYr6/H/rzroBjIX0lU80CUVCzSDVJUbOqWkJExovIn51zB1Q1lt2RWZKiZjNdpv6a30n5/nyyiGzK+/Ph00c2iEhb59yx+R5lnHPfisjm/MdR1bKSe2n00N87q+pPR3gcuiy6TETOyXfOc0Rkq3OOQkZ+yVSzQBTULFJNstRsRcm9MjUp756XT/J236iqF8X1GSPVJUvNZrRMbaZ6quqJeTffDRKRSQVs97iIDFPVmiIiqpqlqv+V97XJInKZql6oqkeLyP2S7/vpnHvROVf+CI/1eZs+JyI3qmp9VT1Wcn/f9Zm4P2OkuqSpWc1VRkSOzvt7GWWcPyxqFqkmWWr2B8m9WtAw79Eub/fGIjI/zs8ZqS1ZajajX2cztZn6p4jMEJGvJff3S4cWsN1YEXldRGao6i4R+VhEzhMRcc4tE5GeecfaLLk3822MdSHOueki8qCIzBaR9ZJ7mXZIrMdB2kuampXcXxfYI/+Z2LNHRL4MOA7SGzWLVJMUNetybTn0EJFteV/a6pzbF9tTQppLiprNk7Gvs+pcRn2ulqjqWhHp5pybmei1AFFQs0g11CxSDTWLVEPNJo9MvTIFAAAAAIVCMwUAAAAAATLu1/wAAAAAIB64MgUAAAAAAWimAAAAACBAyVg2VlV+JxAxcc4l9KPbqVnEippFCsp2zmUl6uTULAJQs0g1BdYsV6YAAEht6xK9ACBG1CxSTYE1SzMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAAUomegEAAABHkpWVZbKbb77ZZB07djRZhQoVTDZ79myTdevWLXB1ADIZV6YAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgAEUAIBCqVKlislmzJhhsjPOOMNkH3/8caR9R48e7T33vn37oiwRKW7nzp0m69Kli8lOPfXUSMe74YYbTLZ69WqTvfHGGyZbtmxZpHMARa13794mu+KKK0yWnZ0daTtV9Z7HOWeyuXPnmqxPnz4mW7hwofeY6YQrUwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAigvpvKCtxYNfrG8N7IV758eZNVqlTJZD169DDZ1VdfbbLatWub7O677zbZY489ZrK9e/eaLN6cc/67GYsJNetXr149b/7ZZ5+Z7IsvvjDZ+eefb7KcnJzCLywJULOx89VTYW7S9712FjSA4q677go+TxpZ6JxrkqiTJ6pmfXV34403muzWW281WZkyZUzmq7s9e/ZEOt6kSZNMVhw/Y1NYRtZsvPleZ+vWrWsyX2373v/HMoDCt+22bdtMdsstt5hs6tSp3vMkuQJrlitTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACMAAijioVq2aNx8yZIjJbrrpJpP98MMPJvPd9Oq7wfWbb74xme9mv4MHD5rsggsuMNmPP/5ossLgZv7EO+GEE0w2ffp077a1atUyme+G7smTJxd6XcmKmo1dqVKlTFanTh2T9evXL9LxrrvuOpPNnDnTu+0VV1xhst27d0c6TxrhZv4jaNiwocmeffZZkzVo0MBkUd8jTZkyxWR9+/b1brtu3bpIx0xz1GwcdOjQwWQPPfSQyapUqWKylStXFurcjRs3Npnv/5dFixaZrGnTpoU6d4IwgAIAAAAA4olmCgAAAAAC0EwBAAAAQACaKQAAAAAIwACKI6hYsaLJhg4darIePXp49y9ZsqTJFi9ebLKrr77aZKtXr46yRK+rrrrKZP/93/9tsho1aphsy5Ytwef14Wb+4tWsWTOT9e/f32StW7f27t+pUyeTvfrqq4VfWAqhZhMvJyfHZAX9rGrTpo3J3nnnnbivKclxM38cPPXUUya74YYbgo9X0KAe38/8DETNFhHfsImiGEDhq2/fQCDfQLVrr73WZFOnTi3UeooBAygAAAAAIJ5opgAAAAAgAM0UAAAAAASgmQIAAACAAHZCQoY6//zzTfbII4+YzPeJz1u3bvUe84UXXjDZ3XffHbC6wvv+++9Ntn///gSsBPkde+yxJtu5c6fJVO1MhF69eplsyJAhJvPVcd26db3rqV69ujcHitO4ceNMduutt3q3/eMf/2iyDBxAgTjo1q2byY4//niTXXbZZZGO5xvoI+J/Pe/cubPJ+BmNENnZ2ZGyWGRlZZns//2//2cy36CgcuXKmaxVq1YmS4EBFAXiyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACZOQAinPOOcdkvhv3fcMmfJ5++mlvPnjw4NgWFidXXXWVyV577TWTbd++vTiWgzy+TyAfNWqUyXyDS3y1dMopp5isa9euJps2bZrJWrZsWeA6gUQbPXq0ya6//nrvtr4bmU866SSTbdiwodDrQua57bbbTFanTh2T1atXL/Ixr7zySpPt2LHDZD169Ih8TKAoPffccybzDZvwZT6pPGzChytTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJBWAygqVqxosj/84Q8mGzlypMl2794dabubbropcHVF4/zzzzeZ74bs9u3bF8dycAQ9e/Y0WZcuXUx23XXXmezf//63yS655BKTrV69OtJaPvjgg0jbAYmwdu1ak/lugBYRufnmm01WtmzZeC8JGWr9+vUm69+/v8mefPJJk2VlZUU+T4MGDWJbGFAEChq81qhRI5OpaqRjLlq0yGQzZsyIbWFJjitTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJBWAyhq165tsgkTJphs9uzZJhs0aJDJtm7dajLfzc6JdOedd5rM90nqK1euLI7l4Ah8n/j9008/mWzatGkmW7VqlckOHjwYn4XlU6IE/76C5LRixYpELwEQEZE33njDZMOGDTPZmDFjIh/zzDPPNFnbtm1N9tZbb0U+JhCrhx56yJtXrlzZZM65SFkm4J0TAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAqTsAIqKFSuabOTIkZH2/ctf/mKy4447zmSvvfaayb799luTjRs3LtJ5C+vGG280Wbt27Uzmuzl2y5YtRbImRLdkyZJIWXGYP3++N+/bt6/JHnnkEZNl6k2mAODz+uuvm6xLly7ebZs2bWqy8uXLm+zhhx82GQMoEC9ZWVkmu+iii7zb+n7mq2qk83Tt2jW2haUgrkwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgQMoOoBg6dKjJLr30UpP5hkP4hjE8+eSTJtu2bZvJWrVqFel4heW7MfCOO+4wWZkyZUz26KOPxn09SC/vvfeeN3/qqadMdswxx5hs586d8V4SAKSEU0891WSnnXaaydatW+fd/9e//nWk85QrV85kxx57rMl4Pcb/xfee8s033zRZQcOlog6dmjJlislWrlwZad9UxpUpAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAAKTHNr1q1aibr0aOHyZYsWWKy+fPnm2zatGkmO3DggMl+97vfmawoJvf51KhRw2T169c32X333Weyjz/+uEjWhPTx448/evNffvnFZC1atDDZ1KlT474m4EiGDx/uzVU1UobMVbKkfavz2GOPmaxJkyYmq1mzpsl8U3TLli0beT2+yWi+rHr16ib78ssvTfbSSy+Z7KOPPjLZF198YbJly5YVuE6kj169epmsUaNGJovltXP37t0mGzx4cGwLSxNcmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAA9d30WODGqtE3jqMTTjjBZN9++63JfDfO+Z7fnj17THbnnXea7PHHH4+6xELxPb9Zs2aZ7KijjjLZpZdearINGzbEZ2Fx4JxL6J3giarZVPXQQw+ZrFmzZiZr2rSpyXJycopkTcWNmk1OP/zwgzcvX768yXzDenw37qeRhc45Oz2hmCR7zTZs2NBkH3zwgcl8tRTLe6Soor5Xifc5fO99+vXrZzLfgCHfe65ComaLSIcOHUw2efJkk/lqrqABFL5tfcMmRowYEWWJqarAmuXKFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALYjwVPQgcPHjTZzz//bLKjjz7aZNnZ2Sbz3Zy3YMGCwNXFpmLFiibz3fRft25dk/3+9783WTINm0Di+W609tV7pUqVvPvXqFHDZOeee67J1q1bZ7Jrr73WZLNnz/aeBziSs846y2SlSpXybrtmzRqT/fjjj3FfE1LX4sWLTeYbouN7/TzuuONMVrp0aZO1bt3aZC1atPCux/deJaq9e/ea7MCBAyarUKGCycqUKWOysWPHmuzqq6822UUXXRR1iShG5cqVM9nQoUNNVtBgiajb+YaSpPmwiZhwZQoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABUmIAxXfffWeytm3bmsx3o+jrr79eJGuKwnez57x580xWv359k61fv95kS5cujc/CkBa6dOlisokTJ5rsqKOOinzMmTNnmsx3c/Pxxx9vslWrVkU+D3Ak7777rsl8N/2L+F9TN2/eHPc1Ib18+eWXkbKoxowZY7JHH33Uu227du1MVqtWLZPt2rXLZJdddpnJvv76a5OdffbZ3nNH2S7qsAIkXv/+/U3mG2DmnIuU+Ya2iYj06dMnYHWZgytTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJASAyh85s6dm+gl/J9uv/12k/mGTfzyyy8mGz16tMk2btwYn4UhLbzyyism89WSzzvvvOPNf/jhB5MNGTLEZP369TMZN/0jXqpUqWIy383SQDKbP3++N7/ttttM5qtv3/Cgbdu2mWzTpk2RMp/p06dH2g7JqWPHjibzDRCJOlSkWbNm3tw3FA3/wZUpAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABNBYbupVVe4A9rjrrru8+ciRI022evVqkz3yyCMme+yxxwq/sCTgnEvoR6lTs0Vj//79JrvyyitN9tprrxXHcuKKmk28nJwckxX0s8p3w3QqDCiKs4XOuSaJOjk169emTRtv/uabb5rMV9+rVq0yWb169Qq/sORAzR5Bhw4dTDZw4ECTNWrUyGS+WvINoFi+fLnJGjRoEHWJmajAmuXKFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAAKUTPQCUk2/fv1MNnToUO+2Bw8eNFnfvn1N9vrrrxd+YUAx+sc//mGyv/71ryZLxQEUSDzfzdIFDaA47bTTTJaBAyiQhKZPn+7NlyxZYjJu/Ed+vuElvmETJUrYayK+AT67d+822b333hu4OhyOK1MAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIwACKI/DdPO+7KdB3A6CISM+ePU3GsAmkgyFDhphs1apVJjvrrLNMtnTp0iJZE9KH78b9Vq1aebd9++23i3o5QFzt27cv0naffPKJyapUqWKy7OzsQq8JycU3cMeX+YZN+LZbuXKlyaZOnRq4OhyOK1MAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIkJEDKE488UST3X777SZr3bq1yb766iuTderUyXse3w1/QDrw3fC8ZMkSk40ePdpkvv+vgPyaN2+e6CUARcb3Wtm4cWOT/fTTTyZj2ERmmDt3rsm6d+9uMt8ANN9Qin/961/xWRi8uDIFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAAOr7pOQCN1aNvnGSqFGjhsmmTZtmsgYNGphs+fLlJmvVqpXJNm/eHLi69Oec00SePxVrNlVVrVrVZAsWLDDZhRdeaLINGzYUyZpCULOJN2DAAJMNHTrUu+1JJ51ksk2bNsV9TUluoXOuSaJOTs0iADV7BGXLljWZ73Vx4MCBJpsyZYrJChqUhpgUWLNcmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAECAtB9AUatWLZO9//77JpsxY4bJBg8ebLItW7bEZV2Zgpv5M9ujjz5qsiFDhpjs+++/L47lRELNIgVxMz9SDTWLVMMACgAAAACIJ5opAAAAAAhAMwUAAAAAAWimAAAAACBA2g+gQGJxMz9SDTWLFMTN/Eg11CxSDQMoAAAAACCeaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAAUrGuH22iKwrioUgLdVM9AKEmkVsqFmkokTXLTWLWFGzSDUF1qw654pzIQAAAACQFvg1PwAAAAAIQDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNVByp6lpVvTTR6wCiomaRaqhZpBpqFqmGmo0NzVSCqerjqvpTvscvqror0esCCqK5hqrqt6r6g6q+p6pnJnpdQEFUtbSqPqyqm1R1h6qOV9VSiV4XEIWqzlJVp6olE70WoCCZ/H6WZqoAxfWi5Zy72TlX/tBDRP5bRF4pjnMjvRTjD9pOIvInEblIRI4TkY9E5PliOjfSSDHWbH8RaSIiZ4nI6SLSSETuKaZzI40Ud0Ojqp1FhMYfwXg/W/QyrpnKu3Q5QFWX5/0L5URVLaOqzVV1o6r2U9UtIjJRVUuoan9VXaOq21X1ZVU9Lt+xuqrquryvDYrD2sqJSEcRebawx0L6SMKaPUVE5jrnvnbOHRSRF0SkfhyeKtJEEtbs5SLyiHPue+fcNhF5RHL/QQAQkaSsWVHVY0RkiIjcHYeniDSTjDWb73gZ9X4245qpPJ1FpLWI1JHcf6U89C+U1ST3X9prishNInK7iFwhIs1EpLqI7BCRcSIiqlpfRCaISNe8r1UWkRMPnUBV/6iqO4/wONmzro4isk1EPoj3E0bKS6aafUlE6qjq6Zr7q1LXicj0InzuSE3JVLMiInrYn0/Me7MKHJJsNTs871hbiuoJI+UlW80eklnvZ51zGfUQkbUicnO+v7cTkTUi0lxE9olImXxfWyEiLfP9/QQR2S8iJUXkXhF5Kd/XyuXtf2kh1jZLRP6S6O8Rj+R6JFvNisjRIjJWRJyIHBCRb0TklER/n3gkzyMJa3aoiMwTkSzJfZMxP69+T0j094pHcjySsGabiMjivGPWyqvXkon+PvFInkey1exha8uo97OZejPjhnx/Xie5nbiIyDbn3N58X6spIlNVNSdfdlBEqubt8z/Hcc79rKrbQxeU19k3F5HuocdAWkummr1XRJqKyEmS+y+mXUTkXVU90zm3O+B4SE/JVLPDRORYyX1z+ouIPCki54rI1oBjIX0lRc2qagkRGS8if3bOHVDV/2sXZK6kqNn8MvH9bKb+mt9J+f58sohsyvuzO2y7DSLS1jl3bL5HGefctyKyOf9xVLWs5F4aPfT3zvq/p5oc/jj8smhXEZnnnPs6Xk8SaSWZarahiExyzm10zh1wzj0jIpWE+6bwvyVNzTrn9jjnbnPO1XDO1RaR7SKy0DmXI8B/JEvNVpTcK1OT8u55+SRv942qelFcnzFSXbLUbH4Z9342U5upnqp6Yt7Nd4NEZFIB2z0uIsNUtaaIiKpmqep/5X1tsohcpqoXqurRInK/5Pt+OudedPmmmnge6w8717Ui8kwcnyPSSzLV7Cci0klVq2ruTa1dJXfa1Or4P22ksKSpWVWtoarVNdf5IjJYcm/sB/JLlpr9QXKvFjTMe7TL272x5P6KKnBIstRsfhn3fjZTm6l/isgMEflacn+/dGgB240VkddFZIbmzsr/WETOExFxzi0TkZ55x9osuTfzbQxZjKr+RnJv9suIEZIIkkw1O1JEPpfcX5naKSK9RaSjc25nwLGQvpKpZuuIyIci8rPkTpfq75ybEXAcpLekqFmXa8uhh+TeyC8istU5ty+2p4Q0lxQ1e0imvp9V5w6/EpjeVHWtiHRzzs1M9FqAKKhZpBpqFqmGmkWqoWaTR6ZemQIAAACAQqGZAgAAAIAAGfdrfgAAAAAQD1yZAgAAAIAANFMAAAAAEKBkLBurKr8TiJg45xL60e3ULGJFzSIFZTvnshJ1cmoWAahZpJoCa5YrUwAApLZ1iV4AECNqFqmmwJqlmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABKCZAgAAAIAAJRO9AADxU7p0aZP99re/Ndm9997r3b958+Ymy8nJMdmoUaNMds8995hs//793vMAAACkA65MAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIIA656JvrBp9Y0BEnHOayPOnS81Wr17dZNdee63JLrnkkkhZQVTtf66orxENGzY02dKlSyOfO1lQs0WnQoUKJrvtttsi7duqVStvfv7555vsoYceipRt37490rlTwELnXJNEnTydazaqJk3st3/kyJHebVu0aGGyefPmmezyyy832c6dOwNWl5SoWaSaAmuWK1MAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIkNABFJUrVzZZlSpVTPbzzz+bbNOmTSbLycmJz8IQN9zMH7vmzZub7B//+IfJTj755LifuzADKIYNG2ayIUOGFHpNxY2ajY+6deuabMGCBSYrV65cpOP5alMken3u3r3bZAMGDDDZuHHjIh0vyXAzfzHyvfa+9957JqtVq5Z3/6g16xvOMmHChEj7pgBqFv9L7dq1TdaxY0eTnXvuuSbz/WwZM2ZMfBb2HwygAAAAAIB4opkCAAAAgAA0UwAAAAAQgGYKAAAAAAKUTOTJJ0+ebLJKlSqZbM6cOSZbtGiRySZOnBifhQFFoHr16ibz3WDcp08fk5UqVcpksQyPOZzv/z0RkTPPPNNkZ5xxRqRj1qxZM3g9SG2+wUHjx483WdRhE0WhbNmyJnvwwQdN1rp1a5O1b9++SNaE5HfWWWeZbNKkSSaL5fVv165dJqtQoYLJ9u3bF/mYQLKqU6eOyfr27Wuy7t27m6xECXvNZ+3atSabOnVq2OLihCtTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJDQARRlypQxWYMGDUy2ZMkSk/k+Bdx3E/Tf/va3wNUB8fXYY4+ZLN43to8aNcpkDz30kMmys7O9+2dlZZls5syZJvMNpfB9Uvldd91lsm3btnnPjdRw/PHHm+zFF180WbNmzYpjOYVSunRpk/l+jiAznHrqqSabNWuWyXw18ssvv5jshhtu8J7n008/NVmTJk0infvKK680mW+40TPPPGOyH3/80bseIMQ555xjsieeeMJkjRo1MlnJkrb92L9/v8l87/VHjx5tMt9QiuLElSkAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAESOgAiokTJ5rsvPPOM1mXLl0iHe8vf/mLyS6//HKT+W4KXb9+vclq165tsp9//tlkGzdujLQ+ZA7fjcwNGzY0mapGOt53331nspYtW5ps+fLlkY5XkK1bt5ps7NixJvPdZFq2bFmTtWjRwmQvv/xy4OqQDDp06GAy33/nqA4cOGCyQYMGmez999/37u+7Ib9v377B60FmqFGjhskWLVpksnLlykU6Xs+ePU02adKkyOtZs2aNyW6//XaTjRkzxmS+nyM5OTkm8w1BAvLzvXcR8ddO8+bNTXb00UebzFefq1atMln37t1NVtDrfrLhyhQAAAAABDZduXYAACAASURBVKCZAgAAAIAANFMAAAAAEIBmCgAAAAACJHQAxccff2yyDRs2mOykk06KdLxf/epXJrvwwgtNNm3aNJP5bnLz3Qznu0n0X//6l3c9vk8vnz59unfbePI958aNG5vMN1gAsTvmmGNM9tZbb5ns5JNPNplzzmTbtm0zWVEMm4jq1VdfNZnvxugzzzzTZL66YwBFavMN8CmMr776ymS+T7gviG9oEfB/uf/++01Wvnz5SPtec801Jotl2ISP78b/W2+91WS+m/l92Ycfflio9SD91alTx2RPPvmkd9tmzZpFOuaLL75ospkzZ5rslVdeMdnu3bsjnSMZcWUKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAARI6gGLJkiUma926tcnuu+8+k3Xs2NFkJUpE6w1PP/30SJmP74a9O++807vtnj17TLZ//36TTZkyxWQrV640WfXq1U1WsWJFk5UqVcpkTz/9tHeNKDzf4JNTTjkl+Hi+G6OLa9iEz4EDB0w2b948k/kGUAD5rVixwmTt27cv1DE7d+5cqP0P5xsAg9TWpEkTk/kGqfgGAo0fP95kvqE8sfANuvCdx/e+xLdG3w3+ixcvDlwd0lHt2rVN5ntfePHFF3v3972fvfbaa032+uuvm8z3vjfdcGUKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAARI6gMLHN3jh6quvNlnLli1N1q9fP5OddtppJitXrpzJqlSpEnWJkR1zzDEm27Fjh8neffddk/luRn3iiSdMVqFChUhrKehTrVF4f/3rX4P3/frrr03m+wTxeCtoWMQdd9xhsgsvvNBkvv+vfOrXr2+yTp06mcz3aehIP2vWrDHZ2rVrI+17+eWXe/Nzzz23MEsyxowZE9fjoXiVLVvWZI8//nikfffu3WuyUaNGmcw3lMenWrVq3tz389j3niaqhQsXmiwnJyf4eEhtHTp0MNmgQYNMlpWVZbIhQ4Z4jzl37lyTzZ49O2B16YkrUwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAiQdAMoopo1a1akzMd3U7zvhr1LL73UZFWrVjVZLAMDbr/9dpOVLGn/M7z99tuRjqeqJvN9QjqKzhlnnBG874cffmiyH3/8sTDLMXzDJt555x3vtr76Lkw9tW3b1mRt2rSJlD3wwAMmizqsAPFxzjnnmOzkk08OPt6GDRuC923cuLE3L1WqVPAxv/zyS5OtWrUq+HhIvIYNG5os6pCSoUOHmmzdunWR9j377LNN9tRTT3m3LaiWQy1dujSux0Pq6N69u8lGjx5tsqlTp5rsb3/7m8mopTBcmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAECAlB1AURjLly+PlI0dO9ZkvmERO3fujHzuVq1amWz69OmR9z+cbzjA008/bTLf80N8fPDBByZr0qRJpH19A0QKw3fztW+YSZUqVbz7lyhh/30lJyen8AvL56ijjjLZ9ddfHyl74YUXTHbdddfFY1nwqF27tsmOP/744OMdc8wxJitTpozJBg8ebLL+/ft7j1mYASkbN26MlCF1dOzYMXjfJ554wmQVK1aMdI4RI0aYrKD/VwpTs998843JZs+eHXw8pLYBAwaY7P333zfZ+PHjTcawifjhyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQICMnOYX1U8//RRpu0qVKnnz0qVLR9rWN9nnu+++M5lvAptvkqBvCuGBAwe8a0Th+f77RZ3WVJipTr/61a9MNm7cOJNVrlw58nn37dtnsvfee89kzz//fIQVijRu3Nhk11xzjckKmi54uN/+9rcmO/XUU022evXqSMfDkU2dOtVkCxcuNFnU6ZWdO3eOlPn4Jk2KFG7a5KBBg4L3RXK64IILTOabmurLJkyYYLI2bdqYrEKFCpHWsmTJEm/um2pZs2bNSMf86KOPTLZp06ZI+yK1+aarZmVlmcz3HnD+/PlFsibk4soUAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAjCAIkann366yV577TXvtnXr1o10zF9++cVk48ePN1mvXr1M1rt3b5MtXbo00nkRH6+++qrJ+vTpE2nfli1bmqxhw4YmW7x4scl8tXjeeedFOm9BBg4caLLRo0cHH++FF14w2bPPPmuyRx991GS/+c1vTHbKKaeY7O233zZZnTp1oi4RMSrMwJXCKGjQRHGcG6mjf//+Jps1a1akfa+88spI223evNlkzzzzjMnuv/9+7/4PPvigyW677bZI5/70008jbYf0s3XrVpN99tlnJtuwYUNxLAf5cGUKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAARhAEaOOHTuaLOqgiYL4hkjUqFHDZJUrVw4+x29/+1uTlS9f3mQzZswIPkemWrlypcm++eYbk/mGJ5xwwgkmmz59usmqVasWuLrYFGbYRFS+YRqdO3c22bx580zm+37VrFkzPgtDJCNHjjTZK6+8koCVANacOXNMdvPNN5usb9++JsvOzjaZ72fipEmTTOb7OVCQBg0aRNpu9+7dJps2bVrk8yC9/Pzzzyb7/PPPTda2bVuTffnll0WyJuTiyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACMIDiCK655hqT3XvvvQlYSS7fp6736NHDZL6bFCdOnGiyUqVKmaxcuXKBq8tcP/zwg8kGDx5ssmHDhpnMNzyhSpUqJluzZo3JBg0aFHWJkd1yyy0mmzBhQtzPc7gWLVqYLOrAleJYH/7Dd0N+69atTearJZ+zzz7bZLVr1459YQF8w398w1CQOg4ePGiyJ598MlIWb7Vq1fLmTZs2jbS/b1jP6tWrC7MkpJkNGzaYrLBD0RA7rkwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgAAMo8rRp08Zk9evXN1np0qXjfu7x48eb7P333zeZb2DEhRdeaLILLrjAZFu2bDFZhw4doi4RMXrppZdMdsUVV5jsxBNPNFnJkvZ/S98N+S+++GLg6gr24IMPmmzlypUmW7FiRaTjNWnSxGS+T2e/+eabIx1v7969JnvzzTcj7Yv48A24mTVrVqTMp169eiZbunRp7AsL4Bv2AsRLQUNYypYtG2n/hQsXxnM5SENfffWVyXzvXVG0uDIFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAAAygyOMbQNG1a9diObfvZv4ZM2aYbO7cuSZ7/vnnI53j6quvNtn8+fMj7Yv4+MMf/mCyYcOGmaxfv34my8nJMZlzLj4Ly8d3Y/Q777wT13OoqsmiPhffsJa33nqr0GtC4mzatClh527UqJHJGjZsaLLFixcXx3KQZqpXr+7Nfa+BPps3b47ncpCGeL1KDlyZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQAAGUOR5+OGHTdazZ0+T+QYB7N+/33vMf//73yZ77bXXTDZnzhyTrVu3znvMUA888EBcj4f4eOyxx0z2008/max79+4mO/nkk4tkTcli+/btJhs3blwCVoJ0ValSpUgZEE++gTt79+41me89BDJXrVq1TOYbbHXnnXcWw2qQH1emAAAAACAAzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFHl8Ax86dOhgMt8nmv/9738vkjUh/fk+4X7EiBEme+aZZ0x27bXXmuySSy4xWcuWLcMWV4x27NhhslGjRpls7dq1xbAaACi89u3bR952w4YNJlu2bFk8l4MU169fP5N98sknJps2bVpxLAf5cGUKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAARhAcQR8+jiShW9QxciRI002ZswYkx1zzDGRz9O8eXOTNW7cOPL+UXzwwQcm++ijj0z2/fffx/W8SE7OOZPt27fPZKVLl477ub/88kuTrVq1Ku7nQWaqUKGCN49a80htRx11lMlKlrRvu0uUsNc1Lr/8cpO1a9fOZF9//XXg6hBPXJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAAZQAGnkl19+Mdl3330Xef+XX345UgbEy65du0zWtm1bk7377ruFOs+yZctMNmLECJNt3LixUOcBDpk9e7Y3P+uss0w2ZMiQol4Oiln79u1NNn78eJNlZWWZTFVN9vnnn5vshhtuCFwd4okrUwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAigvk/iLnBj1egbAyLinLN3URYjahaxomaRghY655ok6uTULAJkZM3OmTPHZDVq1DDZTTfdZLKZM2cWyZoQWYE1y5UpAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABCiZ6AUAAAAA6e6iiy5K9BJQBLgyBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAALQTAEAAABAAJopAAAAAAhAMwUAAAAAAUrGuH22iKwrioUgLdVM9AKEmkVsqFmkokTXLTWLWFGzSDUF1qw654pzIQAAAACQFvg1PwAAAAAIQDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNVByp6lpVvTTR6wCiomaRaqhZpBpqFqmGmo0NzVSCqer1qnpQVX/K92ie6HUBBVHV61R1oar+qKobVfVBVS2Z6HUBUajqLFV11CySnar2VtUtea+1/1DV0oleE1AQzTVUVb9V1R9U9T1VPTPR6yoONFMFKOYftB8558rne7xXjOdGmijGmi0rIneISBUROU9EWopI32I6N9JIcTc0qtpZREoV5zmRXoqrZlW1tYj0l9zX15oiUltE7iuOcyO9FOPrbCcR+ZOIXCQix4nIRyLyfDGdO6EyrpnKu3Q5QFWXq+oOVZ2oqmVUtXnev7L3U9UtIjJRVUuoan9VXaOq21X1ZVU9Lt+xuqrquryvDUrg00IaS7aadc5NcM7Ncc7tc859KyIvisgFcXq6SAPJVrN5xzlGRIaIyN1xeIpIM0lYs9eJyNPOuWXOuR0i8oCIXF/4Z4p0kYQ1e4qIzHXOfe2cOygiL4hI/Tg81aSXcc1Uns4i0lpE6ojI6SJyT15eTXK76ZoicpOI3C4iV4hIMxGpLiI7RGSciIiq1heRCSLSNe9rlUXkxEMnUNU/qurOIzxOzreec1U1W1W/UtXBxf2vtUgJyVaz+V0sIsvi+myRDpKtZofnHWtLUT1hpLxkqtkzReTzfGv7XESqqmrlInjeSF3JVLMviUgdVT1dVUtJ7j8ITC/C5548nHMZ9RCRtSJyc76/txORNSLSXET2iUiZfF9bISIt8/39BBHZLyIlReReEXkp39fK5e1/aYzrqS253XwJEWkgIstFZECiv088kueRbDV72Nr+JCIbRaRKor9PPJLnkWw1KyJNRGRx3jFriYgTkZKJ/j7xSJ5HEtbsGhFpk+/vpfLqtlaiv1c8kuORhDV7tIiMzavTAyLyjYickujvU3E8MvUKyIZ8f14nuZ24iMg259zefF+rKSJTVTUnX3ZQRKrm7fM/x3HO/ayq22NdiHPu63x//UJV7xeRu0RkRKzHQlpLmpo9RFWvkNw6vdQ5lx16HKStpKhZVS0hIuNF5M/OuQOqGsvuyCxJUbN5fhKRivn+fujPuwKOhfSVTDV7r4g0FZGTJPc3ALqIyLuqeqZzbnfA8VJGpv6a30n5/nyyiGzK+7M7bLsNItLWOXdsvkcZl3ufyOb8x1HVspJ7afTQ3zvr/57Qd/ijoF+ZciLCT3scLqlqVlXbiMiTInK5c+6L+D5VpIlkqdmKkntlalLe/QOf5O2+UVUviuszRqpLlpoVyf3V6XPynfMcEdnqnAv+BzCkpWSq2YYiMsk5t9E5d8A594yIVJIMuG8qU5upnqp6oubefDdIRCYVsN3jIjJMVWuKiKhqlqr+V97XJovIZap6oaoeLSL3S77vp3PuRfe/J/Qd/lifd8y2qlo178/1RGSwiLxWFE8aKS2ZavYSyR060dE5t6Boni7SQLLU7A+S+y+vDfMe7fJ2bywi8+P8nJHakqVmRUSeE5EbVbW+qh4ruffCPBP3Z4xUl0w1+4mIdFLVqpo78KKr5P566ur4P+3kkqnN1D9FZIaIfC25v186tIDtxorI6yIyQ1V3icjHkjsKWpxzy0SkZ96xNkvuzXwbA9bSUkSWqOrPIvKmiEyR3BulgfySqWYHi8gxIvJmvn+ZeivgOEhvSVGzLteWQw8R2Zb3pa3OuX2xPSWkuaSo2bzjTBeRB0Vktoisl9xf4RoS63GQ9pKmZkVkpOQOSlksIjtFpLfk/qPrzoBjpRR17vArgelNVdeKSDfn3MxErwWIgppFqqFmkWqoWaQaajZ5ZOqVKQAAAAAoFJopAAAAAAiQcb/mBwAAAADxwJUpAAAAAAgQ04f2qiqXsRAT51xCPzOLmkWsqFmkoGznXFaiTk7NIgA1i1RTYM1yZQoAgNS2LtELAGJEzSLVFFizNFMAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIQDMFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAAKUTPQCABS/nJwcbz5lyhSTqarJli9fbrLBgwcXfmEAAAAphCtTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACKDOuegbq0bfOIldf/31JuvZs6fJ9uzZY7Lhw4ebbPr06XFZVzpyztnpBcUoXWo23g4ePOjNfa8HvgEUvu06depksqlTpwasLrGoWaSghc65Jok6ebrUbNWqVU12yy23mKx58+Ym27Rpk8kKen+1ZcsWk915550RVphWqNk4KFeunMnq1atnsu7du0c6XlZWljf31XJ2drbJfD/z33777UjnTgEF1ixXpgAAAAAgAM0UAAAAAASgmQIAAACAADRTAAAAABCgZKIXUNQmTJhgsqOOOspkv//970123HHHmWzKlCkmO/HEE73nfuqpp6IsESh2vpuqCzJ06FCTVa5c2WQDBw40WSoOoEDsfK+LHTp0iLTv3LlzTdauXTuT7dq1K/aFASJSq1Ytk914440m69Gjh8l87wOiDuWJxbHHHmuyu+++22Tbt28v1HmQGnyDIHyvqX/+859NVrduXZNFrVnfdrFs261bN5Oly3CqI+HKFAAAAAAEoJkCAAAAgAA0UwAAAAAQgGYKAAAAAAKk1QAK341vxx9/vMn+8Ic/mGz//v0m++6770zWpIn98ONTTz016hKBpPD3v/898raNGjUyme//NWQu383JUW/Iv+CCC0y2du1ak40aNcq7/7PPPmuyTZs2RTo3MsO8efNMVrVq1eDj7d6922S+wRAFDafyuf766022evVqk40YMSLyMZG6evXqZTLfkKfCDpYI3S6WbZ944gmTrV+/3mQLFy6MfO5kw5UpAAAAAAhAMwUAAAAAAWimAAAAACAAzRQAAAAABEirARR/+tOfTHbppZeazDdsIqozzjjDZPXr1/duu2TJkuDzAMnMd+PpnDlzErASJIO//e1vJuvQoUPw8SpVqmSyYcOGebdt3769yQYPHmyymTNnBq8HqaNx48Ymq1atmsl8N+l/9dVXJnvggQdM9t5775ls3759Jjv77LO9a5w6darJKlSoYDJfHZ900kkmu/XWW73nQeryDZuIOtQn3tsV9piVK1c2mW+IFQMoAAAAACDD0EwBAAAAQACaKQAAAAAIQDMFAAAAAAFSdgBFixYtTPbZZ5+ZzPdJ5YXhu/mzb9++3m2ff/75uJ4bSATfIAHfTaZTpkwpjuUgCX3++ecme+ONN0x2+eWXRzpex44dTTZy5Ejvtuedd57J/v3vf5usc+fOJnv11VcjrQepwze0oUQJ++/GOTk5Jtu7d6/JfLW0a9euSGuZPXu2Nz/22GNNdt9995nsnnvuMVmPHj1MdtRRR0XaDqkjas36hkFlZ2ebbP369YVaT7169UxWrly5SPv61phuuDIFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAACk7gOJ3v/udydauXVvk5/XdeHrKKad4t61fv77Jli9fHvc1AUVp6tSpJuvevbvJ5s6dWxzLQRLas2ePyTZv3hx8vAULFpjslltu8W47evRok5199tkmGzBggMl8AwK+//77KEtEkjr33HNN5rtx3/ez/IEHHoi0XVHwDVi55JJLTPab3/zGZDfeeKPJfANgfMM0kHi+IU++mvUNfpozZ47J+vTpY7JFixYFri7Xc889ZzLfUB/fGjMBV6YAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQIGUHUDRt2tRky5Yti+s5srKyTDZ8+HCTTZs2zbu/75PYr7nmmsIvDCgkX22L+G/S990cyyAVFLdZs2Z588cff9xk48ePN1mjRo1Mdscdd5js3nvvDVgdEqFly5YmK+i17XBffPGFyXzDdorL7t27TfbQQw+Z7JVXXol0vNNOO63Qa0LxWLFihclmzJhhsq5du5osOzs7+Ly+/1d69erl3dY3bEJVI53Ht11h1p2MuDIFAAAAAAFopgAAAAAgAM0UAAAAAASgmQIAAACAACk7gMJ3Y+a4ceNMFvUTzX2fmv7oo4+a7MUXXzTZvHnzvGv03TB91llnmWzp0qXe/YFY1axZ02Tbtm0zWZcuXbz7//nPfzaZ78boTp06BawOiKZVq1Ymmzhxonfbjz/+uKiXgyR19913m+zoo49OwEqKxubNm4P37datm8kefvjhwiwHRWTlypUma9u2bVzPcdNNN5mse/fuJvMN6hERcc5FOo9vO99glxEjRkQ6XqrgyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACpOwAirfeestkb7/9tsm++eYbk23dutVkO3bsMFnv3r1NVtCwCZ8JEyaYbMiQISa76qqrTBb1Zj8gvwULFpisT58+Juvfv793f1/dDR8+3GS+G2aBeHnyySdNVqZMGe+2H330UVEvBylEVSNlY8eOLY7lxJ3vufhUqFChiFeCZNC4cWOTvfnmmybLysoyme/nfdT6imXbW265xWS+wVapjCtTAAAAABCAZgoAAAAAAtBMAQAAAEAAmikAAAAACJCyAyh8fDe5+W5kLlHC9pCffvpp3Nfju3F/7ty5JuvVq5fJUvXmWBSfDh06mMx3k+nAgQMjbScismLFCpOl2yeVo3h89tlnwfv6XqNr1arl3ZYBFJnrqaeeMlnLli0j7fv73//eZJMnTy70mopa1OFUDLHKDL5hE5UrVzaZrx5iqZHC1J3vvcrf//73yOdOBVyZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQACaKQAAAAAIkFbT/HwWLVqUsHP/+OOPJrvhhhtM9sYbb5js4MGDJnvsscfiszAktXr16pmsY8eOJuvfv7/JfJN0Xn31VZPVr1/fe+4rrrjCZIMGDTLZsGHDvPsDh/gmrX388cfBx/vqq6+8ed26dYOPidT27rvvmmzbtm0mO/7440126qmnFsmagHioWbOmyRYsWGAy32Re3/sAVY103qjbxbLt0KFDTbZw4cJIWargyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACpP0AimTju8GuZ8+eJnvhhRdM5hum8eGHH8ZnYUgI302mvuEOHTp0MNkHH3xgslq1apnsn//8p8nKlSvnXc+KFStM9sADD5hs7dq1JnvxxRe9x0Rm8g3R+fzzz+N+nho1akTa7sCBAyabMWNGvJeDYrR9+3aT7du3L9K+jRs3NlmjRo1MlsghVshcVapUMVnlypVN5hs24ct8om5X2GP61t2tWzeTMYACAAAAADIMzRQAAAAABKCZAgAAAIAANFMAAAAAEIABFEngX//6l8kef/xxk40fP95kzZo1M9kPP/wQn4WhyD333HMmu+CCC0y2bds2k/Xp08dk69evN1l2drbJypYt612PbwDF1KlTTTZw4ECT7d69O9K+QH7ly5c3WYsWLUx2zTXXePf3betTsqT9cTdhwgSTffHFFya79957TbZ69epI50XxevPNN03Wo0cPk5UuXdpkn3zyicnOO+88k3366aeBqyvYSSedZLJevXqZTFWDj3fllVeabPLkyZGOh+Ll++8cNVu5cqXJxo4da7Lly5ebrH79+t71+IZgtW7d2rvt4aLWbCrjyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACaCyfgKyq0TdGofzqV78y2bJly0zWu3dvk7322mtFsqYQzrmE3nmYTDWblZVlsq1bt5rsgw8+MFnz5s2LYklx9d1335msTZs2Jlu0aFFxLCcYNVt0fEMgZsyYYbJE1vvBgwdNdsMNN5jshRdeKI7lRLXQOdckUSdP9pp9/fXXTdauXbvg491xxx0me++990y2efNm7/6nnXaayXzDJq666qrYF5fHd9P/1VdfbbIEDqCgZo+gcePGJps/f77JfP+dmzZtarKi+Lnre6309RS+NfqGalWrVi0+Cys6BdYsV6YAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQwN4NjKSwZ88ek02dOtVkv/71r02WTAMo8B++TxD33azp+++cCrp27Woy36epJ/sACsRH6dKlTTZ9+nSTNWvWrDiW4zVnzhyTjRo1ymRvvPFGcSwHRWTgwIEm8w15atGiRaTjjR071mS7du0yWSwDKHw36ccyIAzpZf369SbbsGGDyWrW09aY7wAAIABJREFUrGky33uNwv7c9Q3Q8tWsj2+7qPumCq5MAQAAAEAAmikAAAAACEAzBQAAAAABaKYAAAAAIAADKFJIvXr1TLZ48eIErAQhfJ/4vX37dpPddNNNJvPdjJrIQRW+G1ynTJlispycHJO98MILRbImJJd9+/aZbMGCBSbzDaB4/vnnTdauXTvveSpXrmyyn3/+2WSdOnUy2ezZs032yy+/eM+D1LV06VKTXXbZZSa7/fbbTXbPPfeYrEKFCiYrX768yXyDJmIxbdo0kzVq1MhkJ5xwQqHOg+Tje7/Qp08fk02ePNlkAwYMMFm5cuVMtnLlSpP5Bk2IiHTr1s1kvgEpUYem+N4vpDKuTAEAAABAAJopAAAAAAhAMwUAAAAAAWimAAAAACAAAyiSlO9mwQYNGphs4sSJxbEcxIFvYITv08tvvPFGkz333HMmGz58uMlGjBgRuLqCDRo0yGT9+/c3mW/YxLBhw+K+HqQG343II0eONNkTTzxhMt/AlU8++cR7Ht8Aip07d5ps+vTp3v2RmXyDRkaNGmUy35CSiy++OO7refjhhyNt9+GHH5qsevXqJlNVk/kGZyB1+N5DlChhr4n4fhb37t070na+48WybdTtsrOzvedJVVyZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQAAGUCQB3815d911l8mWL19uMt+nXyN1jBkzxmS+G+V92ZNPPhn39Tz//PMmq1evnsl8N4/6Pp3dd8MsMtf3338fKSusRx99NO7HRGZauHBhpKy4bNq0yWS+YS8+f/rTn0zGEKvUNnToUJMNGDDAZL7BEL668W0Xy7a+7V599VWTFcWwrETiyhQAAAAABKCZAgAAAIAANFMAAAAAEIBmCgAAAAACMICiCFWuXNlk7du3N5nvxv1q1aqZrHv37vFZGJLaypUrTdalS5e4nsM3VEJE5IorrjDZ8OHDTeYbfpFun2iO1LB3716TffHFFwlYCVD0Jk2aZDLf6zYyw+DBg03me030DarwDZDwDUQraNtFixaZbMqUKSZLt2ETPlyZAgAAAIAANFMAAAAAEIBmCgAAAAAC0EwBAAAAQAAGUOSpVauWyS6++GKTtWvXzmRnnnmm95hnnXWWyXbs2GGy3r17m8x3k6nvpkJkhrlz58b1eL4hFyIiFSpUiOt5gKLmq+W33norASsBit60adNMtnnzZpNVrFjRZC+//HKRrAmJk5WVZbILL7zQZL4BEs45k73//vve8/gGUc2YMSPKEjMCV6YAAAAAIADNFAAAAAAEoJkCAAAAgAA0UwAAAAAQQH03oBW4sWr0jQERcc5pIs9PzSJW1GxyeuD/t3fv8TbWaR/Hr1+2QzaRQxEhppNDsTORLSGnSEzimTSmRtOknFUqkRGlV3TSSDPVYColNaJ5MuNJitHjKUqTw8Pg2aJsbaKdc/g9f+xtZo/rt5t7/fZa61732p/363W/Xnzdh2vvflbrcu/7WhMnOvMePXqoLCsrK9HlpJrV1tqWYV2cNQsPrFlETbFrljtTAAAAAOCBZgoAAAAAPNBMAQAAAIAHmikAAAAA8MAACiQUD/MjaliziCAe5kfUsGYRNQygAAAAAIB4opkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMADzRQAAAAAeMiIcf/dIrItEYUgLdUPuwBhzSI2rFlEUdjrljWLWLFmETXFrlljrU1mIQAAAACQFvgxPwAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeaKbiyBiTY4zpFHYdQFCsWUQNaxZRw5pF1LBmY0MzlUKMMUuMMdYYkxF2LUAQrFlEDWsWUWCMudkYs9oYk2+M2WGMeYw1i1RmjHnOGLO/yHbEGPNd2HUlA81UMZL9omWMuUlEyibzmkgvrFlEDWsWUZPENVtRREaISA0RaSUiV4vI3Um6NtJIstastXaQtbbSyU1EXhWRecm4dthKXTNVeOvyfmPMemPMXmPMTGNMBWNM+8J//bnXGJMrIjONMacZY+4zxmwxxuwxxrxujKlW5FwDjDHbCv/sgRLUVEVExovI6Dh8iUgzrFlEDWsWUZNqa9ZaO8Nau9xae9Ra+6WIvCIi2XH6cpEGUm3NnlJbpoj0EZHZJT1XFJS6ZqrQTSLSVUQaicgFIjK2MK8lItVEpL6I/EpEhopIbxG5SkTOEZG9IjJdRMQY01hEZojIgMI/qy4idU9ewBjT3xiz7we2ekXqeaTwXLmJ+oIReaxZRA1rFlGTamu2qHYisi6uXy3SQaqu2T4ikiciy+L9Backa22p2kQkR0QGFfl9dxHZIiLtReSoiFQo8mcbROTqIr+vLSLfi0iGiDwoIq8V+bPMwuM7xVhPSxFZU3jOBiJiRSQj7O8TW+psrFm2qG2sWbaobam2Zk+pbaCI7BCRGmF/n9hSZ0vxNbtERH4d9vcoWVtpfZhxe5Ffb5OCTlxEJM9ae7jIn9UXkfnGmBNFsuMicnbhMf84j7X2gDFmTyxFGGNOE5FnRWS4tfaYMSaWw1G6sGYRNaxZRE1KrNmijDG9RWSyFLyx3e17HqStVFyz9aSgobvN9xxRU1p/zO/cIr+uJyJfFf7anrLfdhG5xlpbtchWwRb8/PLOoucxxlSUglujJ39/k/nXqSanbvVE5Awp+BfTuYU/1/px4eE7jDFXxvUrRtSxZhE1rFlETaqs2ZP7dhOR50Wkp7X28/h+qUgTKbVmCw0QkRXW2q3x+iJTXti3xpK9ScFt0c+l4OdBq4nIX6XgZ+nbi8iOU/YdKSLvi0j9wt/XFJFehb9uIiL7RaStiJQTkakickxiuC0qIkYKfq715PZjKfgLUEdEyoX9vWJLjY01yxa1jTXLFrUtldZs4Xk6isgeEWkX9veGLTW3VFuzRa61UUQGhv39SeZWWu9MzRGRxSKyVQp+vnRSMfs9LSILRWSxKZiVv1IKRpSKtXadiAwuPNdOKXiYb0csRdgCuSc3KXhYT0Rkl7X2aGxfEtIcaxZRw5pF1KTEmi00TkSqiMg7Re4ALPI4D9JbKq1ZMcZcIQXNXakYiX6SKewiSw1jTI6I/NJa+27YtQBBsGYRNaxZRA1rFlHDmk0dpfXOFAAAAACUCM0UAAAAAHgodT/mBwAAAADxwJ0pAAAAAPBAMwUAAAAAHjJi2dkYw88EIibWWhPm9VmziBVrFhG021pbM6yLs2bhgTWLqCl2zXJnCgCAaNsWdgFAjFiziJpi1yzNFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADxlhFwC3M888U2V79+4NoRIAAAAALtyZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAcGUKSAF198UWXNmzdXWffu3VW2a9euhNQEAMkwduxYlY0ZM0ZlOTk5zuN79+6tsk2bNpW4LgAAguDOFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADwwgCLJunXrprKf/exnKitbtqzKnnnmGZUNHz5cZTt37vSsDlGSkaH/+nbo0EFlr776qsqqVavmPOfmzZtVlpWVpbL9+/cHKRH4F08//bTKhg0bFujY6tWrO/M1a9ao7NZbb1WZ6+8BSofMzEyVPfbYYyobNGiQyrZt26ay7OxslfH/XaD04s4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPDCAIslq1aqlMtewCZcWLVqorH379irjQevSYdKkSSq7++67Ax1rrXXmDRs2VNnpp5+uMgZQ4N958MEHVRZ02MTo0aNV9sILLzj3vfTSS1XWvXt3la1YsUJlX3zxRaB6EG133HGHylzDJg4ePKgy14CTw4cPx6cwIMUU9370wgsvVNmUKVNU1rVrV5VNnz5dZUOHDvWoLnVxZwoAAAAAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeTHEPojt3Nib4znByfRL7//3f/6msRo0agc73/fffq6xLly4q++CDDwKdL96stSaUCxdKlzXbsmVLlb333nsqq1ixosoOHTqksiNHjjivU7VqVZWNGTNGZY899pjz+HTAmo1dzZo1VbZ161aVVapUSWUvvfSSym699VaVuV7rYjFy5EiV3X777YGysF4/Y7DaWqtfJJIk1dfstGnTVDZ48GCVffbZZyrLyspKSE1gzYatSZMmKnviiSec+3bu3FllGzduVNnOnTtV9umnn6rsrrvuClJiqil2zXJnCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB4ywi6gtDlw4IDKTpw44X0+16dVV6hQwft8SK46deqozPXw/Z133qky17CJffv2qWzixIkqe/fdd531uB7AbtiwoXNf4CTXJ9y7hk24Hlh+4IEHVFbSYROua48bN05lZ555pspatGihsggMoMAPcP1/1+X8889XWevWrVW2cuXKQOcbMGCAM7/ssstUtnr16kDnfOONN1TmGjKE0qt69eoq69Onj8pc7w1c7wFERC655BKV5eXlqez0009XWU5OjvOc6YQ7UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPDAAAogRNddd53Kxo8frzJr9Ye1f/zxxyobM2aMypYuXaoy1yefF6dRo0aB90X6cz1g3KxZs0DHuh543r59e4lrOpXrYWvXsAmXLVu2xLschGzPnj2B9nOt7T/96U8qW7Nmjcp2796tsr59+zqv43o9D2rKlCkqu/baa1W2atUq72sgOlxr7LbbblOZa+jJHXfcobI333zTeZ3jx4+r7IwzzlBZuXLlnMenO+5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwEPaD6CoUqWKyurUqaOy7OzsQNk777yjsrVr1zqvvX//fpVlZWWpzPUQH9JPZmamykaNGqUyY4zK8vPzVda9e3eV7d2717O64q+9cOFC7/NVq1ZNZd988433+RA+1+vVRRddpLIdO3ao7LXXXktITaeqV69eoP2+/vprlS1btize5SBkU6dOVZnr9Xjw4MEqq169uso6dOgQ6LoHDhxw5q73Bm+88YbKXINUatWqpbL27durjAEU6ce1Rpo3b66ybt26qWzz5s1xr+eSSy5RWd26dVX285//XGUjRoxQ2aZNm+JTWAi4MwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPkR1A4XrguXbt2ipzPXjaokUL7+u6HqRzPbQvIpKbm6sy10OvFSpUCHRt13W2bt2qsk8//TTQ+ZBcP/3pT1V23nnnqezbb79VmetTzksybMI10EJEZN68eSqbPXu2ys466yyV3XnnnSq74oorVNa1a9cgJSJFudaxy7Bhw1R2/PjxeJfj1LZt20D7ff755ypz/f1D+pkwYYLKpk2bprIePXp4X+OTTz5x5uvXr1dZ5cqVVda5c2eVnX322d71IDqGDh2qso8++khlAwYMUNmhQ4fiXo/r/bXr/+WuARSugRiuAW8MoAAAAACAUoZmCgAAAAA80EwBAAAAgAeaKQAAAADwEIkBFK5PeJ4+fbrKXJ96X6dOnYTUVJS11pnH+0HR4q5zqtatW3tf489//rPKjh496n0+/FOXLl0C7ff3v/9dZe+++25ca9m+fbszdw0XcA1Nefrpp1XmGpJx4sQJlbVp00ZlH374obMepJ5+/foF2i/o61VJlS1bVmVVq1YNdOyqVaviXQ4izDXU5+WXX477dZo2baqya6+9VmXnn39+oPO98cYbJa4JqWXmzJkq279/fwiVFNi5c6fKxo0bp7JevXqp7K9//avK0u3/+dyZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgIeUG0DhenB46dKlKqtSpYr3NVwPa/bu3VtlGRnBvj3GGGfuegDbtW9JHtQ+77zzVPbWW295n++hhx5S2a9//Wvv85VW5cqVU1nNmjUDHfvII4/EuxzFNVRCxD2AYsSIESq7+OKLA12nTJkyKqtfv77K0u1h1HTmWtthatu2rcouv/zyQMfOmzcv3uUA/3D66ac782effVZlrsE8Lvfdd5/KcnJyYqoLqS8ZwyYqV66sso4dOzr3rVWrlspyc3MD7bdu3TqVbdy4MUiJkcGdKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADyk3za9Vq1YqCzq5zzX1r3///irbvXu3ylzT/IJOeoplGl9JJvclw9NPPx12CWnhxIkTKjt8+HCgYzdv3ux93a5du6qsYcOGKhs1apTzeNd0yHhr1KhRwq+B9FOnTh1nXpLXrPXr13sfC/w7xU3Zq169eqDjXZP7nnjiiZKUhDRTvnx5lR05ckRlAwYMUNmGDRtUtmfPHud1FixYoLKsrCyVzZo1S2UdOnRwnjOdcGcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHlJuAIXrAXqXnTt3quz2229X2a5du1SWmZmpsttuuy3QdcPkGl5x/PhxlWVk6P+sc+fOVdmKFStUFnRIAn7YsWPHVPbdd9+pzBijsg8++EBlBw8eVNk555zjWZ3Iaae5/x3FNTjDtcYWLVqksmuvvTbQdZYvXx6kRKSooK8RlStXVplrvbter1xr6amnnnJe5+yzz1aZa7BE48aNVda5c2eVLVy40HkdlE6uB/zbtGmjMte6qVSpkvOcrtfZRx55RGVTp04NUiLSkOv1s1u3bipbt26dylyvf4sXL1bZoUOHVJafnx+0RLnllltU5hrctmbNmsDnjCruTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMBDyg2gaNq0aaD9ateurbIpU6aobMeOHSpzPTTnGkoRlOuhahH30ICyZcuq7E9/+pPKlixZojLXAIPc3FyVnX/++SpzPRT49ddfqwyJs3nzZpW5hopUrVpVZVWqVAl0bFCuB6BFRN5//32VTZo0KdC1e/ToobJ9+/apLC8vL0CFSFWuh+Kzs7NV9oc//EFlAwYMUJlrbV9++eUqK27N3n333Sr75ptvVDZr1iyVVa9e3XlO4KTJkyerbNiwYYGOLW7Nul4/Xe9LXK/Hy5YtU9n3338fqB5Eh+v9nut9XFCuYWyxaNKkicrq16+vsieeeKJE14kq7kwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAQ8oNoPjLX/6isk6dOgU6tlevXvEuJ5DiBgEMHz5cZTt37lSZa9jE4cOHvetxfdI1wjdhwgSVuR6A7927t8pca8z16eWuh+xd9u7d68yfe+45lbkebg56HdcAGNensyM6FixYoLKJEyeq7LbbblNZ586dA10jPz9fZYMHD3bu+/LLL6vs5ptvDnQd4N9Zvny5yi6++OJAxxY3nMr1ML9rqJbr/VCbNm1U9tFHHwWqB/A1ZMgQldWtW1dlH3zwQTLKSTncmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAIAHU9zwBOfOxgTf2dMZZ5yhsg0bNqjM9bBmMrhqWbVqlXPfESNGqKy4B//TlbXW/QRukiRjzZZGc+fOVVmfPn1U5vr70qxZs4TUFC+s2fgoV66cyrKzs1VWuXJllbkevD9y5Ejga7sGULiGpriG9XTt2jXwdVLIamtty7Auni5rNlnOPfdclY0bN05lAwcOVNmiRYtU1rNnz/gUllys2RR04403OvM5c+aorHv37ipzrc80Uuya5c4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPKTcAAqXBg0aqGz06NEqGzRokMo2bdqksjJlyqhszZo1KnvvvfdUNm/ePJXt3r1bZSjAw/zpiQEUicOaLTnXUIv8/HyVLVmyRGWdOnVKSE0JxsP8EXfBBReobNmyZSqrUaOGyoYOHaqyGTNmxKewxGHNhqx+/foqe+ONN5z7Hjx4UGVXXXVV3GtKcQygAAAAAIB4opkCAAAAAA80UwAAAADggWYKAAAAADxkhF1AEDk5OSobNWqUyh5//HGV7du3T2WnnaZ7yLy8PL/iAIiIiDF6bsPevXtDqASl3eHDh8MuAYiJa1jWPffco7JZs2apLDMzMxElIc1lZWWprHr16s59+/btm+hyIo07UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPAQiQEULq4HjLds2RJCJQBERKzVHyj/6quvhlAJAKQn1+tsjRo1QqgEUVK+fHmVjR8/XmWrV692Hu8aBId/4s4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPER2AAUAAC4nTpxQ2datW0OoBEi8q666KuwSkOJq1aqlsvvvv19ln376aTLKSTvcmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAIAHBlAAANLK8ePHVfbJJ5+o7Mwzz0xGOUBCvfnmm2GXgBTStm1blVWvXl1lCxYsSEY5pQJ3pgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBARQAYrZu3TqVXXjhhYH2A8LQt2/fsEsAYnLZZZeFXQJSXIMGDVQ2cuRIlT355JNJqKb04s4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPBhrbfCdjQm+MyAi1loT5vVZs4gVaxYRtNpa2zKsi7Nm4YE1GwfZ2dkqe/3111VWp06dZJST7opds9yZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgIeMsAsAAAAAEJsVK1aojGETycedKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHiIdQDFbhHZlohCkJbqh12AsGYRG9YsoijsdcuaRaxYs4iaYtessdYmsxAAAAAASAv8mB8AAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTcWSMyTHGdAq7DiAo1iyihjWLqGHNImpYs7GhmUohxpglxhhrjMkIuxagOMaYm40xq40x+caYHcaYx1izSGWsWUSNMeY5Y8z+ItsRY8x3YdcF/BBjzEhjTG7ha+3vjTHlw64pGWimipHs/9EaY24SkbLJvCbSSxLXbEURGSEiNUSklYhcLSJ3J+naSCOsWURNstastXaQtbbSyU1EXhWRecm4NtJLstasMaariNwnBa+v9UWkoYhMSMa1w1bqmqnCW5f3G2PWG2P2GmNmGmMqGGPaF/6L5b3GmFwRmWmMOc0Yc58xZosxZo8x5nVjTLUi5xpgjNlW+GcPlKCmKiIyXkRGx+FLRJpJtTVrrZ1hrV1urT1qrf1SRF4Rkew4fblIA6xZRE2qrdlTassUkT4iMruk50L6SME1e7OIvGitXWet3SsiE0XklpJ/pamv1DVThW4Ska4i0khELhCRsYV5LRGpJgUd9a9EZKiI9BaRq0TkHBHZKyLTRUSMMY1FZIaIDCj8s+oiUvfkBYwx/Y0x+35gq1eknkcKz5WbqC8YkZdqa7aodiKyLq5fLdIBaxZRk6prto+I5InIsnh/wYi8VFqzTUTksyK1fSYiZxtjqifg604t1tpStYlIjogMKvL77iKyRUTai8hREalQ5M82iMjVRX5fW0S+F5EMEXlQRF4r8meZhcd3irGeliKypvCcDUTEikhG2N8nttTZUm3NnlLbQBHZISI1wv4+saXOxppli9qW4mt2iYj8OuzvEVtqbam2Zguv3a3I78sWvqdtEPb3KtFbaX0Ad3uRX2+Tgk5cRCTPWnu4yJ/VF5H5xpgTRbLjInJ24TH/OI+19oAxZk8sRRhjThORZ0VkuLX2mDEmlsNRuqTEmi3KGNNbRCZLwQvubt/zIG2xZhE1qbhm60nBm+PbfM+BtJZKa3a/iJxR5Pcnf532g1NK64/5nVvk1/VE5KvCX9tT9tsuItdYa6sW2SrYgp+531n0PMaYilJwa/Tk728y/zqJ59StnhQstJYiMrfw51o/Ljx8hzHmyrh+xYi6VFmzJ/ftJiLPi0hPa+3n8f1SkSZYs4ialFqzhQaIyApr7dZ4fZFIK6m0ZteJyKVFrnmpiOyy1nr/Y0JkhH1rLNmbFNwW/VwKfh60moj8VQqeWWovIjtO2XekiLwvIvULf19TRHoV/rqJFHThbUWknIhMFZFjEsNtURExUvBzrSe3H0vBX4A6IlIu7O8VW2psqbRmC8/TUUT2iEi7sL83bKm5sWbZoral2potcq2NIjIw7O8PW+ptqbZmRaSbFDz731hEqorIeyLyaNjfp2RspfXO1BwRWSwiW6XgZzwnFbPf0yKyUEQWm4LPd1gpBWN1xVq7TkQGF55rpxQ8zLcjliJsgdyTmxQ8YCpS0Mkfje1LQppLiTVbaJyIVBGRd4r8y9Qij/MgvbFmETWptGbFGHOFFLxRZiQ6ipMya9Za+2cReUxElorIF1LwY4fjYz1PFJnCbrLUMMbkiMgvrbXvhl0LEARrFlHDmkXUsGYRNazZ1FFa70wBAAAAQInQTAEAAACAh1L3Y34AAAAAEA/cmQIAAAAADzRTAAAAAOAhI5adjTH8TCBiYq01YV6fNYtYsWYRQbuttTXDujhrFh5Ys4iaYtcsd6YAAIi2bWEXAMSINYuoKXbN0kwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAABULNc6AAAS/UlEQVQAPNBMAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAh4ywCwAAAAAQm9q1a6ts4MCBKhszZozKKlSooLK1a9c6rzN+/HiVvfXWW0FKLBW4MwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPxlobfGdjgu+MEqlWrZrKfvOb36isRYsWKrv44osTUpMPa60J8/qsWcSKNYsIWm2tbRnWxVmz8MCajdGPfvQjlT3zzDMq69y5c9yvfejQIZXdcsstKnvzzTfjfu0UUuya5c4UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPGSEXQBEqlatqrL//M//VFnr1q1V9u233yakJoSnXr16Krv66qtV1q5dO+9rGOOesdCnTx+V5ebmqqxRo0Yq+/jjj1W2fv16j+oKXHLJJSrLy8tTWbdu3byvgR9WpkwZlT344IMqcw3CufbaawNdw/XftEKFCir77//+b+fxa9asUdlzzz2nspycnED1AD5q1Kihsv79+zv3HTJkiMpcwwW++OILlc2bN09ljz32mMpcf6+Qmk47Td/XcK2diRMnqsz1fiGWwXJBVaxYUWV9+/ZVWZoPoCgWd6YAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADgwcTyoFoUPzE6qMaNG6ts9OjRKrvzzjtVdvDgwcDXueyyy1Q2bdo0lbVp00Zl3333ncpcD3kvW7YscD2JZq11TzpIklRfs7Vr11bZrFmzVNapUyfva+Tn56ts165dzn0T8eDqqerUqaOyzMzMQMfu2LFDZfXr1y9xTUWxZv8pOztbZcuXLw+hkti4BvP8z//8j8puvfVWlX355ZcJqSnBVltrW4Z18VRas/F21llnqcw1HGDs2LEqO/PMMwNfxzXoZ+DAgSqrVauWylzDVT744IPA1w4Ja7bQvffeq7KHH3440LGbN29W2bPPPquyr776KtD5Hn/8cWdet25dlW3fvl1lrsFRrvcgEVXsmuXOFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADxkhF1AqmjZUj9TdvPNN6use/fuKluwYEHg6/zsZz9TWYUKFQIdO2rUKJWl0rAJxK5MmTIqcz1Q6srmzJmjsuPHj6ssLy9PZVu2bAlaYolUqVJFZa4127RpU5W5hgh06dIlPoUhkBYtWgTa78SJEyo7dOiQyt577z2VuQY+NGvWTGU//vGPndcuX768ylzrzrV23nnnHZX94he/UNknn3zivDaiq2zZsiq76667VHb77berrF69eio7duyYylasWOG8tuvvlWuwxAsvvKCyefPmqezdd991Xgepp2vXrip76KGHAh27du1alXXs2FFl33zzTeyFFRo/fnzgfV1DKQYNGqQy1/C0devWqSzK72e5MwUAAAAAHmimAAAAAMADzRQAAAAAeKCZAgAAAAAPDKAo9Morr6isSZMmKhsyZIjKfvnLX8a9HtdDfLNmzYr7dRCuHTt2qGzw4MEhVFJybdu2VZlrOEvVqlVV5hpMMHLkSJVt3LjRszr46N+/f6D9/va3v6ksKysrrrU0aNDAmbsGU4wZM0Zll156qcpcgy5cD/3H+2tBctWuXVtlL730kso6dOgQ6HwzZ85U2ejRo1VW3CAA12ugawjWjTfeqLLhw4erLDs7W2U33XSTylyDjJA4ruFi/fr1U1lGRrC34suXL1dZSYZNZGZmqqy4gWjGmEDnnDx5snc9CxcuVJlraNuBAwe8r5Eo3JkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB2OtDb6zMcF3TlMXXXSRyqZPn+7c1/XJ1C5PPfWUylwP30eRtTbYU4sJwppNDNewCdfDo1WqVFHZ0qVLVdarVy+VhfWQKWv2nz788EOVtW7dWmWuoQ2/+tWvElJTEBUrVlTZlClTVHbHHXeo7NixYypzfS0pNhBotbVWTzBIklRasy4TJ05U2dixY1Xmej/09ttvq6xv374qO3r0qGd1sWnevLnKZs+erTLXcKMePXokpCZPab9mr7vuOpX98Y9/VNnixYtV1qVLl0D7de/ePVAtrsESriEsP/nJT5zHuwZQxNI/BOG6xptvvqmyAQMGqOzIkSNxraUYxa5Z7kwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAQ7CPXcY/uB6Kb9q0aeDjDx06pLLnn3++RDUBiTR+/HiVDRs2TGWuT1P/7W9/q7L77rtPZan4ieYIbsOGDWGX8C8OHjyoMtcQgmuuuUZlDRo0UFksr/FIPa4BDVlZWSrr1q2byubMmaOyZA2bcFmzZo3KXAN8li1bprL7779fZZMnT45PYVC+/PJLlbkGNbkG4biOdQ2lmDdvnspc7zNda7tatWoqi8Xq1atVdsMNN6isXbt2Kps0aZLK6tWrp7Lrr78+UC39+vULtF+icGcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHhhAEaP27dur7Kyzzgp8/MCBA1W2fv36kpQExM24ceNUNnToUJVVrVpVZQsWLFDZnXfeGZ/CEIolS5aorHXr1irr2LGjyp588smE1OTL9bC1a9gE0s/mzZtVtnTpUpXVrFlTZfPnz09ITfGUk5OjssWLF6usUaNGSagGJ+3Zs0dlrgEi1lrva7gGNJTkfMX56KOPVHbdddepLC8vT2WvvPKKylauXKmy999/X2W1a9dW2ZVXXllcmaHhzhQAAAAAeKCZAgAAAAAPNFMAAAAA4IFmCgAAAAA8MIDiB9SpU0dlY8eOLdE5y5cvr7IHHnhAZa6HvPPz81U2bNgwlbkeegSKmjp1qjMfMWKEylwPsw4ZMkRlzz//fMkLQ0qZMGGCyh555BGVnThxIhnllEi7du28j929e3ccK0EquOGGG1T23nvvqezo0aPJKAdpyDUY5KGHHlJZRoZ+K/7EE0+obNSoUXGp66SvvvpKZa5BUiIi48ePV9k333zjfe0tW7ao7OGHH1bZb37zG5W5hgndddddKnv88cc9q4sdd6YAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADgIe0HULg+8bty5coqa9asmcpcwyYuuOCCEtUza9Yslbke8P/ss89U9tprr6msXLlyJaoH4apatarKXOvT9QD8oUOHVOZ6kNU1RKC4B1ldD1v/7ne/U9mMGTOcxyO9HDt2LFAWBZmZmYH2c/0dWLhwYbzLQRLVr19fZa1atVLZlClTklFOUvTq1UtlP//5z0OoBP+O6zV19OjRKrvmmmtU1qRJk0DXcA2IGjRoUKBjk2X58uUqM8aozPU+p0ePHipjAAUAAAAApDiaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHiIxDQ/18S6cePGqaxnz54qa9iwocpc09ISIT8/X2XPPvusyl599VWV/e1vf0tITUgtruk1jRs3VplrPUybNk1lN9xwg8q6desWuJ7/+q//UtmwYcMCHw+kqj59+gTab/369YEyREfz5s1V5pqgtmXLlmSUE3eXXHKJynJzc1W2aNGiZJSDJHJNg37//fdVNnTo0CRUUzKu6duur8+lRo0aKjvjjDNU5npfHg/cmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAICHSAygaNOmjcoeeOABlR09elRlGzZsUNn27dtV5nrwzTUIwPXwWnEP6M+ePduZAye99NJLKnv00UdVdumll6rsxRdf9L6uMcaZ79q1S2UXXnihyjZu3Oh9bSAMxa35U61YsSLBlSAV7N27V2Vr1qwJoZKS+8UvfqGyt956K4RKkAry8vJU9v3334dQSWw6d+7sfWxmZqbKypcvX5JyYsKdKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHiIxAAK16c516xZU2UnTpxQ2cGDB1V25MgRlT388MMqcw2g2Lp1q8oYNAFfTz31lMrmzp2rsjlz5qisdevWca/H9SBz3759VTZ9+nSV/fa3v1XZtm3b4lMYIqdMmTIqcw36cb2Wd+vWrUTXzsrKUlmLFi0CHesauNKqVSuVrVu3TmX79+8PdA0kV/PmzVWWm5sbQiUl96Mf/UhlN954o8quvvrqZJQDxE2dOnW8j83JyVGZaxBHonBnCgAAAAA80EwBAAAAgAeaKQAAAADwQDMFAAAAAB4iMYDCZc+ePd7Huh6Mvv766wMd+/bbb3tfFzjVsWPHVNaxY0eVuYZNuAauvPbaayrLz89XmTHGWY/rIf3LL79cZffee6/KhgwZorJVq1apbNCgQSr7+9//rjJrrbNGhMv1+jlixAiV/cd//IfKWrZsmZCa4qlTp06Bsv/93/9V2e9//3uVTZ06NT6Fwdu5556rso8++iiESmJz2mn637td62nJkiUqcw1IAVKFa2hKr169VBb0fYDrvU8ycWcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHiI7gKIk7rrrLpVddNFFKtu1a5fK5s6dm5CaUDp1795dZS+88EKgY4cNG6ayGTNmlKieypUrq6xHjx4qcz082rhxY5VdddVVKtuwYYPK5s+fr7JXXnkl0H5IrsmTJ6vs7rvvDqGSxHANbFm2bFmgYxcuXBjvchAH+/fvV1mlSpVCqCQ29erVU1l2drbK2rZtm4xyELLnnntOZc8884zK+vbtq7KnnnpKZStXroxPYUVkZOi2okOHDipz1e0auOIatJWXl6ey559/PmiJCcGdKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHgolQMoGjRoEGi/FStWqIxPFYcv14CGP/zhD4GOXbp0qcoS8cDld999pzLXJ4u7srp166ps9OjRKtu6davK+vfvr7KePXuqzPXA7M6dO1WGxAn6+nns2DGVrV27NtCxCxYsUJlrMISI+3W6X79+Khs1apTK/vKXv6jspz/9qcq+/fZb57URDa7XT9frS6pxDbxyreONGzcmoxyEbPPmzSqz1gbKXH8HJkyYoLItW7Y4r22MUVmjRo1U5hqqdcUVVzjPeSrXsAnX1/L4448HOl8ycWcKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHkrlAIqgNm3aFHYJiKhKlSqpzPWJ31WqVFHZrl27VHb99derzPWAf5h27NihsmHDhgU61vXp7EhN99xzj8oefvhhlbnWZ7IG+PTt2zfQfvPnz1cZwyYQBtdQn9NPP11lL730UjLKQQpyDcxZtmyZyq688kqVlS1bVmWTJk0KfG3XAArXcIiS2L9/v8pcNTKAAgAAAADSBM0UAAAAAHigmQIAAAAADzRTAAAAAOChVA6gcD2c5/Lhhx8muBKkqzFjxqisffv2Ktu3b5/KXA/P5+fnx6UuoKS2bdsWKAtTy5YtVXbkyBGVvf3228koByE7evSoylq1ahVCJQV69uypsiZNmqisS5cuySgHEdarVy+Vvf766yrr1KlTMsoJbObMmSp79NFHVbZly5ZklFNi3JkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAh7QfQHHdddeprGnTpio7dOiQyjZu3JiQmpBeRo4cqbJ77rlHZZ9//rnK2rVrpzKGTQDxl5ubq7KdO3eGUAmSbdGiRSp7+eWXVfaTn/xEZfPnz/e+rms4gIjICy+8EGhf1if+Hdf7hX79+qnMtbZdg7Jc74VFRJo1a6Yy13CIBQsWqGzWrFmBjnUNCYoK7kwBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAQ9oPoFi5cqXK1q5dq7IDBw6obNOmTQmpCdHlGmjy4IMPqmzVqlUq69Onj8oYNgGUTKdOnVSWnZ2tMtdD/yi9hg8frrLRo0er7ODBgyqrVKmSyoYMGaKyc845x3ntW2+9VWUffvihc18gVq73FbNnzw6UwQ93pgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOAh7QdQfP311ypzfZIzEMSMGTNU5vok73vuuUdlX331VUJqAkqz8847T2WuwUNz5sxJRjmICNd6qFWrlsqmTZumMtcAiieffFJlrv9fiLgHXgGILu5MAQAAAIAHmikAAAAA8EAzBQAAAAAeaKYAAAAAwIOx1gbf2ZjgOwMiYq01YV6fNYtYsWYRQauttS3DujhrFh5Ys4iaYtcsd6YAAAAAwAPNFAAAAAB4oJkCAAAAAA80UwAAAADggWYKAAAAADzQTAEAAACAB5opAAAAAPBAMwUAAAAAHmimAAAAAMBDRoz77xaRbYkoBGmpftgFCGsWsWHNIorCXresWcSKNYuoKXbNGmttMgsBAAAAgLTAj/kBAAAAgAeaKQAAAADwQDMFAAAAAB5opgAAAADAA80UAAAAAHigmQIAAAAADzRTAAAAAOCBZgoAAAAAPNBMAQAAAICH/wc2hN/DNrkeJgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "# Show some of them.\n", "show_img_grid(\n", " [test_ds['image'][idx] for idx in error_idxs[:25]],\n", " [f'pred={logits[idx].argmax()}' for idx in error_idxs[:25]],\n", ")" ] } ], "metadata": { "accelerator": "GPU", "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: examples/mnist/mnist_benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark for the MNIST example.""" import time from absl import flags from absl.testing import absltest from absl.testing.flagsaver import flagsaver from flax.testing import Benchmark import jax import numpy as np import main from configs import default # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS class MnistBenchmark(Benchmark): """Benchmarks for the MNIST Flax example.""" @flagsaver def test_cpu(self): """Run full training for MNIST CPU training.""" # Prepare and set flags defined in main.py. workdir = self.get_tmp_model_dir() config = default.get_config() FLAGS.workdir = workdir FLAGS.config = config start_time = time.time() main.main([]) benchmark_time = time.time() - start_time summaries = self.read_summaries(workdir) # Summaries contain all the information necessary for the regression # metrics. wall_time, _, eval_accuracy = zip(*summaries['eval_accuracy']) wall_time = np.array(wall_time) sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) end_eval_accuracy = eval_accuracy[-1] # Assertions are deferred until the test finishes, so the metrics are # always reported and benchmark success is determined based on *all* # assertions. self.assertBetween(end_eval_accuracy, 0.98, 1.0) # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics({ 'sec_per_epoch': sec_per_epoch, 'accuracy': end_eval_accuracy, }) self.report_extras({ 'model_name': 'MNIST', 'description': 'CPU test for MNIST.', 'implementation': 'linen', }) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/mnist/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.4.1 jax==0.3.4 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image. ml-collections==0.1.0 numpy==1.22.0 optax==0.1.0 tensorflow==2.11.1 tensorflow-datasets==4.4.0 ================================================ FILE: examples/mnist/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST example. Library file which executes the training and evaluation loop for MNIST. The data is loaded using tensorflow_datasets. """ # See issue #620. # pytype: disable=wrong-keyword-args from functools import partial from typing import Any from pathlib import Path from absl import logging from flax import nnx from flax.metrics import tensorboard import jax import ml_collections import optax import tensorflow as tf import tensorflow_datasets as tfds tf.random.set_seed(0) # Set the random seed for reproducibility. class CNN(nnx.Module): """A simple CNN model.""" def __init__(self, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) self.dropout1 = nnx.Dropout(rate=0.025) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.dropout2 = nnx.Dropout(rate=0.025) self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x, rngs: nnx.Rngs): x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs)))) x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) x = x.reshape(x.shape[0], -1) # flatten x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs)) x = self.linear2(x) return x def loss_fn(model: CNN, batch, rngs): logits = model(batch['image'], rngs) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] ).mean() return loss, logits @nnx.jit def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, rngs): """Train for a single step.""" grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch, rngs) metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. optimizer.update(model, grads) # In-place updates. @nnx.jit def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): loss, logits = loss_fn(model, batch, None) metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. def get_datasets( config: ml_collections.ConfigDict, ) -> tuple[tf.data.Dataset, tf.data.Dataset]: """Load MNIST train and test datasets into memory.""" batch_size = config.batch_size train_ds: tf.data.Dataset = tfds.load('mnist', split='train') test_ds: tf.data.Dataset = tfds.load('mnist', split='test') train_ds = train_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # normalize train set test_ds = test_ds.map( lambda sample: { 'image': tf.cast(sample['image'], tf.float32) / 255, 'label': sample['label'], } ) # normalize the test set. # Create a shuffled dataset by allocating a buffer size of 1024 to randomly # draw elements from. train_ds = train_ds.shuffle(1024) # Group into batches of `batch_size` and skip incomplete batches, prefetch the # next sample to improve latency. train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) # Group into batches of `batch_size` and skip incomplete batches, prefetch the # next sample to improve latency. test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) return train_ds, test_ds def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory path to store metrics. """ train_ds, test_ds = get_datasets(config) # Instantiate the model. model = CNN(rngs=nnx.Rngs(0)) learning_rate = config.learning_rate momentum = config.momentum summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) optimizer = nnx.Optimizer( model, optax.sgd(learning_rate, momentum), wrt=nnx.Param ) metrics = nnx.MultiMetric( accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'), ) rngs = nnx.Rngs(0) for epoch in range(1, config.num_epochs + 1): # Run the optimization for one step and make a stateful update to the # following: # - The train state's model parameters # - The optimizer state # - The training loss and accuracy batch metrics model.train() # Switch to train mode for batch in train_ds.as_numpy_iterator(): train_step(model, optimizer, metrics, batch, rngs) # Compute the training metrics. train_metrics = metrics.compute() metrics.reset() # Reset the metrics for the test set. # Compute the metrics on the test set after each training epoch. model.eval() # Switch to eval mode for batch in test_ds.as_numpy_iterator(): eval_step(model, metrics, batch) # Compute the eval metrics. eval_metrics = metrics.compute() metrics.reset() # Reset the metrics for the next training epoch. logging.info( # pylint: disable=logging-not-lazy 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,' ' test_accuracy: %.2f' % ( epoch, train_metrics['loss'], train_metrics['accuracy'] * 100, eval_metrics['loss'], eval_metrics['accuracy'] * 100, ) ) # Write the metrics to TensorBoard. summary_writer.scalar('train_loss', train_metrics['loss'], epoch) summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('test_loss', eval_metrics['loss'], epoch) summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch) summary_writer.flush() # Export the model to a SavedModel directory. from orbax.export import JaxModule, ExportManager, ServingConfig def exported_predict(model, y): return model(y, None) model.eval() jax_module = JaxModule(model, exported_predict) sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)] export_mgr = ExportManager(jax_module, [ ServingConfig('mnist_server', input_signature=sig) ]) output_dir= Path(workdir) / 'mnist_export' export_mgr.save(str(output_dir)) ================================================ FILE: examples/mnist/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.mnist.mnist_lib.""" import pathlib import tempfile import sys from absl.testing import absltest import jax import flax.nnx as nnx from jax import numpy as jnp import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from configs import default import train CNN_PARAMS = 825_034 class TrainTest(absltest.TestCase): """Test cases for train.""" def setUp(self): super().setUp() if sys.version_info < (3, 13): self.skipTest('Tensorflow 2.20 required for this test, which conflicts with tensorflow_text.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], "GPU") def test_cnn(self): """Tests CNN module used as the trainable model.""" inputs = jnp.ones((1, 28, 28, 1), jnp.float32) cnn = train.CNN(nnx.Rngs(0)) cnn.eval() output = cnn(inputs, None) self.assertEqual((1, 10), output.shape) def test_train_and_evaluate(self): """Tests training and evaluation code by running a single step.""" # Create a temporary directory where tensorboard metrics are written. workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + "/.tfds/metadata" # pylint: disable=unused-variable # Define training configuration. config = default.get_config() config.num_epochs = 1 config.batch_size = 8 with tfds.testing.mock_data(num_examples=8, data_dir=data_dir): train.train_and_evaluate(config=config, workdir=workdir) if __name__ == "__main__": absltest.main() ================================================ FILE: examples/nlp_seq/README.md ================================================ ## Part-of-Speech Tagging Trains a simple sequence-based part-of-speech tagger. The following sentence shows an example. ``` From|ADP the|DT AP|PROPN comes|VBZ this|DT story|NN :|: ``` ### Requirements * Universal Dependency data sets: https://universaldependencies.org/#download. Download via command line: ``` curl -# -o ud-treebanks-v2.0.tgz https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1976/ud-treebanks-v2.0.tgz tar xzf ud-treebanks-v2.0.tgz ``` ### Supported setups The model should run with other configurations and hardware, but explicitly tested on the following. | Hardware | Batch size | Learning rate | Training time | Accuracy | TensorBoard.dev | |:---:|:---:|:---:|:---:|:---:|:---:| | Nvidia Titan V (12GB) | 64 | 0.05 | 5:58h | 68.6% | [2022-05-01](https://tensorboard.dev/experiment/F5ULHlyzQlieVJn5PG8mRQ/) | ### Running ``` python train.py --batch_size=64 --model_dir=./ancient_greek \ --dev=ud-treebanks-v2.0/UD_Ancient_Greek/grc-ud-dev.conllu \ --train=ud-treebanks-v2.0/UD_Ancient_Greek/grc-ud-train.conllu ``` ================================================ FILE: examples/nlp_seq/configs/default.py ================================================ # Copyright 2025 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default hyperparameters for NLP sequence tagging.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Model directory for checkpoints and logs config.model_dir = '' # Experiment name config.experiment = 'xpos' # Training hyperparameters config.batch_size = 64 config.num_train_steps = 75000 config.eval_frequency = 100 # Optimizer hyperparameters config.learning_rate = 0.05 config.weight_decay = 1e-1 # Model hyperparameters config.max_length = 256 # Random seed config.random_seed = 0 # Data paths config.train = '' config.dev = '' return config ================================================ FILE: examples/nlp_seq/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for the sequence tagging dataset.""" import codecs import collections import enum import tensorflow as tf # pytype: disable=import-error # Values for padding, unknown words and a root. PAD = '

' PAD_ID = 0 UNKNOWN = '' UNKNOWN_ID = 1 ROOT = '' ROOT_ID = 2 class CoNLLAttributes(enum.Enum): """CoNLL attributre names and indices. A UD CoNLL file looks like: 1 They they PRON PRP Case=Nom|Number=Plur 2 nsubj 2 buy buy VERB VBP Number=Plur|PTense=Pres 0 root 3 books book NOUN NNS Number=Plur 2 obj 4 . . PUNCT . _ 2 punct For details, please see: http://universaldependencies.org/format.html. """ ID = 0 FORM = 1 LEMMA = 2 UPOS = 3 XPOS = 4 FEATS = 5 HEAD = 6 DEPREL = 7 def create_vocabs(filename, max_num_forms=100000): """Loads corpus and create vocabulary lists. Args: filename: file name of a corpus. max_num_forms: maximum number of tokens included. Returns: Dictionary containing named vocab dictionaries. """ form_counter = collections.Counter() xpos_counter = collections.Counter() with tf.io.gfile.GFile(filename, 'rb') as f: for line in codecs.getreader('utf-8')(f): line = line.strip() split = line.split('\t') if not line.startswith('#') and split[0]: form_counter[split[CoNLLAttributes.FORM.value]] += 1 xpos_counter[split[CoNLLAttributes.XPOS.value]] += 1 special_tokens = {PAD: PAD_ID, UNKNOWN: UNKNOWN_ID, ROOT: ROOT_ID} # create word form vocab vocabs = {'forms': {}, 'xpos': {}} vocabs['forms'].update(special_tokens) vocabs['forms'].update( { form[0]: id for id, form in enumerate( form_counter.most_common(max_num_forms), start=ROOT_ID + 1 ) } ) # create xpos vocab vocabs['xpos'].update(special_tokens) vocabs['xpos'].update( { tag[0]: id for id, tag in enumerate( xpos_counter.most_common(), start=ROOT_ID + 1 ) } ) return vocabs def create_token(token, attributes, vocabs): """Map for a token a selected subset of attributes to indices. Input example: CoNLL 09 representation for a token. ['Ms.', 'ms.', 'ms.', 'NNP', '_', '2', 'TITLE] Output example: Indices as defined in self._attributes, e.g., [word form, part-of-speech tag, and head]. [1025, 3, 1] Args: token: CoNLL token attributes. attributes: selected attributes. vocabs: dictionary of vocabs. Returns: List of attribute ids for a token, e.g. [1025, 3] with word id and pos id. Raises: ValueError: CoNLL attribute requested but not covered by mapping. """ selected_attributes = [] for attribute in attributes: index = attribute.value if attribute == CoNLLAttributes.FORM: selected_attributes.append(vocabs['forms'].get(token[index], UNKNOWN_ID)) elif attribute == CoNLLAttributes.XPOS: selected_attributes.append(vocabs['xpos'].get(token[index], UNKNOWN_ID)) elif attribute == CoNLLAttributes.HEAD: selected_attributes.append(int(token[index])) else: raise ValueError( 'CoNLL index %s not covered by mapping.' % str(attribute.name) ) return selected_attributes def create_sentence_with_root(attributes, vocabs): """Create a sentence containing a root. Args: attributes: attributes extracted from token. vocabs: dictionary of vocabs. Returns: A list representing a sentence containing the root only, e.g., [[2, 1, 0]] for root word, unknown xpos, and head 0. """ # Create the token properties of an artificial root node. token_properties = [ROOT for _ in range(12)] # CoNLL 09 has 12 columns. token_properties[CoNLLAttributes.ID.value] = '0' token_properties[CoNLLAttributes.HEAD.value] = '0' token = create_token(token_properties, attributes, vocabs) if len(token) == 1: token = token[0] return [token] def sentences_from_conll_data( corpus_filename, vocabs, attributes, max_sentence_length=1000 ): """Load and returns conll data in list format. Args: corpus_filename: filename of corpus. vocabs: dictionary of vocabs attributes: list of conll attributes to include into the batch max_sentence_length: cut off sentences longer as max tokens Yields: A sentence as a list of tokens while tokens are lists of attributes. """ with tf.io.gfile.GFile(corpus_filename, 'rb') as f: sentence = create_sentence_with_root(attributes, vocabs) for line in codecs.getreader('utf-8')(f): line = line.strip() if line.startswith('#'): continue split = line.split('\t') if split[0]: # Not an empty line, process next token: if len(sentence) < max_sentence_length: if len(attributes) == 1: sentence.append(create_token(split, attributes, vocabs)[0]) else: sentence.append(create_token(split, attributes, vocabs)) else: # Sentences start with an empty line, yield sentence: yield sentence # Reset sentence. sentence = create_sentence_with_root(attributes, vocabs) if len(sentence) > 1: # sentences does not only contain a root. yield sentence def sentence_dataset_dict( filename, vocabs, attributes_input, attributes_target, batch_size, bucket_size, repeat=None, prefetch_size=tf.data.experimental.AUTOTUNE, ): """Combines sentences into a dataset of padded batches. Args: filename: file name of a corpus. vocabs: dictionary of dictionaries to map from strings to ids. attributes_input: attributes for the input. attributes_target: target attributes empty targets is not included. batch_size: the size of a batch. bucket_size: the size of a bucket. repeat: number of times the dataset is repeated. prefetch_size: prefetch size of the data. Returns: Returns dataset as dictionary containing the data as key value pairs. """ data_keys = ['inputs'] if attributes_target: data_keys.append('targets') def generator(): """Generator to create the data.""" input_generator = sentences_from_conll_data( filename, vocabs, attributes_input, max_sentence_length=bucket_size ) if attributes_target: target_generator = sentences_from_conll_data( filename, vocabs, attributes_target, max_sentence_length=bucket_size ) for inputs in input_generator: data = {'inputs': inputs} if attributes_target: data['targets'] = next(target_generator) yield data output_types = {k: tf.float32 for k in data_keys} output_shapes = {k: (None,) for k in data_keys} dataset = tf.data.Dataset.from_generator( generator, output_types=output_types, output_shapes=output_shapes ) # cache the dataset in memory and repeat. dataset = dataset.cache() dataset = dataset.repeat(repeat) # static padding up to bucket size. padded_shapes = {k: [bucket_size] for k in data_keys} dataset = dataset.padded_batch( batch_size=batch_size, padded_shapes=(padded_shapes) ) dataset = dataset.prefetch(prefetch_size) return dataset ================================================ FILE: examples/nlp_seq/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.nlp.input_pipeline.""" import os from absl.testing import absltest import jax import tensorflow as tf import input_pipeline # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() CONLL_DATA = """1\tThey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur\t2\tnsubj 2\tbuy\tbuy\t VERB\tVBP\tNumber=Plur|PTense=Pres\t0\troot 3\tbooks\tbook\tNOUN\tNNS\tNumber=Plur\t2\tobj 4\t.\t.\tPUNCT\t.\t_\t2\tpunct 1\tThey\tthey\tPRON\tPRP\tCase=Nom|Number=Plur\t2\tnsubj 2\tbuy\tbuy\t VERB\tVBP\tNumber=Plur|PTense=Pres\t0\troot 3\tbooks\tbook\tNOUN\tNNS\tNumber=Plur\t2\tobj 4\t.\t.\tPUNCT\t.\t_\t2\tpunct 1\tNY\tNY\tNOUN\tNNS\tNumber=Singular\t0\troot """ class InputPipelineTest(absltest.TestCase): def setUp(self): super().setUp() self.test_tmpdir = self.create_tempdir() # Write a sample corpus. self._filename = os.path.join(self.test_tmpdir.full_path, 'data.conll') with tf.io.gfile.GFile(self._filename, 'w') as f: # The CoNLL data has to end with an empty line. f.write(CONLL_DATA) f.write('\n') def test_vocab_creation(self): """Tests the creation of the vocab.""" vocabs = input_pipeline.create_vocabs(self._filename) self.assertEqual( vocabs['forms'], { '

': 0, '': 1, '': 2, 'They': 3, 'buy': 4, 'books': 5, '.': 6, 'NY': 7, }, ) def testInputBatch(self): """Test the batching of the dataset.""" vocabs = input_pipeline.create_vocabs(self._filename) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [] # empty target for tagging of unlabeled data. sentence_dataset = input_pipeline.sentence_dataset_dict( self._filename, vocabs, attributes_input, attributes_target, batch_size=2, bucket_size=10, repeat=1, ) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() self.assertSameStructure( inputs, [ [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ) self.assertLen(batch, 1) # make sure target is not included. def testInputTargetBatch(self): """Test the batching of the dataset.""" vocabs = input_pipeline.create_vocabs(self._filename) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] sentence_dataset = input_pipeline.sentence_dataset_dict( self._filename, vocabs, attributes_input, attributes_target, batch_size=2, bucket_size=10, repeat=1, ) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() self.assertSameStructure( inputs, [ [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ) targets = batch['targets'].numpy().tolist() self.assertSameStructure( targets, [ [2.0, 4.0, 5.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], [2.0, 4.0, 5.0, 3.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0], ], ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/nlp_seq/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the NLP sequence tagging example. This file is intentionally kept short to allow config-based execution. """ from absl import app from absl import flags import train from ml_collections import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file( 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', lock_config=True, ) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Convert config to FLAGS for train.py compatibility config = FLAGS.config # Override FLAGS with config values FLAGS.model_dir = config.model_dir FLAGS.experiment = config.experiment FLAGS.batch_size = config.batch_size FLAGS.eval_frequency = config.eval_frequency FLAGS.num_train_steps = config.num_train_steps FLAGS.learning_rate = config.learning_rate FLAGS.weight_decay = config.weight_decay FLAGS.max_length = config.max_length FLAGS.random_seed = config.random_seed FLAGS.train = config.train FLAGS.dev = config.dev # Run the training train.main(argv) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/nlp_seq/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer-based language models.""" from typing import Any, Optional from collections.abc import Callable from flax import linen as nn from flax import struct import jax.numpy as jnp import numpy as np @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" vocab_size: int output_vocab_size: int dtype: Any = jnp.float32 emb_dim: int = 512 num_heads: int = 8 num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 2048 dropout_rate: float = 0.3 attention_dropout_rate: float = 0.3 kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) posemb_init: Callable | None = None def sinusoidal_init(max_len=2048): """1D Sinusoidal Position Embedding Initializer. Args: max_len: maximum possible length for the input Returns: output: init function returning `(1, max_len, d_feature)` """ def init(key, shape, dtype=np.float32): """Sinusoidal init.""" del key, dtype d_feature = shape[-1] pe = np.zeros((max_len, d_feature), dtype=np.float32) position = np.arange(0, max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature) ) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) return init class AddPositionEmbs(nn.Module): """Adds (optionally learned) positional embeddings to the inputs. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__(self, inputs): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a learned position embedding is desired, pass an initializer to posemb_init in the configuration. Args: inputs: input data. Returns: output: `(bs, timesteps, in_dim)` """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) assert inputs.ndim == 3, ( 'Number of dimensions should be 3, but it is: %d' % inputs.ndim ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. pos_embedding = sinusoidal_init(max_len=config.max_len)( None, pos_emb_shape, None ) else: pos_embedding = self.param( 'pos_embedding', config.posemb_init, pos_emb_shape ) pe = pos_embedding[:, :length, :] return inputs + pe class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. Attributes: config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ config: TransformerConfig out_dim: int | None = None @nn.compact def __call__(self, inputs, deterministic=True): """Applies Transformer MlpBlock module.""" config = self.config actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, )(inputs) x = nn.elu(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, )(x) output = nn.Dropout(rate=config.dropout_rate)( output, deterministic=deterministic ) return output class Encoder1DBlock(nn.Module): """Transformer encoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__(self, inputs, deterministic): """Applies Encoder1DBlock module. Args: inputs: input data. deterministic: if true dropout is applied otherwise not. Returns: output after transformer encoder block. """ config = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=config.dtype)(inputs) x = nn.MultiHeadDotProductAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=config.kernel_init, bias_init=config.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=deterministic, )(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=config.dtype)(x) y = MlpBlock(config=config)(y, deterministic=deterministic) return x + y class Transformer(nn.Module): """Transformer Model for sequence tagging.""" config: TransformerConfig @nn.compact def __call__(self, *, inputs, train): """Applies Transformer model on the inputs. Args: inputs: input data train: if it is training. Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) config = self.config x = inputs.astype('int32') x = nn.Embed( num_embeddings=config.vocab_size, features=config.emb_dim, name='embed' )(x) x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=not train) x = AddPositionEmbs(config)(x) for _ in range(config.num_layers): x = Encoder1DBlock(config)(x, deterministic=not train) x = nn.LayerNorm(dtype=config.dtype)(x) logits = nn.Dense( config.output_vocab_size, kernel_init=config.kernel_init, bias_init=config.bias_init, )(x) return logits ================================================ FILE: examples/nlp_seq/requirements.txt ================================================ absl-py==1.0.0 flax==0.3.6 numpy==1.22.0 tensorflow==2.11.1 ================================================ FILE: examples/nlp_seq/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sequence Tagging example. This script trains a Transformer on the Universal dependency dataset. """ import functools import os import time from absl import app from absl import flags from absl import logging from flax import jax_utils from flax import linen as nn from flax.metrics import tensorboard from flax.training import common_utils from flax.training import train_state import jax import jax.numpy as jnp from jax import random import numpy as np import optax import tensorflow as tf import input_pipeline import models FLAGS = flags.FLAGS flags.DEFINE_string('model_dir', default='', help='Directory for model data.') flags.DEFINE_string('experiment', default='xpos', help='Experiment name.') flags.DEFINE_integer('batch_size', default=64, help='Batch size for training.') flags.DEFINE_integer( 'eval_frequency', default=100, help='Frequency of eval during training, e.g. every 1000 steps.', ) flags.DEFINE_integer( 'num_train_steps', default=75000, help='Number of train steps.' ) flags.DEFINE_float('learning_rate', default=0.05, help='Learning rate.') flags.DEFINE_float( 'weight_decay', default=1e-1, help='Decay factor for AdamW style weight decay.', ) flags.DEFINE_integer( 'max_length', default=256, help='Maximum length of examples.' ) flags.DEFINE_integer( 'random_seed', default=0, help='Integer for PRNG random seed.' ) flags.DEFINE_string('train', default='', help='Path to training data.') flags.DEFINE_string('dev', default='', help='Path to development data.') def create_learning_rate_scheduler( factors='constant * linear_warmup * rsqrt_decay', base_learning_rate=0.5, warmup_steps=8000, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, ): """creates learning rate schedule. Interprets factors in the factors string which can consist of: * constant: interpreted as the constant value, * linear_warmup: interpreted as linear warmup until warmup_steps, * rsqrt_decay: divide by square root of max(step, warmup_steps) * decay_every: Every k steps decay the learning rate by decay_factor. * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. Args: factors: a string with factors separated by '*' that defines the schedule. base_learning_rate: float, the starting constant for the lr schedule. warmup_steps: how many steps to warm up for in the warmup schedule. decay_factor: The amount to decay the learning rate by. steps_per_decay: How often to decay the learning rate. steps_per_cycle: Steps per cycle when using cosine decay. Returns: a function learning_rate(step): float -> {'learning_rate': float}, the step-dependent lr. """ factors = [n.strip() for n in factors.split('*')] def step_fn(step): """Step to learning rate function.""" ret = 1.0 for name in factors: if name == 'constant': ret *= base_learning_rate elif name == 'linear_warmup': ret *= jnp.minimum(1.0, step / warmup_steps) elif name == 'rsqrt_decay': ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'rsqrt_normalized_decay': ret *= jnp.sqrt(warmup_steps) ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) elif name == 'decay_every': ret *= decay_factor ** (step // steps_per_decay) elif name == 'cosine_decay': progress = jnp.maximum( 0.0, (step - warmup_steps) / float(steps_per_cycle) ) ret *= jnp.maximum( 0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))) ) else: raise ValueError('Unknown factor %s.' % name) return jnp.asarray(ret, dtype=jnp.float32) return step_fn def compute_weighted_cross_entropy(logits, targets, weights=None): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch x length] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) onehot_targets = common_utils.onehot(targets, logits.shape[-1]) loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1) normalizing_factor = onehot_targets.sum() if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): """Compute weighted accuracy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch x length] Returns: Tuple of scalar accuracy and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( 'Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { 'loss': loss, 'accuracy': acc, 'denominator': weight_sum, } metrics = np.sum(metrics, -1) return metrics def train_step(state, batch, model, learning_rate_fn, dropout_rng=None): """Perform a single training step.""" train_keys = ['inputs', 'targets'] (inputs, targets) = (batch.get(k, None) for k in train_keys) weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(dropout_rng, state.step) def loss_fn(params): """loss function used for training.""" logits = model.apply( {'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights) mean_loss = loss / weight_sum return mean_loss, logits lr = learning_rate_fn(state.step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) grads = jax.lax.pmean(grads, 'batch') new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) metrics['learning_rate'] = lr return new_state, metrics def pad_examples(x, desired_batch_size): """Expand batch to desired size by zeros with the shape of last slice.""" batch_pad = desired_batch_size - x.shape[0] # Padding with zeros to avoid that they get counted in compute_metrics. return np.concatenate([x, np.tile(np.zeros_like(x[-1]), (batch_pad, 1))]) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps eval_freq = FLAGS.eval_frequency random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train') ) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval') ) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig( vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), max_len=FLAGS.max_length, ) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, ) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, repeat=1, ) model = models.Transformer(config) rng = random.key(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables init_variables = initialize_variables(init_rng) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate ) optimizer = optax.adamw( learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1 ) state = train_state.TrainState.create( apply_fn=model.apply, params=init_variables['params'], tx=optimizer ) # Replicate optimizer. state = jax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn ), axis_name='batch', donate_argnums=(0,), ) # pytype: disable=wrong-arg-types def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = model.apply({'params': params}, inputs=inputs, train=False) return compute_metrics(logits, targets, weights) p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard( jax.tree_util.tree_map(lambda x: x._numpy(), batch) ) # pylint: disable=protected-access state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] # reset metric accumulation for next evaluation cycle. eval_metrics = [] eval_iter = iter(eval_ds) for eval_batch in eval_iter: eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_util.tree_map( lambda x: pad_examples(x, batch_size), eval_batch ) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(state.params, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums, ) logging.info( 'eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy'], ) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.process_index() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() if __name__ == '__main__': app.run(main) ================================================ FILE: examples/nnx_toy_examples/01_functional_api.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from flax import nnx X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w.value + self.b.value class Count(nnx.Variable[nnx.A]): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count[...] += 1 return self.linear2(jax.nn.relu(self.linear1(x) * 0.5)) graphdef, params, counts = nnx.split( MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)), nnx.Param, Count ) @jax.jit def train_step(params, counts, batch): x, y = batch def loss_fn(params): model = nnx.merge(graphdef, params, counts) y_pred = model(x) new_counts = nnx.state(model, Count) loss = jnp.mean((y - y_pred) ** 2) return loss, new_counts grad, counts = jax.grad(loss_fn, has_aux=True)(params) # |-------- sgd ---------| params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad) return params, counts @jax.jit def test_step(params: nnx.State, counts: nnx.State, batch): x, y = batch model = nnx.merge(graphdef, params, counts) y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} total_steps = 10_000 for step, batch in enumerate(dataset(32)): params, counts = train_step(params, counts, batch) if step % 1000 == 0: logs = test_step(params, counts, (X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break model = nnx.merge(graphdef, params, counts) print('times called:', model.count.value) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/02_lifted_transforms.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from flax import nnx X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w.value + self.b.value class Count(nnx.Variable): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return x model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads: nnx.State = nnx.grad(loss_fn)(model) optimizer.update(model, grads) @nnx.jit def test_step(model: MLP, batch): x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} cached_train_step = nnx.cached_partial(train_step, model, optimizer) cached_test_step = nnx.cached_partial(test_step, model) total_steps = 10_000 for step, batch in enumerate(dataset(32)): cached_train_step(batch) if step % 1000 == 0: logs = cached_test_step((X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break print('times called:', model.count.value) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/03_train_state.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from flax import nnx from flax.training import train_state X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w.value + self.b.value class Count(nnx.Variable[nnx.A]): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) return x class TrainState(train_state.TrainState): counts: nnx.State graphdef: nnx.GraphDef model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) graphdef, params, counts = nnx.split(model, nnx.Param, Count) state = TrainState.create( apply_fn=None, graphdef=graphdef, params=params, tx=optax.sgd(0.1), counts=counts, ) del params, counts @jax.jit def train_step(state: TrainState, batch): x, y = batch def loss_fn(params): model = nnx.merge(state.graphdef, params, state.counts, copy=True) y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) counts = nnx.state(model, Count) return loss, counts grads, counts = jax.grad(loss_fn, has_aux=True)(state.params) # sdg update state = state.apply_gradients(grads=grads, counts=counts) return state @jax.jit def test_step(state: nnx.TrainState[MLP], batch): x, y = batch model = nnx.merge(state.graphdef, state.params, state.counts) y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} total_steps = 10_000 for step, batch in enumerate(dataset(32)): state = train_step(state, batch) if step % 1000 == 0: logs = test_step(state, (X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break model = nnx.merge(state.graphdef, state.params, state.counts) print('times called:', model.count.value) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/04_data_parallel_with_jit.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' import jax import jax.numpy as jnp import numpy as np import optax from flax import nnx from jax.experimental import mesh_utils import matplotlib.pyplot as plt # create a mesh + shardings num_devices = jax.local_device_count() mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh((num_devices,)), ('data',) ) model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec()) data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('data')) # create model class MLP(nnx.Module): def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dmid, rngs=rngs) self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): return self.linear2(nnx.relu(self.linear1(x))) model = MLP(1, 64, 1, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param) # replicate state state = nnx.state((model, optimizer)) state = jax.device_put(state, model_sharding) nnx.update((model, optimizer), state) # visualize model sharding print('model sharding') jax.debug.visualize_array_sharding(model.linear1.kernel.value) @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, x, y): def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return loss def dataset(steps, batch_size): for _ in range(steps): x = np.random.uniform(-2, 2, size=(batch_size, 1)) y = 0.8 * x**2 + 0.1 + np.random.normal(0, 0.1, size=x.shape) yield x, y for step, (x, y) in enumerate(dataset(1000, 16)): # shard data x, y = jax.device_put((x, y), data_sharding) # train loss = train_step(model, optimizer, x, y) if step == 0: print('data sharding') jax.debug.visualize_array_sharding(x) if step % 100 == 0: print(f'step={step}, loss={loss}') # dereplicate state state = nnx.state((model, optimizer)) state = jax.device_get(state) nnx.update((model, optimizer), state) X, Y = next(dataset(1, 1000)) x_range = np.linspace(X.min(), X.max(), 100)[:, None] y_pred = model(x_range) # plot plt.scatter(X, Y, label='data') plt.plot(x_range, y_pred, color='black', label='model') plt.legend() plt.show() ================================================ FILE: examples/nnx_toy_examples/05_vae.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import typing as tp import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from datasets import load_dataset from flax import nnx np.random.seed(42) latent_size = 32 image_shape: tp.Sequence[int] = (28, 28) steps_per_epoch: int = 200 batch_size: int = 64 epochs: int = 20 dataset = load_dataset('mnist') X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8) X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8) # Now binarize data X_train = (X_train > 0).astype(jnp.float32) X_test = (X_test > 0).astype(jnp.float32) print('X_train:', X_train.shape, X_train.dtype) print('X_test:', X_test.shape, X_test.dtype) class Loss(nnx.Variable): pass # %% class Encoder(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dmid, rngs=rngs) self.linear_mean = nnx.Linear(dmid, dout, rngs=rngs) self.linear_std = nnx.Linear(dmid, dout, rngs=rngs) self.rngs = rngs def __call__(self, x: jax.Array) -> jax.Array: x = x.reshape((x.shape[0], -1)) # flatten x = self.linear1(x) x = jax.nn.relu(x) mean = self.linear_mean(x) std = jnp.exp(self.linear_std(x)) self.kl_loss = Loss( jnp.mean( 0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) ) ) key = self.rngs.noise() z = mean + std * jax.random.normal(key, mean.shape) return z class Decoder(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dmid, rngs=rngs) self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, z: jax.Array) -> jax.Array: z = self.linear1(z) z = jax.nn.relu(z) logits = self.linear2(z) return logits class VAE(nnx.Module): def __init__( self, din: int, hidden_size: int, latent_size: int, output_shape: tp.Sequence[int], *, rngs: nnx.Rngs, ): self.output_shape = output_shape self.encoder = Encoder(din, hidden_size, latent_size, rngs=rngs) self.decoder = Decoder( latent_size, hidden_size, int(np.prod(output_shape)), rngs=rngs ) def __call__(self, x: jax.Array) -> jax.Array: z = self.encoder(x) logits = self.decoder(z) logits = jnp.reshape(logits, (-1, *self.output_shape)) return logits def generate(self, z): logits = self.decoder(z) logits = jnp.reshape(logits, (-1, *self.output_shape)) return nnx.sigmoid(logits) model = VAE( din=int(np.prod(image_shape)), hidden_size=256, latent_size=latent_size, output_shape=image_shape, rngs=nnx.Rngs(0, noise=1), ) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) # %% @nnx.jit def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): def loss_fn(model: VAE): logits = model(x) losses = nnx.pop(model, Loss) kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) reconstruction_loss = jnp.mean( optax.sigmoid_binary_cross_entropy(logits, x) ) loss = reconstruction_loss + 0.1 * kl_loss return loss loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return loss @nnx.jit def forward(model: VAE, x: jax.Array) -> jax.Array: y_pred = model(x) return jax.nn.sigmoid(y_pred) @nnx.jit def sample(model: VAE, z: jax.Array) -> jax.Array: return model.generate(z) # %% for epoch in range(epochs): losses = [] for step in range(steps_per_epoch): idxs = np.random.randint(0, len(X_train), size=(batch_size,)) x_batch = X_train[idxs] loss = train_step(model, optimizer, x_batch) losses.append(np.asarray(loss)) print(f'Epoch {epoch} loss: {np.mean(losses)}') # exit() # %% # get random samples idxs = np.random.randint(0, len(X_test), size=(5,)) x_sample = X_test[idxs] # get predictions y_pred = forward(model, x_sample) # plot reconstruction figure = plt.figure(figsize=(3 * 5, 3 * 2)) plt.title('Reconstruction Samples') for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap='gray') plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap='gray') # # tbwriter.add_figure("VAE Example", figure, epochs) plt.show() # %% # plot generative samples z_samples = np.random.normal(scale=1.5, size=(12, latent_size)) samples = sample(model, z_samples) figure = plt.figure(figsize=(3 * 5, 3 * 2)) plt.title('Generative Samples') for i in range(5): plt.subplot(2, 5, 2 * i + 1) plt.imshow(samples[i], cmap='gray') plt.subplot(2, 5, 2 * i + 2) plt.imshow(samples[i + 1], cmap='gray') plt.show() # %% ================================================ FILE: examples/nnx_toy_examples/06_scan_over_layers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp from flax import nnx class Block(nnx.Module): def __init__(self, dim: int, *, rngs: nnx.Rngs): self.linear = nnx.Linear(dim, dim, rngs=rngs) self.bn = nnx.BatchNorm(dim, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x: jax.Array): return jax.nn.gelu(self.dropout(self.bn(self.linear(x)))) class ScanMLP(nnx.Module): """ An MLP that uses `vmap` during `__init__` to create a Block instance with an additional `layer` axis, and `scan` during `__call__` to apply the sequence of layers iteratively over the input / output `x`. """ def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): self.n_layers = n_layers @nnx.split_rngs(splits=n_layers) @nnx.vmap(axis_size=n_layers) def create_block(rngs: nnx.Rngs): return Block(dim, rngs=rngs) self.layers = create_block(rngs) def __call__(self, x: jax.Array) -> jax.Array: @nnx.scan def scan_fn(x: jax.Array, block: Block): x = block(x) return x, None x, _ = scan_fn(x, self.layers) return x model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) x = jnp.ones((3, 10)) y = model(x) print(jax.tree.map(jnp.shape, nnx.state(model))) print(y.shape) ================================================ FILE: examples/nnx_toy_examples/07_array_leaves.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from flax import nnx, struct X = np.linspace(0, 1, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = jax.random.normal(rngs.params(), (din, dout)) self.b = jnp.zeros((dout,)) def __call__(self, x): return x @ self.w + self.b class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = jnp.array(0) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count += 1 return self.linear2(nnx.relu(self.linear1(x))) def is_param(path, value): key = path[-1] return key == 'w' or key == 'b' model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=is_param) @nnx.jit def train_step(model: MLP, optimizer: nnx.Optimizer, batch): x, y = batch def loss_fn(model: MLP): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) diff_state = nnx.DiffState(0, is_param) grads: nnx.State = nnx.grad(loss_fn, argnums=diff_state)(model) optimizer.update(model, grads) @nnx.jit def test_step(model: MLP, batch): x, y = batch y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} total_steps = 10_000 for step, batch in enumerate(dataset(32)): train_step(model, optimizer, batch) if step % 1000 == 0: logs = test_step(model, (X, Y)) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break print('times called:', model.count) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/08_save_load_checkpoints.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tempfile import TemporaryDirectory import jax import jax.numpy as jnp import orbax.checkpoint as orbax from flax import nnx class MLP(nnx.Module): def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.dense1 = nnx.Linear(din, dmid, rngs=rngs) self.dense2 = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.dense1(x) x = jax.nn.relu(x) x = self.dense2(x) return x def create_model(seed: int): return MLP(10, 20, 30, rngs=nnx.Rngs(seed)) def create_and_save(seed: int, path: str): model = create_model(seed) state = nnx.state(model) # Save the parameters checkpointer = orbax.PyTreeCheckpointer() checkpointer.save(f'{path}/state', state) def load_model(path: str) -> MLP: # create that model with abstract shapes model = nnx.eval_shape(lambda: create_model(0)) state = nnx.state(model) # Load the parameters checkpointer = orbax.PyTreeCheckpointer() state = checkpointer.restore(f'{path}/state', item=state) # update the model with the loaded state nnx.update(model, state) return model with TemporaryDirectory() as tmpdir: # create a checkpoint create_and_save(42, tmpdir) # load model from checkpoint model = load_model(tmpdir) # run the model y = model(jnp.ones((1, 10))) print(model) print(y) ================================================ FILE: examples/nnx_toy_examples/09_parameter_surgery.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from flax import nnx # lets pretend this function loads a pretrained model from a checkpoint def load_pretrained(): return nnx.Linear(784, 128, rngs=nnx.Rngs(0)) # create a simple linear classifier using a pretrained backbone class Classifier(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.backbone = nnx.Linear(784, 128, rngs=nnx.Rngs(0)) self.head = nnx.Linear(128, 10, rngs=rngs) def __call__(self, x): x = self.backbone(x) x = nnx.relu(x) x = self.head(x) return x # create the classifier using the pretrained backbone, here we are technically # doing "parameter surgery", however, compared to Haiku/Flax where you must manually # construct the parameter structure, in NNX this is done automatically model = Classifier(rngs=nnx.Rngs(42)) model.backbone = load_pretrained() # create a filter to select all the parameters that are not part of the # backbone, i.e. the classifier parameters is_trainable = lambda path, node: ( 'backbone' in path and isinstance(node, nnx.Param) ) # split the parameters into trainable and non-trainable parameters graphdef, trainable_params, non_trainable = nnx.split(model, is_trainable, ...) print( 'trainable_params =', jax.tree.map(jax.numpy.shape, trainable_params), ) print('non_trainable = ', jax.tree.map(jax.numpy.shape, non_trainable)) ================================================ FILE: examples/nnx_toy_examples/10_fsdp_and_optimizer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' from matplotlib import pyplot as plt from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec as P, NamedSharding import jax import jax.numpy as jnp import numpy as np from flax import nnx import typing as tp mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh((2, 4)), ('data', 'model'), ) def named_sharding(*names: str | None) -> NamedSharding: return NamedSharding(mesh, P(*names)) @dataclasses.dataclass(unsafe_hash=True) class MeshRules: embed: str | None = None mlp: str | None = None data: str | None = None def __call__(self, *keys: str) -> tuple[str, ...]: return tuple(getattr(self, key) for key in keys) mesh_rules = MeshRules( embed=None, mlp='model', data='data', ) class MLP(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.w1 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)), out_sharding=mesh_rules('embed', 'mlp'), ) self.b1 = nnx.Param( jnp.zeros((dmid,)), out_sharding=mesh_rules('mlp'), ) self.w2 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)), out_sharding=mesh_rules('embed', 'mlp'), ) def __call__(self, x: jax.Array): return nnx.relu(x @ self.w1 + self.b1) @ self.w2 class SGDState(nnx.Variable): pass class SGD(nnx.Pytree): def __init__(self, params: nnx.State, lr, decay=0.9): def init_optimizer_state(variable: nnx.Variable): return SGDState( jnp.zeros_like(variable.value), **variable.get_metadata() ) self.lr = lr self.params = nnx.data(params) self.momentum: nnx.State = nnx.data(jax.tree.map( init_optimizer_state, self.params, is_leaf=lambda x: isinstance(x, nnx.Variable), )) self.decay = decay def update(self, grads: nnx.State): def update_fn( params: nnx.Variable, momentum: SGDState, grad: nnx.Variable ): # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t) momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...] # θ_{t+1} = θ_t - α * v_t params[...] -= self.lr * momentum[...] jax.tree.map( update_fn, self.params, self.momentum, grads, is_leaf=lambda x: isinstance(x, nnx.Variable), ) @nnx.jit def create_model(): model = MLP(1, 32, 1, rngs=nnx.Rngs(0)) optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9) return model, optimizer with jax.set_mesh(mesh): model, optimizer = create_model() print('Model parameters sharding:') jax.debug.visualize_array_sharding(model.w1.value) print('Optimizer momentum sharding:') jax.debug.visualize_array_sharding(optimizer.momentum['w1'].value) @nnx.jit def train_step(model: MLP, optimizer: SGD, x, y): def loss_fn(model): y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return loss loss, grad = nnx.value_and_grad(loss_fn)(model) optimizer.update(grad) return loss X = np.linspace(-2, 2, 100)[:, None] Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size, num_steps): for _ in range(num_steps): idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] losses = [] for step, (x_batch, y_batch) in enumerate( dataset(batch_size=32, num_steps=10_000) ): x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data')) loss = train_step(model, optimizer, x_batch, y_batch) losses.append(float(loss)) if step % 1000 == 0: print(f'Step {step}: Loss = {loss}') plt.figure() plt.plot(losses[20:]) y_pred = model(X) plt.figure() plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/hijax_basic.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import optax from flax import nnx X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(rngs.params.uniform((din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w + self.b[None] class Count(nnx.Variable[nnx.A]): pass class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear1 = Linear(din, dhidden, rngs=rngs) self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): self.count[...] += 1 return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5) nnx.var_defaults(hijax=True) model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param) @jax.jit def train_step(model, optimizer, x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((y - model(x)) ** 2) grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) optimizer.update(model, grads) @jax.jit def test_step(model: MLP, x, y): return {'loss': jnp.mean((y - model(x)) ** 2)} total_steps = 10_000 for step, (x, y) in enumerate(dataset(32)): train_step(model, optimizer, x, y) if step % 1000 == 0: logs = test_step(model, X, Y) print(f"step: {step}, loss: {logs['loss']}") if step >= total_steps - 1: break print('times called:', model.count[...]) y_pred = model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/hijax_demo.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from flax import nnx # ## Data # We create a simple dataset of points sampled from a parabola with some noise. X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape) def dataset(batch_size): while True: idx = np.random.choice(len(X), size=batch_size) yield X[idx], Y[idx] # ## Model # Here we define a MLP made of a stack of blocks. Each block contains a linear layer, # batch normalization, and a dropout layer. # # In this version we want the Modules to be pytrees so they can be used with JAX transforms # so we use a new Pytree type as the base. The main difference with current NNX is that # attributes that contain arrays or other pytrees now need to be explicitly marked as # using `nnx.data` to be included in the pytree. class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.din, self.dout = din, dout initializer = jax.nn.initializers.lecun_normal() # nnx.data is used mark attributes as pytree data # Param, BatchState, and Cache are built-in Variable subtypes self.w = nnx.Param(initializer(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x: jax.Array): return x @ self.w + self.b[None] # Block implements linear, batch norm, and dropout. Its behavior # is controlled by the 'use_stats' and 'deterministic' flags. class Block(nnx.Module): def __init__( self, din: int, dout: int, *, dropout_rate: float = 0.05, moumentum: float = 0.95, use_stats: bool = False, deterministic: bool = False, rngs: nnx.Rngs, ): # ----------- linear ------------------- self.din, self.dout = din, dout initializer = jax.nn.initializers.lecun_normal() self.w = nnx.Param(initializer(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) # ----------- batch norm --------------- self.mu = moumentum # momentum self.use_stats = use_stats self.mean = nnx.BatchStat(jnp.zeros((dout,))) self.var = nnx.BatchStat(jnp.ones((dout,))) self.scale = nnx.Param(jnp.ones((dout,))) self.bias = nnx.Param(jnp.zeros((dout,))) # ----------- dropout ------------------ self.dropout_rate = dropout_rate self.deterministic = deterministic def __call__( self, x: jax.Array, *, rngs: nnx.Rngs | None = None ) -> jax.Array: # ----------- linear -------------------- x = x @ self.w + self.b[None] # ----------- batch norm ---------------- if self.use_stats: mean = self.mean var = self.var else: mean = jnp.mean(x, axis=0) var = jnp.var(x, axis=0) # ema updates # stop gradient is used until a Hijax supports updates from grad tracers sg = jax.lax.stop_gradient self.mean[...] = sg(self.mu * self.mean + (1 - self.mu) * mean) self.var[...] = sg(self.mu * self.var + (1 - self.mu) * var) x = (x - mean[None]) / jnp.sqrt(var[None] + 1e-5) x = x * self.scale + self.bias # ----------- dropout ------------------- if not self.deterministic and self.dropout_rate > 0.0: assert rngs is not None keep_prob = 1.0 - self.dropout_rate mask = jax.random.bernoulli(rngs.dropout(), keep_prob, x.shape) x = jnp.where(mask, x / keep_prob, jnp.zeros_like(x)) # ----------- activation --------------- x = jax.nn.gelu(x) return x class Model(nnx.Module): def __init__( self, num_blocks: int, din: int, dhidden: int, dout: int, *, use_scan: bool = True, rngs: nnx.Rngs, ): self.count = nnx.Variable(jnp.array(0)) self.block_in = Block(din, dhidden, rngs=rngs) self.linear_out = Linear(dhidden, dout, rngs=rngs) # 'blocks' is either a list of blocks or single block # whose parameters contain an additional 'layer' dimension, # here created using jax.vmap if use_scan: @jax.vmap def create_block(rngs, /): # return nnx.stateless(Block(dhidden, dhidden, rngs=rngs)) return Block(dhidden, dhidden, rngs=rngs) # self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks))) self.blocks = create_block(rngs.fork(split=num_blocks)) else: self.blocks = nnx.List( [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)] ) def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None): self.count[...] += 1 x = self.block_in(x, rngs=rngs) # on the forward pass we either iterate over the block # list or use jax.lax.scan to apply the blocks, if we # had shared state we would use split and merge to # pass the shared state as a capture if isinstance(self.blocks, nnx.List): for block in self.blocks: x = block(x, rngs=rngs) else: def block_fw(x, block: Block): x = block(x, rngs=rngs) return x, None x, _ = jax.lax.scan(block_fw, x, self.blocks) x = self.linear_out(x) return x # ## Optimizer class OptState(nnx.Variable): ... # Optimizer are an interesting case as they are inherently stateful and # pose a good use case for MutableHijax. Here we implement SGD with # momentum. The optimizer receives the params as constructor arguments but doesn't # hold a reference to them, it only uses the params to initialize its state # by creating new OptState Variables that reuse the param's metadata. class SGD(nnx.Pytree): def __init__(self, params, lr: float, decay: float = 0.9): self.lr = lr self.decay = decay def make_opt_state(x): if isinstance(x, nnx.Variable): return OptState(jnp.zeros_like(x[...]), **x.get_metadata()) else: return OptState(jnp.zeros_like(x)) self.momentum = nnx.data(jax.tree.map(make_opt_state, params)) # during the update we simply map over (params, momentum, grads), # for each triplet we implement the SGD update rule which updates # both the optimizer's state (momentum) and the params in place. def update(self, params, grads): def update_fn( param: nnx.Variable[jax.Array], momentum: nnx.Variable[jax.Array], grad: nnx.Variable[jax.Array], ): momentum[...] = self.decay * momentum + (1 - self.decay) * grad param[...] -= self.lr * momentum # is_leaf might not be necesarry as MutableHijaxVariable are not pytreees jax.tree.map(update_fn, params, self.momentum, grads) # ## Training nnx.var_defaults(hijax=True) rngs = nnx.Rngs(params=0, dropout=1) model = Model( num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs ) optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) # Create a copy of the model structure and set its attributes to eval model. # This works because they share the underlying ArrayRefs so both models # will always be in sync. eval_model = nnx.merge(*nnx.split(model)) eval_model.set_attributes(use_stats=True, deterministic=True) # The training step uses 'jax.jit' and receives the model and optimizer as arguments, # this is supported as they are now pytrees. The first thing we do is group the model # state into the params and the non-differentiable state using 'split'. We differentiate # the loss function using 'jax.grad' with respect to the params-only. Inside the loss # function we merge the params and non-diff state back into a single model and then # compute the loss by calling the model with the inputs. @jax.jit def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) loss = jnp.mean((model(x, rngs=rngs) - y) ** 2) return loss # For the time being we have to use 'immutable' # as 'jax.grad' doesn't support QDD types yet. grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False)) # 'update' mutates the optimizer's state and the params in place # so we don't need to return anything 🚀 optimizer.update(params, grads) # simple test step that computes the loss @jax.jit def test_step(model: Model, x, y): return {'loss': jnp.mean((model(x) - y) ** 2)} # minimalistic training loop total_steps = 2_000 for step, (x, y) in enumerate(dataset(32)): train_step(model, optimizer, rngs, x, y) if step % 200 == 0: logs = test_step(eval_model, X, Y) print(f'step: {step}, loss: {logs["loss"]}') if step >= total_steps - 1: break # ## Sample # Sampling is trivial, just use 'model_eval' print('times called:', eval_model.count[...]) y_pred = eval_model(X) plt.scatter(X, Y, color='blue') plt.plot(X, y_pred, color='black') plt.show() ================================================ FILE: examples/nnx_toy_examples/requirements.txt ================================================ matplotlib>=3.7.1 datasets>=2.12.0 ================================================ FILE: examples/ogbg_molpcba/README.md ================================================ ## Predicting Biological Activities of Molecules with Graph Neural Networks [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb) This example trains a Graph Neural Network to classify molecules on the basis of their biological activities. ![Prediction on a caramboxin molecule](https://www.gstatic.com/flax_examples/ogbg_molpcba.png "Prediction on a caramboxin molecule") We use [Jraph](https://github.com/deepmind/jraph/), a JAX library for Graph Neural Networks, to define models which are trained on the [ogbg-molpcba](https://ogb.stanford.edu/docs/graphprop/) dataset, part of the [Open Graph Benchmark](https://ogb.stanford.edu/). You can run this code and even modify it directly in Google Colab, no installation required! The [Colab notebook](https://colab.research.google.com/github/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb) can even create visualizations of model predictions: ![Visualizing predictions of a trained model](https://www.gstatic.com/flax_examples/ogbg_molpcba_predictions.svg? "Visualizing predictions of a trained model") ### Requirements We depend on [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/ogbg_molpcba) for ogbg-molpcba. ### How to Run To run with the default configuration: ```shell python main.py --workdir=./ogbg_molpcba --config=configs/default.py ``` Since the configuration is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags), you can override hyperparameters. For example, to change the number of epochs and the batch size: ```shell python main.py --workdir=./ogbg_molpcba --config=configs/default.py \ --config.num_training_epochs=10 --config.batch_size=50 ``` For more extensive changes, you can directly edit the default configuration file or even add your own. ### Supported Setups This example supports only single device training. The model should run with other configurations and hardware, but was explicitly tested on the following. Hardware | Batch size | Training time | Test mean AP | Validation mean AP | Metrics -------- | ---------- | ------------- | ------- | ------- | --------------- 1x V100 | 256 | 3h20m | 0.244 | 0.252 |[2021-08-03](https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/) These metrics reported above are obtained at the end of training. We observed that slightly higher metrics can be obtained with early-stopping based on the validation mean AP: Hardware | Batch size | Training time | Test mean AP | Validation mean AP | Metrics -------- | ---------- | ------------- | ------- | ------- | --------------- 1x V100 | 256 | 2h55m | 0.249 | 0.257 |[2021-08-03](https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/) ### Model Description The default configuration corresponds to a [Graph Convolutional Network](https://arxiv.org/abs/1609.02907) model with 695,936 parameters. We noticed diminishing gains when training for longer. Further, the addition of self-loops and undirected edges significantly helped performance. Minor improvements were seen with skip-connections across message-passing steps, together with [LayerNorm](https://arxiv.org/abs/1607.06450). On the contrary, we found that the addition of [virtual nodes](https://arxiv.org/abs/1709.03741), which are connected to all nodes in each graph, did not improve performance. ### References - Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta and Jure Leskovec (2020). *Open Graph Benchmark: Datasets for Machine Learning on Graphs.* In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual. - Thomas N. Kipf and Max Welling (2016). *Semi-supervised classification with graph convolutional networks.* arXiv preprint arXiv:1609.02907. - Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E. Hinton (2016). *Layer normalization.* arXiv preprint arXiv:1607.06450. - Junying Li, Deng Cai and Xiaofei He (2017). *Learning graph-level representation for drug discovery.* arXiv preprint arXiv:1709.03741. The caramboxin molecule diagram depicted above was obtained and modified from [Wikimedia Commons](https://commons.wikimedia.org/wiki/File:Caramboxin.svg), available in the public domain. ================================================ FILE: examples/ogbg_molpcba/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines the default hyperparameters and training configuration. Uses a Graph Convolutional Network model (https://arxiv.org/abs/1609.02907). """ import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Optimizer. config.optimizer = 'adam' config.learning_rate = 1e-3 # Training hyperparameters. config.batch_size = 256 config.num_train_steps = 100_000 config.log_every_steps = 100 config.eval_every_steps = 1_000 config.checkpoint_every_steps = 10_000 config.add_virtual_node = False config.add_undirected_edges = True config.add_self_loops = True # GNN hyperparameters. config.model = 'GraphConvNet' config.message_passing_steps = 5 config.latent_size = 256 config.dropout_rate = 0.1 config.num_mlp_layers = 2 config.num_classes = 128 config.skip_connections = True config.layer_norm = True return config ================================================ FILE: examples/ogbg_molpcba/configs/default_graph_net.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines the default hyperparameters and training configuration. Uses a GraphNetwork model (https://arxiv.org/abs/1806.01261). """ import ml_collections def get_config(): """Get the hyperparameter configuration for the GraphNetwork model.""" config = ml_collections.ConfigDict() # Optimizer. config.optimizer = 'adam' config.learning_rate = 1e-3 # Training hyperparameters. config.batch_size = 256 config.num_train_steps = 100_000 config.log_every_steps = 100 config.eval_every_steps = 10_000 config.checkpoint_every_steps = 10_000 config.add_virtual_node = True config.add_undirected_edges = True config.add_self_loops = True # GNN hyperparameters. config.model = 'GraphNet' config.message_passing_steps = 5 config.latent_size = 256 config.dropout_rate = 0.1 config.num_mlp_layers = 1 config.num_classes = 128 config.use_edge_model = True config.skip_connections = True config.layer_norm = True return config ================================================ FILE: examples/ogbg_molpcba/configs/hparam_sweep.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines a sweep for the hyperparameters for the GNN.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Optimizer. config.optimizer = 'adam' config.learning_rate = 1e-3 # Training hyperparameters. config.batch_size = 256 config.num_train_steps = 500_000 config.log_every_steps = 50 config.eval_every_steps = 1_000 config.checkpoint_every_steps = 10_000 config.add_virtual_node = True config.add_undirected_edges = True config.add_self_loops = True # GNN hyperparameters. config.model = 'GraphConvNet' config.message_passing_steps = 5 config.latent_size = 256 config.dropout_rate = 0.1 config.num_mlp_layers = 2 config.num_classes = 128 config.skip_connections = True config.layer_norm = True return config def sweep(add): for add_virtual_node in (True, False): for add_undirected_edges in (True, False): for add_self_loops in (True, False): for layer_norm in (True, False): for skip_connections in (True, False): add( add_virtual_node=add_virtual_node, add_undirected_edges=add_undirected_edges, add_self_loops=add_self_loops, layer_norm=layer_norm, skip_connections=skip_connections, ) ================================================ FILE: examples/ogbg_molpcba/configs/test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines a CPU-friendly test configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Optimizer. config.optimizer = 'adam' config.learning_rate = 1e-3 # Training hyperparameters. config.batch_size = 32 config.num_train_steps = 10 config.log_every_steps = 5 config.eval_every_steps = 5 config.checkpoint_every_steps = 5 config.add_virtual_node = True config.add_undirected_edges = True config.add_self_loops = True # GNN hyperparameters. config.model = 'GraphConvNet' config.message_passing_steps = 5 config.latent_size = 256 config.dropout_rate = 0.1 config.num_mlp_layers = 1 config.num_classes = 128 config.skip_connections = False config.layer_norm = False return config ================================================ FILE: examples/ogbg_molpcba/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Exposes the ogbg-molpcba dataset in a convenient format.""" import functools from typing import Dict, NamedTuple import jraph import numpy as np import tensorflow as tf import tensorflow_datasets as tfds class GraphsTupleSize(NamedTuple): """Helper class to represent padding and graph sizes.""" n_node: int n_edge: int n_graph: int def get_raw_datasets() -> dict[str, tf.data.Dataset]: """Returns datasets as tf.data.Dataset, organized by split.""" ds_builder = tfds.builder('ogbg_molpcba') ds_builder.download_and_prepare() ds_splits = ['train', 'validation', 'test'] datasets = {split: ds_builder.as_dataset(split=split) for split in ds_splits} return datasets def get_datasets( batch_size: int, add_virtual_node: bool = True, add_undirected_edges: bool = True, add_self_loops: bool = True, ) -> dict[str, tf.data.Dataset]: """Returns datasets of batched GraphsTuples, organized by split.""" if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') # Obtain the original datasets. datasets = get_raw_datasets() # Construct the GraphsTuple converter function. convert_to_graphs_tuple_fn = functools.partial( convert_to_graphs_tuple, add_virtual_node=add_virtual_node, add_undirected_edges=add_undirected_edges, add_self_loops=add_self_loops, ) # Process each split separately. for split_name in datasets: # Convert to GraphsTuple. datasets[split_name] = datasets[split_name].map( convert_to_graphs_tuple_fn, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True, ) # Compute the padding budget for the requested batch size. budget = estimate_padding_budget_for_batch_size( datasets['train'], batch_size, num_estimation_graphs=100 ) # Pad an example graph to see what the output shapes will be. # We will use this shape information when creating the tf.data.Dataset. example_graph = next(datasets['train'].as_numpy_iterator()) example_padded_graph = jraph.pad_with_graphs(example_graph, *budget) padded_graphs_spec = specs_from_graphs_tuple(example_padded_graph) # Process each split separately. for split_name, dataset_split in datasets.items(): # Repeat and shuffle the training split. if split_name == 'train': dataset_split = dataset_split.shuffle(100, reshuffle_each_iteration=True) dataset_split = dataset_split.repeat() # Batch and pad each split. batching_fn = functools.partial( jraph.dynamically_batch, graphs_tuple_iterator=iter(dataset_split), n_node=budget.n_node, n_edge=budget.n_edge, n_graph=budget.n_graph, ) dataset_split = tf.data.Dataset.from_generator( batching_fn, output_signature=padded_graphs_spec ) # We cache the validation and test sets, since these are small. if split_name in ['validation', 'test']: dataset_split = dataset_split.cache() datasets[split_name] = dataset_split return datasets def convert_to_graphs_tuple( graph: dict[str, tf.Tensor], add_virtual_node: bool, add_undirected_edges: bool, add_self_loops: bool, ) -> jraph.GraphsTuple: """Converts a dictionary of tf.Tensors to a GraphsTuple.""" num_nodes = tf.squeeze(graph['num_nodes']) num_edges = tf.squeeze(graph['num_edges']) nodes = graph['node_feat'] edges = graph['edge_feat'] edge_feature_dim = edges.shape[-1] labels = graph['labels'] senders = graph['edge_index'][:, 0] receivers = graph['edge_index'][:, 1] # Add a virtual node connected to all other nodes. # The feature vectors for the virtual node # and the new edges are set to all zeros. if add_virtual_node: nodes = tf.concat([nodes, tf.zeros_like(nodes[0, None])], axis=0) senders = tf.concat([senders, tf.range(num_nodes)], axis=0) receivers = tf.concat( [receivers, tf.fill((num_nodes,), num_nodes + 1)], axis=0 ) edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) num_edges += num_nodes num_nodes += 1 # Make edges undirected, by adding edges with senders and receivers flipped. # The feature vector for the flipped edge is the same as the original edge. if add_undirected_edges: new_senders = tf.concat([senders, receivers], axis=0) new_receivers = tf.concat([receivers, senders], axis=0) edges = tf.concat([edges, edges], axis=0) senders, receivers = new_senders, new_receivers num_edges *= 2 # Add self-loops for each node. # The feature vectors for the self-loops are set to all zeros. if add_self_loops: senders = tf.concat([senders, tf.range(num_nodes)], axis=0) receivers = tf.concat([receivers, tf.range(num_nodes)], axis=0) edges = tf.concat([edges, tf.zeros((num_nodes, edge_feature_dim))], axis=0) num_edges += num_nodes return jraph.GraphsTuple( n_node=tf.expand_dims(num_nodes, 0), n_edge=tf.expand_dims(num_edges, 0), nodes=nodes, edges=edges, senders=senders, receivers=receivers, globals=tf.expand_dims(labels, axis=0), ) def estimate_padding_budget_for_batch_size( dataset: tf.data.Dataset, batch_size: int, num_estimation_graphs: int ) -> GraphsTupleSize: """Estimates the padding budget for a dataset of unbatched GraphsTuples. Args: dataset: A dataset of unbatched GraphsTuples. batch_size: The intended batch size. Note that no batching is performed by this function. num_estimation_graphs: How many graphs to take from the dataset to estimate the distribution of number of nodes and edges per graph. Returns: padding_budget: The padding budget for batching and padding the graphs in this dataset to the given batch size. """ def next_multiple_of_64(val: float): """Returns the next multiple of 64 after val.""" return 64 * (1 + int(val // 64)) if batch_size <= 1: raise ValueError('Batch size must be > 1 to account for padding graphs.') total_num_nodes = 0 total_num_edges = 0 for graph in dataset.take(num_estimation_graphs).as_numpy_iterator(): graph_size = get_graphs_tuple_size(graph) if graph_size.n_graph != 1: raise ValueError('Dataset contains batched GraphTuples.') total_num_nodes += graph_size.n_node total_num_edges += graph_size.n_edge num_nodes_per_graph_estimate = total_num_nodes / num_estimation_graphs num_edges_per_graph_estimate = total_num_edges / num_estimation_graphs padding_budget = GraphsTupleSize( n_node=next_multiple_of_64(num_nodes_per_graph_estimate * batch_size), n_edge=next_multiple_of_64(num_edges_per_graph_estimate * batch_size), n_graph=batch_size, ) return padding_budget def specs_from_graphs_tuple(graph: jraph.GraphsTuple): """Returns a tf.TensorSpec corresponding to this graph.""" def get_tensor_spec(array: np.ndarray): shape = list(array.shape) dtype = array.dtype return tf.TensorSpec(shape=shape, dtype=dtype) specs = {} for field in [ 'nodes', 'edges', 'senders', 'receivers', 'globals', 'n_node', 'n_edge', ]: field_sample = getattr(graph, field) specs[field] = get_tensor_spec(field_sample) return jraph.GraphsTuple(**specs) def get_graphs_tuple_size(graph: jraph.GraphsTuple): """Returns the number of nodes, edges and graphs in a GraphsTuple.""" return GraphsTupleSize( n_node=np.sum(graph.n_node), n_edge=np.sum(graph.n_edge), n_graph=np.shape(graph.n_node)[0], ) ================================================ FILE: examples/ogbg_molpcba/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.ogbg_molpcba.input_pipeline.""" from absl.testing import absltest from absl.testing import parameterized import input_pipeline import jraph import tensorflow as tf def get_dummy_datasets(dataset_length: int): """Returns a set of datasets of unbatched GraphsTuples.""" # The dummy graph. num_nodes = 3 num_edges = 4 dummy_graph = jraph.GraphsTuple( n_node=tf.expand_dims(num_nodes, 0), n_edge=tf.expand_dims(num_edges, 0), senders=tf.zeros(num_edges, dtype=tf.int32), receivers=tf.ones(num_edges, dtype=tf.int32), nodes=tf.zeros((num_nodes, 9)), edges=tf.ones((num_edges, 3)), globals=tf.ones((1, 128), dtype=tf.int64), ) graphs_spec = input_pipeline.specs_from_graphs_tuple(dummy_graph) # Yields a set of graphs for the current split. def get_dummy_graphs(): for _ in range(dataset_length): yield dummy_graph datasets = {} for split in ['train', 'validation', 'test']: datasets[split] = tf.data.Dataset.from_generator( get_dummy_graphs, output_signature=graphs_spec ) return datasets class InputPipelineTest(parameterized.TestCase): def setUp(self): super().setUp() dataset_length = 20 self.datasets = get_dummy_datasets(dataset_length) @parameterized.product( valid_batch_size=[2, 5, 12, 15], ) def test_estimate_padding_budget_valid(self, valid_batch_size): budget = input_pipeline.estimate_padding_budget_for_batch_size( self.datasets['train'], valid_batch_size, num_estimation_graphs=1 ) self.assertEqual(budget.n_graph, valid_batch_size) @parameterized.product( invalid_batch_size=[-1, 0, 1], ) def test_estimate_padding_budget_invalid(self, invalid_batch_size): with self.assertRaises(ValueError): input_pipeline.estimate_padding_budget_for_batch_size( self.datasets['train'], invalid_batch_size, num_estimation_graphs=1 ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/ogbg_molpcba/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the ogbg-molpcba example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') # This example only supports single-host training on a single device. logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': flags.mark_flags_as_required(['config', 'workdir']) app.run(main) ================================================ FILE: examples/ogbg_molpcba/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Definition of the GNN model.""" from collections.abc import Callable, Sequence from flax import linen as nn import jax.numpy as jnp import jraph def add_graphs_tuples( graphs: jraph.GraphsTuple, other_graphs: jraph.GraphsTuple ) -> jraph.GraphsTuple: """Adds the nodes, edges and global features from other_graphs to graphs.""" return graphs._replace( nodes=graphs.nodes + other_graphs.nodes, edges=graphs.edges + other_graphs.edges, globals=graphs.globals + other_graphs.globals, ) class MLP(nn.Module): """A multi-layer perceptron.""" feature_sizes: Sequence[int] dropout_rate: float = 0 deterministic: bool = True activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, inputs): x = inputs for size in self.feature_sizes: x = nn.Dense(features=size)(x) x = self.activation(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)( x ) return x class GraphNet(nn.Module): """A complete Graph Network model defined with Jraph.""" latent_size: int num_mlp_layers: int message_passing_steps: int output_globals_size: int dropout_rate: float = 0 skip_connections: bool = True use_edge_model: bool = True layer_norm: bool = True deterministic: bool = True @nn.compact def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # We will first linearly project the original features as 'embeddings'. embedder = jraph.GraphMapFeatures( embed_node_fn=nn.Dense(self.latent_size), embed_edge_fn=nn.Dense(self.latent_size), embed_global_fn=nn.Dense(self.latent_size), ) processed_graphs = embedder(graphs) # Now, we will apply a Graph Network once for each message-passing round. mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers for _ in range(self.message_passing_steps): if self.use_edge_model: update_edge_fn = jraph.concatenated_args( MLP( mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic, ) ) else: update_edge_fn = None update_node_fn = jraph.concatenated_args( MLP( mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic, ) ) update_global_fn = jraph.concatenated_args( MLP( mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic, ) ) graph_net = jraph.GraphNetwork( update_node_fn=update_node_fn, update_edge_fn=update_edge_fn, update_global_fn=update_global_fn, ) if self.skip_connections: processed_graphs = add_graphs_tuples( graph_net(processed_graphs), processed_graphs ) else: processed_graphs = graph_net(processed_graphs) if self.layer_norm: processed_graphs = processed_graphs._replace( nodes=nn.LayerNorm()(processed_graphs.nodes), edges=nn.LayerNorm()(processed_graphs.edges), globals=nn.LayerNorm()(processed_graphs.globals), ) # Since our graph-level predictions will be at globals, we will # decode to get the required output logits. decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.output_globals_size) ) processed_graphs = decoder(processed_graphs) return processed_graphs class GraphConvNet(nn.Module): """A Graph Convolution Network + Pooling model defined with Jraph.""" latent_size: int num_mlp_layers: int message_passing_steps: int output_globals_size: int dropout_rate: float = 0 skip_connections: bool = True layer_norm: bool = True deterministic: bool = True pooling_fn: Callable[ [jnp.ndarray, jnp.ndarray, jnp.ndarray], # pytype: disable=annotation-type-mismatch # jax-ndarray jnp.ndarray, ] = jraph.segment_mean def pool(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Pooling operation, taken from Jraph.""" # Equivalent to jnp.sum(n_node), but JIT-able. sum_n_node = graphs.nodes.shape[0] # pytype: disable=attribute-error # jax-ndarray # To aggregate nodes from each graph to global features, # we first construct tensors that map the node to the corresponding graph. # Example: if you have `n_node=[1,2]`, we construct the tensor [0, 1, 1]. n_graph = graphs.n_node.shape[0] node_graph_indices = jnp.repeat( jnp.arange(n_graph), graphs.n_node, axis=0, total_repeat_length=sum_n_node, ) # We use the aggregation function to pool the nodes per graph. pooled = self.pooling_fn(graphs.nodes, node_graph_indices, n_graph) # pytype: disable=wrong-arg-types # jax-ndarray return graphs._replace(globals=pooled) @nn.compact def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: # We will first linearly project the original node features as 'embeddings'. embedder = jraph.GraphMapFeatures(embed_node_fn=nn.Dense(self.latent_size)) processed_graphs = embedder(graphs) # Now, we will apply the GCN once for each message-passing round. for _ in range(self.message_passing_steps): mlp_feature_sizes = [self.latent_size] * self.num_mlp_layers update_node_fn = jraph.concatenated_args( MLP( mlp_feature_sizes, dropout_rate=self.dropout_rate, deterministic=self.deterministic, ) ) graph_conv = jraph.GraphConvolution( update_node_fn=update_node_fn, add_self_edges=True ) if self.skip_connections: processed_graphs = add_graphs_tuples( graph_conv(processed_graphs), processed_graphs ) else: processed_graphs = graph_conv(processed_graphs) if self.layer_norm: processed_graphs = processed_graphs._replace( nodes=nn.LayerNorm()(processed_graphs.nodes), ) # We apply the pooling operation to get a 'global' embedding. processed_graphs = self.pool(processed_graphs) # Now, we decode this to get the required output logits. decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.output_globals_size) ) processed_graphs = decoder(processed_graphs) return processed_graphs ================================================ FILE: examples/ogbg_molpcba/models_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.ogbg_molpcba.models.""" from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp import jraph import models class ModelsTest(parameterized.TestCase): def setUp(self): super().setUp() self.rngs = { 'params': jax.random.key(0), 'dropout': jax.random.key(1), } n_node = jnp.arange(3, 11) n_edge = jnp.arange(4, 12) total_n_node = jnp.sum(n_node) total_n_edge = jnp.sum(n_edge) n_graph = n_node.shape[0] feature_dim = 10 self.graphs = jraph.GraphsTuple( n_node=n_node, n_edge=n_edge, senders=jnp.zeros(total_n_edge, dtype=jnp.int32), receivers=jnp.ones(total_n_edge, dtype=jnp.int32), nodes=jnp.ones((total_n_node, feature_dim)), edges=jnp.zeros((total_n_edge, feature_dim)), globals=jnp.zeros((n_graph, feature_dim)), ) @parameterized.product( dropout_rate=[0.0, 0.5, 1.0], output_size=[50, 100], num_layers=[2] ) def test_mlp(self, dropout_rate, output_size, num_layers): # Input definition. nodes = self.graphs.nodes # Model definition. mlp = models.MLP( feature_sizes=[output_size] * num_layers, dropout_rate=dropout_rate, activation=lambda x: x, deterministic=False, ) nodes_after_mlp, _ = mlp.init_with_output(self.rngs, nodes) # Test that dropout actually worked. num_masked_entries = jnp.sum(nodes_after_mlp == 0) num_total_entries = jnp.size(nodes_after_mlp) self.assertLessEqual( num_masked_entries, (dropout_rate + 0.05) * num_total_entries ) self.assertLessEqual( (dropout_rate - 0.05) * num_total_entries, num_masked_entries ) # Test the shape of the output. self.assertEqual(nodes_after_mlp.shape[-1], output_size) @parameterized.parameters( { 'latent_size': 5, 'output_globals_size': 15, 'use_edge_model': True, }, { 'latent_size': 5, 'output_globals_size': 15, 'use_edge_model': False, }, ) def test_graph_net( self, latent_size: int, output_globals_size: int, use_edge_model: bool ): # Input definition. graphs = self.graphs num_nodes = jnp.sum(graphs.n_node) num_edges = jnp.sum(graphs.n_edge) num_graphs = graphs.n_node.shape[0] # Model definition. net = models.GraphNet( latent_size=latent_size, num_mlp_layers=2, message_passing_steps=2, output_globals_size=output_globals_size, use_edge_model=use_edge_model, ) output, _ = net.init_with_output(self.rngs, graphs) # Output should be graph with the same topology, but a # different number of features. self.assertIsInstance(output, jraph.GraphsTuple) self.assertSequenceAlmostEqual(output.n_node, graphs.n_node) self.assertSequenceAlmostEqual(output.n_edge, graphs.n_edge) self.assertSequenceAlmostEqual(output.senders, graphs.senders) self.assertSequenceAlmostEqual(output.receivers, graphs.receivers) self.assertEqual(output.nodes.shape, (num_nodes, latent_size)) self.assertEqual(output.edges.shape, (num_edges, latent_size)) self.assertEqual(output.globals.shape, (num_graphs, output_globals_size)) @parameterized.parameters( {'latent_size': 15, 'output_globals_size': 15}, {'latent_size': 5, 'output_globals_size': 5}, ) def test_graph_conv_net(self, latent_size: int, output_globals_size: int): graphs = self.graphs num_nodes = jnp.sum(graphs.n_node) num_graphs = graphs.n_node.shape[0] # Model definition. net = models.GraphConvNet( latent_size=latent_size, num_mlp_layers=2, message_passing_steps=2, output_globals_size=output_globals_size, ) output, _ = net.init_with_output(self.rngs, graphs) # Output should be graph with the same topology, but a # different number of features. self.assertIsInstance(output, jraph.GraphsTuple) self.assertSequenceAlmostEqual(output.n_node, graphs.n_node) self.assertSequenceAlmostEqual(output.n_edge, graphs.n_edge) self.assertSequenceAlmostEqual( output.edges.flatten(), graphs.edges.flatten() ) self.assertSequenceAlmostEqual(output.senders, graphs.senders) self.assertSequenceAlmostEqual(output.receivers, graphs.receivers) self.assertEqual(output.nodes.shape, (num_nodes, latent_size)) self.assertEqual(output.globals.shape, (num_graphs, output_globals_size)) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/ogbg_molpcba/ogbg_molpcba.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Flax ogbg-molpcba Example\n", "\n", "\"Open\n", "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/ogbg_molpcba." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "outputId": "6508ab2f-b0e5-4693-f6a0-7bc495ec1344" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l\r", "\u001b[K |████▏ | 10 kB 20.9 MB/s eta 0:00:01\r", "\u001b[K |████████▍ | 20 kB 10.6 MB/s eta 0:00:01\r", "\u001b[K |████████████▋ | 30 kB 8.8 MB/s eta 0:00:01\r", "\u001b[K |████████████████▉ | 40 kB 8.0 MB/s eta 0:00:01\r", "\u001b[K |█████████████████████ | 51 kB 5.1 MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████▎ | 61 kB 5.3 MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████████▌ | 71 kB 5.5 MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████████| 77 kB 2.7 MB/s \n", "\u001b[K |████████████████████████████████| 88 kB 5.5 MB/s \n", "\u001b[K |████████████████████████████████| 4.0 MB 38.2 MB/s \n", "\u001b[K |████████████████████████████████| 73 kB 1.3 MB/s \n", "\u001b[K |████████████████████████████████| 118 kB 38.3 MB/s \n", "\u001b[K |████████████████████████████████| 57 kB 3.6 MB/s \n", "\u001b[?25h Building wheel for flax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "# Install clu, ml-collections, latest Flax version, and tensorflow_datasets.\n", "!pip install -U -q clu ml-collections git+https://github.com/google/flax tfds_nightly jraph" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "example_directory = 'examples/ogbg_molpcba'\n", "editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')\n", "\n", "repo, branch = 'https://github.com/google/flax', 'main'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "cellView": "form", "outputId": "8261a349-b41e-4e1b-a2ca-8d23412155be" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'flaxrepo'...\n", "remote: Enumerating objects: 350, done.\u001b[K\n", "remote: Counting objects: 100% (350/350), done.\u001b[K\n", "remote: Compressing objects: 100% (312/312), done.\u001b[K\n", "remote: Total 350 (delta 65), reused 140 (delta 20), pack-reused 0\u001b[K\n", "Receiving objects: 100% (350/350), 2.10 MiB | 20.10 MiB/s, done.\n", "Resolving deltas: 100% (65/65), done.\n" ] }, { "data": { "text/html": [ "

WARNING : Editing in VM - changes lost after reboot!!

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/ogbg_molpcba/configs/default.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/ogbg_molpcba/input_pipeline.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/ogbg_molpcba/models.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/ogbg_molpcba/train.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", "# example directory and nothing needs to be done.)\n", "\n", "#@markdown **Fetch newest Flax version and copy of example code.**\n", "#@markdown\n", "#@markdown **If you select no** below, then the files will be stored on the\n", "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n", "#@markdown be restarted and any changes will be lost**.\n", "#@markdown\n", "#@markdown **If you select yes** below, then you will be asked for your\n", "#@markdown credentials to mount your personal Google Drive. In this case, all\n", "#@markdown changes you make will *persist*. Even if you re-run this\n", "#@markdown Colab notebook later on, the files will still exist. You can\n", "#@markdown remove directories inside your Drive's `flax/` root if you want to\n", "#@markdown manually revert these files.\n", "\n", "if 'google.colab' in str(get_ipython()):\n", " import os\n", " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", " !git clone --depth=1 -b $branch $repo flaxrepo\n", " # Copy example files & change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", " if mount_gdrive == 'yes':\n", " DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n", " else:\n", " DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'\n", " example_root_path = f'/content/{example_directory}'\n", " from IPython import display\n", " display.display(display.HTML(\n", " f'

{DISCLAIMER}

'))\n", " if not os.path.isdir(example_root_path):\n", " os.makedirs(example_root_path)\n", " !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n", " os.chdir(example_root_path)\n", " from google.colab import files\n", " for relpath in editor_relpaths:\n", " s = open(f'{example_root_path}/{relpath}').read()\n", " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "outputId": "14b17380-5077-4354-e651-027f3d933cfe" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/content/examples/ogbg_molpcba\n" ] } ], "source": [ "# Display current working directory.\n", "# Note: In Colab, running the above cell changes the working directory.\n", "!pwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Base imports\n", "from absl import logging\n", "import flax\n", "import jax.numpy as jnp\n", "import jraph\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import pprint\n", "import tensorflow_datasets as tfds\n", "logging.set_verbosity(logging.INFO)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Local imports from current directory - auto reload.\n", "# Any changes you make to train.py and other modules below will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", "import train\n", "import input_pipeline\n", "from configs import default as config_lib\n", "config = config_lib.get_config()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorFlow Datasets supports customizable visualization of the ogbg_molpcba dataset." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Visualization helpers\n", "# Dictionaries used to map nodes and edges to colors.\n", "atomic_numbers_to_elements = {\n", " 6: 'C', 7: 'N', 8: 'O', 9: 'F', 14: 'Si',\n", " 15: 'P', 16: 'S', 17: 'Cl', 35: 'Br',\n", "}\n", "elements_to_colors = {\n", " element: f'C{index}'\n", " for index, element in enumerate(atomic_numbers_to_elements.values())\n", "}\n", "bond_types_to_colors = {num: f'C{num}' for num in range(4)}\n", "\n", "# Node colors are atomic numbers.\n", "def node_color_fn(graph):\n", " atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()\n", " return {\n", " index: elements_to_colors[atomic_numbers_to_elements[atomic_number]]\n", " for index, atomic_number in enumerate(atomic_numbers)\n", " }\n", "\n", "# Node labels are element names.\n", "def node_label_fn(graph):\n", " atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()\n", " return {\n", " index: atomic_numbers_to_elements[atomic_number]\n", " for index, atomic_number in enumerate(atomic_numbers)\n", " }\n", "\n", "# Edge colors are bond types.\n", "def edge_color_fn(graph):\n", " bonds = graph['edge_index'].numpy()\n", " bond_types = graph['edge_feat'][:, 0].numpy()\n", " return {\n", " tuple(bond): bond_types_to_colors[bond_type]\n", " for bond, bond_type in zip(bonds, bond_types)\n", " }" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "outputId": "d9336190-e685-43e8-e3f1-73e3f1ce1cd2" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: ogbg_molpcba/0.1.2\n", "INFO:absl:Load dataset info from /tmp/tmpmdoxgxq7tfds\n", "INFO:absl:Generating dataset ogbg_molpcba (/root/tensorflow_datasets/ogbg_molpcba/0.1.2)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mDownloading and preparing dataset 37.70 MiB (download: 37.70 MiB, generated: 822.53 MiB, total: 860.23 MiB) to /root/tensorflow_datasets/ogbg_molpcba/0.1.2...\u001b[0m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "78bb0e82e40b4c3c937ae1f248ba1d54", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dl Completed...: 0 url [00:00, ? url/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f933ef3ecafb4973b5c54cc4ff039ca1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dl Size...: 0 MiB [00:00, ? MiB/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0607a3a123334ae882e26c39c543c9e7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Extraction completed...: 0 file [00:00, ? file/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Downloading https://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/pcba.zip into /root/tensorflow_datasets/downloads/snap.stan.edu_ogb_grap_csv_mol_down_pcbapc4I82Cv1THcU-IggPHK8IHZ8qM-BJ3VDk-q_rtqrf4.zip.tmp.495be609cf3840d3a942e7f9ba5d705e...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83b57edf962a4faa9e875b4fecbb1eb1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating splits...: 0%| | 0/3 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAAM9CAYAAAB5Rim2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3hUdfb48fe9d2p6IaGE0BJ6B6mCgFhWXOwitrUgYC+rruv2dcvPVXe/axfsYK/rWtaCSu9VekkoSSC9J9Pv/f0RQSAzyUzIzAQ4r+fh8XHaPUlm5t7zKecohmEghBBCCCGEECI0arQDEEIIIYQQQoiTkSRTQgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gKmpO9u1a2d069YtQqEIISJh3bp1pYZhpEU7jhMl309CnHpOhe8n+W4S4tTT1HdTk8lUt27dWLt2bXiiEkJEhaIo+6MdQ2uQ7ychTj2nwveTfDcJcepp6rtJlvkJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gCRTQgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gCRTQgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLWCK6tENA/YugmVPwaGN4K4HzQxx6XDGLTDkarAnRTVEIYQQorjayRur9vPhugIqHW50HWKtGmOyUpk1PouBnROjHaIQfum6waLdJcxdnMuOQ9XUu31YTSrpCVZuHNuNS4Z2Js4a3ctBIU5m0fv0/PAefPNHcFWBu+6n270OcFXDd4/At3+C/pfCBY+BLSFqoQohhDg9FVY5+c1Hm1maU4oCuLz6kfscHh+f/3CIBduKyUiy8cjFAxib3S56wQpxnLdW7edf3+zC4fZR5/Ydud3l1al2evn7Fzv46+fbuXxYZ353YT/sFi2K0Qpxcor8Mj/DgAV/gk/vhpqDxyZSR/PUg9cJWz6CuROgpjCiYQohhDi97Sqq4YInF7NwVzFur35MInWYbjQkVXtK6rj59TW8uyYvCpEKcSzDMPjdx5v5y2fbKa11H5NIHa3e7cPp0flgXT6XPLuMijp3hCMV4uQX+WRq2b9h1RzwOIJ7vM8FFQfgtQvBVRPe2IQQQgjgUJWDaXNWUFHvQTeCe47To/PH/27hq60y+Cei67GvdvLh+gIcHv9J1PFcXp3c0lqueWkVziCfI4RoENlkqnQPLHy0YdYpFIYXKvPg20fCE5cQQghxlPvf20SN0xvy85wenXvf2UitK/TnCtEaNudX8eqyvUEnUod5fAa5JbU8+/2eMEUmxKkpsnumVj4H+rEf7m7/rqHeA3vviSPWogDw0no3b/zgYeGNsT890OeCjW/CuY+A2R7JqIUQQpxG8srrWbe/Al+AKam6bQupXvMfPGX5qBY75vQeJI6dhq1zfwAUBT5en8/1Y7pFMOrw21dax9urD7C7uJZal5dEm4mBnZOYPjKT9HhbtMMTP3pxSS5uP0tSofn3rsurM2/Ffu6e3BOzJgWfhQhG5JIpdx1seht0T6O7fAY8ucrNb8Zbm3+dLR/B0GvDEKAQQggB81fsRzf8J1LVqz+matUHpJ53B7buw1A0E46963DsXnXkgrTe7WPOolyuG90VRVEavYZhGH5vb6sW7izm6e/2sKWgCl038ByVZC7eXcoz3+9hXHY77p7ckyGZUoE3mqrqPXy1tdDv0tRg3rsAXl3n2+1F/GxAxwhGLsTJK3LJ1J5vQfE/yvHgWAuPLXNx+wgLSbYmTjDuOlj7iiRTQgghwuaD9fl4fI2vRnVXHZVL3yR1yr3E9B575PaY7FHEZI865rFl9W52FdXSu0M8Hp/ON9uKmLMoh51FNTg9OiZVISXWwrWjunDNqK6kxQcxmBhhhmHwjy938Pry/QGXjB0uyvH9jmKW55Typ6n9mT6ySyTDFEf5elshmtr4OiqU926dy8dbqw5IMiVEkCKXTNUWge5/DfkZnTQmdjPxxHIXfz27maUCtcVhCE4IIYRoUO1ovIICwFWwA8PrJqbXmGZfw6QqlNQ4WbanlH8v2IXPMKhz/ZSQeHWD4hoXzy3M4bmFOUzoncZjlw8iKcbSaj/Hifp//9vB/BWBE6mjGTTsF/vzp1vRVIUrz8gMf4CikZJaFy4/f69Q3rsARdWu1g7Nr4JKB19vLaS01o1P10mOsXBmdjsGZEjfNnHyiFwy5fOA4X8NL8Ajk6yc+Uod94xq5kSiS9lOIYQQ4WEYRsC9Uj5HNWpMAorafC8eQzeYsziXtfsqmkxGjp7ZmfLkEj64bSydkqK/L/i7HUVBJ1JHc3h0fv/JFoZkJtGzfXyYohOBuL263yV+obx3Ady+wNdrJ8owDJbsLmXOohzW7q8AfvocmFSFfy/YTUayndsmZPHzwR2xmqT3lWjbIpdM2RJBNYHPfzI0IF3j571MPLrUTd+0JjY9WuXLWQghRMt4fDoLthUxZ3Euu4pqcHp8mDWVtHgrN4zpxrQzMrGZNb9JhGZPQK+vxtB9zV6Uurw6q/eW++1N5T8ug6JqF9PmrODzu8eTaDe36OdrLU8u2B0wkWquiIHHq/Piklweu2JwJEMWQILNjNmkNipAEcp7FyDWVYenuBhzenqrxufx6dz37ka+21FMvZ/eV17dwKv72FNcy+8/2cKcxTm8NXM07eLa3jJYIQ6LXDKVOarJmSmAP0+0MWxOLfePCfChUc3QfWLrxyaEEOKUZhgGLy3N5elv9zRacufy6uRXOPjnNzt54uudxNtMfhMJa0YfFJOZ+l0riO0zrsnjeXQDf1METSUiPsOgqNrJP7/eySMXDzjxH7qF9hTXsrPQf1/HYIoY+Az476aD/GFqf+KskS0afLob0S0FzU9xk1DeuxZ0hhXvJnfqX9Di4rAPGYJ96FDsQ4Zg690LxdyyRN+nG8yct5aVuWU4Pc0PMtS7feSW1DH16aV8cfd4kmPbzhJYIY4WuW+5dtnQfgAUrA34kOwUlav6m3lqtZuB6X5mp1QNRt8WxiCFEEKcanTd4IEPNvG/zYVNLls7fIHnq/eg0LAP6GiqNZakcddS/s0LKKqGrftQFNWEc99GnAd+IHnSzQ0PNHQURcHg2IvaYBIRj8/g/bX5PHxBX+yW6CxvenPVfrx+EsFQihioisJnmw5KMYoIG9g5kU5JNnJK6o65Pej3LoDJxB1/uZ30uHtx79uHY8MGHBs3UvnuO3gKDmLr3/+nBGvoEEzJyUHF9vhXO1mVWx5UInWYVzcorXVx82tr+PiOM4N+nhCRFNkho3H3wsezG6ryBfCHCVbm/+B/8y8dB0NqVpiCE0IIcSp65LOtzSZSR/OXSByWMPIy1Nhkqla8S+lnT6BY7FjbZ5Mw5qqfHqSojRKxUBIRRYFPNx1k2ojoFHHYVVTr93cQShGDerePfWWBz/UifG6bmM0fPtnSaBldMO9dBTgzK5X2CQ3FwKw9emDt0YOkyy8HwFddjWPTDzg2bKDijTc4+KtfYUpNPSa5smZno2jHDgTUu728vnxfi5aOenwGOwpr2JRXyWApvS/aoMgmU70ugKSuULr7SL+pffceuwcqM1HF+buExs812eHcv0QiSiGEEKeI9QcqeHdNvt+LuOb2/mAYDZnNceL6TyKu/yS/xzOrCjoGx+/fDzUReWfNgaglU3Uu/5V3Qy1iUFkfYGBUhNXPB3XkqW92ku/yoB/Xkqap9y6AzazxwPm9A96vJSQQN34cceMblgoaPh+uPTk4Nm7EsWED5a+9hre0FPuggQ0J1pAh2AcP5pOdVf4+SkBwM7Yur48Xl+TyzDXDQvxtCBF+kU2mNBPc8CnMGQ+1pcFX5jPZ4KKnoMuo5h8rhBBC/OjFxbk4vY0TqaAamIbYWNdmUkmOtVBW68J33NxUqIlIaW30KtfGB9jnFGoRg+Q2VOb9dKJVlPH46pe5tfc0qhUtYHXK49nMKk9fPZT+nYIvS65oGrbevbD17kXyVdMA8FZUNCRXGzdR9tLLOLds4dkJv6Te2nhWKdgZW92Ab7YVUVXvITEmusVZhDhe5HeGxraD2Uth/sVQvhfctYEfa45p+O8Vr0DvCyITnxBCiCYVVDqYv2IfX20tosbpQVEUkuxmLhuWwfQRXdrMRvHyOjff7SjGOO5aMpQld2ZVAaVhD1CgynyqAlaTxpisVIZ3Teb/vtnV6DGhJiJePXylqZszICORlXvLGjUuDqWIQaxFo1eHuHCGKfxwHzjAgRm3kHX5Zfzv6vO59qVVFFY5qfNTOe+wGIuGosDc68/gzOx2JxyDKTmZ+EmTiJ/UMAOmezwc/MPXjTchEtqMrUVT2V9ex6AYWeon2pbolNmJTYVZi2HvIlj2JOxfDiYrGDounwGGgTU2EcbeBUOuBbt8cIQQItp2FtbwyKdbWbu/AsMwcB91sV1S4+LJb3fz7wW7Oadve3738750TIxuv6RvthWiqY1nl0K5gPPoBiO7JzOyWyrzV+4/MspvYKAqCm6vznn92nPL+B4Mzkzi3TUHMGsqXv3Yi9dQEhFoKHEdDj7d4Lsdxbyz+gCHqhx4fAaJdjPje6ZxzagupMVbuXZ0F15Ztpfjr35DKWJgABcM6BiWn+F0UufyUlLjwuHxEWc10T7BhsXkv32Mc8cO8mbNpt3tt5E8fToAX983gcW7SnhhUQ4b8yqxmFR0w0BBQTcMUmIszJ7Qg0uHdQ5b5UU3gdvdhDRjq0Ct0/8SVCGiKXo1S1UVsiY1/KsqgKIt4Kxi60EH/9uv8ttZ14W8xEIIIUR4LNtTysx5a/32hjnscJWu/205xPKcUt6eNZo+HfzsgY2Q0lo3Tj97pUJdclft8PLA+b2555yerNtfQVmtG6+uk2A3MzQziaSjlrON6JaCfvxUGKElIhaTyqTerdvfx+nxMWdRLq8t34vbpx9TGh5gc0EVzy3cw1m90njgvN4M65LMityyRq8TTBEDs6oc6dclQmcYBhvyKnlxcS7f7ijGpCqoioJPN1AUuGpEJjeN7U6X1Jgjz6lft478u++hw+9+S8IFP63k0VSFSX3SmdQnnfyKenYW1lDj9GIza3RKsjEwIxElzNdaVlPjgixH4gtlxtYgahUuhWhK22gAkZjR8A+Ib1/Dgi3r+K0kUkII0SZsyqvkltfXBl0NTzegot7DVT82oO2cHNP8k8LA49P9tXoKecmd58dqEmZNZXSP1CYf2yMtjn4dE9iQV9novqAqAdJQUe0XY7s2G1ewKuvdXPfSKvYU1+IMsFTx8BLGBduLWLanlLsmZ7N2f3mjpX7QfBEDk6YyY1z31gn+NJNXXs/Nr62hoNKB0+NDN+D43XNvrNzPW6sOcFavNJ6cPgTf8qUcevg3dHr8ceLGBS4f3jk5JiqfRUVRSImxUFbXeB9gKDO2bp8e9dluIfxpG8nUUbqmxlJQ6cDl9WE1yQiEEEJEk083mPH6mqATqaPVOn3c8eZ6Prmz+WVt4ZBgM2M1qY32OoW65C7RHtqSu1snZvHLdzf63afSXCICMLJ7SqtdNDo9PqbPXUlOSa3fxOh4htFQTfCJr3YRY1ZRFSPgXjF/bGaVZ64ZSmZKdBLok9muohqueGE5tU6v30GAwxr+jgaLd5Uw9f99yRPf/Zs+zz+HfciQiMUaqutGd+WFRTmN3kuhzNgOzEikQ6It0qEL0aw2l0xZTCqdk+zsK62nd4f45p8ghBAibL7bUdyi3jAAPsNgZ2ENu4tq6Nk+8t/no3qkoPpZ5RDKBZzVpDIxxCV3k/uk0z0tlp2FNUElMEezmzUevqBvSM9pyl8/28be0rqQ4/DpBroB953bi38v2IXb63+W7zCTqmDWGhKpyX3bn2DUp5/iGifT566kxuENuCTueC6vzgGXjz9f/DAfDBwU1vhO1LWju/DCohy/9wUzYxtr0bh1gvQZFW1Tm0umALLS48gpqZVkSgghouyFRTmN9tdAkKXFAY+u8/LSvTx6eeQv9vp3SiQzxc6uosZVY4NdcgdwzaguIR3XpKm8MWMUU59ZSlGVC/fxTacCsJlVnrt2GBlJdt5ctZ+88npqnF5SYi3065jAOf3aY9YCb+Y/Xp3Lywfr8wPOLDWXDBtAvM3E+7PHMmdxDl9vK0JVftobBw2V4AwDLh+ewS3jetCtXWzQ8YmfPPXtbqodHr+JVJMNbTUTuyo9fL2tiCkD227Bj/R4GxN7p/H9jhK/n4emZmwVINZqYlKf1t1HKERraZPJVHZ6HHuKmyiZLoQQIuxKa11szq9qdHsopcV9OvxnQwH/77KBYd/o7s9tE7P47cdb/BbOaG7JnaLAhF5ptIuzhnzcpBgLn901nptfW8P2Q9VH9r/4E2PRUBz1PNBN4eMNBdz6xjpURTlmRjDWqqEpCteN7soNY7vRPqH55U7/2VDgd2YOgkuG690+XliUw+IHJ/HMNcOoqHPz300HyS2ppdrpITnGSp+O8fx8UEdiLG3ycuKkUO/28uG6Arx+3iDB/p2eX5jTppMpgMevHMyFTy7hYJWjUVPrptgtGvNnjPJbmVOItqBNfvtlp8WxaFdJtMMQQojTWnG1C4tJbTSSHEppcWgoL+706FGpxDVlYEee+W4P+8vq/V6sNsVu1njw/N4tPnai3cwHt45hY14lLy7JZcH2YqwmtWHKR2kobJGZHMOtE7I4mJvPo6sL8ZgO+k26Ds8OvrR0L/NW7OfVm0YwoltKk8d/dflev0lkKMlwWa2brQerGZCRSHKshRvGdgv59yCa9ummg36LF4fyd9pdVMOe4hqy09vuip4Em5kPbx/L9LkrKahwNLsXTwVibSbm3TxSViqJNq1tJlPpcby8dG+0wxBCiNOay+vze5EXamlxTVFwenxRSaasJo23Z41m6tNLfyxpHlxCZTOrvHDd8BPe66UoCkO7JPPctcOpqHOT8+Osjs2k0T7RRlZaHM8t3MNzP1Ti0sx+G5seze3VcXt1fvHyaubPGMkZTSRURVUuv7eHkgyrisLBSgcDMhKbfaxomU83HfKb9Ibyd/IZBgt3lrTpZAoalvt9euc4nlu4h/krGvq2HV+oxW7R0L0+Jpbv5Hd/m0VmqjR/Fm1bm0ymeqTFsre0Dl03UGVaVwghoiLeZkb3k3yEXFpc14m3NZxuquo9fL75EAWVDmqdHlJirfTrlMCk3mmYQtgPFIr0eBuf3z2e619exf6yehxuX8CcJcaiYVIVXr1pBMO7Nj3zE6rkWAtnxB77mot3lfD0t3tCrpbo8Pi48dU1LHxwYsBliIH2aoWSDBuG0aJKjiJ45X5KhkNofyePz6Cs1v/rtDWxVhMPnt+He8/pxYJtRby/Lp/SGhde3SApxsx5/dpz2bAMym+cT8LyXjD159EOWYgmtclkKt5mJtFupqDSIeVVhRAiSjJT/JfnDrW0eNeUGHYU1vDikly+3FKIqio4fhyNVoAYq4ZZVblxbDeuG9O1RXuUmtMuzsoXd49nZW45z76zhNXVGla7BcNo2Bvl9Rl0TrZz64QsLhzUMWINZ//59c4WV0v0+HTeWnWAuyf3bPTcercXs6bg8jZ+3VCSYVVRiLeayCuvp7zOjQEk2c10SYmRwc5WEmgrYaiDFifb38OsqVwwsCMXBNjrpd13H4f+8AcSfnY+ijm09gRCRFKbTKagYanftkPVJNjMxNlMsvFQCCEizGrSuGpEJvNX7j+mtHYopcVjLBp9O8ZzxQvL8XgNfMaxc0IGh/cD+Xh+UQ6vLNvLvBmjGJKZdMLx+3SD73YU8+H6fIqrnfh0g+QYC6MKt/LwpCFUDR5OjdOL3azRKckW8SVSe4pr2VlY4/e+YAoPuLw6ry7by+0TszhQXs+GA5VsyKtgw4FKckpqA543Q0mG691efvPxFiod7iOVBL0+gzibiZnjuzPtjEySYiwn8FsQqXH+f3+h/J0sJpWU2FPr7xA7ehSWzhlUfvgRydMbV9kUoq1oc8lUTkktryzdy8rcMpbtKcWkKXh1gy7JMcya0INLhmQQa21zYQshxCnppjO78+aqAxy/mSfY0uJur87CnSXHlNMOxOXVcXl1rp67kndmjWZwCxMqh9vH3MW5vLZ8L26v3mhPxkprX55dr3Iph7jr7Gw6JbVOg9xQvb58n989XKEUHqh2eBn8yNck2S0M6ZLE0MwkLh3amf6dElieU8pdb29oVNo+lGTYAAqrncCxJdEdHh//+mYX//x6F/ed24vZZ/WISrXGU8ElQzJYt6+i0fs0lL+TQkN/s1NN2n33kX/nXSRecjGqTRr2irapzWQlh6oc3PnWBrYUVOHTjSMnmMOjofvL6/nb59v5y2fbuGVcd355bu+TbkpbCCFONpkpMZzXrz3fbC9qlBA1V1pcVcB71Pd5sBweH9e/vIrFv5oU8qxHaa2Lq+eu5EB5fcBqYU6TFbw67609wOebD/LWLaOjUmBhc0GV399NKIUHDAxuPrM795/XuOrghF7pWE2a3z5hwSbDTf3pDr8fnlywm4OVDv58UX9JqFpgysCO/P4/W/zeF+zfaUBG4inZ48s+cCD2QQOpePMtUmfc3PwThIiCNpFM7Smu5coXllPt9OJr4pv7cLWbl5fuY09xHc9eO0yW/wkhRJg9MW0wV81ZwfZDNc2WMz7MbtYwawrVTj+bdmh+P5Dbp/PumjxmT8gKOs4ap4crnl9OfoUjqATOpzfM7Fw1dwWf3HFmxJf51To9/uMKofCAbhDwvKmpCved05O/f7HD776s5pLhYDk8Pt5fm0/nZDuzzgr+7yUa2Mz+l9Me1tzfKcaiMfusHuEMMarS7rmH/b+4gaRpV7KnrqGU/MFKBx6fTrt4K2dmtWNSn3S5HhRRE/VkqrjGyVVzV1BZ77/ztz8Oj49Fu0r47cebefTyQWGNTwghTndWk8Y7s8Zw+5vrWZFTiqOJJXuaqmDWFM7pl84324r8PiaY/UBOj85LS/cyc3yPoFchPPTBDxyscoY8E1bv8nH9y6tZ+tDZEb0gswdodBtK4QFNgThb4FP5daO7svVgNZ9sPHhCVfmaS34dHh//+noXV4/sQrxNigWE6u7JPflicyHFNc4mZwOPZ9UUhndN5py+7cMXXJSZe2Sxcvyl3PWPb9iv2PF4dY7OOd9bk4fVpHLTuO5cP7qr7OETERf1ZOrRL3ZQFSCRaurL2+Hx8cnGg1w9skuL19ULIYRoXkGlg9eX7WPtvnKcHh2FI31n0RQwaQomTcPj0/n5oI7MGNeDeSv24fEzixXKfqB6l5flOWWM69mu2RiLa5ws2FGMO8DMWVPnEwOodnpYvKuESRHcd9IjLZatB6saXTyHUnjAZtbonBy46q2iKPz90oHE2Uy8ufIALq8vpIt1CC75hYZqch+uy+fGM7uHdgBBUoyF92aP4dLnllHpcBOgqv0xrIpB9+oinps68pTd9uD0+BoGccz9cXh1oPEvps7to87t45nvGnpXvTNrND3SpDeViJzwNPUIUrWzod+Iv1HE6tUfU/7tiySOnkbnO98g47ZXiR82BcfuVUce4/L6eHFJbiRDFkKI00ZFnZsbXlnN2U8s5LXle6l2ejH4qRSFASiqgs+A0T1SWPmbyfxz2hD6dUpgc0EVflYshbQfyOMz2FXkv9rd8d5edYBAl5PBnE/qXD5eWJQT1LFayw1ju2H1U4L96MID9btWoHucGD4vjpy1VHz/yjGPNYDz+jU9K6GqCr+7sB9vzRzFz/p3wGJSsZu1I78vs6ZgN2tYtMa/wcPJb8q5txHTeyyqxYaimYjJHnVMAQRoWIo/d0kuhhFitiYA6JIaw//uHc+QzGRsJpVAbdesJhWLSWXqsExe6lxG2T13ozudkQ02Arw+nZteXcPyPaU/JlJNc3l1SmpdXPLsMvLK6yMQoRANojoz9eHafFQ/m1WDHbnUDfhmWxEVdW6ST7GSoEIIEU0HKx1c9txyyupcfvdxHHb4vmV7yrj77Q28fMMILCaVWn8NjghtP5Dbp7Mxr4KVuQkk2hv6DybYzcRatEaFDuav3O93P1coM2Eb8yopqnbSPiEyVcOGZiaRHm9lf1njC79gCg+YVYVpZ2QG3RNraJdknrtuOOV1br7eWkhJjQunVyc5xszAzolcPXdlo+eEkvwClNW6Kap20SFRKq+1RHq8jQ9vG8ue4hpeWbqPjzbko+sNy2e9uk6c1cRNZ3bnmlFdaBdnxdAHcvChX1Nw/wN0fvLfKKaoLzhqNY9/tZONeZU4g9ynCWAYUOvycs2LK1n04KRTdsZOtC3RTabW5/tdwx3Kl7dJVVi8u4SLh2SEI0QhhDjtVDk8TJuzgpIaV6O+UIE4PD7W7Cvnvvc28szVQ7EHuMAPZT+QAmw/VMO/vt5FlcNz5J/Hp5NwOLmymYi3mSitdft9jVDOJxaTysFKR8SSKUVRuOvsbH7/n60tKhChaQo3ndkt5OOmxFqYPrLLMbcV1zixaGqjC9dQkl9oaMRa5fBIMnWCstPj+ftlA/nLJQOocXqod/uIt5mIs5qOGUhQVJVOf/srebfeSuEjf6HDn/90SlRUdLh9zFuxv0UNrXUDyuvcLNpdwqTep165eNH2RDWZqqg/8UpGXt2gvM7/SVQIIUTonv5ud0OTWz+JVFMXMU6Pznc7ilm6p5Tu7WLZWVjTaD9sKPuBYqwavzy3FxcM7HjM7W6vTrXzp+SqtMbFspwy/OV9oSYDDnfLizS0xOXDOvP9V6tZ4NRwacEXbrCZVR67fBBdU1unHLaqKH73LoeS/EJDqXaZDGg9mqqQFGMhKfC2OBSLhYynnmL/L35B6XPPkXbHHZELMEw+3XSQQDlhMHv46tw+5izKkWRKRERUkyk9wIhnKF/ehmGEvJlWCCGEfy6vj7dX5+H2s7QvmIsYh9vHI59uw6z5vzgPpREpKJzdt/HFkMWk0i7OSrs4K9BwHjhcFON4ISUDRtOV8cKh7KWXuHvRh1iu/iNf5lQ1W3FPAaxmlb9ePICLWnFFRqLdjNfP3zyU5BfA4zVk2X0UaHFxdJkzh31XX4MpPZ3kK6+MdkgnZO6S3CPtcFU8eJ0AACAASURBVI4WyrLdDQcqOVTloGNidJpyi9NHVAtQJNr9j8Id/eXdHJOmBnwdIYQQofnf5kL8TfGEUoggp6SW60Z3oWOApV4JIy8j+ewZVK14l/ynryX/+RupWf8Z9p4/LcUzawrXjMzEamp+NkRRFDIDVLQL5Xzi9ul0SWliCqCVlTz3HFUffUyPefP41w2j+ccVg+jVPg67WUM7bljealKxmlQm9k7jnVljuOKMzFaNxaypjO6R0uj2UIphAHRPiz2S5IrIMqWlkfniXEqeeoqa77/3+5hDVQ62HqxiS0EVeeX1bbZYSKACEqEu291XKoUoRPhFdWbqvP7t2Vta12jTcCgjl17dYGxWaqRDF0KIU9IH6/Kp8zMiHNJFjKbi0+HOSdn89fPtLdsPpCrcMLZb0HHPPKs7f/9iR6PR7GDPJ4quM85cR1x9NcQ0X4r9eOV1bt5dc4C3Vh2gvM6NVzewWzSGZCYx66wejOmRemQvi2EYlDz5JLXffkvX+fMwtWs43kWDO3HR4E5sKajirVUH2FtaR73bS4LdzNAuSVw7qmtY93PNnpDFxrzKRn//YIphAMRaNG4LocmyaH3W7t3JfOYZ8m69DdPzz2EfMgSH28enPxzkhYU5FFQ6MP9YJtCn6yTaLcwc350rzshsUwPTgVochLpstyZAY2whWlNUk6nrRnXlhUX+S5sH++U9slsKnZJkClcIIVpDaa3L7+2hXMS4vDrldW7untyTpTmlfL+jGGcTjX6PZzOrPHHl4Cb7Jx3vkqGd+evn2/3eF8z5xGbRuNq9l5wpFxJ/9tmk/OJ6bP36NXvcijo3v/vPFr7ZXoSqcMzP6fLqLNpZwuq95STazfz2wr5cOLAjxY89Tt3KlXSZNw9TcnKj1xyQkcjfLxsY9M/eWsZltyPWavKbTDeX/AKgwAUDO4QpOhEs++DBdHr0/5F3511s/sNT/H5pQ/PswwMNRw9gOzxOnvh6F499tZNfntuLWWf1aBMFLCwm1W91zlD38EV62a44PUX1XZaeYGNcVju+31Xsd+Nwc1/eMRaN2RN6hDFCIYQ4vbTKXlYaVg0oisJT04dy77sb+W57cdD7gf5+yUB+PqhTSHHHWU1cP7orb6zcj8NP4tbU+cSsKvTskMDkO36F74FZVL7/AXm33Y4lM5PkX1xP/OTJKFrjnzm/op4rX1hBaW3g8vEGDRex9W4fD7y/iY3vf8G129fQ9dVX0JLaVsN5VVV4cvpQbnptdUjJLzQkwP+8cnBQyzJF+MVNmMBXV9/PM9/sb7awyeHP5b8X7Kag0sGfL+of9YSqY6KNfX5aBoSyh8/tjeyyXXH6iuqeKYDfXNg3YAndplg1haGZSZyZFfpyDCGEEP4ltcJeVoumkhTTUITArKk8c/VQ/nLJALq1iyHGojWq9nZ4P9CEXg37gS4b3rlFsT/0sz4M65KMzRz8qc2kKqTEWXjtppEoioIpOZl2s2aSveAbkq+5mvJXXiXn3PMoe/kVfFVVR55XWe9m2gsrKKp2NtmH62hOj858Vxrf3/6XNpdIHTYmK5V/Xjk4pN+hzazy2yl9+dmAjs0/WETEF5sP8myhNaQKkQ6Pj/fX5jNncWSbV/tzy/juxFhOrKF1/04JIc1uC9FSUZ//zE6P45UbR3Dza2v8Vm7xx6oYZFYX8ew5/aUhmxBCtKKfDejI5oLqRrNIoexlVdWGJWOHKYrCFcM7c/mwDDblV/Hmyv3sK6vD4fGRaDMztEsy143uesK9iUyayis3jeDutzewZHdps+cUu1mjY5KNd2aNJuW4CnSK2UzClCkkTJmCY/NmyufNZ8+555Fw4RRSrr+ev6ytpqTWFXI1WZdq4h/f7eX8YV3JaKNL1C8c1InUOCu/fHcjlQ4PDrfPb6XEWIuGzaLxj8sGcU6/9hGPU/jn9en85uMtAWcXm2pv4PD4+L9vdnP1yK5R3UN1ost2Yy0at8r+PREhUU+mAEb3SOW92WO46bU11Dvc1Hn9n53MqoKqKkzqncYfjCpKZ88i9o35mNOlj4AQQrSGy4d35h9f7vB7X7B7WbPS4ujdIb7R8xVFYUhmEkMywzcrYzVpvHDdcL7bUcycxblsyqtEN4wjs0eqAjazRnq8ldsmZnHxkAxszayOsA8cSMbjj+EpKqbinbfZduNMPh1zDx7F//OauliFhgIU85bv4+EpfVv3h29Fo3uksuzXZ7Myt5y5i3NYvKv0SN8f3TAY2T2F2ROymNAzTQY125jvdhTj8flPpIJpb6AqCh+szWPG+Ohto4izmph2Ribvrclr1Egaml62qygQbzNxdh+5NhSR0SaSKWjYcLvy4cl89NCjvBXXl61uKxaTiqI0fHEbBkw7I5Mbx3ajW7tY4AxK62vIm3ELXefPa7NLJoQQ4mSSaDczZWBHPtlY4HfWJZi9rNEeEVYUhcl92zO5b3sOlNXz1dZCimuceH0GqXEWxma3Y2hmUsj7Qszt00m/5x4+GzwF9eud4Od6NZiLVbfP4M1VB7j/vN5YTFFfbR+QoiiMyUplTFYqhmFQ7/ahGwZxVlPU99SIwF5YlEOdq+U9mhweH3OX5HLzuO5R/Tv/7sJ+bM6vYuuh6oDV/fyJsWi8OXM0Jq3tfrbEqaXNJFMARlUlgxZ+zGULbqdCs1Fc7cLp9ZFgM9M52d5o9DB19mx81TUcmDWbLq+8ghbXOp3ghRDidHb/eb1YsL2IGqc3pOeZNYWstDh+NqDtVHTrkhrDzLNad4T97fUHceqNLzJDaSgKsCK3jAm90lo1tnBRFIVYa5u6ZBB++HSDjXmVfu8Lpb1BlcNDfoWDzCgWcLCYVN64ZRQzXlvDpvzmG1qbNYU4q4k3bxlNVlpchKIUog0UoDha1UcfEz95MlpiIu3irPTrlMCwLslkp8f5XYahKArpDz6ArXcv8u+8E93lv6SvEEKI4HVOjmH+jFHEWDSCHZe2mFQ6JdmZP2PkkT42p6qyOrff20O5WNUNg9IaOWeJ1lXj9KAFWHYZSnsDk6pS5Yh+j6ZYq4k3bhnFH6b2o2tKDDafu9F3UqxFI85q4sax3fjqvrPo1ykhKrGK01ebOeMZuk7Fe++SdNW0kJ6nKAod/vQntOQkCn55P4bH/4ffpxtU1ruprHfjC3XHsBBCnGaGZCbxnzvOpEOijVhr4IsvTWmo5nZG12Q+vWvckSp+p7JA55BQLlYNA9wB9rUI0VJNLcs7ur1BcK/VWlGdGJOmcvXILnw5tQOP7fqIa0d14Zy+6ZzVsx2XDOnE3y4dyLrfn8NvL+xHenz4mloLEUibmbOvX7kS1WbHPmRIyM9VNI2Mf/yDvDvv5OBvf0unRx9FUVV8usHCncW8sCiHdfsrMKkqBgY+XTbPCiFEc3q1j2fZQ2ezeHcJcxblsu5ABRatYS+rYTRUDZs6uBMzxnenT4fTZzQ41qr5HbUPpReXpipRrZYmTk3xVlPAZD+UHk0en05yGxsYqf32O0aN7s9Fl0a+obUQTWkzyVTFu++RPP2qFm92VCwWOj/5JAdmzqTor39j8xUzeejDzbg8viPd3I8eBVyZW87m/CrsFo1/Thty0qxbF0KISFJVhYm905nYO53SWheFVU6cHh/xP+5lPR330YzNasfHGwoaXbSG1FDUpzO0ixROEq1LVRXGZbdj8e7SxveF0N6gY7yFjifYqqC11Xz7LR1+83C0wxCikTaxzM9TXEzdihUkTJ16Qq+j2u1kPv887+51cOe8NZTXuY8kUv7UuX2U1rqZPX8t76/NO6FjCyHEqa5dnJUBGYmc0S2F3h3iT8tECmDGuO6YtcYDf6E0FB3VPYWOiW2zz5Q4uc2ekOW34S00tDdIPnsGVSveJf/pa8l//kZq1n+GvedP+/zs+Lh0+XscfOghHJs2YRjR3xrhzs/HW1SEfdiwaIciRCNt4kxY9dFHJJx/PlrciVdf+fZAHc9njMcVQhlNp0fn959sIS3eysTe0pdACCFEYH07JtA9NZbthTWN7gumF1eMRWP2WdJQVITH2KxUEu3mgE2rm2tvoFgszHjp77g+/S8FDzyIlpBA8rXXkjDlAlRbdGarar/9lrhJE1G05vcjChFpUU+mDJ+PyvfeJ+Opp074tdxenfvf3+S3wVtzTRSdHp373t3I2t+dG7ASjhBCCAHw+6n9uPm1NTg9oTUUtWgKvdvHMzYrNdwhitOUoij831VDuPHV1X7fn02xmVX+fukA4tJSibv5JlJu+AW1S5ZQ8dZbFD/+OEmXX0bS9KuxdM4IU/T+1Sz4lpSbboroMYUIVtSX+dUtXYqWkoJ9QP8Tfq2vthai+9l4Wb36Y8q/fZHE0dPofOcbZNz2KvHDpuDYveqYx7m9Ot/tKD7hOIQQQpzaxma1488X9cdmDv40atEayse/dvNIKXwkwmp0j1T+eeXgkN6fNrPKL8/txSVDOx+5TdE04idOpMvcuXR7520Mr499V1xB3m23U7t0GYYe/oqU3vJynNu3Ezu2+ZYDQkRD1JOpw4UnWsMLi3Ia7ZE63EQx5dzbiOk9FtViQ9FMxGSPOmazJTTsoXphUU6rxCKEEOLUdtWILvxrWsMFq72Zi1ZNVRjYOZH/3jVOqviJiLhwUCdeu2kkHRJsxAbYQwUN1SmT7GaeuHIws5pYfmrp2pX2v36I7O8altwVP/EEuVMupHzePHw1jZe8tpba7xcSO3Zs1JYYCtGciC7z8+kG9W4vdrOGSVPxHDpE/bp1ZDzx+Am/dlW9h11FjT/MoTRRBNiYV4nD7cPexBePEEIIATBlYCfOzErj/XV5zF2cS53bi4qCQUOfHrdXZ2xWOzbmVfDHn/cjwSaJlIic0T1SWfHw2SzPKWPu4hyW7C5FUxUUFDy6zpDMJG6dkMXkPumYgmy2rcbEkDxtGklXXolj/Xoq3nyLkmeeJeGCC0i+5hpsvXu1KNb8inrmLd/PfzYWUO30YBgN+wuHV5Zyy+jJdG7+JYSIirAnU1UOD++vzeOlpXspqnKiqQo+wyA5xsJlSiGXTrkYNSbmhI9TUe/GrKl4fMfOTIXSRBEalmFUOtzYLVJlSQghRPMSY8zcMr4HN5/ZnR8KqiiqduL26iTYzfTrmEBavJX5K/fz+Nc7mT9jVLTDFacZRVE4M7sdZ2a3Q9cNalxeDMMg3mY+oT3iiqIQM3w4McOH4ykupvK998mbORNL164kX3sN8ZMno5ibHzw4UFbPQx/+wPoDFeiGgcf303YNl1fnO1MHlm03k/GvRfz1kgGM7iH7DUXbErZlfm6vzm8/3szIvy3gn1/vorDKiQF4dQPDgPI6N/Oq4rnUNZhb31hHrct7QsfTDQN/XwmhdvxueK0TCkUIIcRpSFUVhmQmcX7/Dkwd3IkJvdJIi7cCcNUZmewvq2dFTlmUoxSnM/XHZtFJMZZWLbZlTk8n7c47yP52AcnXXE3FG2+y55xzKXnuObwlJQGf90N+JRc+vYRVe8twefVjEqnDdFXD4dHZU1zLja+u5qP1+a0WtxCtISzJlMPtY/rcFXy4Ph+XV8fh8Z/IuDUzbt3g+x3FTH16CWW1rhYfMynGckxT3sOObqIYDI9Pl/XsQgghWpXF1LC5/7GvdrSJvj1ChINiNpNwwQV0fWM+mXPn4C0sIufCn1Nw/wPUr99wzHt/X2kd1764ihqnN+hBbKdH5zcfb+Z7KRYm2pBWT6Z8usGs+WvZerA66JKcLq9OfrmDa15chaOJJrtNSY4xk5HUeGleKE0UAbLT44g7TRtRCiGECJ+LBnfC4faxYLtcCIpTn613bzo+8meyF3yDfdBADj78a/ZedjmVH3yA7nBw9zsbqHP7X5VUt20hh16/lwP/uoL8Z66n6L0/4szfCjQkVHe+tR5ngIF6ISKt1bOGL7cUsm5/hd+muU31evLoBvvL6nht+V5um5gd8nEVReHWCVk88tm2Ro3qgmmiCBBr0bh1gjRSFEII0fpUVeGB83rz+Fc7ObtPuvQ0FKcFLSGBlBtuIPn666lbtpyKN99kzfPz2DliJrqfMf3q1R9TteoDUs+7A1v3YSiaCcfedTh2rzrSGxTg000HufKMzEj+KEL41erJ1POL9vjtuh3Mh8Pp1Xl56V5mn5XVoh4cFw3pxJ8/3eb3vuY6fkNDQnbBwA4hH1cIIYQIxuS+6Ty3cA//3VTApUOlPpk4fSiqStz4ccSNH8fT81bg3dZ4/+DhdjapU+4lpvfYI7fHZI8iJvun4i2HW9lIMiXaglZd5re7qIY9xbWNbg+l15PD7WPx7sCbFZsSYzHx2yl9sJtDL2tuM6v8cWo/rCYpiS6EECI8FEXhVz/rw7++2YXb44U938Ind8Abl8P8S+Gj2bD9U/CdWFEmIdqyz/dU4fNTNiyUdjYFlQ4OlNWHIzwhQtKqM1P/21KIx08RiFA+HHVuHx9vKGBi7/QWxXDdmG4crHLy6rJ9AQtfHM9mVrl9QraMcAghhAi70Z3t3Gb5Au/jM7HgAHfdsQ/Y8RloZhg5C0bfBvbk6AQqRBh4fTr1Aa7PQmlnY9ZUSmpddEk98fY6QpyIVp2ZKqx24ieXCrnXU3F1y6v6AfzqZ334zZQ+P3alD3zMGIuG3azx54v6c/c5PU/omEIIIUSzakvgxbOZXjOfGHdp40QKwF0LjgpY+m94/kyo2BfxMIUIF6/uv5UNhN7Oxt8AvhCR1qozU3qA2pZHfziCSai8+ol/OK4f041Lh3Xm4/X5vLAol7Ja15Hu3h6nizSrwu0XDuDiIRnESvU+IYQQ4easgpfPgaoCVN3T/ON9Lqg5BC9OhluXQkLH8McoRJjZzBqKooCfFgFHt7OJ7TOuydcxDKSVjWgTWjWLSIuzogDHfzxC+XAApMRaWiWeOKuJ68d047rRXTlY5aSy3g2A+s3/SNixmYxRU1vlOEIIIUSzPpoF1YcgmETqMEMHZyW8dRXcujh8sQkRQf07JfBDflWj249uZ6OoGrbuQ1FUE859G3Ee+OGYffaGYdC9XWwkwxbCr1Zd5je+Vxp2S+OZp1B6PcVYNH42oHUr6imKQkaSnf6dEunfKZHuo4bg3LihVY8hhBBCBFSZB7nfN8w2HaXbv2tIf7yGOvdPw5AvrXcz8bWjlv/pXijbDQXrIhWtEGF124QsYq3+VyoljLyM5LNnULXiXfKfvpb852+kZv1n2Hv+tO/erClcNSITWwsKjgnR2lp1ZmpEt2SSYyzUux2N7gu215MCXDAgvEsZrD174i0uxltRgSlZNvYKIYQIs9Uv+l3WBOAz4MlVbn4z3hr4+V4nLH8WrmzcbF6Ik825/dpjUlXA/96o5trZqIrCjWO7hyk6IULTqslUQ+PcHvz9i+04PI33PTX34TCrCtNHdgn7SIOiadgGDcSxaRPxEyeG9VhCCCFOc4YB614Fn9vv3Q+OtfDYMhe3j7CQZAuwNd/QYedn4KoFa1wYgxUi/EyaysMX9OHPn24LuvLyYXazypSBHaWKn2gzWnWZH8AVwzPpkGgPubO7AsTbzdw6Iau1Q/LLPmQIjo0bI3IsIYQQpzF3HXgC98M5o5PGxG4mnljeTCVb1Qw1ha0cnBDRMX1kF24Y2zXk3qDZ6fE8evmgMEUlROhaPZmyWzTemTWadrEWzEEmVKoC8TYT78waTVp8E8scWlHM0KE4NkgyJYQQIszcdaA2vRDkkUlWnl7tpqSuiWq2igLumlYOTojo+fUFfbnv3J5YTSoWU+BLUk1p6Ak6rEsSZbUuKur8z/IKEQ2tnkwBtE+w8cU94+nbKQG7WaOpnCrWopGZEsNnd42nV/v4cITjl33wYJybN2N4pcu8EEKIMLLGNRSRaMKAdI2f9zLx6NImLhINA6wJrRycENE166wsFj04iVnje5BgMxFn/elfvNWEzaxy+fDOfHLHOD66/Uymj+zCzHlrcbhDWx4oRLiErcFSapyV/945jk15lby4JJdvthVh1lQOtxZwuzwMN9Vy1w3nMKZHakPPgQjSEhMxdeyIa9cubP36RfTYQgghTiPmGLDENZQ4b8KfJ9oYNqeW+8cEWKGheyG+davdCtEWdEi08cD5vbnnnJ5szKukrNaNTzdIijEzqHMi8baf+knddXY2e0vr+OV7G3n2mmGoIW4rEaK1hb1b7eDMJJ65ZhjVTg955fXUOr3EWk2kHNyH4w8Pk/XX6eEOISD7kMHUb9woyZQQQojwURQYOQuWP9VQlS+A7BSVq/qbeWq1m4Hpxy0cUTTofylYpK+OOHWZNZUR3VKafIyiKDx6+UCue2kVj3+9k4d+1idC0QnhX1iW+fmTYDPTv1Mio3qkMiAjkY7DB+KrqcF94ECkQmhE9k0JIYSIiBEzApZGP9ofJliP6Tl1hMkCY+4MQ2BCnHysJo0515/BF5sP8d7avGiHI05zEUumjqeoKnFnnUXtwkXRCkEq+gkhhIiM+A7Q50Iw2Y65ed+98ZzT46dFIpmJKs7fJbDwxqNmoDQLdBwMHQZEKloh2ryUWAsv3zCCx77cwfKc0miHI05jUUumAOImTKB2UfSSKUuPHviqqvCWyodQCCFEmF38LKT0aEiOgqWaILYdTH87fHEJcZLKTo/jqelDufvtDeSU1EY7HHGaimoyFXvmWBwbNqDX1UXl+IqqYh88WGanhBBChJ8lBm7+smGWyRxEw1GTHZK7w8zvIabpfSRCnK7GZrfjwfN7M+O1NVIyXURF2AtQNEWLi8M2aBB1K1cSP3lyVGIoH3gGHy/fj9u3E5+ukxJrYXzPNPp2lPKzQgghWpktEW78An54F5Y9CdUF4HECh/tLKbhVGw4tnsTJ98Ow66XohBDNuGpEF3JL65g9fx3zbxmJ1RRaI2AhTkRUkymAuIkTqF24MKLJlK4bLNpVwguLcthQmIHh9eL5fg8AZk3hX9/soltqLLdNzOKCAR2bbCQnhBBChMRkaUiShl0PBetg26dQewh0HeI7UJo+nos+VVgxYjJmTc4/QgTjofP7cPub63n4o83888rBEW+5I05f0U+mJkzgwCuvYhhGRN74To+PO99az/KcMuoPN3w7qjO9x2fg8RnsKKzh4Y828+KSXN6YMYqkmBDWuAshhBDByBje8O8onYAuy5exeFcJk/u2j05cQpxkVFXh/64awrQ5K3j2+z3ceXbPaIckThNRH/Kydu+OYrfh2r497Mfy+nRufHU1S3eX/pRINaHe7WNnYQ0XP7uMGqcn7PEJIYQQAJcN68xH6wuiHYYQJxW7ReOlG87grVUH+OyHg9EOR5wmop5MQeSq+v3ls21syqvC6dWbf/CPPD6DQ1VObp2/LoyRCSGEED+ZOqgTi3eXUFUvA3lChKJ9go2XbhjBHz/ZyvoDFdEOR5wG2k4yFeZ+U1UOD++sycPh8T8jVbdtIYdev5cD/7qC/Geup+i9P+LM3wqA26uz7kAFu4pqwhqjEEIIAZAYY+asnml8KqPrQoSsX6cEHr9yELfOX0deeX20wxGnuDaRTMWMGIErJwdveXnYjvH+2jzUAHuyqld/TPm3L5I4ehqd73yDjNteJX7YFBy7Vx15jMer8/LSvWGLTwghhDjaZcMy+Gh9frTDEOKkdHaf9tw6IYsZr6+h2s9Wjbzyetbtr2BFThlbD1bhDDDYLkRzol6AAkC1WIgZPYq6JUtIvPjisBzjpaV7/c5K6a46Kpe+SeqUe4npPfbI7THZo4jJHnXk/30GfLKxgD9O7UeMpU382oQQQpzCzuqVxkMf/kBuSS090uKiHY4QJ52bzuzG3tI67nhzPa/eOAKvbvDF5kM8vzCHvPJ6zD9WazYMA92AaWdkctOZ3eiaKu0IRPDaTFZweN9UOJIpj0+nqNrp9z5XwQ4Mr5uYXmOafR1NVSiocNCzfXxrhyiEEEIcw6ypXDQ4g5eW5BJjMbG/rA6HRycpxsyo7ilcMjSDeJs52mEK0WYpisIfp/ZjxutrmTV/Lav3VmAYBnU/FiE7fg/9myv38/bqA/xsQAcev2KwtMYRQWkzyVTs+LNYNPcdXvvvFgpr3fgMg7Q4K2f3bc/47HaoasvLptc6vZhVFbevceEJn6MaNSYBRW2+wZuqKFQ7vS2OQwghhAiGYRj8b0shC3cVs7ekDlVpWCFx2Hc7ivnbF9uZOqgTt0/Kpns7GUkXwh+TpjJ1UEce/OAHjGYe69EN0A2+2lpIQYWDN2eOkgbAollRT6bcXp2PN+Tz/MIcCodej2v5/mPe7B+syyfGYuKW8d25ZlSXFo3C2S0aXt1/BT/NnoBeX42h+5pNqAwDYizyoRJCCBE+Xp/Ogx/8wFdbC4+08fAddxV4+PaP1ufz+eZDzLl+OON7pkU6VCHavLX7yvndJ1uaTaSO5vTobDlYxb3vbOT564Y3/wRxWotqMlXt9HDDy6vZUVjTsJ9JbZwo1bl91Ll9/N+CXcxfuZ/3Zo+hU5I9pOPYzBoxFhO1rsazStaMPigmM/W7VhDbZ1yTr+P26nRIsIV0bCGEECJYhmFwz7sb+W57ccDqs0fzGQ2J1cx5a3ntppGM7pEagSiFOHn84ZOtOD3+B9Trti2kes1/8JTlo1rsmNN7kDh2GrbO/XF6dBbuLGFLQRUDMhIjHLU4mURtMajD7ePKF1aw9WB1UCcMp0fnUKWTi55ZSkmNK+TjTR+ZiVlrvFRQtcaSNO5ayr95gfpdK9A9TgyfF0fOWiq+f+WYx47JSiE51hLysYUQQohgzF+xP+hE6mhOj86M19dSWe8OU2RCnHx2FFaTW1rr975gKjm7vTovLsmNVLjiJBW1malff/QD+0rr/O5jCsRnGFTWe7jp1dV8dvf4kI53w5huzF+xH/xM9CaMvAw1NpmqFe9S+tkTKBY71vbZJIy56shjYi0as8/KCumYQgghRLB03eCZ7/c02Q8x0Cg6gE/XeXdNHrMnyLlK//9HjAAAIABJREFUCICXl+zFc/waWUKp5Gzw5ZZCquo9JMZIsRfhX1SSqdJaF19uKcTlDX3a1asb5JTUsSmvksGZSUEfMzMlhpHdU1iZW+b3gxXXfxJx/Sf5fa6iQEqshTFZsnxCCCFEeCzLKaXOz3J0aBhFr1r1Aann3YGt+zAUzYRj7zocu1cdSaacHp2Xlu5l5vgeJ1S0SYhTxcJdJfj0xtd8oVRyNmsq6/MqmNQ7PRwhilNAVJKpt1YdCHhfMCcMl9fHS0tyefqaYSEd98npQ5ny5BJKalz4jOC3IsZaTLx+80iUAE1/hRBCiBP1ytJ9R0o2Hy3YUXSAereXlXvLGJvVrtHr5NXksbF4I9XuasyqmVR7KmM7jcVuCm0fshAni0CDE6FUcjYMg6r6xk1/hTgsKsnUa8v3+Z2VCvaEoRvw9bYiapyekKr7pcRa+Oj2sUybs4LiGhfuADNjh2kqxFnNvDVzlDRMFEIIEVa5Jf73doQyiq7rsL+snrE/rvTz6T6WFizllS2vsLVsK5qi4dW9qIqKpmrohs7FWRdzbd9r6ZbYrRV/GiGiTw0wCB5KJWdFUdBkplc0IeIFKNxePeAG2VBOGCZNobDKfyPepnRKsvPFPeO5aWw34m0mYq2NP0QxFg2bWeXK4Zl8ee94+neSKi5CCCHCK9BeqVBG0b26fmQ0vspVxdWfX82vFv+K9cXrcflc1HvrcetunD4ndZ46HF4HH+z6gCs+vYKXfngJI4RVG0K0dQl2/3MGR1dyDkaqFB8TTYj4zFSdy4tJU/3OCkWqgW6CzczDU/py/3m9+WprIR+uz6do/yF8LjdpWZlMGdCRS4ZmEGuNehsuIYQQpwl7gD6GoYyim1SVOKuJKlcVV312FcX1xXj0ppcoeQ0vXp+XuZvnUu2u5pdn/LLFP4MQbcmlQzN4ccneRtecR1dyVlQNW/ehKKoJ576NOA/8QPKkm495/PBuyZEMW5xkIp4txFg1vAEq+IXaQDfedmLhW0wqUwd3YurgTpS/+SbunBw63HLlCb2mEEII0RK92sdzoKy+Uc3ZUPohqgp0bxfLbQtuCyqROprD6+DtHW+TlZTFxdkXt+AnEKJtuX50N15astfvfcFUcjZrCteO6oLV1Pwgvzh9RTyZspo0EuxmKv1s5gupga5Pp0Ni6zbQleUNQgghomXGuO4s21NK/XFFKEIZRU8wqtEqPiKnco/fRKpiSQWlX5XiLnaj2f4/e/cdHlWVPnD8e++dnp6QQkiAJBBKpPfQFTvYQUVRRAWxoru6u7qiq+vqb2Vdu4BdcREVewcVAQm9NylJgBAICSF1+r3390ckEDJJZoD083kenl1n7sycIeTe+57zvu9RCO0XSuw1sShBFTeLTtXJ8+ufZ1zKOGSp0baiFISzIi7MwuDkSJbvPuqz8VhtnZyhIgtq0pAO9TlEoQVolDPlTUM6YDZU/2h/N9CVJTivawyhATSfqJPo1CcIgiA0okFJkYTXsJdN6MCriDj3VoozFpDz0g3kvDaZ0vVfY+18osbYapSZ1l3l3Y2v4PA4qr1HwXcFHP74MHET4uj+aneSH03GfdRN9qxstJPSoOweOytzV579LygIjeCpK3sQZAl8ZclskLnnvE4kRNjqYVRCS9IowdSNgzr42Dq3gj8XDLNB4fYRyQ0zWEEQBEFoAJIkMeO8VKxG3zd+wWmjaXvz87R/YCGJd88jZvzjWBK6VT5vVGRGXTyaFSYF/ZT5QdWhcuTzI8TfGE9IzxAkg4Qp2kTinYm4C9wUryiuPNbutfPW1rcQhJYgIcLG/NsHE2ox4G9TPpMiI0lwrthbSvBDo3RYiAm1MKZbLD/tyPPZIr22ZVdF12gfEUSfADbsFQRBEITmYHz/BNbvP8YXG3Nr7O7ni9Wo8N6tg8gt345JMeHWqnbNte+2o3k0QvuFVnlcsSiE9AyhbFsZESNOFNnvKNxxZl9EEJqQtPgwvr5nOPd+uIGdh0rwajpeH5v5BpkUFFnioYu6Em4zMvntNSyYNoSkNkGNMGqhuWi0hOhnr+lJQoQVo+J/ep0sQYik8uSy2XgO5p79QYmaKUEQBKERSZLEv67swbUDEmpcoTqZUZEINhuYd9tAeieGU+opRfeR+6GWqRiCDUg+rrmGMAPesqrdcR3e6mmCgtCctY+y8fldQ/n63uFcOyCRILOCJIEiSSiSRI92oTw7vhfrHj2fGwd3YGzPeB44P5Ub31hFbpH4fRBq1mi9v4PMBhZOT+eGN1aRmV9e5wyc2SATGWTiw6mjCPmmhH3XX0+7F17A1rfPWRmPJGqmBEEQhCZAliUev+wczu0ay50frMPp0ZAk8KgngqQgkwKSxPUDE7l1WBJtw6wAmBUzEtWvZ0qwgrfMi67q1QIqb7EXQ3DV2wGDLLYGEVqmTjHBPHVlD566sgdeVcOr6ZgNss/7wOsGtqfM5eXGN1bx0R1DaBNsboQRC01do54tw20mFk5P58PV+5m7LJMiuweHW60ypxZkVjAqMrekJzE5vSNhNiPcdBOmDh3IuftuYv/2N8LGjW207yAIgiAI9SHEYiDcZuKD2wbx5aZcDhQ6sLu9RAWZ6NM+got7xFVr2Rxji0HTq6fP2zrZkAwSJetKCBt4YiN61alSurmU2GtiqxwfaYmsny8lCE2IQZGpq+v5bcOTKXF4uOnN1cyfOpgw61lsfia0CI0+9WQxKkwemsTN6R1ZnVXIV5tyyStxoeo6bYJNnN89jnO7xqCcUjUYPHIk7d9+m5zp03FnZdLm7ruRZNHGVRAEQWgZ3votm8npHekQFcQ953b26zXdIrsRag7F7rVXeVyxKcRcEUPuvFxki0xw92A8xzzkvp+LMdJIePqJOmSLYmFC6oSz+l0EoTm7//xUSpxebn1nDe/dOhCbqdFvn4UmpMn8a5AkiUHJUQxKjvL7NZYuqXT8aAE5d92NKyuL+KefRracwd5TomRKEARBaAIOFTtYuiufp648J6DXSZLELWm38Pz656vVPUVfEo0SpHB4wWHcR9zIVpnQvqEkTktENp6YjNR0jas6X3VWvocgtASSJDFzbHceWriZae+v442b+4uNfIVKzX4px9CmDe3fexdJVth308148/NP741EzZQgCILQRLyXsY8r+7Q7rf0UL0u5rMbnIkdG0vmpzqS9nka3F7vRbnK7yg17AYyykdHtRxNuER1zBeFksizxzFU9CDYbmPHhRrxq9XRaoXVq9sEUgGw2Ez/rWYJHjiDr2mtx7tzZ2EMSBEEQhFqVubx8sGoff/lkM1PfW8sDCzby0s+7ycwvY8GaA9wytONpvW+wKZhZI2dhUQLL1FAkhWhrNI8OfvS0PlcQWjqDIvP8db0pc3n566db0Hy0VxdanyaT5nemJEki+q67MCclsf+WKbR96ilCzvW9V1WNRGt0QRAEoZ5l5pcxZ2kmX2w8iCxJ2N0nutmaFJkXFu8m1Gokt8hJh6jT299mRMIInkh/gpkrZuJUnXUeb5SNxNhieOeidwgzh9V5vCC0VmaDwpxJ/bjpzdU88fV2HhvXXXSEbuVaxMrUyUIvuYTE2a9x+PHHOfrW2+h+B0jiF0EQBEGoXz/tyOPSF5fzyboDOD1alUAKwP1Hq+bCcjdT3lnDrB9+D+A6VtXFyRfz+gWv0yu6F2bFjEGqPn9qNVixKBau6HQFH4/7mLiguNP6LEFoTWwmA29OHsDqrEL+u3h3Yw9HaGQtZmXqZNZevej44XwOTL8Td1YmcY8+imQy1Xh8mcvLQZeEXTdjtbsJsxrFLIMgCIJwVv3y+xHu+t96nB7/ai0cHpU3l2fh1TT+enG30/rM3jG9mXfJPPaV7GPe9nlkHMqgzF2GQTYQZY3ims7XcGnypdiMttN6f0ForcKsRt67dSAT5mQQajFw2/Dkxh6S0EhaZDAFYIyPp8MHH5D75z+z//apJLzwPEr4iYJar6rx884jzP51L5tzijFiA+0cvE8tpm2YlWkjk7midzuCzC32r0gQBEFoIIeKHdz1gf+B1HEOj8q7K7Lp2z6CC9JOf9WoQ2gHHhn8yGm/XhCE6toEm5l36yDGz84gxGLg2gHtG3tIQiNo0ZGCEhxEwisvc2TWf8i+9joSZr+GOSmJ5bsLuGf+BtxelfI/Uiy8SIACqs7+QjtPfbODJ7/ezoMXduXWYUmN+0UEQRCEZu29Ffvwqr7T9cq3L6Fkzed4juYgm6wYY5IJS5+AJSENAIdH4/nFu88omBIEoX7Eh1uZd9sgrpubQZDZwNie8Y09JKGBtehgCkBSFGL/8hCmpI7su3ESG2c8xczN9jpnB4/nsc/64XcOFNpFgaEgCIJQxa68Ut5dkc223BLKXV5sJoXU2BBuTu/IOe1ONHFwezXmrdqH20cr5ZLVn1G86hOiLrgLS1JfJMWAI2sdjt2rKoMpgMyCMnYeLqFrXGiDfDdBEPyX1CaId24ZyKQ3VxFkNjC6S0xjD0loQC0+mDouYsIENlrb8ujSo7gU//ftcHhUFqw5QNswC9NGptTjCAVBEITm4KcdeTy3aBd788vweDVOXnDacrCYrzbnkhBhY8aYzoztGc9PO/LQfDSR0FzlFC3/gKhLZmDrkl75uK3TIGydBlU51u3VeHdFNk9f1bPevpcgCKevW9tQ5t7Un9vfXcurN/RlUHJU5XOHi528vzKbX3flU+LwYlAkYkIsXD8wkYvOiRMbADdzrSaY0nWdR7d7agykakuzcHhUnlu0iwn9E4kIqrmRhSAIgtBy6brOf37cxZvLM3HUkN2g6eD0aOw5UsaDH29mZeZR4kItOE7p2gfgOrgT3evGljqkzs/WdNh5qPSMv4MgCPWnb/sIXry+D3d+sJ53bhmIIkv83/c7WZl5FJ2KSZHjMvPL2ZJTxCOfbWXiwPbcN6azqNNvplrNT211ViGF5W6fz/mTZiFJ8OGa/Uwf1akhhy0IgiA0ES/9vIc3l2fVGEidyuFRWbjuICnRQfja21N1lCDbQpFk/2aly1zeQIYrCEIjGNqpDU9f1YMb3lyJy6Ph8tZ8vjhet/9uRjaLd+bx4dTBxIQEttm20Pha3D5TNZm7NNPnzODxNIvI86dj65KObLIgKQZsnQYRMXpK5XFOj8Yby7PEbteCIAit0Lp9hby2ZC8OT/XrSG0cHpXf80p97mSoWEPR7CXomn/vGSxmrQWhWQixGHG41VoDqZO5vBr7j9qZMDtDTJo0Q63mzLxsdwG+wqBA0iwcbpU9+WWkxoac/QEKgiAIDUbTdEqcHsrdKsEmAyEWA7Jcc5OhV5fsxVlDIFVXNz6PqiNLcGrZlLldVySDEfuuDIK6Dqt1vIosiWuPIDQDDrfK7e+txVND986aeDWd3CIHj36+lf9e27ueRifUh1YRTLm9Gl7N9+xAIGkWiixRZPec7eEJgiAIDeRAoZ13V2Qzf81+3F4Ngyzj1TTMBoWJA9tzU3oHEiKqbmB7pNRZ44Scv934NE2vyBc/iWwOInzYDRQumo0kK1iS+iDJBpzZG3Hu31wlO8KoSNyc3vFs/lUIglAPvtx00GfDGah74sWt6ny75RCPj0sjzOZ/szShcbWKYKo2J6dZ1BVQicbogiAIzVOJ08O98zeQsfcomq5Xzhp7VPWP//Xyzoos3s3IZlinNrxwfZ/KtLovNhz0ef4PpBvf8a01Tr3FCh14FXJQBMUZCyj4ehaSyYo5thOhQ66tclzHqCC6x4u26ILQlOm6zmtL9lZur3MyfydeZAk+WnuA20ckN+TQhTPQKoIpk0HGIMs+9/gIJM3Cq+mEi5kCQRCEZqWgzMWVr/5GXrHL53XgOLeqAzrL9hQw7qVlLJw+lMggE5kFdp+1D4GkievgM9UPIDhtNMFpo2t8rdWoMGNM5zo/QxCExvV7Xil5Ja5qjwcy8eLwaLyXkS2CqWak1TSgGJHaxufM4slpFvZdGWgeJ7rqxbF3Lcd+eavKsTaTQqfo4IYZsCAIgnDGHG6Via+v5FCRs9ZA6mRur0bOMQc3vLEKp0elvIaC8EC78aXGhmAxBnbZtRqVP/aiaRvQ6wRBaHiHipwYlOp3m4FMvAAUlPnuPi00Ta1iZQpg6ogUVuw96nPp1Z80C4tR5rZhSbUWKAsth6rplDg8qLpOqMWIydBq5h0EoUX536p97C+04/XRibW2+gWPqpNdUM7Haw/UmJEQSJo4QFKbIO4/P5UZH27E5VV9tks/mdWoMGlIB/52cVe/vqsgCI3L6VF9rj4HOvHi8XPiR2gaWk0wNaBjBJFBJuxuh8/n60qz0HW4dkD7+hqe0ATouk7G3qPMXZrJst0FKIqERMVJrVvbUO4YmcKFaXEisBKEZkLXdeYuy8TpY18of+oXHB6V2Uv2Mq2tG6vuxSFVvWQGkiZuMcr0SgznwrQ4vrx7KLOemccSSyKywVCl3bpRkZAliT7tw7lzVCdGpEafhb8JQRAaQojFeGqfGSDwiReL0b+gS2gaWk0wJUkSs8b3YvLbq31eWGtjNSrcf35nIoJM9TQ6obGtzS7knvkbKHZ4KlcvVe+J6aVtuSX8deFmHv5sCzPHdmd8/8TGGqogCH5asfcopc7qKXqB1C8cPVpMUPZqiBoGp1w6AunG51F1bEaFIyVOEgtz+PPa+fzf19/x+dYjbDxQRJHdQ5DFQHKbICb0TyQxsmpHQUEQmr7UuGCf9ZWBTLwAdIkWv//NSasJpgAGJ0fx7NW9eHDhJr8DKoMsMXFQe6aOSKnn0QmN5cdth7n3ww11/ps4vlP5zC+2cvCYgxnnpzbE8ARBOE3fbTnkM7U7kPoFl8HMtrGTGK/pzF+1H88puXn+duPTNJ1nvt/JP7/dQT/XESZPuJ1OoTYmD006sy8pCEKTERNiIT05iiW78qs8HsjEi1X3Mu6bORzM/JywKy4nKD0dydCqbtebnVb30xnXO57IYBP3zN+Ay6tS7vK9CaPNpKDpOookcUmPuAYepdBQ1u0r9CuQOpnDozFnaSbRoWZuGNQh4M/cc6SUd1fsY8ehEspcXoLNBrrGhXBTekexKacgnEV5JU6fjwdSv6ADeaUu/npRVz5Zl4PHR3BWV5r48fc5HtitlCPZdMRA+ntreXliX5HSIwgtyNSRyazOLqw2kePvxIslyMr1816k/IfvyX/lFXIfeYSwS8cSdsXlWLqK+smmqNUFUwBDO7VhzSNjWPL7EWb/upcN+4swKjKSVFEf0y7cyvRRKVzWqx3L9xRw34cb+ebe4YRZRVv0lkTXdR5YUPMqZW3F6Q6PyhNfbWdcr3hCLf79u/h5Zx7PL9rNrrxSPJrGyfWlG/Yf45P1OXSKDmbGmFTGdI89G19REFq0w8VO9uaXUer0YjMpJEbaSGoTVOfrAq1f0HWdxEgbcyb14/b31gacKl7t/SQZu0dj+Z4CbnhjFfNvHyxqMQWhhRiSHEVSmyB2HS6ttpLt3zYIqZijIjFPnEjkxIm4srIo/vJLcu68Czk4mLDLLyd07FiMsTH1/VUEP7XKYApAkSXO6xbLed1iKXd5KXJ4UFWdMKuRUKuhcoPF87vH8tueAv66cDOv3tC38nGh+Vu//xj5ZdX3gwD/itNlSWLh2hxuGVZ7mo6u68z64Xfe+i27SqH5yVQdVI/G1twS7pm/gZvTO/CXi7qKf2+CcApN01m+p4A5v+5l7b5jmAwyuk7lZFhym2DuGJXCRWlxtAmx+HyPQOsXokPMAAzvHM2bNw/g9vfWomq6z9qIQDg9Gttyi/nbp5v5z4TeZ/RegiA0DZIk8d6UgVzy4jKOlrl9dhL1xWqUuaJPO24aUjXjxZyURMx99xF9zz3Y166l+IsvyLzsMqw9ehB2+WWEnHcesk3UWDUmSffVw/EP/fv319euXduAw2manB6VK19dwY2D259WWpfQNE17fy0/bs+r1sZUc5WT88rNRF0yo84brbgwCxl/PbfWoOe/i3Yxd2lmjYGUL1ajwq3DkvjzhV38fo2/JElap+t6/7P+xg1MnJ9an8PFTia+sZK8YmdlDaMvQWYFs0Hh3vM68ez3v/s8tmT1pxSv+pSoC++qtX4hyKQw96b+DO3UpvKxvBIn72fs472MbFRdx+Guuc15bSvcx5kNMksfGk1sqO/grzVpCecncW4SAPJLXUx8fSUHixw+azePUyQJo0FicnpHvydRNYeD0p9/pviLL3Bs3ETIeecRdvnl2AYOQJLPbJW72O7ho7UH+HDN/spg0GZSGJQcydThKfRICDuj92+uajs3tdqVqUBYjAovT+zD+NkZ9O8QSZc4UdfSEvyyM9/nfhCBFKcX2z1kFpSTUsNmziszjwYcSEFFS+Y3l2eRnhJF+kk3cYLQWh0otHP5K79R7PCg1jHTW+6qqIf99/e/+9xAE/yvXwixGElPiaryWGyohT9f2IX7xnTm551HmL9qP8v3FFSbgfZnhfu49zP21cvkiSAIjSM6xMw39w7nh22Hee3XvWTml6Hr4PGqyIDJZEDVdC46J47bhydzTjv/gxTZaiXs0ksJu/RSvPn5FH/9DXnPPINaXEzYuHGEXXE55uTkgMZ7rNzN419u4/tth5Glivrw48pcXr7ZfIjF248QH25h5rg0RoptGyqJYMpPKdHB/PXirtwzfz1f3DUMq0kUDDdnbq+GV/OdohNIcbpBkSiy17xT+au/7KkxkKprxtrhUXn5lz0imBJavXKXl+vmZlBkd9e50e3J7G4VkwRG1YNHqV7bWFf9gkGWGN0lGlXTfQZlRkXmgu6x/P3zrdUCqUDar7u8Gu9lZDNjTGcMiqidEoSWwmSQGdcrnnG94tl5uIQ12cc4+O2PWMLDaD8qnfO6xZ5xPb4hOpqoWyYTdctknL//TvEXX7L/5skY4uIq6qsuvQRDRESt73Gg0M6EORkUlLqq1Xkdp+kV9yV788uZ9v5aHr64GzeldzyjsbcUIpgKwPh+CSzfXcCT32znX1f2aOzhCGdAqyW9NZDidN3rxbnvAFqMFdlqrfLcoWIHq7IKfb7O3xnrdfuOcbDIQbtwq8/3EYTW4OO1Bygs9x1I1TUp4dU0LGYjmibVuaJ1KlXT+XJzLt9uPcxNQzowaXAHYk5JxSt1eX1OqASywg0V9V4FZW7iwkSqnyC0RF3jQukaF0reLwUYIoxE9U04659h6dIFy0MPEvPA/ZRnrKT4yy/Jf+EFbAMGEHb55QSPHoVsqrpn6rFyN+PnZHCkxOn3ZJXTo/H0dzsJsxq5vE+7s/49mhsRTAVAkiSeuvIcLn1xOd9tOcTFPdo29pCE02QxKsiS5DOoCqQ4XfV4cT73b3ZlbkUJC8PYoT2mxPaY2icyj/boPha/Apmx1nWdj9Yc4H6xp5XQSum6ztxlmVVSTo7zZ1JCk2RUZCJtBorsnhpnXX1+NvyxfYbK3KWZvLEsi+cm9Kpy7i91ejHIMh616gp0ICvcALIsUeL0iGBKEFo6xQBaYKn/gZIMBoKHDyN4+DDUsnJKf/yRY//7H4dnziTk4osIu/xyrL17I0kSM7/YytEyV0Cr/lCxSvWXTzczPDWayCBT3S9owUQwFaAQi5GXru/DlHfW0CMhjIQI0UGluRqcHMXyPQXVHg9oc71gG6M+nYesa3jz8nDvP4D7wH48+w+w59Ah3IbqAXcgM9ZuVSezoPzMvqggNGOrsgopsnuqPR7IpIQsSUwe2pGfd+azLbcYj6oHvEp1vHPf/R9txOFRueqPWWWzQfY5KRN4+3WwGET6uCC0dJIio3vrN5g6mRIcRPhVVxJ+1ZV4Dh6k+KuvOfTwI+iain7pFfyY1wGP6vt8WGcDHR0+XL2fO0d3arDv0xSJYOo09EoMZ9rIZO77cCMLpg4WOe7N1B0jU9iw/5jPTl/+FKdbjDJThnVEkSVAwRgfjzE+nqDBFTdy2rtrYUdetfcOdMa6zFn9RlIQWovVWYU4fPyOBjIp4fCorMk6xsLp6ew8XMIby7L4auNBXDXcQNTG6dF4+LMtdGwTRN/2EYRZjT4b2QTaft2jakQFt+7ZXUFoFWSl3lemamJs1442d0wjatpUnFu38tJHGeB2gVL93OPPyr/Tq/Hm8iymjUz5416odRLB1Gm6bVgyy/cc5fnFu0UHpmYqPSWKYIuhxhbLdRWn6zpcP6B9jc+H23wXlQY6Yx1uEzdYQsvgUT2UuEswKSaCjEHIUt0TUQVlLnyFPIFOShz7o66pa1wos8b3IjUmmOcW7cLpY6+oumZjnR6N//zwO+9N6EbZ9z8wvCSfX6yJaCeNJZAVbgkY3imaILO4JAtCSycpMrp6ZnvUnfEYJAlrjx58810BLo+92vOBrPw7vSob9h+jf8fIeh93UyXO3KdJliX+M74XY19aJtpXN1OyLPHM1T2ZPm8dTh/1GLWxGhXuGJlMVLC5xmN6JoTx9ebcarUegcxYW40KvVrpng5Cy+DwOvg+63ve2voW+0r2YVSMHN/fcEyHMdycdjNpUWk1vt5Uw8p/oJMSxpMmTXVd5+0V2T4DKX+bw6zZc4QVl/ydpAE9mHrepaxY56n2u+5v+3WrSWHqyMDaGAuC0Ew14srUqQrLfXcjDmTlX0Iiv9R1tofWrIhg6gxEh5iZNb4XD3y0iW/uHVbrjbXQNI3uEsPMsd154uvtfgdUVqPC5b3jufe8zrUed0Wfdvzzmx3VHg9kxlrTda6sh44/glDfdF3n7W1vM3vTbCQk7N6K2U+3euLi/UP2D/yy/xfahbTjuVHPkRxWPaCIC7NgMsi4vac/KYGuE7R6GXu+fBJLl65s7diT4rK4aocF1BxGlln24H8YfkUvEoAO2UvZlVdarYi7rhVuCWgTbGZQUuud1RWE1kRSZDRP00jf95yFLWJ0XcfdyCttjU0U+5yh4Z2juaJPOx78ZHPlbKvQvEwc1IEXr+tDsNlAUC37h1l9I9lNAAAgAElEQVSNMmaDzJ2jUnj6qh517lIeYjEyrle8zzzi0IFXEXHurRRnLCDnpRvIeW0ypeu/xtr5xCyQIklc0qPtGe9BIQgNTdd1Hl/xOK9tfA2H11EZSJ1K0zWcqpPMokyu//p6NuVvqnbMRefE4aso6eRJCfuuDDSPE1314ti7lmO/vFXl2CCzgSl/vYXEV14h5IIL2F4uVzaUOFkgs7EeXWJt7onmMHMn9T+tND2bSeHNm/vXeT4RBKEF0DQkyYPkLQa18QMqm8n3Oevklf+6yLJEqKV136eIlamz4E8XpDJ+dgZv/ZbNrcOSGns4wmm4IC2OdY9G8+3Gg7zwzmJyw+Iw/dFZy6tphFqMTB2RzPh+iYTVUAvlyx0jU/hm8yEcPk5Idc1Ymwwy00elBP5lBKGRvbjhRb7L+g6n6vTreB0du9fOtEXT+PDSD+kY1hFd11mVVciby7NqbNnrbxqdzWxgeJdYZDkOc6dOaMbfUX/ZU+39Aq3DKnGcWGVrH2VjwdQhTHx9JSVOT51thmWpIsh7b8pAOseG+PV5giA0Q7oOB1bDihdh1/dEHj85/PN5iO4Gw+6H7peBoeGzm/p1iGDxjrxq81WBrPy7vBpp7ULrcZRNnwimzgKjIvPS9X244pXfGJQUyTntRI1Lc2Q2KFzgPUTvw98R9PT7HCv34NU0wm0m2oZakE+jU02nmGBeuK439364IaC6LItR5j8TepIqbrKEZuZA6QHe3/4+LrV6Dv2xZcco+KEA9xE3ikUhtF8osdfEogRVBC92j51/ZDzBZbFP8ObyLOwulSnDkpjQL4H7FmzE7qNZTF2TEhajzO3Dk6r8/lqMMrJEtYAn0Dosi7HqMd3jQ/n2vuE8/e0OftyehyRR7ffebJDRgXO7RPPwJd1pHyW21xCEFit3I3xyC5TmgccO6FSeiXTgyDb4ekbFn/OfgAG3Nujwpo5I5rc9BdXOrf6WI1Q0z2lDTEjr3h9PBFNnSWKkjccvS+Oe+Rv46p5hBIuuTM1S2ZIlBI8eRXSI5aydHC5Ii+Pl6/tyz/wNeDWtxv0cAIyKhEGWeeG63lyQVr2mQxCaug92fIDmY7fqgu8KyP8un4TbEgjuHoznmIfc93PJnpVN0iNJyAYZHZ21hzdQemATD5zfn1GpMciyhK7rXJgWx3dbDwU0KWFUJLrEhnBzescqj8eFWbEYlWo3EIG2M28XYa32WHy4lZcm9qXI7uajtQf4YmNuZSfBMKuRS3u25foB7UWNrSC0dJlLYP71fwRRtXCXVfzvj4/AsWy44Mn6Hlml/h0iiAwyYXc7qj3nz8q/1aQwdYRoniPu+M+icb3iWb67gJlfbOW5Cb0bezjCaSj75Rfin/vPWX/fMd1jWfTACN5Zkc38VfsBsLtVdCpmdixeF7LNxrUDErllaBKJkWK2Wmh+nF4nn+3+DI9WtRZAdagc+fwI7W5tR0jPitVWU7SJxDsT2fXgLopXFBMxIgIAgywxpM9Ozu16aeXrJUni2Wt6Uu7ysmx3AQ5P3Xn8ZoNMcnQQ700ZhPmUzXAvSIvlkc+2VHtNIM1hgkwKNw7uUOPnh9tMTB2RwtQRIlVXEFqdQ5v9C6RO5nHAmjcgOAbS76m/sZ1EkiQeHdud+2rInqlt5d+oeki1yQwUzXNEA4qz7bHLurM5p5jPNuQ09lCEALkys9CcTizdu9fL+ydE2Pj7pd1ZP/N8nrm6J/efn8qtw5K4//xUHji0lN+ubsfMcWkikBKarQ1HNvjcO8q+247m0QjtVzWvXrEohPQMoWxbWeVjXt3D91nfV3sPgyIz+8Z+TB+VQrDZgK2GZjHHG8Vc2acdn9051GeNY6jFyNiebU+7OQxU1DsNTRFbYgiC4MMXd1ULpDo+X0rMs6WUu09kp7yx3s2od040ssFjh5+fhLL8hhopF6bFcf+YVCxG/0MCkyITH2bh8R+eo/Dtd1p9AzaxMnWW2UwGXrq+Dze8sYreiREktQlq7CEJfipbsoTgUaPqvauW2aAwrld8lccOLQ1H27geevWo188WhPpU5CpC97HFrlqmYgg2ICnVf7cMYQYc+6qmmJR6Sn2+vyxL3HteZ6aNTOb7rYeZszST7IJynB4Vk0EmNtTCLekdubpfAiF1dJe6fUQy32w5hOqjU4Q/dVhTRySfVh2lIAgtXN42KNjt8ylVhxdWuXl4eG1pvhKsextGPlQ/4/Nh2sgUwqxGHvtyG4DPbqd/jAyrSaFb21DevmUA1ik9OTD9TlyZe2k7cyaSydRgY25KxMpUPejWNpQZYzpz7/wN1fZGweuC0sNQdABcZb7fQGgUZb/8QvCokY3y2bb+/bGvXdsony0IZ4umaz5nKJVgBW+ZF91HvaC32IshuOq8Xl2znGaDwuW92/HtvcPZ/sRFZD59KTufvJhfHxzN5KFJdQZSAF3jQrl/TCpWo3+d+058tky/9hHcMlR0bhUEwYeMV0H1vRnug+kmZq1wUeSs5RzndcKq1xp8Y9/rBrZn6UOjmToimTCrkWCzQpBZwWZSCDYbMBtkRnWJ5s2bB/DJHUMItRgxxsfT4YMPUI8Wsv+22/EeO9agY24qxMpUPZk0uAPLdxfw7+938vdLu0HWUvjtBcj6FWQjSFLFL1t4exg6A3pcAyaxitVY1OJinNu3EzSk7v1l6oOtfz/ynn4aXdfFfjNCsxVmDvOZ5mfrZEMySJSsKyFs4Ilup6pTpXRzKbHXxFY5PsjYMOfCqSOSsbtV5izd61djC6tRpndiBK/f3N9niqAgCAI7vwLddyDUP15hVEcDs1a4+Oe5tTS58rrh0CZo17eeBulbbKiFP13QhfvO68zqrELyy1y4vRqhViO9E8OJDa0+ZiU4iISXX+LIf54j+7rrSHxtNubk1jXZJIKpeiJJEv++pid/+u9bOLdfhcVbAu4/8mI174kDCzPh+7/B93+FUX+rKDoUN9MNwulRcXk0gi0GypYtxzZwILKlcdp7Gtu2RbZYcGdlYU4WnXGE5ql3dO9qzScAFJtCzBUx5M7LRbbIVbr5GSONhKeHVx5rkAyMShzVIOOVJIn7z0+le3woz36/k4NFTlxetVrL9CCTgtEgc9uwJO4YmYJBEUkdgiD4oOt1Zh09MdrM0LfKuW9QLSlxkgyOwrM8OP8ZFJn0Tv7XhEqKQuxDD2JOTmLfpEm0m/Vso01ONwYRTNWj8NylzNUeQ7HXsXGl548ga8nTFcHV2P+KgKqe7M0v463lWXy6/iAur4oiS3g1nXjNzuS0MUx0eQlqpLb21v79sK9dK4KpZmTTgSI+3XCQ3CIHbq9GVJCJEanRXNwjrloHudYg2BTMRR0v4uvMr1FPmZmNviQaJUjh8ILDuI+4ka0yoX1DSZyWiHxS4bMiK9zY/cYGHfeFaXFcmBbH5pwi3lyWxZaDxZS5vFiMCgkRViand+TcrjEiiBIEwQ+1pymfE6MwNtXAM8vddIuu5ZzSDJs6hF9zDcbE9hz805+IvvtuIq67tu4XtQAimKovuRtgwSQUtY5A6mQeO2xeACFxMOqv9Te2VuhQsYO7/7eBrQeLUTUd7x9Tz9ofNRwHJRv/zZP5zz8XcduwJB44v0uDF5fb+vXHsW4dERMmNOjnCoFRNZ2F63OYvWQvh4qrr2T8sO0wj3y+lesHJjJ1eDIxPtIiWrKb0m7ih+wfUNXqaS6RIyOJHFl7G93O4Z1JDmucCYWeCeG8cH2fRvlsQRBaAEkCo+3E3lE1+McoC33nlPGnITU1otDBGnH2x9cAggYNpOMH8zhwx3RcmXuJ/ctfkJSaJxe9qkaxw4NX0wmzGqttht4ciGCqvnxxj8+2mHYPZN0XTJCp4kb9jfVu5m32sGTyHzUCHjssfw763gSh8ae+q3Aa9hwpY/zsFZQ4vT47dx1n/6Nm4s3l2ew5Us4rN/Rt0LoI24D+HJ0zp8E+Twicw61y+3trWbfvWI17HZX/sRHsuyuy+XhtDh/cNohz2oX5PLY5KXd5WbwjryKA9GiEWAz0SAijf4eIKnV+qRGpnJc4hm+zf0THdxF2TSyKhUcGP3K2hy4IgtBwUs6FnV+Dj83Lj+sUKXNtmpEXV7vpEVPD6lRc8+3ua+rQgY4fzidnxgwO3Hkn7f7zH5Tg4MrndV0nI/Moc3/NZNmeAgyyhCSB26uREh3M9FEpXNKjbbMJrETOQn04vBWO7vH51PG2mLXSgTVvnv1xtUJHSp1cOzeDIrun1kDqZA6Pyq+78vn751vreXRVmZKS0JxOPIcONejnCv7xqBqT3lzFmuxCvzaN9ag6xQ4P187NYHee71bfzcGeI2X87dMt9P/nYh7+dAuzfvid/y7exTPf7eTmt1Yz/N+/8H5GNmWuilrQcpeXfbsuJVxKxaL4vypnUSzMGjmLc9qcU0/fRBAEoQGk3wuGus99M0eaq+w5VUkxQb9bwFBb+/SmTwkLo/3cuRjbtmXf9RNx5xwEYP3+Ywz9v5+57d21LNmVj6rpuLwaTo+GpsPuI2U8+vlW+j25iPczshv1O/hLBFP1YeUZtsVUXbDmdVCrF3ILgXnm250U2z0+M5jLty/h0Lsz2P/cNeS8PIm8jx7DmVOxx4LDo/L5hoNsOlDUYGOVJAlbv77Y165rsM8U/Pd/3+1kW25xjftv1MTuUpn4xio8amCvawre/i2LsS8t46O1+3F4VMrdamWKrFvVsLtVco45+Nd3Oxk9awmb9hdx45ur6BAZyuKJ73NZymWYFBMmueZCa5vBRrg5nDnnz2FkYuNsTSAIgnDWJPSvKNc4RfaMEMYkn0gISwyTcf499ERm0nGSDANvr+9RNgjJaCTusccIHz+e7Ouv47tvMpj4+kpyi5zY3TVPSpa7K643//p2J099s6MBR3x6RJpffdj5zZm3xdQ0yN0IiQPqaZAtX4nTwzdbDlXe/FV5bvVnFK/6hKgL7sKS1BdJMeDIWodj9yosCWkAuLwqbyzL5KWJDdea1NqvH/Z1awkbN7bBPlOom8Ot8sGq/ThqaJ9dvn0JJWs+x3M0B9lkxRiTTFj6BCwJaeiA3e1l8fY8Lu7RtmEHfgZe+WUPL/+8x6+W4Q63itOtcuVrv3FV3wSevqoHkiTx6JBHmdZrGh/9/hH/2/k/vJoXRVLQ0XGrbrpFdmNKjymMTBiJQRaXI0EQWgBJgnEvwgfjweuo+/iT6LIZqf+Uim1zWghJkoi8aRK/hycwY0keLqXufQCPc3hU5q3cR0yImdtHNN3mXOLqVR/qKDz0ry2m1KhtMVuChWtzkH10RdRc5RQt/4CoS2Zg65Je+bit0yBsnQadOE6HH7fnUWR3E25rmF29bf36U7xwYYN8llDBq2rsPFxKsaNiJTjcZqRrXGiVermvNuXW2GDTn8C83KXy2q97m00w9dOOPF76ebdfgdRxOhXNp5buysfp0bCaKnLdY2wx3N3nbu7odQeHyg5R4inBJJuItEQSZY2qp28gCILQiJKGw2Uvwpf3+h1Q6YqFslwzOiMJrefhNYaZe+QaA6naJiQdHpVnf/ydq/q2Iyq4aaY+imCqXrTetphNycL1OT5rW1wHd6J73dhS694DwSBL/Lorn8t7t6uPIVZj6dYVT+4hvMeOYYhonp18mosjJU7eX7mP9zL24dW0ysBb03TMRoUpQzty3cD2tAk28/qyTJ8pCf4G5gC78krZd7ScDlFNf3Pu//t+Z42BVG0XPYAyl5evNucyoX9ildcZZAOJoYm+3lIQBKHl6TkBbJHw8S0VzShqmmg3WAAdaeA0jAnXsn/qHSDJhF54QYMOtz5tySkm55jvoNKfCUkZmL9mP3eP7tyAo/afCKbqgzEIXCW1HtKS22I2FcfsvmvOVEcJsi0USa67S4xX0yksD6wj2ZmQDAak3v34adEanKkVJ5Fwm4n+HSIabf+rlkbXdf67aBezl2Yigc8aqHK3ysu/7OGln/fwwPmpHCzyfREIJDA3KjLZR+1Vgild1yl3qxQ7PBgViXCrCZOhcUtZtx4s5kDh6V/07G6V2Uv2VgumBEEQWp1OY+DBvbDzK1j+POT/DgYTIIHmBaMVBt8F/W6GoDZYgPavz2X/7VMBWkxA9cayTFze05+QdHo13l6ezfSRnRq0y7K/xN1ZfUi9ELZ+WmPdFPjZFjO+dz0NsHXQaljZU6yhaPYSdE2tM6DSdR0/mwCescz8Mt76LYtPYsairHfCti0AyEh4NI0r+yRw67COdIoJaZgBtUC6rvOXhZv5atMh3HU0kji+MvP84t04a+jeF0hgrusVne6gYvXm8w05zP41k8PFToyKjKbrqJrOyC7RTBuRwoCOVVuON5S3lmf5/LsJZBXuULGTzTlF9EwIr/fxCoIgNGkGE5xzdcWf0jywF1Q0GLOGQ1ginHL9sHTr1uICqkU78nzeSwUyIen0qPx+uJTu8U0vCVIEU/VhyN0VTShO2WfqVDNHmnl/s4/VkxbSFrOxhVmNHCquvmmyuV1XJIMR+64MgroOq/U9DIpMmNX/YsnToWk6T369nfmr9+PVdLz6H8G1q+oN/MdrD/DZ+hyu7pfAE5ef0yRnZ5q6l3/ew1ebDvnV2vy42o4NJDCXJAgyKbz0025eWbIHWZIqUwe92onP+HnHETL2HiUq2MTsG/uRFt+we1RtyilC9TEREchFD2B7bokIpgRBEE4WElvxpw4tKaDSNL3G62ggE5KyLFFkb7hMoUCIYKo+xPeu6MSSv7PKw9kzqq4oHG+LWY0kwcCp9TnCVuGCtFiyCsqrpXHJ5iDCh91A4aLZSLKCJakPkmzAmb0R5/7NRIyeUnmsqukM7VR/RfK6rnPfhxtYvOMIzjpWSryajlfT+XT9QQrKXLx2Qz9kEVD5rdjh4eVf9vhO66ujDqgmgQTmLo/K/1bvZ+muglobO1R0/1OxFzoYPzuDtyYPYHBywzVqOL5f1KkCS4/VKHX6fh9BEAShbi0poKqplUAgE5JAg2UKBUrsM1VfLnsJDNaAX2bHTG6XmyBc1BucqRsHdaixFUjowKuIOPdWijMWkPPSDeS8NpnS9V9j7Vx11n1gUiRtwwL/Ofrrv4t2sXjHkYBXSpbuKuCZ73fWfbBQaeE6390dS1Z/RuFPrxM2eAIJd8+j3fS3Cel7CY7dq6oe6GO15uTA3L4rA83jRFe9OPau5dgvb1U5NtxmYumugoB+1na3ypR31vD74Ybb9Nds8H1BO/miVxdFkrCYmsfO9YIgCE3V8YDq8JNPUvLDj409nNMiyxIWo+/rwckTknXRdJ1wW/1mCp0usTJVXxIHwtWvw8Lb/d9nwGijLOF8Ltt5ATM35XJZr/j6HWMLFxNqYVhKG37ZdcRnY8TgtNEEp42u8fU2k8K0ESn1Nr4Sp4c5SzNPa6XE4VF5d0U200emEBHUMG3bmzNd15m7NLNaIBNIHVBNvdFDB16FHBRBccYCCr6ehWSyYo7tROiQayuPMXtdFJZqeH3MX9X1s7a7VR7+bAsLp6dXe219SIiwsr+weopyoOmx8WG17KMnCIIg+MWvFSpdh6L9YD9a8d/WCIjoWON1q6ENT23Dou151e7FAskUUiSJLnFNs2ZcBFP1qds4uPET+Ohm8DpraYtpBXQYcjcxox9mXl4pU95eQ26Rg2kjkhulCL2lePjSbqzMOlrrTtu+mA0yfdtHkJ5Sf+lVn9ayUlJXxzSoOEcuWHuAO0bWX8DXUuSXujjmI9c60DqgmtQWmMsSKFYrqqd62pu/P+utB4sbrK36zekd2ZRTRPkpNXuBXPQkCUakRtf7WAVBEFqDGgMqVyls/gh+ex7K8uH4Pk6aBywRkH4v9L6+otlFI5o2Ipnluwt83ov5NSFpkLk5vSNGpWkm1Ilgqr51HAZ/3gW7f6xoi3lwXdW2mNYISL8Hek+sbIXeNS6UT+8cyuS3V5NzzM7j49IwNNF/QE1dp5hg3po8gCnvrPE7oDIbZDrFBDP3pvqrSdJ1nTlnuFLi9Gi8uSyLqcOTRe1UHYodHgyKxKnlQIHUAUFFYGRUZJ+rib5IVKxwqhrVVqUC+Vlrms7bv2Xz+GW113CdDed1jcEoy8DpXfRMBplJgzs02YueIAhCc1QtoIo+DN8+VDF7dbzh2cmZUB4H/PwE/PQ4nDcThtzV8IP+Q9/2EbQJNvvMeoC6M4UAbhzcoT6GdlaIYKohyAp0ubjij6OoYhn2eFvM4Fify7BxYRY+vmMId36wnjvmrePF6/tgM4kf1+kYnBzFR9OGcMs7a7C7vJTXEFQZZQlZlhjdJZrnr+tTY47v2VDi8FJQ5qr2eKArJSVOD0fL3USHiM6PtVFkyWcBbKDFrwZZYlBSJGv3HaszODcbZILNBv5yURee/HpHtecD+Vl7NJ1vthxqkGDKoMjcOiyJV5bs8dkoo66LnizBpCFN96InCILQXB0PqMr+dTl652NIWh3d7Y4HWT8/CcUH4aJ/1f8gfZAkiWeu7sGUd9bU2oDJF6tRYcrQjsSGNt3UcTF12NCs4RCVAjFdISSu1nzWEIuRtyYPIMJm4vq5K8kvrX7zLfjnnHZhrPzbebx4XW96FB/AJEsEmw2EWAwEmRVsJoUbBnfghxkjmD2pf70GUlARBPmauQ90pcQgS5Q4fW9OLJwQFWTGrVY/gQdS/Aqg6jBnUj/+fU1P0uJDsRhllFN+hYPMCuFWI9NHprDogZEEmY3oPiK5QH/WZQ3YHW/6qBR6JoRjDnADYYtR4dlretZr0xZBEITWzOLaQFRqUd2B1Mk8Dlj3Nqx8rf4GVof0lDY8c1VPLEb/rytWo8LF58Tx5wu71OPIzpxY6mjijIrMv6/pyYs/7eHq11bw9i0DSIkObuxhNUuKLJHuyaPjvq8Ie/ZTjpS4cHpVQi1GEiKs9R5AncxkkH1uKhzoSokOmEQ6VZ3CbEa6tw1lU05xlccDqQMCGJoShdVkYGzPeMb2jGdXXilfbDhIbrEDl1ejTbCZ9JQ2jOkWU2dqbqA/a18yizP5cOeH7D62G7vHTpAxiK6RXbm267V0CD391SGDIvP25AFcOyeDrbklfr3GYpT5x2XdGder3Wl/riAIglALrxu+ewhJrTq53vH5UuweyLovmCBTxQzfG+vdzNvsYcnkP2ptPXZY/A/oMwnMjXMfeUWfdkQEmbh3/ga8qlZjppDVKKPpFbVW943p3OR7B4hgqhmQJIn7xnSmXYSVa+es5LUb+zKgY2RjD6tZKl28mJDzz6dNsJk2wY2XGhdmNeJVqwdTgXRMA/CoGpGim59fpo9K4U8fb6rWWMGfOiCo2HB32inNPlJjQ3jwoq61fm5FK9fqF4JAf9ZB5hMB19Kcpby28TV2F+1G1VS8+olVqw1HNvDRro/oGtmVO3vdSXq7mrsAappOmduLIknYTEq1C1ap08MVveNZtqcAp1utduEzKhKyJNEzIYyHLuoqzkuCIAj1aceXPrfpgIrMiRdWuXl4eC33NpIEmxfAgFvraYB1G5kazdq/j2HR9jxmL9nLjgNHMZlNSLKMV9MItRiZOiKZ8f0SCWuirdBPJYKpZuSafgnEhpq54/11/OPyNMb2FK3TA6HrOqWLFhP/7LONPRQsRoURqdH8svNIlQSwQFdKBiVFEWQWv8b+GNMttsbGCv4Uv4ZYjKfV3bFP+3BUHzsNBvKzNsgSF3SPQ9d1Xlj/Ah/s+ACn6vT5eV7di1f1sil/E/f9ch+39riVaT2nVQZKXlVj8Y4jzP51L5tyijDIUuW1+dyuMUwdkUy/DhH846ttDEyK4tnxvdA0nV935zMvYx8Hixw4PSqhViP9O0QwOT2J9lG2gP9eBEEQhAAt/2+NnaEfTDfx799c3DnARLilhpUcj72i81//KY3aNt2oyFzSoy0Xdgxm7ZhLiPziG1QqJppjQyzNrqmWuAtrZoZ3jmbebYO49Z2K1um3Dxet0/3l3rsXzeXCck79F/H7Y9qIZFZmVm/bHthKSXJDDrlZMygyL03sw+3vrQ24ANZilHnlhj6n9btmMxm4qm87Fqw5gPeUoMrfn7VBlpgyLImXN77MBztrDqRO5VSdvLnlTQySgdt63sYXGw7y2Jfb8JyUXuE5aYV00Y48lu8pwGZSMMgSi/80CuCPxiwxjO4SE/D3FwRBEM4Ctx2OVG9mdFz/eIVRHQ3MWuHin+fW0qyhNA/KjkBIbD0MMjCuXbuI6RhPUlxoYw/ljIhgqhnq1jaUhXemc8vbazh4zMHMcWkV3cqEWpUuWkTIeec1meBzYFIkUUEm7O7qmzr7u1IyNKVNfQ2vRRreOZp/X92ThxZu9jugshhlXpnYl34dTj+FbcqwJD5Zl1MtmAL/ftapcSEUqFt5b9t7fgdSxzlVJ3M2zyE7N5ZPfjPU+r11vWKTYLtbxWSQWZNdKAIoQRCEpsBZBIqpavvzUzwx2szQt8q5b1At6f+KseK9mkgwZUlNbexhnDFRud5MtQ2z8tEdQ9ibX84d89bhCHBT2taodNFiQsaMaexhVJIkiTmT+mOVfec/18ZqVOp1H6yW7LLe7XhgTKqPKqbqJKBtmIU+7SPO6DNTooMZ3y8B62k0ObEaFf51ZQ/mbp7rM5A6tuwYu/++m21Tt7Hz3p3kvpuLWl71fOBUXXyW+X5AK3Jur8ad89az8UBRwGMWBEEQzjJJxuceHyc5J0ZhbKqBZ5bX0elPahq3/87ff8ec2rQ79fmjafxtCqcl9I/W6aEWI9e9vtLnvkVCBc/Bg3gOHcLWv19jD6WKhI3LeXLTfGwGya+be6i4uX79pv70TGjcHc2bq++2HOK5xbvquCRV0IGcQgeXvLjsjH+//nH5OQzv3CaggMpilHn1xr5EhJWxKX9TtecLvivg8MeHiZsQR/dXu5P8aDLuo26yZ2WjVdlYWEey7kYyVO3MV759CYfencH+564h5+VJ5H30GKHHL3UAACAASURBVM6cbZXPOzwq9y/YiF5DwbMgCILQQKwRFXuU1uEfoyy8vt7NwRLf523d7cB1pLRRzusOt8rPO/P4eO0BFqzZz+L9dso7dGrwcZxtIphq5kwGmVnjezIyNZqrXl1BZr7vwsTWrvSnnwgePRrJ0HQyW0u+/57DTz/N2Oce4/N7hjMgKRKzQcZ46sZFVGwobDbI9OsQwcLp6QzrLNL7TsemA0U88NHGgFZoPJpOQamLia+vxONjryp/KbLE7Bv7MWlIB0wGuda9NoJMCm2CTXxw22BGd4nh892fV7vwqQ6VI58fIf7GeEJ6hiAZJEzRJhLvTMRd4KZ4RXG19zWErq/8/yWrP6Pwp9cJGzyBhLvn0W7624T0vQTH7lVVXnO42ClWpwRBEBqbwQztB9d5WKdImWvTjLy42vfqlFcLZ/9df2bPyFEcfOghihYuxJ1z8GyPtorM/DL+/vkW+j65iHvnb+SxL7fx+JfbeTpyEOf9UMi099eybl9hs524azp3lsJpkySJB85PJSHcyoQ5K5l9Y1/6ixbFVZT+uIjIKVPqPrCBlPz4I4f/+RTt33gdS5dUUoGPpg1h/1E776zI4tsthyl1VcxABZsNXJQWxy1Dk+jYJqhxB97MPf3dDhw1BFLl25dQsuZzPEdzkE1WjDHJhKVPwJKQhlfTyTnmYNH2PC7p0fa0P1+WJR6+pBt3jExhwZr9vLE8i3KnF0WR0DUdt8tN/5QYpo1MYURqdGUtZGZxJh6t6oykfbcdzaMR2q9q4a5iUQjpGULZtjIiRpxIT5RkL7I5HwDNVU7R8g+IumQGti4nWqfbOg3C1mlQlfdzeVVeX5bJqzc0rVVdQRCEVmfoDDi0qcaOfsfNHGnm/c0+VrFMwRjH/ZtO/7gaz4EDlK9cSflvKzjy3H+RrVZsgwcRNGgwtkEDMcaceb2sruvM+uF33lieharp1euGDRZQdX7cnsfS3QWkJ0fxyg19G3Tfz7NBBFMtyIQBicSGWZj2/jqevOKcM7rpa0m8hYU4d+4kaGjN++00pNKffuLwE0/Sfu4cLF2r7lHUPsrGzHFpzBzXNDoOtiQHCu1s2O97haVk9WcUr/qEqAvuwpLUF0kx4Mhah2P3KiwJFT8Lu1tl9pK9Z+X3KjLIxPRRnZg2IoX8MhclDg9GRaZo/BV0v/dtTAlVL2LlnvJq76GWqRiCDUg+VjINYQYc+6oXKUtyRc2V6+BOdK8bW+qQOseq6bB4xxF/v5ogCIJQXzqdB0ZrtWAqe0ZIlf9ODJNx/t1HhzxJhm7jkCQJU/v2mNq3J2LCBHRdx71nD+UrV1H64w8c/uc/MURFETR4ELZBg7ENHIAhIrDaYV3X+dunW/hiYy4ub+1ZHbpekQL4254CrpubwYJpQzAbmk9AJYKpFmZkajTv3TqQ295dS26Rg1uHJfnsXqdqOmuyC8krceLyaoRZjfRoF0Z8uLURRn12ubwqxQ4PmgahVgPun38maNgwZHPjbdJ7XOnPv3Bo5mMkzpmDpXv3xh5Oq/JuRjaajxSCQFZpdh0pZc+RUjrFhJz6NqdFliViQy3EhJhZt+8Yr/SZwL63NuE0/Y7NqNApNoQpQ5MIMVa/KCrBCt4yL7qqVwuovMVeDMHVT++6WrGyqTpKkG2hSLJ/FyuvquH2apgMIjNcEASh0cgKXP0G/O+6Wrv6+WSwwhWvVaQLnkKSJMydO2Pu3JnISTeiqyrOnTuxr1xF0cJPOPTwwxjbtydo0CBsgwdh698fJTi41o97c3kWX2zMxeHxv0Ga06ux81Apf/poEy9P7BvY92tEIphqgdLiw1g4vaJ1es4xB4+O7V6ZLpRf6uJ/q/bzzoosPKqGDmhaRT2HR9Xo3yGCaSNTGNapTbPqFKdpOr/tLWDOr5lkZB7FKEtIErhVnUSPh9t6juEaj9qoS8dlv/7KoUcfJXH2a1ibyF5XrcmyXQVV9lQ6LpBVGglYt+9YZTDl8qp8v/UwX206xNEyF0jQJsjMuN7xXJQWV2fwoes6n64/yPOLd3G03I3D1BbdDtgrVpD2FpSzdFc+JtMw9LBCCFlVuc+irZMNySBRsq6EsIFhle+pOlVKN5cSe03Vtre6ZkRzVmz0rVhD0ewl6Jrqd0DlKxAVBEEQGljyKLjsJfjyHv8DKoMVLvwXdBvr1+GSomBNS8OalkbUrVPQPR4cW7diX7WKwnfe5eADf8LcuRNBgwYTNHgQ1j59kK0nJuNdXpXnF++uMZCqLa3e6dVYtD2P7ILyZlPaIIKpFio+vKJ1+vR567jzg3U8f20fvtt6iIc/3YIONS65/rb3KBsPFJEUHcT7UwYREVTLXgVNxJrsQu7+33rKnN7KjUjVk/Jys5VQns5WePrJRfztkm7cOLhDg4+xbNkycv/2MImvvYq1R48G/3yByhq0UwWySuNRdYodHo6Vu3l1yR7+t/oA6Hrlv7vjVmQW8PCnW7hhUHumj0r5f/buO06K+nzg+Gdmtl+vdO7g6FV6b3YNKhKjMYpGo6LGWGOJppmYYjS/WJLYNZbYEbsmakSlcyAdFO6Og+No18v2mfn9sYAcu3u3u3dHOZ7368UL2J2dnbnbnZlnvs/3eUh3hX+PDMPknrfX8/bXOw854TS9gfFd3ycNGr+HNbU79i5voSgmmksjd1Yu5S+VozpUkgclE6gOUP5iOdZMK+kTD6/2aBKoOwkAe7cBKBYr7m+XkDRgcov7rSrKcZfDLoQQHdawH0ByDsy/Fnx14A9PBQfAlhxKCzz379D/zITfTrFacY0YgWvECLKvvRbD58Pz9Woaly1l36N/x/vNNzgHD8Y1bhxJ48fxHzMnajGJWNLqDcPkucXbuPfc4+PGswRTHVia08q/rhjLXfPWctrfFlBR78fbQt4qQKNf55vd9cx8dCHv/2zyMR1Q/WfDbm569esWq7MduNj9wwcb2VHl5hdnDzwSmwdAw8JFlN95F93/8Xecw4cfsfcVTVnUyKNE8YzSqKpCvTfImQ9/SVWjP+JIF0CjL/R5e25RCe+uKef1uRPokelqssxv39vA21+XRS2IEca0EagbDkoQR5d3AMg5OwctSWP3a7vx7/WjOlVSR6bSY24P1EOqBZqmEnqt4Qjthz2J9MmXUPXJ4yiqhqPXCBTVgnfbarzb15Ixo2mxlrG9pKCNEEIcU3pPh1s3QckXsOhhKF4Aiha6J2fokDcxVLCi4BSIcv5LlGq3kzQ+FDhxExiNjbhXraJx6VL2/OnPPNL5LBpTOoe9Lta0+oBh8kbhDn5x1oDj4kaeBFMdnM2ict5JXXl3TXl4FZVmBHSTvfVe5jyzjHdvmHxMpvytLK2KKZA6lCdg8MKSUnJT7fxkcu923LqQxsWLKb/9drr//VFcI0a0+/uJ6HJT7Gyvcoc9Hs8ojUVVeHZRCR6/TixfJ79usqfOy6x/LOKjm6eQmxIKZhZ8s5c3CuMIpA4wbQRqR2JJ/gZLymYAMqdlkjmthWDHtBConNbkodSxs1GTMqhd8hoV7z+IYnNi79SH1AkXNVkuyaYxd1pBfNsphBCi/SlKKKjqPT2UyuCrC/1tT23zAKo5alISyVOmkDxlCqZpUnr3hxH7C8eTVq8qCtsqGxnQOUIhjWOMBFMngPs+2BQ1kGoubzWgm5RUNPLlln1M79/6EpltyTRNbnt9TdRAqrn98gR0/vLxN8we0b1dR90aly5j520/p/sjD+MaJWWlj7YfjevJpl11YSl58YzSePw6ikJMgdQBhgm1ngBzX1zJ/OsnAfDPBUUJ5ZIDYNrxVU4/GEy1xK7ZaSi/CMMf/h1OHjyD5MEzmn29y25hSh/payaEEMc0RQFHWsvLtTNPQEdVlIjzbONJq1cUqPME22MT25wEUx3c+p217KiOPEExlrzVRr/OE18UH3PB1KrtNeyp90V8Lpb9UhR4rXAH18Zxx93j13l/bTmbd9dT4/aT6rDSOzeZc4d3Jc1pbbJs4/Ll7Lz1Vro99BCuMWMS31HRZs4e2oVfvb0+4nOxjNKoZihwNyL0Om8pAAoaJpt21fHtnnocFo01UZrgxvLZBTC83VACuZjW6CXLraoVi2rh/in3s3JTV55ZWBJXVSUAh1Xlj+cPPSZHpoUQQhx7rJqKHmW+VLzFj46XCrISTHVwT31VTCDCPKl4ykGv2l5NWbWb7hmuw1dz1Dz1ZTHeCBeGse6XN2DwzFclXDOld4sXiqWVjTz1VTHzVu5EUUIFAQ5wWjXue38jZw3pzDVTCxjUNRV3YSE7b76Fbv/3V5LGjW2DvRVtwWHVuGhMD15auh2/Hv6daGmURlHViCNSsQZAgaDBMwtL6JzqaHWJdotipbf6I8otT6Cg4A5+l77osoS+p7P7zuaSgZfQPaU703uYlNd4+Gj97pgDKodV5RdnDeS0QZ1aXlgIIYQgFEy5rFpYFgjEl1bvDxrkphz9ljaxkGCqg1uxrSriHYJ48latmsKaHbXHVDC14Ju9RLrxEc9+NfqDbN3XQL9O0XsGfbJxDze+8jUB3YiYKnngwvS9Nbv4eMNu7hyWzKS/3UW3Bx8gafz42HdIHBG3nd6fBd/uo7SykQjxVFROq4Y/qIelgMcTAOkmvLN6J6cP6tTqEu1BA7o5RvDKhV/waemnlNaVUuurJd2eTn5aPqf0PAWHxXFweUVR+OuFw+me4eKJL4tQIGoxGpctdLfwgQuG8b1hXVvcFiGEEOJQs0Z047UVO8Kum+JJq++Tm3zc9D6VYKqDa/BFzjeNJ29VN6DOG7ms9NEQ1A18Ua6E49kvTVWobvRHff6TjXv42SurYipwoZsmesDkz8sr+Pl1v+PqiRNbfI048pLsFl69ZjwXPr6E8hoP/ijV+A7ltGqcO7wL76wuRz8sAIknAAIwA0H2rFoHWvgcpHgb6Xr8Ok6Lk3MKzolpeUVRuPX0flw+MY8nvyziyS9LcFg1LJqCaUJAN+ie4eTaaQWcM7zrcVFBSQghxLHnysm9eHNlWcSb0LGk1SfZtbimYRxtEkx1cDat9eWgFSX6eo6G5i5/483HjbaubRWN3PhKfJUCAXyajb9+62dkaRWj8qSc9LEoN8XBez+bzG/e2cAH63ahKkrE1Lckm4bTpnH7GQMYnZ/B+2t3hS0TbwCkqgqpmVlQFf5cvJ/ddJe1xWUOatgH9bsg6CPLkUpeqsr5I7rxs1P6UusJYFEVMpJsdDtO7gIKIYQ4dhXkJDO4ayprdtRGzI5qKa1eUxXOGBxeWv1YJcFUB5edbKeiIXz0JZ68VVVRyE09dvJWrZqKVVPxR0hTime/gj4/ts0b0HNHoiUnN3nu6YXFBKKMfrVUbMAbMHjksy08f+W4iK8XR1+Kw8r/XXQSvzl3MG8U7uDFpaVUNPgI6CYuq8bArqnMndqbqX1zUFWFvXXeiCeEeAMgQ9WYMHEAi/77TZO5dxDfZ9dl0xiVl9HCm+mw5RNY9BDsXAUWG6CAafD9QICJ+bPJ5+fQo0+L2y2EEELE45GLR/C9RxZS64kvs0kBTh3QCat2/BQ+kmCqg7tkXE/++OHmsDvv8eStKj4vwz17MM1sFOXY+HBPLsjm82/2ho0sxbNfNgzSX36aLb/cgC0/D9eo0bhGjcQcOpx5K3dGHJ6OtdjA0uIqdtd66ZzmCFuHOHakOa1cNaU3V01pvudYZpJtf9PfpgF2PAEQhCoTXTi6Bw/855uw5+L57JomzDqpW/Q32rUG/n0h+BtCfwD076pf2oG80jfh8flQcDJ8/xmwHTtzIoUQQhzfume4eG3ueC5+cim1nkBMLUUcVo0/nT+EF5aWcvNrq/nLBcOwW479lPNjJ3dLtIvzR3bHjJLMljp2Nhkn/4TaJa9R9ugllD32Y+pXvY+z73fzP2yawgWpDey95WZKzp9N1YsvoddELut8pJh+Pxeb27Hrkec7xbJfdovKVacNovdLL9Bv6RI6//JXWDvlUvv2O7xw/W8wveHl5A8UG8g87Tpc/Sei2hwomgVXn3FNLnQhlD748vLSNt1vcfRYNJUfju0Rdqfs0ADI/e0SjIAXUw/iKSqk+vNnmyxr01TmjM8j2WHh+6O6YYlQRTKWz66mhib3Jtmj3AsrXQzPngkNu78LpCJQjAAEvVD0P3jmNPBFX1YIIYSI14DOqXx40xROHdQJu0XFHqHUuVVVsFtURvZM5/W54zl/ZHdeuXo8/qDBnKeXNzu3/VghI1MdXLLdwnnDu/HW12URK4i12LQzGOSqy06j610X4162jJo357HvkUdInjKZtO9/n6QJE1COUJdtw+Oh5o03qXz2WfoVFJCZ933K3ZFT8WJpRnrx2J4AqDYbrpEjcI0cQdZVV1H73ga8i7aFLR9PsQF/0GD9zrqWd0ocNy6fkM+LS0o5fKZdLJNpAVBgzoQ8AOZOLeCtVTsJRigd29Jn127RmDs1ykhaxVb49w8g4I78fCRBL1RuhVd+CJe9C0fo+yyEEKLj65Lm5Mk5o6lo8PHq8u3MW7WTWncA3TRJtls4ZWAuV0zqRa/spIOvcVg1/vGjkdz/8Wa+/9hinrtiDHlZSc28y9ElwdQJ4J6ZA1lcXEF5jRc9lnHW/RxWlZuV7bgv+yGe++8nacIEkiZMQK+pofb9D9j74F8xamtJmz2b9NnnY+3aPmWU9YYGql9+haoXXsB50nC6P/oIzqFDeWBrBT95fkXcRSKcVo1rp/UmOznyPLAab+srIALUxZknLI5tPTJdTOuXwxff7sN32Hy9lgMglVMGdqJLmvPgup6+bDRXxvn5dVhVnpgzivzsKCeVT34F/sYmD+U/VI87ACU3JZNkC42GPb3Kz0trAyz48f71BL1QvgqK/wd9To15e4QQQohYZCfbueHkvtxwct+YlldVhV+cPZAemS4ueHwJj186quW5wkeJBFMngFSHlTfmTuTCJ5awu9YbsWHp4RxWlVtO7cfcaWdRv2AgZbfeSvoFF5Bz/fVo6elkXnoJmZdegnfjRmrenEfJ+bNxDBlE9sn5ONVvUNx7QxPgXVkwYCYMvQBsMd5VaKyEVc9jrngWs24viu4nQ7GTcflYtNN/BD2HADCpTzZ/PH8od89fF/MFqdOqMWtEV248JfqXOdUR+WsRb7GB5GhpWOK49dAPT+K8vy+itNId0/cIQvOk8rOT+L8Lhzd5fGKfbP51xViuer6QoGE0+xm2agoK0C83mQf+8w0Pf7qFbulOLhrTgwkFWaG5jA17YetnRKpRqZvw8DI/d09pppCMvxEWPSzBlBBCiGPGpePz6Jbh5JoXCvndeUP43rAuYct4vV7WrFnDli1bcLvdaJpGamoqJ510EgUFBajtnHEhV3sniM5pDt6/cTK/f28j760tRyG8HLSqhIZWc1Ps/PJ7gzh1UCcAUqZPx/nWW5TffQ/bLr2Ubg8+iK1HDwAcgwbR+a48cicqsOQx+MaLoh42slO6CD6+C4b/EKbeDqlRRrAaK+GDWzC/+QhTN1AJogBoAF7Y+RW8tDIUoJ35Rxh4DrNHdiczycaNr3yNbpgRO25DKIgyTJMbTu7D9dMLIhbSCOzZg3vZMrIWb8Ou98SnNS09HU+xAauq0LdTcrPLiOOPy2Zh3vUTufzZ5Xyzuz6sIl/48hqDuqTy3BVjIvZtGt87iy/vmMFrK7bz9MISvAGdoG4S1A0smoqmKviDBgHdxGFRWHtI6ujK0mo+2bSHNKeVuVN7c6nvdSxRCsTcPtHGXxb5uH6MjXRHM0VkdiyDmu2Q3jO2H4gQQgjRzmb0z+WFn4RuPu6odjN3am8URaG6upoFCxawYcMGFEUhEGiaEbRlyxasVisTJkxg/PjxWCztE/YoZoRyvweMHj3aLCwsbJc3FkdPvTfAvJVlvLRsOxUNPoK6icumMTovg6un9uakHukRgw3TMKh+8UUqHn+CTnfdSeq556I07oN/zYSa0lCqUHNUC9hTQvMyugxr+lz1NsynTgN3BQox3PG3OGH6XTD5ZiDUcPSTjXt4bEERm3fXhfpiKRDUTTKTbMyd2pvZo7qT6vguQApWV+NetpzGZUtxL12GXlWFa9w49DHjOH1zOr4Ic8zqlr9F7bK3yDrjp81WW7NbVP5z89To6VhHkaIoK03THH20t6O1jubxKaAbfLhuF49/UURJRSMB3eDAQJWmKlg1hYKcZOZOK+DsIZ2xxNCnTTdMFhdVsK2ikXpfkN21Xl5bsQN/0Gi2txqEbhYMV4t5ht+SpPiaPJf/UD1Pn+vknyv8DMpRue9kR3ia3wFWF5z9AIy4NI6fhhBtpyMcn+TaSYj2savWwxXPrWBkXgZXj0jl1Vdexu/301wsA2CxWMjNzWXOnDk4nYn1U2zu2CTBlIibd/Nmdv785zj796ZL989R6svBiDzP6HAmgD0F5eoFkB3qb+PbuArLazNRzUaUeEZirU44834YdXmTh6sb/VS5/eiGSbrTSk6KHUVR0BsacBcW4l66jMZlywjs2IFz1EiSxo0nafw47AMGHCymcfOrX/Peml0Rews1bPic+sJ3CFTuaFJswNF94MFlRuVlMO+6iXHszJHTES5W4Ng5Pm3aVcfn3+ylot6PokBWso1TBnSif+eUhNe5vKSKy59dhieO+VR2AgxVinnFdh9W5bsRswPBVOdkhUnPNrL1Z8m8800wcjClWuHU38LEGxLediFaoyMcn46VY5MQHVG9N8Ct/1pAlz1LUc3ms0MOpWka2dnZXHXVVVitcTS936+5Y5Ok+Ym4OQYMoNcbb+C7fzJm1Q4ULfaiFgpgeuvR/3kqnhkvUjPvLVI987B18RJ3C6uABz66A/qdASnfdcrOSLKRkWTD8HrxrF7NviVLcS9dinfLFpxDh5I0YTydf/0rnEOGoET5Qs2dVsDHG3ajB+KvgOi0avzsZGmEeqIY2CWVgV1S22x9Db4gV/5rRVyBFIAPK+vNfP4veAF3Wl8Le35IrsbMfhb+vNDPwJwody0UhfjuaAghhBBHjtOi0LfuaxpjDKQWLFhAVVUVs2fPprKykg8//JDzzjuvTbdJgimREDVQg1PbweGT3WOpHKYooPhrqfrV5egpfUgZ7UY57EsRUwUyCL1/4bMw427MQADPuvW4ly2lcekyPOvW4ejXD9f4ceTccjPOk05CdcTWRHdgl1R+NXMQ972/KWxuWXOcVo0fT8pnev/cmF8jxKHe/roMI0rGQOPGBdSteJtAZRmqzYk1tzdpEy882DDai50X9NO52TIPuxI+WnzvdAcjn2jgtglRClFo1tCcRCGEEOIYtHnzZoLB/XPqD7Fu3TqWLFlCRUUFdrudzp07M2XKlCbLBINB1q1bxxlnnIEjxuvBWEgwJRJT+GzUp2KpHKZYDLrPGYAv0B2zfEXEm+ExVSAL+jC+eoSyV0vwrFyNtUcPksaNI/OKH+MaPRotOfEiEJeMy8MwTP7w4aaYqgUeCKTuOKN/wu8pTmymafL4F8URC1vULZ9P7bI3yTr9pzh6jUTRLHhKVuLZsuxgMHXAR8Y4ZmmLwtbRJ1PlosFWHlnuZ2hupC9dUKr5CSGEOGYtXLgQv79pI98lS5awcOFCZs6cSUFBAZqmsXXrVjZv3ozNZmuyrKIorF69mvHjx7fZNkk+h0hM4TOg+yI+dftEGw8u9lHjjZ7+pwDqrkKc3qWoauRAJZb1AGDoZJ0ykIJP/kvv+W/R6a47SZk+vVWB1AFzJuTz76vGMa1fNjaLiu2w7t2Hdu7+56UjufPMARGLdwgRi3U7a6mK0O3d8DVSs/DfZJ52Ha7+E1FtDhTNgqvPuCaFTwAacfJ08Kyo7/HraXYa/RG+U4oK/c+EJBmZEkIIceyprKykoqKiyWNer5fPP/+cs88+m4EDB2Kz2dA0jf79+3P66aeHrSMQCLBs2bI23S4ZmRLxM3RwV0V9enRXjen5Fh5c7OO+k5sZRtXs4K5s9XpUm42kgT0ho32auY3Ky+T5K8exu9bLy8tLWb+zjjpPgGSHhb65yfxoXF6Tzt1CJGpHlQc1Qizu27kZM+jH1W9CTOspN7MP/nvbzU0LYfRIU/H+MsIcL4sDJvwsru0VQgghjpSamho0TSMY/C6NvaysjGAwyMCBA5t5ZVP19fVtul0STIn4BTygas1W8PvdDDuTnm3kpnG2qMsAocCsGTGtxzRBD7+b39Y6pzm49TRJ4RPtx+0PYkQYNNI9daiu1JiaRUOoGEVcrC4Y8n3oMSa+1wkhhBBHyOF9pADcbjculyuuxry6Hvtc+FhImp+In9XVYhB0aOWw6MxQefPWrkdVwZHW7HqEOB6kOCyoEdJENWcqhrsOs4Xv3QFOJRDq6xYLqwt6z4CZD8WzqUIIIcQRZbeHz6F3uVy43W4MI/YKuG3dvFeCKRE/VYW07i0udu90B0+t8rOzLsqcJz0AeZNaLMXc8nr80G1Ui9sjxLGuT24KQT38hGDvNgDFYsX97ZKY1tO3Z1foPiaUuqdGGaWyJoWaX4+/Hi56CTRJVBBCCHHsysrKapLiB9C9e3csFgubN2+OeT2ZmZltul0STInETLghdEe7GYdWDgujqDBgJky5NXTBl+h6UKDXtCZ9poQ4XvXJTaZPbnjhFNWeRPrkS6j65HHc3y7BCHgx9SCeokKqP29aWTPJpnHNyQPhyo/husWhptbWpNB3TrMBCmTkwxl/gNu3wim/Ct0gEUIIIY5hqamp5OXlNXnM4XAwffp0PvzwQzZv3kwgEEDXdbZs2cInn3wStg6bzcakSZPadLvkVqRIzEkXw6e/aXGxX0+z8+La8BxXLA6YeAN0HREa5ar4NrH12Fww6aZYt1qIY9610wu48821NB5WHj117GzUpAxql7xGxfsPotic2Dv1IXXCRU2Wc1g1pvXNCf0nqwC+99fQn4AXgh6wp4bmPAohhBDHmUmTJlFWVtakPPrEiRNJ/RYJxgAAIABJREFUTk7myy+/5K233sJms9G1a1emTJlCUVFR2DoGDRrUptskwZRIjCMNRl4OX78QKkixX0yVwzQbdBocCqQAzvwTvHoJBL1xrscOnYaGUgWF6CBOH9SZPzhDzaIPL0aRPHgGyYNnRH2t06px86l9USOVBLQ6Qn+EEEKI41SvXr1ITU2lqqqqyTypYcOGMWzYsLDle/TocfDfVquViRMnypwpcQw544+huUotpOk1oVogKQd+9Pp3j/U5FU6/r8ViFE1odkjvAZe8AdLXSXQgNovKq9eMJ8luieuj7bRqzBzehUvH57W8sBBCCHEcUlWVyy67DIfDEVdfT6vVSu/evZk6dWrbb1Obr1GcODQLXDIPCmaE5mS0xOqCzAK4ZgG4Dpv8N/ZqOOfRUGBmaS6oUkLv1W0kXP05OCL0yxHiOJeXlcTbP51EdrIdh7X5w7RCKJC6aEx37p89TJpGCyGE6NBSU1O55pprSE1NxWptuRWI1Wpl4MCBXHjhhXGVUI+VpPmJ1rE64KKX4duPYdFDsGsNGAYY+3NZFQ0s9lCBiMm3wNAfRB+BGvYDKDgZVr0AS/8RSh9UlFAfKUUF3QcFp8LEn0HP8TIiJTq0gpxkPrttGq8t38FTXxXT6Avi9uscyPyzW1RMYFJBFtdMLWBCQdbR3FwhhBDiiElPT+f6669nzZo1LFq0CI/H02QelaZpKIpCjx49mDRpEgUFBe12s1ExzSjlpoHRo0ebhYWF7fLGooOqLILNH0DDHjACoZS+3jNC6YDxfIgNA8pWhNaj+0NztLoMh+Tc9tv2E4SiKCtN0xx9tLejtU6k45NhmCwqquDr7TVUNvhw2jQ6pTo4e2gXOqXKPCjRcXSE49OJdGwS4lhgmialpaWUlpbS2NiIxWIhJSWFgQMHkp6e3ibv0dyxSUamRNvKKoBJN7Z+PaoKPce1fj1CdACqqjClbw5TDlTpE0IIIQQAiqKQn59Pfn7+UXl/mTMlhBBCCCGEEAmQYEoIIYQQQgghEiDBlBBCCCGEEEIkQIIpIYQQQgghhEiABFNCCCGEEEIIkQAJpoQQQgghhBAiARJMCSGEEEIIIUQCJJgSQgghhBBCiARIMCWEEEIIIYQQCZBgSgghhBBCCCESoJimGf1JRdkHlB65zRFCHAF5pmnmHO2NaC05PgnRIR33xyc5NgnRIUU9NjUbTAkhhBBCCCGEiEzS/IQQQgghhBAiARJMCSGEEEIIIUQCJJgSQgghhBBCiARIMCWEEEIIIYQQCZBgSgghhBBCCCESIMGUEEIIIYQQQiRAgikhhBBCCCGESIAEU0IIIYQQQgiRAAmmhBBCCCGEECIBEkwJIYQQQgghRAIkmBJCCCGEEEKIBEgwJYQQQgghhBAJkGBKCCGEEEIIIRIgwZQQQgghhBBCJECCKSGEEEIIIYRIgARTQgghhBBCCJEACaaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAmQYEoIIYQQQgghEiDBlBBCCCGEEEIkQIIpIYQQQgghhEiABFNCCCGEEEIIkQAJpoQQQgghhBAiARJMCSGEEEIIIUQCJJgSQgghhBBCiARIMCWEEEIIIYQQCZBgSgghhBBCCCESIMGUEEIIIYQQQiRAgikhhBBCCCGESIAEU0IIIYQQQgiRAAmmhBBCCCGEECIBEkwJIYQQQgghRAIkmBJCCCGEEEKIBEgwJYQQQgghhBAJkGBKCCGEEEIIIRIgwZQQQgghhBBCJECCKSGEEEIIIYRIgARTQgghhBBCCJEACaaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAmQYEoIIYQQQgghEiDBlBBCCCGEEEIkwNLck9nZ2WZ+fv4R2hQhxJGwcuXKCtM0c472drSWHJ+E6Hg6wvFJjk1CdDzNHZuaDaby8/MpLCxsn60SQhwViqKUHu1taAtyfBKi4+kIxyc5NgnR8TR3bJI0PyGEEEIIIYRIgARTQgghhBBCCJEACaaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAmQYEoIIYQQQgghEiDBlBBCCCGEEEIkQIIpIYQQQgghhEiA5WhvgBAiXJ03QE1jAIA0l5U0p/Uob5FoLY9fp8rtxx80SHFYyHTZUFXlaG+WEKKDMwyTKrefem8Qm0Ul02XDadOO9mYJ0WFIMCXEMcIfNPh4w24eX1DEt3vqsVlCA8cB3aAgJ5nrphdw5pDO2C1yEjxemKbJim3VPPFlEV9+uw+LqqIqEDBMUh0WfjK5FxeN6Ulmku1ob6oQooOpavTz2ortPLOwhDpvEKuqYJgQNAym9sth7tQCxuRnoChyU0eI1pBgSohjwAdry7lz3jpMTBp9OgBBv37w+c2767n7rXXcM389980awqwR3Y7WpooYfbunnqueL6SiwYfHr2MCAf2732lFg5+HP9vCQ59u4eKxPfnVzEFoMlIlhGgl3TD5/fsbeWX5dhQFvAEDAP8hy/xv016WFFWSnWznqctG079zytHZWCE6AJkzJcRR9uzCEm57Yw0NvuDBQCqSRr9Ogy/IXW+t5bEFW4/gFop4rSytZtY/FrGjyo17fyAViTdg4AsavLZiB1f+awVB3Tii2ymE6FiCusFPnl/Bayt24AsaBwOpw5mA26+zo8rN+f9cxMrSqiO7oUJ0IBJMCXEUfbC2nL/8Z3PUE14k3oDBI59tYf6qsnbcMpGokopGLn92ebNB1OE8AZ3lJVXcMW9tu26bEKJju3PeWpYVV+EJRL8xd6gDQdXlz66geF9D+26cEB2UpPkJcZQEdINfvLUuYiDVuHEBdSveJlBZhmpzYs3tTdrEC3F0HwyAJ2Dwy3fWc9bQLjisMoeqPfiCoQCnssFP0DBJc1oZ0TOd7GR7s6+7990NNPqDEZ9r7vfqCeh8tG43V0ysZWj3tPbYJSFEB7aurJYP1+2OGEi1dE5p9Ae5972NPH/l2CO92UIc9ySYEuIo+e+GPehm+NhF3fL51C57k6zTf4qj10gUzYKnZCWeLcsOnvgAMOGj9bs4f0T3I7jVHd/OGg8vLN7Gv5dtB0JFJExAVRQCusH0/jlcM7U3I3uGT9zeU+dlSXElEX6tMf1efUGdp78q5uGLR7T3bgohOpinvyrGFwwPpGI59pgmLC2uZE+dl06pjiO96UIc1ySYEuIoeeyLrWFzpAxfIzUL/03W2Tfj6j/x4OOuPuNw9RnXZNlGv85jC4okmGpDT3xRxP998i2maeLXIyfp/XfjHr7aUsGovAyenDO6SYnhF5eURnxNrL9Xw4SPN+ym1h0gzSXl8IUQsal1B/h4w26Mww5b8ZxTIHQM+/kZ/dt7c4XoUGTOlBBHQaMvyOZd9WGP+3Zuxgz6cfWbENN6SioaqXH7W15QtOj+jzbz0Kdb8AWNqIEUhO7guv2hFMDvP7YY7yEpNR+sK8cXDE/bjOf3atVUFhdVJLYTQogT0pLiCqxa+CVdPMceX9Dg/bXl7bF5QnRoEkwJcRTUeAIRT3y6pw7VlYqixjYPyqqp1LgDbb15J5w3Cnfw3OKSmCdtQ+jCo2hfA9f/e9XBx2o9kedKxfN71Q2TGo/8ToUQsat2B9APH5Yi/nNKnTfyMUwIEZ2k+QlxDNGcqRjuOkxDj/nkJ1pHN0z+9FH0iorNTdz2BQ0WF1WwaVcdA7ukRn2PeH6vhmny5sodvLxsO42+IC67Rt/cFH48MZ/hPdJbta9CiBOLnFOEaH8STAlxFKQ7rQQi9BSydxuAYrHi/nYJSQMmt7iegG6QLnNrWmXBN3sjTtqG2CZuB3STZxaWcNdZA6I23Y3n9+oLGny9vabJ3IeN5XV8vH43XdIc3HRqX847SZo2CyG+k+GyRjz+xHtOSXXIZaEQ8ZI0PyGOgiS7hQFdwjvOq/Yk0idfQtUnj+P+dglGwIupB/EUFVL9+bNhy/fKTiLdZTsSm9xhPf5FUcRmyQcmbmeedh2u/hNRbQ4UzYKrzzgyZlx5cDndMJm3sozpD3xOkk2LeEET7+81bBK5GepFVVzRyF3z1nHP/HUYEVJ6hBAnpgm9syPeoIvn2GO3qMwc1vVIbbIQHYbcghAiTkHdIGiYre7vdGU3uGdHAI/adGQpdexs1KQMape8RsX7D6LYnNg79SF1wkVNlkuyaVw3vaBV2yBg3c7aiI/HM3HbadN49sdjyctyMfUvn0ecuxDr77UlnoDOW6t2YtUUfnvukLheK4TomNJcVs4c3Jn31paH3YyJ59gzZ0Je2GPrymr597JSiisacfuDoZ57PTK4ZHxPuqQ522uXhDhuSDAlRAy+2V3PMwuLeX/tLjwBHQVQFIXh3dO4dloBJw/IxRKhoEQk7lWr2PfIowwq3402+qcQYapO8uAZJA+e0fyKFDhrSJf4d0YcZJomvihzpeKZuK2pCo2+IJ1SHUzoncUXW/ZF7DUV0+81Bp6AzmsrypjaL4eTB3Rq9fqEEMe/q6f25r8b90QspBPLsWdsr8yDPaZM0+TdNeU88tkWymu8+IJ6kyBtxbZqnvqqmHG9Mrn5tH6M7JnRpvsixPFEgikhmrF1bwM3vfo1RfsaCOjmwREHE8A0WbW9hlteX41VU/n1zEHMHhm955Nn9Wr2Pfp3/Nu2kX39dazKG41//kYiRlMtcFhV/jBrSKtHx050iqKgqkrEkaR4J27bLKFg+jfnDuacRxfS4GtdVazmCl9AKKD6+/+2SjAlxAnK49dZVlJJVaMfwwzNxT11YC6fbtobV2VSAIuqsKvWQ0lFIz0ynNz+5lo+Xr876nr8+1tAfLmlguXbqrj33MFcNKZnq/dJiOORBFNCRLFqezWXPbOcRl+Q5manhObb6Nwzfz2llW5uOa1fk+c969ax79FH8W3dSvbca0k/fxbPLCvjwXc24o+Q494Sh1Xl5lP7MUua9baJNKeVqsbwXl3xTNwO6ia5KXYgNI/t+SvHctkzy3D79WY/O9HEUvgCYEN5HSUVjfTKTkrgXYQQx6OSikaeW1TCG4VlaKqCaZqYgKooBHSDNKcVwzQj9rw7nAK4bBrPXzmWjbvqmP3PRRTkJLOhvBZPlFH7w3kDBr95dwM2TeX8Zm4oCtFRSQEKISIo2tfAZc8sp6GFQOpQnoDOk18W8/ziEgC8Gzey47rrKfvZjaTMmEHBxx+TcdGFfPxNJQ/+95uopbibY9UU/jx7GNdOk7lSbeUHo7pj01pXNCI3xU6f3OSD/x+Vl8H8n06iR6YLl00jco2/yGItfAGh4hevrdge1/4KIY5Ppmly/0ebOPOhL3l52XY8AZ0GX5BGv47bH/q3L2hQ0eAjaJgoCtgtkY8+B4KoHpku5v90EqPzM7lsQj6XT8hnZWl1zIHUAd6AwS/mr6O0srEN9lSI44uMTAkRwR1vrqXRHzlNq7n0K09A548fbGLkK4/iWLOSrKuvpttDf0O1h0YtdMPknrfXJ9TTCEBBYUrf7PbZ6RPUZRPzeW5RScTnYpm47bJpzJ3WG0VpetHSr1MKX9w+ncLSap74oogvvt2HRVVRlVA5dUUh4p3jeApfBA2Tkgq5eBGiozNNkzvnreW9NbtaHHEyTMA0sVtUclPsuP1B6r06Vk3BMCFoGEzrl8PcaQWMzss4eOwyTZP5X++MegOxpfOTrps8u2gb9547OMoa2pC7CjzVoX87M8CV2f7vKUQUEkwdx0zdILDHjeEJoiigJlmx5LrCLupEfLZVNLJ+Z23EAgKxpF+Zfj8f54/ntgf+iOpwNHl9a3saqQq8smI7N8zo27Y7fQLSDZPPNu3h8S+K8OvRxx9jmbg9a0Tkvk+KojAmP5Mx+Zl4AzrVbj++gEGq08o989fy0fo94dsVR+ELIGJZdyFEx/LEl8W8t2ZXXHOhfEGDfQ0+vjekC/fMHESdJ4DdqpLhskWcb7tqew37GnwR1xVTzz3D5I3CHfzirAHtM5834IWN78DCv0HVVtBCNynRfZBZAJNvhUHngdXR/HqEaGMSTB2HgrU+GpfsomFpeagSwoHYyTBRHBZSpnQjaXRnVKf8ehPx7KKSiD18DqRfZZ19M67+Ew8+7uozDlefcQf/79esvOJO4labPWwdLfU0amnd3qDBcwu3cd20PlEbxIqW/WfDbu6atxa/brQqGHFYVf7+oxG4bC1/1xxWrUkZ4azk8M8HxF/4Qpo2C9GxeQM6j3y2JWog1dyIkTdg8P66Xdx6Rn/yW5hb+cKSbRHfI9bzE4QuR/67cQ/nDm/jflVrX4f3bwn9298Q+lsPfPf8vs3wwa2hPzP/BsMubNv3F6IZMmfqOGIaJjUfFLP7gRXULyzD9OqYPj30t1fH9BsYdX7q/ltK+R+W0bB819He5GOaaZo0+oLsrvVS3eg/WNHtvTXlBCIEU/GkX/mCOusP619kmiYrS6sjLx/Huj0BXfLSW+H5xdu46dWvqXYHWh1I/fH8oQlX0xvePR2XLTxYOrTwRUucVo0RUpJYiA7t/bXRz+V1y+dT9dlTpI2/kO43vES3654jZeTZeLYsO7iMaZq8sHhbi+9TvK8xYkZGvOensmp3i8vFZeFD8O6NoSDqQCAVyYHn37sx9BohjhAZujhOmIZJ1aub8W6qgmDzJRHM/fNxat8rxqj3k3pKeBO+E1mdN8C8wjKe/KqYffU+rJqCbgImnD20M/XeQMTXxZN+pSoKVe6mFeIa/TqqomBEOFvF29OoxhN5G0XzPlxXzp8+2pRQ8Q8I/eytmkKv7CR+e85gxvXOSnhbZg7ryq/f2RD2+KGFLxRVw9FrBIpqwbttNd7ta5sUoTBMkwtGSfUsITqyxxZsxe1PfMTIr5v8e9l2bju9/8EWDgd4AzoVDT4qGvxUREnxi+f8ZJjQ4G1dW4gm1r4OC/4MQU/srwl4Qq9J6QLD42uKLkQiJJhKwK5aDy8tLWXNjlrqvAFcNgsFOUlcOj6PgV1S2+U9az/ehndT1cFAKRZmwKB+QRlappOkEbntsl3Nvr9psmp7NfO/3kl5jZegbpCVbGNav1zOGtoZu+XI9kgyDJO//Gczzy3ahqooB9MZgoeMQr23ZhfRps/ElX5lhibjHkpViDqxN97ULk3mxcXNF9S54811EQOpliZWAygKnDe8C1dPLWjV93xjeR1PLyzm883R58/FUvhCUxTOHNKZNKek+QnRUemGSfG+yJkI8WZL3PbGagJBk4oGH5WNfirqffiCofNydrI9am+8eM9P/9u8F01VKMhJpiAnmd45SSTZE7jcDPrg/VvDAqn8h+pxB6DkpmSSbKFz4dOr/Ly0NsCCH+9PZQx64IPbYPAssEROqRairUgwFYeVpVU89OkWlpdUYZo06RG0YlsV81aVkZ+VxI2n9OXsoV3a7H31ej8Ni3dGHJFaXraWP37+GN9WbENVVfpm5fGbU37GSV0GAqGAqvbdIlzDclAilH9uD0Hd4PXCHTz+RTEVDT48Ab1J6sB/N+zhnrfXc/HYHlwztTe5Ke0/WTSoG1z30koWbq1sthKSHinHYb94+g4BpO2fy2IaBv6iIrwrv0bRnaCEZ9fGs+6AbpCZZGvx/UVTH63bjRnh9xtrTyeXTeN7w7omHEit2l7N3W+tY1tlY5MG0NG0VPjCalH56Yw+CW2LEOL4UO8NYNEUAhHu8sUzYqSg0DnVwUk9MshOtpGdYic7yU6q03KwaNWv31nPy8u2N7nBCPGdn5xWlWn9clAVhU837eGJL4spqWggw2XbH1wlUZCbfDDQ6pRqj140a+O7RLsFqZvw8DI/d09pLlAyQwUrZP6UaGcSTMXopaWl3PfBxqjpQboRujjavLue215fw8ItFfx+1pA2KRLQuCxyvnS9r5Er3ryLP5x+K+cMmIFfD7K8bA12remFtmmYeDdX4hzc/iW1G31BrvzXCtaW1UafLLs/XeH5xdt4c2UZL181nkFd22dE74C7569j4daKuHtnHCqe9KtgMEjnD99g++qVeFavQUtLwzVyBFPSJ7KgVg07PcSz7k4pDrpnOBHxefyLooOfvQPimVjd6NN54stiThkY/xypj9fv4pbXVrfq83coh1Xl4R+eRL9OKW2yPiHEsclmUYl23yWeESObRWXWiG4M7poWdZnLJ+bz+oodYcFUPOcnh1XjjjMHNLn2MQyTnTUetu5roGhvA5t21fP+2l0U72vAGzBCAVZO8v4gK/TvvKwkbIseijpH6vaJNv6yyMf1Y2ykO6JcZ/kbQnOnJJgS7UyCqRi8tmJHs4HU4TwB/WCvhj+eP6RVpcpN3aRhcXnEUaniqh0AzBp0KgBOVWNar7Hh6/Dp1C0oa/dgyh80uOTppWzaVR9T5/WAblLjDvCDJ5bwzk8nNWl62pbWldXuLymbWHrXoWJKvzKCnF67FVtKEOcFF9D1D3/AkpMDwA2lVSx7ZnnE/PfW9DQSzatq9FO0L/ykHE+aDMDK0mq8AT2usr9Liiq5+bXVCc/TOpRNU9FUhX9cMiLhwhdCiOOH06phUZWII9nxZjTkpDSf7laQk0z/zimsKasNey6W85PdonLl5F5hN5FVVaFHposemS5m9G865aDG7adoXyNF+xoo2tfAmyt3UryvgZqaSpZZNhEtiXl0V43p+RYeXOzjvpObyW6p+Ba8teCIHkQK0VoSTLVg6956fvNu9Car0XgCOm9/vZOJBVmc04oSocEqD2aUwKR3Zg9UReWWD/7AuQNOYUS3waQ7It+pDpTVY5pmu16E3/fBRjbvji2QOpTbH+TSp5ey8M6TsWhtX2Dyya+KIs5NiTW963AtpV9ZbDZu/N1cOuWEB4cje2aQk2yntCpytaOW1m2a0Xsaieiq3X5smkpAb/o5iLenk1VTqPMEYg6mgrrBdf9emfA8LQCLqmC3aqgKzBmfx5wJeU1KrAshOi5FUThneFfmf70zLKCKZ8RoQOeUmFLqf35Gf65+oTDiMavF1GNN5eKxPePYO0h32RiVZ2NUXtOqpP59JWhP2CEYvTLg72bYmfRsIzeNaybtXbOFmvu2QzDlC+p8vH43T35ZzI4qN96ggd2i0iPTxdVTenH20C5HfG64ODokmGrB01+VENAjBwctXQx5AjqP/m9Lq4Ipwx0MVS6IIMWexFuX/J1/LnuZOz5+gH2NVcwoGMdfzryDnKTDuoErCqZfR0lkEmgMGnxBXi/cETXobO5nZZpQ7w3yv817OX1w5zbdrlp3gP9u2BOWJhFPelc8nFaNmcO6UBAhkILQifFvPzyJHz21NO4APVSKe0hMPY1EU5HmSkH8E6uBqCk3kXy6aU/E40esgbymwI/G9WRiQRYnD+gUVolLCNHx/WRyL95fWx5xdCqWEaMku8Z10wtieq8pfXO46ZS+PPLZ1rgaBDutGs9fOYbsKP3z4mWzKN/10IxiSK7GzH4W/rzQz8CcaMfGOA7YMdINk4c/3cKzi0pCLVYOyTTxBw02ltfxy/nr+eXbG7hiYh63nNZf+kJ2cHJV1oxGX5C3V+8kUiwV68XQ9io363fWMqRb83dFdMOkssHH7jovu2u97KnzsrvOi1HeyGy/TrT70H2z8/nb9+4GYGtlKTe+fx+//exR/nHubw5bv8GFTy4hNcVBdrKNrGQ7WUk2clLsZCXZD1bzyXBZExodevvrMtQoo16x/Kwa/TqPf1HU5sHUwq0VWDSFw4sUxZveFQunVWVUXgZ/mj202eVG9szgn5eM5Pp/r4o5oHJYVW49rR/nj5Qy2IlIc9oiTuCOt6hIIGjGVT3vsQXhTZrjCeRNoKLex5lD2q6gjRDi+DKwSyoFOcls2lUX8WZOLCNGp8Yx1/O66X2wWzT+8p/NBIJms4WZbBYVm6by/JVjGJWXGXW5uDkzQfe3uNi90x2MfKKB2yZEDuL0oJ9Pi/2c1NdLp9TWF7vyBXWuer6Qwm3VzQabBwKsZxaWsHpHLU9fPjqu9HBxfJFgqhkfrtsVMUCI52IoEDT51+Jt3DCjD7sOCZIODZj21HrZ1+Aj1WGlU6qDzmn7/6Q6yOuViX2rh1jurvTJyuPCIWfy0up3w55TNZV7zhtMRYOfyv1lUcuqPawpq6WywRcqldrgp9YTINVpDQVchwRZhwZgB6oAZSXbDpY7ffqrklb1wQDYUF5HWbWb7hmuFve1OaauY3q9GF4v+8r3okdIO4w3vas5Vk1BVRTOO6kb980aElMwevKATrx89XhufnV1xIqHByTZNOxWjftmDebsoW3cUf4Ekp1so3Oag+2HpVfGkyYDMKBLCs4IjXYjqWzwsWlXfdjj8QTyhgn/3bgHwzBR5c6mECesf/xoJDMfXRi1fHk0DqvKU5eNjvsm6ZWTezG+dxZPfVUcuhZSFTyHnOOT7BoWVeWyCaHU4zavyutIhdxBsHtts4v1yVS5aLCVR5b7GZobvo9VSb15Y30td33wJUl2C2PyMxmVl8Ho/Az65abEdVw1DJMbXv6aFSVVeGOczuAJGBRuq+KnL6/iqTmj5TjeQUkw1YwdVe6IAUI8F0O6aTJvZRnLSirpnOoIBUupoWpso/Iy6JIWeiw31R41t3Z3YSXBivCGdVsrS/msaAnnDjiZLqm5lNft4Z1NnzGy62HzfRRwDcpiRM+MsHWEba9hUu32HwyuDjTzq2zwsWN7Tej/jf6DAZiCQlayjZ3VkRvqxfOzspoGm97+Dy67B9PjxfB5m//b6z0YNB36txkMojgcqA4HlT3HYfScBlrT0YR407tS7Baykm3srvNi3X9SMkwTBYWLx/bgsgn59MiMLwgc2TODL26fzqrt1TzxZTH/27QXk1Bmg2GajMrL4NppBUzvnyspAq2kKApzp/XmDx9sCvtOx5ImA6GLh2unxZYqA1DZ6MdqUTj8EBJvIK8ooTTXAWIcAAAgAElEQVTYA6X2hRAnnvzsJF65ejyXPL2UBl8wpnRjp1Xj7z8awZj8xEaMBnVN5W8XncRvzx3Mh+t2saPKTYMvSKbLxoAuqZwyMPfg+bBdTL4Z3r0xakW/A349zc6LayM0srclk3PGnTw9dAyGYVJc0UDhtmoKS6t5ZmEJlQ0+RuZlMDovg1F5mZzUI73Zm2UfrNvFoq0VMQdSB3iDBkuKKnl3TbnMee6gJJhqRn2ULt7xXgwN7JLKhzdNSXg7UqZ3p+bdIkx/0y9wks3F6vJNPLXidep8DaTakzm1YAL3zLi+yXKKRSVlamzpYZqq7B+Jajnv2TRN3H6dygY/0x74POIy8fyszGCAqm+24nO6Ue0OFKcDNSkJS3YWit2B6nSE/70/aDr0b8VmO1hoo++acqzz1uI/7Io23vSu3FQ7n946jdJKN5WNPgJ6KN2rd05SqyaYKorCqLxMnpyTiWmaeAMGummSZNOkYl8bm3VSN37//saIz7WUJgOgKgpnxJGGGtANlAhJ//EG8qqiNOlpJ4Q4MQ3tnsYHN07hV2+vZ3FxJRCao3Moi6qgqQqDuqTyu/OGMLR76wsvpDmtcReWaBMDzoH3bg57eNvNTQtt9UhT8f4yQnsVRYGB5wKhioJ9clPok5vCD/fvy756HytLq1lZWsVf/rOZzbvq6dc5hdEHAqz8jCYjbo8tKIp4gx1ankPv9us89kWRBFMdlARTzUh3WVEIT7CL92Io1dm6H7NzWA417xaFPd4lJYfHZt3b4uu1NDvW7m1fdlxRFJLsFpLsFmwWNWIVv3h+VqrTSc85V9KlIKvNtnFiQVZYzwyIL73LblE5e2gXFEUhPzuJ/OykNtu+QymKEnMKmYhfkt3CPWcP5I8fboq735PDqvL784Zg1RRWllbx5JfFLCuuwh3Q0RSFNJeV74/sxmUT8g/m5ac6rK0uZwyhoKy1xxAhRMfQI9PFv64cy546Ly8tLeXt1Tup8wQxTJMUu4VTBnbiikn59I5SBOm4YrHBuY/C/GshGDn7JfprnXDOo6F1RJGTYufMIZ05c0joJpk3oLNmRw2FpdW8sbKMu95aR5rTyuj8DLqnuyK214DY59CXVjayoby22V5f4vgkZ+hm9OuUgsumhTX6jOdiyKopDG2h+ERLVJtGxg/6UfX6txDnRaBiVcm8eEC7j3L0yHSxdW/4gSaen5U/aJCf3br5UofLTrYztV8On27aEzYnKdb0LoBLx+e16XaJo2POhHzKa7z8a/G2mCtVKcCIHumkOa1Mf2AB+yLMcfPU6jz9VQlPf1XChIIs/jR7KF3TnTisatj7xDtPq1d260Y/hRAdT6dUB7ed3p/bTu9/tDelfQ2eBfW74NN7Yw+oLE445dcw5Py43sph1RjXO4txvUM3dA3DZOu+UGrgc4tLIt4wjmdeuD9oMH/VTgmmOiAJpppx6qBOEScLxnMxpCoKl03Ib/W2uIbmYDQGqf2gGDPGgEqxqWTNGYStW/vfobpmSm9++96GsCHweH5WI3qmx9U/xzBM6r1BGvxBXFaNVKc14tyiuVN7s3BLRcSL55bSuxQlNLrVFlWAxLHhzrMG0DnNzp8+2oyqKFHTNlw2DdOEO87sz6P/28LVLxRGHOU84MCJ9qtv93H2w1/x+twJ/HhiL/65YGvYSTjWQN5l07huep9W7rEQQhzHxl8HKV3gvRvB0KPPobIlg6rBOY+EgrBWUlWFfp1S6NcphSVFFWzZ07rG74YJ5bVxjrCJ44IEU82waipzxufx9MKSsLzkWC+GTuqRHndhgmiSx3dBS7dT/dYWTG8wbA4VAEpojpSW6SDzwv5HJJACOGd4V37z7oaIz8Xys3IGfVwSLMUMjEaxNj/RflethxeXlPLS0lI8AR2LqqIbJpqqcMGo7lw5uRe9DknFG5WXwZS+2XyxeQ++OKeeuGwav5w5KL4XiWPe5RN78f1RPZi/qozHvyimosF3cCJ1QDfokubguukFnDO8K59u3EOjT282kDqUbkK1O8CFTyzhvllDwo4dB8QyTwtg5jApiy6EOMENngUDvgeb34eFD8Ge9aGGvBAqod5pMEy+BQbMDCs41RYijUpB/HPoo50PxPFNgqkWXD4xn+eXbMMfoRZFSxdDDqvKzaf2a9PtcQ7IxPGLsfiKaqn/sgxfcU3o6s0MpfQ5hmSRMrn7EQuiDm6XTWPOhDxeXFIa9wiQqkBWehJDV35I8bvPk3vH7STPmBGWmujx69z2xho+27QHk+8OSgF9//vp8Ory7bxeuONgL6eMpNDB9l7fWq6u9rM1owfeGI9lLpvGcz8eG7UBrzi+JdstzJmQz6Xj89hV66XGHUBRIMNlo1OqHUVRqHUHuGPe2ogn0pYmHNe4A9z86momFGTx9fbquOdpOa0at5zWV3qTCCEEhIKkweeH/vgbwVMdetyRDvb2PU9nRSnKFe8c+gPXJKJjkWCqBZ1SHTx92Riu+NfymBusQuhC6Odn9GNCGxZTOEBRFBx90nH0SQfA1A1QFJSjXD77jjP6s66sllXbq6PexTmcokCKw8or10+mR+YZNHz1FXvuv5+q51+g05134BgUGhWq8wb4wWNL2FbZ2Oy6A4YJhklhaRVnPvwl8y4bjvnnewnu3cfrf/0/fr20gvfWlDcJxg6XZAulDD5z+RgGdY1QIUh0KIqi0DXdSdf08BTT1wt3RKjHF9uEY3P/uh/4wTB++84GFm6tjHmeltOq8f2R3bhqcu9W7JkQQnRQtqTQnyNkUp8s3l2zM6wJezzzwl02jcl9sttzM8VR0o4NAjqOCQVZPPvjMSTZNGwt9FTQVAWHVeXuswfwkyN0IaRo6lEPpAAsmspzV4xhUp9sXDFUpbNbVHKS7cy/fuLBVMjkKVPo/fbbpJ51FtvnzqX8F3fjLt/Fj59dTklF84HUoQK6SUW9jx/89b/4O3Uj/6UXcfXoxoM/GM4Xt8/g6im9SHVYcFo1ku0WkmwaVk1hSt9snpgzmkV3niyB1AnOMEye+qo4bETpwITjzNOuw9V/IqrNgaJZcPUZF1Y8AgX+vXQ7j88ZzeyR3XBYVCzNfFdtmoLdojJ3Wm9+P2uIlMcXQohjwOmDOqNGOB4fOi/c/e0SjIAXUw/iKSqk+vNnmyzr9uv8Z8MelhVXYh5eEUsc15TmfqGjR482CwsLj+DmHNt21Xp4fvE2Xlq6HROzyR0Kp1XDME2+N6wLV0/pzcAuJ+6FuGGYfLh+F48tKKJoXwOBoIF+yMcsyaZht2pcOSmUYpXuijzsrTc0UPnkU7y5YCOPDjoXT4TYv6VUK5sK18/oy82nhadbBnWDPfU+6jwB7BaV7BQ7qY6O3xhVUZSVpmmOPtrb0VrtfXxaV1bLD59cElbN01O8kr1v3kvPn8+PKa0j9//Zu+/wqMq08ePfc860TCaTQhohoYaWAErvAta1YwFfRdeGBSvr7v62u++6677qurvq6tqxACo2xLKsHaX3Jr0FCC0J6Zk+5/z+GIGEmSQzIQkB7s91caFnzsx5opmZcz/P/dx3gpVlvzsfgJ3F1by2sID3VxZiUpU6bRcUBW4c2ombhneKuEomxJngdPh8knun09MT/93MK/N3Rez7V73hW6pWzMF/eG+dfeG27N5AqLLzDUM70SnFzsylu9FUhUlDO3HVgA5Nvu8IBHW+2lTE64t2sbfUhcevY7do5GU5uWN0VwZ2SpYJuWbU0GeTBFNN4AvofLHxIFsOVlHm8uG0mclJsXNpv/ZnxM14LLYcrOKTtfs5UOHGF9BJdVgZmZvKuF7pESvvRXLBE1+zrdQTdry+VCvv3g11VgiS7GZW/O58TC3Zqf0UcjrcrEDLfz59s/kQD76zJqx5d/WGbyn79lVy7psR1etYTSpb/nJxnWNuX5A1e8upcPtQFIVku4Wzc5KwmOR3VJzZTofPJ7l3Oj0VV3m54J/fUe7yx/zcxDgzX/7sHNKdNgzDYPHOw8xcsof524q5tF97Jg3tRJ8o2+gEgjrPz9vBqwt24Q/qYRN+ihKa4G8Xb+GXF/XkirOlUXBzaOizSfZMNYHFpHJZvywu63eyR9L29cxMoGdm0/tg/LCvgsKq8A+uWHo7+AM632wu4sL8zCaPQ5x5/EEjrDcZxL7hOFIVwDiL1iL7KYUQQrSMtAQrb00exoQXF+HyBokmUU8htFdq5uShpP/YYkVRFEZ0S2VEt1SKKj3MWr6XO99cQbrTxo3DOnFZv/b1Fh5y+4JMfmM5qxooamQYoZRCl8/Nrz5Yz7rCCn53aW9ZpWpBMg0q2rSvNx3CGwjftB9Lb4caX5BP1u1vieGJ05jTZibSd0/tDcfRsEs1PiGEOC3kZTmZc+9IUhOsxFsb/myPt2i0c1iYfe/Ieled0p027j+vO/N/dS73jsvlk7X7GfHYN/zl043sKqmpc25QN7hz+gpW7I6+OqzbH2Tm0j089dW26H5A0SSyMiXatIOVXiK194m1t0Nxla+ZRyZOd/kdnBErPsbSiFoh1OdMCCHE6SE3PYGFvzqXLzYePLo33Kyp6IaBqij4gzpd0xxMGdONi/Izo0rf1lSFC/IyuCAvgz2HXcxctptrn19E7/ZObhzWkfN7Z/DW0j2sKIi+WvIRbn+Ql77fybhe6Zydk9TUH1s0QIIpcUqKNdVKKueIWDltZi7p256P1+wneNzvT7RNu+MsGneOkfLmQghxOjm23SOLHcXV7D5cQ5UngMNqolO7eHLTm973qmM7O7+5uDcPXdCDuesP8sr8XTw85wfcPr3e9hqNFePyBoK8/P1Onps0oMnjEvWTYEq0aWkOC6pC2OpULL0dAFLrabgnREMmj+7C3B8OEPSHB+ONNe2G0Kbj4V1lb5QQQpyuuqU56JbW/E2DrSaN8f07ML5/B95dsYfffvhDxPOi6XuoG/DVpkOU1fikcXALkD1Tok07t3cGVlP4ylMsvR3iLRqX9G3fWkMWp5H8rERG5qZibUKVPZtZ5Y+X58mmXyGEECdkwbbDBCPseYil76GqwGfrD7TWkM8osjIl2rSzshPJTLSFbcSE6FOtNFXhwvyM1hqyOM08d8MAJrywiG2HqvFEmaseZ1Z58Pzu/KSPBPFCCCFOTGGZK2L1wFiKcbn9OocqwtvMiBMnwZRo0xRFYcqYbvzx4w0Rc4UbS7WyaAo3DuuEWXpMiSaymTXeu3sE98xcxeIdh/EGghGLokBo5s+sqfzx8nz+Z0jH1h2oEEKI05Knnup9sRbjqvEFGj9JxEzuMEWbd2X/LLqkxmPWYkuXUhVIcVi58xwpACBOjM2sMe2Wwbx713Au75eF1aSSYDXhsGokWE3EWzWS7WZsJpWZk4eEAinDAG81+FxEbFglhBBCRCExzhzxeO1iXI1RFUiR/VItQlamRJtnNYUa3o1/biF7SiMvdR/PpCo448y8e+dwkuzy4SGaR9/sRJ6+vj8VLj/r91VQ4fZj1hTaOaycnZPEs19vZe282QxSPoaCBaFW9BiAArnnw4gHoNMIIjawEkIIISIY1rUdq/aEl0WPpRhXnFmjX7aURm8JEkyJU0JyvIVJwzryzNfb8AcNgrpBIEKulaYoWEwquenxvHrz4KMdx4VoTol2M6O6p9Y9uPUL7l9zH+6aCgw8KECdyH/r57BrPthT4OqXQkGVEEII0Yjrh+bw3LztYcdj6Xtot5oYlZsa9hrixEkwJU4JWw5W8cJ3O/nk/tEYhsG0hbv4YOU+IFRgQjdCwdXFfTK5Y3TXeruNC9EiVrwO//01asBNfL0nGeCvgYoamH41XPUC5I9vvTEKIYQ4JaUn2BjdPZVvNhWFZedEU4zLZla5Y3QXVFWyIlqCBFOizfP4gzzw9mp+fXEvuqSGblX/Mr4vf7gsj+IqL9XeAPEWE6kOK3GW6DZhCtFsNn8G//01BNzRPyfghtl3Q3wqdG68T5oQQogz20MX9GDh9pKIxSgaK8YV1A1pEdOCpACFaPMem7uZ3AwHEwZm1zluNWlkJ9vplekkJ8UugZRofQEvzL4rLJDq/FQV6X+rosZ3bA7xlVU+xr5eq8R/wA0f3A56dOXWhRBCnLnysxJ56rqzsZmjv3VXCPXavLhPJje9uoxth6paboBnMAmmRJv2zeZDfLnxEH8d31ean4q2Z+PH9VbqCxrw9FJfw8/3VsOuec0/LiGEEKedn/Rpz/M3DiTOrGFrpJm83aLRzmFhzn0jeeb6AdwzthvXvbSE//4gjXubmwRTos0qqvLwqw/W88/rzibRHrksqBAn1cJ/gq864kO/HGHhyUVeyj0N1J/0VcPCZ1pocEIIIU4343qms+BX47j/vO60i7cQb9WwWzSsJhW7RSPOotEtLZ4/XZHPgl+dS256AgATBuXw+q2D+fOnm/jb55sJ1tcwUcRM9kyJNknXDX7+7lquH5zDkC4pJ3s4QoSrLoaSbfU+PChLY2xnE08u8vKXcxuoKlnwPQT9oMmEgRBCiMa1c1i5d1wud4/pxtKdhyksc+PyBXDYzPTMSKBvduQiXP2yk5hz30jue2sVt72+nGf+p79MVjcDCaZEmzRt4S5qvAEeOK/7yR6KEJG5DoNmgWD9qXyPjLMycloNDw5toNeZagJ3OTjSWmCQQgghTleaqjAixnLnqQ4rM24fyv/N3cwVzy3gxZsG0ivTGfHckmov763Yy8YDVVS6/ThtJnpmJjBxUI60nqlFginR+kp3QsFC8JSHbiTj00INTeNCzeQ27K/g3/N2MOfekZg0yUQVbZQRbLT5bp90jct6mHhsgY/eafX9Liuh1xJCCCFagUlT+cNlefTtkMgNLy/lkSvzuaxf1tHH1+4t57lvt/Pd1mKAOs2Cv9h4iH99s52R3VK599xcBnZKbvXxtzUSTInWoQdh2xew4Ck4sAYUFfQAoITSm/QA5I3HM3gKD7xbzsOX5ZGTYj/ZoxaifnHJofS8RvxprI0BL1bz8+HWyCcE/WCTrvRCCCFa1/j+Heie4eDuGStZX1jBLy/qyawVe/nzpxvxBvSI9ZWOBFbfbili8c7D/OKiHtw+qmsrj7xtkWBKtDx3Ocy4Goo3g68m/PGgN/T3+vdQ18/md8njOfesF1p3jELEKqE92FOgcn+Dp+WmqFyXb+aZZT76pkdYncrIA7OkSwghhGh9+VmJfHzvKO5/ezUXPfU9+8rceAKNt+wwALc/yJOfb8UwYPLoMzegkhwq0bI8lfDyuXBwfeRAqjYjiMXwMq7qE/j4/npLTgvRJigKjHgQzI2voD48xlqn59RRFgeMnNoCgxNCCCGikxxv4ZcX9WD3YVdUgVRtbn+Qv3+xlZW7S1todG2frEyJlvX29VBR2OAm/eMpfhds+BDSe8OI+1pwcEKcoLOvh6/+GHa4YGpCnX/PSVTx/D7CBl9Fgd5XtNTohBBCiKi89P0u9HomsWs2zqNy+Uf4DxeiWuIwp3clccREbNn5AHj8QZ79Zjuv3TqkNYfcZsjKlGg5+1bB/lXH0vh+1PmpKtL/VlVnpv6VVT7Gvl5r5crvgu8eh0D0QZgQrc6WCGN/E9Xq1PEMUxz85HEwNVDpTwghhGhhZTU+vtp0iEitpyqXzab065dJHDaR7Ptm0GHKayQMuAT3tqVHzzGAhTsOc6jS03qDbkMkmBItZ/FzEIj8xgoa8PTSRgIlIwibP22BgQnRjEY+CGddH1NA5VNtvG2+isreE1twYEIIIUTjZq8ujFicVvfWUL5gJikXTMHecwSqxYaimbDnDiV53G11zlWAWcv3ts6A2xhJ8xMtw10eCoSMyLm3vxxh4YmFXu4ZbCHJVk95aV9NqPpfn6tbcKBCnCBFgUv/DonZ8N1jgAoBd+RzzXYwdMw/eYxNhYP58LXlvHHbEOKtxz6KDcNg1Z5yFu8ooaTah0lVSHda+Ul+ezq2kwqXbVWFt4KPd3zM5tLNVPmqcJgd5CblcmXulbSLa3eyhyeEEPXacrAajz/8fs27bzNGwIe9x/BGX8Mb0NlysKolhtfmSTAlWkbJ1lBD03pWpgZlaYztbOLJRV7+cm4DlcxKNrfQAIVoRooCox+CgbfAqjdDq7K+6lAfNQDdD/Z2MOIBOOt6FJuTPw0w+PWH65j8xgpeu3UwAB+uKuTF73ZSXO3F4w8eTbkwawp//2Ir/bITuXtMN87tlY7SSI8r0To2l27mlfWvMG/vPAC8tdKarZqV59Y8x8gOI5ncdzL90vqdpFEKIUT9KjyR23wE3ZWodieKqkX1OpX1vM7pToIp0TI8lY2e8sg4KyOn1fDg0Ab2jAS8oR5VUb6RhTip7CkwamooaCrdCe6yUE81ezIkd6nT5FdVFf7v6n489O4abnltGYcqvRys8OD2hzfw9QcNwGB5QRkb9q/mnO5pPHN9fywmydQ+meZsn8NflvwFX9CHToRZ3R8Dq3l757F4/2IeHPggN/a+sbWHKYQQDUqMM0c8rsU50V2VGHowqoCqvtc53ck3sWgZpnoalNbSJ13jsh4mHlvQwN4pVQvdjApxKlFVSM2FnMGQPRBSuhIpIV1TFX57SS9W7ymnoKQmYiB1PJcvyLytRdz2+nKCkXYLi1ZxJJDyBD0RA6naDAw8QQ9Pr3ya6Runt9IIhRAiOr0yE4gzh99rWTv0QjGZcW1d3OhrWE0qvTITGj3vdCR3qaJlOLMg2Phy75/G2nh5lY99lfXcFMalRLwJFeJ0YBgGt7+xAl03iCUs8vh1Vu4u5bG5m1psbKJ+W8u2Hg2kYuEJenhm1TOsKVrTQiMTQojYXdW/Q8RKfqo1nqRRkyj98gVcWxej+z0YwQDuHSso+3ZanXMNYOLgnNYZcBsjaX6iZbTrBildoGhjg6flpqhcl2/mmWU++qYfF9trVhh4awsOUoiTa9WecnYW1+CP8C3WWF8Pt19n+pLdTD2/R50CFs0iGIDKfeCtDL0PHWkQl9y81ziFTVs/DZ8evqJeNr+Mks9L8BX50GwazoFOMq7NQIs/lh7jDXp5cd2LPH/+8605ZCGEqFeS3cKF+Rl8tu5AWFDlHHI1anwyFYtnUfLpkyiWOKwZuTiHX3f0HAUY3T2V9IQG9sCfxiSYEi1n1M/g05+FNuI34OExVqavq2cVa9BtkY8LcRp4+fudeCKk9lUum03F0vdpd+G92LoMQNFMuHetxL1t6dFgCkBVFD5as49JQzs1z4AqD8DyV2D5y6GVZVUDwwg13c4ZCiOnQrdzQ2mMZ6hKXyVf7fkK/bhKpSVzSyieW0z25GwceQ78ZX72T99PwZMFdPldF9Qf97cZGCw7uIwiVxHp9vST8SMIIUSYKWNy+XLjoYhV/Rz543Dkj6v3uVazyn3jcltyeG3amfuNKFpe3pURC0cUTE3g/K7H4vicRBXP753MuyX+2EmqGbqNA2f71hipEK2u3OXjmy1FYbOAsfT1cPmCvPT9zhMfTDAAHz8Az5wFi/4FnopQ42xvVWgyJOiDgvnw3s3wzzzYf+amqX224zPU4746g+4gRR8VkXVjFgn9ElBMCpY0Czn35OAr8VGxqKLuixjw4bYPW3HUQgjRsLwsJ49c0QdbhL1TDbGZVX5zcS/6dzxzsxckmBItx2SF62ehm+Jie56iQnwqjJc0GHFqCZR58Gwtw7W+GPeWUvzFrnrP3X3YhTVCNb5Y+noAFJa6MYwTKEQR8MH0K2H9u6HqmbVKe4fxVUPVAXjtYtg1v+nXjHV8rtKo9mC2hm3l23AH6/YRc21zoft1nAOddY5rNo2EfglUb6i7Ou/TfWwt29riYxVCiFhMHJzDn68MBVRqFNvVbSaV313Sm5tHdGn5wbVhkuYnWlRF+iCesPw//ld/HLMexWZtzQKODLj1P6Ey00K0cYZu4NlSStV3hfgKq1FMSmgnrgIEDUypcSSMzSYuPxWlVvBU7Q1EfL1Y+3oYGPiCOlZTE9oHGAZ8eAcUrqy/0XAkfhe8fR1M/hrSe8d+3cZUHYIV02D5S6EG4KoJ9ECoV9fQu2HAzaF9XCdBlS+8KWWwOojJYULRwu8+TIkm3LvD/9tWehtvHyGEEK1twqAc8rKcvDBvB19sPISiUCf1z2ZSMYDMRBu9MxO4aXjnkzbWtkKCKdFiPP4gd765grzeF2EadD58+TDsXgT8uAejNnN86Hi/6+D8P8pmd3FKCJR6KH55PXqND8MX+rIxjouR/AdqKPtwO+Uf7yT19j5YshwAxFkiBz+x9vUwDPh+SzE9M51kJ8ehRjOdeMS+lbDtizqBVOenqnD5YdeDDuItodd6ZZWPGev8dVNxfS6Y+yu4+ePor9cYXw3MuRc2/ydUxfNI0+8jnxc1xfD93+D7J6D3lXDFM2COceX7BCVYwkv/ag6NQHUAI2iEBVSBigAmR/hXbaTXEUKItiA/K5F/3TCAcpePD1ftY9PBSipcfhLjzHTPcHDNgGw0VeG8v3/H1kNV9Mg4sz/PJJgSLULXDX7+7lpSHVb+cGkeiqrATz+CikJY9kroBs5bAYoWSunr/1PoNwEs8Y2/uBBtQKDEzaHn1mB4AjRW19zwBjG8QYpfWEvq5L5YOzrpkBSHLxC+0bd2X4/4XqMaHYfNrDF96R62Haqiwu2nW5qD7hkOemQk0DMjge4ZDjokxaFEajGw6F/HApZaggY8vdTHb0c31C/OgL1LQu/pxOxGx9koVylMuwjKdzecanhkvJs+DlULvfU/YEs88etHKTcpF5tmq1MW3Z5rRzEpVK6sJHHIsbEEPUGq1lWRcW1GndewqBZ6JPdotTELIURTJNkt3Daq/hS+e8bl8tf/bOL1W4e04qjaHgmmRLMzDIM/f7aRkmovb9w2pO5MeWI2XPC/oT9CnKJ0d4CiF9dFFUjVZvh0Sqb9QMbUAWQk2eiXbmL5/rqrtLX7eiiqhq1LfxTVhKdgDZ496+oUobCaVO4Z1437z+Zh4X8AACAASURBVO0OQKXHz7ZD1Ww7VMXWQ9Us2FbC1kNV1HgD5GYk0CM9FGT1yEygp9NHxta5KEZ4QPfLERaeWOjlnsEWkmwNrHQZBix7GS74U/T/ESLxe+DN8VBWEL5qXZ+AB0q2wsxr4Zb/gGY+sTHUwzAMlu0q5fVFBewsrqEmkIgnLRBK4/yRZtdIH5/O/hn7UW1qnWp+5hQzSSOSwl736u5Xt8h4hRCitdw0rBPTFxfw/dZizulxclKv2wIJpkSze3n+ThZuL+G9u0dgMzdhH4cQbVz1sgMY7siB1LLCdfz12+fZWlKAqqp0b9eJP553P2e3D+0tMnxBqmYvJJl/cHdNHBu1m6kJ1i1EEU1fjyOuH9Lx2PNsZgZ2SmZgp7ppshUuP1uLqth6qIpth6r5ZnMROQe/5GFdxRHh5xuUpTG2s4knF3n5y7kN9A0J+mDjnBMPpla+Doe31gmkoko3DPrg4A+w9h0YcNOJjeE4hmHwzrK9PPvtdspcPty+4NH/3TZrPibnOhTl2C9A2iVpaPEaB2cdxFfkQ41TcQ5wknNXDupx1bEGZw4mI77uapUQQpxqLCaV31zSm0c/28TI3FS0WNLMTyMSTIlmNWfNPl5fWMAH94wgMa5lZoqFOJkM3aB6/j6MCCl6Vd4abn3/1zx64UNc3mscvmCAZYVrsWqWYyfp4NoKiRNuZeyNV5Dw5HxcFZ6wuKzRvh4mlfN7Z5DqaCgVLyTRbmZw5xQGd65V1GXlZoy5QOQ6GDwyzsrIaTU8ONQS+YQjTrSQgmHAomfAH16kIap0Q78LFj7VrMGUP6jzs1lr+HpTEe4IfcB8h8dgStgISt0KgyljUkgZ03DhHJtm485+dzbbWIUQ4mS6MC+DVxfs4t0Ve+tM7rUWwzBYvPMwL323k+UFpbj9QTRVwWkzc1X/Dtw8ojM5KfYWHYMEU6LZLNxewp8/3cjMycNon9i6m8KFaC3e7eUYvvAbbICdpXsBGJ93PgBxqsaYLhFyyc02XP6ROMwWpt82hPH/XkiNN/JrRmJSFbKS4nj82n6x/wBHKZH3Uf2oT7rGZT1MPLbAR++0+rtoVHiCPPzOatIcVlITrKQ5rKQlWEn98e+UeEvDs5UFC8BTHvGhqNMNK/eHiml0GFj/OVEyDINfvLeWrzZFbl4JoHuz8By8Alvmxyhq9CXbbZqNe8++lwEZA054nEII0RYoisIfLs3jtjeWc/lZWTisrRdafLHhIH/8eAMVbn+d7AE9aHC4xscbiwuYvmQ3Z+ck8eSEs1osqJJgSjSLDfsreODt1fx70gB6Zp7ZVV3E6c2zo/xo5b7jdU3JQVVUfvbZo1zR6zz6d8gnyRb+fjB8Op6tZTiGtqd7RgLv3DGcG19dQo03SOD4Lr7HsZpUOrWz89Ydw07sS8veLmJT7dr+NNbGgBer+fnw+leGzPHJjOmRRkm1l+IqL5sPVFJS7aO4yktxtZdKt58ku4W0hCNB1o///GOwNXTDTDJ8LiKFSlGnG/rdsOGjZgmmPl67ny821B9IHRGoGIwHsGV+DEqgTsrf8RQUrJqV+/vfz0/zf3rCYxRCiLakb3Yio7un8sK8HVw/tCPTFxfw9aYiKj1+VEUhJd7ChIHZXDMwmwRb82QtvTJ/J09+saXBz2p/0AAMlheUcukz83n7zmHkZzV/wSIJpkS9fAGdbzYfYmdJDVWeAA6ria6p8ZzXOwNLrX45e0td3P76Ch65sg9Du7Y7iSMWouXpNfWvRCRY4/lw0rP8e+lb/L///o3imlLGdRvKEz/5f6TF103/0l3HXqdvdiJzHzyHf32zndmrC1EVBddxq1/xFg2LSeXWkV24Y3TXekurR63L6EYb4eamqFyXb+aZZT76pkdYndKs2Af+D1cPqL+anz+oU1pzLLgqrvJSUu1lX7mbNXvL6bBnF5kNVPGILt3QCDUTbgbPfrM9YmofQM3GeVQu/wj/4UJUSxzm9K4kjz0fZ/+9mBybAVDUY3mTVs2KYRgMyxrGHX3v4Oz0s5tljEII0daMPzuL215fwUvf7wQMfMFjn+sHKjzsKtnCY//dzGX9svjtJb1JiW8khbwBH64sbDSQqk03oNIT4PqXlvDZA6ObfYVKgikR5kCFm+mLdzN9yW50Azy+AEEDNOVIbxyFSUM7cvOIzsSZNW5+bRl3jenKpf3an+yhC9HiIjVmra17amf+eelvAdh+eDcPfPoX/vfrf/HcFX887nXqBidZSXH839V9+f2lvflo9T6+3HSIMpcfk6qQnmDlmgHZjOuV3nwbfG2JkH8VrHsXjPpTDB8eY2X6ugaCroG3NngZs6aS4bSR4axnZWlmImyr//nRphuu21vKt19tO7oCVnslLNqGxusLKygsi9y8uHLZbCqWvk+7C+/F1mUAimbCvWslNRs2YEm/DUWrxpS4EnPcQUb2sJMSl0i3pG5clXsVafYzt8qVEOL0N29LEVNmrPoxsyLy5NiRCcKPVu9j/tZi3r17OJ3axd4Op9zl4zez1+ONsG850oRX4oiJ2LLzAaj2Bvjl+2t5587hMV+3IRJMiTrmbSninpmrCAQNfMG6v6hBA6p/3Nfx2sJdvLl4Nx2S4rggL4NbR9bfh0CI04mWZA3NLAQbr4me264TE/v8hBlrwhvbaomRU+firSYmDevEpGGdTnisjRp+Xyg9rlbT3oKpddMScxJVPL93hj9XUaHbOEg4wap0jsYDjcbSDQ3A5MwgoOusKywPWwWzW0x1UguP/jnu32cu3R2x95furaF8wUzaXTIVe88RR4/bc4dizx0aGkPQgb90DJpZZfSg3vx0eOcm/ecQQohTyYqCUqbMWIk7ylWigG5QXO3lmucXMffBc0hLaLyIUm3vrthLpO2+9U14ubctPRpM6Qas3lPO3lJXs65OSTAljvp2cxFTZq6MatnUFzQgGGRXSTUDcnq2wuiEaBvs/dKo/HpPxMe2H97N1zsWc0Wvc2nvTGd/5SHmbPqaAVn5dc5TLCrxA9tAaezMPnDW9bDunVBVvFhYE+Dix098DD0uDgV0vup6T2ks3VCxOMgbM4G8buGfRbpuUOH2Hw2ujv6p9rL1YFWd46U1vohzqt59mzECPuw9Gp/N9Pj1ele3hBDidOIL6Nz+xoqoA6kjdAPKXX5+NmsNMyYPjf55usHL83eF3adGM+F17NoGbywq4PeX5cU05oZIMCUA2FVSwz0zV0Wdf3pE0ICps9bwSfpIctOl8IQ4/ZnaxWHpkIBvd3hJ8HiLnTX7N/Hy8nep9FbjtDo4v9twfjfunjrnqXYzli4RVntOhkufBFcJbP8quoBKUcESDz+dA8mdT/z6PX4SVcPdBtMNLQ7oMjbiQ6qqkBxvITneQo+Mhj+jrnx2AWsLK8KOB92VqHYnSiMFO46o9tRTb14IIU4j/91wkEAw8n1jYyl3AT1UGKKwzEV2cnSrRBsPVFLjDf98jWXCyx80mL16nwRTovm9+N0OfMHoN13XfkP4gjrPz9vB3yfK5mpxZkgYm03p25vDqvq1T0jj+fENN7BVzCqOc7IbLEveqlQNJrwB3z4Ki58NBUuRgipFA5M1FEBNnA6puc1zfc0EQ+6Chf+EgPfo4ajTDU1xMOI+UOvfTxWtRHvkDdFanBPdVYmhB6MKqNqdwMZqIYQ4Vbwwbwc1EVqFRJNyB6FWFNMX7+Y3l/SO6nol1d6I+4ZjnfCq9ETf0iIaEkwJarwB5qzZT6TJhWjeEEHd4NN1B/jjFfk4m6nkpRBtma1XCrZeKbg3lUIsq7kmBXMHB46hmS03uKZQVTjvDzDyQVg3CxY+HerdpJnB+PHny7sSht8LWf2b//rD7obVb0LVwWPXi4aigSMDBt7SLMMY1CmZpTsPh21stnbohWIy49q6mPheoxp8DbtFIy+r+VYdD1Z42F/hxu0L4rCa6JwaLw3RhRAn3d5SFzuKw9OzY0m58wUNZi3fy28u6Y2uG1R5ApS5fJS7/aG/XT7Kavyhv11+Nh+sDKt0C7FPeAUbaUESKwmmBB+t2RdxM18sbwhVUfhwZSG3SCEKcQZQFIWU63pyeMamUBPfaAIqs4o5M57UW/PDKvm1GTYnDLkDBk8O9W7yVIRWo2yJjfakOiFxyXDLZ/DK+aFr6lGkyakmiEuBWz8L7d9qBtcP6chz324Pv5Q1nqRRkyj98gUUVcPWpT+KasJTsAbPnnUkj7vt6LkuX5B3V+xFVRXO7ZWOuQn/r4O6wbebi3jhux2s31dxrBWFEcoEuCAvgztGd+WsnKQm/6xCCHEiDlV6sJjUsMmnWFLuAMrdfvo/8gWVngB2s0ZSvJlku4Uku4Vk+5F/NtMtLZ5ku5kf9lXi1usGVLFMeAEn3lrkOBJMCdbtrYgY6cfyhnD7gxH3GghxulI0lXY35VH1zR6q5u8Dw4jYzFexqGCAfXAGSZd0RTG10UCqNkUBiz30p7WkdIW7F8CMa6CsIBTMRSoHoahgskFqd5j0PjjSm20IaQlWzumRxlcbD4Vd2TnkatT4ZCoWz6Lk0ydRLHFYM3JxDr/u6DkWk8ptIzvTLc3Bq/N38bvZP3D1gA5MHJRDbrojqjFsOVjFT19dSrUvQM2P1VOPv1n5z/oDfL2piN7tE5h2y2CS6klPFEKIlhKpNDnEnnKnKPCfB0eT6rA2OvlU7Q3w0vydYcdjmfBSgMGdU8Je40RIMCUoc/kiHo/1DVFez+sIcbpSVAXn+Z1IGJuD+4cSqr4vJHDYg+HXUUwKWqIVx6gO2PunozbzTNhpyZkFUxbB3mWw6BnY9mUo1VBRQ+l/uh96XhraI9VhYIsM4cHzujN/W3HEYjyO/HE48sfV+1yLpnLbyC6kO21MGJTDzuJq3l1RyA0vLyEnxc7EQdlc2i8LhzXyV++aveXc8PKSiJNbtelGaAJr/b4KLnl6Ph/fP4pUR2zlhYUQ4kQk2EwR57tiTbnTFIX2iXFRXdNhNXHl2R14f+XesK0p0Ux4QWhV6q5zukV1vWhJMCWIr+eLPdY3RH2vI8TpTjGp2M9Ox352862SnLEUBToOhY4zwV0e2rvlqw6l8jk7hFIRW1CfDok8cU0//t8H62Kqbhpn1njjtiGk12pO3DXNwa8v7sUvLuzBvC3FvLtiL49+tomL8jO5bnAOAzslHy1EsrfUxU2vLm00kKrNHzQoqvJyw8tL+OT+UVE3JxZCiBPVNc2BP8Jm+1hT7qJdtT/i9lFdmLNmH0E99gkvgCS7mWFdZWVKNLMuqfFYNDWsSW8sbwizptAlNfZO1kIIUa+4pNCfVnbF2R2wmDR+NmsNQV0P9dWrh82sYtFU3rhtCP07Jkc8x6SpnJ+Xwfl5GRRXeflwVSH/74N1AEwclMPVAzrw1FdbI5b8hYYrqgZ0g8IyN5+tO8DVA7JP/IcXQogoOKwmLjsri9mr99Up6BBLyl28RWPK2NhWiXpkJHBe7wy+3nQo5nY+NrPK/16e3+zVdCWYElw7MJtnT3DTtaooTByU05rDFkKIFvOTPpn0yx7Dm4t3M3PpbnQDfIEggaCBWVMxmxTsZhO3j+7C/wzOiXrfUlqClbvGdOPOc7qyak8Zs5bv5bwn51HjCxKpwFQ0FVVdviDPz9shwZQQolVNHt2FT9ftD6uOF23KnaIo/KRP7NVt/znxbK5/eQkb9ldEHVDZzCq/uLAnF+Y3fzVdCaYEDpuJ9AQrhWXusMeifUOcnZNETkorblYXQogWlpUUx68v7sVDF/Tgm81FFJa5cP1Yorx7hoOR3VJRI/Q8iYaiKAzslMLATil0S3Pw5Bdb0I9bAYulomphmZsf9lXQp0Nik8YjhBCx6pXpZGiXdiyJ0FKisZS7OLPG/eflNik92WJSeeuOoTz07lq+3ngIv65HbO8Tuo6KYcBfx/fl6oEtM+EkwdQZzBfQeWvpbp79djv5WU5Kqr1N2nQdZ9a4d1wzNfAUQog2xmJSmzR7Gq3vthbjj5BKGEtF1YCus3B7SZODqaIqD0WVXryBIAk2Mx1T7NjMsgdLCNGw528cwOX/WsDO4ppI9SgiijNrXNI3kztHd23yda0mjeduGMDWQ1W8umAXc9bsw3ykebsSajHhsJq4Y3RXJgzKbtGqpxJMnYEMw+C/Pxzk8f9upmO7eKbfPpTe7Z387b+bmbawALc/+g3QcWaNm4Z14pweaS04YiGEOH01R0VVf9Co93XqE9QNvqndz0pTURTQDQNdD6WA3zqyM13TYtsgLoQ4c8SZNXplJlDm8uPxBxssoqOpCmZNYdLQjvz2kt7NsnepR0YCj1/Tjz9clsemA5VUuPyYNIVUh5W89s4mZw/EQoKpU4Q3EOSrjUXsKK6mwu3HaTPTOdXORfmZMc0ertxdxl//swmXL8ifx/dhdPdjQdAvLuqJXzeYvnh3VAFVnFnj+iE5/PriXk36mYQQQoRuMCIej7Gi6tFZ2SisL6zg1teX4fYHj/az8h2XpvP2sj28u2Iv5/RI41/X95eVKiFEmJe+38nuUhfzfzmOFXvKePG7HazcXYamKgR0AxXQNIVA0OCyfu25fVRX8rKavyqrw2pq9v5R0ZJgqo3bX+7mjUUFzFy6BwMDlzeIQajpmN2i8ZsP1zNxUA63jexCx3b171kqKKnhic83s3pPOT+/sCdX9e8Q9gWuKAq/vaQ3Azsl89SXW9l1uAZ/0KizsVBTFMwmhU4p8Uw9vzsX923fQj+5EEKcGdITbEBl2PFYKqpaNYV2jujSWBbvOMxtry9vdNIsoBsEdIPvtxZz1b8X8sGUEdgtctsgWke5y0e5y4+iQJLdQmKc+WQPSRzn2y1FvLpgFx/dO5J4m4kxPdIY0yONwjIXy3aVUuH2Y1IVUuKtjO6RitN2ev4/lE/FNuzbLUXcM3MVgaAelk9vADU/LqXOXLKbWcv38veJ/bikb1ad80prfDzz9TbmrNnH5NFd+fuEs4lrpHnoRfmZXJSfyaYDlbyxqIDNB6uo9gZwHP6BHn0GcfOYXuRnySZnIYRoDtcOzGbpzsNHP9OPiKWiqu730+sfv6d45VAc552HLS8vYgrN9qIqJr/ReCBVmzegs7O4hslvrGDG7UNbJW1GnJk8/iCfrN3PC9/tYE+pC4umYhBaNe2ZmcDdY7pxUX4mFlP0q7CiZeworuYX767lxZsGkpVUt+ludrKd7OQzpyiZBFNt1DebD3HPzFVRlXz06wZ+PchD764lqMPlZ2Xh8Qd5bWEBL8/fyWX92vPlQ2NIdVhjGkPv9k4eu6bfsQMvPgyjh4EEUkII0WwuyMuoN9Uv2oqqQ7pn0O+6h6j6+hv2PfQQhs9PwrhxJJx/HvbBg1HMoRnhRz/bVO+ehob6WXkDOmv2lrNwR0md9HAhmsus5Xv50ycbAI7+jvqDx35XN+yv5NcfruO3s9fzt2vPatGiMKJhFW4/d7yxgl9e1JNBJym1ri2RYKoN2lVSw70zV8fcjMzj1/nl+2vZU1rDW0v30qeDk/fvHn7Cm4e3lG5h+sbpLLRV4vr2ThTNjNPi5PJul3Ndz+tIt6ef0OsLIcSZzKyp3DS8Ey/P3xW2bwmiqaiq0i87ife8NpTRE0i+6AYGa9WYFn5H0dNP4yvYjWPUKNyjz2XhdiJW3Iq2n9UL3+2QYEo0u6e+3MoL3+9o9L7nyP6+qbNW89uq3vx0eOdWGJ2oLagbPPjOakZ3T+V/hnQ82cNpEySYaoNe+n4nvmDsM4cQCqhe/n4XL9886IQ34q06tIpHlz7Knso9+HU/QYKgA7qPGn8Nr//wOq9veJ0hmUN4eNjDtHfI/ikhhGiKe8bmMveHg+w57CIQqXtvPVQllIb35uKCo88zaQr+oMEFvQdxx5MTybf4qP72W176ZhNYu4FWd99CLP2sVhSUsb/cHZbWI0RTvbNsT1SBVG0ev85f/7OJDKeNi1qgCauo3xOfb8br1/n9ZXkneyhthgRTbUyNN8Ds1YURm49FM3MI4PYHyT3B1ai5O+fyh0V/wBv01nuOTw+V4V20fxHXfnIt0y6aRs+Unid0XSGEOBPFW028c8cwrnl+EYcqvfjq60BZh4FuhNIDa++38gZCf8/94QDfbC7i4r6ZPDFhAksOLMB3qCrsVWLpZ6WpCot2HObaFmp+Kc4sLl+AP32yMWIgFc3k8a8/WMd5vdIxabKHqjV8tHof/1l/gI/vHYVZ/psfJcFUG/Ppuv2oETYNxzJzqCjw/spC7jinac3QFu5b2GggVWdshk6lr5JbP7+V9y5/jw6ODk26rhBCnMnSnTY+e3A0v3xvLd9uKUYhtOp0PFWB0CJUw4UgdCM0uTZ3/UHKanxUuE+8n1UgqFMeYz8rIeozZ81+IrUainby2BfQ+WZzERfK6tSJKd0JW/4LNSVg6BDfDrqdBxnHVp/WFZbzyKcbeeuOoSTHt1wD3FORBFNtzLrCioibg2OZOfT4ddYUljfp+p6Ah5/P+3nUgVRtNf4afvX9r5hxyYwmXVsIIc50TpuZF28aRFGVh7eW7GH6kt2UunwogIKCzaLiC+jowehTAd3+IEt2Hq63Cl8s/awURam3WIYQsTAMgxfm7Qi754ll8rjmx318Ekw1ga7Dts9hwVNwYE0oiAr+OFGiWuCbRyG1O4yaSlGHC7lr+kr+elVfemU2f4+oU50EU21MfR3sY5k5BKhw+Zt0/c8LPseIuD0ZyuaXUfJ5Cb4iH5pNwznQSca1GWjxoTHphs7m0s0UVBTQObFzk64vhBAi1Htq6gU9mHpBDwzDwBvQKavxMebJeWGtMqDxlCi3X693HSuWflahnjEyKy1OXHGVl4OVnrDjsUweA6zZW44voEu59Fj43TDrJti9CPw14Y/rvtAe+YPrMObcT7mRzU8HPS8VFOshv3ltTII1ckOz2jOH0XBYmxYnv/rDq7gCrrDjJXNLOPjeQTInZpL37zy6/qErvsM+Cp4sQK+VhhLUg8zYJCtTQgjRXBRFwWbWmLF0d8SAqHLZbEq/fpnEYRPJvm8GHaa8RsKAS3BvW1rnPJOmYIqwqlS7n5Vr62J0vwcjGMC9YwVl306rc25ANxjbQyq4ihNX5vJH3HcT6+SxWVOp9DRtAvmMFPDBm1dCwfzIgdRxFH8NXQI7uXvbXeCtboUBnnokmGpjuqbFY40wu1J75rAxZk2ha1p8zNfeW7WX/dX7w44H3UGKPioi68YsEvoloJgULGkWcu7JwVfio2JRxdFzA0aAz3Z+FvO1hRBC1C8Q1Jm+eHfYHqojKVEpF0zB3nMEqsWGopmw5w6t09QXwB806q0U6BxyNcnn3k7F4lkU/msShc/fQtWqT4nrfmx1QFXgwvwMEu2RJ/2EaA6xTh6LGP3nF3BgHQTCVwXrY8aPUr4H3rul5cZ1CpM0vzbmqv4d+PuXW8OO1545VFQNW5f+KKoJT8EaPHvW1fnSVBWF6wbnxHztEncJZtUctl/Ktc2F7tdxDqybJ6vZNBL6JVC9oZrkc5KPHq/x16AbOqoisboQQjSHoqrIFf5iTYlSCAVFkbZcNdbPymrSuGN00wobCXG8ZLs54u90LGmnAP6gjtMmAX5Uag7D2nfguPu8zk9V4fLDrgcdxFtCq9evrPIxY52febf8ODkf9IZWs0q2Q2pua4+8TZO73TYm3WljVLfUiKkc0cwcAvTLTqRTu9hXpvzByMvkweogJocJRQsflSnRRKA6UOeYoigE9EDYuUIIIZqmyhPApJ54SpTVrOKMM0esoNaQOLPGVf070C87KbYnClGPtAQrHSL0K4sl7RRgQMdk2S8VrVVvUN+bP2jA00sbqdSpB2Hp8y0wsFObrEy1QfeM68ainSUR+y40NnMYZ9a479zuTbpugiUhYvEJzaERqA5gBI2wgCpQEcDkqPtrpKBg0WSDshBCNBerSUU3Inw+x1CJD8Aw4NWbBzP5jRVUePwEo2gQHGdWGdszjT+P79OksQsRiaIoTBnTjf/9ZENYRT/nkKtR45OpWDyLkk+fRLHEYc3IxTn8ujrnxVs07h7TrTWHfeoyDFjy73rT+345wsITC73cM9hCkq2e2RbdD2veggsfBbOtBQd7apFQvg0a1DmFe8fmEmeObqbxiDizxk3DOzGmR1qTrtslsQu6ER7A2XPtKCaFypWVdY4HPUGq1lURn1d3Fax3Su8mXV8IIURk7RwW/I2kREXDMELZC3OnjmZAxySsJjViUQoAu0XDZgS4MaGSf08agKrA8oJSPlxVyPQlu5mzZh8b9ldEfK4Q0bj8rKx6H3Pkj6P9zU/R8aEPyLlvBukT/hdbdt37C4tJZVwvKYgSFb8L3KX1PjwoS2NsZxNPLmqsNY4CVeH7689ksjLVRt13bi6VHj8vz98V1flxZo0bhnbkNxf3avI1bSYb47uN572t7xEwjqXpaXaN9PHp7J+xH9Wm4shz4C/zs3/6fswpZpJGHEv7iDfFc3vf25s8BiGEEOESbGZGdGvHd1tL6hyPZT+tpihc3DcTk6aS4bTx3t0j2FVSw2sLd/HBykLc/iCaqhDQDTql2JkythsXJQXYeettvNy3F9NWl1D1Y9W0oGGgKQq6AR2SbNw9NpfL+rXHFuMkoDizxVk0Hrkin9/P+SFiNk5DbGaVJyecJX3PouWtAtUcStWrxyPjrIycVsODQxvILlJV8FTW//gZSIKpNqzgsIuL+2Syv9zNloNVBPS6lZg0Bcwmla6pDh44r3uz1P+flDeJD7d/SCBYd89T2iVpaPEaB2cdxFfkQ41TcQ5wknNXDqr52AKnpmqMzRl7wuMQQghR113ndGNFQRk1TUyJspjUsAISXVLjeeTKPjxyZR+8gSAev06C1XS0we+iHSVMHv1z9HkFeJTItwzbi2v445wfeGzuJt6+YxjdMxKa8acWp7trB+VwqMrLv77ZFnVAZTOrPHxZHuf0TGFt8VoqBfbwbQAAIABJREFUvKEV0kRrInkpeZg1KUgRxhzXYCAF0Cdd47IeJh5b4KN3Wj3Ja4YBltj35Z/OJJhqo95etpf95W5m3zMSi0lle1E1by4uYNOBSqo8ARxWEz0yEvjpiE7N2o26k7MT43LGMW/vPDzBunm1KWNSSBmTUu9zbZqNqQOmYlLl10oIIZrb8G7tSLJbcPncYbtbG9tPqyrQsZ2dPh0S6z3HatKwmo6tLM3bUsTdM1biMVRopDprjS+Iyxdk/L8X8sGUEc36vSROf/eOyyXTaeOPH2/AMIywCYMj4i0abn+QiUNTKDHPYcw776Cjo/xYtsvAQEFhYs+JXN/rejLjpcnsUVYnaObQvqcG/GmsjQEvVvPz4dbIJwT94MhogQGeuuSutw3aUVzNk19s4d27hh2tUJOb7uCRK1tn8+9fR/2VWz+/lc2lm8PKpNfHZrIxscdEJvSc0MKjE0KIM5OiKLx6yyCu/veisA37jbGZNUZ2a8fvZq9HNyDNYWFMz3QGdExCiVDda3tRFffMXBVT6pUB1HiD3PDSUr7++RiS46UQkYjeNQOzueys9sxdf5AXvtvB9qLqo019/UGd/Cwnd4/pxndFM5m9dybmMhW/Hrn63IyNM5ixaQY39b6JBwc8GPF3/IyjKHD2DbDyjQYDqtwUlevyzTyzzEff9OMnURTIPQ9sMllSmwRTbYwvoDP1nTU8dEEPctNPTqqEWTMz7aJp/Or7X7Fg3wJ8ui9iYQoAs2pGVVTu6ncXt/eRvVJCCNGSemU6mX77EG6etpwaX4AIBf7qUAwdRVUJ6gZvLCo42l9KAV5ZsIu0BCt3ndOVqwdk19nv9PRX2/D4IwdsNRvnUbn8I/yHC1EtcZjTu5I4YiK27HwAXL4AM5ft5r5xTassK85cVpPG+P4dGN+/A1UeP+UuP6qqkBRnxm7ReGTxI3xz4DNQAjQU5/t+DLJmbp5JkauIR0c9KgEVwNApsHpGo6tTD4+xMn1dhHPMdhjxQAsN7tQlwVQr8gV9fLX7K344/ANlnjLizfHkJORwaddLSY1LBeCfX20lw2ll0tCOJ3WsFs3CP8f9kw2HN/Dmhjf5es/XmFUzuvHjcroSKoE+occEru91Pe0d7U/qeIUQ4kwxsFMKn9w/ij9/upEF20tQAG+g7p2l1aTgCxgogG6EP24ALl+Q3Ydd/PnTTby+qIC37hhGqsNKucvHFxsPEalqeuWy2VQsfZ92F96LrcsAFM2Ee9dK3NuWHg2mPAGd1xYUMGVMrhQHEE2WYDOTUKsZ76vrX+XTXZ+GbUFoiCfg4avdX5GTkMOUs6e0xDBPLam5+DL7oxQuw8yxyZKCqXUn73MSVTy/P271SVHB2R46DmuNkZ5SJJhqBQdrDjJj4wze3/Y+GFATqDn6mFWz8syqZxieNZyhKdfw4aognz0wus3MoOS3y+fxcx6nwlvBuuJ1VPoq0VSNZGsyA9IHyCZPIYQ4CbqkxjPtlsEUVXl4a8kePlt/gMofK+05bWaqPH5Ka3z4go13QHH7g+wsruHKZxfwnwfOYdbyvRH7eureGsoXzKTdJVOx9xxx9Lg9dyj23KF1zvX4g3y3tYhze8neCnHiXH4XL6x9IWIgVTa/jJLPS/AV+dBsGs6BTjKuzUCLD620uoNuXv3hVW7Mu5EEy5ldHGX1njJ+V3QH75m3YgqUodSTdRSRxQGT3qu36e+ZTIKpFrby0Eru/fpefEEf/gjLqkf2JH1f+D3f7VnERYMn0i7+vNYeZqMSrYmMzh59sochhBCilvQEG1Mv6MHUC3ocPfaL99by6dr9+IKNN+Q9IqAbFFV5uWP6CuxmLeJeKe++zRgBH/Yewxt9vRpfkNW7yyWYEs3is52fRZxkLplbQvHcYrInZ9dp21LwZAFdftcF9cd956qiMmf7HG7Mu7G1h95mzFq+hyf+u4XHrhlFfPtv4fVLoLoIgpH3nR2lmsCWCDd/CildGz73DCVNe1vQmqI13P3l3dT4ayIGUrUZGKD6+a7oA/658p+tNEIhhBCnk8PVXj5Zux9PIDwYqtk4jwNvTGXPP66l8NmbOPTuH/EUbjj6uD9osK6wnIMVkdOogu5KVLsTRY2ul1RJTSM3aUJEwTAMpv0wDXfAXed40B2k6KMism7MIqFfAopJwZJmIeeeHHwlPioWHWso7Q64eX3D6xiNbTI8DfkCOr//aD0vfr+TWXcN54K8DEjuBHcvgCF3hlacLI7wJ5rjQ+XU+/8UpiyCjLzWH/wpQlamWki5p5wpX02JKbcXQvm9b29+m75pfbmg0wUtNDohhBCno7eX7yFSEk40e50gdON1uCZyFVctzonuqsTQg1EFVHaLNPAVJ67aX83BmoNhx13bXOh+HefAunt7NJtGQr8EqjdUk3xO8tHjZZ4ySj2ltItr1+JjbiuKqjzcO3MViXFmPrp3JM5ae9CIS4aLHoXzHoaNH8O6d6CmBAwd7CmQfzX0vVZ6SkVBgqkW8sG2D+pdjWosv9cT9PDs6mclmBJCCBGT1xYUhK1KxbLXSTegtJ4VJWuHXigmM66ti4nvNarBcVhNKllJcU38KYQ4ptJXiVkzEwgE6hwPVgcxOUwoWvj0gSnRhHt33ZUsk2qi0ld5xgRTq/eUcc/MVUwclMOD53U/2og7jMkK/SaE/ogmkTS/FqAbOm9ufDNij6aSuSUcfO8gmRMzyft3Hl3/0BXfYR8FTxag1/oC3F+9nw2HN4Q9XwghhIjE4w9S5goPhGLZ6wRgNqnYzOG3B6o1nqRRkyj98gVcWxej+z0YwQDuHSso+3ZanXMN4PJ+UuVVnDiTYoqYnqc5NALVAYwIewMDFQFMjrrrBQYGZvXMKJr17vK9TH5jBX+6Ip+fXdCj/kBKNAsJplrAkv1L8ATC0/tiye/16T5mbpzZmsMWQghxCqv2Bo42Oa0t1r1OZlWtmw5Ui3PI1SSfezsVi2dR+K9JFD5/C1WrPiWu+7FATQFGdmtHutPWpJ9DiNoSrYkE9EDYcXuuHcWkULmyss7xoCdI1boq4vPqpqf5dT9J1qQWHevJ5gvo/OGjH3jhux3Mums4F+ZnnuwhnREkza8F7KzYGfGNH0t+r27obCnb0irjFUIIceqzWzQCEZpDxbrXScfghiE5vPj9LtwRGvc68sfhyB9X7/NtZo27xnSLbfBC1MNmsjE4czCLDyyuc1yza6SPT2f/jP2oNrVONT9zipmkEXUDp7PSzsIRqdDCaaK4yss9M1fitJn56L6R9U6IiOZ3xgRTQT1ITaCGOFNciy/z1le9L9b83hp/Tdh5QgghRCRxZo04s0a1t+5kXix7nQD8AYNbRnZh9d4Kluw8HNbwt7Ex3DC0I8O6nhn7UkTruLXPrawtXosr4KpzPO2SNLR4jYOzDuIr8qHGqTgHOMm5Kwe1Vqqq3WTntj63tfawm8TtCzJvSxHF1V58AR2nzcyATsnkptcfCK7ZW86UGSuZMCiHqQ3tjxIt4rQOpiq8FczZPoc3N75JkasIk2oioAdIsCRwbY9rub7X9WTGN/8SqN1sx6SawgKq2vm9xwdUkfJ77SZ7s49NCCHE6UlRFCYN7ci0hbvw19pHUnuvk6Jq2Lr0R1FNeArW4NmzjuRxt9V6DRjbK40ku4UXbxrI5DdWsKLgMJ5A4yWl48wa4/t34HeX9G6Rn0+cuYa2H4rD4ggLpgBSxqSQMialwefbTDZGZo1sqeE1ix3F1by2cBcfrNyHpoZaFeiGgUlVMQyDHpkJTBnTjfPzMuqk8767Yi+Pzd3M/13dl4skre+kOC2DKb/u5/FljzN7+2w0NNxB99HjEKoMM2PjDGZsnMGwrGH83+j/w2lxNvSSMclJyMGiWsKCqdr5vYlDEo8eP5Lfm3HtseaGCgqdEzs325iEEEKc/m4a3onXFxUQKgFxjHPI1ajxyVQsnkXJp0+iWOKwZuTiHH5dnfOsJpVr+mcTCOrYzBpv3DaEv079B+85ehA0manxhaf9xVs0EmxmfnZBdyYOyonYXFWIE6EqKv8Y+w8mfz455pYzNs3GP8b+Ay3KPYOtzTAMXvhuB09/tY2AboSl6vqDoffcusIKfvHeWjokx/HWHcNIjDPz5083smBbCe/eNYzc9ISTMXzBaRhMeQIe7vryLjYe3oivga7OPj302OL9i5nw8QRmXDKDNHtas4xhZIeRqGr4JuBY8nttJhuTek9qlvEIIYQ4M2Qn2xnerR2LtpfgO67KWWN7nQC8AZ0H31mN2aRy07BOXMV+btj6Fb+afS/f7Cjjlfm72FvmwuvXibNo9MxM4M7RXRnerZ0EUaJFnZV2Fv8Y+w8emvdQ1AGVTbPx+DmPMzBjYAuPrume+HwLry/8/+zdd3hUVfrA8e+9d3p6AgkthISAQOiE3uyiIpYFUbE37GV33f3trlvcblnXsiooIirFXhAFsQDSkd5bQhqkkF6m33t/fwwgYSbJDEwSgufzPDzo1JOQ3DnvOe95X/+WBoHUuVWyj9Zx+Ysr6RJrJT7CJM5HnQXOqWBK0zV+ufyX7CrbFbAseSAezUOJvYQ7vr6DDyZ+gM145ql1RtnIjb1uZM7OOSeCtuOCze+Nt8QzOHHwGY9FEARB+Hl5ceogrnhpJcXVzoAFKRqj6+D0aji9GrNWHuINj5vLrnyM/xiNTOjbkQl9RblzofWM7TKWORPm8Kc1fyKvOg+P5kHV6++WypKMUTbhcsTw1Pi/c2HXEa002qZ9urmAOasP4fAEfy7Rq+kcrXGh6zofTh+JwSAKc7e2cyqYWp6/nI3FGwMGUo01yvXqXgrrCnlr51s8OOjBsIxl6nlTeXf3uxDg96Op/F6LwcL0/tPFKp8gCIIQshibkU8fGMUNb6zjcIUjpAISJ3OrGsgGvi3RmDpzLe/dOxKL8exMlRJ+PjLaZfDxpI/ZV76Pd3a/w/L85ScKdtkMNsZ1GcetGbeyepeF91aVcnkP/aycT+m6ztNL9jUYSNXtXk71j5/hKStANlkxJqYRM+p6LF0yALC7VdYeKmNsj/BkVQmn75wKpmbvnI3D6/C7vXRxKUcXH6XL3V3qpdblPJdD6h9SkQ0ybtXNgr0LmD5gOgb5zL8tibZEnhv/HL9a/quQ8nstioVLul7CNenXnPEYBEEQhJ+nxGgLXzw0htmrD/HW6hxcHjXgeadgOD0aewprmP7uJt66faioFCacFc6LP49/jPlHg/f3GK3x/sZ8vtldfFb2W1qbVUaN07/yM0D1hk+pWv8RCZc+iCV1MJJiwHFoE44D6+sFUzNWZIlg6ixwzuwN5lbnsrd8r9/toTTK9WgeVhSsCNuYxnUZx7/H/huLYkGWmv5WWw1WJqRO4K+j/3pWrqIIgiAIbUeE2cDDF/bgxz9czIs3DGJi/44kRJgafHzd7uUUvv0Yec9PpuB/t1D8wZ9xFuwCfGepfswpZ3VWaUsNXxDOiFGR+ctVGfzty904j/VL03Wdijo3h0rryC2ro8oROJhpCTN/yA64wKG56qhcNY/4S+7Hdt4oZJMFSTFgSx9er/ImwMacCgqr/DcRhJZ1zuxMfZ/3PZruv1UaSqNcu9fOF1lfcFHXi8I2rotSLmJe9Dxe3/46y/KWIUlSvTRERVIwyka6xXTj7n53c2nKpSKQEgRBEMJGkSUu7pNEZrc4hv/zu4CPCXYlfOaKbLESLrQZY3q0I6NjDP/7/iBd4qzMWJHF4UrHidLibq9G747R3De+O5dm1C853tx2FFQFvN11eC+6142t58gmX8NkkNlTWE3HGGu4hyeE4JwJpo46joalUW6pI/yrbj3jevLc+OeodFby4b6P+e/KbxiWbsVmsNI1uiu/6PELesT1CPv7CoIgCMJx7/+YT6AMveMr4QlXPIbtvFEnbrelD8eWPrzeY3/MKedIpYNOsWLyJrQNfTpF8fw3B7AaFRzHdqiOlxsH2HG4it98vI3ffSLx/PUDubhPUkMvFVZ2jzfg7aqjGtkWjRREKXdN11t1d03wOWeCKVULnAseaqPcU6vChFOsJZYLOk5lvrM7cyac32zvIwiCIAinWrS9MOBh91BWwhVZYuWBo0wd2rU5hij8TOSW1fHW6hy+3FFIrdOLjk6k2cDFvZO4a0wqPZLC0zPp6cV7j/Vd40QgFUidy3ffQws28+eJfbhxeEpY3r8xRkXGGeD3UbFGo9mr0TW1yYBKQsIqisK0unMmmEqwJiAjo51SPi+URrkAcea4Zh1nfrmd5PgzL78uCIIgCKFoaAU7lJVwt1ej0i5WwoXTc6C4ht99soMdh6vQNB3PSaX7nR43H27K57Mth+mRFMk/ru1H/y6xjbxa495Zk8OcNTmNBlGncno0nlq0mw4xVi7olXja7x2M9pEmapz+u1Pmzr2QDEbs+9cS0WtMo6+h6zpJ0ZbmGqIQpHMmmBrVaRSzdszyq+YXSqNcq2zmkpRLmnWceeV2usaL9AhBEATh7BDKSjiAONbbdtU4PSzZWURhlRO720uM1URGp2jGpLdr9iqN67PLuHPOj9jdKg11P1M1UDWNHYermTpzHf+7aRAX9Q497a7a6eGfi/cE3PlpquS406PxxEfb2PD7i8P+PdFVFfumTdQsWcKEHWXMTr0Ap1y/4a5sjiB2zDTKv5mBJCtYUgchyQacOVtx5m2vV4Qi0mJgYPLpB5xCeJwzwVTfdn1JsiWRU53jd1+wjXJRXUxY9gLU1ELfX4Ap/DtIeeV2kuPEzpQgCILQsuJsRvLK/W8PZSXcZJCJtTVcEVA4O+0rquGNldl8se0IiizhOBbQKDJYDAo2k4G7xnTjhmFdm+Xfd09hNXccC6SC5fCoPDh/M3PvGk5mt4Z7cwbyyaaCgMW8gim0AuBwq/xw4Cjnn3fmu1O6quLYvJnqxUuo/mYphnbtib7sMu76+zRmz90PAQK+6GHXIUfEUbX2fUoXPYdksmJOSid65NQTj7EaFe4ZmyaKlp0FzplgCuDOvnfyrw3/CthrqqlGuYqkMKnnL7AmDIMf34Rv/ggDboTMO6Hd6RWH0HWdbQVVzFqZzbb8SurcKnUuL51jrXSMtTIhowMm0blaEARBaAFXD+zM/uJav7SnUFbCvZrO+J6iml9b8uaqbJ79eh8eVUfV6u8JqRq+uYlb5YXvDjDjh2zm3T2cjE4xDbxa6HRd5/65m0IKpI5zejSmv7uJDX+4GCXIXSJd15n5QzaOU94vlEIrdccqV55uMKVr2k8B1NKvMSS0I3rCZaS88w7m1NQTj7uqfy2fbzuCO0Bj7ciMC4jMuKDR95mSmXxa4xPC65wKpq5Iu4J3dr9DTlUOXj1wlZSGRJmiuHfAdLAlQs/LoCIHNs2Bty6HxD4w9C447wpQjE29FABLdhbx9JK9FFU5cXlVTr5+ZZfW8buPt/P7T3Zw68gUHru4pwiqBEEQhGY1ObMLTy/x78cIwa2ES8Do7gnijEYb8uryg7z83cGA6W6ncno0nB6NKTPW8tF9o+jTKbrJ5wRjc14FJTWugPc1lXLnG5fKsr0lQVfZK6hwUGF3+90eSqEVgPWHytA0PehUP13TcGzZQvXiJdQsXYoSF+cLoN5+B3NaasDn/HlSBhtzK8gvt+PVGkp+9GcxyvzvpkHEWIObkwrN65wKpsyKmTcve5MbFt1AqaM0YKn0U8nI2Iw2Zl06i0TbSSsQcd3g4r/A+b+DPV/Autdg8W9h8K0w+DaI6dzga77w7X5mrMhq9OJ1vFHb7NWHWJtVxjt3DSPKIn4pBEEQhOYRbTFyZb+OfL7tMGqAj6emVsKtJoV7x3VvxhEK4bRsXwkvfXfwRMPaYNndKjfNWscPv7mA6DDMS17/ITtgEYhgU+7q3CozVmQFHUxV2j0YZRnnKQXJQim0Ar7KlbVub6PfA13TcGzd6gugvv4aJTaWqAmX0XXOW5jT0pp8j0izgfenj+CG19dxuMKBK8AO1aksRpmnr+t/WmfJhOZxTgVTAPGWeD686kMe+f4RdpftxqN5Gix3bjPYiLfEM/OSmXSNbqDMq8EM/Sb7/hTvgo2z4bVR0G2Mb7cq9XyQf9pVeuOHbGauyA5qFQh8K0G7jlRx2+wNvD99ZIs2jBMEQRB+Xn57eS+W7ztKhd3dYBGAQCxGmfE92zMiLbSzK0LreXbJvgYDqaZ2hFxejU82FXD76MA7KsHyqhrf7SlBP+WHLZSUO4BtBZWUV9YR5apDraxArahAraw88be3ogK1ohK1spLDdhm146W++dtJQi200hBfALWN6iWLqfl6KUp0tC+Aems25u6hLzYkRln44qExPLd0H+//mI/ETwvuxxk0FcWo0LdLHL+/ohdDUsTv4dnknAumAGLMMbx9+dvsLd/LO7veYWnuUgyyAQnfVq1bdTO0w1Du6HsHwzoMC/7wXlIGXPkf347V9g9g6R/B4/Cdqxp4Ewdrjfxn6T6cQawsnMyt6uwurGbG8iwevkg07xUEQRCaR1K0hQX3juD6mWuocXoJJrPIalQYmBzLizcMEofd24i9RdVkl9YGvC+YHSGHW2XmD9ncNqrbGf2bVzk8KLLkl8IWasqdweVk02UTSTF5McTGocTGosT99LexUycsffpgiIsj3RiF+4tCTl0tCKXQCoCq6USafNNkXdNwbNtGzZIlVH+9FDkygugJl9P1zVmY09OD+2Y0IsJs4M9XZfDbCb34cnsh89bncrTGhUfViTQrdKgsokw3sPuIwuQZazHIEnE2E1OHJnPLiBQSReptq5L0U5cLTpKZmalv3LixBYfTPOo8dRypPUKdpw6rwUqiLZE4Sxj6Sek65K/3FazY/zX/Z/4/PjyaghrgWxpMXnCczcjGJy8J+pClIJwOSZI26bqe2drjOFPnyvVJEFpDfrmdK19a6StTrYMaYC5gMcpoOlw/pAt/mZSBoQUyJ86F69PZcG367cfb+WhTgV/BCc1VR8Ert5FwxWNNBhQRJoU3bx/KiLSEgPfrXi9qTQ1qZSVaVRXq8T+VVajV1ahVVRRV2ZkqDcUl1V+7r921jIplb5L80Nygvp5Is8IH00fSJ4jCGFUONwOe+ibgfUe/eBb73lUgycjmCIyJaVi7DUCzV9UrtAIwqnsCbw61UL3ka6q//hrZZiN6wgSiJ1yGuUfLLHy//2MeTy/Zh8vloc7r/ztqPnbefnR6O56Z3J92kWa/xwjh0di16ZzcmTpVhDGCHnHN8IMvSdB1BHQdQV1FMZ89uyFgIBVsXrBb1UI6ZCkIgiAIpyO/wk6UxcicO4bx7rpcvtxRiATIkoRX04i2Grl7TCpTh3YlPkKUQm9rtuVX+gVSENqOkMfjZeP8haR4clGrq3wBU+VPQZNmt6NERSHHxqDExKLExKBER/v+jo3B1KUzSb1i8G40+O0ShZpy57E7qXvydxSldsKc3h1zejqm9HQMcf4L459sOowhwG5Y9YZPceZsJWrQlTjzd+ItP4y7cB9qxWESJv663mONaFz15UyOfFpC9IQJJM+cgblHjxbbmdV1nX98tYd563JxNHJs5PgZqx/2H+WKF1fy4X0jSUmIaJExCj/5WQRTLeHbXC+KwQhnUorTpfLuulwRTAmCIAjNxqtqPLVwN09e2ZvBKXEMTonj2cn9qXJ4cHhUoq1GoswGkdLXhtW6Alc0DqUIg0eHOoMJc1pPX4B0PGCK9QVNclQUktz0bmWXrGXkltvr3RZqyp0t0krPSTfhzTqIc/duqhZ+gevgQSSjEXN6Oub07pjS0zGldWfmimq/QOrkuVhQKX5ITPrXb7H1Oq9Vfg9eXZ7FvHV5jQZSJ/NqOqW1LqbMWMviR8eSIHaoWpQIpsKkuNoZsApLqHnBRyr9e2QJgiAIQkg0DVxV4KwGUyRYY+HYBHre+jwSIk1M6NvhxMMNiiwmYGeZ3LI63tuQT9bRWuxulWirgUHJcUzJ7NJkY12rMXCwFMqOkFGRaTdmFPFjzqwIxfTxafz9yz31+kyF0tvMYpC5a2waMWN7wNifAiFd1/GWHMV18ADurCxce/eR99W3lHW6GuT609tQ52Kg4+7YhYhWCKRyy+p46bsDAeeUjR0Z0XQor3Pz9y/38N+pA1t83D9nIpgKE5dHC7ilHmopzmDKYgqCIAhCQNVHfOd4f3zDVyBJNoCmgiTDgBuoHHAPL313hPn3jBA7T2epZftK+N/3B9l5uApV0+vtsny/t4Tnlu7jkj5JPHRhOr06BO4F1TXexoES/wIUoewImQwynWOtZ/bFANcM6szfFu3xuz2Y3mYAGnDDMP+Ky5IkYUxKxJiUCKNHA+AsrcP00ko8p2QJhToXM+kau+9/hGH//SfGDh2afkIYvbU6By3AfDKYIyNeTeerHYX8ZVKG6EHVgkQwFSbRViMmg+wXDIWaFxxlEf8kgiAIQog8DvjsAdj7pe//1WNNUtWTmpdueZeIzfP4xNablMiPgagWH6bQME3znZOZvz4vYF8m4ETbla92FPLdnmKenTKAif07+T3u1pEprD1Qgv2UlwllRwgkLujV/oy/LpvJwEMXpvO/7w/6fV1N9jYzKtw4LDnowgoGWQpY8j/UuZhkMhE3fgw5U66n8/P/wTZ0aFDvf6acHpUPNubjaSRNsakjI7Ik8fGmAu48wx1FIXiiqVGY9O8Sgxxgle/kVaCmGBWJ4amid4AgCIIQAlcNzLoE9n3lC6KOB1Kn0rwYdTddHbthxmiozGvZcQqN+tuXu5m/PrfBQOpkmg4Oj8avP9zGkp1FJ27XNY2ab7+ly58exuqsC/jc6GHXEXfhXVStfZ+Cl6dR8Nrt1GxehLXHTylwRkXipmFdMRtOvx/TyR44vzsT+3dsMP0wEKtRYXR6Ak9e2Sfo58RHmHAHyPAJZS4G4FF1Uu+8jY7/+hcFjz1O+dx5NFb9OlxWHywNOJcMJU3R4VF570fxu92dazZ5AAAgAElEQVSSxDZImAxMjiUp2kxOWf1DlqGsAsmSxO2jxEqCIAiCECTVC/OmQOn+hoOoU0iaB+pK4a0r4L5VvvNUQqtasrOQ9zbkB11w4DinR+Px97fSJ3Ek0auXUTZrFrLFQvt77uHBiHSeXbo/4Gs2tSOkyBK3jkoJ+etoiOR18cx4C+10M7O32dGQ8AQqf4xvd8kgS1w7qDN/u6YvcpDtYnRNgy0b6adVsoX6JdRD25GDEWkJWE0KjBlNt/cWUPDQwzh37qTDX/6MbGm+nk5ltW60AEFbqGmK5XXuph8khI0IpsJEkiTuG9+dvy7aXe+QJQSfFzwgOZauCbaWHLYgCILQlu3+DAq31wukur1Qg90Dhx6NJMLkm4jO2uxm7nYPy28/VjZZV6G2GFa/BBf/qTVGLpzkhW8PNLgj1VSfSo/HywtPvMCjUg4d/vB7bCNHIkkSt2s6q7LKWXOwFGcI57EtRoVnftGPLnFhmI+UZcG612DrPCRJ5reSzE22eOY4xvGedCGSwXSiMApIqJrOdYM7c8foVNITI4N6C9eBA1QtXEjVF4tQ4uK448LJ7C+RqTsliAx2LhZhUpg+Pu3E/5uSk+m2YD6FT/6R3Gk30+XllzB28k+tDAePpgXcAQs1TTHQGX6h+YhgKoyuHtiZF747gNOj+nWVb2oVSJbgwQvOvIu2IAiC8DOy+gXw+KdzqTq8uN7N78c2ctZEdcOPs+CC34EiDqu3lt1HqsktC5ySF1TRASSWJA/ln39+EstJaXSyLPHazYN5YN5m1hwsCyp90GKU+eukDCYN7HxmX5THAR/fDQe/9RVA0Twn7kqmmj8acnhCWcBmpT+VhgS00b8irnM6g7rGYjM1PTX1Hj1K1ZdfUrVwIWppGTGTriL59ZlYevakm6bzt399R53Hf6e2qbkYQJTFyOju7erdJttsdPrPc5S/NYdDU6fS+bn/EDF8WJDfjODFWI0oAXbiQi0lH2kW0/uWJL7bYWQ1KXxw70iu+t8qapwev4CqIRajTGZKHE99sYs3bxtKajvRcE0QBEFoQvEuKD0Y8K4nRpl4ZrWLB4aaiLU0kialq7B3EWRc20yDFJryztoc3AFS3kIpOiDJMkt3FzNpQP0dE7NB4Y1bMpm/Po/XVmRRYXfjcKv1ijSYFAlJkhiSEsevLu3JkJQzPLvtroPZl0HpAfA6G3yYRfIwStsEbmDVarj5YzA1fCZIczio+fY7qhYuxLFtG1EXXkjSr3+NbfhwJOWUIHLaYG5+c/2Jgh3BshoVXr15cMDUQkmSSLjzDiy9zuPwr35Fu3vvIe6WW/yqYmqazooDR5m1Mpv9xbU43CoWo0yXOBt3jkllQkYHTIbAJQsyU+IDpj+GkqZokCXG9jjzwiFC8EQwFWZdE2wsengMU2eupcrhoc7d8EqQ2SAjSfD89QO4ol8nFmzIY8qMNfx36kDxiyAIgiA0bv+Seiv+J8vspHB+NwPPrXHx9wsbOePhroWdH4tgqhXtL64JmJYVatGBvAZ2t2RZ4uaRKUwb0ZUNh8qZuy6XvHK7r0GzxciQlDhuGZkSnrQ+TYP5N/jO8HmDO8MH+HZX502Ge1dAu5+ydHRVxb5hA1WfL6Tm+++xDhhAzKRJdHnxBWRbw+PN7BbPKzcN5sH5W3AGsSMHPwVSg7vGNfq4iFGjfOeoHn4Ex86ddHzqKWSrFV3Xmbsulxe/O4DDrdab/9W6oLTWze8+3s7vP9nBHaO78ehFPTAo9YOqDjEWhqXEsTKrFKgfpAWbpqjIEneNFefvW5IIpppBcryNZU+cz5KdRby2PIucsjoUScKr6SiyhCSBUZa5fVQ3bhrRlcQo3wfdjcO6ktYugocWbOGB87tz+6huog+IIAiCEFhNMWjeBu/+6wVmRs+u49HhjTd4pbYkzAMTQlHnCjzZD6XogKZDpSNwYA2wvaCS9zbkk1dux+lRSYwyMzwtgSlDkomxhTHFM/t7OLypXiAV1Bk+8O1offNHuHEBzv37qT5+DiohnphJk0j81S8xtA9+ofmi3km8d+8I/u/j7eSU1eFRdb+gVZEljIpEartInv5FP/p3Ca4Yi6lLF7rNn0fhH/9EzrRpdHrxJZ5cV8aX2wsbTac8HmC9sTKbH3PKeev2Yb5CF/iaENcsXsxViz9mY49JOCT/KXowaYq9OkTRvX1w582E8BDBVDMxGxSuHtiZqwd2Zk9hNXsKq6lxerEaFTrEWBjVPcFvRQJgeFoCn9w/irvf3sj+4hqemtS3we1gQRAE4ees8VzyvokKE3sa+PcqN73bN/Y5Ig6rt6aGzreEUnRAliDOVj9o1nWdTzYf5tXlBzlS6cTlrX+ee9XBUp79eh+XZXTgkYvSSU8MQ9+x1S+e/hk+dPR9S8mbciXuo3ZirppI11lvYO7R47SHMzA5liWPjWP3kWreXJXNt3tKsLt9CxA2k4FL+yRx19jUBpsfN0a2Wun07DOUv/02v/nDbJYmD8EZ3CYYTo/GlrxK7nlnI2/fOQzX5k0UP/MseL1c8cSv+XCbzraCqoBl3htjMcr86argS8kL4SGCqRbQu2M0vTsG/4uaHG/j4wdG8dh7W7l51npeu3kwCUE2rBMEQRB+JiITQVJ8554a8NT5FgbPrOVXIxv5DLG1a/g+odn16RTN1oIK1FPmzaEUHbCalHrnrV1elUcXbGXFgaM4GjhucLxk+pfbC/lmdzGv3jyYC85LPP0vpDIP8tcHvCvYM3y6rtPx6jSMN75Q7xzUmerTKZr/XD8wbK93nCRJ7Bs7kW8OrA86kDrO5dXYlFPGq798hsu3L6X9448RfeWVSLLM7EEern11DfnldlxBBlQWo8yzk/uf+Zk3IWRiy+MsFWk28PotQxiaGsfVr6xmT2F1aw9JEARBOJukXwKGxlP40uNlpmYYeWlDA31nTJGQcV0zDE4I1m2jUjDK/tOxk4sO2PevRfM40VUvjqyNVCybXe+xkqZxce8kwFcA4YF5m1m+v6TBQOpkqq7j8KjcP3cTaw6Wnv4XcugHX3AfwMln+BojSyqmms1hDaSa28wVWTga+DbX7V5O4duPkff8ZAr+dwvFH/wZZ8GuE/c7vDoLInuR+tWXxFx1FdKxn4Moi5HPHxzNkJQ4rEaFxlpt2UwKNpPCjJuHcNWAM6zCKJwWsTN1FpNliScu60XPpCimzVrPv6/rx6UZHVp7WIIgCMLZoNNAiE2Bo3sbfdifxpt5d3vD52noMynMAxNCkZ4YRc+kKLYfrvK7L5iiA0ZJZ2LeBo7cPp9206czX+/EmoNlIVeyc3o07nlnI6v/70JibU2cswvEURGeM3xO/+/D2aqoysm67PKA9wVT1h6gUjazpchOZrf6hWIizAbm3zOC7QWVvLEym6W7in3HPnRAAq+q0z7KzH3j07hmUOegSsoLzUN859uAqwd2pltCBNPf3cT+4hoevCBdFKYQBEEQYMzjsOhx8NhP3JTzWP2zL8kxMs4nA6SaK0bIvAMMIo28tT1+SU8emLc5YPGCpooOGI0GHnn+CSLWLKfwmed49bxbcBisAR/bVANgVdd5/8d8po/vHvTYdU3DW1SElpePSdVoaHYS9Bk+qe0kTS3afiTg7aGUtXd4VN7/MZ/MboHT8/p3ieXlGwdTZfewr7iGaocHs1EmKdpCj8RIMR88C4hgqo0YkBzL5w+N5t53NrKvuJZnJ/ev15xPEARB+BnKuA7Wz0Av3oWkNpDKF5AEtgQY/XizDU0I3gW9ErlnbCpvrDwUVHPd4yxGmdduHkKnhEi4aiIbu2finrcJArxEMDslTo/GrFWHuGdsWr1eS7qm4S0pwZ2TizsvF3eu748nNxd3fgFKdDRxfSTik3wVixsS3Bm+tnPm50ilI+CZplDK2us6FFQ6mnxcjM3IsNS28735ORHBVBuSFG3h/ekj+e3H27l+5lpevyWTDjH+/UM0TWdNVhlLdhZSUuNC13XaR1uYkNGBMentAjajEwRBENoggwlt2qeUvDCWBK0Io95IOt9xsgHM0XD7VxCR0PxjFILy+CU9MSgyry3PwulRG62xaJAlTAaZV6YNZnzPn8qFv7UmF3uAQCqUnRK7w8V3cz5hQPkh3Ll5vsApPx85MgJTSgqmrimYUlKImXgVpm4pmLp29fV8qiuD53sHDOSOO/kMX7/EADtQRhsMuLGRr/zs4mgglTKUsvYArhACaOHsI4KpNsZiVHhh6kBeW5HFNa+sZsYtQxiY7OuL4HCrzFufyxsrs6lxerGfcvD08y2HsZkN3DM2lWnDU4hooByrIAiC0Hb8a3khe+NfYk7Ua5DzA2hq4Ga+kgwGCySkw03vQ3Snlh+s0CBJknjkoh6M6p7Aq8uyWJVVigT1dj5sJgVdh+sGd+aesWl0O6mCH0BOA417Q9kpUd0esvfnM6hbFNFXXI4pJQVj1xSUyIjGnxiRAD0vg72LQG/4vFajZ/h0rU0FU+0iAp//CqWsPUCMNYy9voQWJ2bTbZAkSTxwfjo9EqO4c86P/GliH0ant+PGN9ZRUG7H2UAZzbpjHbmfX7qfBRvyee/eESRF++9sCYIgCG3Du2tz+G5vCZ/cPwbFdiGUZ8P6GbBlri9/6HjpdE31FZoY+SB0GtTawxYakdktntl3xFNS7eTTLYc5VFpHjdNLXISR/p1jmTigY4PFBpwN7HCE1ADYZMY44WrajUkNffCjHoGD357eGT7Z4EtbtYTe86m1DE6JI8Ks+DVeDqmsvVFhdLpoT9CWiWCqDbukTxLJ8cO5a86P/PEzLw6Pildruvmi06uRX27nmldWs/jRsadXtUcQBEFoFna3l0XbC9lXVEOF3U20xUj3xEgm9e9EjO2nFezv9xbz0vcH+ei+kT9dx+PT4PJn4JK/Q10JuGp8qVMR7cFka6WvSDgdidGWkApBgG9iHkgoOyUGWSLScprTw+ShvoBo1yf1AqqmSWCNh0ueOr33bSXjerbHYvAPpk4uay/JCpbUQUiyAWfOVpx524m74M4Tj9V0nSmZyS09dCGMRDDVxvXqEE2XOBsbKstD6mHv1XRKa11Mn7uJ9+9tettfEARBaF45pXW8sTKbTzYfRpKol6ptNcr8fdFuJmR0YPr47mi6zq8/3M4bt2aSkhAg/cpggpguLTh64WzQq2M0BRUOv/lAKDslAD0SI09/EFe9CI5yyF4eXEB1/AzfHV/5GlG3IYoscdeYVF767oBfVlAwZe0VSeLyvh1Eml8bJ4KpNi6ntI6t+ZUBA6mmSqB6VJ1teZUcLKkhPTEqwCsIgiAILWHpriIefW8rHlULmGFw/KD7F9uPsGRXESZF5unJ/RmSEtfSQxXOYneNSWX1wVK/M9Oh7JS0izSfOIt9WhQDTJ0H3/8d1r0KkhQ4qJKNICvQcQBMmdNmz/DdNLwrs1YdwuV1+83FmiprbzbKPHxRj+YdoNDsRDDVxr21+hBagA/eYJvFeTSN2aty+Od1/Vpy2IIgCMIx3+wu5pH3tgTVZFXTfQUJVE3naI2zBUYntCXDU+OJtRn9gikIbqfEZlKYPj7tzHsXyTJc/CcY+zhs/xBWvwCVeb7eZprq62024EYYcT+0a9vBRKzNxPv3juDaV9dQ5/aiB5kmZDXKvH5LJt3bn8EuoHBWEMFUG+ZVNT7cVIDnlGAqlBKoqgafbCngL5MyfJ21BUEQhBaTU1rHIwuCC6RO5tV0/rV4L307xzAkRfSeEXwkSeLxi3vyp893nVYDYJNB5pqBncM3IHMUDL3T90f1gqsajFbfn3NIj6QoPn9oNDe+vo46l5e6AMHscWaDjFGReeuOoQxtoFGv0LaIYKoNq7B7UAPsSoVSAhVAAsrr3AF7VgmCIAjNZ9aqbDxqAxVYm0jVdno0XvruAG/fOTzg84WfpymZyWzNr+STzYdDagBsMynMu3t487VNUQxtqiFvqLq3j2Tlby9gyc4iZizPIqfMjiL7Fq0VWULTdYyKTNd4KwvuHUmkaE9zzhD/km2Y3e1FCdCAN9RmcbIsUevyhnt4giAIQiPsbi8fbzoc8IxUsKna67LLKapyisUwoZ6/Xd0Xq0lh3rq8JgMqs0HGYpSZd/cIMjrFtNAIz01mg8LVAztz9cDO7Cuq4WBJLbUuD1aTgeQ4K6kJEYx7dhl2t1cEU+cQ8S/ZhkWYDQF3pkJtFqeq+lnxS737SDUbDpVR5fBiUCTaRZq4uHcSCZHm1h6aIAhC2H25vZBAR1NCSdXWgfkbcvnlJec182iFtkSWJZ68sg8X9Upi5oos1maXoes6bvWnOUOEWcEgy9w6MoVbR3ajfZT4rA2n8zpEcV4H/+JeEwd0Yv76PB67uGcrjEpoDq0/gxZOW6zVGHBnKtQSqLrLScUvJqEO6I910CCsgwZiOe88JEPz/3i4vRqLdxby2vIscsrq0HXfbZIEFqPCHz/fxYXnJXLPuDRRtUoQhHPK3qKagIUCQknVdns1dh6ubo7hCeeAkd0TGNk9gaIqJ19sO0JBhR27WyUh0sTA5Dgu7p2IQRHnpVvSbSO7ccub63ng/HRxVv0cIYKpNsygyEwdmszcdbl4TlptCqUEqkGWmDI8nbT7Z+LYsgXH1i1UvLcA75FCLP36YR00ENugQVgHDECJPYNSqQGU1rq46Y11FFQ4/CYUuv5Tj5Wlu4tYsf8oUzK78JerMpADBJCCIAhtTaXdHfD2UFO1qx2ecA5LOAd1iLFwz7i01h6GgG/Hqnv7SBbvLOTqcBb7EFqNCKbauNtHdWP++jw4pbtBMCVQwXco8o4xqZjbR2JOSyX2F9cBoFZV4di2DfuWLZS9NQfn9u0YOnb8KbgaNAhTauppl08tr3Mz8eVVlNa4Ap4XOJmmg8Oj8uHGAqocHl6YOvDMy7YKgiC0smhL4EadoaZqR1rER7kgtCW3jerGzB+yuHpgZ+weO+XOctyqmyhTFPGWeJQgF1KEs4O4ArdxKQkRDE9NYF1WKe5TgpKmSqAaFYkhKXEBexwoMTFEjhtH5LhxAOheL679+7Fv2ULd2nWUvvoaWm0t1oEDj6UGDsLary+yzdbkmHVd57bZGyirbTqQOpnDo7J0VzGv/5DN9PHdg37e6ahxevhk82FWHjhKhd2DSZHpFGthSmYyw1PjRTAnCD9DOwqqeHNVNjuPVFPn8mI1KqQkRHD76G6MTW8X8q55WmIkVqPiVyAglFRtoyzRI1H0qRGEtuSiXu3589KF3LxoDrsqNmKUjcjIeHUvFoOFab2mMeW8KbSztmvtoQpBEMHUOeD5EdFM3HGAUksMXoL7MDfIEolRZl6bNiSox0sGA5Y+fbD06QPTpgHgKS7BsXUrji1bOPr88zj378eclnbi3JVt0CAMHTv6BR4bcyvIOlpbLzXxuKZKATs8Kv9bdpA7Rqc2S65xXpmdl78/wBfbjyAh1ZvkSMDinUXE2UzcOy6NacO7ilxzQfgZWLKzkOe+3sfhSicur8rJa0DZpXVsOFSGzWTg/vO7c/uobkEFVbqukxxnxeX1PzMVSqq2LEPf7ke5/9v7OVBxAIfXgUkx0TmyMzf3uZmLki/CqATeARMEoeUdqDjAw98/jKddGdvKfI23vdpPFZVdqos3d77JrB2zuLr71fx+xO8xyGK6fjaT9EZaNWdmZuobN2484zdxelS+2lHI3qIaquweoiwGUttHMLF/J2Ks4iJ/Jhw7d5F/330YHnuC+/NjyS2va7L5o8Uo0yXOxoJ7RoS1eo/mcuHctevY2aut2DdvQTIYsA4ahG2QbwfL0qsX9763nW/3FPt1CW+oFLArf1e9yUOEWeHpX/RnYv9OYRs7wIZD5dwxZwNOt0qAOK8eq1FhQHIMb942tPl6cjQTSZI26bqe2drjOFPhuj4JQkN0Xee5r/cxe3VOUP16rEaF0ekJvDJtMGZD4DSdWpeXz7YcZu66XNxeDZtZYfeRagJt0tfuWkbNxs/xlOXXS9W2dOkNgCFmI1EdvsVkcmP32v2eH2GIQJIkpvWexn0D7msTE7Jz4fokrk1CQ7aWbGX6N9NxeB3oNJ2ZY1Es9G/fnxkXzxCLIq2ssWtTswZTeWV2Zq3K5qNNBQD1igxYjQqarnNFv47cOy6N3h2jT/t9fq7smzZR8PAjdPzrU0RdfDFOj8p7G/J4fWU2lXYPDrda71c1wqQQZTUyfWwaNwzritXUvDm5uq7jKSjAsWUL9i1bcGzZSklhKbee/wRuqf57a646Cl65jYQrHguqAmHfztEsenhs2Ma6Lb+SG15fF1KDQ7NBJqNTNO/dO7JNVeQ5FyYrICYsQvN76bsDvLY8K6TrgsUgM/689sy4eUi9Xfn9xTXMXZfL51uPMDItgVtGpjCqewJ7i2q49tXVTS6C1adjTvocY+wmJLnp4hMWxULfdn155aJXsBmbTsVuTefC9Ulcm4RAcqtzmbpoKnWeupCeZ1EsjE8ez7PjnhVHDFpRY9emZlum+m5PMQ/N34JX1fAEWHI7/uG0cNsRFu8s5M9XZXDjsK7NNZxzTu3q1Rx54jd0evYZIkePBnylxG8fncpto7qx/lA5S3YWUVLj20JuH2nmsr4dGJmW0GK/jJIkYUpOxpScTMykSQAU7yrAtGAb7lN6BIdSChjgYElt2Mbp9KjcNntDSBMmAJdXY3dhNc98vZcnr+wTtvEIgtD6thdU8trygzhCCnLA6dX4YX8pH24q4NpBnVm6q5h31+WQdbSOG4cms+SxsXSMsZ54fO+O0fxpYgZ/W7Q76GuQqf3ioAMpAKfqZPvR7Tzy/SPMvGSmONwuCK3gmR+fwe7x30GuWFlB6deluEvcKBaF6CHRJE1OQonw/Z46VSc/FPzA9tLtDGg/oKWHLQShWYKp7/cW8+D8zUGttKmajqrp/PWL3XhVjVtGdmuOIaFpOpUODzVOD1ajQqzNdFbtJjg9KqW1LhxulQizgXaR5gbHV/3NNxT9+S90efklbEP8zzxJksSItARGpCU097BDVicZkRQDeOtHU6GWAnZ5NHRdD0tguGh7IR418M9qU2e4nB6N+evz+PWl52ExigmKIJwrZq7IxuU9veuCw6Pyjy/38OySvaS1j+SWkSlc2qdDg9f0m4Z3RdU0/vHVniY/N2VrLqb4tUEHUse5NTfbjm7j/X3vc1Pvm0J6riAIZ+ao/SjrjqzzS+0rXVzK0cVH6XJ3FyL7ROKp8HDk3SPkPJdD6h9SkY9dM1xeF3N2zuG/F/y3NYYvNCHswVRemZ2H5m8JMWXh2IfPV3vo0ykmrM1ZS6qdzF2fyztrcnF4VBRZQtN8P85X9OvI3WNTyegUE7b3C9WuI1XMWnmIr3YUIksSsgyaBpIEUzKTuWNUN7q1izjx+KqFCyl+5lmSX38da9+MVhv36bKaFALVyAi1FLAOXPT8CjpEW+gQbSEp5tjf0RY6HPvvdpGmoApEvLb8IHUBGmc2dIbLcWD9iUnTcV9sO8KUzOQm30sQhLNfRZ2bb/cUBzzHFOx1oc7l5a/XDwi6j8wtI7vRp1MML393gLXZZQD1gjmjLCHLEvFd1lAnewOetghmhfutnW9xY68bRbqQILSgD/Z94Heb6lAp+ayEznd1Jqp/FACm9iaSH0hm/xP7qVpTRdw433xYQ+OHgh+ocFYQZwnfHFkIj7AHU7NWZeM+zdU8p0fjpe/28/adw894HG6vxu8+2c4X2wuRIOAK48KtR1iys4i09hG8cWsmnWKt/i/UTEqqndz9zkYOFNfg9uqoAc6uzV+Xy3sb8hiRlsD/bhqE97OPKX1tBilz3sKcnt5iYw2njjGWgLtAoZQCBkiKNjPj5iEUVTkpqnZSXOVkf3ENKw+UUlztu63S7iY+wlQvyEo6Fnwd/+9qp4cjlU6/19dcdVSumkfCFY9hO2/Uidtt6cOxpdf/+bS7VWatOiSCKUE4RyzaUUiggnyhXBc0XWflgdKQmnIOSYljzp3DKKpysmBDHjsPV1Ht9BBpNtAzKYrLB0Ry1/d70ANEecGucFe5q9hYvJGhHYYGPS5BEM7M0tyluLX6TbrtB+xoHo3oIfVrBigWhaj+UdTuqj0RTAEYZSPrC9czIXVCi4xZCF5YgymH29dYNVDvoGBX89Zll1NU5aRDjOWMxnHTG+vYU1TdYGAHoOo6Do/K3qIarnhpJR/dN4r0FujXkVdm59pXV1Pl8DTaZ8mj6aDprM0u44p/fMWLG+bSd+67mJLb7qQ9rX0kXeNt7C+uf+YplFLAFoPMbSO70TMpip5JUQ2+l0fVOFrjOhFsFR0Lsg4U1/huq3ZxuMKBO0BwF+oZrsJKR5DfAUEQmpvLq1Lr9GI1KViNSsi7MEcqHAHPSoVyXdB0yC0L7aD5cR1iLDx+SU+/2xfsXYAs+e+2h7LC7fQ6+Wj/RyKYEoQWVO2u9rtNrVUxRBqQFP/rkyHGgCO3/rzCq3upclU12xiF0xfWYOqrHYUE+swKZTVPB+atz+VXl553WmPQNJ3pczexu7C6wXz3U6maTpXdww2vr2Xxo+PCWi78VFV2D1NfX0uF3R0whSQQt1ejUIMnL/0ln3YMbznw1vDA+en84dMdfql10cOuQ46Io2rt+5Queq5eKeCT6cANQRQrMSoynWKtje44fralgN9/urNepUk4jTNcQf6sCYLQPCrtbj7YmM+slYcorXVhUGRUTcdskJmamczto7uRkhDR9AsBdadWyDkm1OtCqOnuTSmuK8ap+u+kh7LCraNTWFsY1nEJghA6JVLBW+tFV3W/gMpb5cUQWX+KLgXZR1RoeWENpvYWVftNSiG01Ty3V2Pn4dOPvJfvL2FjTnnAyW1jaYY6UGn38MK3+/nHtf1O+/2bMvOHLMrqAgdSjY3PKxs4VOVm4da2fzbn8n4d+PPCXYD/z0pkxgVEZlzQ4HNNBplL+yQRH2EKy1iiLEaUAPk8oYePjo4AACAASURBVJ7hEsUnBKF1eFSNpxbu5sNN+cgSJ3aUjmcl2N0qc9fnMn9DHkNS4nj5xkEkRDa+YBZvMyGB37mkUK8L0WHuo+jwBt4BD3WF26W6wjouQRAaF22KptRRWu82W7oNySBRvamamGE/nd1XnSo122tImpxU7/GKrBBjab0z/kLDwlrOrtIeuLpQqKt51Y7QqhSdbMaK7IABXfWGTyn/7g1iRlxPl4fm0vn+t4gafAWOA+tPPMar6Xyy+TD2BlYlz5RH1U40ajyd8TncKjNWZDXL2FqS2aAw+/ahIQcgiuw7c/XP68IX7KYnRgb89zj5DFcw0toHt+ItCEL4OD0q095Yz8eb83F5tQbLmHtUHZdX48dD5Vz+4koKKvzLE58so3M0ZqP/x2Mo1wWzQWZYt/AeFI+3xAdcnT55hftUgVa4o82ir6MgtKQrUq/ArNRfxFFsConXJHJk7hFqttege3XcR93kv5qPMd5I7KjYeo/3al5Gdgzu6IHQssK6MxVlCfxyoa7mbc6rZMzT39PxWJGAn/620iHGTIcYK4lRZoynVGrLK7OzLb/S7/VCSTOUJPhsy2FuGp4SzJcckqW7ilEDbEmFMr4jlU625VcyIDn21JdpU4akxPHmbZnc885GHB6VRnpHA76JSec4K+/dO4IoS/hWe1MSIujdMZqtp/zchHKGK8KkMH1cWtjGJAhC0zRN58H5m9leUIkzyDRbj6ZTVutm6sy1fPXoOGJO2TnaV1TDJ5sL+GzLYTwBApNQrgsANw4Pb+/Efu37YTVYsXvrB4OhrHCbFTPDOg4L67gEQWjc5J6TeX376363t7+iPUqEQtH7RbhL3MhWmejB0SRPT0Y+aUFHlmQuSL6AGLPYmTobhTWYSmsfidWo+DUeDKVSm0GWuHlEV+4YnUphlZPiaieFVU4KKhxszKk4VjjASWmti1ib6URltg7RFo5UOdACzMpDSTO0u1U+33qkWYKpr3YcCViCO5Txubwqy/aWtPlgCmB0ejsWPjSaZ5bsY/n+o0heL65TNksjTL7D4zeP6MrDF/Ygwhz+1mj3je/Orz7cSp3r9M5wGRSZi3vX344XBKF5rdh/lLVZZQEDqcZSplVd52iNi1e+P8jvr+xNSY2ThVuP8Mnmw5TXublmUGfevXs4S3YW8cqyg34p48FcFyRgTI92JEadfiGlQEZ0HEGEMcIvmDp5hVu2yPWq+Z26wq2jM7nH5LCOSxCExiVYExjdeTQrClag6fWvKfHj44kfH9/o802yidszbm/GEQpnIqwz06sGdOJvi3b73R7Kap4iS9w60ndQuLHDwqqmU1rrorDK6SuPXeVgT2F1wNXEUNMMD5bU8t9v9vvdHnDzJEDwFuhxug5bAuyahTo+TYeSmnMn3z09MYrXb82ktNbFq0/8h40pg6k1WDAqMu2jzFw/NJkJGQ03uwyHi3snEms14nCrfmfZmjrDZTUq3H9+WlD9rARBCJ8ZK7IaTOluqnKsW9V5Z20Oe4uq2ZJfyaV9OvCHK3szIi3hxBnKhAgTb646FPD8bVPXBbNR5vGL/avxnSlZkrm1z628svUVv0IUwaxwy8iM6zxO9KkRhFbwm6G/YWPRRmo8NSE9z6JYuCTlEjLatb3eoj8XYQ2mYqxGJvTtwKJthX59k4Jd5e/TMZq09k2XJ1dkiaRj/YM4Vo+hyuFhU27FGR8a1o/9OTUzXQL/aoWSFDCHPVBVQ7mB8ryhju9cKOhS7a5mac5SDtccptZTS4whipTSD/nN0w9iiW7ZbWyDIrPgnpFMfHklNS5vkymHx1mNMhf0as/0cd2bd4CCINSTX273S82F0FKmPapOSkIEM24Zgs3k/1GYEGlm3t3DuX7m2oBBW0MsRplnJw+gb+fmuY5d1/M65uyag0t1oZ/yadfUCrdJMfHAwAeaZVyCIDSuS1QX3rj0De5aehd2j93v9zcQi2IhMymTv47+awuMUDhdYc+Zmj6uO1/vKkL1+P+QBLPK//BFp9+MNtZmwmSQ/VYSQ20I2yMxkl8G6PFxpvYX11BQ4V+NKZTxyRIkNWPp9ua2r3wfc3bN4Zvcb5AluV51KutlEjO/vJwbzruBG3vdSHtb+xYbV9cEG589OJqpM9dR4/Q0eQbDalSYOKAj/7q2X8g9bARBODPf7ikOeHsoKdOqrpNXbg8YSB3Xt3MMH943kmmz1vsKXDQSVJkMMook8dKNg7ikT/Ol/Uabopl92WymfTWNOk9dUBMy8E3Knhv/HD3iejTb2ARBaFxGuwwWXLmAR5c9SlFdES6vCw3/+YZFsfhScntO5teZv0YJMrNKaB1hD6b6dIrmySt684+v9jRYWSkQq1HhlpEpXNjr9D+Ezj+vPf/8ao/f7aGkGdpMCpMGNE8vp4n9O7FsX4nf2ZxQxmfSVcZQhq7rZz6Jr8yDDW/Ajg/BWQW6BqZI6H4hjHoIOg44s9c/xdw9c3lh0wt4NI9fzjCAw6iDu4a3d73N/L3zefWiVxmcNDisY2hMWvtIvv3VeBZsyOPNlYewu731zrgZFQlZkhicEsd947szrkc7EUgJQjgU7YA1r0DWt+Cq9W3tm6Oh3y9g2L0Q163ew8tr3QHT70JN6S6tbTplOqNTDCt/cwGfbD7MzB+yqLJ7UHUdr6qjyBJGRUKRZW4dmcLNI1J82RLNLC02jQVXLuDOr++kzlPnd4bqZBbFgizJvHjhi4zoOKLZxyYIQuNSY1L5/OrP2VG6g7d3vc2y/GXIkowsyXhULx6PBXf5+bgqh/DWPhsfLV3GpRlJ3Dkmle5BZG4JLU/SG8lpyszM1Ddu3HhaL/z2mkP8e/HeoAKq44HU7y7vdcaT08mvrWFjbkXA+2p3LaNm4+d4yvLrpRlauvQ+8RiLUWbTk5c0S6EDr6qR+fdvqWyg9Hsw4+tu9PD6llnoThcxkyYRc/UkTCkhFssoz4YvHoP8db7DXKq7/v2SAgYTxHaDK/8D3UaH+JX6m7NzTsA8/8ZYFAuvX/o6gxIHnfH7h0rTdFYeLOXHQ+WU1rqwGBU6xli4sn9HusTZWnw84SRJ0iZd1zNbexxn6kyuT8JZIm8dLPolVGSD1w36KTs/igkkGTpnwqSXIMGXUvv04r28FqBNhCN7EyUfPUXXX38aVEDVu2MUix8dF/RwdV1nY24F+4pqqHV5sZkUusRZGdejfaucm3SrbpbmLmX2jtnk1+SjyAqqpiJLvrFEGiO5LeM2rk6/us1UATsXrk/i2iSEwqN62JCXzz++2k5WsYrbbUY/5TyHQZZQZIleHaP417X96dNJtDdoaY1dm5otmALYmFPOS98dYP2hcnQd3OpPgdXxH4zeHaN55KL0M9qROtm3u4t55L0tIeW4nzymyUO68O9f9A/LWAJ5/pt9zFyRHXBVtSlWk8I/r+3LNQM749y1m6qFn1P95VeYkpOJuXoS0ZdfjhLbRJW/w5vhnavBXevbiWqKwQqTXob+U0Ie73HrCtfx8HcPhxRIHRdhjODLa78kwZpw2u8v1HcuTFZATFjavF2fwaf3QQONaOuRZDBFwM2fQvJQZq3M5ukle/0KDmmuOgpeuZWEKx4PKqV7VPcE5t9zbuzW7K/YT1ZlFjXuGmxGG50iOjEocVCb2z0/F65P4tokhOKH/Ue5b+6moOetNpPC67dkMqZHu2YemXCyxq5N4d9+OUlmt3jeuWs4hVUO5q3LY+fhKqqdHiLNBtITI5k2IiXsW5YX9EpkYHIsm3IrQgpYJHzd6pvjrNTJ7h6bxqdbDnOk0hmw51RDTIpMrw5RTOzfCUmSsPbNwNo3g6QnnqB29WqqPv+ckv88T8TIkcRcczWRY8cimUz1X6QsyxdIuaqDH7DXAQsfBlscpF8c/PNO8vLmlwMGUhUrKyj9uhR3iRvFohA9JJqkyUkoET+tKHs1Lx/s+4D7B95/Wu8tCMJZKHt58IEU+BZ+XDUw91q4ZxljeiTx3NcSnlPOC4Wa0n1lv45h/KJaV8+4nvSMa97PL0EQwmtbfiXT393k11KoMXa3yj3vbOT96SPo36Xtt8k5FzTrzlRrsbu9XD9zLQdLanEGkWYoSxBpMfDRfaPomRTV7OM7Uung2ldWU17nxhNEQGU2yKQk2Pjo/lFEN9KwVq2upvrrr6n6/HPcWdlEX345MddcjaXfsSIJr58PhduC25E6lSkSnsgCY2jnAXKqcpj8xWRcav2zCaWLSzm6+Chd7u5SryeKWqOS+odU5JNKoceYY1h+/XIMcrPG/j8b58LKL7Td69PPntcNz6aDq+o0niyhJ/bhh6v+xQNz8qmrC1ziu7VTuoXTdy5cn8S1SQiGpumM+vf3FFWHnrUD0DHGwpr/u7DN7T63Va22M9VabCZfYPTER9tYuqvYL8XwOFkCs0Gha7yNWbdlkhzfMmdhOsVaWfzYOO6ft4mteZWomo43QFBlUmQkybfb9t/rB2I1NX4GQImOJm7KFOKmTMFdUEDVwoUcfuIJJEkm/srhxFbuRjolkOr2Qg12Dxx6NJIIk+8XctZmN3O3e1h++yl9vnZ/BgNuCOlrXbB3AapWf8VFdaiUfFZC57s6E9XfF7ya2ptIfiCZ/U/sp2pNFXHjfpokeVUvqw6v4vzk80N6b0EQzkJ7FvqfjSLYa5GO++geFq35F1NHPMJ7K+WAqTFNVY5VZLhmYGcRSAmC0GpWZ5VS4wx8hr6xxuPHVTk8rM0qY1S6SPdrbefsJ4nFqPDyjYM5Uulg7rpc3l2Xi8erocgSmu5r+ntpRhL3jE1jQHLLb5PGR5h4/96RZB2t5a3Vh/h402E03VcdyqvqmI0yt4xI4ZaRKXSMsYb8+qYuXWj/wAO0u/9+HFu3on/+EGguCHBGWtXhxfVufj+2kZLr7lpY9d+Qg6ldZbvw6t56t9kP2NE8GtFD6h+gVCwKUf2jqN1VWy+YcqpOsiqzRDAlCOeCVS/4ricBBHMtMgLPGFPwXjiBjXvXsKcocLP2xkSZjTzWDE11BUEQgjVzRXa9isHHBdN4HHzpfjN/yBLB1FngnA2mjusUa+U3E3rxy0t6Ul7nptrpxWpSSIgwYTG2ft3+7u0j+fs1/fjLVRlUOTzY3SqRZgMxViOyfOZbt5IkYRs0CL46BIEXQHhilIlnVrt4YKiJWEsj71mR4/tzUpli3etFq6tDq61Fra3z/XddLVptLVpdHVW1BX5NhtVaFUOkAUnxfy9DjAFHbv1zFKquUuU+nZQgQRDOKtVHoGx/g3cHcy2SdQ12fYrxupm8e9dwrnl1NUcqHUGdkZUlX+bC/HtG0CGm+UuYC4IgBGJ3e1mXXeZ3eyiNxwFWHyzD4VabzFwSmtc5H0wdZ1BkEqMtJJ6l1SQNikxCpJlmqVmnesDTcB+SzE4K53cz8NwaF3+/sOEJhubyUvTw3diPKmjHAifd7UaOjESOiECJjECO8P23HBmJHBmBuZsOp9TBUCIVvLVedFX3C6i8VV4MkfV/LCUkoozNf5ZNEIRmVlsMihm8gfs7BXstQvOCx0GMzcoXD4/h/rmb+PFQOd4GUqYlfNVQ20eamXPnMFLbRfi/piAIQgspr3NjVGS8pxyDCKXxOIBRkamwu7GaQs9gEsLnZxNM/aypHpBl0BquFvPXC8yMnl3Ho8NNDT5GMptJuOt2ErqNRomMRI6MRLJYGj382HPl7zmQ/WW9Dt+2dBuSQaJ6UzUxw37qfaI6Vf6fvfsOj6pMGz/+PWVKeoGEGgihVxGQ3sSy9oJtbaigYvnt6uu+um51d93i67rNdQUUkbWs4tpFsaEgJVSl94QAIUASQur0c87vj5ESZpLMhHTuz3W9716eOXPmmYScc+7z3M99V2yqoMP11cvkx+gxdI5vnEbKQogmZNQwPX6KSM5FqGowILPFEO/QeXXGKPYUVjBveR7vfXcQVQFVVTAtC79hMb5ne+6dlMWoHqmyWFsI0ez8hkW45KNoG48rCvjD1AQQTUuCqbOBLQbqWFIwKF3jij46Ty330T8tfPNJRVVw9B8KHXtE/NE/7PdDvtz/Je5TSiBrsRrp16RT8FoBqlOtVs3PlmojeWz1NWwWFlO6TYn4M4UQLZQzuc5qopGcizAC4KieZtArPYE/Th3Mr68cwMFSNxWeYFPdDglOkmJrroIqhBBNLSnGFnatpxaTiOkqxzKNiAIqv2HWWuVZNA0Jps4GigKdhkDBd7Xu9tvJTobNqeQnY2opRNG+d1QfPbj9YNJi0thfsb/a9rTL0tDiNA4vOIyv0Icao5I4LJGMmRmotpM3ULqic3XPq4nRZQpbiFYvtUewAW8d6jwXdRgQnJ0Kw2nTGrx/oRBCNKSUWBspcTaOlFdPeXZ06Yei23Dtyo6o8Xj7eAfJ8rCo2dV9VRNtw7iHg72iatErVeWmgTaeXeMLfVGzw/A7Qa8l0ApDURTuO+e+sMFQ6qRUev+hNwNfHEj/Z/vT5c4u1Rr2Auiqzm0DbovqM4UQLZRmg/NmBM8ntaj1XGSPh/GPNNIAhRCi8SmKwj0Tsog5rRDaqY3HXbuyMf0eLCOAO2cdx76eV23fGJvGzIlZkrrcAkgwdbbodzlEMGX860kOqnxhcgIVBUbeW6+PviLrCi7sdiFOLbrqWU7Nya/H/Jruid3r9blCiBbovHuC55M61Hou6n9lIwxMCCGazg0jMjCt0HNc4sippEyZQVn2AvL/eSv5s+6k4tuFxPSuXpTCwmLq8K5NNVxRC0nzO1toNrjwd/DZ4+A/uX4p7+HqVfIyklQ8vzyt5KEeAwOvhZT6BTWKovC7cb9DVVQ+3/d5tfVTNXFqTh4f+ThX9pSbJiHalKQucM7NsGlB9OciWyxM+VXUM+RCCNHSJMXYuGdCD15anofbX71AWF2Nx2NsGvdM6CHrpVoImZk6m4y4E867O3hDEik9BroMh6uePaOP1lWdJ8c9yW/G/IaeST1xak7U09ZO2FU7dtXO2M5jmfuDuVzX57oz+kwhRAt12V8gY3SwOE6kbLFw7u0wambjjUsIIZrQTy7uywX900PS/WoTY9O4aEAH/uciaTzeUsjM1NnmoichNg2W/CFY4c8I3+8FVQfVFkwPvHZ2cGbrDCmKwmVZl3FZ1mVsO7qNt3a+RV55Hu6AmwRbAkPShnBj3xvpGNfxjD9LCNGCaTrc+l/48Eew9X0wfGDV0LpBsweLVkz4SfD/hBCijVAUhWd/eC6//3gbr6/ej2la+MP0ygOwaQqqonDLqG784rL+slaqBVGsMPmax40YMcJat25dEw5HNJnyAlj7Eqx5EY73gLIIrkcw/DD4Ohj9YLBqlmhTFEVZb1nWiOYex5mS81MbUbgdsv8Fm98OPrSxrGCXXQCUYMGK8+6GJFkbcDZoC+cnOTeJ+sgrruLlFXv57/p8NEU50dFGAQzL4sYRGdw1LpPu7aTpeHOo7dwkwdTZzvDDgTXgOgpmAGKSocsIcCbW/V7RKrWFmxWQ81Ob462A/HXgKQ3ORMW2g64jQa+98p9oW9rC+UnOTeJMePwG3+47Rqk72OQ8OcbGsO4pOKNIBRQNr7Zzk6T5ne00G2SOa+5RCCHOdo4E6FnzgmshhDgbOG0aY3u1b+5hiChIAQohhBBCCCGEqAcJpoQQQgghhBCiHiSYEkIIIYQQQoh6kGBKCCGEEEIIIepBgikhhBBCCCGEqAcJpoQQQgghhBCiHiSYEkIIIYQQQoh6kGBKCCGEEEIIIepBgikhhBBCCCGEqAcJpoQQQgghhBCiHhTLsmp+UVGKgH1NNxwhRBPobllWWnMP4kzJ+UmINqnVn5/k3CREm1TjuanWYEoIIYQQQgghRHiS5ieEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUg17bi+3bt7cyMzObaChCiKawfv36Ysuy0pp7HGdKzk9CtD1t4fwk5yYh2p7azk21BlOZmZmsW7eucUYlhGgWiqLsa+4xNAQ5PwnR9rSF85Ocm4Roe2o7N0manxBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUgwRTQgghhBBCCFEPEkwJIYQQQgghRD1IMCWEEEIIIYQQ9SDBlBBCCCGEEELUg97cAzgb7D/qYueRCiq9fmJsGl1TYhnYORFFUZp7aEKIFqDM7ee7/ccoc/vRVIXUODsjuqdi1+V5lxBCCFGXY1U+NhwopdzjR1dV2sfbGd49BV1r/OuoBFONxDAtvtpRyOylOWw5WIZdUzGxUFAwLYv28Q5mTsrimqFdiHPIr0GIs9HWgjLmLtvLJ5sPYddVTMsi+Igl+P9vHdWNaWMz6ZIc05zDFEIIIVocy7LYcKCUF5flsnh7Ych1VFMVpo3pzu2ju5Oe6Gy0cSiWZdX44ogRI6x169Y12oe3VQdL3dzy4iqKK7xU+Ywa94u1a2iKwkt3nsfIHqlNOEJxNlMUZb1lWSOaexxnqjWfn3wBk/95awNfbS/EFzAxajgP2zUFRVH48QW9eWByT5nNFm1eWzg/teZzkxCthdtncP9r61mTV4LHb2DWEM44vs/w+OXl/bl9TGa9P6+2c5PkkDSwAyUuLn92Gfkl7loDKQCXz6DCG2DavNV8s6uoiUYohGhOfsPktrmrWbz9CG6/UWMgBeAzLLwBk+e+2sOTC7c14SiFEEKIlsnjN7h+9kqyc4/i8tUcSAF4AybegMkfP9nBs4t3N8p4JJhqQFXeADfNyabc7a/1Bul0Hr/Jfa+tZ/eRikYcnRCiJXjs7U1sOliKx29G/B633+CNNQd4bdW+RhyZEEII0fI98Pq37CmsxBuI7jr6/JI9fLShoMHHI4t1GtC73+ZzzOUPGyFXbVtC+dr38R/NR7XHYEvPImnsjTi7DgSCUfZfv9jFrNuGN/GohRBNZd/RKj7ZfCjsBaCuc4Tbb/D0pzu46bwMbE2woFYIIYRoabYcLCM7p7he11GP3+TJj7dx+ZBOqGrDpc1LMNVALMtizje5uP2hqX3la96jbPXbtLv4QZw9hqFoOu6963HvXn3iF2xa8NWOQkqqfKTG2Zt6+EKIJjB/ZR5mmFnrSM4RECxs8/nWI1w+pFNTDlsIIYRoEV5avhefUf/raJU3wIqcYib0TmuwMcnjzQayNu8YJVW+kO2mt4rS5a+TetH9xPYdi2p3omg6sb1GkXL+9Gr7KsCba/Y30YiFEE3J4zdYsPYA/tMuAtGcI6p8BrOX7mnKYQshhBAtQrnHzyebD2GYZ3odzWnQcUkw1UDW7zuGL8yUo/fgDqyAj9g+Y+o8hidgSiEKIdqovcVVhCvGF805AmBrQTm1VWEVQggh2qLtBeVh+y9Gex3dsL+0QcclwVQDOebyEQizWMpwl6PGJqKoWkTHKXX7G3poQogWoNztRw0TTUV7jlAUJariFUIIIURbUO4JQJhnidFeR8MtyTkTEkw1EGeYSBlAi0nEdJVjmZH94hw1HEcI0brZdTXsRSDac4RpWdg06TclhBDi7GLX1eM97auJ9jqqqw17ry137g0kLdFJjC30x+no0g9Ft+HalR3RcTomNV6HZiFE82kf78BnhM4oRXuOiLFp6FLNTwghxFmmfbw9ZL0URH8dTYxp2Pp7ckVuIJcM7EiY4iKojjiSx99KyRezce3KxvR7sIwA7px1HPt6XrV94+waPxzZrYlGLIRoShmpsXRvFxuyPZpzhK4qXD20S1MNWQghhGgxBnRKJCnGFrI9muuoXVe5YXhGg45LSqM3kLQEB5P6pPHl9iOcvjY8ceRU1LgUyrIXULzwGRR7DI4OvUgcc1O1/Zx2jUkNWKpRCNGy3DepJ796fwtVvuqpCJGeI3RNYcb4zCYcsRBCCNEyKIrCvROyePrznbjreR1VgGljuzfouCSYakAzJ2axfHdx2IVt8QPPJ37g+TW+12lTuXt8jwZtIiaEaFkuG9yJX3+wNexrdZ0jVAX6dEigV3pCYw1PCCGEaNGuG9GV//tsR9jX6rqOagqM7JFKp6SYBh2TpPk1oBGZqfxwZAYxtsiqiRxn11T6d0xkxvisRhqZEKIlcNo0nr91GM4w6yvrEufQee7mYY0wKiGEEKJ1SHTa+OuN5+DQo5t8UICkGDvP3HBOg49JZqYa2K8uH0ClN8DCjYciLr3YrV0s86ePxK6r7Dq2iy/yvuCI6wiGZdDO2Y7RnUczutNoVEViXyFau4l90vjz9efw6NsbIypxrioQ79D5zz2j6RZmzZUQQghxNhnQKYkEpw1vpS+i/VUFkmPtnbVmjwAAIABJREFUvDVzNB0SG77QmwRTDUxVFZ6+bgiDOyfx98W78fqNkPURALF2DcuCc7snc6TcxZL8z3l1+8vklecRMAMY1sn3LNi5gHhbPNMGTuP6PtcTZ4tryq8khGhgV57TmU5JTn71wRb2FlfhN6yQCkV2XUUhmJLwx2sHk5EqgZQQQoiz29aCMn44ZxVVvkDE74mz67xz/xh6tI9vlDEp1unVEk4xYsQIa926dY3ywWcDw7RYuquQF5bmsvNIBW6/gV1X6ZQYw/TxmVx5TmdQfFzyxgzKrd0YeGs9nlNzkupM5eVLXqZzfOcm+haiMXkNL17DS5wehxZhs7kzpSjKesuyRjTJhzWitnJ+2n6onOeX7GHhxkPE2DVURSHRqXPNuV24fUz3Bs/tFqIlawvnp7ZybhKipSkodXPpP5ZR5vZH9T6HrjKsWwqv3z2q3rUJajs3ycxUI9JUhSn9OjClX4ewr/tNP9M/nYlL2Y1h1R5IAXgMD4ddh7n545t5+8q3SYuVyn+tUV5ZHq9vf50Pcj7Aa3jRFI2AGaBrQlfuGngXl2ddTqxNZiHOFv07JXL76EwOlLh5/8FxzT0cIYQQokX6y+c7qfSED6Sqti2hfO37+I/mo9pjsKVnkTT2RpxdB+INmGzML2XJrsIa78nPhARTzejpNU+zo2QHXrPuQOo40zIp95Yz84uZvHPVOyiKVP9rLY5UHeHRbx5l29FtGKZBwApOUZtWcN3MgYoD/Hndn3l67dNMGzCNB899UNbJtWEHKw+SV5ZHlb+K1bmVdEmTAFoIIYQIp9zjZ+GmQ2F7upaveY+y1W/T7uIHcfYYhqLpuPeux717Nc6uAwFw+QxmL82VYKotqfBV8N6e9/AaoYHUsWXHKP6sGF+hD82pkTg8kQ7Xd0CLC6aBBawA+ZX5fFv4LcM7DG/qobcIHr9BucePQ9dIcOgtvqT83rK9TFs0jQpfRbX1cKdzB9wAvLr9VXLLcnlm0jNNlv4nGp9hGnyT/w3ztsxje8l2bKoNCwuv3wTF4IYPX2L64Olc2O1CbFpoY0IhhBDibPTu+nzUMBMIpreK0uWv0+6yh4ntO/bE9theo4jtNaravhsPlLL/qKvBizlJMNVMPsz5MOysQ/GiYooWFdH17q7ED4jHf8xPwasF5D2TR49f9EDVg+/xBDzM3zL/rAqmqrwB3v/uIHO+ySX/mAubpmJ+v+bvskGduGdiFoO6JDXzKEMVu4u569O7KPOWYVHzGsVTuQNulh9czu9X/Z4nxj7RyCMUTeFAxQHu/uxuyrxlVAWqAKo/TLFgx7Ed/Gblb3hqzVO8cNEL9E3t20yjFUIIIVqOhZvCV8n2HtyBFfAR22dMncdQFFi6q5Dbx2Q26NgkmGomr2x95cQsxHGG26Dw/UK6zOhCwpBgY057mp2MBzLY9eguylaWkTIxBQALi5UFKznmOUaKM6XJx9+ULMvi2cW7mbU0B1VRcH1fHdEbOFlWeuGmQ3y27QjdU2OZfftwerRvORUP/7b+b5R6S0MCqbpmID2Gh4W5C5naeyqD0wY3x9BFA9lbtpdbP7mVKn/VibTOmrgCLlwBF9MWTWPuxXPldy+EEOKsd8wVvgy64S5HjU1EiSCLxxcwKXVFV7wiErIgoxlYlsXhqsMh2127XZh+k8ThidW2a06NhCEJVG6trLbdrtk5WHmwUcfa3EzT4qE3NzB7aS4ev3kikDqdYVl4/Aa7Ciu48p/L2ZRf2sQjDa/CV8FneZ+FpPYVLyrm8H8P0/HGjgx4fgBZv8rCd9RH3jN5mKcEiT7Tx7+3/buphy0aUJm3jOmfTqfSV1lnIHUqV8DFzC9mhj1XCCGEEAK0mERMVzmWGVlv18YoNSAzU83AZ/qCrZhPy/gyKg30eB1FC/1N60k67n3ukO2V/sqQbWfKMC02HDhGYbkXn2GS6LQxuGsS7eMdDf5ZdXly4Ta+2HYk4gbIlgWV3gC3zV3Nwh9NqDEvNmCYfLWjkA83FlBY4cWyLNrFO7h0UEcuHdQJu94wzxnCpXNGMwNpWiZf7/+aMm8ZSY6Wl8Io6vbWzreo8FdEPTMJwXTPlza/xC9G/6Kphy2EEEK0GO3i7OQUVYVsd3Tph6LbcO3KJq7f+FqPYddVkmPtDT42CaaagV21hwRSAFq8RqAygGVYIQFVoCyAHh/664rTGy6draTKxxtr9jNv+V48AQMFBQsLVVHwBUwm9knj3olZjOie0iRVBHcfqeCNNfvxBEKf5tdWAhOCAdWvP9zC/LtGVntfpTfAS8tyeXllHn7DpMpbPUhbtquIX7y3hVtHdWPmpJ6kxp3ZH92HOR+GpHNGMgN5PJgC0FWd5QeXc3nW5Wc0FtH0DNPg1e2vhhSaiXRtZMAK8EHOBzwy4hFidOk3JYQQ4ux09bld2FJQHpKhpDriSB5/KyVfzEZRNZw9zkVRdTx5G/Ds30TK+dNP7GtaMKVfeoOPTYKpZqAoCu1i2lHkLqq2PbZXLIquUL6+nKSRJ2chDI9BxaYKOlxfvZyjz/TRKb5Tg4zpvW/z+dl7m8EibPAC8OX2I6zYU8ygzom8dOd5JDgbt9rYS8v34jdDxxJJCUzTgpU5RzlS7qFDohOAw2UefvhCNofKPNXWW52q6vs/0nkr9vLedwd5897RZKVF3jH7aKWX3YWVVHgCxNg0iqpKQvaJdgYyYAYo9baMtEURnRUFK/AGqgdS0cxMAigofLr3U67tfW2Tjl0IIYRoKa4Z2oUnF24L+1riyKmocSmUZS+geOEzKPYYHB16kTjmpmr7jcxMpXNywz+YlGCqmdzS/xbmbJyDx/Cc2KbFaqRfk07BawWoTrXaE2tbqo3kscnVjjEsfRjtY9qf8Vjmr9zLU4t24PHXvp7DsoJ1+jfkl3H1cyv44P+Na7SAqtIb4P0NBzFOG1I0JTAV4LVV+/jJxX0pdfm49vkVFFZ4Mcy6K+r5DYuiSi9TZ61k0UMT6JRU8x+fZVmszTvGC9/ksGx3cTBF0AoOwOrqRjntRxTtDKSFhRFhLrBoWdYcXoMr4Kq2LdqZSVfAxfKDyyWYEkIIcdaKc+hce25X/rvuAIEw93HxA88nfuD5Nb4/1q5x78SsRhmbFKBoJtf1vg6T0OAl7bI0OlzXgcMLDrPt/m3kPJmDLdVGj8d6oNpO/rpi9VjuGnjXGY9jyc7CiAKpU/kCJvmlbqbPX4tlRVbqO1qrc4+iq6H/PKMpgekNmHy4sQCAH73xHcWVkQVSx1kWVHgC3PVyzd+zpMrHVc+t4M6X17B4eyHegEmFJ0CFN0CFJ4BphAZhp85Anur4DGTcgOqpm7qqy3qpVqrEHf3MZKAyELL9mOdYo4xPCCGEaC0euagPybG2qItIOHWVsT3bMaH3mU9AhCMzU80kxZnCxd0v5ot9X4Ssp0idlErqpNQa36uikuJMYXTn0Wc8jt99tK3GQKq2dUm+gMnWgnLW7C1hVFa7Go+/5WAZLy3fy+aDZVR5g6lv3VJjuWNcJpN6p9XYbLekyneih9SpoimBCVDu9rP/qIs1e0vwh2mbXdfaK8O02HfUxcb8MoZmVJ8ZLK70csU/l3O00hv22ACBioGo9mIU9eQNcrQzkIZpMKrTqNMPLVoBuxa65q4+ayPtesMvmBVCCCFak7QEBwtmjuH6WSsp9wQiekAeY1MZ3CWZf906rNHW+0sw1Yx+PebX7CjZwb7yffjNSOveK8TaYnnxohfDNv2NxsYDpRwq84R9LZJ1SW6fwQvf5IYNpj7fepg/f7aTA8fc+AMmximBUW5xFWvzSoixa8ycmMWM8VkhQZVlEb5IxyklMCMJqCwrmMYYLjCL5DsCeAMGL36Ty79uHXZimy9gcsuLqyiu8Iadbj7OXzoKe7ulIdvTLktDi9M4vOAwvkIfaoxK4rBEMmZmVJuBBBjUfhAdYjuEHEO0fB3jOqIpWrXS+NGujVRQ6BjbscnGLIQQQrQElt/Ek1uKWekH00KN0cnsnsgnD03gpjnZHDjmRgXCPc926ioWcO25Xfnt1QOxaY2XjCfBVDOK0WOYf8l8Zn4xk5zSnGrrp8LRFZ1AwMnDQ/9BRmLGGX/+i8ty8QZC1+JEui7JApbtKaawwkN6QrDIg2VZ/P3L3cz5JqfW1MEqn0GVz+CvX+xmZc5RZt02HKftZHCUFGsLO2sVTQlMgHjvERZku/Cb1RcuRbP2yrTgi+1HqPIGiHME/2QWbTlE/jF32EDq9NkuZzc7adfEEt+3evpeXTOQx20q2sTl713O9EHTuazHZcTawpd7Fy3PJZmXMHfzXAzj5N9ZtDOTTt3J1b2ubuqhCyGEEM0icNRNxcoCXGsPBxtDHX8grihYhomtVzK9fAoPXT+E1XtL+GBDARagKQoBM9jS5+4JPbjpvG5nXJU5EhJMNbMkRxKvXPoKC3Yu4N9b/02FryJkwXqsHouiKNzQ5wYGx1/Fz9/ex9gMFxmpZ3ZTvSr3KOEmVaJZl+TQVTYdKOPCAcFgavbSHF74JjfiNVhuv0F2zlEe/M+3vHj7iBMB1KgeqfjCVNyLpgSmTVM4f1A33txUcUbfEcCmKhRWeOnxfTA1a0lO2AbC4Wa7vAe+ouLbN0OCqUgFrAAHKg7w9Nqn+b81/8fjIx/nuj7X1etYomllJmXSN6Uvm4o3VdsezcxkqjOVoWlDm3LYQgghRJOzLIuKrw5Q/vWBYABVwxIK745j/FLRSdhQyrW3D+JPU4dQ7vbj9hskOHXiHXqTtPA5ToKpFsCu2bl9wO3c1v821h5ey8LchRS6CgmYAVKcKUzOmMxF3S86sf7iwGSd+15bzzv3j602mxOtcMEARLcuyTAtyj3BFMWtBWX8Y/HuqIpZQLAU+8o9R3lj7X5uPq8bS3cX8fqqfVhY4XobR1wCU1UUrh07hP9uW4X/tDFFu/YKBaq8we+5/VA5eUdDG8fVNNvlzLwYLbYXlvkyihppOmeo4/2qnlrzFIddh3lw6IP1PpZoOjMGz+DxZY+H9BuLZGbSqTmZPmh6k14UhBBCiOZQtjCXqjWHoYb2NcepgN0C794yimZtJP2Bc0iJs5NS67sajwRTLYiiKIzsNJKRnUbWut/0cZlsyi/l5+9u5i83nlPvGy29huIP0axLUhRO5KG+sDQXfyD8U4S6Cj24/Qb/t2gHs5bsITnWzm2junPvxCzumLcWtz806KurBCbAORnJ9O2YQCDMk41o115VeQ2mzlpJ1+TgbGC4WbPaZrsMVxauffcRkzEfRfGhaN6QfSLlMTzM3zKfjrEdZYaqFTg/43wmdJnAN/nf1JnKeyq7amdw+8FM7T21EUcnhBBCNL/KVQVUrTmMFc0D+YCFv8jN0f/soP0dA+vev5FIMNUKKYrCn6YOZurzK3klex93jM088drWgjKW7CyiqMKLAqQlOriwfwf6dEgIOU5qnJ1yT2gZ5mjWJVkuN/qHb5N/oC+fbjGqFZo4LtJCD1U+g59e2o9bRnY7ESBO6Z/O4u1Hop7tirVr/PaqgcTaNeKdOqWu6jNC0a69ctpUVj4+hZIqP89/vYfc4tCZqbpmu0xPF6p2/wwtfif2dkvRY/JBCf35Axxbdoziz4rxFfrQnBqJwxPpcH0HtLjgsT2Gh6fXPs3lWZfj1J11jl80H0VReGrCUzz09UOsPbw2ooDKqTnpk9qH5y54Dl2V07QQQoi2yzIsyj7bFzaQWpO/iT9+PYtdxXmoqkrvdt154oIfMbRT/+AOARPvnlL8R6qwdajfcoozJVfpVirWrjPn9uFcN2slvdPjOVLhYdaSHA6UuPEZxolmt7qq8Ozi3fRMi+f+yT25ZGBH9O9nkn44sht//3JXSKAS1bqkGAfnpsfw7mfrURz94bQSzlEVejAtVueWcOuo7ie2/e3GodwxbzUbDpTijjCgirFpzLl9OP07BZui3jU2k+eX5OA9ZTYpmu+oqcFqMKlxDlLjHPRMj0dVCFlvFtlsl4pR2R93ZX+6dPsOX+IHITfXxYuKKVpURNe7u1YrTpD3TB49ftEDVT+5pubTvE+5ptc1Ef1cRPOxaTaeu+A5Zm+czSvbXsHtNzAJDapi9Bgsy+L6PtfzyPBHsGmN0xRbCCGEaCk824+G3lQBFd4q7nr7cf5w8SNc2e98fEaANfkbcZzWdsQKmFQsP0jqdX2aasjVSNPeVqx7uzh+e9VAps1bw8/f3cyuI5W4/ScDKYCAaeHxB3tCPfb2Jm6du5oqb3A25KYRGVg1lPVOHDmVlCkzKMteQP4/byV/1p1UfLuQmN4nU9gcuspdE3rRceY9eK+6Dk+YXjjRFHqwgAPHqhffsOsqr84YxVVDu+DQVRx6zf9k4+wa7ePtvHnvaCb0Tjux/eZR3cJVWY/oOwLYVJUZ4zNP/HdavANHmLVqp852RcJn/yokkDLcBoXvF9L5ts4kDElA0RXsaXYyHsjAV+yjbGXZiX1dARfzNs+L6LNE81MVlQeGPsBHVy7GPHIdfZL7EWeLQ1M0YvVYspKyeHzk4yy9aSk/HflTCaSEEEKcFcqX5mN5Q5d05JYcAOCaAReiqRoxNgeTeoykf3rP6jta4N5QhOkNn+3T2GRmqhWr9Ab4y+e7sCwLt7/uxmUun8GGA6XcMDubd+4fg7rkC8Yc3sbytH4EwsTVkaxLumV0NyC4piicaAs9eMKsj9I1lf+7bgj/c2EfXl2Vx6vZ+wiYFqqiYGHhC5ic2y2F+yZlMalPOtppa8HSE5xcNqgjn245jOe0tU51fUebpjCsewq90k+mSV7QP51ffrAlZN9oZrvi7D58tuKQY7h2uzD9JonDE6tt15waCUMSqNxaScrEk0ss8yvzKfWUkuxMPv1QooX6asdRxnW6iNlXP97cQxFCCCGanb+gMuz2rNQMVEXlfz7+A1f1u4Bzuwwk2Rm6bAUAVSFQ6MaeUcPrjUiCqVbsR//5loOl7poqR4blDZjkFFbw/x6byy9yF/GHn/+KqV+VUlThDVsmvSYxNo3fXDXgRH+plDj7GaS+nZTgrPlpfMckJ4/+oB//c2EfDpV5KHP7cegq7eMdpNTRR+Cp64awu7CSPYWV1dL9amNTFTomOpl12/Bq29vFO5jSN53Ptx0O+b4RVxq0e7BrTjxG9QpvRqWBHq+jaKHFQfQkHfe+6vvbVBtlvjIJplqRDzYUcPvo7nXvKIQQQrRxllFzCfQERxzv3vocz6/+D499+meKqko4v+conr7kMdLiTquGq4DplpkpEYW84ipW5hwNGxjUVTnPa1gsi+lK3PzX6ZAaz9t9XFw/eyVHK31hm9CezmlTefjC3tx0XrcT2wZ2TiTGplF1Wrn1aAo9OHSV8zLrLmypayoZqbFE07bYadNYMHMM019ey5aCshrLwh8XY9Po3i6W/9wzmqSY0ADv3klZLN1VVK9Kgw5d5epzuvJZeZgqg/EagcoAlmGFBFSBsgB6fOifrCrZuq3GkXIPWw6WcX6/9OYeihBCCNH8VAjbB+d7vdtn8rfLfw7AnqP7+PHC3/Obxf/kX1c9EbKvojdPGxG5C2ulXl6Zhxkm8Clf8x4li18kafSNdP1/r9Hl/pdJGHYZ7t2rq+2naBqvrTsIQEZqLIsemsjFAzpg11WcYdYlKUqwQl5GSgzP/vBcZk6qnq86sXcaTnvozNOpqW+uXdmYfg+WEcCds45jX4eu97l9dGY0P4aoxDt0/nPPKP40dTADOiXitKmcGq+oSjCIymwfy2+vGsj7D46rsXP2sG4p3DCiKzFR9vnSVOiSEsPDU4biN0J7TsX2ikXRFcrXl1fbbngMKjZVEDegeqUan+kj0VE9JVC0XB9tLODigR3PqD+cEEII0VYoioISE9ncTq923blx0CXsLMoNfdGwUBNqz1JqLDIz1Qr5Aib/XXcA/2nBVDSV87wBk1ey9/E/F/ZBVRVS4+w8f9twSqp8LFi7n/+s3s8xl5+AaRJr1xnePYWZE7MY3j0lbF8rVVW4e3wP/v7l7pDZskhT30ZnpdIxqXHLfOuaytVDu3D10C7sOFzO51uPUFjhwbQgPd7B5H7pnNM1KaLeXb+5ciDlbj+fbT0SdobqdHZdpVOSkzfvHU37OCcD2g1gc/HmavtosRrp16RT8FoBqlOtVs3PlmojeWz1dL6eST1JciRF90MQzebDjQU8+oO+zT0MIYQQosWIG96BypUFIel+e47uY3FONlf1m0KnxHQKyo/wwfbFDOsc2lNKS3ZgS4ttqiFXI8FUK1RU6SVMO6eoKucBuHwBKryBamlsqXF27p/ci/sn94p6XDeP7Macb3LDph7WlfrmtKk8clHT3mT265hIv471n9VRVYW/3TSUF5bl8txXezBNKyTNEcCpq5jADwZ04I9TB59YFzZj0Ax+vvznuALVKximXZaGFqdxeMFhfIU+1BiVxGGJZMzMQLWdnDWM1WOZMXhGvccvGodpWizdXcR/Vu0nv9SFL2CS4LTRv2MCB0rcjMlq19xDFEIIIVqM+DGdqcwuCNkeZ49lQ8F2Xlz7FuXeShId8VzYcwy/OP+BavspdpWESdEs/mhYEky1QpWeAFqYBM1oK+fpqkrlacHUmUiOtfOfu0dz/eyVda5JOpXTpvLU1MGck9H6iigoisLMiT25a2wPPt92mDlLc9lTWInHb2DTVdrF2Zk2pjs3ndctJGVwUsYkbGr4n33qpFRSJ6WGfe04VVG5sNuFDfZdxJkJGCZzl+9l7rJc3D4jJLDenF8KKNwxbw2PXtKPoa3w37sQQgjR0PRUJ/ZuifjyyqtVMuuUkMasa35b9wEUhdhz2jfiCGsnwVQrFGvXwlbei7ZynmFZxDbw2o0BnRN55/6x3PLiKrwBs9agyqGrqIrC3246h0sGdWrQcTQ1u65yxZDOXDGkMwCWZdWZKqirOn8Y/wf+d+n/hvSbqotTc/K7cb+TXkQtRJU3wF0vr2XzwZqbSwezFyxW5Bxl/QvZPH3dEK4a2qVJxymEEEK0RO1+2Jcjf/8WwxUgmjISik2l3W39UZpxLbIUoGiF2sc7wlbdi7ZprKYoJDbQrNSp+ndKZNlPp/CzS/vRNTmGGAwcioWuKjh0lXiHTqJT5+4JPfj6fye3+kAqnEjWXEFwduqnI3+KQ3NEfGyH5uCREY9wUfeL6js80YD8hsm0eWvYmF9zIHU6j9/ksXc28fnWw408OiGEEKLl0xIdmLf0pQwLK8JoSrGppNzQB2fvuitBNyaZmWqFYuwaPxjQgY83H6o2QxVN01hdVZg6rEtIg9uGEu/QuX1MJreN7s5HN9/H0RvuwNu+A7F2jS7JMUzsk4YtXK7iWej6PteTFpPGL1f8Ep/hC1lDdVycHoeu6vx23G+5oNsFTTxKUZNnF+9ma0FZxP3LjvP4TR56cwPLfno+7eMjD6aFEOJsZ5oWlb4Adk2V6qhNoWQvHFwPnlLQHJDQCXpMBL3hque5fAHuXriFWy/szGVHfLi3lwRLSZ/+kFIFVBVbx1iSr+yJo3vzVzSWYKqVumdiFl9uLwypIhdp5TxdVZg+vkejj9Nyu+mzcy19rvwnqr15Sla2BpMyJrHkxiUsO7iMeVvmsbFo44n1VH7Tz6B2g5g+eDqTuk5CV+XPtqXwBUzmr8zDU8OMVF0930zL4o01+/nRlN5NOWwhhGh1PH6DTzYfYtaSHPYUVaKrCoZp4dA1rh7amRnje9C7Q0JzD7PtMA3Y/Tks/zsc2gCqHtymqKB+3xxqxAwYeTckda3xMIfK3LyavY9VuUcp9wRw6CpdU2K4dVR3xvdqj6oqmKbFT97ayKAuSdx2QU8URcGo8lO19jCudUcwXX4sC1SHhrNPCvHju2BLb57KfeHIXVkrNaRrMj3ax7HzSAXGaSl/dVXO01WFIV2T6ZkW39jDxLN9O45evSSQioCmakzOmMzkjMn4DT9lvjIAEu2J2DX5+bVEn287HLbfGwR7vpWtfpt2Fz+Is8cwFE3HvXc97t2rTzbQDpjMW76XByb3arRZYiGEaM0sy+KFZbk8++VugBPFffzfl9F2+w3+u/4A7393kD4dE/jXLcPISG05N9qtUlUx/PsqKN0Hvsqa91v1L1g9Cy55CkbcVe2ljQdKeebznazZW4Jlgc84+dBxa0E5y3cXE+vQmTkxiwp3gCPlHt64d/SJZRJanI3EyRkkTm6+Kn2RkmCqFXvpzhFc+o9llLn9YUulh6Mpx3tKDWvcwX3PvXkzMYMHNclntSU2zUb7mOarTCMiM39FXthy+NH0fPMZJtk5RxnfW37fQghxKsuyePydzXy4saDWfo6GCYZpsuVgGZc/u4w37h3NwM7Sg7FeqophzgSoLALTX/u+hi/4v5/9DDxlMP5hAD7ccJDH3tlUY9YGBIPiKp/B05/uBCy+eGQSDr11pmzKopVWrFNSDO89MI60eAc2re6n2nZdpWOSk3cfGNtkazQ8m7fgHDS4ST5LiKZ2sNQddns0Pd8sEwpqOI4QQpzNnv5sZ52B1KlMC8o9AW55cTX5x8KvPxa1MA145arIAqlT+d2w5E+wcxGfbz1cZyB1Kp9hYgGPvb2JgBHd2uOWQoKpVq5H+zg+fXgid4zNJM6uEWcPjepjrUCwet74Hnzy0AS6pjTe9LdlWVR4/BSUujlW5aNy8xZihkgwJdqmmopORNPzzbCsiG8UhBDibJFbVMm85XvDnh+rti3h0L8fZv9fryf/uds58tYTePK3nni90hPgiQ+3hrxP1GHPYji2r1oglfn3CtL/XEGV72QK1NxvfUyeX1X9vQEPgUU/56E3v4s4kDrOb1hszC/lua/3nNHwm4uk+bUBqXF2fnn5AB79QV8WbT7Mx5sOcbTKi4JCqm4w6uN/M+3t2dhtjffrPlblY8Ha/cxdvpdSlx+bpmKYJtY593PV+gpmxJYzoHPzV1wRoiHF2TUkIBVsAAAgAElEQVRKqkK3R9PzTVcVEpxyKhZCiFO9vCIvZE04RLYe1bAslu0qprDCQ3qCM+QYB0pcrNhTTJnbj/r98ofz+6WTGneWr09e+Y+wa6QMC/6x2sfPJ9Se1WSVFzDQ2sM6skJeq6sgk8dv8vKKPB48v1erq/YsV/A2xKFrXHNuF645t3oj0D3v/B4zZw/069fgnxkwTH770TYWrDuAqnDiaUTA/P5Jkqbz/sZDfLzlML3S4pkzbQRdkmMafBxCNIfBXZI4WOoOaaJ9as+3uH7jaz2GaUG/jvKgQQghjnP7DN5enx/SUzOa9aiKAm+s3s9DF/YJvte0WLqriNlLc9hwoBRVUfAbJopC8AHwexZT+qVzz8QshnVr3r5FDerQpmBZc2856M5gWfPeF4HttHuxY/sgf23YQzw61s7TK7w8cJ6dZGfNy0pU08udykes46Fq2yMJgCF4T7l4+5FW139UgqmzQNzYsVStWImzgYMpb8Dgznlr2XCgFF8tPXYM08IwLbYfquCyfyzj7fvGSPlS0SbMmJDF17uKcJ9WhCKanm8ZqTEyayuEEKdYlXs0bIXTaNajegMm7353kIcu7EOlN8D0l9eypaAMV5iiQX4juO2zrYdZsrOIK8/pxJ+mDmm9VVb9Htj2Piz/W7AiH0qwWISigWYHTBh6G4yaCe16Bt9T8B2odsAbcrgRnTUmZ+o8s9LL76eEzvQdp2ExXN1VbVs0AXCVz+DlFXmtLphqXfNool7ix42jasWKBj2mZVk8/OYGvjtwLOL1HoZlUe72c9MLqyis8DToeIRoDsO6JZOeED7tIXHkVFKmzKAsewH5/7yV/Fl3UvHtQmJ6n7wJiLNr3D+5Z1MNVwghWoWjVT7MMGWKo1mPClDm9uPyBZj6/Ao25JeGDaROZVrBUusfbSzggdfX19j6okUr3Q/PnQcLH4GiHcHiEH4XmAEwvOCrAF8VrJsHs8bCmheD7/OWg1Xzz+d35zv45xofRVW1r4eKOy0YiyYABsg/1voKMsnM1FkgdtQoCh77KabXi+pomCp+2blHWbqrKOwiw9ryYi2g3O3n6U938swN5zTIWIRoDKZpseNwBSVVPgzLIinGRr+OCThtJy/iiqLw00v68ciCDXjCzM7W1vNNVSDBqXNpK3sCJ4QQjc20LAgTx0SzHvX4cf7ff75j31FXrRk0p3P7Tb7ZVczfvtzFTy7uG83Qm1fpAZgzKVimvJbACAgWmTD98MWvwVsBiV2CDXlrMChd44o+Ok8t99E/reb9/FT/vUQbAHsDra8gkwRTZwEtIQFH3764168nbuzYut8QgRe+yQ1JbYLI8mIDpsXCTQU8ceUAEpy2BhmPEA2l1OXjrXUHeHHZXqq8gRNpHpYVnJH94chu3Dk280RTyIu7OrnuyLe8024wHiK7WCgKxDt0FswcUy04E6IlMAyD3NxcSktL8fv9OBwOOnbsSOfOnU801BSiMSXF2FDDpNhFsx4VwKGrrNxTHLbyal0FEdx+g7nL9nL/5J7E2lvB7XLAB/MvjyyQOpXfBUufhon/G7w41eK3k50Mm1PJT8bU/GC+yEqu9t/RBsBxreFnfZrWN2JRL3Fjx1K5YkWDBFNHyj2szDka8tAoqoWhKLz77UHuGJt5xuMRoqG8uWY/T3y4FVUJPpkM55XsPF5btY8bhnflV2M7cPDuGfx40mQ6D+/PP7/ag9+wwlagOi7GppHgDAZS3dvFNdI3ESJ6FRUVrF27ljVr1mBZFoZhYFkWqhp8Cp2QkMD48eMZNGgQdvtZXvVMNKqRman4w/QcimY9qmYaxFeUUqKEtoOJtCCCosAHGwq4eWS3xvmiDWnHR+AqrhZIZf69Apcf9j4UT5w9GCjN/dbHa5v8LLnzlOtPwI214Q0sv1Hr+p9eqSo3DbTx7Bofg9ND93RZDl41Lqy2LZoAWFVgUNfW12xZgqmzRNy4cRx+8skGOdbSnUXoqoLvtO3R5MW6/QbvfSfBlGg5nvtqN//6OqfG3lHH+Q0LsHh7fT57Pl/KPy69jPT77+NBRWFKvw7MXZbLwk2H0FQFj9/ANC1suopNU0mNtTNzUhbXDutKvENOv6Ll2L59O++++y6WZREIBKq9Zny/OL+kpIRFixaxePFi7rzzTtLS0ppjqKKNsyyLnKJKOiY52Xc0tPFu4sipqHEplGUvoHjhMyj2GBwdepE45qZq+2m6RoEVg3Has61oHvy6fAazl+a0jmBq+d+Da6FOE3FZ88Icqqy+xDtyUczT7/BO+vUkB69uCt/QV1Us3jeqB0zRBMAOXeOeCaFl1Vs6uZqfJWKGDMafn0+guBi9ffszOtYxly9s7nG0ebGlrpr/WIVoSh98d5Dnvt4TVaNBT8BkQ3J3nuvcgye/T43o3ymRv9w4lCeuGsjnW49QkFdA4UefkDljGudkJDOie4qkSYkWZ+PGjXz00UchQVQ4fr8fv9/P3LlzmTFjBunp6U0wQnE28AVM/rvuALOX5nC0yldrsYja1qMe1zM9gQPHXHi91Y8TbUGE/UddmKYVNu2wxSjcDsW7w74UaVlzxaaS0Lc37MmDUy6FeQ9Xr76ckaTi+WWYCrSag4IuV2DlxcNpv7tIA+COSU7OkZkp0VIpuk7sqJFUZa8i6corzuhYwefyoWlM0ebFtsIaOaINMkyL33y0tcZAqra8eo+l8ta6Azxwfk86JZ3s2ZHotHH98K54U/3kv7CKnhN+00TfRojo5OfnRxxILVmyhJKSEqZOnYrX62X+/Pn8+Mc/xumsuVSyEJEoc/u5Y95qdh6ujLhCcG1ibBp3juvB7z/eFvJatA9+dU2h0hcgsSWv8T60EdTwCXqRljVXLAOObIGr/gkf/ThYBTBSig5JXel841+J/8c63H6D04sx1hUAx9g0HvtB31b5wFGCqbNIsN/UijMOppJjbNg1DbdZ/YQX7cLQ5JgWfGISZ42vdxTWWOUpkrx6y4JXs/fx2CWhfdwUVcUKU95XiJbiq6++CgmkNm/eTHZ2NsXFxSeKT0yYMCHkvX6/n++++44xYyJ7wi9EOG6fwY1zssktqvw+jfrMOG0qz91yLh0SnSE39BD9g1/DtHDqLbxQkKc8WPq8Br8738G4eVU8NKqOtY6+ChhyY/B4n/8SAhEEVJoDkrrCXYtwxifz5r2jufpfK6j0BsL+/MOJsWnMGN+DSwe3zuq2EkydReLHjuXo7DlYlnVGkf/43u3D9n+IJi82xqZyxZDO9R6DEA1l9tIcqsKkk0SaV+8zTF5dtY+HL+yDXQ8+GdxTWMnrq/exc18Rx/rcQNq81QzuksSto7rTOTkm5LOEaA5lZWXs27ev2rbs7GyWL1/OFVdcQc+ePdE0jT179rBjx46QohN+v5+VK1cyevToVvk0WbQMP3t3E3nFVWccSMXZNWyaygvTRjCyRypFFd6wD8qiffAbY9dOnNtbLN0RbMhbg0jLmqN9v65q5N3QLisYUJXkBisFnl4h0BYHWHDOzXDRb8ERTAfMSovnvQfGceM/vsLlt/CoNYcamqpg0xR+dEEv7p/UensuSjB1FjmSmMacHlNY8fvPqQwE0+wSnDo/GNiRu8ZlRlxZrGtKLMO6pZCdezTktUjzYk0LbhyR0RBfS4gzsuFAadjt0eTVm98vmD54zM2zX+1m15EKAoZFwLQgvhPbdxWzKqeEucv2cl5mKg9f2JsRmakN/VWEiMratWur/bfH4+Hrr7/m6quvpn///ie29+3bl759+7JkyZKQY3i9Xvbu3UtWVutbNC6a39FKL4u2HK6x8E9d5csBbJpCv46J3D+5JxcN6IBNCwYLaQkOhnRNYt2+Y9WOGc2DX5umtI57laQuUMcsWyRlzUk85SF3zylw/0o4vBmy/wX7VoKvEjQbxHeAETNg8PVgD7137OY9xktL/sK3P/8bczcdo8ztx2+Y+A0LVQGnTcMwLS4f3Im7J2QxoHOYNVitiARTZ4E9hRX88v0tfLe/FKPTMAJVJ6eCK70BXl+1jzfW7GdwlyR+f+0g+nWs+x/1zElZbKyhm3hdebGqAj8Y2JGkWEnzE83L4zfCzrJCdHn1CvDPxbv5emdRjfn+vu/L/C7fU8z6fSX87NL+TJNqlqIZ7du370SlPgiunwoEAtUCqboEAgEOHTokwZSolzfW7K+xtVGk5cvHZLXjlRmjwh7jvkk9eejN70KyDyJ98KsqCne2hvN0j0m1NtyFusuaY4+HkTNDt3ccDNfOjngolmVx+Ikn6DZ9GudePozpl1msyi1hy8EySt1+Yu0a6QkOfjCoY8tehxYFCabauDV7S7jr5TW4fMb3BR9C/4D8pgWmxbp9x5j6/EpenDaCcb1qr/g3sXcaw7qlsDavpM5S0qeLd+g8dkkr6igu2ixNVWoshBJNXr03YLJ4R2HEfwtuv8mfFm1H0xRuHdU9ylEL0TA8Hk+1/3a5XMTGxp7oKxUJ0zRxuULLVwsRifkr88IW/4mmfPmqvSWUufxhH9Ce3y+dWIceNpW7rge/mmUytFNijVk7ZS4/H20qIK+4igpPgORYG/07JXLJoI5N34xds8F598LKZ8Hw1rhbbWXNARhwVeSfaRrBRcNa9VCi7P0PCJSWknrHHQAoisKYnu0Y07Nd5MduZSSYasO2Hyrnzu8DqUi5fAZ3/3sdC2aOZkjX5Br3U1WFF6eN4KYXstl1uAJPBDeRihLsbP363aPpmhLaRE+IpmbTVJy6FnY2KZq8+uO9p6Lh9ps8uXAbw7ql0L9T605xEK2Trle/BYiNjcXlcmGaZlQBlTTwFfURMEyOVoVvkRJNmrVdU8kvdZEUG1pSW1MVXr7zPG6YnR1VlUBFgUTF5CcfPYNnSnucp8zWbi0o44Vvcvl0y2FURal23Fi7xs/f28xN52UwfVwPMlLP7F7H7TP4aGMBC9YeoLjSi2FZJDp1pvRL5/YxmXRIPKU633kzsLKf49SJvojLmuvOYNqeXnsvKg6uh5XPwa5F4P/+YYxmh8wJMO7HBBIGUPjMM2S8MAdFP3tCjLPnm55lLMti5qvrowqkjnP7De59ZT0rH59Sa1+FGLvGf+8bw6P/3cRnWw9jBgz8hO6vfL9vWoKDeXeeR8+0+KjHJERjuXRQRz7YUIBxWrpfNHn1Nakr398fsHjxm1z+etPQRvluQtQmOTmZQ4cOnfjvrl27ous6O3bsYMCAAREdw2azkZgoDwNE9Fx+A11VwhaeiCrNWoEqb833OoO6JDH/rvOYPn8trjAlu0+nqwqpcXbemjmZlNU6+6fPoMPPf0bSlVfyyso8/rhoO/6AFXLNAE7cc722ah8L1h5g1m3DmdQn+ubWZW4/z3y2k7fX56MohNzL5RRV8eKyvYzJasfPLutP344JmI5USstHkuxchqpGkTGk2SGtH0z5Rc37HFwP782EsoMQ8IB1yvENL+R8CQeywQPpV15FzMCBNR+rDZJgqo1at+8YxZXhp3ojWdBZ4fGzIqeYCb1rPwk4dI1nbz6Xg6Vunv3fv/Jxx3MxFBVNVTAti4BhMaFPe2ZO7Ml5mdKwVLQ8Myb04JMthzD8oRfGSPPqw4kk39+wLD7efIjfXD2wzeSOi9ZjxIgR5OTk4PMFZwecTieTJ0/mk08+QVVVevbsiaqq5ObmkpeXh80W+m/Usqyo1lgJcVysTQsW6QkjmjRrywouH6jNqKx2fPSj8fxp0Q6+2VUEEJKWHWvXMC2Lq4d24bEf9KVdvAMuvxxHr17k/+jHzF9XwGyzW0TN3f2Ghd8wmPnqOubcPiKqgOpgqZsbZ2dTWOGpscLh8bEv3VXEmrwSnr+2L5nP/Ao9rSfJU6+AL38VDHrqojshvT/c/l7Ns1K7v4C3bq+775SvCl2FpMB7sH4MDL+j7s9vIySYaqPmLM3BHWZWKtIFnVU+gzlLc+sMpo5rX3yQu3d/wZOzfkpRlZ8Kj58Ym0a7eEedJzkhmtPAzklktotj5+GKsIl6deXVhxNNvr+qKny0sUDWTokm16NHDxwOx4lgCmDs2LHEx8fzzTff8O6772K32+ncuTMTJkwgJyen2vsVRWHAgAHStFfUi66ptI9zUBTmwW80adY+w6RLBC0nstLieXHaCIoqvLyxZj+fbT1MmduPpiikxNm5YXhXrjm3C3Gn3bM4+/al6C8v8vxr3+FVolsj7vGb3P/aer54ZFJEYzxW5eO6WSspKveGnfk6nUVw1mrmGxt5bsBoLvzZAyiqCh36w5dPwJGtwf5Tp/egsscHS6mPvBcmPVpzIJW/PrJA6hRKwAOLfgpxadDvsojf15rJXW4b5A0YLNlZFHJjGM0NHsDqvUep9AYiCobKPv6YxEsvxW7T6ZKsA9JLR7Qez958Ltf8a0XUabGKQtiUkWjy/d0+g7ziqqg+V4iGoKoq48aNY/Hixfj9JxelDxkyhCFDhoTsn5FRvUS0ruuMHTs2ZD8hInXHuO48t3hPyLrrSNOsFQXO75sWVXXgtAQHP76gNz++oHfE73k2uwCvEv5eqM50bsPk5RV7+eXldafOPvbOJo5WRhZIncqr2vhJoDdrTQuHCvSYAPd8BUW7YNUs2J8N3vJg0JTQOdhHqt8VwcIVNbEsePuukEAq8+8VuPyw96F44uzBbKO53/p4bZOfJXd+X6wj4IZ374FHc8DW9h+2/H/27js+qip9/Pjn3jt90iHFhFCSUENRegfFtRcWFVR0VSyIBd1d26qr636tq/7WsvaGggXXvip2qdKLdAlggBBCKinT5977+2MgMMwkmYE0kvN+vVxf3tyZObOvuTP3Oec5zyOCqTao0unDoEghy+fR3OBBYHN+hcPbYDCl6zpVX31NxtNPH/OYBaEl9UiNZdY1Q7n6rRW4aitf1s9ilEmOMbOnInTGLpp8f4Aqd92d6wWhKQ0dOpRt27axe/du/P7IP4dGo5Fx48aRlpbWhKMT2rrLhnTm+R+3h/1bJGnWVqPCDWObtiz/nnIna3ZXhP1bJNk+PlXn/RW7ufPMnpgNdf8mFFe5WbCtpM7UvoaCNlXV+WZjEReenHH4Qck94Px/H9sb370UHKVh/6Tq8OxyL/eOaaBgxebPYUDDafEnOhFMtUEev4YUphBEtDd4EqE5xeG4N25EkiQsfdvXhkOhbRnaLYnPbh7FXz/8lbz91fj8KupR19GhYioJNiOPTOzHrF/ywwZT0eT7AyTZRDU0oWXIssyll17K+++/T0FBQdAKVV2MRiMjRoxg1KhRzTBCoS3rEGPmnH4nMW/DvrBVgetLs1ZkicxEGwM7J1Lp8vF7qaN2i8FJCdaI0uoi8e7y3WH7EUaV7aPDt5v2c8GAdOoye9muMHduAZEEbQ6vykvzdwQHU8djyXPgC9/24M6RJv61xMNNQ0wkWOoYtbcGlvxbBFPCiSnOasSvhX4pRXuD59d04qwNf0SqvvyKuHPPFcUlhBNej9RY/nfraPL2V/Pcv97lZ0snnFqgF5XFoDA8K4np47IZ1i0JSZJYvauCpTvL8B51ExBNvr/dpNAjTVS4FFqOyWTiiiuuYMmSJSxduhRVVYP2UUFgf5TBYCA+Pp4JEyaIohNCo3lsUj+2FlWxo9hR29y8ITIa8VYL95zdk5nvr+W7zfsxKTKHohGvX6NnWiw3jsvmD31SMSqRl/o/2oa9lWFXi6LJ9nF4VfKKqmFA3ee8v2JP2AnsaIK2/DIHu8ocdfbGipimwfbvqavlx+B0hfFdDTz1i4eHT6snja98J1Ttg7iTjm88rZwIptqgOIuBjjFm9lUGV3KJ5gYPIMZioKO9/iVcXVWp+vprOr8963iGLAitSk6ynVsXz+L/ffE5hpQUdJ2wbQKmDu/Ma4t2hhyPpqy6Dpzdt23/0Aitn6IojB07llGjRpGXl8fy5cvZsLOQ1BgjFouZk046iREjRpCR0Uiz3oJwkMWoMHf6CK56cwVb91U32A/KbJCJdVaTYoeb3l2Lx6+i6aGZNOsLKrnzo18xKTJvTxtab+/M+lS7w6/WRpvt8/mve9lT4cRsUDAZZMwG+eC/FYyKRJkjfAXmaII2oyKzr9J9/MGUtzqwIa0e/zzVzKg3Hdw2rJ7MCsUEzlIRTAknHkmSmD42iye++S3oSymaGzyLUea60d3q7TMF4Fy5CqVjR8xZTZuzLAjNyb1lC0pSEsbUVKDu35ST4q0M7ZbEorzQvPJI8v2NisSUIZlYjJH9GAtCU1MUhV69etElqzuP/PM7frvjLJF1IDS5OIuRuTeM4JM1Bby0YAcl1Z6Q/at2k4LFqHDp0Ew+XZlPXqUHn1J/4OXwqDhQmfLKMt68eggjsjtEPTa7KfytcrTZPr3S4hjbIxmvX8Pj1w7+W8Xr13B46u5/FW3Q5o6iOXGdVD9Q/2pe3xSF83oYeHyxl97JdZ0rhVYSbINEMNVGTRrUice/2RpyPNK+OboOU4Z0rv3vCoeXuav28NPWYqpcPgyyREqchT/k/cJp57SP0pdC++Fctgz7iOERnXv76d1ZmV8etvdIQ2XVDbLMtFHdjnmcgtBUSms8JMeYRSAlNBuTQebSoZ2ZMiSTtXsO8OWv+9hf5UbVdDrGmDitdwojszsGSoe7NHxK5LewLp/KtW+v5H+3jiY7Obq06u6pMaz4vYyjM/2iyfaxmRQm9E5h0sBOdZ4z65ddYdMcow3aYhujZ6ElHlRvg6c9NN7CwFdq+OuIOrKYND9Yjm1F8EQigqk2Ks5i5PoxWby+6PeQJfOGbvAsup/LB3cjyW5iZ0kNT3+3je+37EeWCL5hLKxiqb8LjxfYuOr735gxLgerScywCyc+x7LlJFxycUTnDuqSxL1n9+axeVsbTE85ksUo8+IVA8lMsh3rMAWhyRRXe+gY20ClLkFoApIkMbBzIgM7J4b87av1+/i91BF2D1ND1e5cPpUnv9nKy1cOjmo8Vwzvwoer9qD6jq18O4Cm65zbv+7iExAI2jYVVoUcj6rnll8jJ8pgMSzFACf1h32/1ntaTpLMlFwjz63w0i8lzOqUyQ4Jbb+Hogim2rC//KEHO0sd/LSlOOKbPKtRZqDnAFfMfZzF3R5l+seBG8Q6mpTjMphxeQINfr/dtJ/3rhsW6BouCCco3evFtWYNGU/+K+LH/GlkVwyKzD+/3IRP1VHrumAIpPYZFJmXpg5kfM+UxhiyIBw3TdNZmFfCO0t3savMQYXTh8evcddHvzJtdDd6pcW19BAFgZcXbA/bDzCSane6Dj/9VkJZjSeq+5QeqbFkJ4cPdCLJ9lEkuGBAeoNtZqaPy+Zvn6zH4Ql+f5EGbbIEZ/RJjarnVr1G3Q5f3BqoylePB8aZmb0+zL4ygxWG3wzysRf/OFFIej2NwQYPHqyvWrWqGYcjNDZN03noy03MXbkHv6qH9J46RJHAaJCZeHIGD1+Yy+InX2J6eSc89TV0O4pRlsjsYON/t4wO6SAutB6SJK3WdT26qblWqDG+nyqdPj5ctYf/rS+k0ukDCeJ1L+O2LuKGFx+MOl1ia1EVry3cyZfr9iL5fbjlw4+3mxR0YMqQTKaN6iZWpIRWQdN03lzyO68s2InT68dx1I2qIksYFYluHe3cdWYvTu3VtBMAbeH7Sdw7NY28/dWc//zikBLqmsdBwQtX0eGc2xtcubEYZG45LYdbTou8YS/At5uKuP2DdVFlH9S+plHmi1tG0yM1tt7zvH6NQQ9/T3UdfQdrNv1M9arP8ZXtCQraLJ0ClTWtRoW504cfc6GNEH4vPJkdaPZ7LAxm+PMWsEe/T601qu+7SdzxtnGyLPHQBX25cnhX3lz8O5+u3YtBltB0HU3X8fg1TAaZC0/O4NrR3eiRGovHr3K7twcepeF+I0fyaTp7K1zc/9lG/j3l5CZ6R4Jw/AoqnDzxzVa+27Qf6ej0VWBbylBee+QHzuufzl1n9iQlLrIO7r3S4nh68sncvOtnvvcnUDZgOFUuH4k2Ez3TYjmrb5ooNiG0Gm6fyo1zVrN8Z3mdN4mqFlhp3bKvmhnvrubW07pz86k5zTxSQYClO8vCFuqOptqd26/xw5biqIOpM3PTuHRIJh+s3BNVQGU1yvzjgr4NBlIQ2DN2y2k5PPN9XtjXqG+LhkmGfhlxjRdIAbqkUC2dRoz6ObISSSv7IxhtMOLmNhNINUQEU+1ETkoMj07qx/3n9WbF7+UccPpw+VQe/HwTK+49nTjr4Rn0bzYW4auj10NDOckev8bXG/bxj/NzG2+pWRAa0YaCSqa+sYwat7/O9FU3Cvg0Pl1bwE9bivlg+vCIfgwPkZYu5sq/3YNtiOjFI7ROmqZz45zVLN1RFlFzdghMOvznp+1YjDLXjhYVXIXmVen0hfT0g+ir3VW6opsoPuTv5/UBCT5YvhtXBNeM1ajw9/P6MGVIZsSvccOYLH4rqmbehqKIgzajBB0d5TzXr+7ApajSTWmNB7+mk2A10inRiqGe3ltqZSV777gT3QP2KbfCutfAF9qgPvyAbND7fDj1vsjObwNEMNXO2EyGoH0az/ywjWqPPyiYenH+jpBUD4gsJxkCebsfrtrD9WPFj63QuuwoqeGy15ZR44msVKuqQYXTyyUvL+WrmaPplNhwap6/ogLvrl1YB9TTnVEQWtjsZbtYvrM84kDqEJdP5clvf2N0TjI90yKfYBCE46UoErJESFW9aKvdKQ20fKmLLEs8eH4uPed/wTu2HuzEhk/VOHLu2WwIBCgjsjow8/TuYYto1EeSJJ66eADxViPvr9iN16/VOekHgfTx7JQYXu6fSNUdfyH26aewjww09vX4VeZtKOKlBTvIL3VgVGQkKbDabDLIXDOyK5cN60xKbHDmhfu33yi4dSaxp55Kyp13IBkMkNIVvr0XkMBfR1BlMAcaJw67ESY80GCfqrZEBFPtXHZyDDuKa8hIsAKws6SGXWWOkPOi6cDt8njg6PoAACAASURBVGm8vTRfBFNCq6LrOtNmrcThja7nhQ7UuP3cOHs1X84cU+dzu3wqFoOCc9kybIMHI5nqaWQoCC1I13VeXrCjzpnvhjIQfH6N1xft5MlLxISB0Hw62E2YjUpIAYpoqt0BdIyw+IRP1ZAlKSj4cq5dy6BV3zJ53m3srFF5d/ludhTX4PCoxFoN9O8Uz+VDu5AWH1lqeDiHgraLBnbitYU7+WZTEQZZwqtq6HqgiJGmQ256HDeOy+a0XikYFBnnc89SMPM20v/1BCuTe3Lr+2vRdb12cvzIiROnV+XF+Tt4cf4OrhzehXvP6Y0sS1R+9RX7H36E1HvvJf788w4Pasi10HcSrH0Xlj4P7mo4FLjqGsgGGDYdBk+D2LRjfu8nKhFMtXPZyTHsKKlhbI9kAPYecGFU5JA9JNHkJAOUVIfv5C0ILWVlfgUl1Z6wjREbunlUdZ0dJTVsLqyiT3qgqlm128ena/by2qKdFBxwIUuBvYjxuo+Lu0/g+io3qRHutRKE5rRsZ3mdqU6RZCCoOvxvfSEPXpDbYIUyQWgsE3qn8sDnm0KOR1Oi3G5SmDw4fK8nXdf5ZUcZryzYwbKd5fg0DXSwGBXOzE3lutFdsT/+OMm3345ss5FjgwfPzw37XI2hb0Y8z152CpVOHz/9tp+yGi+qphNvNTK0WxJZR5VAtw0ZQqcXX2DWQy/xXJ8LcTew6HwouHp3+W72lDv4R+HPOH74ns5vvYmlV6/QB1gTYeQtMPwmKN0GzrJAIGVNgOTegXLq7VT7fecCAFnJdnaUHC576QqT3gfR5ySHazwnCC3plYU7wn6+I01f9ao6byzeyRMX9eexeVt5d/kuZEmqnSVVD0ZpBzDyToWJd/71M+N6JPPUJYGUDUFoLWYv2xX2WogmA0GWJL7ZWMTFg+puQioIjaljjJlxPZL5fsv+kEmxSEqUQyBgOqffSSHP/dPW/dz7yUaq3L6QlS+XT+V/v+7jm/V7SUs/n/8MGUdztqGNtxn54ymRXWfrYjN5Lnci7jB9uOri8qks2FjIk24LT370X5SEBt6dLENKmGCrHRPBVDuXnRzDd5v21/53jCX8RyLanGSLQVQsE1oPh8fPom2lIZWgorl5VDWdL34tZN8BN2v3VISs3h7Jq+qAzoLfSjj3uUV8MmNkxBUBBaGp7S53HH9VNJ9KUWWEG9IFoZFMH5fForzSqKvdARjROWvnL1Q+sxHjjBkoMYGVndlL83nk6y31fqeruo6qS+SbE5n86nJevXIwo7t3PO7305h0Xedvn26oM5CqLwPDLRv4Kq47f9ZMRF4uQzik7XfSEuqVnRLDztLDK1M5KTFhq+UcmZMc0fMm2xttjIJwvMpqvBiU0M2w0aavqprO6t0VuOr50T2SV9XYV+ni0teW4Yiw6IUgNDVPHZ/faDIQNJ2Q5qKC0NQGdUniwpPTsUbZYkKRJdKSbPz9iZtQKw6w8+xzOPDpZ3z5694GA6mjOb0qN8xexca9ldEOv0mt2X2gzi0WVSs+pfzH14gfPplOt8whY8ZbxA48B1fe8tpzNF3n7V/ym2m0bYtYmWrnToqzUOXyU+32EWsxkhJrYURWBxZsKwmauYwqJ9msMH1cdvO/GUGog8unEq6AU7Tpq5pO2Opn9c34qRrsrXDx8oId/PWMnsf7VgThuMU2QgaCQZZIEO0vhBbwyB/7UeX28fPWkojKhxsVieQYMx9OH0FivJXERx/BtX49+Y8+zh2dCWqufkhD+2idXpWZ76/lx7+OQ2olVeteXbgz7P8fkWZg+FSd91fs5o4ze4p+iFESK1PtnCxLZCXb2VlyuILfDeOysJpCL6S4oZNIPO1aKpfOpeD5qRS8dDXVa77E2j14Vl9C4szc9lfNRWi9Yi0Gwm3jO/Lm8VhFMuPn8Wu8s3QXfrGXUGgFhnRNwhhmpTaaDASzQSY3Pb4phicI9VJkiRcuH8iM8dnYTQr2MPcrEPiMmg0y43umMO+2sZwUb639m7V/f9b/5TFkQ+jEQiTf6QD7Kt38WtB6VqcWby8NW2ApmgwMSZLYvK+qCUbXtomVqXaswl3BJ3mfUBL/PtfNr0GSNWwGG6eknEJa8sns3teRoyfhG8pJthoVZozPxmQQcbrQeqTEmgNpfkcVMIu2pO7Rotlz5dc0fthSzFl9xUSD0LKuHNGFWb/kw1E7p6LLQDAwMrvuJqGC0JQkSWLmhO7cMDaLrzfs4+UFO9hZ4kDVdSQg0WZi6vDOXDGsS9j9qrqu89LCnTi14EmFaL7TPX6V1xbu5IWpA5vkPUZD13VcdbT9iCYDQ5KOvalxeyaCqXaowl3BI8se4ec9PyNLMm7coAEauPwuftr9E6bExVitMbj3n4+vOrLUJKtR4bTeKdw0XqT4Ca2LQZG5YlgX3lj8e1ClyWhuHsOJZsbP4VH5eE2BCKaEFtcp0cbAzoks3VkW8rdIqqJZjTLXj8lCPsbmp4LQWCxGhUkDOzFpYKDanU/VMMhSg6l3BRUuig64Q45H852u6fDd5qJjG3gTkJA4eoIEoi8gJreStMUTiQim2pmC6gKu+uYqyl3l+PXwsxgaGm7VDUY3low5mErPxlE6Muy5ABKBL7SJp2Tw8MS+rSZ/WBCOdOWILryx5PeQ45GW1A0n2j1Xov+a0FrccWYPpr6+POzG+4YyECxGhcmDRc0vofUxKpFlxZQ7vBgNMu6j0m+i3kerBSpbtvQeI0mSsJsVqtyh93XRZGBouk6STTScj5YIptqRCncFV31zFaXOUjQi3Lsh+TClfEucOY4DxQPwaxq+g2U3LUYZTYfROR2ZPjaLod2SRCAltFrpCVbOyk3ju81FITeQDd08KlKgUWnI8Shn/Pya2DMltA6DuiTx0AV9efCLjVFVMrOZFB6b1I+dpTXEWoykJ1iwmcSthHBi8Wvhy4dH+50uSXU/V3M7q28aH68pCNkfHE0GhkmRaxvTC5ET34DtyMPLHqbcVR55IHWQX/fgS/wvr5wzma17NR6ft5WbxmeTGm/lzNxUUmJF/xzhxPDkJf3Z9bKDrUXVYavyhWMxyqTEmik84A750Yx2z5WY8RNakylDMrGZZO78aD2qx4tPrvuWwKSr6AYDXr/GHf9djySBpumoms55A9K5bkw3eqWJmzDhxJBgM6KGCYKi/U7XdL3OAhjN7boxWXzxayFqmEm7SDIwLEaZaaO7oYj03aiJYKqdqHBXML9gftjUvopFFZR+W4q32ItiUYgbFEfqxako9uAviK2O77lu9PU8Nm8rfz2jp1iFEk44ZoPCBzeMYPqc1azKLw/pdH8kCbCaFMb1SOa6MVlc8fpy/EdV/Ytmxs9mVDgjN7Wp3pogHJPzB2TQPX8jb321ma9OGohOYN+JX9UwKTKSJOFTNTS/jv/g8mzNUT3TPl27ly/XF9K/UwKvXTmYeFEyXWjluiTZMCkyTo79Ox2gf6eEVnMv1CM1lpzkGDYWhq/G11AGhqbDZUM7N9Xw2jQRTLUTn+R9cnBzYrDSeaWUzCuh03WdiOkTg6/CR+HsQvKfyqfbfd2QD1bl86ge5myew9SeV2M2yK3my0MQomU1Kcy6egiLt5fyyoIdrNxVgXLwhhECPUn8Hi8js5KY/ofejMjqgCRJpCdY2HFEC4FDIt1zpaHzx1M6Nct7FIRI6ZqG6fUXuP+vf+H/Ro/l59+KKTzgwulVcftU3lj8O6qmo9ezaqUeXKFau7uCc55bxOe3jKJjjLkZ34UgRMegyFw9sisvLdgRkqUQ6Xe63RyoXtyaPHnJAC566Zd6JwrDsRoV/npGD3HdHiMRTLUTH237CI8avPlddakUf1ZMxrUZxPaPBcCUbCLzpky23bmNyl8qSRybWHu+V/Oyev9azIbWsaQtCMdKliXG9khmbI9k9h5wsWhbCQdcvtqSuj0/fp1Me3eSssfUPubGcdk8+MWmsD9SDe65kmHiyRnYzeIrV2hdqr/9FslqJWZcoPnooR6BpTUezvz3QlxeNUx9sPB8qs7+KjdTX1vO57eMavFN+YJQn8uHd+alBTvC/q2h73QIFLuY0CulKYZ2zHqfFMcbVw3h2rdXRhxQWY0KV43swnVjspp4dG2X+GVvJw54DoQcc+Y50XwacYOC89wVi0Js/1hqNtUEBVMA+2vKsBjFvg+h7chIsHLpUakN1VVjKX9nNklXXlF77PwB6by2aCc7SxxRbziOMRuYOaF7o4xXEBqLrqqU/OcFUu+5OyTb4MWft1Pl8oUNpByb51O18jN8ZQXIJivGlCziR07G0ikXv6azu9zJF+sKmTxEVPwTWq+UWAtXjezK7KW7cPmiW8mxGGX+cX4uhgirBzanEdkd+OjGkdzy3hp2lzvr/L2ymxRkWeLes3tx2bAuzTzKtkUEU+2EXwvdK6XWqBhiDEhKaMqeId6Aa5cr6Jiu6zj9bswGa8j5gtCW2EeNovDue/BXVGBIDEwoWIwK710/nAueX0xJjae2qmV9JMBmVphz7XDSE8R1IzQ/ze3HsbYY764qNKcfySRj7GDFNjgV57KfUGJjsY8O3mjv9qnMXbkHX5ibsKoVn1K5/CM6nHEzlm4DkRQDrt9X48pbjqVTLgAun8qL87dzyeBOIiVcaNXuOasXew+4+GlLccQBldUoc/OpOUw8JaOJR3fs+qTHMe+2MQx99EcGZMbzy/ZATzlZAp+m0+ekOGaMz+aMPmmYDK0vIDzRiGCqnbAarIHeUUdQYhT8NX50VQ8JqPyVfgwxwR8PWZIxSXYsRnHhCW2bbLFgHzWKmp9+IuGii2qPd4wx8/VtY5g2ayVbCqvweP1ocvjrwW5SiLcaeefaoeSkxDbX0AUBAF+Jk+r5e3CtLwUJdO/hfSFuGaqX7EWrrCZ+8s0hAc/XG/aFfU7N4+DA4nfpcM7t2Hoe7j1oyxmGLWdY0Ln7qzz8WlDJyZkJjfiuBKFxybLE85eewmPztvL20nwAvHVUepWlQEPbB8/PDclmaI2+27yfPifF8c60YYHJcK+KT9WItRhFxb5GJu6K24mBqQNDClDYcmxIBomq1cGVX1S3SvX6aux97EHHvaqXTrbuIg9eaBdi//AHqr/7PuR4gs3EJzeN4tWUIibIZZgNMjFmA7EWA7FmA2aDzIisDrwwdSCL7z5NBFJCs3P9Vk7xc2txri1G92lBgRQAGuDXke2dcK6VqPg0D/2IVah5G4twhNlv4dm7Fd3vxdZjRINj8PhVFvxWfLxvRRCanCxL3HdubxbeeSrXj+lGnMWA3awEvtMtBmwmhbQ4C9NGdSPRbuTCk1vvitSRZi/bxRXDA+l7gaa+BhJsJhFINQGxMtVOXJ17Nb8U/oLLfzh1T7EppExMoXBOIbJFDqrmZ0wykjDy8IyihMSojFGY5QTMhqKWeAuC0Kxixo2l6MEHUWtqUGJigv6m+/2kfTqb5/7zPL7sHuwuc1Lj8Qd+dOMtovea0GLceRWUz9mCHmEjXt2n4VhTjK7qJF7UHUmSKHd4w56ruqqQbXERNTPVdCip8TR4niC0FmnxFu48sxe3n96DvP01HHB5USSJJLuJnJQYJEliX6Wbt375nZvG57T0cOv1W1E1+aUO0Y6jmYhgqp0YkDyADpYOFNQUBB1PPicZxa5QNLcIb7EX2SoTNzCOzOmZyEek81kMFq7KvQpHpSpWpoR2wW+1sXjE+cz89wJ2exU8fhWjIpMaZ+Hyjm5Gn5SJNTcXK9A3I76lhysIqNVeymZvjjiQquXTcP1agrlbHPZBadQ1ca1Y49CcVeiaGlFAJWbAhRORUZHpkx6+AfVfzujBJS8vZerQLq26n9q7y3dx6dDOGFthgYy2SPy/3E5IksQ9Q+/BooTOmCeNS6L7I93JfS2X3s/1JuPqjKCGvSbZRO+k3gxMGYjHr2EWmxWFNkzTdJ77MY+B//cDT8cN4jdHYEO9poPHr7G73MkzW71c1PUS7vl4Pe4oq0AJQlNxLN8XlK53pBUF65k4ewZ9/n02fZ89lz/OuYl1+7bU/l33aVT9uAdd10mODd9rxpzRC8lgxLltaYNjUWRJrNAKbU52cgxn5qbWWVK9Najx+Pl8XSGXDRXVNJuLuCtuR8ZljuPWU24NG1DVxSSbSI9J54UJLyBJEm6filmsTAltlF/VuHHOal6avwOHx49TCz+z7pYNeDSJT9fuZeILS6h0+pp5pIIQTFd1apYUgj80mKr2OLjmo3u4etBFbLjtS1be9Am3j7oasxLc5kKr8eLdXX2wJ1ro97xstpMweirl37+Mc9tSNJ8bXfXj2rGKip/fDDrXKEucKVKMhDZo5oTufLByN0WV7oZPbgGfrd3L8KwkTooXFWSbi0jza2f+lPsn4sxxPLzsYYCQRr6HSEhYDBZ6J/XmhQkvEGMK7BkRK1NCW6XrOnf891cW5pXgjjBNyuPX2FFSw5VvLue/N44QDa2FFuPOq6hzVWpn+R4AJvY5HQCrrDCu29CQ83SfRs2SvZw2pefB9KDQVde4oZOQ7YlULp1L6ZdPIZmsmFNziBsxJei8HqmxoviK0CadFG9lyuBMnv0xj8cm9QN3JexaCq4KkGSwJUGXkWCyN/xkx6iwppD3trzHVzu/ospXhaZr2Aw2BqcOZv2mU/jnmec22WsLoUQw1Q5NzJnI6IzRfLTtI+ZsmVPbg0rXdWRJxuF10zN+MPeMvJGBKQODyua6fce3Z8qjeliwZwFFjiLcqpsYYwy9O/Tm5OSTRT8SoUUtzCvlu837wwZS9TUp9ak62/ZXM2tJPtPHZbfAyAUB/KUudDX8JEBWUiayJPPnrx7hgl4TOCUjlwRLmEBHB99+JwZF5pqRXXlx/g48YcpEx+SeSkzuqXWOxWZSxLUgtGkzxmcz/clZVHueIXbHl6AYQT94rUgyaH7ofymMuBk6Nl7D9r01e3lgyQOsK1mHruv4tMNZEVXeKn7e8zN6zGIe3TCX++z3MTpjdD3PJjQWEUy1Ux2tHblxwI1c1+861hWvo8xdhlf1EmeKY9ueeH7N1xiUOhAIBFneXVV4d1fTaVs5sV4Vx8oirLkdkCPcgLmneg/vbnmXT/M+RULCq3lRNRWjYkSWZJIsSVyTew3nZ5+PzWhryrcuCGG9PH8HzjDloCNpUur2aby+6HeuH5OFLDbdCy1A96pQRyPpWLOdT6b+hxeXv8dd3zxJiaOcU7OH8a+z7iLZnhT6PMB1Y7L4bF0hu8udqHWseIVjNsgM7JzIWX3Tjv3NCEJrpvpI+GYmc6RPkX/zAhr4w6T8rZ0Nv74Pg6+BMx6FOnoSRmpL2Rau/e5aHD4Hmh5+4kRHB9lLQU0Bf/75z9w55E4m95x8XK8rNEwEU+2cQTYwOG1w0LE+CR6e+24+bocX/4ZSqhcUoDl86KpOjqqjAQeKdlDx+XasfToSOzYDU6e60zk+yfuER5c/iqZrQbMocDjNcG/NXp5e/TQv/voib535FlkJWY3+XgWhLgUVTtbsrgg5Hk2TUqfXz+LtpYztkdzk4xWEo0lmBRQp7J4pgO4du/Lvc+8FYHvZLmZ++TD/+PF5XrjgwdDnAexmA3NvGM4fX1xCSbUHbx2B2pEsRpleabG8+qdBopKf0DZpKrw3GXYvxag1sGdK8wf+Wf021JTARa/DMWbgFFQXcO2311Ltq474MW7VzZMrnyTeHM+ZXc88ptcVIiM2vwghkmPNjEyKYd/Tq6n86nfUCk+g6ePBH1MZAv/t13FtKKHklfVUfpeProf+2L635T0eW/4YHtUTEkgdzeV3UeGu4PKvL2fngZ1N8dYEIaxvNhYR7lYxmialDq/Kf1fvafzBCUIEjB2tSBGWQc7p0IXJfc/it5KjvmclMKYe3ueREmfh65ljGZXTEZOuYpTCB1QWo4zZIHPhyRl8OH0kNpOYpxXaqHl3we5l4HM1fO4hPif89jUseOKYX/a+xffh8DtCjlcsqiDv/jw23bCJrTO3Uvh2IarjcIaFW3Vz/+L7cfhCHys0HhFMCSH85W7uLtFRnP6G+5XoBzctL9rLgf8F/zAv37ecf6/+N2418oo3OjpOn5Nrvr1GXPxCsymqdOMNszckmialh55HEFqCuXsikiH8rPf2sl28suID9lUVA1BYtZ/Pt/zIwPTcoPMkg0zs6IygY/E2Iy8MMPDWhllMG9WNGLMBWQqk88kSJMeYuX1Cd5b9bQJPXNQfkyhQJLRV1fthzexAcHSErs9Uk/JkNQ7v4cmG19d4GT/riHsYnxMWPwOemqhfdnfVbjaVbQpJ7SudV0rRf4tIm5xGnxf7kPX3LLxlXvKfykc74vdMQuKL7V9E/bpC5MT0kRBEVzVKXl2P0a8TzWK07tNwrizClBGDfVCgHO6za54NG0hVLKqg9NtSvMVeFItC3KA4Ui9Ore1tpaPj8rv4cseXTOk1JeTxgtDYvHVs3I+2SakvglQoQWgKkiwRMyqDqp93gy/4c2g32VhXuIXXVn5IlaeGOHMMp2eP4L5Tbwo6z2838NOBGlyllcSYjXRPiaFrRzulr7xCnz9NYdR5ufztvFzcPhWnV8VuVkQFS6H9WPVmnWl6qg7PLvdy75jwPdqAQGGK9R/CkGlRvey7W95F1YL386ouleLPism4NoPY/oFtFqZkE5k3ZbLtzm1U/lJJ4thEAFyqi7c2vcWlvS4Vhb6aiAimhCCuTWVoTh/hsjlWFKzn0Z9fYltpPrIs071DFx6ccCsnn9QbONj08bt8bANTyK/KZ1vFtpDnKJ1XSsm8Ejpd14mYPjH4KnwUzi4k/6l8ut3XDfngrKbLH7j4J/ecLC5+ocl1sJuRICTV78gmpfZeDVdFSoywIIsgNAX70DSqFxSgH1XS/KTYZF6a+FC9j3Wj8/+qq1j40Xo0XUeWJHyqRq8EIxP3q0y98MLacy1G5biqugrCCUdTYcUr4QtNAHeONPGvJR5uGmIiwVLHPYvPAUueiTqY+mrnV/h1f9AxZ54TzacRNygu6LhiUYjtH0vNppraYArggOcA2w9sp3ti41UWFA4T6/FCkOoFBYH9UAf50clDZZGnkis/upuzB/2Rdbf9r+6mjy4/np2V9c6kpF+RTmz/WCSDVDuT4i31UvlLZdD55e5y1havbbo3KwgHDctKwmo6vialVqPC+J6i+ITQcpQYEx2vykUyRvfT7kLnS7x8rXqp8fhxelVqPH48fo1fSz082ftCTnvuF/aUOxt+MkFoi6oK6wykAAanK4zvauCpX8L37jz8PAXgjfw60nU9bNEJtUbFEGNAUkIDN0O8AX9NcPClSArl7vKIX1eIjgimhFq+Yie+4sBFXorG67g5n2puxsF95dtxAx/1Gc4k2cV/jTr9uw2md0pwLxHdq1GzsIB1xeuOaSblSJqmsaV8S+O/UUE4yrBuSSTUsaoUN3QSiaddS+XSuRQ8P5WCl66mes2XWLsHF6XQdZ1JAzs1x3AFoU7mrHg6XNUHySQ3+AuvoeNC5zO8PEvdN4FOXabwgIvznl9MfqnYyyq0Q+5KkOtP5vrnqWaeX+GlxFHPXnPZFHiuCOnoYYt7KTEK/ho/epjUcn+lH0NM6FgbKgImHDuR5ifU8hU5QILXcfMeXoCD/wtaUgZIMnu++n/Ye43lzYxezLLEcD1mLiM4R9i7uwyHpSjk+RuaSXHtCq6O49W81Hij36wpCNGSJIkbxmTxxDdbcYUputJQk1JFgvMHpBNrEWl+Qsuz5CSSevsgqhcW4Fy9H83tQlIOf0/rsoRX09iAyhw8rCK0v9rRNB2q3T6mvLqUH/4yTnzWhfbFYIYwQc2R+qYonNfDwOOLvfROrmMmQ9cCzxUhWZIxKabaNjKH2HJsSAaJqtVVxA+Nrz2uulWq11eTenFqyHPFmeJCjgmNQ6xMCbU0l5/HfA4+wIuXw4EUgGy2kTb1X4BE2TfPs+O5yyn4+J+84ijiPwQvfeseH2Z/6AxItDMpBtmAxWBphHcmCA27ZHAmSXYzyjHs0bOaDMycIHLRhdbDkGQhcWIOHa5Mw7fra2xD07D07YBtYAqLUgxcQQ234wwJpByb57Pv7dvZ/f8upuA/V7L/wwdxF2wCAgFVlcvHf1cVtMRbEoSWY+8IqrfB0x4ab+G1NV72VtUReOkamKMLagYkDwg5ptgUUiamUDinkOr11eh+HW+Jlz0v7sGYZCRhZELQ+X7NT05CTlSvK0ROBFNCrVfz9vOj7qOurGBjx0w6nvtnOt38NunXvoBaU86+H1/jM7x8ckSKiGSLISN9cMjjj5xJOdKhmRR7H3vQcZNiItUeOrsiCE3BbjYwd/pw4m3GqBqOWo0Kb08bQmaSrQlHJwjHxrV6BZYsM0mTutPxij4o52XxUGkF+8J0Vqta8SnlP75G/PDJdLplDhkz3iJ24Dm48pYffj6fxqsLd4ZNPRKENsuaCJ2GNHhaTpLMlFwjz60IE3hJMvQ+D5ToksKu6XsNNkPo70vyOcmkXpRK0dwiNs/YzI7/24ExyUi3u7ohH7Fv0iAZuCD7AmxG8RvVVESanwBApcvHK1sKw2bNOzbPp2rlZ/jKCpBNVowpWcSPnIy97wRq1n2DG3gJD+diwoyEEm/m0p6XsqpoFU7/4Y2WR86kyBY5qJpfuJkUTdcY32l8k75vQThSp0QbX88cw+WvL2N/pRuHt+70J7tJwWxUeGfaUPpmxNd5niC0JMfyFcSMG1f73x+u2hO2urPmcXBg8bt0OOd2bD1H1h635QzDljMs6Nwqt4+lO8oYmdOxycYtCK3OqNth36/QwPaDB8aZmb0+zP4kgwVG3Br1y45MH4nVYA26nzokaVwSSeOS6n28Iitc0eeKqF9XiJwIpgQAPlq1B1mWAs0SjlC14lMql39E/Igp6D4Xtl7j8JXtwrHxZ3yluzCn96w99yd8nGOyEjMqnVEZyVgMlpCLP/mcHYkbmQAAIABJREFUZBS7QtHcIrzFXmSrTNzAODKnZ4bMpEzMmSjS/IRmlxZv4Yc/j2PJjlJeXrCDVfkVmAwyuh5oMeLw+Dkp3so9Z/fizNw00aRUaLV0TcO5YgWpd91Ze2zhthLcYfYFevZuRfd7sfUYEfK3o7m8Kqt3VYhgSmhfciaAyR4STOXfHhv035nxMu77j07lkyC+E2QMjPplZUnm7qF388CSB8L27qyPRbEwPnM83eK7Rf26QuREMCWg6zqvLfo9ZOP9kTOV5oxeVPz4OtVrvkLzOJDNdqzZQ0k8NdAvwQXMwcs5WLH1S0aSZK7OvZoX170YcvFHPJPSW8ykCC1DliXGdE9mTPdkiird7Cytodrtx24ysCq/jD0Vbs4fkN7SwxSEennytiPHxGBMP/xZPeAKX9FLdVUh2+Iiak6tA2WOhvePCEKbIitw6Xvw9nngczV8/pFMdpgyp86mvw05u9vZ7KnawyvrX8WrNVB+/SCLYqF3h948OvrRY3pNIXIimBLYX+Whwhn6w3jkTKUkKyRPvKfe5ylAQx6cUtvj5Mo+V7Jk7xLWFq/Fq0X+w2tRLNw95G46x3WO7o0IQhNIi7eQFn94hTQj0cqlry5F13XRUFpo1ZzLl2MfNjTomLmOlVTFGofmrELX1IgCKqto2iu0R50GBwKqD6aCr+F+URoSfsWG6crPILlng+fXZ0qPa3h9YTGG+E9RZCmkwt8hiqRglI2MyxzHY6Mfw6iIyptNTeSnCFS6fBiV0I9CNDOVAAZ08LwBamDms7CmEL/uj6q3gUWxMHPgTC7ueXHEjxGE5tS1gw2DLLO9WJTtF1o3x4rl2IYNDzp2UkL41GlzRi8kgxHntqUNPq/ZIJMSF3l5Z0FoU7JPg2u/g8zhgX1QcphgRTGBYsabOYZLeYw1+vFV0tN1nXs+Xs9ZnSfy4+Tvmd5/OonmROwGOzHGmNp/zIqZC7Iv4L1z3+OpcU+JQKqZiJUpAUWWwjeFi3KmEoOMwbEXZv+RDaf/jRsW3oHT50QPUzUqnE72Tjww8gFGpDecsy8ILUWSJMZ078iivFK6p8Y2/ABBaAG6quJcuYq0Bx4IOn7pkM78tKU4pLiKbLaTMHoq5d+/jCQrWLqdgiQbcOevw717fW1KNwTS/M7pd1JzvA1BaJ3S+sG130L5Tlj+Mmz5CjxVgTQ+SzzkXgRDr8MS34kbNhZx2wdr+WrmGOKOsT/b7GW72F3u5N9TTsZiVLi+//VM6zuNTWWbKHeX49f8xJniyO2Yi91ob/gJhUYlgimBDnYTXjV0Q/KRM5X2XqMbfB6/BvGXv0H+D3dy/Q834ogyA6rUXYqm19M5XBBaidHdO/LJmr1MGy029Qqtk3vrVgwdOmBMSQk6PjK7AzEWQ9hKlXFDJyHbE6lcOpfSL59CMlkxp+YQN2JK7TnSwedIjRPFgQSBpCw4+1+Bf+pwVt80Fm8v4b5PN/LcpSdHnR6+cW8lz/yQx8czRmI5Ir1WkRX6J/c/5qELjUcEUwIJNiNJNhP7q4Pzb6OZqQQYld0Bk8nInepenGG+KyoWVVD6bSneYi+KRSFuUBypF6ei2ANfDm7VzV8X/JWFUxZiUkxN9n4F4XiNyu7IPR9vwOvXRDU/ocU5vX4+W7uXRXmlHHD6MBlkOhbt4uwhp5J11N4+SZK4fkwWT32zFXeYBuoxuacSk3tqna9lMSrcMDarSd6HILRV95/bhwv+s5iPVhdwyeDM2uOqplPu8FLl9mE1KiTZTUEBU7Xbxy3vreEfF+TSraNYcWqtRDDVzmmazv99tRmDImMzKTiPmq2MZKYSAj13po/LZlvFNvIr80MS+0rnlVIyr4RO13UK6i+V/1Q+3e7rhnzwhlTXdb7b9R3nZZ3XlG9bEI5Lot1EVrKdNbsrGJ7VoaWHI7RTe8qdvDR/B5+u3YskEfT9LekWvlZ6k/rUfG4an8NFgzrVNqO+sGQ9n5YXkZeQiTeK3rtWo8LFgzoxMluURBeEaFiMCs9fNpDLXlvGwC6JxJgNzFm2i7eX5uPxaRhkCU0Hv6ZxRp80rh+bxYBO8dz76UZGZHfgAlE9tlUTwVQ75lc17v54A/llDj6/ZRQTnl4AhKZ+NDRTCRBrNTIyuwMPLHk2pOCE6lIp/qyYjGsziO0f2GNiSjaReVMm2+7cRuUvlSSOTQTA6Xfy5oY3RTAltHqjczqyOK9UBFNCi1iVX87Vb63E7VPxa6ERkS7JuDTIL3Py4Beb+HLDPl6+7GSqn3+W6h9+4J1/P8e180vYuq8at7/h9GqrUeGsvmk8dEFuU7wdQWjzeqbFMnNCDpNeXILLq4EE3oPX3pF5QfM27uOnrcXEWQzYzAbm3TamZQYsREzkp7RTbp/KTe+uoaTGw+xrh9Ixxszzl52CxRj9R8JilHnh8lOQJIlv8r9B1YMDMmeeE82nETcouImdYlGI7R9Lzabgqmi7q3dTWFMY/ZsShGY0untHFm0vbelhCO3Q+oIDXPnGCmo8/rCB1NFcPpXlO8uYev971GzaTNe5H9AhtycfTh/J1OFdsJkUbKbwRYbspkDq0d1n9eT/TR4QaO4uCELU3D6Vz9YWUu3241W12kDqaJoeuGb3V3sornaTX+Zo5pEK0RLBVDtU4/EzbdZKjIrM638ajM0UWKAc2yOZxyf1jyqgCgRSAxnUJQmv6g3b90CtUTHEGJCU0B9hQ7wBf40/6JhRNlLmKovyXQlC8xrUJZEdxTVUOiMv/S8Ix8vlVfnTGytw+UKzCOrj8WtsMSbxyZQ7MSQGMgFMBpm/n9eHNX//Aw9dkEuvtFjirUYsRplEm5GhXZN49tJTWHnf6Vw9qpvoqyYIx0jTdKbPXs2WfVVEMP9Ry+FRuezVZRRVuptucMJxE2l+bYHXCcWbwXUAFAPYkyGlT9hO2xUOL1fPWknvtFge+WO/2hz6QyaekkFKnJm7P15PWY0Xl0/l6KrpsgRmg0J6goWnLhnAKZ0DP8x+zY8iKfj14OBIiVHw1/jRVT0koPJX+jHEhH4Mo+lNJQgtwWxQGNDVwNPL3qJDvAu36ibBnEC/jv0Ynj4cWRJzVULj+9/6wrDVVwEcm+dTtfIzfGUFyCYrxpQs4kdOxtIpkJrnRmHW8j3MPKMnZsPhlSiLUeGSwZlBG+MFQWg8i7aXsjK/HE+Y1aiGrtsqt4+nv/uNJy8Z0NzDFiIkgqkTWel2WP4SrHsPZIVA0VpAU8GaACNnwsmXBXoeAPur3Fz5xnJO7ZnCPWf3qnOWcWR2RxbeeSqrd1Xw6sKdLMwrweMLfAFYjAqn907hujFZDMhMCHqc1WANW9rclmNDMkhUra4ifmh87XHVrVK9vprUi1ODztd0jThT3NFPIwitxq8lvzJr4yw2KQvYuEdC2+MFQEbGYrBgM9q4qs9VTOoxSXyWhUb18vwdIYWCAKpWfErl8o/ocMbNWLoNRFIMuH5fjStvee1NGQC6zjcbi7jw5IxmHLUgtG/Hc92qWmAS5YHz+xB7jH2qhKYlgqkTkeqDz2+BzZ8FAqdwqzg+B/z4EPzwD7jwP+xKP5sr31jBpUMzuWl8w524JUlicNckBndNCrykpiNBvfnykiTRM6knW8q3BB1XbAopE1MonFOIbJGDqvkZk4wkjEwIeZ4ucV0aHKMgNDdd13lu7XPM2TwHj+oJaUitoeH0O3H6nbyw7gXe2vQWb575JtkJ2S00YqEt2VxYxb4w6T6ax8GBxe/S4ZzbsfUcWXvcljMMW86woHMdXpXXF/0ugilBaCZ7yp2s2V0Rcjya61aWJD5eXcDVo0Rvw9ZI5KGcaFQfzP4jbPkc/O7wgdQhPif4XWif3cx/X7ifG8ZmRRRIhaPIUkQbj6f1m4bNYAs5nnxOMqkXpVI0t4jNMzaz4/92YEwy0u2ubshH7NEyykYu6XEJRkXMvgitzxMrnmDO5jm4VXdIIHU0t+qmwl3B1K+n8nvl7800QqEt213uCEnNBvDs3Yru92LrMSKi59lT4WzsoQmCUIeFeSXhdl1Edd06vSpf/LqvCUYnNAaxMnWi+fxm2LsKfK6IHyKrbm5X3sOQeAbQtCs+EzIn8JD0UNi/JY1LImlcUr2PlySJy3pd1hRDE4Tj8sWOL/g472PcauQbgXV0nD4n076dxrxJ87AYLE04QqGtc3hU9KM3sQKqqwrZFockh6/Id7RDaduCIDS9A05f2Mp90V63B5zexh6a0EjEytSJpGQbbP48JJDq+kw1KU9W4zii++Lra7yMn3W4nKZBdcPXdxBSTaKRGRUjtw+6HYsS/U2jRbFwbrdzSY8RzemE1kXXdZ5f+3zYQKpiUQV59+ex6YZNbJ25lcK3C1Edh3PjDwVU3+36rjmHLLRBMRZD2L2uijUOzVmFrkVW4e9YWmAIgtC4or1uhdZLfKOeSJa/BJo/7J9UHZ5d3sCshasCdi1pgoEFm9JzCpN7To5qFt6iWBiQPIAHRjzQhCMThGOzav8qqjxVIcdL55VS9N8i0ian0efFPmT9PQtvmZf8p/LRjpiJdPqdvLHhjeYcstAGZSfH4NdCZ7jNGb2QDEac25ZG9DxZHWMae2iCINQhwWbEbAi93Y72uk2wmRp7aEIjEcHUicLrgF8/qDOYunOkiad+8XDAXc/Kk9cJS55togEGu2PwHczoPwOTYsIk1/0FoEgKZsXM6V1O56U/vIRBFpmnQuvzzuZ3cPmDV4RVl0rxZ8WkX5FObP9YJIOEKdlE5k2ZeEu9VP5SGXR+YU0hW8u3NuewhTYmJyUmbCAkm+0kjJ5K+fcv49y2FM3nRlf9uHasouLnN4POtZsVrh+b1VxDFoR2b1yP5LC9paK5bq1GhYkni6yd1krcuZ4oijYeLH8e3uB0hfFdDTz1i4eHT6trRUhvlpUpCOx9mtZvGudmncvc3+bywdYP0NCQkNDRkZDwa37OyTqHK3tfSU7isRXGEITm8Fv5byEFJ5x5TjSfRtyg4NLnikUhtn8sNZtqSBybWHtclmS2H9hOr6RezTJmoW26cXw2f/t4PY6jyizHDZ2EbE+kculcSr98CslkxZyaQ9yIKUHnGWWZ03unNOeQBaFd65RoY1CXRH7ZURbyt0ivWx2dPw4UFThbKxFMnSjcB6jtI1WHf55qZtSbDm4bVs9SsM8V2DfVTJ3sU+2pzBw4kxknz2B9yXoOuA/g1/3EmeLon9wfu9HeLOMQhOPh9IdWP1NrVAwxhpBG1ACGeAOuXUetZOkqNd6aJhuj0D6clZvGw19uxhmmoXpM7qnE5J5a52OtRoUZ47MxKCIpRRCa0/Rx2azbcyBsr6mGrltFhvP7p4seU62YCKZOFBFUe+mbonBeDwOPL/bSO7mOH0tJbrZA6khG2cig1EHN/rqC0BjMijnkmBKj4K/xo6t6SEDlr/RjiAn+epUlGavB2qTjFNo+k0Hm/RuGc+F/luDw+Bso0H+Y1agwvmcyN4gUP0FodmNyOjK0WxJLd5ThCVPZrz7xFhN3nNmziUYmNAYxPXWisCeD3nDFl4fGW3htjZe9VXX8xJrjwh8XBKFOGTGh6RW2HBuSQaJqdXBhCtWtUr2+Gnuf4FVXCUlUqhQaRXZyDB/PGEmi3USYhdEQNpPC2X3TeP6yU8JWAxQEoWnJssTLVwyiT3pcxNU0ZQnirAbev2E4qXGirUZrJoKpE0VqPzA1XIEpJ0lmSq6R51aEqewnG6H/JU0wOEFo267ofUVIM2rFppAyMYXCOYVUr69G9+t4S7zseXEPxiQjCSMTgs63GqxidVZoND3TYnnjT4MxKjLxViN2U3D2glGRMBtkhnRN5PnLTuHpyQNEep8gtCCLUWHuDSM4r386JoMctsIfBIIoq1GhR2osX906hp5psc08UiFaIs3vRCHLMOIWmP9ogw17HxhnZvZ6X5jnUGDYjU00QEFou07tfCrK0tBU2+RzklHsCkVzi/AWe5GtMnED48icnol8xOyjWTHzp9w/IUviZlZoPK8v/p1bJ3Rn+tgsftpazMr8ckprPFiMBtITLFwwIJ0uHcS+VEFoLUwGmacuGcBdZ/XkvWW7mfVLPpUuH+aDvxeaBmf1TeP6MVn06xTfwqMVIiWCqRPJKVeg//xISBmK/NuDZy0y42Xc9x+VzifJkH4yJIl8eUGIllE2cmWfK3lzw5shjXuTxiWRNC6p3sfLksyknElNOUShnVm35wCrdpXz5CX9MSgyZ+SmcUZuWksPSxCECKTEWrj9Dz249bQc+v3jWz68cSTJseaDPaka3iMvtC4imDqB5FUb+Vi+nr9Ir2LSPdE92BQDF77YNAMThHbg+n7Xs3LfStaXrsejRn79WRQLz4x/hgRLQsMnC0IEdF3n8XlbuG1CD2wm8TMuRKbIUcTa4rVUeapQZIUkSxIj0keIwjgtqMajIssyfTPEKtSJTHwLnyC+3VTE3z7ZwN/Ono7Jkww/Pwb++tP9gMCKlCkG/vQ5dMhu+oEKQhtlkA28cPoLzPxpJuuK14WsUIXQJSwGM0+MfYKRGSObZ5BCu7BgWwnFVR4mD+7U0kMRWjlN11i2bxlvbXyLNfvXYJSN+HU/EhKKrKBqKhfmXMjU3lPpFt+tpYfb7hRWukhPEMUlTnQimGrlNE3nmR+28dHqAt66eggDMhOA2yCxK8y7CzzV4HWEPlA2BvZIpfWDP74iAilBaARWg5WXT3+Zub/N5a2Nb1HprcTtdwc19DUrZnR04vV+DI2/jNM6n9aCIxbaGk3TeXzeVu46q6coKCHUq8Zbw80/3szW8q21vfK82hHFqQ4WCP5428d8tv0zpvWdxowBM0TFx2ZUeMBFeoJYGTzRiWCqFat0+fjz3HXUuP18fstokmOP6HXT50LodT78Ph8WPwsFy8HnDvSQMsdBv4th2AzomNNi4xeEtkiRFS7vfTmX9bqMVftX8dG2jyisKcSjeogzxTEkbQgX97gY1WfnzGcWcsPwGrKSG67EKQiR+PzXvViMCmeK/VFCPZw+J1O/nkpBdUFwABWGX/fjV/3M2jSLSk8lfxv2t2YapVB4wMVJ8SKYOtGJYKqVyttfzQ2zVzO2e0fuP68PxnAzkLIM2acF/oFAGRhJapGmvILQ3kiSxJC0IQxJGxL+BCvMGJ/NQ//bzKxrhojZXuG4efwqT327jacnDxCfJ6Fet/18G3tr9jYYSB3J5XfxSd4nZCdkM7nn5CYcnXBIYaWbDJHmd8ITwVQTc3r9bC6s4oDTh6JIdLCbyE2PR5Hr/iE8vD+qF5cMzoz8xWSR8iEIrcnVI7sxd+UefthSzB/6pLb0cIQT3Jxlu+mZFsvwrA4tPRShFdtUtol1xevCFsqpWFRB6beleIu9KBaFuEFxpF6cimIPVJBzq26eW/sck7pPwiCLW8SmVnjAxdjuyS09DOE4iSuliewoqeHNxb/zyZq9GA4FTlIg391iVLh2dDcuHdqZJLup9jHh90cJgnCiMhlk/nFBLvd+uoEx3TtiMYqSt8KxqXL7eGn+duZcN6ylhyK0cu9seifsilTpvFL+f3v3HR9VlT5+/HPvnZnMTHpCCpAAgdB7V6qIBQuKXdcKiIqr6677dYu66rquu+v6W3fVFRERe8MV7AVQmtKV3gMpJJBC+vS59/7+GAgZ5k4yA6EEzvv1yuulM/dOTkLmzn3OeZ7nlH9VTtadWcT1isNX5aPkrRLyn80n55Ec5EObyPpUH0v3LRX1nifB/mq3qJk6A4iljBbmUzV+++F6Lv3PMj5YU4TLp1Ln8Qe+3H4cXpWDDi/PL9rFuX9bxDurCoDAB+W0N9eyck8ln9w3SgRSgnCGGN01jd5tE3ll6Z5TPRShFZu1dA9ju6XTIzOh+YOFs1att5aFhQvRdC3ocdWlUja/jHa3tCO+XzySScKSZiH73my8FV5qfqxpONbpd/La5tdO9tDPSsXVLtqLYKrVE8FUC/KrGne8tpovN+3H49fwa3rYY91+DY9f46nPt/LEp5uZ9OIPZCXbeGfa8OBGE4IgtHqPXt6T137Yy74q56keitAKldW6eWtlAQ9e1O1UD0U4zeVV52GRLSGPO3c50XwaCYODg3HFqhDfL576LfVBj2+v3H5CxymAqumU1bnJSBT3fK2dSPNrQX+ct4mfCqtw+bTmDz7E5dN4/ccCbhqWzZ+v7HMCRycIwqmSlWxn8ogc/vrFNmbcMhiHx88Xm/azt9xBjdtHks1M14w4LunTVqQCCiH+s2gX1w3OEjPYQrPqvHWGj6v1KqY4E5ISWq9tSjThKgjet9KjetB1XTQ6OYHK6zwk2S3EmMQ1v7UTwVQLKTzo5NP1JXj8oYGUY+tiatfMx3dwH7LFhjm9M4kjrsea1bvhmG+3lPLUJL3JxhSCILRed4/tzHnPfs8dc1azcs9BZEnC6VUbno+1KDwybzPXDclm6sgcOqTaT+FohVNB13RQdSTzkaSRPeX1fLlpP9/99rxTNzCh1bAooatSAEqcgr/ej67qIQGVv8aPKS74dtAkmUQgdYIViz2mzhgimGohr/+4F00PTeurXT2PmlUfkXrRL7HmDEJSTLj2rsO1a1VQMOX2qSzdWc64Huknc9iCIJwkX2/eT2W9j8U7yg2fdxwKrN5dWcCHa4p4/qaBogPgWcB/0EXdD8U4fypDd6sgARKYM2KJPy+Lf20sZNqYziTHGt8kC0Jj6fZ0/Jo/5HF7rh3JJFG7rpbEYYkNj6tulbqNdWRcG3ytSYxJPPolhONUVOnkjR/zmfdzMbVuH35NR5Yk7pizmrvGdObczqkigG2lRDDVAtw+lQ/WFOFTg4MpzeOgevk7pF76a+zdRzQ8bs8djj03uCOTw6syY0meCKYE4Qz0v3X7eGT+Jrxq8ynAPk3Hp6nc/95PPH/jQC4Sm7Oekfw1Hirf3463qB70wIoUAHrgy7ffQcVHu7jfp5LWIUOkXAkRyUnIISM2g4LagqDHFbtC+qR0St4uQbbKQd38zClmkkYcaXplkS1c0+2akz30M1ZRpZOHPtrAz4XVaLoedK+o6jpLdpSzem8liTYzT1zRW2zI3QqJBhQtYHdZveGHnKd4O7rfi73buRG9zs+FVS09NEEQTrHNxTU8Mn8T7ihqKQHcPo0H3l/PnvL65g8WWhVfmZOy//yEt6AW/NqRQOoosk8jFgnXwkKq5+1GN8h+EITGJEliap+p2E2hacJpl6aRcU0GBz44wNbpW8n7Sx7mFDM5v8tBbpRaigQ3dL/hJI76zLW1pJbLnl/G6r2VePxayKQ7BOZPnF6V/TVuHnj/Z15dJjq/tjZiZaoF1Lh8GE0Yqq5aZHsCkhxZcaFf0/GpGmZFxLiCcKZ44btdhrWU0Hw9pVdVeWXpHv5+Tb+TOWThBFJrvZTP3IjmDE3FCkf3aTh/LkOym0iakHMCRyecCSbkTOAfq/9h+FzK2BRSxqaEPdckmRieOZx0u8iSOV5FlU5unLWCWnfk73W3T+PZb3eQbLdwzeCsEzg6oSWJYOo41Dh9zF1XxGs/7KXO4M2i2BLQnLXomhpxQKWINA5BOGNU1HtYvKMcowWFSOopVQ3mry/mT5f3IjZGXK7PBNWf5aG5fCGPr963kae/n8HOinxkWaZrakceH38/A9r2BAIBVf3yEmIHZ2BOE81JhPBsJhv/GvcvHvjuAdyqO+LzZGSSrEn8ZeRfTuDozh4PfbSB+jCBVFMTaW6fxiPzNjG+ZzpJdlEr2RqIT+djUOXw8sRnW/h68wFkibCt0GPa90AymXHuXEFsj1HNvm6sxYQsuvkJwhlj7toijN7R0dRTypLEZxtKuHFYhxM8WuFEUx0+XNsOwlEfGXUeB5M/+gN/vehBJvYYh1f1s3rfBmKO7symadT/UELypNyTN2ihVRrRbgRPj3qah5c/HFFAZZJMpNhSeP3i10m1pZ6EEZ7Ziiqdh2qkQp+LZCJNkmDu2n1MG9P5JI9cOBYimIpSUaWT615ewUGHxzD3tTE5JpakUTdTueBlJFnBmjMQSTbhzl+Pu3AjyeOmNByryDCxf7sTPXxBEE6ijftqcBuk+EVTT+n0qmwpqT0RwxNOMseaA4G7JII/O/ZUFgEwqdcFANhkhbE5w0JfQAPnulISL81BtgSyHdx+NwsLF5Jfk0+tt5YESwKdEjtxQYcLsJqsJ/TnEU5vF3a6kMzYTJ5Z8wzbKrehaip+PXilxGayoekaF3W8iIeGPkSyNfkUjfbM8saP+YYdniOdSHP5NGYt28PUUTlikr0VEMFUFCodXq59+UfK6zyGsw1GEoZdjRybTM2KD6j4/Fkki42YjFwSzg0u7jTLMlNHdWr5QQuCcMrUGqRzQfT1lFVOb0sOSzhFnOtKwSCToXNKNrIk85sv/soVPcYzsH1vkqzxxi8iS3h2V1OR7eTtrW8zf/f8wGv7nQ2H2E12nlzxJFflXsUtPW8hOyH7hPw8wumvZ0pvftnj32w/uIflB+aRV/ol/th4zIqZpJgkrul2DVd0uYJ4S5i/N+GYzPu52HDCPZqJNIfHz9b9tfRpL9rUn+5EMBWFR+dvotLhjTiQOiyu9zjieo8L+7wkQbfMeHLTxcVMEM4k4eqcoq2nTLCaW3powikQrulEfEwsH9/8Ii+tepffff1Pyh2VjOsynGcm/I602KOaBWg63x34nid+/geqpuLTQgP2w4HVhzs+5ONdH/P30X9nfMfxLf7zCCdeQW0BxfXFuP1u4sxxdE7qTBtbm2bPK6118/bKAt5cUYCqBfrt6/poZN8gvEos5/dIZ9o5nRmYnSRa7p8Ate7jn0iTZYmKek9LD004AUQwFaFKh5dF28rCpvY115WrKbEWE/++YUBLD1kQhFMsNz2O73eEXjeiqaeMMcl0Tos9kcMUTpYmWpt3bdOJ5y57GIDdBwv41edP8cSiF/jvFY8HHbcsdh3P7nsTj96UDVKAAAAgAElEQVT8TZZf9+NX/fxh2R94Sn+KiztdfHzjF04Kr+plYcFCZm+eTWFtIWbZjI6OhIRH8zAscxiTe09maOZQw0DozR/z+euX2wAMOonawK/xzZYDLN5ZztCOycy8dQg2S2Sr5EJktDA7YUQ1kabTbDmJcHoQPbgj9P7qQsNCcggUE1YumkXiOdeTdd/btJ8+h/hBl+LatarJ15QkiIsx8fadw+mcFtfygxYE4ZS6cWgHZIObncb1lM6dK9B8bnTVjytvLVXfvxZ0rA5MGtj+JI1YOJGkCDsy5qZ25Po+E9hRHrzfTIFlP89mvBFRINWYW3Xz6PJH2V21O6rzhJNvS8UWxs8dz5MrnmRn1U7cqps6Xx31vnrqfHV4VS8/FP/A/d/dz7WfXUuFqyLo/OcX7eJvX23H49fCbskAoOng8qqs2lvJNTN+xOVVT/SPdlaxWoxvrxtPpEUi0SayEloDEUxF6N3VhYaF5IeLCVMunI69+whkixVJMWHPHR7UYKIxsywRY5IZ1CGZz+4fxYDsJMPjBEFo3Tqk2ukf5v2dMOxqks+fSs2KD9j3ws3sm3EHdT99jq3rkVx6SYJx3dNoExdzsoYsnEDWHslgUEy++2ABM1e/z/7aMgBKakv5ZNsiBrULzmyYm/oNPkJTBauWVbHr0V1suWsL23+1nZI3SlAdwTfHXs3L7M2zW/CnEVra2gNrmfzNZKo91Tj8jrDH6eg4/U721Ozhus+uo9RRCsBn60t4afFuXL7IAyOPXyOvvJ7p76w77vELRwzrlGI4AR/NRJpP0+jVLuHkDFg4LiLNL0JVDuMC8GiKCQFiLQrXDcnmjhGd6NRGpO4Iwpnu/vNzuevNdYY3OM3VU8aYZO4Z2+VEDk84ieJHtsexppSjC29jLXbWl2xj1poPqfXUkxATxwVdzuWRcfc2HOOQXSxL+BlNDp7Uq/iqgvKvysm6M4u4XnH4qnyUvFVC/rP55DySg2wKzJlqusaCggX8cfgfSbCIG7TTTUFtAb9c9EtcflfE5/g1P1XuKqZ8M4W5l3/EX77YijvMVi1NlSJ4/Bqr9lSyubhGNDtoIXeN6cKqvZU4DVb8ImlMpshw5YD2xIn9BVuFs/5fye1TqXH50HVIspuxmo1zWP1huk5EU0wYG6Pw96v7iRbognAWGd01jSkjO/HaD/lRzRhLwPSxXRjYQbQqPlOUmqDQrNPhqMWltvFpzJj05ybPXZS4CumouW7VpVI2v4z2U9sT3y/QwMiSZiH73mx2PrSTmh9rSB5z5O9HQuKzvM+4uefNLfMDCS3mv+v/a7gfVNWyKiq+qcBb5kWxKiQMTiDj2gyU2MA9h6qrlLvKeWH1XByedMPXjmRfI69fY/byvTwn6rdbxDmdU0iymw2DKWh+Is2syEwdlXOihie0sLMyzc+vBoovr/rvD/R+/BvG/vN7znv2e3o/9g0TX1jOl5v241ODZ3fCFWc2LiZsjixJJIj8V0E46/zfxd25Y0RHbGEma45mMysMy0nh262l1DiNu0IJrYdP1Zi1dA+XP7+MXX1TwBz9R+8ueyEeOThDwrnLiebTSBgcvNKkWBXi+8VTv6U+6HG36mZLxZbofwDhhKrx1PBd4Xdoeuiq44G5B8i8PpNeL/Wi85864z3oJf/ZfLRGZQcuv4sPd7+BwxuaAhppKYKq63y5aT81YbZzEKIjSRJPTOyN9Rje61azzPk90umWITo8txZnXTC1cGspQ55ayIMfrOfnompUTcft03D7NFRdZ1NxDQ/N3cDgvyzgi40lDef1TzEh6aHL59EUE3r8Gr1F/qsgnHUkSeL3l/TkpZsH0T87kRiTjOmo2hmzEqilHJaTwuzbh/D+XecwPCeV2+aspi5Mm13h9LeuoIqJLyxnyc5yPr53JLdf3YvUm3ogRXOTZZJx2EJXLdR6FVOcCUkJrc4wJZrw14feXNd6xQbQp5t5u+aFXXVsd0s74vvFI5mkhlVHb4WXmh9rgo736FXI1qKQ146mFMFiktm2X/x9HI+Cgw4+WFPIzCV5FFY6ubh3JjGmyFvPW80yvdslihXCVuasSvN7d1UBT34ePqf4MMehZdnfzt3A/ho315atZ+LnH7J60K24jjq1cTGhJCtYcwYiySbc+etxF25smPmRJBjbTRSSC8LZbFyPdMb1SGd3WT3vrCpgZ2k99W4f8VYzvdslcPPwjnRItTcc/6fLe/KnTzYzec4a3pgyLOy+VcIJoGmwbzXUloDfDdZEyOwHSZFtgFvj9PH3r7ezaFspj1zWkyv6t2toY23rlUqbyb2peHMraDq6N9xnkgYmE9ZuyaR0bAf564OeVeIU/PV+dFUPCaj8NX5McaF/L7FmUat7ulm6b2lIil8kq46NUziRVBT7XjR3h6Djo9ogXA+/0bgQnqrpfL+9jJeX5LGpuAZZkvCpGrIUSNdTtcDKhaL58MnG2UmKHDh2fI8MnrthABbTWbfW0aqdNZ/MC7eWRhRINeb2afzzi83oBUu49cWn+NdHBbiqQ2cHIykmtJkV7hrTuUV+FkEQWrfc9Dgen9j8HnSSJPHkFX34/f82cucba5kzeWjYuk5d1/H4NWJMstiE83g4K2HdG7Dyv+A71AxA10GWQfVC1jAY+Wvocn7gsaPous789cU8/eV2JvTOZMGDYw3bG8d0TqLdo+fg2lxB7eIi/BVupMMz2DqgqtjNy4m7+wHMmXHkbOiMRbbg1Y6k+tlz7Ugmidp1tSQOO9I4QHWr1G2sI+PajKDvaZbNdErsdNy/IqFl1XhrQh5rbtXRVRDcqEKSNCTFGXJsVPsaSYib+CjVuHzc8dpqdpbWNUzEN+ZVA4/JmoaumLGZZHSvF5NJQTIFrgs+TePKAe2ZOipHpPa1UmdFMOVXNf5v7gbDQKq5zXY9KPy76yXc1imHP10Wy28+XG/4Ok0VE1pMMr3bJTCkoygkFwQhOrIs8fdr+vHgh+u56611zLptMDGmwE1RWZ2bd1cW8vaqAg46vA2JQh1S7Nw1pjOTBrbHbjnzL/MFBx288WM+64uqqXP7sVkUctrEctu5HRnUITny4HLrpzDvrkAwE66rWv4yKPkJkjvBbZ9CbJuGp/LK63l03mZqXD5m3Tak2W0vJJOMfUA69gHp+CvdqHVedL+GbDVhSrMhz/4DuMYAo7ky90pe2fhK0PmKXSF9Ujolb5cgW+Wgbn7mFDNJI4K/vyRJXJV7VWS/C+GkMcmh79GoVx11QA8NlqLZIFzVdNLiRfZMpOo9fib99wf2VTmb3VxXk2U0HcxI9HPuZ+ol/ZE7dSbRZqZXuwTRta+VOyv+9RZtL8NnsB11JB1uDvtq0wEmDWxPXnk9//0+L+KuXBaTTPskG6/dYbxTuSAIQnMUWeL/Xdef+9/7mV++8xP/uKYfj8zbxHc7ypGgYXPOwx/n+QedPPXFNv7y+TZuH9GRhy7ugWKwv1FrtyLvIP/v2x1sKq5B0/WgG5pNxTUs2FpKm7gY7hvXheuGZDd9Df75XfjiwfBBVGNeB5TvhJdHwd3LcMek8NL3u3lrZQH3n9+V287tiEmJbobflGLFlGINeszT91q+WfMf3t3yX8qcZSENCgDSLk1DiVU48MEBvGVeZJtMwqAEsu/ORj6qLmtw+mAyYzOjGpdw4qXb0tnK1qDHol11VCQLsh5akx1pKQIENojt1VbUdUfqrjfXUlztajaQaszlU9kQk8FuOZ5f9spo/gShVTgrgqmXF+fh8AQHP4c73KRe+mvs3Uc0PG7PHY49d3jQsQ6vyowleUwa2J77zu9Kos3MU19sQ9V1/E28iewWhZ5tE5gzeSjxVtHFTxCEY2dSZP5z40CmvrGa0c98j1/V8arh05YPt+R948cCtu2vY9ZtQ86oFJ45P+zlH19vD5u6reuB30FhpZPHP93Ksl0V/OuGAZiNgpz8HyIPpA7TfOCowDHrUib6/kb3dsl89cAYMhOtzZ/bDI/q4fmfnuejgrngdeJ0NR0Ip4xNIWVsSpPHWBUr0/pNO+6xCS3vqq5XsfrAapz+I2l60a46yjLo9capw5GWItw9prOY9I3Q5uIafi6sxus/hownk4UZK4q584KeDVkGQut2xgdTLq/KxuLQfORoN9vdU15PlcOLzaJgNStkJlgprm70wavrIEnIUiCVYmSXVO45rwvndk4VFydBEFqET9UornKH3bvEiMunsmrvQR78cD0v3DTwjLgevbuqkGeaCKSO5vKpLNhWyoMfruf5Gw1+B98+GhJIdfp3HU4f7H0gjlhL4PhXf/Ly9kYfi+841MRB80FNIf8edYB+F1xw3D8XBNpk3/nNneyt3YtH9UALrChaFSuT+0xmaObQFhih0NLGZI3BrJjhqOaLka46SkgMbzsMp68zy3eXYzTH0ty+Rjo6Vw/OMn5O16l1+al1+zApEsl2S9jazbPF7OV7DQOpSDOedF1vyHgSWr8zPpiqdnmxKBKuozbdjarDDWBRZFbsqeDhjzfjVbXQm5lDH86aDnazzI7SOjITrGfEjYsgCKeHl5fkBU/iNNLUbKjbp/Hd9jKW7CznvO7GG3ueLG6fyhcb97N0VzmVDi8WRaZ9so2rB2XRPyux2Wvm7rI6nvx8S1TNhALfV2Ph1jI+WreP64Y06shXvgPKthqeo+rwn1VeHh4dvo4kFjf98l8HbolqPEY8qoc7v72TvJo8fFrLdFWzKlZu7XUr0/tPb5HXE1qeSTZxc4+bmb15diCAbiSiVUeTlSl9ptBhWF8u+c8yqhxeIk88C2wQPig7OWQfvDq3j49/KuaVpXsoq3NjVmT0Q+m0I3NTuXtMF87tcvZNGNe5fXy5aT+qHvxbjjbj6eVDGU9C63fGB1MBoW/0qDrcENjQ7sEPjZtYHM3pVXH5VK548Qc+vneE6M4iCMJx86kab64oaKiPaiyS2VCnV+WVpXtOWTB1oMbNzKV5fLCmCAmCOl/JEsxdu4/MRCvTx3bhmsFZYWu8Xl22N2RT9cOaS69x+VRe/G431w7OOnIDuHIGaKH7MQE8NMLCMz94uHeohSRrEzeMBzZBxW5ok9v8L6IJMzfMZG/N3uMOpEySCUVW6Jbcjen9pzM6a/RxvZ5w4k3uM5mFhQvZU70Hv27892jEqli5JOcShmQMQZIkPrz7XG6YuYJql9dwhSrkfLPMHy/pydKd5dz+2mpm3DyYBJuJF77bzUuLdyMhNdSI+9Qj79klOytYk19Fos3MzFsH0y+r6WYrZ5K8cgcWRQ65Fkeb8bS7rL75g4RW4bQPpnRNx72zCue6UtRaL7qmIdvN2HqmYB+YgRzTdCCUaDMb1hVE0+EGiHoWVNfB4fFz4ysrWfTgWJJjLVGdLwiC0NiibaX4DRrpRDMbuq6gin1VTrKS7Ue/zAm1aV8NN89eicurGhZra3og0Nlb4eDxT7fw6YYSZt02BJsl+Pru8PiZv77Y8CYx0vSa8noPPxdVM6jDoe6q+cvCBlND2imc18nEsz96eOr8JmqhZAX2rTmuYMqn+nhv+3shKxMAVcuqqPimAm+ZF8WqkDA4gYxrM1Big38/yTHJdEnqQrfkbtzY40ZyEnOOeTzCyWU1WXn1ole5/avbKa4vDmqB39Q5o9uP5rFzHmuYHMhNj+OrB0bzyPzNLNkZ3KDmMFmCGJNC+yQrT1zRh1Fd23DLOR35y+dbueql5XRNj2fpropm73ucXhWnV+WGmSt55bbBjO6adsw/f2tS5/YZzdFHnfGkajo+VTOu4xRaldM2mNJ9KnXLi6lfXoLu09CPSqvz7q2h5ou92AekET++A6Yk4w86u8VEj8x4tpQE7+odTYebpjQ1E6oT+PB/a2UBvxrf9Zh+D4IgnF3KnGVsqthEnbcOi2yhja0NgzIG8dG6fSGNdCD62dBvt5QyZdTJu8necaCOG15ZEXGdl8unsia/ktteW8W7084JutH4evMBZIOUomgCSpdP5c0fCxjUIRld19HdtTR1K/PkuBhGvubggeFNTIipfnCH1uZGY1HRIsNufRVfVVD+VTlZd2YFNSHIfzafnEdykBs1FXGrbl664CVsJttxjUU4NZKtybx/+fs8tfIpvi34FgkpZDNfALvJjizJTOkzhTv73hmSZpeeYGXWbUMor/Pw3upC5q4totrlQ9V07BaFczunMm1M56DVJEWWeOKK3tz8ah3fbC2Natwun8pdb65j7j3n0qd9YvMntHLh6sWizXjSgfveWcdzNw48K7awOJOdlv96qsNH+axN+CtcYJDSAjTsGO9YV4pzUwVpU/tiyQ5Np6tx+shJjQ0JpiCyDjdNiWQm1OPXmPPDXn45LveMbE0sCMLx03Wd1QdWM2fzHNaUrsEiW1B1FQkJSZIwSSYkz0gk0wB0f3Dr4mhmQz1+jYOO5me8W4rXr3Hzq4EVqWh4/Bqbimv45zc7ePjSng2PF1e7DF8rmoBS1+Hrzfs55+mDVDq9LFZU2jVxae6TrnB5NxN/X+6lZ1qYsEuSwXR82QcfbP8gqJsbgOpSKZtfRvup7YnvF/h8s6RZyL43m50P7aTmxxqSxxzZv1BGZnnxci7seOFxjUU4dexmO0+PfprfD/s983fP593t71LhqsCn+ohRYshJzGFyn8lc0OGCQNOKJqTFx/Cr8V0jnswtqnSyNr/K8LlIUmgfnreJT+9rPtOntctMsBo2n4g24wlg8c4KrnjxB/53zwgS7aLrc2t12gVTmkel/OUN+CvdgerfZk8A3a1SPmsT6b/sjzkj0GWprM7N7OV7+WBNEeO6pRNrUQx3p26uw034cUY+E+pVNb7fXsYFYk8BQRCOUuutZfqC6eyu3t1wM+1VDQIey7fEdvkW94Er8dcc6coW7WxouHqjE+GbLQdweVXDYvjmbs7cPo23Vxbw4IXdGmaC6z1+w9eKNr0mI8HK+3efE+hK9nonKD7Y5PF/Ps/KoJn1/PbcMI0oFBPEHd/1vdQZuhrg3OVE82kkDA4OoBWrQny/eOq31AcFUz7NR5mz7LjGIZweEmMSub337dze+/aT9j3f+DEfTQ99h0WaQrvzQB27y+rITT+z68SzU+zkpse1SMaTx69ReNDBba+tYu49I86o7SvOJqddMFU1bxf+qggDqUZ0r0r5q5tR7+7DzOV7+HR9CVcNbM/n948iK9nOZxuKeeijjVHXPslSIJ//aNHMhDo8Kqv2VopgShCEILXeWm76/Cb2O/Y333RA8iNJYM38FI/swlc1BohuNtQkS6QcS/1mbQmsmQ07vgJ3NcgmsLeBgbdAv+shJs7wtJeX5BlOYkWzYfpnG0q4ZlAWBZVO9le7kCAkoIo2oEyLj6Ft4qFUuCFToHx7YCPeMHJTZG7obeb51V76phvc7OgadBnf7Pdtik8N/fdX61VMcSYkJXTpzJRowlUQ3NlR0zXjQFwQmuH2qby3ujCkpjGaiWOfpjF7eT5/u7rvSRnziVDv8TPvp318v6OcKqcXsyKTmWDluiFZjOzSBvlQhtE9Y7vwh483hqReh8t4srTtyv43fm04eeRVdXaW1jP/52KuH5ptNCzhNHdaBVOqw4drcwX4Q6OX1fs28vT3M9hZkY8sy3RN7cjj4+9nQNsjKSBuh5e/P7+C3BFZLPrteaTFH5lFnNi/Pftr3Pxrwc6IAyqbWSYl1kJxdWjOcrQzoQfrQ4uKBUE4e+m6zr0L740skGpEkn3EpC9A87VBre8V1WyoIkuM6JIa+SBLtwb2YMpfHvj/xs0RqgsCQci3D0O/G2D842A/0sJ5T3k9eeWh3aqiuTlzelUe+2QzT362lQSbmYyEGMyKhPeoG75oAkqzIjEgu1Hnsd5Xw5cPNfeb4LGxMby10eDfSTbDwFvBfHyb9cZaYiE4yw8lTsFf70dX9ZCAyl/jxxQX/BFukk3EW87sVQHhxNhQVG1YjxjNxLGqBVajW2MwVVTp5IXvdvHphhJkSQqp8Vy0rZTYGBPTRnfmthEdubh3Jn+avxloPuMpkskjly/QKl0EU63TaRVMOdYcAIN5xzqPg8kf/YG/XvQgE3uMw6v6Wb1vAzFK8AyrRYM/Z6TSdkIPw9e/a0wX0uOtPDJvU+D7hcnjtx/qIPX4xF58sWm/YTAV7Uzo2b7BnSAIwdaWrmVn1U7DQKq57m2S7MOa8TmO+p6AFHH9p1/TeeyTLfxieAcu79e26aLnvO/g/ZvB5wx/jO/Qas7P78DOb2HKV5DcCYCCSidmRQ6ZvIq2YYYOLPndOFJiLaiazpCnFuB1Bv/OogkoZUni9hGdjpxsscPAW9DXvYHUKFjM/3VwUJKdKON+NDjdLvCCCgy/O6KfpSlDMoZQUFMQ1BbbnmtHMknUrqslcdiRwn7VrVK3sY6Ma0OzHfq2aX03ssKpV+X0tUgKrcMTeVv308W6girueG01Tp8/bDt5h1fF4VX5fwt28MWmEubc0Ie/HVjErxNG4JHCX0ejmTzaX+NmQ1E1/bPPnjbzZ4rTK5j6scSw4cSeyiIAJvUK7DBvkxXG5gwzfA11vxN/tQdTknFu+6SB7ZnQJ5MvN+1nxpI88svqMWt+JGsMPr9O+2Qb08/rwsR+7bBZFNYX1Rim+kU7E9o++fhmLQVBOLPM2TwHtz90oibS7m2SqR7ZVojm6gg0X/9ptyg8MbEXKbExvLe6kL9+sY0r+rfjF8M70LPtUUFC0Rp4/xfgM94gOITmg/oDMPsiuOcHiEvD6VHRDeovor0586t6Q2qiIktMGZXDi9/tDmn3HGlAOSA7ieyU4Nbw2vmPUbnha5LUYkwGM81hme0w/jFI6Rz5OWHc0vMW5u+ej189cjOq2BXSJ6VT8nYJslUO+nswp5hJGhF805Udn033lO7HPRZBOCzaieNo+VQfJtl0yjb+3VJSw62zV0XcbdTt09hSXMt1f/mEVzNimXXjcO55d33Y86OZPPL4Vb7dckAEU63QaRNM6bqOWm+c6905JRtZkvnNF3/lih7jGdi+N0lW41QGySShVrvDBlMQWCW6elAWVw/KIu+1tykvLKHNvfeSZDfTJi74vOuGZDF/fXFIB6loZ0In9hO7XAuCEFDuLGfV/lXoR80FR9W9TfJhSVmKu/jWZr+fSZZon2TjigHtsZoVLuiVQUm1iw/WFDHl9TVkJFiPrFbJGrx7XeSB1GG6Bs5KmHc33PoxsTGK4Q3S8a7q/2JYB15Zusdw8+LmAkqrWebBC7sFPeZTNX7/SR51CX/jZf9jULsvOJ0xHLMNRj4A50xv/tgIdErsRPfk7mys2Bj0eNqlaSixCgc+OIC3zItsk0kYlED23dnI5iP1W3aTnSl9I9vSQxCOlmQ3G22dFHWHutiYpm8rXX4XX+39ijmb51BUV4Sma0iSRHJMMjd0v4Hrul9HG1ubY/wpouP1a9w6e3XEgdRhPk2nyJ7KK/0G8nSPTObdO5IbX1lBldOg7jGKySNNh9I6URLSGp20YGpvhYPdZfU4PH5sFoUOKfbg2VDt0JeB+JhYPr75RV5a9S6/+/qflDsqGddlOM9M+B1psSkhxx9um94Uv6pR6/ajHqygU0YimenGBdQDs5NIj4+h4GBoqkukM6H9s5PokHpyN8kUBOH0teXgFiyKJWRjzmi6t0mSjmIvaPZ7WRSJ1LgY3p12TlBg0i7Jxm8u7Mb95+eyeEd5w2rVwx22ca3fy9Ef/Z3+XYfTB3sfiCPWErjtevUnL29v9LH4jkAXVTQf5C9Hry5C1SyGNynR3pwdfe1MjYvhzSnD+MWsVbh8kd8E2cwKf7ykB8M7H6kZc/tU7nv3Z/yaxoxpl6DoY+HrP8Kmj0CSDFIcpUBaoC0VLvoL9J4U8fePxINDHuSeBfeE7C2UMjaFlLGhn3WHKZJCsjWZizpe1KLjEc4eA7KTUA1WkqOryYQLwzTa0nSNF356gXe2v4OEFLQNgK7rHHQfZPbm2by66VXGZo/lyRFPEmcxvi9rKd9sOYAnzDWkuW6jHl3i45+L+eOlPemeGU/vdoks310R8jpR7z1l8G8gnP5OaDDlUzUWbC3l5SV57CytwyzLaLqOJEmomk7bJCv3jO3CFf3bBT7kFSlsF7+ubTrx3GUPA7D7YAG/+vwpnlj0Av+94vGQY6UY4z9YTdNZsrOcl5fksSa/EpMig68L/oMSA2f8yD1ju3B+j/Sg/aAkSWL62C78+bOthh/ckaTW3DP2+FNABEE4c9R56ww3aI22e5skh+/cJksQY1LonhnPnDuGkhymi59JkbmgV0bDapUy61EUn3FnO1WH/6zy8vDo8Cv/fk3jvRf+xCsxt5FiN1N+VMZBNDdnsRaFu8eEXj8Hdkjm3WnDue211fj8Gu4w+xEeZlYkHr28BzcP79TwWL3Hz7Q31pIaZ+Ff1w861JI4Hq58ES5+GjZ+AGtmQX0ZqF4wx0L7wTDyV9Dh3ECw1cIGZwzmj8P/yN9W/c1ws1YjiqSQYEng9QmvY1GOb68r4exlNSvcMDSbt1cWhHT0i3TiWNOgW3ocmqY3dL2DQMv+B757gDUH1jT5d+05tCK8pGgJ139+PW9e8uYJXaU63m6jsiTxv3X7uGNkTkhW02HRTB5JQHq8KAlpjU5YMJVf4eCmWSupdfsaWke6j1p62lPu4M+fbuGvX2zjzSnDyEix4i9vPrUkN7Uj1/eZwNvrPw15TvdrmNqE7v6+eEcZ/zd3A65DRYQQWOJFCqRJrCuo4tfv/4zFJPOPa/pxUe/MhnOvG5LNF5v2s3pvpWFqSTg2s8KEPpmM654e8TmCIJz5zIoZySCpJtrubXazhR4dk9lUXINFkRv693hUjQt7ZjBtTGf6ZyVGXI/QjnLwhF/temiEhWd+8HDvUAtJVuPXNOk+bjIt5paH5vDNllJ+O3d9xO2DQzZMlyQm9MnEyMAOySz+v/N4d1UBM5bsaTJVx6TI/O3LHYCiPDkAABWYSURBVByo8fDA+K7Uuf3cMWc1vdol8tSkPqEbqlsTYNi0wNdJdnXXq4k1xfLoD4+iozfcYBqxm+ykWFOYM2EOmbHGvydBiNTkETm8u6qQ0M0HItuTMzvFzicbSnh/TRH3nZ/L5f3aIUvwyLJHmg2kGvNqXvbX72fqN1N577L3sJtbPrMnv8LRIt1G5/yYzx0jc7ikbyYLth0IudZFM3lksyiM7ynuF1ujExJM7S6r56qXfsDh8Rvu0dRYILBRufGVlcwc0YUuKzwhaXq7DxawKG8FV/Q4n7YJ6ZTUlvLJtkUMatc75PWs3ZJRYoN3kZ67tog/fbK52Zboh7u1/Or9n3nk0p7cem4nIFD0/MqtQ7hjzmo27quJKLXEZlY4r3saz1zT75QVVgqCcHpKs6VhVKAQbfe2NHsKH90ygtJaNyXVLlxelTiriY4psSTazUe/fPNqS0CxgEFjDIAh7RTO62Ti2R89PHV++BlUk7cagAt6ppNgNePyqiGfBc3dnNnMClNH5hBjCp8ak2y3UFjlornMmMM1r68u28vyXRXUun1c2CuT30/oflpeny/OuZjBmYOZu2Mu72x/B7/qR0ND13VkSUbVVXKTcpnadyrnZZ+HWT6Gf2tBOEqHVDu3nNORd1cVRpVCC4H3639/MYg+7RNYuquCFxbt4t8Ld3HRkBoWFy82DKSa6lrq1/3sq9vHa5tf476B90U1Fr+qsXx3BSXVblw+lXiriZ6ZCfTNOnJNLapqmW6jpbWBn2t8j3TMsoxRq/RIJ49SYy0M7pgccr5w+mvxYKrG6ePGV1YEdqqPIvXT5VOZviKPV3w2OhK8KWKsxc76km3MWvMhtZ56EmLiuKDLuTwy7t6g4ySLTPyYrKDHvt9RFlEg1Zjbp/HXL7eRFm9tmBW1WRTeuXM4/1qwkzd+zEf3+3FqoR/CsRYFkyIz/bzO3D2my2n5QS0Iwqk1IG2A4Q1wNN3brIqV67pdB0BGgpWMhBZIDwkTRDX25LgYRr7m4IHhTaWUSaD6MJksvDvtHK54cTn1br9h62UjVpPMsJwUHriga5PHPfHZFj7fsD/iGz+XT+Xnomo6ptp56OLTM5A6rI2tDdMHTGdav2msLV1LqaMUt99NvCWeHqk96Jwo0seFlvfIpT0pq3WzcFtZxO8rq1nmpVsGNQQrY7ulMaZrG1buqeQ3S+/BJblCJo8i6Vrq1by8t/097ul/Dya5+dvVslo3b60s4M0VBfg1DVXT0TQdkyyjA+0OlZZM7N8Op1c1vEeNttuo91C2kkmRuX1EJ15ekndMzXFsZoV7xop7xtaqxYOpd1cXUOc2DqSaK+hzeFVesPl41m+DRsFP2/g0Zkz6c9PfWAZTqg1LxyOF235V48EP1hsGUs2Nxe3T+L+5Gzi/RzoWk0yN00eV08t1Q7K5fURHvnjs38xN6cN+YnD7VGxmhc5t4pg2pjMX9EwP1GMJgiAYUGSFW3rewqxNs0LSuCLt3qbpGld1vaplBxZjsI/SUfqkK1zezcTfl3vpmRbmOicpYAoEWzltYvl4+ghufGUl9R5/s6nSdrPCmG5p/OemAaHpd42syDvI3LX7op5BByir9fD+6kJuPqdj1OeebCbZxDltzznVwxDOErIs8fxNA/nXgp28snQPsiSFfY/FWhRiY0y8fOtgBnUIXlGRJImOGR585j0hizXRdC31a36WFC1hfMfxTY77q037+c2H69F1Qq4xXjUwgLxyB49/uoV/frOD303oblj2eDzdRqeMymHu2kIOVLvQpMjvAU2yRIcUO9cMzmr+YOG01KLBlKbpvLpsr+GHZaQFfev9Pvwd0zAV1AcFVE2SJeRYE22m9AmK6r/bXobXYAe2SMei6zp//WIrq/ZWkldej0UJzG74VI22/kzuH9uViSO6ig15BUGI2rXdrmXWplmGzzXXvc0smxnfcTyJMYlhjzkmqbmBZgvN+PN5VgbNrOe354ZpRNEmeEWpa0Y8Cx4cy1sr8pn9w15cHhWfQQ54XIzClJE5PDC+K0ozE1Izl+SFvclrbrLM5VOZsSSPXwzvIGaCBeEokiTx24u6c+fozsxdW8SsZXuocvgwK1LDPdCQTincM7YLo3PbBDWbaOy7wu8Ma0Oj6Vrq9Dv5JO+TJoOpj3/ax8PzNkWUgeT0qrh8Ko9/ugWfwb1qtN1G2yUdqdGP1308t+MjpmdeTK0cg7+5OhcC3VbTE6y8M224uJdsxVo0mFqyqxy3P/TDLZqCPlmSWNozgctsFtw7Kpttcy6ZZeQEC+l39UOJD047mbl0T0gxYDRjcXhV3lxR0JCa4lOPvFZhXDqPL8rn8UX5PHVlH64WMwqCIEQh1ZbK4+c+zpMrnoy4MBtAlmRSrak8MvyRlh+UNQF6XwUbPwQ9/IpPborMDb3NPL/aS9/04KBHNceijPpNyDlJNjP1Hn/ghidM/FLvUXl1+V4+XLuPWbcNCapxaKy01s2KPQcNn4t0sqzS4WVtQRVDO4UPWgXhbJZoM3Pn6M5MHZXDQYeXGpcPiyKTEmtpdj8pgIOug4YNVKLtWlruLA/7PX4qrIo4kDpM1wNBlckgCIymYQQEGln86r2f+celuZT+cjrZHTry9e8vZtpbP7GztA6vqqMaBFUmWcIkSwzqmMyMWwaTaBN1j61ZiwZTq/dUhgQvEF1Bn9Orsmz3QW69bTDuHVXULSnCW1QXaC5zuF2ndCiIirUQf14W9oHpyJbgiN7tU1lfVH1cYwGjnjbBYwV4eP4m9te6+eW43IheUxAEAWBil4nUemt5bt1zTXZtO8wsm0m1pvLGJW+0/KrUYef+ErbOb3bT3sfGxvDWxtBNKl0+jT9v6cSvspxkpwS6cKmazvS317FsV0VDjUE4Tq+K06ty/cwVvHr7EEbmhrZG/nzjfsNzo5ksc3lV3l9dKIIpQWiGJEm0iYsJ2/47HL/uN3w82q6lahMTO//8ZkfYQKqpFWpdB/x+zLqO76h6rIi7jQJ+TefbLQfYsWYTMzp2pu1fnkCSZT65bxRbSmqYvWwvX27ajyxLyIe2BQK4elB7Jo/MITfMHqdC69KiwdRBh/HNQLQFfVVOL5IkYeuRgq1HCv6DLlxbDqLWedH9GkqchZguiVg6JoRN0ahyerEoMi4t+E0Y7Vgi4fZpvPDdLjITrCLnVRCEqNzc82Y6xHfgmTXPcMBxAK/mDdmDyqpY0dC4sMOF/HH4H09cIAWQ2Reyh0PBCmgU4OX/Oj7osOxEGfejR9VYme2YRvyWtv4kJr64nKsGtue+cbk8t2Any3ZVRFXf5PKpTHtzLR/fO4IemcHfZ3+1yzCdPJrJMh0ormp+Kw5BEI5NUkwSiqSEBEPRdi0tqZT57/e76dUugd7tEhr2YtpX5eSngirD7x3JCrVPUlAUCQVCVo8iaQV/mNuvsdeSxGM5l/Jmo8TG3u0S+dcNA3j66r6U13lweP3ExZhIi49pskup0Pq0aDAV7o8j2oK+GFNw2ogp1RbSpa854ToJRjuWwyJpWPGnTzZzWb+2Iu9VEISojM4azeis0Wyu2MzrW17np9KfcPqdmCQTydZkru12LZNyJ53YIKqxG96GWedDVUFQQNUksx16XI71vN/yoCRx6zkdefG7XYx7djFOr2pYP9DcddXpVXl03mY+mj4i6DxnmKAs2smyaPYNFAQhOsPbDmfmxpm4/MGTFtF0LY1RbIzKGEOty8ery/awpaQWiyLTu10CNS6f4XUlmhVqsyJhlmUc3gi28mnieuVF4afCapbsLGdcj+C9oqxmpWGVXjgztWgw1T7ZhlmRQnbPjnYH6MYFfccqyW42bD4RbXEhRJ6DLwFfbNwvVqcEQTgmfdr04dmxz57qYUBMPNy5EN65Fkq3gNcR/tjDnfsG3goT/s7hFllp8TH8+co+VDq8hml5kV5XNxXXUHDQQcfUWAAO1nuoc4WmF0L0k2XHtBeXIAgR6dOmDxn2DPJr80Oei7RrKWg8OvY24iyBdDhd1ymudrGlpJbHPtlsWI8UzQq1quncOaoj834u5qDDGzZlMJLrldOr8vKSvJBgSjjztWgwdVnftjy3YCdHVxpFuwP09UOzj3ssdouJrulxbD9Qd8xjgegbVsxYkieCKUEQWj9rIkz+GnYvhOX/hpJ1gcBJ8wMSKObAf/e8IlBn1W5AyEs4PH4WbisLqT2N5rrq1zR+99EGMhNtrC+qptLhpUOK/bgn7mxmmVEG9ViCILScKX2m8LfVfwtZnYLmu5bKyEzoNKEhkIJA/VZWsp2sZDvPfL2dUgwaXESxQu1TdSQkvvnNWB6au4GvNh8IOSaa69X6omqKKp1iJeos06LBVHaKnYEdkli5pzLkuUgL+lJiLQxpoR2gp5/XhYc/3oTDG5wSEk1xYbQNKwoOOqh0eEmJbWpDS0EQhFZAVqDbxYGvyr1Q8AO4qkE2QWwbyB0PtvDX67UFVYYds6KbOYaN+2q5ZlA2943LpUta4MZq6F8XctAR3MY9mskyTYfrBh//xJ0gCOFd2vlSXt/yOoV1hfg144YU4djNdqYPmB72eYvJePuEaFaoZV3DPf8jahbWUBIzFAgNgqK5XknA4p3l3NoK9rATWk6Lb9p7z9gubNhXg8sbmtMe2Q7QnVts348JfTJ5ZN5mw+ciLS6MNgffoshUO0UwJQjCGSYlJ/AVhWqnF82ggDXa66oiSyEZC3eOzuHfC3eF1D1FMlkmSzChd6ZI8xOEEyxGiWH2xbO54fMbqHJX4dOMU3SPZjfZmXnhTNrHtQ97TFaynW3760Iej2aF2moxkTNuFLGWKmrWGR8TzfXKq2rUOJvfq084s0S+RXOExnZL47K+bbGZo3vpGJPMwA5J3Di0Q4uNJcak8OQVvbFGOZbGGs9wCIIgCFEymBtrievqL4Z1JC7GZLhlVVzvcbS9/d90ePB/ZN/3NunXPYE1q2fD81azwq8u6GpwpiAILa2NrQ0fTfyIbsndsJlsyE3cesaaYmlja8Nbl75Fv7R+Tb7uLed0JNYSGuA0XqF27lyB5nOjq35ceWup+v61oGM1YOIVI0iaNAklKSnktUDcBwrNa/FgSpIk/n51X87vkYEtwq52VrNM3/aJvHr7EEzN7HofrasHZ3HfuNyIx3K0xjMckfCqGsl2sSolCIKQZLcgG4Q70V5X4ww2CE20m3n/rnOIjTERTTKD1Swz89bBDemCgiCceMnWZN677D1mXTSL8zucj0W2EGeOa/iKUWLon9afp0c/zYJrF9AtuVuzrzk6tw12i3GCVcKwq0k+fyo1Kz5g3ws3s2/GHdT99Dm2rkdS9RRZYmK/dsRbAyvU4TKKorleWUwyieIe8KzT4ml+ACZF5sVfDOSVpXt4aXEefk0z3Mw31qKgE5hdeOji7phbOJA67L7zu5KRYOWxT7YgSUc22z2aJIW2VI+2YUVOm1iSRYqfIAgCQzslG7Yujua6alYkLu2bafj6XTPi+fS+kdw0ayX1bn9IfWxjNrOCIku8dsdQhuWIjXoF4WSTJIn+af15btxz1HhqKKwtpN5Xj81kIzM2k8xY4/d5OLIsMW1MDv9asNOwC19z5RxmRWLq6COpy1cOaMeWktqQe8Rorle6DuO6p0X1cwit3wkJpiDwprl7bBemjsph4bYyXlmax55yBy6fitWs0C7JytRRnbn8JO3LdN2QbC7r15ZP15cwY0keB2rcDcGbT9VIi49hVG4bPt1QjNMbfQ4+BILD6ed1OeE/iyAIQmtgt5i4elB7PlhTFBJURXpdlSWJO0aEr9XqnBbH0t+N4+vNB3h5cR57DzowyRKqFph51nSdZLuFe8Z25qpBWYarXIIgnFyJMYn0Tet73K8zZWQO328v46fC6qj2jbOZFX59QdegDcGvGpjFU19sMzw+0uvV4I7JZCWLTn5nmxP+qWJSZCb0yWRCn+hmHE4Eu8XEjcM6cMPQbEprPVQdKhJMspvJTLDi13S+2ryfQBZtsIgaVkhwSZ+2J2DkgiAIrdOUUTl8tG6f4QpVc9dVCeifnUSH1KZvTmJMClcOaM+VA9qzs7SOXaX11Ht82CwmspNtDMhOarHGRoIgnD5MiszsO4Yyec4aNu6rwRVmQ+/GbGaZu8d05u6xwZPfsTEmJg1of8zXK7tFCXlN4exwVk7RSZJEZqKVzERr0ONmReKZa/vzwPs/h924LRybWeHpSX1PyiqbIAhCa9ElLY6bhmXzwZp9Ed3oNGazKDw1qU9U53TLiKdbRnxU5wiC0HrZLSbeuXM4Ly3ezezl+YalJbIUmHRpl2TldxN6cHFv4wn+By/qxoKtpVQ6vCH74zUlxiQzPCeF0WLvurPSWRlMNeXi3pk8fGlPnv5yW8QBldUs85sLu3LlwPAtPAVBEM5Wj13em/J6L99tK4s4oLJbFF67Y6gIjARBaJZJkfnV+G7ce14uC7eV8fqPe9lX5cLj14i1KPTLSuLO0Tn0yzLu2HdYeryV9+86h2tfXkG9249qsLXD0axmmV5tE5hxy2Bkg331hDOfCKYM3HZuJ9LjrfzufxtQNd2weQYEaqRkWeLpq/owsb8IpARBEIzIssSLNw3k2W938OqyvciSZBhUSQRWo1JjLcy8dQi92iWEvpggCEIYLVFa0jUjni8fGM20N9ayt8KB168ZBlUxhzYNvrxfO56+qm/YTYSFM58IpsKY0CeT8T3TWbi1lBlL8thaUhvUsKJbZjzTx3bh4t6Z4g0kCILQDEmSeOjiHtw9tgv/W7uPV5btoaLeg1mR0XQdv6ozplsad4/pzLCcFFHjJAjCKdM+ycaXD4xmc3ENs5fv5ctN+5EITAz5NR27RWHyiBx+MbwDafExp3q4wikmgqkmmBWZS/q25ZK+bXF4/FS7Ajt3J9rMoiOUIAjCMUiwmpk8Koc7Rnai1uWn1u0L7M1iM4uaU0EQTit92ify3A0D+Mc1/ah2eXF7NeKtJhJtZpHSJzQQEUGEYmNMxIoAShAEoUVIkkSi3Uyi3XyqhyIIgtAki0kmPd7a/IHCWUnkpwmCIAiCIAiCIBwDEUwJgiAIgiAIgiAcAxFMCYIgCIIgCIIgHAMRTAmCIAiCIAiCIBwDEUwJgiAIgiAIgiAcAxFMCYIgCIIgCIIgHAMRTAmCIAiCIAiCIBwDEUwJgiAIgiAIgiAcAxFMCYIgCIIgCIIgHAMRTAmCIAiCIAiCIBwDSdf18E9KUjlQcPKGIwjCSdBR1/W0Uz2I4yWuT4JwRmr11ydxbRKEM1LYa1OTwZQgCIIgCIIgCIJgTKT5CYIgCIIgCIIgHAMRTAmCIAiCIAiCIBwDEUwJgiAIgiAIgiAcAxFMCYIgCIIgCIIgHAMRTAmCIAiCIAiCIByD/w84YL7pSAqAmgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize examples from the training set.\n", "# It takes a few minutes to download & prepare the dataset.\n", "ds, ds_info = tfds.load('ogbg_molpcba', split='train', with_info=True)\n", "tfds.visualization.show_examples(ds, ds_info,\n", " node_color_fn=node_color_fn,\n", " node_label_fn=node_label_fn,\n", " edge_color_fn=edge_color_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Start TensorBoard\n", "# Get a live update during training - use the \"refresh\" button!\n", "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n", "if 'google.colab' in str(get_ipython()):\n", " %load_ext tensorboard\n", " %tensorboard --logdir=." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "outputId": "7696aae4-beb5-4df2-b72a-7f391ac30c2e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n", "INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.\n", "INFO:absl:Hyperparameters: {'add_self_loops': True, 'add_undirected_edges': True, 'add_virtual_node': False, 'batch_size': 256, 'checkpoint_every_steps': 10000, 'dropout_rate': 0.1, 'eval_every_steps': 500, 'latent_size': 256, 'layer_norm': True, 'learning_rate': 0.001, 'log_every_steps': 500, 'message_passing_steps': 5, 'model': 'GraphConvNet', 'num_classes': 128, 'num_mlp_layers': 2, 'num_train_steps': 1000, 'optimizer': 'adam', 'skip_connections': True}\n", "INFO:absl:Obtaining datasets.\n", "INFO:absl:Load dataset info from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Reusing dataset ogbg_molpcba (/root/tensorflow_datasets/ogbg_molpcba/0.1.2)\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split train, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split validation, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split test, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Initializing network.\n", "INFO:absl:\n", "+-----------------------------+------------+--------+-----------+--------+\n", "| Name | Shape | Size | Mean | Std |\n", "+-----------------------------+------------+--------+-----------+--------+\n", "| params/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/Dense_0/kernel | (9, 256) | 2,304 | -0.000941 | 0.335 |\n", "| params/Dense_1/bias | (128,) | 128 | 0.0 | 0.0 |\n", "| params/Dense_1/kernel | (256, 128) | 32,768 | -0.000737 | 0.0623 |\n", "| params/LayerNorm_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/LayerNorm_0/scale | (256,) | 256 | 1.0 | 0.0 |\n", "| params/LayerNorm_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/LayerNorm_1/scale | (256,) | 256 | 1.0 | 0.0 |\n", "| params/LayerNorm_2/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/LayerNorm_2/scale | (256,) | 256 | 1.0 | 0.0 |\n", "| params/LayerNorm_3/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/LayerNorm_3/scale | (256,) | 256 | 1.0 | 0.0 |\n", "| params/LayerNorm_4/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/LayerNorm_4/scale | (256,) | 256 | 1.0 | 0.0 |\n", "| params/MLP_0/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_0/Dense_0/kernel | (256, 256) | 65,536 | -0.000616 | 0.0625 |\n", "| params/MLP_0/Dense_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_0/Dense_1/kernel | (256, 256) | 65,536 | -0.000287 | 0.0626 |\n", "| params/MLP_1/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_1/Dense_0/kernel | (256, 256) | 65,536 | 0.000217 | 0.0628 |\n", "| params/MLP_1/Dense_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_1/Dense_1/kernel | (256, 256) | 65,536 | -0.00055 | 0.0625 |\n", "| params/MLP_2/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_2/Dense_0/kernel | (256, 256) | 65,536 | 0.000274 | 0.0625 |\n", "| params/MLP_2/Dense_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_2/Dense_1/kernel | (256, 256) | 65,536 | -0.000227 | 0.0626 |\n", "| params/MLP_3/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_3/Dense_0/kernel | (256, 256) | 65,536 | -0.000224 | 0.0625 |\n", "| params/MLP_3/Dense_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_3/Dense_1/kernel | (256, 256) | 65,536 | 1.42e-05 | 0.0627 |\n", "| params/MLP_4/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_4/Dense_0/kernel | (256, 256) | 65,536 | 0.000319 | 0.0623 |\n", "| params/MLP_4/Dense_1/bias | (256,) | 256 | 0.0 | 0.0 |\n", "| params/MLP_4/Dense_1/kernel | (256, 256) | 65,536 | -0.000191 | 0.0624 |\n", "+-----------------------------+------------+--------+-----------+--------+\n", "Total: 695,936\n", "INFO:absl:Checkpoint.restore_or_initialize() ...\n", "INFO:absl:No checkpoint specified. Restore the latest checkpoint.\n", "INFO:absl:Checkpoint None does not exist.\n", "INFO:absl:Checkpoint.save() ...\n", "INFO:absl:Checkpoint.save() finished after 0.09s.\n", "INFO:absl:Checkpoint.restore_or_initialize() finished after 0.10s.\n", "INFO:absl:Starting training.\n", "INFO:absl:Finished training step 1.\n", "INFO:absl:Finished training step 2.\n", "INFO:absl:Finished training step 3.\n", "INFO:absl:Finished training step 4.\n", "INFO:absl:Finished training step 5.\n", "INFO:absl:Finished training step 6.\n", "INFO:absl:Finished training step 7.\n", "INFO:absl:Finished training step 8.\n", "INFO:absl:Finished training step 9.\n", "INFO:absl:Finished training step 10.\n", "INFO:absl:Created artifact [10] Profile of type ArtifactType.URL and value None.\n", "INFO:absl:Setting work unit notes: 2.0 steps/s, 12.4% (124/1000), ETA: 7m\n", "INFO:absl:[124] steps_per_sec=2.048078\n", "INFO:absl:Setting work unit notes: 2.1 steps/s, 25.2% (252/1000), ETA: 5m\n", "INFO:absl:[252] steps_per_sec=2.127820\n", "INFO:absl:Setting work unit notes: 2.1 steps/s, 38.0% (380/1000), ETA: 4m\n", "INFO:absl:[380] steps_per_sec=2.116415\n", "INFO:absl:[500] train_accuracy=0.9849439859390259, train_loss=0.06188433617353439\n", "INFO:absl:[500] validation_accuracy=0.9839196801185608, validation_loss=0.06475622951984406, validation_mean_average_precision=0.036894\n", "INFO:absl:[500] test_accuracy=0.9832062125205994, test_loss=0.06711383163928986, test_mean_average_precision=0.037820\n", "INFO:absl:Setting work unit notes: 0.4 steps/s, 50.1% (501/1000), ETA: 23m (8m : 31.9% eval)\n", "INFO:absl:[501] steps_per_sec=0.351547\n", "INFO:absl:Setting work unit notes: 2.1 steps/s, 62.8% (628/1000), ETA: 2m (9m : 28.7% eval)\n", "INFO:absl:[628] steps_per_sec=2.109877\n", "INFO:absl:Setting work unit notes: 2.1 steps/s, 75.6% (756/1000), ETA: 1m (10m : 26.1% eval)\n", "INFO:absl:[756] steps_per_sec=2.128443\n", "INFO:absl:Setting work unit notes: 2.1 steps/s, 88.4% (884/1000), ETA: 0m (11m : 23.9% eval)\n", "INFO:absl:[884] steps_per_sec=2.122990\n", "INFO:absl:[999] train_accuracy=0.9868282079696655, train_loss=0.052991833537817\n", "INFO:absl:[999] validation_accuracy=0.984052836894989, validation_loss=0.06297025829553604, validation_mean_average_precision=0.049514\n", "INFO:absl:Checkpoint.save() ...\n", "INFO:absl:[999] test_accuracy=0.9833315014839172, test_loss=0.06533924490213394, test_mean_average_precision=0.050113\n", "INFO:absl:Checkpoint.save() finished after 0.06s.\n", "INFO:absl:Setting work unit notes: 1.7 steps/s, 100.0% (1000/1000), ETA: 0m (13m : 0.0% checkpoint, 22.5% eval)\n", "INFO:absl:[1000] steps_per_sec=1.697496\n", "INFO:absl:[1000] train_accuracy=0.9897193908691406, train_loss=0.04319273680448532\n", "INFO:absl:[1000] validation_accuracy=0.9840479493141174, validation_loss=0.0629250779747963, validation_mean_average_precision=0.049538\n", "INFO:absl:[1000] test_accuracy=0.9833057522773743, test_loss=0.06529100239276886, test_mean_average_precision=0.050135\n" ] } ], "source": [ "# Training loop\n", "\n", "# Use a Colab GPU runtime to speed up training.\n", "# We don't use TPUs in this Colab because we do not distribute our\n", "# training using pmap() - if you're looking for an example using TPUs\n", "# checkout the below Colab notebook:\n", "# https://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb\n", "\n", "config.num_train_steps = 1000\n", "config.log_every_steps = 500\n", "config.eval_every_steps = 500\n", "\n", "# Construct the model and start the main training loop.\n", "# The default config corresponds to a 5-layer Graph Convolutional Model\n", "# with skip-connections and mean-pooling.\n", "state = train.train_and_evaluate(config, workdir=f'./models')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "cellView": "form" }, "outputs": [], "source": [ "#@title Upload to TensorBoard.dev\n", "if 'google.colab' in str(get_ipython()):\n", " #@markdown You can upload the training results directly to [TensorBoard.dev](https://tensorboard.dev).\n", " #@markdown\n", " #@markdown Note that anyone with the link will be able to see the data.\n", " upload_data = 'no' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/ogbg_molpcba'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Create deterministic evaluation model.\n", "eval_net = train.create_model(config, deterministic=True)\n", "eval_state = state.replace(apply_fn=eval_net.apply)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "outputId": "51242dc2-5260-4279-b41a-9d7614b33c97" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load dataset info from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Reusing dataset ogbg_molpcba (/root/tensorflow_datasets/ogbg_molpcba/0.1.2)\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split train, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split validation, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split test, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n" ] } ], "source": [ "# Compute accuracy and mean average precision on validation and test sets.\n", "datasets = input_pipeline.get_datasets(\n", " config.batch_size,\n", " add_virtual_node=config.add_virtual_node,\n", " add_undirected_edges=config.add_undirected_edges,\n", " add_self_loops=config.add_self_loops)\n", "eval_metrics = train.evaluate_model(eval_state, datasets,\n", " splits=['validation', 'test'])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "outputId": "2a529c9f-fed6-4221-a738-a8ffd9049d7e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "validation\n", "- accuracy: 0.984\n", "- loss: 0.063\n", "- mean_average_precision: 0.050\n", "test\n", "- accuracy: 0.983\n", "- loss: 0.065\n", "- mean_average_precision: 0.050\n" ] } ], "source": [ "for split in ['validation', 'test']:\n", " split_metrics = eval_metrics[split].compute()\n", " print(split)\n", " for metric_name, metric in split_metrics.items():\n", " print(f'- {metric_name}: {metric:.3f}')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Helper functions for formatting labels and predictions.\n", "def get_formatted_label_for_task(labels, task):\n", " label_for_task = labels[task]\n", " if np.isnan(label_for_task):\n", " return 'Unknown'\n", " elif label_for_task == 0:\n", " return 'Inactive'\n", " elif label_for_task == 1:\n", " return 'Active'\n", " raise ValueError('Invalid label.')\n", "\n", "# Predictions are computed with a threshold of 0 for the logits.\n", "# This is the same threshold used to compute the accuracy.\n", "def get_formatted_prediction_for_task(logits, task):\n", " predictions = logits > 0\n", " prediction_for_task = predictions[task]\n", " if prediction_for_task == 0:\n", " return 'Inactive'\n", " elif prediction_for_task == 1:\n", " return 'Active'\n", " raise ValueError('Invalid prediction.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can choose one of the 128 different tasks and see how the model predictions\n", "match up with the true labels.\n", "\n", "You can change this visualization to use any of the 128 tasks in this dataset, by changing the `task` variable below. See the appendix [here](https://arxiv.org/pdf/1502.02072.pdf) to understand what these tasks mean.\n", "\n", "The default is task PCBA-686978 (indexed at 93)." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Define which task to plot labels for.\n", "task = 93" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "outputId": "38934096-13a6-4701-823a-ac83f3b7eaac" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:absl:Load dataset info from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n", "INFO:absl:Reusing dataset ogbg_molpcba (/root/tensorflow_datasets/ogbg_molpcba/0.1.2)\n", "INFO:absl:Constructing tf.data.Dataset ogbg_molpcba for split test, from /root/tensorflow_datasets/ogbg_molpcba/0.1.2\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1MAAArLCAYAAAAvFZZhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5hU1fnA8e+5U3a2F3pZYCmCgIgUKaKALUbsBSuKFVsSU9TE2JLoLybRxESjWImKBSuWaOygSJPepC11WWB32b4zszNz5/z+mFmY3ZnZnVm2wvt5nnmYvfXc2eXMfe855z1Ka40QQgghhBBCiPgYrV0AIYQQQgghhGiPJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYKqdU0pNUkrltfS+jTzfPKXUjS11PiFE65G6SQjRHrSnuqoxlFK9lFKVSilLa5flSCXBVB3BP7ial18p5Qr5+apmPO90pdSC5jr+4VJK9VFKaaWUtc7y/yilHm6tcglxtJC6KTKpm4RoW6Suii5YV/Vv5nPsUEqdXvOz1nqX1jpFa20253mPZtaGNzm6aK1Tat4rpXYAN2qtv6y7nVLKqrX2tWTZhBBHL6mbhBDtgdRV4mgjLVMxqmnKVUrdo5TaB8yK9BQk9KmDUipBKfWYUmqXUmq/UmqmUiqxEee+Tin1o1KqQim1TSk1I8I29yqlioJPJK4KWd4kZYixnNOVUguC5ytRSm1XSv00yrbdlFJrlFJ3BX+ep5T6k1Lq++B1fq6U6hiy/XlKqfVKqdLgtscGl1+nlPooZLstSqm3Q37erZQaHnyvlVK3BLcpVUr9WymlmuOzEKKlSN0UUzmlbhKilUldFXa+h5RSbymlXgmWa71SalTI+t8qpXKD6zYopS6ss/9NIde0QSk1Qin1KtAL+EgFWgLvViGt90qpy5RSy+oc55dKqQ+b81qPdBJMxacrkAX0Bm6OYftHgWOA4UB/oAfwQCPOWwCcA6QB1wH/UEqNqFOujsHjXws8p5QaGG8ZlFJPK6WebkT5Qo0BNgXL81fgxbo3BUqpHGA+8JTW+m8hq64kcH2dATvwm+D2xwBvAHcCnYBPCFQU9uBxTlZKGUqp7sH9xgX36wukAGtCznEOMBoYBkwFfnKY1ytEWyB1U8OkbhKi9UldVdt5wJtABvAh8FTIulzgZCAd+AMwWynVLXieS4GHgGuC13QecEBrPQ3YBZwb7Nr31zrn+wgYqJQaELLsSuD1eK9VhNBayyvKC9gBnB58PwnwAI6Q9dOBBXX20QT+ABVQBfQLWTcO2B7lXGHHqqdcc4FfhJTLBySHrH8LuL+hMgT3zYvxnH2C12ats/w/wMMh17A1ZF1ScJ+uwZ/nAX8Pfq5X1DnOPOC+kJ9vA/4XfH8/8FbIOgPYA0wK/rwbGAFcDjwHLAUGEagwP6zzu5lQ53P6bWv/nclLXvG+pG6qdU6pm+Qlrzb6kroq7Lwa6B98/xDwZci6wYCrnn1XAecH339WU/76PvPgz7XqSGA28EDw/QCgIlgnxvV5y+vQS8ZMxadQa+2OcdtOBP44l4c8/FRA3NlUgt1RHiTwtMAIHndtyCYlWuuqkJ93At2bsgwEKhoAW8j7mp+9IT/vq3mjtXYGz5sSsv4qYCvwToRz7At57wzZrzuBa6o5rl8ptZvAExMIPAGeRKDynQ+UAhMJVALzYzyHEO2Z1E1SNwnRHhzNdVUkdf/fO1RwLJlS6hrgVwSCIQjUCTVdjLMJtFw1xuvA48AfCbRKzQ3WiZ1p3ms9Ykk3v/joOj9XEfjDA0Ap1TVkXRHgAoZorTOCr3QdMjAzFkqpBOBd4DGgi9Y6g0BXktDuKZlKqeSQn3sB+U1VhqC9BG5M+tRZnkPIzUQMHgqW63UVe5rOfAJdAgAIds3JJvAEGA7dsJwcfD+fwA3LRMJvWIQ4EkndJHWTEO3B0VxXxVPm3sDzwB1Ah2CZ14WUeTfQL8rudT/jur4AOqnAmM0rONTFr1Wu9UggwdThWQ0MUUoNV0o5CHwZA4EnlAT+I/wjGO2jlOqhlKqvH7xSSjlCXwT62ScAhYAv+HTlzAj7/kEpZVdKnUygX/DbjSxDRDqQUvNd4BGlVAellE0pdQWBZulP4ziUF7gUSAZeUUrF8jf4FjBFKXWaUsoG/BqoBhYG188HJgOJWus84DvgLKADsDKOsglxpJC6SeomIdqDo6auilMygaCoMHjO64ChIetfAH6jlBqpAvoHAzCA/UDfaAfWWnuBt4G/ERi/9kVweWtda7snwdRh0FpvJtBM+iWwBag7v8E9BLqNLFZKlQe3G0h04wk8Faj7+jmBL+0SAk2yH9bZb19wXT7wGnCL1npjvGVQgawtM+sp321AMYFB0wUEnphM0Vrvr2efMFprD3AR0AV4qaGbFq31JuBq4EkCT07OJTC40hNcvxmoJHCjgta6HNgGfK9lXgVxFJK6SeomIdqDo7CuionWegOBrniLCARHxwHfh6x/G3iEQKtSBYExYFnB1X8G7lOBzKC/iXKK14HTCQSMod2j4/28BaC0bqg1UAghhBBCCCFEXdIyJYQQQgghhBCNIMGUEEIIIYQQQjSCBFNCCCGEEEII0QgSTAkhhBBCCCFEI0gw1QyUUv9RSj0cfH+yUmpTC51XK6X6N/ExD15LS+57OJRSvZRSlXHMFSPEEUnqosPft6Uope5VSr3Q2uUQoiVI3XT4+7YUqZsadtQGU0qpHUopV/Cme3/wD7rJJybTWn+ntW4wraRSarpSqm5K0CajlJqnlLqxuY7fVJRSk4KV3T1x7LNDKXV6zc9a611a6xRJPSzaA6mL2pbmvv7gOSYppfJCl2mt/09r3WY/F3H0kbqpbZG6qe06aoOpoHODMzuPAEYB99XdQCllbfFSHd2uJTBfzDWtXRAhWpDURUKItkjqJiEacLQHUwBorfcAnxKcXTrYMnK7UmoLgUnkUEqdo5RaFZwEbaFSaljN/kqpE5RSK5RSFUqpOYAjZF2tKF8pla2Uek8pVaiUOqCUekopdSwwExgXfAJUGtw2QSn1mFJqV/Cp0EylVGLIse5SSu1VSuUrpa5v7PUrpd5WSu1TSpUppb5VSg2ps0lHpdQXweubrw7Nso1SalBwXbFSapNSauphlCMZuAS4HRiglBpVZ/1NSqkfg+XYoJQaoZR6FegFfBT87O5WSvUJ/g6tSqnLlFLL6hznl0qpD4Pv6/2MhWhJUhe1jbqoTpl2KKV+o5RaEyzXHKWUI7guUyn1cfAzLAm+7xmyb5ZSalbwcylRSs0N1nOfAt2Dn3GlUqq7UuohpdTs4H6fKqXuqFOO1Uqpi5rzWoWIRuomqZuC+0ndFIEEUwT+4wJnAytDFl8AjAEGK6VOAF4CZgAdgGeBD4P/ie0EZp5+lcDs028DF0c5jwX4GNgJ9AF6AG9qrX8EbgEWBbunZQR3eRQ4BhgO9A9u/0DwWGcBvwHOAAYQmMm6sT4NHqMzsILA7N+hrgL+BHQEVtWsD/7H+4LATNqdgcuBp5VSg6Ncf6lSakI95bgIqCTwGX5GoJWqZt9LgYcItFilAecBB7TW04BdBJ+eaa3/WueYHwEDlVIDQpZdGSwz1PMZC9HSpC5qM3VRXVOBs4AcYBgwPbjcAGYBvQk81HEBT4Xs9yqQBAwJlusfWusq4KdAfvAzTtFa59c53xvAFSHlHRw8x3/jvVYhmoLUTVI3BUndFInW+qh8ATsI3LiXEvhP+zSQGFyngVNDtn0G+FOd/TcBE4FTgHxAhaxbCDwcfD8JyAu+HwcUAtYI5ZkOLAj5WQFVQL+QZeOA7cH3LwGPhqw7Jlju/lGudx5wYwyfS0bwOOnBn/9DoCKrWZ8CmEA2cBnwXZ39nwUeDNn34Th+J18CTwTfXxH8rGzBnz8DflHP7/L0kJ/7BK/BGvx5NvBA8P0AoIJAJVLvZywvebXES+qiqJ9Lq9RFEa5/B3B1yM9/BWZG2Xc4UBJ83w3wA5kRtjv4uwhZ9hAwO/g+NfiZ9w7+/AjwUvB9vdcqL3k11Uvqpqifi9RNUjfVeh3t/Vwv0Fp/GWXd7pD3vYFrlVI/C1lmB7oT+A+1Rwf/aoJ2RjlmNrBTa+2LoWydCNzwL1dK1SxTQE2Guu7A8hjOWa/gU6BHgEuD5/QHV3UEyoLvD34WWutKpVRx8Py9gTE1ze1BVgJPPOItRzYwGfhdcNEHwHPAFAJPtLKB3HiPG/Q68DjwRwKtUnO11k6lVGfq/4yFaClSF7WRuiiKfSHvncFzopRKAv5B4MlwZnB9avBasoFirXVJvCfTWlcopf5L4MnuXwg8XLopuLq5r1WIUFI3Sd10kNRNkR3twVR9Qv/T7wYe0Vo/UncjpdREoIdSSoVUFL2IfOO/G+illLJGqCh0nZ+LCDTLDtGBvsp17SXwH6JGr+iXUq8rgfMJNH/vANKBEgIVUo2D51GBTD5ZBJ4y7Qbma63PaOS5Q00j0Cz9UUil6CDQ1W9u8Fz9ouxb97Or6wugk1JqOIH/+L8MLm/oMxaiLZC66JCWqIvi8WtgIDBGa70vWMesJFDm3UCWUipDa11aZ7+G6iwIdKd5UCn1LYG68Jvg8ta6ViHqkrrpEKmbjuK6ScZMxeZ54Bal1BgVkKyUmqKUSgUWAT7g50opW3AQ3olRjrOUwH/uR4PHcCilTgqu2w/0DPYtRmvtD573H8EWFJRSPZRSPwlu/xYwXSk1OPgE4sEYrsMaPGfNy0agybYaOEDgCc//RdjvbKXUhGDZ/gQs1lrvJtCv+Ril1LTgtduUUqNVYKBovK4F/kCgKbrmdXHw3B2AF4DfKKVGBn8H/dWhAZ77gb7RDqy19hLoo/03AhXcF8HlDX3GQrQ1Uhc1f10Uj1QCN3OlSqksQq5da72XwDiLp1VgMLhNKXVKcPV+oINSKr2eY39C4EnvH4E5wd8DtN61ClEfqZukbjpq6yYJpmKgtV5GoBnzKQJPI7YSHOSntfYQSJwwnUBK78uA96IcxwTOJTBIcheQF9we4GtgPbBPKVUUXHZP8FyLlVLlBMYUDQwe61PgieB+W4P/NuQZAv+5al6zgFcINH3vATYAiyPs9zqB/4jFwEjg6mAZKoAzCTT35hNobv4LkBDp5CqQGebkCMvHEviP+W+t9b6Q14fBa7tCa/02gWb21wmMeZpLIDAC+DNwnwoM3PxNlGt/ncBTpbfrPO2K+hkL0dZIXdS8dVEjPAEkEnhCvhj4X5310wAvsBEoAO4Mlncjgae724L1Vve6B9ZaVxP4/Z3OoYQ5cV+rEC1B6iapm47muknV7sIqhBBCCCGEECIW0jIlhBBCCCGEEI0gwZQQQgghhBBCNIIEU0IIIYQQQgjRCBJMCSGEEEIIIUQjSDAlhBBCCCGEEI1Q76S9HTt21H369GmhogghWsLy5cuLtNadWrsch0vqJyGOPEdC/SR1kxBHnvrqpnqDqT59+rBs2bLmKZUQolUopXa2dhmagtRPQhx5joT6SeomIY489dVN0s1PCCGEEEIIIRpBgikhhBBCCCGEaAQJpoQQQgghhBCiESSYEkIIIYQQQohGkGBKCCGEEEIIIRpBgikhhBBCCCGEaAQJpoQQQgghhBCiESSYEkIIIYQQQohGqHfSXtG+bCnZwq6KXTi9TpJtyfRJ70Pf9L6tXSwhhBCtZE+pix/zy6mo9uKwWuiWkcjxPdNRSrV20YQQosmUVHlYtbuUcrcXq2HQMcXOyN6ZWC3N324kwVQ7V21W8/mOz3lx7YvsqdyD1bDi134MZeDz++iT3ocbht7Aab1Ow2axtXZxhRBCNDO/X/PtlkKenb+NFbtKsFsN/FqjUGitSUu0cfPJfbl4VE/SHPK9IIRon7TWrNpdyvPfbeOrHwtC6joAhcVQXDOuN9PG9qZzmqPZyqG01lFXjho1Si9btqzZTi4Oz8bijdz0+U14TA9OnzPqdknWJJJtybzwkxekpUqglFqutR7V2uU4XFI/CRGusKKaq19YQl6JkyqPGXW7RJsFpeDpq0YwaWDnFixh/Y6E+knqJiGan8tjcuvs5SzdUYzba+KPEs4kWAMtU/dNOZZp4/o0+nz11U0yZqqdWlO4hms+vYbS6tJ6AykAp89JkauIK/97JZtLNrdQCYUQQrSkgnI3U/71HbmFlfUGUgAur4nTY3LL7OV8sja/hUoohBCHz+01uWTmQhZtO4DTEz2QAqj2+an2+fm/Tzbyr6+2NEt5JJhqh/ZV7WPGFzNw+Vwx76PRVHmruOGzGyh2Fzdj6YQQQrQ0j8/P5c8vprjKg6++O4s63F4/v3prNat3lzZj6YQQounc9toKthZUUu3zx7yPy2vy9LytfLSq6R8eyZipduiltS/h9rnDlpd8V0LRZ0V4CjxYHBbSRqbR5ZIuWJItB7dxep28/uPr3HHCHS1ZZCGEEM3o03V72VfmjhhIVW2YR/kPc/EeyMOwJ2Lr3Jf08VNx9BwCBAKqRz/dyBs3j23pYgshRFzW7SljUW5RxEAqlrruT//dwJRh3TCMpkvCI8FUO+PyuZibOxef9tVaXvRpEYWfFtLzxp6kDE7BW+Il/9V8djy2g5zf52AE+4x6/B7e2PgGM46fgc2QgcdCCHEkmDk/F2eErn3lS9+nbMk7dDjzdhw5I1AWK67ty3FtWXLwBgNgxa4Sdhc7yc5KasliCyFEXF5csB2PGf7QKNa6rqrax/e5RZw8oFOTlUm6+bUz/9v+PxS1o2nTZVIwt4DuV3cndVgqyqqwd7KTfVs2niIPZQvLam3v8/uYv3t+SxZbCCFEM9m4r5ztRVVhy/3VVZQueI2sM24laeB4DLsDZbGS1H8MmZOvr72t1ry6aGdLFVkIIeJW6vTw37V7Meu0wMdT11V5TGbOz23Sckkw1c4s3rs4LOGEc4sTv9dP2si0WsstDgupw1KpXF9Ze3ufkx/2/dDsZRVCCNH8lu8sibi8es9GtM9D0jHjGjyG19R8u6WADfnlfLelkG82FbBiVwlub/2JLIQQorkVV3l4Zt5WJv7tGzwRuvfFU9cBrNrVtGNEpZtfO1NaHf4HYFaaWFOsKEt4/09ruhXXzvBEFZKEQgghjgzlLh9eM/wGw3SVYySloQxLhL3Cbd5fySUzF2IJjiXQOtBidcnInlx3Ug45HZObtNxCCFEfrTV/+2wTLy7YjlKBMU+RxFvXuZr4IZEEU+2M3bCHLbOkWPBV+tCmDguofGU+rCnhv+YES0KzlVEIIUTLsVsNDKUwqd31xZKYht9ZjvabMd1k+DURx129sWQXc37YzXnHd+fPFx2H1SKdWoQQzcvv1/z8zZV89WNBg1n74q3rrEbT1mFSI7Yz3VK6Yajav7ak/kkoq6J8eXmt5abbpGJNBcmDaz9NtCor3ZK7NXtZhRBCNL+OKXbsEQKchB6DUFYbzs2LDuv4Xr+m2ufn4zV7mT7rB3wRWsGEEKIpPfLJj3z1Y0FMrUjx1nVpiU3bliQtU+3M+f3O5/0t7+M2D6VGtyRZ6HxBZ/Jn52M4jFrZ/GxZNjLGZ9Q6hsWwcHbfs1u66EIIIZrBqYM6hw3IBjASksmYcBXFX8xEGRYcOSegDCvuHatw71oTNjC7IS6vyfKdJfz+/XX85ZJhTVV8IYSoZXexk9mLd8aV/jzWus5uNbh0ZHaTlleCqXZmSMchdEvuxvby7bWWdzq7E5ZkC/vm7MNT4MFINEgbkUb2jGwMW+0nlgMzB5KTntOSxRZCCNFMUh02zh3enfeW51E3Y3DaiRdhJGdStmgORR8/hrInktClP2njLot6vPrmanF5Teau2sPtk/vTq4OkURdCNL1XFu3Ar+NLf545+fqY6joFXDO+d5OWV4KpduiG427g4SUPh03cmzUxi6yJWfXum2hN5Pqh8T2NFEII0bbdOKEvH63Ox4wwQDtlyGRShkyO6TixzNXi15pZC7fz4LlDGjiaEELEp9pn8sbS3XjNyOnPO5x9J0kDxx9cntR/DEn9xwAN13UWBSfmZNEtPbFJyyxjptqhc/udy+guo+NOIuGwOJjYcyKn9jq1mUomhBCiNQzsmsotE/uRaIstm1Uksc7V4jU1c37YLWnThRBNbt2e8ojL401/XpcC0hPtPHbp8YdRusgkmGqHDGXw90l/5/hOx+OwOGLaJ9GSyNhuY/m/k/8PpcJTqAshhGjffnHaAC7p6ifB9DZq/3huVgyl2Ly/olHnEUKIaMpcHiLdpsab/jyUxYDMZDtv3TKWLmmx3TfHQ7r5tVMOq4Nnz3iWZ1Y/w2s/vobWOmwyX4AkaxIWZWH60OnceNyNYZkAhRBCtH1r88p4ZdEONu+vwOkxSU6wMqR7GteO78MxXVIB8OTmMm32wxx7z+M8sb6KqmofVRFSnUcTz82KUlDmalzQJoQQ0UR74B9v+nOABGvgnvfkAR155MLjmiWQAgmm2jWrYeVnJ/yMGcNm8MXOL3h5/ctsLNqJzWqSaHXQK60X04dMZ3KvydgMW2sXVwghRJw+Wp3PP7/czJ5SN9U+k9CkfWvzSnl3RR79O6Vw54Rscn53B53vuosrLxjH5edpvs8t4rlvt7F+TzlOrw+7xaBjip3tRU7Ch3a3/lwtQgiRmWRHR0g+EZr+PHnQhAaPo4DbJvXjijG96JzaPEFUDQmmjgB2i50pfacwpe8URj38BZ/8/GQ6N1P0LYQQovn5/Zr7P1jHeyv2RJ1nxdRgev2syy/njjmruXL0xdx/4QUAGIbi5AGdOHlAp7D9hj74GZXVvrDl8dys+ExNx5TwSeSFEOJwDO2eFnxQU7vei3eqh0tG9uQXpx/TImWWYOoIYvo1pU4vWcnyBSeEEO3ZQx+t570VebgiZOeLxK2svKF6kvn1Vn522oB6tz1nWDfeWZ6Hr87cVPHcrGQl2+nfOSX+CxNCiHpYLQbXjO/NU19vpe70eTVTPRR//QJmeQFoDRYb9s59ceetP5hxFIj4IKnZytxiZxLNrrjKQ1qiDatFul4IIUR79eWG/by9LHIgVf8cUH7+PW8r4/t3YGTv6NNk3DAhh7mr9oQFUxDbvFSJNgszTukryYyEEM3i3OO786+vtkZc568qRXvddDr/d1GnbwD4YsM+zhvevUXKK8FUO6a1Zsn2Yl5csJ3N+yood3uprPYxfdZSbpiQw0n9OmIY8mUnhBDtyZNfb4nYtS+WOaCqvX6enpfLi9dGD6YGdEllQOdU1uWXEWFoQoNztWg0F47oEf+FCSFEDL7+sQCrocIe+MQy11SNzzfsp8zpJT2p+XMGSDDVDmmteXtZHk98tZlSpxeXx6w1mHjepkKWbi8mJcHKHZP7M21cb3mCKIQQ7UBuYSWb9oWnHI/1JkID320uorCimk6p0ecifOLy4Zz31AKqquObK8phM/jrxcNIdUhSIyFE83h/ZeSW83imb7BaFN9uKeTc45u/dUr6g7Uzfr/md++t5cEP15Nf6sZZJ5Cq4fSYFFRU8+dPN3LHGyvxmrH1uxdCCNF65izdfdg3EUrB3JV59W7Tr1MKr94whmRDo3Rs3w8Om8FvfzqI84ZLq5QQovmUOiNPuxDP9A0+U1Pi9DR10SKSYKqdeeij9XywKj9qdqe6XF6Tr37cz11vr46YalIIIUTbkVtYGTGYiucmotrnZ8cBJ6bfpMRdQl5FHsXuYkx/7e+NIb4Snlw8k2Gdk0iwGlgjdAtXQJLdQrd0B/++cgTTx+c0+tqEECIWOmIzQe3pG2Lhj1CXNgfp5teOfLu5MDgoOfyPqL5ByW6vn8837OeTtfuYMqxbK5RcCCFELJxRJtmNZw4oZStmZeV8TnrzW7ymF4thwfSbWAwLFw24iKsGXUWPhM7s+eWvOOGWa/hg6mlsK6xk1vc7+HB1PpVuH340iTYLJ/bJYsbEfoztmyXdxYUQLSLNYWN/eXXY8nimb7AaiswWym59xAdTFW4vH63ey9aCCspcXjKS7BzTJYVzhnUnOaF9Xf4z83IbPSjZ6TF5et5WCaaEEKINS0+MPBYpppsIoxpH9zewJm8lzw/aH5xLqqYXnx/mbJzDO5vfYUhlBr8dMIiMSy8FoG+nFP50wVD+dMHQg70YJHgSQrSGs4Z2Zee32/D4andBjmf6Bq9fM65vhxYpb/uKJuKweX8Fz327jY/X5KOgVorZJLuFhz7cwPnDu3PTKX3p16ntz5Wxu9jJil0lYcvjyWxSM7B5YNfUZi+vEEKI+A3vlcG8zQW4vfHdRGSdfhmJfZ7GsJahDF+UTjLg0z58po81tv38YozidfcBOiZ2rLWNBFFCiNZ09djePD0vN+K6WKZvUMCEfh3pnOZokfIekcHUe8vzuHfuWrymxozQX7KmG8Xby3fzwap8Hrt0GFOGtUwu+sb6YNUe/BHGPMUzKNnr8/PWsl3cf86QBrcVQgjR8qaOyuYfX2yOuC7qTcT4i0ns9QKGtRRlxDaWwGeFQlcR1//vet48502SbElNeRlCCNFoH6/Jr3ecf0PTNyTaLdw8sW9zFC2iIy6YenvZbh74YH3YU71ITD+4/Ca/fns1fk2LpE9srJ3FTrzm4Q1KNjXsKnY1R/GEEEI0gaxkO6cd25n/rdtHpLHTkW4ibBmLMexFYYFUyXclFH1WhKfAg8VhIW1kGl0u6YIlOfB94dM+8qvyeWvTW0wfOr25LkkIIWK2ZNsBHvtsU8T6r778ADUSrAZj+3ZgTE70ufaa2hEVTK3bU8b9H6yLGEg1lKDh7nfWMKhrKgO6tM0ucK4mGJRc33GEEEK0DT87dQBfbwzv6heZxt5hPsqonUq46NMiCj8tpOeNPUkZnIK3xEv+q/nseGwHOb/PwbAGkvlWm9W8suEVrhlyDYaSBL9CiNb1z6+21BqaUyOW/AAAqQ4rT181okW7Kx9RwdRTX28NG6wGsf0CPKafmfNzeXzq8IP7eU0/n6/fz1vLdlNQ4cZnajKSbJw2qDOXje7VYllCIPC0MpJ4MpsAZCbLRItCCNGWHdstjb9dcjx3vbO6wYDKkrgTZamqtcx0mRTMLaDHDT1IHRZ4QO69RH4AACAASURBVGjvZCf7tmw237WZsoVlZJ6SeXD7Km8Vi/cuZnz38QghRGvJK3GyfOfh5QeoqvY1eznrOmIeQx2orObrTQVhzYI1v4CsM24laeB4DLsDZbGS1H9Mrawfpl/z8Zq9lLu9OD0+/vbZRkb+6Qvufmc18zcX8uPeCrYUVPLDjhKe+GoLY//8Fbe/toLtRVW0hJG9M0m2h7c8hQ5Kdm5ehN/rRps+XLnLKPnmpVrbJtosjO7dcs2eQgghGufc47vzxGXDSbQZ2CzRn7BaUteBUXtiSucWJ36vn7SRabW3dVhIHZZK5frK2tv7nHy+4/OmK7wQQjTC60t2RRwrFU9+AEMpPlm7tzmKF9UR0zL19vK8iJFhvL+A2Yt28u6KPPJKXFRHaOUCDj4p/HTdXuZtKuDF6aMZ28zpF88a2pV7318XcV0smU0AtNZcOEJmrhdCiPbgrKHdGNI9nWdf+pR3C60YCQm4vCZ+DRYDHDYLRkJlIHVVCLPSxJpiRUUIwqzpVlw7w8fOFrmKmusyhBCiflVFcCAXlfsDx/ur2UNH8jmUZTSe/ABVHpPN+yuas7Rhjphgat2eMtwRgp94fgEub2AuJrfXH3EG+rr8OvBLu27WD8yZMZZhPTMaVfZYJFgtXHFiNi8v3BExEUVDmU0sKvCkM9Uh3fyEEKK96JmZyPXzX+bXv/0d36f2ZlexkwqXj7REG307JfN54Zd8vrP2PpYUC75KH9rUYQGVr8yHNeWI+eoXQrRXWsPOhbDwX7DtG7AkcJvH5Ga7JgEv63UfnvWdw1f+EXHnByhxehvcpikdMTVquTvyBxfvL6Cq2ow4P0d9CSxcXpNrX1rKkntPx25tvp6TN07oy5tLd+M14+8PmmCzcNvk/s1QKiGEEM3FtWIF2uslY/xYzokwoHrV0k4oFDrkmyupfxLKqihfXk76iekHl5tuk4o1FXS5pEvYcToktszklkIIQWUBvHohFG8HrxPQ4KsmGQ62tI9UW/i77RmqcHB5jzspjCM/QLTJz5vLETNmKiUhclwYmqAhFpECqfKl71P81fOkj51Kzztm0+PWWaSOOBvXliUHt/H4/Hy2fl9jih6zrukOXr5+NIm2hoPCUA6bwcyrR5LTMbmZSiaEEKI5lLzxJhmXXxY1M9WpvU4l0ZpYa5klyULnCzqTPzufijUVaJ/GU+hh99O7sWXZyBhfuxdFkjWJ03ud3mzXIIQQB5Xnw8wJULgRvFVEvvMOSFFuOlHKx2l/5tgJZ8WcH6Cl73ePmJapYzqn8qWlAI8Z36zxoUkoAr/Q2l9YsWYQqfKYzJyf2+xzVY3sncUbN4/lmheX4DU1Lm/0VOcOm4HFMHjx2lHNPqZLCCFE0/IVF1M5fz5d7/t91G1GdRlFmj0Np89Za3mnszthSbawb84+PAUejESDtBFpZM/IxrDVfo7qsDo4qcdJzXINQghxkMcJs84G5wHwx9bLylCQTDULT17CqOSp7GooPwCac4Z1a47SR3XEBFNTR2fz9PzciOtiTdAQNoqX+BJY5BZUsuuAk14dmncm+eHZGSz47am8tzyPZ7/dRrnLi9trglLYLQaGAUl2KzdOyOGy0dlkJLVcCnchhBBNo/Tdd0k9/XQsGdHH4yqluHbAlTyx8gmqjdoPE7MmZpE1sf4MrgmWBK4ZLHNMCSHCVXgq+GDrB3yx8wvKqsswlEGHxA6c1+88zuxzJgmWhPgOuPpNqNxfK5Dq80QFTi9s/0UKyfbAffgLKzzMXuNl3vRDLUxJuPnD8fu5d8gTUQ9vMRTnDmv5/ABHTDDVPSORE/tksWBr5IxEDSVoiCaeBBY2q8GeUlezB1MAaQ4b00/K4drxffhhRwn3vLua8f06MrRHOr07JDE2pwOG0XITlgkhhGg62u+ndM5b9Pj74/Vu51q7lhH3vkGXC5LY43Bh6tgnZrcoC52TOnP5oMsPt7hCiCPI3sq9PLnyST7f+TkKhdt0H1y3pXQLawrX8PDih7l4wMXMOH4G6Qnp9RwtSGv4/ongGKnaTA3/XOLh3pOjB2c25edCy/c87JuGE0fEbewWg5tO6dtwWZrYERNMAdw+uT/Ld5bU2/UtXnElsNDg8rbsZGFKKU7MycJqGFwzrg8Du6a26PmFEEI0nunXfL+1iO1FVVR5fCTbrfTtlMyw/A1Y0tJwHHdcxP20aXLghRcpfvllet5/Hy9PGs2V/72SIlcRXn/DmaysykqmI5OXfvISyTYZTyuECFhftJ6bvriJKk8VfiJPEVTTrfjNTW/y5a4vmXXWLHqkNDD1zu4lgRToEdw13s5fv6/mttF2MhzRGwL8KC60fMdr5hlh6xw2gz+eP4RjurT8fXC7DqaKKqspdXrQGjKS7Iztm8Wtk/rxzLzcuAKqRJsFj88kQsbxWgksGswgogItRq3hQJWHrGTpzieEEO1BcZWH15fuZNaCHbh9Jj5T4zP9WC0GVovC7nJy9aSruc7pDavbvfn55N99DyhFzrvvYOsWGB/w9rlv86t5v2J14WpMv4lPhz/csygLVsPK4A6D+efkf5LpyGyR6xVCtH3bSrdxw2c3UOWriml7r9/Lfud+rv7kat49712yHPV0K97+HfjC57gDGNXdwqQ+Vh5bWM3Dp0ZudQJIVtWcZqysFUwpIMFm8NC5Q7h0VHZM5W5q7S6YcntNPlydz8x5ueSVuLBZAxGs16fpkZnILRP7csO4bJ6ft4Vqo+HAJtFm4ZdnDOD9FXv4cV/4JF/xJLDw+PytkjHP9GvKXV4yk2QOKSGEaOsW5R7gxld+wDR12PyIps9PtQ+qjASeLzB46a/f8NK1oxgTTCJU/skn7Hv4EbKum06H669HWQ71mEhPSOfFn7zItrJtzN4wm49yPwICAZSpTTSaKX2nMO3YafTPlKkyhBCH+LWfW768JSyZDUDJdyUUfVaEp8CDxWEhbWQaXS7pgiXZgl/7KXWXcs+39/D8mc9HP0FVAejILV0Af5ycwEkvVfGLMfU3DGSqwL16gtVAAxP6d+Tnpw1geHbzzfXakHYVTL27PI/7P1gHgNMTaHnyhDRAbS+q4g8fbUB7vFzo3c32QaNYtbsUv9a1Jrq1WxRKKUb2zuTnpw1gbN8OdElzcO97a6nyhLdoxZLAQik45ZhOdEiJczBeEyh1ekh1WLFaZACxEEK0Zd9tKeSmV5bh9ka/qajh9vnB5+faWUt5bupQ+s/+N67Vq8l+7jkShw6Jul/f9L48MO4Bfnfi7zjgPkCVt4pkWzJZjizsFunBIIQIt3jvYsqqy2rNWQdQ9GkRhZ8W0vPGnqQMTsFb4iX/1Xx2PLaDnN/nYFgNfNrHyoKV5FXk0TO1Z+QTNFD3DO1s4ZxjrDy6wMOxnaLfzyYmJHBm/y4M6Z7OFWOy6ZwavSWrpbSbYOrf32zlya+3NPgFFAiyDOYm9+f2/h157NLjeW3JLjbtK6fCHZg1/thuqVw5pjc9Mg7NzXHW0K7c+/66qMdtKIFFos3CjFYY9AaBLn6tEcQJIYSI3c4DVcx4dXlMgVQot9fPza8sY7YjjRHvvYuRFFuSI5vFRtfkro0pqhDiKDNr3aywVinTZVIwt4AeN/QgdVhgLJK9k53s27LZfNdmyhaWkXlKoKuwX/t5Y+Mb3DX6rsgnSO0aCKhMT9Qy/GGSgxHPVvLrcdHvaQcNOIbnpo6K8+qaV7sIpt5bkRdTIBXK7fPz9LxcuqU7+O1PBzW4fYLVwq0T+/Lvb+IbbwVgMxR9OyUzsnfr9D0/UCnjpYQQoq2bOX8b1b7I3y9VG+ZR/sNcvAfyMOyJ2Dr3JX38VBw9Ay1QPouduSPPZVSMgZQQQsSq1F3Kiv0rwpY7tzjxe/2kjUyrtdzisJA6LJXK9ZUHgymv38u7m9+NHkwdey58/ad6y9E/y+CyITb+tdTDcZ3DW6f8pgVj+JUxXlXLafP9wqp9Jg98sD5iIFW1YR57X76TXX+/hLynprH/rQdx560/uN7lNXnww/WBOZhicPvk/kwa1IlEW+wfixU/HVLsvHL9mKgz1De3A1XVdEyRYEoIIdqqqmof76/Mw4zwTLB86fsUf/U86WOn0vOO2fS4dRapI87GtWXJwW18wLvL9+D0tGzGWCHEkW+/cz82S/i4e7PSxJpiRVnC72+t6VZ8lbXrI5fpotqsjnySzD7Qo+EWpQcmJlDliZARDqhQiTyd14cDlVHO0UrafDD1v3X70Dr8Q43ly6fGJ2v3xnQupRRPXTGC84b3INFmiTCFb21JNoPuzmLmjEts1ZahYsnkJ4QQbdoHq/ZgRHjg5q+uonTBa2SdcStJA8dj2B0oi5Wk/mNqJTiCwNjcj1bnt1SRhRBHCZfPhYpw12tJseCr9KEjpLv2lfmwptTu4GZVVtw+d9i2B510J9hqt67vuDOV0/seOk52uoH7vrRaE/YCuLSdmeYU/vXNNsY/+jW3vLqMDfnlsVxes2vzwdQz83LDkkLE8+VT5TF5Zl5uzOezGIpHLzqOl68/kcmDOmO3GiTajIN/YjaLItFmYUCXFB6+8Dg+uOwYqh/4Ld79+w/3UhutqNJDh2QZMyWEEG3V8p0lBxMnharesxHt85B0zLgGj+H0mCzfWdIcxRNCHMXS7GkRGy6S+iehrIry5bWDFtNtUrGmguTBtQMer99b/7x1A86AAT9BWxOjbxOBR1vZrrvykvlT3F4/1T4/n23Yz8XPLOSTta3/gKlNj5mqcHvZWlAZtjyeLx+AHQeqKHN6SY8xdXjNRLgn5mRRUOHmiw37OVDpwevzk55kY2zfDgztUTPbc0/Mq65mzy9/Re+X/4OytXx68uKq6laZpEwIIURsSpyRJ9I1XeUYSWkNTwpfc5yqhifkFUKIeHRL6Yapwx/2WJIsdL6gM/mz8zEcRq1sfrYsGxnja6cj75LcBatRT2ihFFz0HNWPnIjdvwvDaHgYjlvbyNOduMpzL9Uc6oWldWA4z6/eWo3dYuH0wV1iv+Am1qaDqVKnF7vVwFfnaV68Xz42i0GJ0xNzMBWqc6qDq8b0rnebDjffhGvlSgoee5wuv/stWmuW7yzhue+2sS6vjCqPSYLVoFuGg+vG5/DT47qSYI2t7JH4TD9fbSxg7so9FFRUs7WgkuzMUhw2C+cd3x2HrfHHFkII0fSS7JHrZUtiGn5nOdpvxvSdlpQg9bsQomklWhOZ0ncKc7fODQuqOp3dCUuyhX1z9uEp8GAkGqSNSCN7RjZGSI4Bh8XBqIwLuff9tRRXeUiwGvTMSOSCE3owIOSBf8l7cyle1JWcn50Nq2YBCrzhkwQ7tR0Dzafmidyx5hiKfnggYoIet9fPz95Yyby7JtElrXXSpLfpYAoCkWdd8X75QCAYbi7KMOj+l0fZfvElzO9xPE/uTeBAlQeX16xV/oKKan7//lp+P3ct08b25ldnDMRujb2nZVW1jxe+28ashTvwmn6qqg/9wZe5vDz04Xoe+nA9l4zsyR2T+9O5lf6ohBBC1Na7QxJWQ+Hz1/5SS+gxCGW14dy8iORBE+o9hs2i6JUl2fyEEE1v2uBpfLztY0wzvLUoa2IWWROz6t3f7TX5cEE3qty7Di6zGPDigu0M6JLKrZP6cYpzF4VP/JPer83GyMmBM+6Dde+iv/8n3gPbsWkfPiwUkc5/fGcyx5zMrqVfUrbkP3Q483YcOSNQFiuu7ctxbVlyMNupX2teXbST3/xkYNN+KDFq08FUZrIdb4TUR/F8+QB4fH4ymzlBg5GezpzpD/LKmiKqLdFTuNeM//rP9ztYlHuAV28cQ5qj4RazgnI3lz+/mD0lLqp9kY9f0x//9SW7+Gj1Xt64eQyDuqZF3FYIIUTLuXRkNi98tz0smDISksmYcBXFX8xEGRYcOSegDCvuHatw71pTaxywoRSXjsxu6aILIY4C/TL6MbLLSJbtW4bHH30uqEi034an5EQ87tr32qYfTL+ftXvK+PWcVYzau4F//eUvJOTkBDawJ8GIaSxOP5sbXv4Bt8eLPySdQ02OhA5n30nSwPEHlyf1H0NS/zEHf672+Xll0Q5+cfoAbJaWTwfRphNQpCRYGdwtPBgI/fJxbl6E3+tGmz5cucso+ealsO0Hdk2NKWA5HE/Py2X25iqqY5xd3u3z8+O+Cq59aSmeKMFRjTKXl4ueWciuA86ogVQon19T4vRw6cxF7DwQ3nQqhBCiZfXpmBwy1ra2tBMvIvPUGyhbNIe8J68i75npVKz4mMQBtccFH5+dQa8O0jIlhGgef5/0d7qldMNmxH7PrP1WTGdvPAVT6t3O5fOztPNAfrXVjr/OQ6X/LNyOy2PWCqQgvhwJptbM31QYc7mbUptumQK4ZVI/7npnda0ubRD48jGSMylbNIeijx9D2RNJ6NKftHGX1douOcHCLRP7NWsZN+2riDqpcH0TMXp8fn7cW85z3+Zyx6kDoh7/V3NWsb/cHfZEs6HjV1X7uObFpcy7a1KrzYElhBAi4PbJ/bj9tZURJ4ZPGTKZlCGTo+6bZLdw66Tm/S4TQhzdkm3JvD7ldW754hZyS3Nx+pzRN9agtR1f5UDcey6jpn2mvvvSaq1YvK2Yp77Zys9PO3Tfu62wikgzS8WTI8FranaX1FPeZtTmg6kzBnfBakQOBBr68oFAt4ifDOkadb3WmsXbinlvRR77ygIBS2ayjdMGdWHKsG4xJXN44btteCPk4C9f+j5lS96pt5+n2+vnxQXbuXVSfywRrnNvmYsFW4sadXy/hsLKahZtO8D4fh0bvA4hhBDN59RBXbh4ZA/eWZ4X8eFbNIk2C5eM7MnkgZ2bsXRCCBFIk/7KT1/h611f89K6l9hauhWtNR6/B4XCbrFj+jVeZ29cRadgVvWH4ARCsdz3urwmz3+3jVsn9TvYJS/SAyaIL0eCz/RHnH6iJbT5YMpmMfjbJcfz8zdXxvXlA+CwGfz14mERkzx4fH5eX7KTZ7/dRpnLi8tj1oqK528q5IEP1jF1dDa3TuwXNZlDhdvLR2vyMeu0GsXaz7OmLN9sLIiY1vHVRTsjRuuxHt/lMXlu/jYJpoQQog146NwhzNtUyN4yd9j3RiSJNgvnD+/OQ+cOaYHSCSEEWA0rZ/Y5kzP7nEluaS6L8hdRVl2GYRhkJmTy0ueJbM6vHULEc9/r15rP1+9nyrBuACTbI4cj8SXoMUh1tE5Y06bHTNU4c0hXfn/2sThssRfXYTP43U+P5afHdQtbV+byMvXZhfzlfxvZW+bGWSeQgkCiiCqPyezFOznziW+jzrL85Y/7sUToQhdPP88qj8lrS3ZGXPfakl0Rx1TFenwNLNx2gJKq+AYTCiGEaHrPfruNrGQ7D58/hJ6ZiSTZLdT9BlEq0K0vOzORP10whD9fdBxGlB4aQgjRnPpl9OPqwVdz+wm3c+vxtzIq6xx2FYbnB4jrvrfa5Llvcw/+PLh7GpGquHhyJFgNRf9OKfFfYBNo8y1TNaaN60OXNAd3v7smLC14qOQEC1bD4C8XD+OsoeHd+9xek8ufW0RuQRWeCJkC6/KamlKnl6nPLuKDO06iX51f1L6y6ohJIeKdCyu/1B3h3H7K3Yc/0aPdYrC3zN3sGQ2FEEJE9/XG/by6aCcf3HESXdIcXH5iL5btLOHFBdvZur8Cp8ckyW5hQJdUbpiQw8jemTLeVQjRpmwrrMRqGEDte99473t3Hjg0vumGCTn8b92+iN39Ys+RYGVs3w7xX1ATaDfBFARaqE4d1JmvNxYwc34uq3aXYtN+UOBTFo7rkc4tE/tx+rGdsUZJjfjAB+vZVhg5kKo3mYPHx1UvLOH7e06tNbbJ4/NH7KoR71xY1b7wPyCnx8RqqIjjpeI5vlLg8voaLIMQQojmsbWgkrveXsNz14w6OLGkUorRfbIY3af++VuEEKKtqPL40BEmgY33vjc0cBraI50emYlsLaiMuG1DORISbQY3ndy31Vrw21UwBWC1GJw5pCtnDumK22uy7R9PYiSn0Ofm6xpMFlHm8vLBqj0RW5IaGjSnNVS4vMzbVMBpxx4a25SsTGwGeOocMt65sNISw9NQpiRYI2bwi/f4WkNKQvOmhhdCCBFZudvLza8u4+6zBjKyd2ZrF0cIIRot2W7FiNBiHu99b9179t+cOZBfzlmJK878CBAYLzV1VOvNwdfugqlQDpuFjoYPw6Zjyrr39rLdEf8AYh00V+UxeebLTYzetRrnsmU4ly+jY4mJMfp6MGp/lPFMxGi3KMZFaJq0GIouaQ72lYV3AYzn+F7TT4/MxAY/H9E0yt1eFuUeGqeWkWRjbN8OZCRF7ma54cAGPtvxGfur9mNqkyxHFif1OImTup+EJcbmciFE22T6NXe+uYoJ/Tty2eherV0cIYQ4LDkdk/H6wwOeeO5LAXpl1Z4z76yhXVm9uw//Wbgzana/SBJtFl65YQzpSa3XaNCugykIpDYPG70bxazvd0T8BcUzaG7NrmI2rf+GnJFD6Xr//fQZPJgn/rWQHQfCc9vH2s8TpZg2rnfE8904IYfHP9/c6H6kFqWYclw3UhLa/a+6zVufX8aL323nv2v3YrMYB7t/WgyF1/RzxuAu3HRyX47PzsDn9/HJ9k94ce2L5Ffm4zE9+EP6H8/dOheH1cG0wdOYOnAqafbwyauFEG3f37/YRFW1j/vPGdzaRRFCiMM2oEsqvbKS2Lw/vEtePHPA3nhyTtj+d581iASbhWfnb8PtM4nQm/CgBKuB3Wrw8vUnMjw747Cv63C0/ztsTcwDdPeXh7fwQHyD5hKSE9G3PUjHnEN93G+Z2I8/frwhYn77WObCUsCavDJ6ZCSGXculI7P522ebou7b0PHtVoMbIvzBiqZj+jW/f38tc1ftwWtqTL+O2JX0k7V7+erHAiYP6oAzayYbitfi8rkiHtPpc+L0OZm5eiZvbnyTWWfNIju19ZqwhRDx+++avcxdmc+Hd5x0cD4VIYRo726d1I/73l9HVaPve1XEJHFKKe48/RhOHtCJ577NZd6mQoBa91TJdgsWI9AIce24PlGnLmpJ7b921zqQYaEBfr+OOv4odNBcLKqqaydzOH94D9ITbRHTOjbEYTO49+xjefLrrUx9dhGrd5fWWp+eZOOSkT3jSgtfw2ZRDOmexpDu6fEXTMTE79fcOns5H6zKx+2NnIzk4LY6MODyf+vzWbDseJze6gaPX21WU+gq5Mr/Xsm+qn1NWXQhRDPakF/O/R+s49lpI+mQktDaxRFCiCbz06HdSIhheE0kiTaD607qQ4I1+v4je2fy7LRRfP/bU/ndTwdx8Z6lTB3akdsm9eNvlx7P8vvP4K6fDGoTgRQcKcFUDP38DENhs0TeLnTQXCxS6kwKlmi38NaMcaQ64guoaiYVvnZ8Hz7+2QQuHZnNTa8s45dzVpFfeqjF4qHzhjCkezoJESYfjkYpyEyy88K1o2IvkIjbXz/bxHdbiuLq3+v3W/E5s6ned97BZSXflbDlvi2sv3k9G3++kfyX8zGrAsf0az8VngpmfDEjYgYdIUTbUlzl4eZXl/HguYMZ2kMeZgkhjiwOm4XXbhxDkj2+gMphNRjZO4s7Tz8mpu07piRw7fg+3LjqfR6degJ3nzWIs4/r1uZa+ttWaRojxpYpgJ6ZSRGXxzMpmMfnp3eH8ONkZyXx8c8m0C09keQG/rgcVoNEm4UnrxjBecN7AIFxNVNHZ/P1bybRMzORs//1HY9/Huhrb7MYvHbjGMb16xDzH67WUOr0cMPLy/h64/6Y9hHxKXN5mfX99oiBVNWGeex9+U52/f0S8p6axv63HsSdt/7QBtqOt2wkfl8KRZ8Wse/tfXSd2pXBTw+m7/198RzwsOOxHfiDTdumNtlbtZfl+5e31OUJIRrBa/q5/bUVTBnWjfOD9bsQQhxpju2Wxps3jyXNYY3aWBEq0WbhpAEdeeHaUbWmGGqIv6ICw27HsLfduVLbZTBVWe1j9uIdXP3CEq6p6MfVOzO54T8/8PGafDwRxqrUuOnknKjBSNqJF5F56g2ULZpD3pNXkffMdCpWfEzigNpJKcb2zaJzauRmxeysJL75zST+cskwhnRPw2EzSE6w4LAZJNktJCdY6JBs52enDWDBPZM5Y3CXsGOkJFj59ZkD+eTnJ7O72Mmpj8/jrR92Y7MYvHTtaP4+dTjDszNi+sP1mJrlO0u4/bWV3Dd3bb1d0ET83l2eFzE7ZPnS9yn+6nnSx06l5x2z6XHrLFJHnI1ry5I6W2rce4dTMLeA7ld3J3VYKsqqsHeyk31bNp4iD2ULyw5u7fa5mbV+VjNflRDicDzy3x+xWw3u/smg1i6KEEI0q2E9M/j8lxOZNrY3yXZLWGOCoQK9sI7pksL/XTSU56eNiin7diizpARLVtuei69dJaDIL3Xxz6+28MGqPRhKBRM+JEE1/LixgMXbDvA7Yy3Txvbm1kn9SHXUTpN4/vHd+OPctUTrFtjQoLlku4UZp/Srt4x2q8E5w7pzzrDubNlfwY/7Kih3eUmyW+iWnsiJOVkxReTdMxJ54vITWLW7lIc/3sCshTu4f8qxnDW0K+mJNq6dVffGPDqX1+Td5Xvw+jSPXnxczAk7RHRaa577bltYq1SsafYDB7FTuS4Tv9dP2sja2fosDgupw1KpXF9J5imBeWk0msX5iyl2F5PlaNsVixBHo7eW7Wb+5kLm3n5SXE9ehRCiveqa7uCBc4dw91mD+GTtXr7fWsSeJStI6dGN3sf04sITeh5Wd2dfcTGWzLY9P1+7CabW7SnjqheWUFntxYzS+FSTVeTFBYH01HNuHkfX9EArknvTZgofeojz0ofzYYehuM34WmkshqJruoNx/cLng4pmQJdUBnRJjes8dQ3PzuDtW8bx6bp93PPeGnI6pLBsZzEeX3j5qzbMo/yHuXgP5GHYE7F17kv6+Kk4eg7BHWZJgwAAIABJREFU5TX5cHU+o/tkckkrTmx2pKio9lFUEZ5AIp40+wBmlRNLig0VoaXRmm7FtbN2tj+7xU5+Zb4EU0K0MSt3lfCXTzcyZ8ZY0iNMwi6EEEcyh83CRSN6ctGInuz+5jkyhl9C6qlDDvu4Zkkp1jYeTLWLbn65hZVc/txiylzRA6lQ1T4/eSUuLnr6e0qKyyl4/HF2TZ9O+vnn8ciTv+KE3plxJXMwFKQ5rMy+cUyrtOoopTj7uG588cuJKKUjpmCPpWuZy2vyz6+2SBKDJlDh9kUcABlPmn0AS1IKZqUPHSG495X5sKaEP++o9IbP7SCEaF711ZsF5W5ue20Fj148jP6dD+8BmhBCHBGa6F7TLCmWbn6Hy+/XXPPiUqo8vrB19bXEmH5NYbmbGfe+zD+S99L3ww+wduoEwKzrTuT211ewKPdAxMAklMNmkJVkZ86McXRLT2yWa4yVzWKwbk952PJ4upYdqPKwbGcJo/u07T/Mts5uMfBHqChC0+zHNG9Z94Eoq6J8eTnpJx5qBjfdJhVrKuhySfi4umRr8uEVXgjRIK01i3IP8Oy32/hhRzEur4lFKVIdVs4f3oPrTupD7w7JVPtMZsxezhUn9oo4DlYIIY46h9vwULYH1rwFpTtwbF2NLdkKa96GweeBte1NNdHmg6kFW4sodXrCAtzype9TtuQdOpx5O46cESiLFdf25bi2LMHRM9Cs6NWwKqsv6u4bsWYcCoQcNgvPTxvFZ+v38cz8XDbvq8Dr99dq9UqyW0hJsHLTyX25/MTssPFXreH7rUW4feHBXzxdy1xek5cWbJdg6jBlJNkiBlOhafaTB01o8DiGI43O53Ujf3Y+hsMgZXAK3hIv+a/mY8uykTG+9qzeHtND95TuTXYdQohwn63fx0Mfrqfc5a01KaVPa0qcXl5bspM3lu5iWM90Oqc66JLq4I7J/VuxxEIIcQTY/i0s+Afs+B7QYHo4mPLt4x/h4zth5HQYe9v/s3fe4VHU+R9/zcz29IQECAkk9N57tSsoFiwodlFEvTv11PPneXd6d57neZxdURGwYEHsIoiKSpcOobeQkEJIr9tn5vfH0sLOJjsQSALzeh4fH3Znd7+72Z35vj/l/YGYpuOW2uTF1FtL9gVNWNbV5C+IfLAqm8fH1nZWEkWBsb1aM7ZXa/YWVvFdxkEKKj14/TKJUVaGd2jByI4tEJtQE/H+4hr8GuVgekrLVBX2FBplYidDTqmTDQfKqHT7sUoifVJiWZ9dxvF/keNt9gVRwpbeD0E04c7ahPtABnHn33Xc0QqSI5PEK+KQolQK5hbgLfQi2kWi+0eTem8q4gnDmge2GkiCPfy+PQMDA328syyTaT/swu0LXVPuk1VAZV1WGYIAc6cMa1LXCgMDA4NGR0+Zn6rCj3+FtTPB59Q+xnt477rmbdjwHkyaB+3C608/3TRpMVVS7WFtVlnQ7XoyMV5Z4cM1wWLqeDomRfHgRU2/zr3G68ev0TSmt7TMqVEyaaCNrKgs2V3I9F/3kZFbgUkUkBX1aO+c1qkievAExIg4KlbNpXj+NASLHWvLjkQPm1j7QNGHJWEpAPFj4okfU3e20GFycGfPOxvibRkYGGjwxfrceoXU8agE9gB3vbeW734/irYaMwgNDAwMzjn0lvktfBw2fhBaSB2P7A38N+cauO0bSB18cmtsQJq0mDpY4cZiEvGeICD0NvlXuf34ZKXJTUzWS6TVhEkSkU+YpaW3tCzC0qT/7E2GshovN7+zmuySmqPZ0WD/Pm3qs9kHBUF0ITn2hfV8oiDSwt6CIa2G1H+wgYGBbiqcPp74cgsejVmFdfXnAtR4/DwybxPzpg4PeqyBgYHBuUCFy8e8dTn8squQIsdwrGu8tDm4nhsGpTCmc1LocREZ84KEVNpLVTh9sP/BSCIsgce9s8HLnAwfv95xuG/c54I518JDGWBvXLe/Jr2rdvlkTXGrNxNjEgVcPrnZi6n0FhGYJAHPCYklPaVlggCdkiL1vfDBzbDyVcj8Fbw1IIiBL26fG2HQ3RDV6tTfXBOjrMbLFa8up7DKfbikp4ER/NhT30MQwnhuVcAs2nlxzBvGjDADg9PEp+ty0LrWh9Ofq6iQkVtBdkkN7RIMgxiD5klxtYdP1hxgfXagnN1hlkhrEcHNQ9vStVV0/U9gcE6SVVzDSz/tZuHWAkQhsN9GioFKlS3bCli2twirSeKukWncPbJ97aG9qgq//EszIyWr8PJqL38eVYfhhOyDjXNg+O9PwzsLnyYtpqJsJs2SS72ZGL+snhXZmOEdWmAzSdR4gk0owi0ts6kyt3SwBT1ek/3LYMFjUJ4Ffi+ox72utxpWvgIrXoH2Y+CKFyEm5RTeXdNBUVRum7X6tAgpAbBbJNp3XcIhpQx33WaSqIqEqtipOXAf41/czvg+5Tx6SZej89MMDAxOHUVRmbEsE9cJ5X16+nMVVeXdlVk8Nf7U56oYGJxJtuZV8PLiPSzdXQRQKzu7Yl8x89bnkJ4Qwe8u6MS4Xq2MoJ7BUdbsL+XOd9fg9irIIXqkajwyNR6Z1xbvZeGWAj68ewixDkvgztx1UH1I83GPDbfw/AoP9w+yEGsL8Z3zu2DlazD0ARAbL2HSpBVGapwDn0aPkL4mf0iOtZ8V0+glUWDyyHReWbwHt0YpSv2lZRAvKSQ8MoWc3r2Jv+N2HENCzM7a9DHMfzjwRQ2F/3DR297F8OZIuGMBtOyu5y01SVbuK2FfUY2mkKqv3CcUFkkEAYakx/PU+O60S7iQL/Z8wextsylxleL2u+G4LJUqWwEBb+kwfGUjUeUIQOHLjbks3nGIj+4ZSrfWRqTQwKAh2FFQSfWJKX/09ef6ZJWvNuYZYsqgWfHNpjz+9HkGHr+iGbxWVHD7FHYUVPHovM0s2V3Ivyf0Drmn2ltYxczlWfy88xDVHj+iIBBjN3NV32RuHZpmBALPIjJyy7l91ppAJioM3H6F3YeqmPjWb3z1wAjsFgl+eyNQrqfBwGSJ89JMTFvp4ZkL6vjeeKsha1kgsN9INGkxFWE1cXmv1ny9OR9Zqf0rDzcTY7dI3DMq/Uwu+7Ry4+C2vP7rPtAQU/VhN0s8fHVvOv11MRXffEPBP59BMJuJv/12oi8fh2g5HCnYtbB+IXU8qgyuMpg9DqYug9hU3WtrSry1dF/Iwcj1lfsAmCWBFpHWo9G9KFvge3zL0HYkH2fRP7HrREa0HM+4t97Ha12HYK4EFJAj8Vd3wV/VA6hdxiorUOb0MfGtVcw3Gt4NDBqE4mqv5ubwZPpzDQwaCkVV2Fi4kfzqfFx+F5HmSDrHdaZjXMPY8C/Yks+fPs8I23DF5ZP5dvNBFBX+e13vWoHYTTnl/O2rrew+FDxqpsrtZ8ay/byzbD9D28fzzNW9SI03rl3NGY9fDimk6go6+2SVrJIanvpmG89f1xsObUXbyivAP863MmJWDQ8OsYRejOKHkj2GmKqLyaPSWbD1YJCYgvAyMYqiMmHA2VF+BhAfYWHW7QO5ffaasE+AEBBSV/VN5tr+KQiCQNwNNxB73XXUrFhB6bvvUfTCC8RNuonYCeMxfTY5SEiF1QzoqYQv74U7FzTY+z3TFFS4WbO/NOh2PeU+flmle3I0M28fVO/r3fvBemoq2iKrtQVo4GT0SMgMWLXHz13vreXHh0cbJRcGBqeIlksq6O/PVVQVVVWN36TBKVHhqeDLPV/y/vb3qfHVAAFhJYkSsiLTNrotk3tO5uJ2F2OWTm4GZm6Zkz9+ullzH1HXZtjlk/ku4yBD28dz3YDAdWvRtgIe/GRjnXsS7+Hg4rI9xVz+yjI+umcoPdvEhDzeoGnz/dYCTbOecILOHr/C15vy+MsV3Yj21tT5Oj2TJK7obOK55V66JYYo45P94K485fd0KjR5R4YeyTH0TonFIum/ONnNIjcNbkt0Exi425AMaZ/AzNsHYZYEwvlU7GaJGwal8Ow1vWpd5AVRJHLUKNrOfIfUme/gzcuj+P4LULxezec50gwYElWGvPVQul/nO2o6bMmrCJTknYCech8V2JxTXu9xW/MqyCyqCaozrlzzJaWLZxAz9AZSfjeHNvfNJqr/OFx7Vh89RlEhr8zFpjBex8DAoG5iHWbN4Ojx/bnhYDNLhpAyOCVW5q3kks8u4fVNr1PkKsLpd+L0O3HLbmp8NbhlN7vLdvP3VX9n7BdjyanMOanXeXdllmaQOpzrj8sn8+rivaiqysp9xfUKqeNRVKh0+7lpxm9kl9S9kTZoukz/NbiC50jQOf7i+3B0GY5osSFIJhwdhwS134iCwOfrcsFSf4by7+fZmLHBS15liAyWZAJr4443avJiCmDGbQNJirZh1iGobGaRfm3j+Mvl3U7jyhqPER1bMKFfG/q1jcVulnBYakdNTQJYVJnB6fG8cUt//n5lzzqHSto6dyb5n/+k5UgrouDTPOax4RamrfRQ7q7DlEFVAgPVmilVbh+KRuG43nIfrTLBE5m5fP/RaN0R9JyMPH6ZGcsyw1qPgYFBaLq1jsan1N2f69y9CsXnRpX9uPato+yXWbWOFYBBaXXPijMwqIufD/zMH375w1HxVBdOv5MiZxETv5tIVkWWrtfx+GU+XpMT1Bes5/pTWO1hbVYpU+esD5ndOvjeQxx44TpyX7uVQ58+hTt327H7PX4e+GiDrnUbNA0yi6rJ0hDCeoLOLp/M7JVZ0KIz1JMW6BgvMrGHmVfWhAjmiyaIa9x2niZf5gcQYzfzze9GHp35U99G1WGRGNWpBa/c1A9TM7dDrwtJEpnQP4UJ/dvw7eZ8NmSXU1rjJcIqkRJpYuBzf2TUX79CigzTqrd4D4KzOOTdYTUDyl7Y/DFc9u+TeEeNj8UkNowdfz3C3+tXWLDlYFBWSs/JSFHhp+2FuH1ybatRAwMDXTgsJq7pl8K8dTn4T6E/994x7c/ksg3OInaV7uLxpY/jkcOdZggKCtXeau5cdCffXv0tkZbwxp4s3lGIVipWz/XH7ZP5z/e7kDWMmsIdJ7C3sJpdBVV0adW4WQUDfeSWuTBLYpCI1ht0Lqxyw5D7jo3dqYO/jbHyQYZ2oB+TBTrU3fJzumkWYgoCvUJfPzCC77bk8/z3uzhU6UYShaORFaspIJr6tY3l3tEdOK9L4llfblHj8RNpNeGwmJg4qC0TB7Wtdf+BL9tRs3w50ZddGt4TVh8KKPw6CKsZ0F0RmB3QDD//VtE2zVZIvXb8LSLqmIsAlLu8mrEYvScjSRQoqfHS5jhjCwMDA/1MHpnGlxtyg8QUhNefG2M3M6x9wulansFZzmsbX9MUUmXLyiheVIy30Itkk4geEE3L61oiRQSuESoq1d5qvt73NTd3u1nzueUaH/4iJ4pbRjCLFOdW4tHIJum5/qgqbMurCHIW1tNf7PMrzFyeyfPX9an39QyaDk6vrOn8qDfo7PUrkDYSbLFBYirrodoCOzVGxP0XDQdjky0gyMLcM50umo2YgkDW4Jp+KSzYUsCNg1KxmSVKaryYRIGESCuXdG951jvEKIrKuuwy8std7C6oomWUlZ5tYuioMYg36qILqVq8OHwxFUZELKxmQFUNlPsJzS9b0r9tHDZz8CwvPXb8drPELUPb1fk6bq+iWXap92QkiuAKo6TQwMCgbjomRXFJj1b8sL1Al7kPBMrKn76yx1kfwDM4PRS7ilmZvxL1hFBe8cJiihYWkXJ3CpHdI/GV+cj/IJ+saVmkP5mOeDiI7JbdvLv1XSZ1nXT0O6iqKt7sSqqW5uLeXYZgOna9Hu2VeUtx8CEeluLniAel3uuP1ogWPdktWYVvNucbYqqJ4PHLLNhykFnLs8grd+HxydgtEp1aRjFlVHvGdE5EFAWibSbNWLneoLPNLAWC7qMfg0V/1hzcWy+CBAPu0P+4BqZZiSkI9LT8tq+E/93Q56wzlqiLshovc9flMHPZfpxePwiBTfT+khre/y2bTklRTB3TgUt6tMR8uLQx8oILKHzpZVSfD8Ecxmdli6Uui8oj/P08G/3fquaRYSGyLyZLo0cJThZRFLh7ZDovL94TtKEKt9xHUVWurcdBMspmwq9RHqH3ZCQrKtG2ZvczNjBokky7vg+3vONmc265plOVFjazyKOXdOHSHq1O8+oMzlbm7ZoXJMRll0zhV4W0mdyGqN6BKL0l0ULq/ansfmw3FSsriBsdd/T4Sm8l6w6tY1CrQcjVXopnb8Nf5ET1KaCC6j8WdDMBXZB4HDt/ROURnOxC0X390UJvdYXXr+Dxy1hNtY/3+GUqXX4clkBPuBGoOH34ZIUXftzN+6uyQVWpOS5AW+OVKa4uISOnHJtZ4sELO3FJj5ZB/d6gfwZseovDLSgD7oDslbDz25AzpzQx2WHiBxCZdDJvu0FpdruwH7cfYkj7hHNKSP2ys5D7P9wAqLhO2OAf2fBvyavgT59v5t8LLXwyZSgpcQ58cQn82PNC7vvPTxR4Bbx+BatZpF1CBFNGtWdsr1a1T2CJXQMWk/VwfDNgrySN7FTrfqfydhudiYPa8vLiPZr31VfuY5YELu/Vmhh77e/n9vxKPlydzb6iGpzeQHmmlh+I3pORRRKJj6ij5NLAwCBsLCaR9+8azMjnf0Z2Burztcr+IOAWq6rw7NW9zqrxGwZnntUFq4NK/Jx7nCg+hegBtUubJJtEVO8oqrdV1xJTHtlDRlEG/SP6cOjVjSjVvkBjUh1EIAACrxHBozjZrOP6IwrUarU4uj6d2S1JPLw3MUlUuHzMW5fDO8v2U1jlxiyJyIqKWRK5dkAb7hqRTvvE8PrCDMLD6fVz28w1bM2vqDMjX+OVqfHK/HvhTjLyKhiUFs/yvcE99uEGnSOsElPHdAj8QxDg6unwjQm2f1V/hkqQwGSF69+FjhfqfcunhWYnpr7LOMj4PsmNvYwzxvyMfB6dpz0L4kRqPDJur5vLX1nOJd2TmJ9RAElDcVUfe6zfI7M9v5Inv9zCk19t5e6R6Tx4YadAyZk1EnrfABs/BLVuURWyGdASCSMf0v0+mxLxERb+d30fHgnzcz+CJAq0irbx9FWBBltVVVm4tYBXFu8hq6QGn6zWsqINZa4Y7snIYhK5ZWi7s9pkxcDgTPPp+hxS4hzMmdyL91dl8+XGPCRROFrWIisqUTYT94xqz/UDUolxnDuBPYPTQ6UneEaOXC1jijQhaJgZmWJMuLJrR/BlVabcVU7RjAyUai/oqFS1I/A8DiZTQ26Y1x9REFA0xJre7JZfUbFKIk9/s42P1xxAFISjg2CPZIf9iswna3KYty6X3ikxvH5zf5KiQphgGYSNX1aY/O46tuRVhJ2JPzJnbFiHBMxSsJiG8HpMRUGonc2XTHD1G9D5Ulj2AmrBNkBGOL5ayuwItJB0vwpGPQKJXcJa85mgWYmpCpePNftLefmm5p35CJeM3PKwhdQRZFUNRHfW5x2+RXujfSSN+/bSTDLyynnrloFYTCIMvR8y5oK/tpgKvxnQCp0uCXu9TZXLeydT7fHz1Dfbwvr8LZJA61g7c6cMI9pmRlZU/vrVVr7cmKc5IRzqDhqGczISgFuH1d2bZWBgED5ZxTW8+ONuPrtvOB0SI3l2Qi/+ckU3dhyspMLlwySKJERa6N462ig7MmgwLFJwdYEUKeGv9qPKapCg8lf4MUUGb99MJQpyuUdTSK3JzeDZX6azuzgLURTplNCOpy78PX1bB8bH2IC7sfI0rrCuP4qqal7D9FZXdEqK5I5317LxQFmdG3q/ouJXVDYcKGPsS8v4/L7hpLUI06nYQJMPVx9gU472517f4OafdxZiN4soKprzyurCbg5kpSymE/anggA9roYeV5N7/QW0Gp+O2VQdyFTZ4qDtMOh7E9hjT+VtnxaalZj6YVsBwzsmEGltVss+af67aFfIjXxdX3Q9uHwyq/aV8PDcTbw2qR9CUteAGNrzI/h11K4CiiLB6L8iNtN+qROZOKgtqXEO/vXVOvaVePCpAjK135vDIqGqcE2/NjwxritRNjOqqvKXr7byVR1C6lSxmUQu7tGS1jGGi5+BQUMgKyqPztvM7y7oRIfjSokcFhMD2hnzowxOH8mRyWwr2VbrNkdHB4JJoHJ9JTGDY47eLrtlqjKqaHldy1rH2wUzkftUVG/wnqHKU8Odn/0f/7rkj4zvej5e2c+a3M1YjxNxEgKjMBEFVIWx5rr2z+FWVzjMEjazxIbsMk0zCy1kBcqcXia+tYqFD402ytxPElVVeWvpvqDWEQjP2h4C7tm7Cqopd/nCFlR2s8h5XRO5/7wOIY/x5eXhyvNguvltkJrHfrJZqZLvthxkQv9zozb9YIWLNftLNe8L94t+PHWJL7dP4Zddhfyw/VAg7TphBsweC4U7whZUqtmO09uTwue+IfXNsZhbNfNmbFWFjXMYvmwa33kK2WVOYKb/MlYp3anBjgmZFmIlN8fs4eorriSiS6+jD12wpYCvNmkLqYYQwVaTSKeWUUy73nBAMjBoKGYuz0QSBe4cntbYSzE4x7iu83WsyFuB03+sV0RySCRdnUT+nHxEm1jLzc8cbyZ2eO3ovKzIDD+oXfaUWZoDwNXdLwLALkqMSR8cdJwKXIGFjwkxHFUH4WS3nD6ZbXkVaFSK1XmtVFQorfHywo+7eObqXsEPDpeKPFg7A3bMD4x0EQSwx0OfG6H/beA4e4Moq/eXUu4MbtXQY22/PruceVOH8cCHGyiu9tYdPFZV7GaRK/ok8+9retWZ2a9asoTIUSMRmomQgmYkpspqvKzPKuP1Sf0beylnhA9WZWv66un5oh8hHPHl9Mq8uWRfQEyZbXDnQph3B+xfAn53oE5VC8kKAggX/JWIofcTM2sWWTfeROr0N7B163aKn0IjIfvgy6mwa8HRRsguopPnLTOCj60CPvsYzn8Shv8egFcW79G0Kz8ZEXwiDotE/7ZxzLhtYJD7kYGBwcmx51AV03/dx9cPjNQcWWBgcDoZ2nooEeaIWmIKIHFcIlKERMHcAryFXkS7SHT/aFLvTUU0HyuREhEZGTmCKBLQcuRtH5+KKIg8/N2/uLLrhfRr04NYW/CgXBsCIzCFLaZOJThoN0skRlk4UBocsA3nWulTVD5fn8eT47pjt+i8FhbugO+fgAMrA4FT+bj3W30Ifn0Ofn0WOo+FS5+FmDb6nr8ZMHdtjuY+RY+1PcC6rDIWPTyaLzfm8eav+yiu8eL1yUcF8pEZsP1NNdxUs4vx146tt0S6eskSYq+6St8bamSajZj6YXsBozq3IOIcKfFbua9E03pS7xddj/janl9JZlF1wC3HbINJn0Deelj5OuyaHxBOqkJAPQGCCAPvhkGTIaYNApAweTLmNikcmHw3yf9+lsgxYzTXVe708vPOQkprvMiKGhh42SGBdgmNXAOtqvDVfbDzu/DLHH0u+OVZMNnYlnIDB0qDnWj0imCbWTz697dIIirQNzWWqWM6HJ31YGBgcOr4ZYVH5m3m0Uu70Dbh7J5TaNA0EQWR23vczmsbX8Mtu2vdFz8mnvgx9WRIBBgq9UczxQNEWSP44ubXeGP1R/zp+/9SVFPK+R2G8PxlfyIxovZzxwhCLT0WymTgVIKDdrPEpT1asmDLwaD79FwrBQG+3ZzPDYNS63y9WmQugY9vOhwoDVGaduTav+Nb2L8U7pgPLfW1UDQ05e5yfs75mWJXMV7ZS7Qlmn5J/ejZoudJ9W/mlbs0370ea3uPX+FghQuHxcTNQ9oxaXBb1meXsSarlNJqLxaTSFKUlct6tibJJLP30pfx7L4OW5fQxhGKy4Vr3Xra/Pe/ut9TY9JslMn8jIPcOKhtYy/jjFHp0nDKQ/8MB10D9BSV+ZsP8oeLOh27sc0AuH4WOEshbwO4ywMzpBwtoO1QkIKdrKIvuxRTyyRy//AHWtx3H/GTJh29b3NOOTOWZfLj9kOHbVUVVDVwwlZU6NUmhqljOnB+1ySkxhAMGXODhFTaS1U4fbD/wUgiLIE1vbPBy5wMH7/ecVj8+Zzww1+Zk975lEWwzSwytmdrOiQGnjvGYeG8zoln/UBqA4PGYPqv+4ixm5k0+Ny5vhg0PW7udjO/5PzClqIteBV9ZXaKqvBC5Zt8mJ7AU7lTae1rEXRMpxZpvHj5nwHYW5LNH+Y/w9OLX+X1K5+qdVxClI2pfVtS5vQRaTPx7eZ8Cqtq27afTIUMBK7zoiBwz6h0kqKtLNpWECQA9VwrnV6Zzzfkhi+m8jbAxzeGPxxWlcFVCrPHwdRlEHvq5whZUdlxsJIypxdFhVi7mW6to4PNGA6ztXgrs7fO5tecX5FECbffjYqKWTRjEk0k2hOZ3GsyY9PHYjeF30PtCVGSp9fa/vjSPkEQGJgWz8A0bfHfYso9FL30MqnT3wj5fDWrV2Pr1g0pWsPgrAnTLMRUSbWHTTnlvH3rwMZeyhkj1A9L7xddj/jyKyoHK93adzriodNF9T7H0cP79SPto4/ImXIvvpxcWjzyCE/P38Fn63Px+OWg5tUjs1zWZZfx4Ccb6dwqivfuGnxm54mpKix5XvNEK6vw8movfx4VYlAxgOxlb1Y2spoQfJeOv4Pbp5ASZ+d3F3Sq91gDA4OTZ1t+Be+uzOLb34803PkMGhWTaOKNC9/g3h/vZWfpzqAMVX248XDAepDfp/+b57Mfpr0ndH95x4R23NDzMuZs+ibovj01bj74rZxOLaNoFW2jzBks7PRWyIhCIBs1cVAqdwxPp22Cg9d+3qPpIqc3YFxaE6bwVGT4aGLQ9T2sYKmnEj69Hab8Et5raVBS7eHjtQeYtTwLj08+Wl2iqoEM2y1D2nHrsHYkx9oP367y4voX+Xjnx3hlLwpKLZdGn+LDp/g4UHWA59Y8x9sZb/PuZe/SKiK8fvVYh7Zxhx5rewFoEVHHnujE15w4kZLZ7+IKvo3UAAAgAElEQVTcuBFHP21X7uolS4g8T7uiqSnTLAbULNp2iDGdE/XXxTZjWsdoz1A4/oseDseLr3AIFa04GSypqaR9/BGurVuZ8vgsPlufg8sXLKROpMYrsy2vkqteW0G1p/4hwg1G3gaoCi47AHhsuIVpKz2Uu+tYvCrjrNH2QdL7dwiVmTQwMGgYPH6ZRz7dzBPjuh3dwBgYNCYOs4NZl83ilu63EGGOQBL07XlUQaVGdPF/bV+myHTMwGpvSTZvrfmEg5WFAORXHuLrHYvpn1y7dE2wiFx8Yw9W/t+FPHl5t5BOeXoFjyQKXNazFXnlLh7/PIPLX1nG28syNfcCeq+VPjnM0TF7fgiZkToSLA2JqkDh9kCv1Ukwc3kmw5/7mdd+3ktpjZcar0yV20+V20+1J/D/mcv3c/60X/nXdzuQZYV//vZPPt71MW7ZHRBSdeDyuzhYc5Abvr2BQzWHwlrTyI4tsJmDJcDx1vbO3atQfG5U2Y9r3zrKfplV61iHRaJ/u7ig5wiFaLWS+MD9FL34Eqqqsr+4hndX7OelH3fz0k+7eW/lfvau2hiyPaQp0ywyU99tyefWoWmNvYwzyqQh7Vizv/ToPKgj6J3hoHeAXouo8KMM4SDFxjL/tj+z4qfduuZleWWF/HIX976/jg/vGdqgawrJ2ncCZhsaDEyWOC/NxLSVHp65IPSwwEi0+6z0/h0Mu1cDg9PLK4v3kBLn4Nr+Z19zuUHzxSyaebD/gwxtPZT7froPWQ0WFWXLyiheVIy30Itkk4geEE3L61oiRUggQI3kZFbSVzyeH9gPRFgcbMrfwYy1n1LpqSbaGslFHYbx5Pn3135iUcDePQGHJDIoLZ6cUidfbcrDJ9deg94KGVWFwenxRNvMRNvNRNvMLN5xiOlL9gVlp/ReKw+UOun7jx/okBhJh8QI2idG0iExkvaJEbSNd2A+MtR+xcvgrdZ8jseGW3h+hYf7B1mItYXIUMt++O0NuPLVetd0PP9esIP3V2XXOxTXe1gUzvktm/Wl35EjfKsrO6moCpXeSib/MJmvr/oa6fDfZVt+BdvyK6ly+7GbJVrH2hjZsQU3DExl2g+7NJ8rbGt7q4mRHYNLSusi8sqr+ObTn/ni+UXsqAqo6SOtEVZJQOl9J/1/LuZeJYYxnZpPf3iTF1NFVR4ycis4r0tiYy/ljHJB1yQsJjFITEH4X3TQJ74irBJD2zesFajbJzN9WRbuEBG2utyAPH6FDQfK2J5fSffk8Otnqz1+vt6Yx5qsUsqcXhxmE23jHVw/MIVOLYMdjI5Sui+0ayHwj/OtjJhVw4NDQgudPsJe1tMF/wlJX71/hzrXaWBwClR4KliZv5JSdymKqhBtiWZwq8G0jmzd2Es7Y2w8UMbctTkseHCUUd5n0CSZs30OfiW4MqN4YTFFC4tIuTulll161rQs0p9MRzSJKILKyqjNVIlOohQHraMSmX713+t+QZNA5LBkBOnYtcthMSFq/D70Cp5Im4mJJ/S8R9pMTF+yL+hYPddKu1ni0Us6c2XfNmQWVbOvqIZ9RdWszixhX1ENBZVuUuLs9In389+cdSE3vGEFS1U/ZHwK418J1OWFwZzfsnl/VbaueZMun4/d3s8RTMFCqk4RDciqTJGziF8OLKWytBPTf91HbpkLQQC/rCKJIIkiogC3DUtjTOdEftpxSDNDWJ+1vc0kcs+odF1ip9rjZ/K768joeDmusuDPxC2rIJr4bX8pGXkVDEmPZ/otA7CZm35VWpMSU16/wrb8CsqdPgQB4hwWNuaUc0HXpGbxYTYkkigweWQ6r/2yVzOjE84MhyOEK75sJokxnZMaZP1H0HLrOUI4bkBeWWXm8kz+d0Pfel8rq7iG6Uv28fWmPERBwHmcEJVEeH9VFp1bRnH/+R24tEer4E2Ut+6m1J5JEld0NvHcci/dErUrZG+VfuQ9dRxaQahw/w6iIHBRt5bBT2BgcApsL9nOe9veY/GBxZgEEz7Fd7SR2a/66ZfYj7t63sXQ5KGIQrOoAD8p3D6ZR+Zt5ukre5AUFTrLbGDQWBS7ilmZvxL1BL812SVT+FUhbSa3Iap3IOBmSbSQen8qux/bTcXKCuJGB8quBAR+jFnFhLIL639BEUxxNqLOr23k0D4xQnMYq94KmTQNl970FhH0SI5mw4HyoPvCvVYqqsp1A1OJsZtJjLIypH3tfmW3Tya7xEnh3vXIuRZMcujy+XCCpSj+QKmgpX7XYY9f5rmFO3XPm5QidoMQXHIYjogGcPqdPPLDy/hzpwYH42U40nj19rJMQMUkiZqmWXUhCNoCuS7cPpnr31xJZlENnjBezumVWbmvhJvfWc3H9wwN6SPQVGgSYiqv3MX7K7P4cPUB4JjoV9TAH2Bcz1YUVrnPuQvflNEd+GnHIbbnVx1NAZ8s9Ykvq0nkrpFpDe6g9+aSfZrZtXDdgGRFZX7GQZ6+sgdRdZhRLN9TzJQP1uE5br7B8cgKyIpCRl4FD8/dzLheh/jPtb0xHReFwxYT/MAT+Pt5Nvq/Vc0jw7TLIVPFEvrFuvmtVPtkW9/fwSKJ3Dq0XZM/cRg0HxRV4dnVz/L13q/xKl4UVcHDMXcunxLYYKwuWM2W4i30SOjBaxe+hsN8drpHTlu0i26to7mid3JjL8XAQJMVeSswiaYgVz/nHieKTyF6QO1KDckmEdU7iupt1UfFlEf08kvM2vrFlElAirGSOKU34gl96Z1bRpEab2f3oeDyuHAFT4RV4p5R7TVfeuqYDjw0d1OtwOcR6rtWigKM69WaGHvofYHNLNGlVRRd/FFgkg6LCW3CCZaqgojf48Ichpj6fmsBqqrfUt6SsBRBqv131yOiAWRzNk61CAhdaXREQJnxYxbAp4aXsBAFiLCamHvvsDo/+xN5eO6mgJDSIdw8hxMsf/16K/+5tnfYj2sMGlVMqarKvxfu5L2VWaiqijfEjIQfth/ih+2HePCiTtw3psM5U5ZhMYm8P3kIt76zml0FVbh1Rg/CRRAg2m7mliFpDfq8qqqyr7BG8z49bkBmSSSr2EmvFG2xszqzhLvfXxt2T5bLJ7Ngy0G8ssIrN/Y79n1KGYiauwZBDt2I2jFeZGIPM6+s8dIrSeOEa4ng4cER3L5Y1NUjdgSzJHDbsDTdjzMw0EJVVf609E8syVkSVv290+9kc/Fmbl14K3PGzdFltdscWLO/lG825/P9Q6MbeykGBiEp95QfDXIcj1wtY4o0IUjBeyBTjAlXdu2e3SpTDYJFRPVqXItMIqBi796CuAkdEW3a28H7zuvAX77cqhkUDadCRhQELumhXWlxYbeWdG8dTUZehe7sSKTVxCOXdA7vYFt0wOa8HuoLlqqyj97P/UZcpJ2UeAepcQ5S4uykxh/7f6toG5IoMP3X4EByOEFk0VoQ9Lp6RHRgoSZE6yFkX/1tGz4kLKKAXRQRBDSF7REirBJRVjOfTBlKWovwZ4LmlDr5eWehppCqb/Cz26fw5cY8Hru0Cy0iG7anvyFpNDGlqioPzd3ED9sO1atUj9z/6uK9FFZ6ePrKxh2ediaJtpn5dOow/rdoNx+uzgYI+oEKgN0iEeewMHFQCm8uyazzB3E8ohA4KX0yZSgxjoa1Iff4laAyhSPocQMSBKh0a6fny51eJr+3TlO41PUjdfkUftpeyIe/HeCWYe04WOFivut8bpNfp76f69/GWPkgI0S5gGRmyKhLedycw3++36lLUNnMIjNuH0irEE6OBgZ6eSvjLZbmLtXVyOyVvWRXZvPYksd47cLXTuPqziw1Hj+PztvMv67pZRi8GDRLpEgJf7UfVVaDBJW/wo8psvaWToq1EjegE5VLcvAXu8CvgiggRpqJHJZMxKCWSJF1/xbG9mzNM/N3aIqp+rCbRe4Z1f6YCcSJ70cUmH3nIK6dvpLsEmdYWQtBCLjIzbl7CClxYWbPY9sS2CnVTX3BUjEujYzfjaOgwk1OqZPcMhc5ZU5W7C0mp9RJTpmTshofLaOt5JQFm1GFE0QWxOBgrl4RDSqCeOycX59g8coq/dtEMaF/G95ckklpjRdRCFT0SKKAX1HomBjJ1PM6cEn3VrorZ95blYVyElm6I4jAJ2sONOlxMY0mpl74cTc/bDukszFPZu7aHNrGO7hrZPppXF3TwmqS+PPl3fjjJZ1ZuPUgL/64m4JKDzaTiMNiokdyNFNGt2dwejyCIHBx91bc8s5q3D65zhNghFUi1m7h43uG0jah4Ut6LJIYar64LjcgRVWxhfjxzl2bg18JPgGH8yN1+WT+9+MuluwuZE1WGVf3TcbfZjDWvBW1nivrodpmEKkxIu6/aBhimGww+F6QTNw5Ih2zKPDMgh34ZFWz7vwIFknALAWE1PAO+pxxDAxC4fK7mLV1Fi5/8EW9vkZmj+zht4O/kVmRSfsY7RKd5sa/F+5gUFo8F3c3+hENmjYx1hjMojkoO+Xo6EAwCVSuryRm8LFKDdktU5VRRcvran+3o63ROPol4egX6IVWFRVBZym/zSzx4T1DmPDGyrCDtBAwKBjSPoHfnd+xzuOibGa+emAEf/h4I8v2FKMoKr4Q18sIi0RchIV37xxMx6TI8N+EyQr9b4e1M6COyhOoI1hqjoARD2GWRFLjHaTGa++ZPH6ZrXkV3Pj2b/jkE3vewggiqyagtvGIXhEdeJ5AcDxcwbItv5IXJ/bl5iHtyMitILfMhdPrJ8pmomNSJB2TTs4UyycrfLImJ+iz0DP42e1XmLUiiwfO79hkK9MaRUyVO728vTTzpFJ+Lp/MtB92MWlI23POlMJmlrimXwr55W4q3T6eGNtN87huraNZ9cSFLNpWwJtL9rGvqBqzJB4dDuf1yXQsyeah+6/kgp7JDd4ndQRRFIixmyl3Bp+Y9LgB1Xhk7np3LX3axtG7TQy9UmLonRJDUqSVd5bvD8r+6PmRlrt8pMZH8NKN/YiwmuDA3+D9q0FjA1ovkgUGTT76z1uGpTEoPYEZyzKZn5GPgFAreBBhkRBFgVuGtOO24e1oHXN2lVQZNC7f7/8eQSMaG24js6zIzNk+h78N+9uZXnqDs2xPET/vKGShUd5n0AwYnjwcvxrs5Cc5JJKuTiJ/Tj6iTaz1+zXHm4kdHnv0WJtkY2za2FqP1yukjtC1VTRzpwzj5pm/4fLKQRvjE3FYJEZ3SuTlm/qG5fbmsJh45/ZB7C+uYfaK/Xy2PheBQImggorPrzKsQzz3ju7AsA4JJ7ehHjIF1r0TdHPYwVIU6H19vS9jNUl0TIzSzIOFE0RW/NFIUu1KAr0iGlQUX6yuvZCiqry7MounxvegT2osfVJjaQgKqzyawWS9g5+r3D6qPf46e+cbk0YRU3PX5mg6S4aroAHmZxzkugGhJ3yfzWSX1NA3te5BaRaTyPg+yYzvk0xmUTUHSp3UeGQibSbat4hA/v0UEoq7IImnd8bKTYPaMnN5ZlA/nB43oN5tYphx+0C25FaQkVvOx2sO8MQXFciKqjnUV8+PVFUhq6QmIKQA2g6Fsc/Bwv/TJ6jMDrjlC4is7YbYpVUU067vw1Pju/NdxkEOlDqpdPuIc1jo2iqai7u3NMwmDE4Ls7fOxumv7VCpp5HZr/r5dt+3PDbosabdO+VzwY5voXgPuMrAFgstOkH3K8Fsp9Lt4/HPMnju2t66GqYNDBqLJEcSg1sNZnne8qD7EsclIkVIFMwtwFvoRbSLRPePJvXeVMTjhrAqqsKEThMabE29UmL46eExzFy+nw9XH0BV1VqVL5IAZpNI+8RI7hvTgSt6t9YtetJbRPCPq3ry53HdOFjhpsrtw26WSIqynXobQlwadL0Cdi7QHyw1O2D478Ny8QOIspk0jbDCCSL7SocjJi2oZUKhR0QDqHIEirsNnrwNYe+FfLLKmv2l9R53PF6/wrqsUoprvMiKQozdTN/UuFpl1FVun2bAXu/gZ5MoUuk2xNRRFEU95WyC0ysz/de957CYcnJV3/BFUPvESNon1k6Jl42/gopvvyXqggsaenm1uHVYO2at2A8aBX/huAFFWCSmnteBltE2Wna3cdHhEh1VVZm+ZB8v/bg7yKBH7490X9EJTkUD7ghkmeb/MWCFqtEIfBSTLXDsLV9A6qCQh0XZzNw4OHwbUQODU0FVVbIqs4Ju19vILIkSOVU5dI4Ls9H7TFK6H36bDpvmAELtgZyWSJj/MPS7mTfKzue8rqmM7nx2zipUVZVV+0p4f1U2OWVO3D6ZKJuZgWlx3D4sLWQ5kkHT5s4ed7L+0HrNMt34MfHEjwltLiAgcF7qecTaGia7cISkaBtPjAu0HHy/tYCV+0oorfZgNUu0ibUzoX8KXVqd+oxEm1kiXYfBQdhcPR1mj4VD28MXVGYHdLoEznsi7JcRRYELuibx045DHN8qFE4Q2VfRD2vL74KeM1wRrSoWvCVjAEH3XqjKHRyc1iK/3MUHq7KZszobVSXQG68GMp9ev8KFXZO4Z3R7+qXGYjdLmv1Segc/y6qKowlXo51xMXWoKhBtOBG9Kb+s4sBF41wr9YPAxO+2p3iBjLr0Ugqn/Q+5uhopUkftsU6SY+0MbZ/Ayr3FmnXQ9VqFm0TNHgdBEEL2ZOn9kWrWgfedBKlDDm/WPgRBDETAVRlEU0BEmaww9AEYcDtEGL1OBk0Hl9+FJEhBpUJ6G5kFBKq8Vad1rSfFjm/hiykg+7SDHYeFlbJ2Ng+q78GEt4FeZ3aNpxlFUXl/VRZvLsmkyu0L6o/dll/BB6uy6Zsay2OXdmFgWsMOZDc4vQxqNYghrYaw6uAqPLKn/gccR6Q5kj8O/ONpWlmglO2qvm10BXWbBCYr3PEdfHobZK0IzIwK1dktSIFAaZ+bYNx/wx7Ue4Qpo9uzYm9x0P6i3iCyasVX3h9z7HoEsfb5uz4RHXi8gK+yH6B/L2QNo0rmnWWZ/HfRLlQI6cC4aFsBS3YXMTg9nv9d3we/RppO7+BngYDrdFPljIupCpcPkyhyZHDYEfQqaLNJoMLlO+fElNsnU1LjJTn21MpuTHFxOAYPpurHn4i95upTXpfXr7A2q5Ti6kB9bIzdTL+2gXTvCzf0YezLyyiu9mhO2g6F3Swx+87BIZ2AIqwmTKKIT659stL7I42whPgOJXSAy6fBxf+AXQugPBs8NQGb1aTu0PFCCPP7amBwJjFLZmQNK+CTaWS2SU3MXXLbV/Dl1LAiy6Lqww7wzVSQgB6nfq5rCrh9MlPnrGd1ZmlIE6dAX4vK6v2l3DJzNU+P72Fkx5sRgiAw7bxp3L3obnaW7gzLkVNAwGF2MOOSGbSJbGZC50xhtsOkT+HAKljxCuz7GSRzoAoFIRAsVXzQ/WoY9gC0Prn5RgPbxZEQacFZGnyeqi+I7Cm8AkvkAUy2Yvx1VcacgKqYceXeBkrAk1jvXqhNPfvK/y7ayazlWfW6LipqIEi9al8Jt89ew+jOLVi8s1B3lu4IkgBX9T19/f0NwRkXUyZR1LTL1qugVRVMTfiDPV3kljlpE2tvkC9VzPgrKJ/32SmJqVDpXlEQ8MrH0r2fTurB9a8uodwaha8eQXXE6n3GbQPpW0cTZPsWEWh9DHp+pALQqWU9ZQkWB/S6ru5jDAyaEGbRjMPsoMZXe86b3kZmr+wl0dGEyuMKd8JXtYVU2ktVOH2w/8FIIiyBE8I7G7zMyfDx6x2HS4X8roAAS+wCSdrGPc0FWVGZ8sF6VmeWhD0A0+1TePrbbdjNElf1MzbZzQWrZGXWpbP4+6q/szBrIagEDfKFgIiymWwk2hN5/cLXSYtJO/OLbU4IArQbHviv6hBkLQN3eaACxR4PHc4Hm/Zcy/BfQuD1Sf2Z+NZvulyrAWwmK88Nf4P3Mv9MZkVmvZlJAQFFMeHKuxnZ2eHo7Xr2QgDZpU7eWZbJFb2Tg0a0fL4+l5kaLTp14fEr7DlUjc0kYTdL+rN0hzGbRCaPbNqusmdcTCVEWDRTg3oVtF9Rm3TK73TRECV+R4g8/3wOPvU0vsJCzElJ9T/gBGYszWTaD7tQVfDK2j+w77cV8MuuIvrU5DEroZxPul/CFxvyNIfDWU2Bsr2RHVvwxNiu9YqcwenxRNnMmvbv4f5I7RaJyeeQzb7BucPVHa/m012f1rJX1tvI3DW+K0kO/eeG08byF8GvMYdFhZdXe/nzqDqmxMnewOMnvH0aF3j6mbU8k7X7S3W74bp9Co9/kcGAtLjw5/MYNDpmycwzI5/hwf4P8umuT/lo50d4ZA8mwYSCgl/xM7LNSO7seSd9E/s2WevoJktUy9MWLO2dEstbtw7g3g/Why2obGaRZ6/pxeU9Urio6wfM2jKLD3d+iE/2BRkKWSUrqqoyLHkYlI7le6fAiWWL4e6F4hxmnrqiO99tOcirPy+lS6soxvdJZlzPVsQ6LDy7YEdIIVXXecfjV9iSV05SlA23zxVUnVRflk7CTxfhEF0Kv4ekCU22GuiMi6m4CAvdW0ezObei1u16swljOiWGLP86m8kuaTgxJdpsRF1wASXfLWD3iMspqvLgkxWi7Wb6tY2t06r7+e93MntFeOlel09mg6kFT0R15rMruvPk5d34elM+n67LobTai6yqRNvMXNQ9iVuGtiMpKryyIkEQmDK6Pf9dtEvzRBXOdPZYh5kh6UYvgcHZx81db+az3Z8F3R5uI7PD5OCuXncFPb7RcFfA9q8CfYsn8NhwC8+v8HD/IAuxthCbSVWG7V/D2OfB3rCN+WcKRVF5a2mm5vkuHDdcWVH5YFU2T4xr3tm5c5FERyIP9HuAqX2mUuoupcpbhdVkJc4ah8NsiOOmyujOicybOoxH520mq6RGc+6kIATaGlpEWnluQi+Gdwz0YFslK/f1vY97et/D0tylfLb7MwqdhXgVL9GWaEYkj+D6LtfTwt6CvYXV/JSxDP9J9KbbzCKTR6VzXtckzuuaxDN+maW7i/l2cz7Pf7+TtvEOTedkCO+845dVeraJodzlC9vkAkBAIRonbwnPwrce2PwRTPwwUC3UxGgUa/T7zuvAI59uDsoo6MkmTBnTtFN+p4vsEiftGmjAbl65i1kdL+Tj7VWIWRtQVBVVDUy89skKg9LiuXdMe0Z0aFFrVsRn63OZvWI/Lh3pXq9oYk9hNb//eCMzbhvITYPbclMD1O9fOyCFF3/aDeGXFR/FZha4sI+HL/Z8gc1ko1VEK/ol9UMUzj2RbnD2kRqdSq8WvdhUuCnIiCKcRmarZGVMypjTuUR9bJ4bKMPRYGCyxHlpJqat9PDMBXUEYwQRNn8CQ6eepkWeXpbuKcKlkYkP1w3XJ6t8uPoAf7ykM1ZT04zwGtSNJEokOhKbVvmtQZ30bBPD9w+NZnt+JTOXZ7Jo2yGcXj8qARE1qlMLpoxuT/+2cZqZRZNo4oK2F3BB29Duyx2TIvnTpV1DBpdDYZEEeibHcO/oY+WBVpPExd1bcnH3lji9fsa/ulwzcB7ueUdW4ZedhXx4zxDufm8dlW4fIYqZjmLGRyzVzLX8k1ZCWWCPl70SPrgG7pgf6HNrQjSKmLqoW0vMJhE0LgrhZBPiHJZzNptwoNTJ8A4Jp/w8763cz7MLdgZK9CQLaEQdlu8tZsOBMrq0jOLduwYTYzcjKyrPLtgRUkjVl+5dtqeIXQVVDWKfChBjN/PeXYO5ecZqXScQUfShRmTwQ+kCFpUqiAR6+ewmO7f1uI0JHSc0uK2sgcGZ5j+j/8N131xHuadcs1c1FDbJxhsXvYFJNOH2yXy7OZ+vN+VTUuNBVSEh0sL43slc2TcZh+UMXUbyNx5239LmH+dbGTGrhgeHWEIeg88JBzedhsWdGd5bmaVZ1qzXDfeXnUVc1rNVQy/PwMCgDronR/O/G/ryPwIjDYAGLcu8a2Q6NR4/r/+6N6zeJptZpHvraN69K7TRl8NiIq9c2+xHz3kn8PwCCx8czYs/7uLrdfsRkXFyLPiVO/0u8LvpcO8bXG9fw8Omz/lyYwl3Hel/9buhYDMsehLGPV/va55JGiUEb5JEXp/UH5tZ/8vbzRJv3Nz/nK0Lzi6poV3Cqc1feGXxHp5buAuPXwnZ63QEp1dmW34FV722nEq3j193FeLxa4uWyjVfUrp4BjFDbyDld3Noc99sovqPw7Vn9dFjfH6FmcszT2n9J9K/bRzvTx5MpNWEWcPyOQjBgxSzFnOrz3D6a3D5XdT4a3D6nZS4S5i+aTqXfH4JS3KWNOg6DQzONEmOJN4f+z4JtgQkof5MhICAHYnX5DiSzWk89fVW+v/zR576ZhvL9xaz42AVOwuqWLG3hH/M386Af/7Ek19uobCqfqexU8ZdVufdPZMkruhs4rnlwT1VtXDV/TxNmQOl2mJSjxuuV1bID7E5MjAwODMIgnBa9rG/v7ATr93Un45JkdjNkqZJV4RFItpmYsqo9sy9dxiR1tABMVVV8YQQZrpcuAWodPloFWPjP/3LWBf1MH82fcgQYTudhBw6CzlY8WFXXUzaNIVnzO+SIGiM5fC5YMP74GlaIzsaJTMFMKJjC6Zd34dH520O2x3EbpZ489YB9KnD4e1swSN7WJ67nEPOQ3hlL5GWSLrFdye3zHVKPVPfZeTzRphRiyN4ZZX8cjd3zl6LgEqN5+TLTGQVvtmUz9/G96jzB6yXQWnx/PDwaN5ZlsnctTkAtSK4Jgn8ih/JfgBLwq+YIneHfK4jFrSPLnmUfwz/B2Pbj22wdRoYnGnSYtL47MrPeGH9CyzKWoQoiEGDQC2iBQQY3Gowj/R7COmbaVw2bRFlsl1zPhwcM5CZuzaHhVsLmDtlaP3OmCeJoqg4VTv1TcT7+3k2+r9VzSPD6jCisJyeNZ4JQvWo6nHD9cuKbncxAwOD5sNF3VtyUfeWbMmt4J3lmWzOKafa48dmlkiOsXPHiDQu7t4yLN8BQRAQRSGozwv0nQVV9UIAACAASURBVHcUVeFoZfGKl4n0lXKLaTG3mBYfPSZNqGLqcAvPr3Dy4OCo0P2vggAZn8KgyfWu/0zRaGIK4IreySRF2fi/zzPIL67Eg4hK7Q9PEgXMkkBaQgTTru9DzzanZlfZ1MmtyuWjHR/x+Z7PEQQBv+xHVmVMogkVsKbF8GOOj0vTLsUq1bFh0EBVVf713ck5snhlhR0HK/GFyGTpTffuPlRF/7ZxutZfH8mxdv42vgd/uqwrC7YcZH12GaU1XmxmgWWHvsTjWAbmkrCfzy27+dvKv9Emqg29E09u1oSBQVMgwZ7Av0b+iycGP8E3+77hq71fUe4pR1EVIi2RXNj2Qm7ofAMtI1pSUOFmXM71lPk8YZUG+hWVshov1725iu/+MPKUneJkRWV/cTVb8yrZklfB1rwKtudX8juTjbswY66jQbJjvMjEHmZeWeOlV5LGRkGyQItOp7S+xiRUAEqPG65FEom21X3pr3L7WJddRoXThyAESusHp8efc3MdDQyaM71SYnj5xn6n/DwxNhOlzuDzrp7zjtPn4aFlkxm2w85L+5eEFB9h9b/6nLDyVUNMHc/g9Hh++t1Qvhs/iYUTH2FpTjVOT+CPFmEzc2mPVkwemU631tGNvNLTz7xd8/jP2v8gqzJ+pXYPk18+/G+zm3/99i9e2/gasy+brWsw32+ZpZS7tDci4TiyuLxyyK2V3nRvRYh1NAQ2s8SE/ilM6J8CwNd7v2bF6kWaQz7LlpVRvKgYb6EXySYRPSCalte1RIoIvA+37OalDS8x69JZp229BgZnikhLJJO6TWJSt0khj5k6Zx0Vbj/qCVXgdQVbVKDa7Wfyu+tY9PDosNfjlxX2FgWE09bDwmnHwUoSIq30ahNDjzbR/O6CjvRIjiHe3xNe+RzqSar8bYyVDzJCnV8E6HdL2OtravRJiWXPoSrkE07EetxwRVGga4jr6Y6DlbyzLJP5GQcxS2Ktvg5FVbl+QAp3jEgnvcWplZobGBg0H64bkMrslfsPDwI/hp7zTkpMNN9O+oD8HV+ibFkCSugTeVj9rxW5p/y+GpJGF1MAVT/+RJ+28Yy/ewQA5V99TfWKFaQ83bQazE4n7297n1c3vlrvcDYAp9+JW3Yz8duJzB0/N2xBNWNp5ik5QdUVo9Y7dNl6Bm3tZ26dGVTWBFC8sJiihUWk3J1Sa95O1rQs0p9MRzQF1ri5aDN51XnGRHmDs56dBZXsLKgKKukIy3ZbVTlQ6mRzTrlmKbbXr7CnsOqwaApknXYVVNE6xkaPNjH0ahPNJd1b0iM5hhiHllNTG0gbAft+rnVr1kO1y/ZSY0TcfwkRfGs3HGKa7+/4tuHt+GJjLvKJago982QsDGxXuyrALys8/nkG3205eNS6Wauk8KPVB/hkbQ53jkjj8cu6nrO9ywYG5xK3DW/He6uy0NoFhnPecVgkpo7pQKwtltjIVBAtQOg+2+P7X7slhtgrKj6Q/SA1CRnTNMRU+WefEXfTjUf/LZhMCHL4XvTNneV5y3l146tHe3XCQVEVqnxV3Pn9ncy/Zj4WqQ4Ff5iNOWWagkiXE5SqBupVT0BPutcvqyRG6StRPFm2lWzjYPXBoNtll0zhV4W0mdyGqN6BzZgl0ULq/ansfmw3FSsriBsd2HCoqsrHOz7m0UGPnpE1Gxg0FjOX7cd3wiY63GALgMcvM2NZJv+7oQ+7CqqOiqZt+RXsPlRFSpwjkHFKjuaK3q3pnhxNlE2Hxe2oR+HAqkATsl7MDhj1iP7HNRFKqj088ulmzTkyR6jPDVcUwO3zs2BLAWN7tjraC3HXe2tZu7+s3l5an6KCovLeymyKqjxMu76PIagMDM5yUuIcDGgXx+r9pZq9U/Wdd1QVru53OIhlra/zNUC9/a+iqckIKWgCYsqbnY1n716iLjjmny+YTKi+c0dMvbj+RU0hVV8JmqIqVHgq+CH7B65ofwVVbh/lh+taYxxmok/YpDg1slKgr0TPbBJRFPWUykwSo6x0TArvB3WqZBRlaPZ9OPc4UXwK0QNqR7Alm0RU7yiqt1UfFVM+xceagjVnZL0GBo2FX1b4ZnN+0G9bT7BFUeG7jIP8uL2A9BaR9EgOZJyu7d+Gbq2jiThV05m0ETDiIVjxcp026UGYHTD8D5A+6tRev5Eoq/Ey/tXlFFUFrOlPBgGIdVj494RevPrzXqYv2cufLu3KD9sKWLu/VNfcQJdPZsGWAtonRvLA+R1PbkEGBgbNhhdu6MvYl5dS7vTpGLIRsF9/9aZ+x879se0gjAqsevtfo1rrWMXpp9HFVPlnnxNz5ZUIlmOZFcFsQpXPDbehXaW7OFB5IOj2cEvQnH4nL619i/9n77zDqyjTPnzPzOnpCUmAhJbQe5MqRQVFFEVXxbKooNi77ufu6u66urq71lVRQUVUsGAvWFF6772HEBIC6T2nzsz3x4GQw5kkZyANmPu6vLw4Z2byvuck78zzPs/z+73xfST7csqxHHvd41PolBjOXaNSGdezJVaThEkStOykdJXoWSQRt6qgdUcPPd2b0mi7mWWeMrxycP+EXC5jCjchaEipm6JMODMCd77LveUNNkYDg+ZATf2UuvohAYtJ5If7z6djAyn7Mepxf739qtdDy1CZHTDkHhj954YZTwOjqiqT31tDXrm7RmXFUHBYJT6ZNoQuLSMY2z2Rn7Yf5clvtpFZ6NR8OKqtRw78AdX0hfuZMrx943mNGRgYNAkto2x8dudQJs1cTYnTE7TppoXNLPLcxF6M6Z544sXIVpB8HhxcXuf5Nfa/mh0w5F4do294mnQFVL1eir/5mnbvvx/4hiSh+hpOoKA5MWfnHLxK4Fz1lKABHK08TGXhHhSlNb5q2addR8r469fbeOLr7fz76p7EOixUuIMfPvSU6CFAt1YR7DwS3FcBoZkuV6V7GwGLaEEURRQlcNdVCpfwlftQZTUooPKV+DCFB/5pmMXm5bZtYFDfOD0yosYmh95+SJMkEHqO4xQQBLjwCWgzCBb/B3K2+4Or6uuoaEIRJHYqbenyh39h7jquIUfUoKxKK+BAXkVQ8zfUHfAcJ9pu4tM7h1aZpQuCwPherdiZXcKMJQeCSgdD6ZHzXwe+35LNpPPaNsDMDQwMmhOdEiP46aERPPnNdpbuzQOC7RokUcAiCbSNC+OpCT0YmhoXfKHhD/lN2D0VAS+H3P+qKtD3htObTD3TaMFUfrmbjIIKylw+wqwmkmPshK9fiaVNW6ypqQHHCiYz+M6NMr/NeZuR1cAsnJ4SNABUEG3ZKO7WQdc/7gn1py+2MjQllrxyd1BdvJ4SPYfFxIzJA7j8teUUaUhl1sbxdG9j7mLGO+KxiJYgdURHRweCSaB0QylRg07I7csumbKtZSRekxhwfKIj8N8GBmcbETYTPiU4DNK12YLfEyqiDunteqHTWP9/+ftg3TuQs9Nv5GiNgITuiIOm8cJ3RVxS0pKatQubPzOXpmmWaIca8Fgkkeev6UPXloH3E5+sMGf1oaBASk+PXKVH5q3FaUYwZWBwjpAYaeOdmweSV+bm4zWH+GJDJsVOL7KiEmY1cX7HFtw+ogM9WtdiY5R6Edjj/JUFqs6tN5Mdel8H9ublN9ugdzxVVVmTXsjMJWmsSCvAajpR9+jxKXRx5nH7mOtoo6hI1WyaBZMJ1XdulPlVeCuCXtNbgoagIEi1l7u4vAqr0gpr9IkKpUTPbhaZNqIDSdGOU073XtStcYOS0W1G84+V/wh6XXJIJExMIHtuNqJNDCilNMeaiR524g/VYXJwXZfrGnPYBgaNTqTNTJjFFFTup2ezBUASRVqEN47ADOD3jbpUW/n1gYsKefDTzVw7MDkkg8rmRk6pi9UHCoNe1xPweGSFuaszuLhHy4DXj5S48Ggo9ukSJAIOFVbi9slYTYYHlYHBuUJ8hJUHx3TiwTGn4NsninDzN/D2aP8GWKhdWJIF4jvXuN43JQ0WTOWUupg8aw1ZRc4qf6KTF+6t5jj+clDiuf8uZO5tg0mND2NrVgn7j3g5am1L+61H6JQYTueGqr1vRI6WuNiQUUSpy4tJFGgRbmVoahwWMViFT28JGoigWOos+XD5FEyigElQ0biH1q3IAlw3sA3gT/f++OAI/vZt7ene44bLT13RgyEpGuneBibMHMb4DuP5Lu27oAxg/Ph4pDCJo/OO4sn1INpFIvtH0ubONojmEw9eJtHE6DajG3nkBgaNiygK3DKsPTOWpAX9LYcqu22WBG4a3LbZBC4D2sXSJsbBN5sOc+2xtetMYueRUswmMej70BvwbD1cEvRaidMbsIl5HL09cmZJpNTpIz7CCKYMDAxCJC4VbvsV3r8cxVWCqNRR6WR2QMtecNMXYK7BzLcJaZBg6nCxkyteX16V+quNCo9MhUfm0leXEuuwUOb2ISgKcswATF/6ZWDbx4Vx9+gTQgpNicsrk1FQSYnTi9UkkhBppVWUXfNYVVVZmVbAzCVprE4vxCKJ+BQFAQFJFFBVlYQukUB2wHl6S9BQRQqXbKJ46bK6Sz5MIhZJpNTlRU8vs80s8vJ1fYh2nAj+WkYFpnu/3JhJceXJ6d4UurduWsPlyd0n81P6T8gaoiaxo2KJHRVb47lWycqNXW/EJBoN1gZnN6qqkhxj1/QXgtD6IUVBYPLQdg0xvFPm/os68sTX27m6f7Jm8NCcKXP5qoxzq6M34NHyF7SaRE2lU709coqqVgkfGRgYGIRMQje4ZzUFC17CseUDwiwieKqLfQn+ICqshb/Pqv9kkJpn/3q9PyGWu31cN2OV/6Fah4arV1bJKasmlyia4Vi/z+6jZfz1q208+8MuPrljCKnxjSOrXZ2MggpmrzjIZ+szEQSqGrU9PoXUhHDuHpXKJT1aVt1Uylxepr6/jh3ZpVX17lolFdmH+mNOTEeQTsz9eAla5luZHJ59mC4vdUGukMmek41oFilaXhQQAMiVMkW//xZyjXvreBsOi0ROqSvkEr2nJvRgfK/gniw4zXRvI9ApphN39L6Dt7e+rcvLyyyaSY1KZVrvaQ04OgODpmfP0TKenr+DnFI3wzvGsSGjbs+hk7GaRC7smkByjKOBRnlqDE2JIy7Mwvyt2VzZ98wy7LWaRAROXxREK1MYF27VvCfp7ZEDiDhdyXsDA4Nzk/B4Coc+wc0HxvDTxcWw71eoyPcHTRGtoM/10Gawpr9pc6LeV8DP12VSWOHWDKRCVR7SosIjU+mVuXL6Cr66Z1ijlf65fTKPfbaFX3fmoCiqpjTtzuxS/vzlVp74ZhvvTB5Ij6QoJr6xgswip+bNKuD6Jd0xJ34V9Hr8+Hjyf8hHrpTZdd8upHCJyP6RhPcMp3RDadVxqiJRuiUZ1bc15JKPjMJKvrl3OFe/uRJJ9ZfjOb2BO5d+RRaRdnEO/n55d4Z1bBHStZsrt/e6nXJvOR/v+jikgMoqWekQ2YG3L347JENkA4MzkcIKDy8v2MNP245y/4UduWlIOxRV5boZq9h9tKzGLNXJWE0iHVqE8cqkvg08Yv0IgsD9F3Xi2R92MqF3a8QzKDvVMtKGonEv1RvwxIUHr2GxYRZ6J0ezIaMo4HU9PXKiAON6tDyjPlMDA4PmhUkScKsm6HWN/78zkHoNplRVZeayA5rmf6EqD9V+fahw+7j+7dUsfHRUQMlZQ+Dyytzwzmp2HSmt86Gi4lj26ZbZa2kXG0ZWCIGUHxOewqFY4pYhiIGKc6JNpMWlLcj/KZ/O/+6MFCZRuOTkZmQBT14b3TXus1ekc3X/JP4yvhtfrM/ik7WHKKjw4FMUwi0mhqTGcfv5TV+iV18IgsDDAx6mS0wX/rfxf5S4S3D6nEFlLg6TAxWVP3T6Aw8NeAir1IiN9AYGjYRXVpizKoPpi/YzoXcrfntkFDFhJ9bTT+8Yyh1z1rMho6hGs+/jOCwSPVtH8d6U87CZm2ffzMhOLXjJLPHLjqNc2qt5mT3WRu/kKCJt5qDvQE/AYzdLTB6iXXp516hUHpq3qUr19Tih9shZTRLTRqbU02wNDAzORUyiEKQqeqZRr8HUqgMFlGkYP+pRHoLaM1gqUOn28dGaQw3qvK6qKg98soldR0p1lbu4vAp7cso036tpXnARpvD9iNZsBDHwpmbvYCesaxj5P+eT+IfAXilVMeM6MhHRgu4a99925rLwsdFE2sxMPb8DU8/vEPIcz2TGp4zn0g6XsiFnA+/veJ9dhbuo9FZikSwkOBK4seuNXNrhUmym5tfgaGBQHyzak8u/5u+kdbSdT+8Yopnlt1skPpgyiIW7c5mxJI1txwQMjm8qWUwiItC1VSR3jUplTLcETM1EdEILQRC4/8JOvLJgL+N6tmw00/DTRRAE7hiZwgu/7AmqHgg14FFUtUbxjQu7JmAzSUHBFNTdIycASTF2eic3L4liAwODMwSfG/b8SHTmDu5x74bFayA2FbpdDmZtLYLmSr0GUyv2F1RlaKqjR3kolAyWy6fw3vJ07hqV2mANxZszi1m2L18zkDqVcsW65lV56DYcbd9FtB4NylAlXJXAgWcPEDf2hBqeqphx51yKr3QA1qQKXSUfbq/CI2M7Ext2bpavCYLAwJYDGdhyYFMPxcAgZFRVxaeop6yUtz+3nH/9sJOMgkqevKwbF3ZNqDWoEEWBMd0TGdM9kYP5FSzcnUtRhQcViAmzMLpLfJP0r54qY7ol8PKCvSzcndvoFg2nwx8GJPPCL3s036sr4LGYRMb3akWUXbtpWxIF3r55ADe9u0Z3j5zDIjHjj/11nWNgYGBAcSasnQnrZwMQ6angelRYPB8sYfD9g9DvJhhyN8SeGZnveg2m8svdmq+HqjykJ4Pl8sos2ZvLhV0b5qb4ztIDuDW8rk6lXDGkeSk2KjPuwhK3EHvscmzCib4eW7KNiD4RFPyQh6WVHRQrzsxbkCv9mTm9PjAWk8ht558Zv6AGBucy+eV+pcy5qzPIL3ejHutxbBfn4M5RqUzo3Rq7pfZ1taTSy/9+38u3m7O5Z3Qqb09ur1t9rX2LsDM+ey0IAvdd0JHXFu6vM5BsTkTZzUy/sR/3frxRV8AjCJAQYeXpK2svox/QLpa3/jiAe+ZuDMp+1XTdMIuJD28bRMeEM9+2xMDAoBHZtwA+uxkUH8gegGoSO+oJNb/1s2HTHJg4A3pMbIqR6qJe6zLMNWSJqisP1YaeDFaFR2ZDRvEpjbMuCis8/L47N0g6/HhQFDv2bhxdhiFabAiSCUfHwUEBS3VCnpdqwpN/Mb69/8czeQVYVZUEn49En4+Bl8dQsriQ/gfN4GpVFUgdJ3LQ1cRceBslq+aR9fpNZL11K2Ub52PvFPwzGzKjZ2BgcPqUubzc89FGhv9nIW8s2k9umRtF9Xu9+RSVtLwKnvpuB/2fWcALv+xG0ag398kKc1Yd5KKXF+P2KSx4eCS3j0g5p2WsL+3Zkgq3j2X78pt6KLq4qFsi/726Nzazju9OhdxSF499voUNGcHGv9W5oEsCn905lL5torGaREwa9wcBv9DI0JQ4vrtvOP3bxuichYGBwTnN3l9h3mTwVlYFUjWieMHrhK/vgm1fNM74ToN6zUy1jLJpNpKFqjyk1zujoIZM2Omy7mAhZun0jRKPo3deLuxcXOkkTpZ5Nr+QMZl+9b5p3U18tuYgcnw3zfNC8YEBGNjeuAkaGDRX8srcXDNjJUdKXLWK2BwXJXhv+UF2Hy1jxh8HVJUALt+XzzPzdxITZubDqYPPGiGZ00UU/dmp6Qv3M7JzfFMPRxdX9ksiOdbOsz/sYkd2KbKs4KulZ1sFPLLKrztzWLo3nxsGteHJy7rXqLzXKzmKb+4dTlpeObNXpLNgZw7lbh8iAmFWE/nlLn544HwjG2VgYKCfgjT4/BbwOfWd53PCd/dBfBe/aW8zpV63KC/t1Uoz41G9DK1y7yoUrwtV9uFMW0/Roveqjgs1g3WchlKOKnF6NXd69QZFx9E7Lwnt40aNi6LMC6IlH4Q6ovpamPbhehbuzjnl8w0MDBqGSo+P699exeGQ1UDB6ZVZsT+fP32+hfS8cm7/YD1/+XorD4/txCfThhiB1Elc3rsVuWUuVh8oaOqh6GZAu1i+umc4c24bjBRi75yq+n9HPlmbyeNfbtU0Aa5Oanw4/5rYizV/HcOOf45j2z8vYfVfL2JQhzjS8yvrYxoGBgbnGitfD8pGtf9fGQkvlFHhObEmvbvRw+j3KwLP9blh6YuNMcpTpl4zU6nx4XRrFcnmzODyu1CUh/R4Z1hMIknRDaP2IQoCGj6Juo0Sj6PXEyQGf83owYdO7ACut1n5X7t4erzbEgB3wS94ci8B9ItIuLwK9360iS/uHkqP1lG6zzcwMGgY3lqcRlaRU1MmtjbhG5dXYf7WI/y2K4d7LujI9Bv7NVuZ8qbGJIncM7ojry/cx5CUuLpPaGYoisqfv9qqW0rY6ZWZv/UIPZIiuXWY/v63C7smsGhPLmO7nzniHQYGBs0ATwVs/dTfJ3USsgqvrvHw1xG12NCoCuz5CSoKIKx5rtn1btp716hUHvlss6Y3SV1laHqEFATg8j4N4xcSF2bxB1QncSrO8KBvXlbcXC8tDDhfBf7aIg63eGIn0hq3AsUXTvESH6Vrv8FXfBTBasfRaSgxo27BeWB9rYqDTq/Msz/s4uNpQ3R+OgYGBg2BV1b4cFWGpqddKMI3PkWlR+so7hndcJYRZwtX9U/i1d/3sSGjiAHtzqyy5yX78sgpcSHrDLjBv+6/+ts+Jg9pr7tvdnSXBGYtX4OqqmeMeIeBgUEzYPuX1FQI96dhFp5f4eae8yxE22pZVwQRNs2F8x9smDGeJvUeTF3cPZHhHVuwbG8erhDLVKoTqnfGoA6xtIpqmMzUkJQ4zRuVXtW86oQ6LxWRm0y/B7y23malRKOko+j3NIqXL6bFZQ9ja9cHubyAgl/fJPu9+1B9HuIuua9WxcENGUUcLnY2WIbPwMAgdH7flYNPDl4z9aicbs4s5kiJs8HWxrMFsyRy9+hUpi/cx+wpg5p6OLqYsThN04IkVKVZj09h4W79GabU+DDMksjWrBISIq1UemQirCZiwyzN2mPMwMCgicneDN4KzbcGtpYY3d7Eiyvd/OvCWvw9fU7I3thAAzx96j2YEkWB6Tf2Y8rsdWw6VIRTp3cF1J3Bspsl7hqVejrDrBW7ReLagcl8vOZQUClFqEGRFnXNyyQKjEgyk1DsBRx+xRPg/agInCftBHpLIila9DNxlz6IPWWA//yoROLG3c/hN28lvO+ldT54qarKhysP8pfx2oIWBgYGjccXG7JO26cP4JftR7l1+JktY94YXDswmekL97Mtq4TurSPZdKiI/HI3Xlkl0m6mT3IU0Y7m5cWXVVSpWUavJ+Cu8MjMXJJWFUx5ZYUFO3P4dvNh8ss8qKjEhlkZ36sl43u1qioX3XWkDLMkcPVbKzFLApIgIKsqkiBww6C23DKsPW1iHQ04ewMDgzMSZ1Gtbz99gZXh71Xw4OA61ltnwyh41wf1HkwBWE0SH04dxH9/3s3c1YcQBDTL/kSBIPnxurCbRf7QP4nhHVvU02i1mTK8A/PWZWrWpYfiDC/onJsoQGyYhedvGQHmvbg3fEz+Ly/QylLBSrsd9aRgqmKn3f+AVe3GCeDNywAEZFd5nT/TI6ss3J1rBFMGBqdIdrGTLzdmkZ5XQYXHR4zDQu/kKK7sm0SYVd/ymlN6ej59AG6fQkH5qYvTnEtYTRI3Dm7LQ/M2UVjhwSufWLAFwZ/Bubh7ItNGptA7OboJR3qCfbnlWEynrzS7L7ecMpeXGUvSmLMqA1lVqXAH3qNXpeXzt2+2c0Wf1mzNKiYtvwKPT0FRCarc+GDVQeaszmBEpxa8en0/3b/7BgYGZzHW2hVAeyZIXN7ZxH+We+gWX0uW29p8TeIbbMUzSSJPXNadh8d25ptNh3lnWTqZhZX4FAUJSIiy88chbVm8J4/th0tDMgu0myXG9WzJP6/s2VDD5kiJkzmrMvh4zSHN3oVQiHKYefCiTvz3590hmSyaJYEou5n/Xt2bw8VOymxmtjom8HmbPswcJ6EuvN3fgFcNX5mi+YAlO0sRzBZUZ1lIYy11eUOfmIGBAQAr0/J5c1Eaaw8W+n0Gq5Xnfbclm6fn72Ri3yTuHJVKhxZhIV3Tp2ivFXqFb7w1XMcgkI/XHOKNRftrXed/2HaE33blMiTFb2zb1KIeZS4fioYan16l2UqPj/GvLiO3zF3j/I9nST9Zl1nn9fyBqMqyfflcMX05X90znCi7OaSxGBgYnOXEdQSTDXyuGg/552gb/WeW8+jQGoQoRDPEdWqgAZ4+Db595LCYuHFwO24c3A6A7BdfxhIeRou77gTgzpGpvPr7Pt5feRBFY3cMIMwqYTNJPHBRR24e2r5Bml9LKr08/NlmVuzP9/tznEIgJQoQbjPxybQhdGsVSc+kKJ7+fif7csrwKgont0PYRPD5fITZ7ZQ4vdz/6SYEQFZVPD6FEZ1akGNPQUAEAk+WwqyaD1iSPRLV60Gwh+YFYhKNWncDg1BRVZUXftnD7BUHa9wAOp6F/3xDJt9uzubNm/pzQdeEOq8dU0NJmR7hG5Mo1HgdgxO8sWg/0xfuq3PDTDkmK74yrYBrZqzki7uGNWlAZTdLiBpSs3oDbp+skl2DiMXp4PYpZBZWcvOsNXx+17Bz2iDawMDgGL0nwcJ/1XpIx1iRST3MvLbWQ68EjXVDlGDAzQ00wNOn0XPxljA7SuUJrwqTJPLoxV24/8JO/LLjKLOW+zNYbp+CXYKkw/t58OFrGNU5oUazwdMlp9TF1W+uILfMHVDqESomUcAkCnRrFclrN/Srqhs/r30s399/Pvtyypi1PJ3l+/Mpd/kwSQJWk0RhhRuzT6XY6c8OeeVA2cgV+wu45JV8rJ2CZaJZuwAAIABJREFUb/j2jmH+B6w9KwnrNqLqdXN8O0BFsoWWDo0LNx68DAxC5d8/7WbOqoyQMumyAk5F5u6PNjBz8kBG1WIS683NZVhFJhsUM24xcEdfj/CNWRIZmto8pWObCz9uy+b1hftCqho4jtunsD+nnHs/2sisW89rwNHVTlK0HVkjM6VXaVYluFSvLiXAUI/1yCr7csv5YVs2V/VLPuW5GhgYnCVEJELqBbD3F/yrjzZ/H2VlztYaqqVa9YXYlIYZXz3Q6MGUYLejFAY3o1lMIhP6tGZCn9YBr+8b/R/a2S9vsECq3O1j0sxVHC11n9IuncMicVW/JKYM70DHBO0AplNiBP/5Q++qfy/YmcP9n2z038zFmkshjpcOmSs7IDrSAt6ztjhI1PBJFP42E9HqCFDzk8LjqNy7EnuH/rU+eDksEpPOa6N7zgYG5yK/7DhaYyBVlwfU3XM3sPDR0bSMOqFW5MvLo3TBAsp++hnXnj2MveAiXjcNPzkJDYQufJMUbWs2/T3NEVVVefr7XTUGUrV+jz6FlWn57MguaTJ/vm6tIkiIsHKwINA893SUZiF0JcBQj630yLy1OM0IpgwMDPyc/wikLwGvs+ql6l6qAG2iRFxPapjMmx0w8rGGHuFp0ejBlGh3oDhDd1G39+tL5aZNWNq1a5DxvLloP0dO0bPDbhZ586b+jO5SdwnPcTZnFp8IpELElT8Ke3IWgniiQV2Q3MRdnIJo+yNFi2b5faYsDhydhtBiwp9wpq2rW4ZdhYl9k0Ieh4HBucwrC/ZqBlKhekB9uOogjwxsQdmCBZT+9DOuXbsIHz2a2Cm3EjZ8OKLVyvh5m/l282FN8Zq6hG8cFom7DI+pWlmVVkBZDX2ioXyPHlll1rJ0Xp7UtzGHXYUgCNw1KpWn5+8MEnU6VaVZPUqAeo7NLHSy/XAJPZMMY3gDg3OetoNhxGOw7KUqpeqQMDtg4FToNLbhxlYPNH4w5bCjVjrrPvAYjr59cW7eTPTEifU+Fq+sMHf1qZtkOr0KM5cc0BVMPfn1tlPYFe2GqpgDgikAS+xyIvreQ0SfcUHXquvByywJTOynX3HMwOBcZGd2KQcLgn0yQn249PgUPly4i8v+/l+iR44g9pabCTv/fERrYLPtI2M7s2BnDuXuYKf42jCJAq2j7Vzeu2GMzM8WZi7V9mgK9XuUFZUfth3hqSt7EGlrGoGFK/q25tkfd2m+V9e6r4UeJUBdx/pkFu3ONYIpAwMDPyMeBUWGFa8EZKhqxOyAAbfCxbX3WzUHGr07VLTbUZyhB1P2fv1wbtrcIGP5dUeOZkbq+I01duzdOLoMQ7TYECQTjo6Dg8olNh4qIrMwtCh7b04Z+/O0JctL135N4e/vEDXkOpLvm0vS3bOJ6D8e5741gIjryDWoSuDNW7IdxRy7EgR9UsiiAPHhVh4f10XXeQYG5yofrDyo2U+p5+FSMVs4NPMzkl58gYiLLgoKpADaxDr4YMp52JH9qeMQMEsCLSKsfDJtSJOrzTV3NmRo+5To+R4tJpGd2aX1PbSQcVhMzLrlPGxmfbfvmrQg9CgB6jlWUSGvXFvu38DA4BxEEGD04zBpLiQNRJFseDlpLREkvIKVo45OcPU7MO7f/vOaOY0eTAk6gylb1654MjORy+v2TdLLlxvrxyRzwc6ckI6btTwd78mSfoQWvMnlXXHnXBYUUFnjf8YctTnkgMosCcRHWJl359BmZ0hpYNBc2X20VHPjRZcHlAoHS+q2Iuiw9Ade2/8VLcIthFlrvq4g+Ev7urWK5McHRhAfUYOkrEEVTo31HvR9j6oKpc6mtZQY1CGWtycPxGGRQnrOsJlEUhPCsWpEVNWVAOtCz7EGBgYGmnQcA9N+Z/7QeayKuRLaDoPEntB2CPS/mYIbf2ac6zkqUoKrrporjVrj5ZMVclUrB30WKKggLtxKeB1lZoLFgq17N1xbtxI2bFitx+olr6yeTDIrQgtkFu7KDZJHh9CDN2/xEKxCNLakzxAFqPRVIghgbfkVovUo7rwxoIqg2oLOPX4TPb9TC164pg+xYUYgZWAQKjWV3emRpJaVun3dypcto+Cddxj1ySesbtWK33fnMmNxGjuPlAbITHt8Chd0TeCOkSn0axPdIHYRZyOSKCBrZBj1fI+CAOZmIPk9snM83947nJd+3cOiPXkIgOukkvUwi4TFJDJleAduGtyWwc/9HnQdPUqAeo6VBEgwAnwDA4Ma+OFIOONGPgMnCdW0BIZ02MCXG7O4eWj7JhmbXholmDpuhDt3dQYer4zYahy8ugyvrDCiYzx3jEphcIfYGh8IrH36sm71DtTYjsiKSrTDTM+kqNMuaalJvU+vZ4ccoklmTQ9kunZFK3vw1x7zsEbtZNb2WaSXpGMRLZC4BU+LjVhcA5GKx5FfYsHtUzCJAtEOCzcObsMfB7cjITI40DIwMKidmnoLdT1cihBhq3nJde/fT/bjfyZ5+utYkv3CMJf0aMklPVpyuNjJkWInlR6ZcJuJDnFhxBgbIrqJsps1S8/0fI+yohIf3jyChE6JEcyYPJCCcjefrDvEkj15lDi9mCWRVlE2rj+vLRd0TUASBVRVJcJmoqgyMKDXowSo51iLSeTCromN9lkYGBicOfhkhVVpBTwzsafm+1PP78DjX27lj4PbNZiad33SoMGUx6fw+Jdb+XHbkUAjXMkCx8otFu3JZXV6AS3Crbx360A6JpyQSswrc/PxmkPMLu2G2+3F9OkmUP0q9Yqqcu2AZKYM70D7FmGnNL6YMO0G4oYyyaxp81jXrihgkayMTxnP+JTxFLoKKXYVI6syUdYo4u3xVUGpoqhnxC+hgUFzp0vLCHYcLg3y+NHzcGkzQ/tjHnQn4ysqIvPue0j4vz/h6N8/6P2kaDtJ0fb6ndQ5yLUDk3l3WXqV7cRx9HyPkTYzPVpryPc2IXHhVu67oBP3XdCpxmMEQeDWYR14c/H+INElPUqAoR7bLi6M7s3sczIwMGgebMkqoXW0nYQI7Q3+89rHEGaVWLQnl4u6JSIrKmUuLwICETZTs3u2bbBgyuWVueGd1ezKLq3VZV7F70mRWVjJldNX8PG0IfRpE82s5Qd4/uc9ALh9+P2YXIGZnY/XHOLTdZn8oX8yz0zsiaTzwx3XoyUbM4qD5I71mmQO79gieF6qSl6Zm2KnF1GAGIeFSJspSM4W9BouCsQ4TgSBsbZYYm2xmkc2t182A4MzlVuGtmf+liOa0ug1PlwO+wOmyE1Y4hYjWvMQUPjLZpH/7orm2s7XMqnLJOId8ageD1n330/kuHENolpqcILJQ9vx7vJ0zfdCCRLMksDU89ufsWWVNw5uy5uL92u+p0cJMCSZ/lGppzRGAwODs59l+/IYWYuRvX/zpz0v/LKHNxbtZ3NmMSZRBFR8isr5HVtw56hUhqXGNYv1uEGCKVVVueejjeysI5AKOAeo8Mj8cdYaruzTmi83Hq7zXK+igqLy9aYsjpa6eOfmgboCqqv6J9coMRvq7lubWHuA9Gu528fXG7OYueQAeeVuTJJ/PF6fQpjVhCTAySX7eoI3FZX+7WJCnqOBgcHp0zMpiuRYO/tytIVwAh8uVSxxi7DEfQ6AIJ3oqVRUhUJXIbO3z2b29tkMbT2U+xbZCYuOJv7hhxp6Guc8raLsDOkQy4q0As0y77qCBFlRmbcuk55JUQxLDd5Ea+7ER1gZ36sVP207EtRfVV9YJJGuLSO4zJDpNzAwwL9ubjpURG6ZG49PIdJuYtHuXB67pGZF6R+3ZfOv+bsocXo5vlJXryhYui+f9RlFRNrMvDKpL0NT4xp4FrXTIMHUqrQCVh8o0AyG6jLCLXP5+GTtoaCAozacXn/t5T+/38HTV2rXX2oRbjVxRZ/WfLkhS/PnhbL7dvfoE7tv7y1P5/lfdiMKQlUGqnqblKey5ubzUHdFbxjU1pA/NjBoAh4e05lHP9uimZ06gYKt9aeYInYhiDX/vXsUf4C1InM5O5IEPnrwSwSx6UUNzgWev6YP419dRlGlBx23GWxmkRf+0BuzSeJPn2+lb9tonhjfjdZnWPnlv6/uxf7ccvbmlIW82RkqVpNIuzgH708dhFkyfp8NDM5ljvdyzl5+ELdPBvy9m4IA5W6Z95anYzVJnNc+JiC79O6yA7z4654aPVmPU+mRqfTITHl/LS9e24fLe7du4BnVTIMEUzOXpmmWs4VihAvBmZvj1BaIOb0y89Zlcu8FHUnUIbLw4JjO/Lz9KKUu/SaZsWEWlGOZsZ+3H2XJ3rw6v/zaqCt4EwWBW84QZRMDgzOF7YdLWLQ7t0qYIDHSxkXdEujaMrDfY3yvViydv5SvC624pRr6LRO/rTOQqo5PkCl0iExdeg+fT/icCEtE3ScZnBYto2zMu3MIE15fHnJ2xmYWeXJ8Nyb09QuDjOocz1tL0rjstWXcPiKF20d0wGo6Mza5bCaRT6f0ZdrH29icVaJ5rw443iwiIKCoCj5F1VSktUgCgiBwQdcEXrmuL3bLmfFZGBgYNAxfb8ziL19vAzVYZfQ4i/fksSa9kB6tI3nv1vOIsJn5fsvhkAKp6ri8Co99voW4MGuTZajqPZg6WuJi1YHCoNdDdZiviVADsTmrMmpNHZ5MUrSdubcP5oZ3VlPplkPaqRQFf9qyqMLD37/dgUdWNA096xO7WeL2ER1oU0MDu4GBQeh4fArfb8nmrSVpHC5y4vbJHK/6kkR4feE+UlqEc9foVMb3bIlJEsl7801uWzofx5Sn+WRbfnCvpe0Q5uiNmoFU0bIi8n/Jx5PrQbJJRA6IJPGaRKQwCQWF3Mpc3tz8Jo8Perwxpn/Ok55fQbTDQqeEcNYcLESAoCyNKIDVJJEYaeWpK3owuktC1Xt2i8QjYztzTf9knvlhJ5e8spR/TOjBBV0TaJZ4XbDja1jxKuTvJQyYo6j8bB7NDOv17HNG4FPAV630URSgQ4sw7hndkcv7tOJQQSWzlqfzzebDqOoxmXlFxWISuWlwO24e2u6My9IZGBjUPx+sPMi/f9pVZ0B0XDNhS1YJV0xfwbw7h/D4l9s0z6urqs3lVXho3iZW/+WiJumhElS15iBg4MCB6vr163Vd8IsNWfz92+1Bu13OAxvI/eKftH3s65AkwKujuCvIeuMW4sY/VKdAQ6TNxIa/jdVdYrA/t4zJs9ZS6vRqGvmCX0lPVdWQ3Zjr+vJDxW6WmNivNc9d1atZNNoZnNkIgrBBVdWBTT2O0+VU1ieA4koPk2etZX9ueR0le/5S3h6tI3nBtQH5t19oN3s2pvh4lu7Nq2qKVVQVr6xiS5qLKWIHghC4pub/lE/eT3kk355MePdwvEVesudkI5fJdHiiA+IxzyKHycHS65dilZqH7PbZSmZhJVe9uYK3bx5I/7Yx5Ja6mLsmgy/WZ1Hs9CIrKg6LxJCUOO4YmULfEHy8Fu3J5envd5LSIoy/T+hOu7jQFWZVVaW40kuJ04skCsSEWer0X9RxcVj2Eix/GRDAo9HzJ1nYpyaxIPxK8tpfzqpDTkqcXl6Z1JchKcG7vF5ZoajSg9MjE241Ee2w6BZ/qo2zYX061bXJwOBMZ/GeXO6au0F3lZbFJNI6ykZemTvoGbymZIo7c0eApkCYReLtmwdqisLVB7WtTfWemSqu9ODVqAPQ46V0MqGa2oI/Y5RRUBEgsR4KHRMiWPH4hSzdl8fMJQfYkFHkN7oV/PejSo8PAZAJ7aZRVyYtlEBLFMBmlnhoTCemjUgxAikDg9Ok3O3jqjdXklVUGVI2udIjs/lgAbe5Ivl2lj+QAr9h6sjO8RwqqOSz9YfYm5/Dau8e1JMCKdkpk/tNLkm3JRHR278mWeIttLmnDXv/tJeSlSXEjDwhKPPrwV+ZkDqhHmdsUB23T+bejzdy9+iO9G/r/9wTIm08MrYLj4wNvaLhZC7oksCw1DjeW36QiW+s4KbB7bjnglQclppvsRVuH99sOsyMJWkcLXVVbQB6fAp920Zz18jUKo+oU0JR4KtpsOdH8FbWfJzsoRPpdCx/g4qdc3g44r/Me2giUXbtUlazJNYoZ2xgYHBu88z8nTUGUrU993p8ChkFlUHVYXqq2io8MjOWpDVYMFUbjWLaC/qNcKujJxATRYESZ2j9Clrnju6SwOguCeSVuTla4qLS4+Pn7Uf4ZG1myPX1dX35oZYsSqLAp3cMoXdy9CnNx8DAIJD7Pt5IdrEzKJCqbZH3IpIVnsBfFmby+g2BUq5t4xw8dklXvt63i81rTTh9gWtP5b5KFK9C5IDA/ivJJhHRO4LyHeVVwVSlr5Kv9n1lBFMNyLM/7KJ1lJ2pw9vX+7WtJr8g0cR+rXnux92MfXkpT1zWjUt7tgzYCFNVlXeWHeCVBfsQBKqqOLzyid3Y9QeLeDB7E1azxPQb+52acuDPj9cdSFVDULzYlSJm+v6GqF4EaFtuGBgYGGixJbOY7GKX5nuhPPdqbW/qSaYArNifj6yo9ZotD4V6D6ZiHBYskhhwYwC9XkqB6ArEVOpFRSg+wkp8hBWnR+bW2es0A6maHsBUj6vGL19PlK0oKnNXZ/D8NUYwZWBwuqTnV7AqLVhlNJRF3u1T+GVHDjmlLk2Bm0JXIV45eBNHLpcxhZsQpOCF3RRlwpnhDHitwFVwOlM0qIXvt2SzZG8e399/foNm+VtF2Xn9hn6sSivgqe928NGaDJ6a0INOiRGoqso/vtvB5+uz6iwxrfDIVHhkpr6/jpeu7cNlepSqsjfBpjngPfH71f5/ZVR6If3BcMIs/vm/u9HD3K1eFt/qL0uUUKAiH357Cq54TffcDQwMzl3eXXbgmGpfIKejmaC3qs0kipS5vEQ7LPoGf5rUu3bpyM7xAU2sVT+ompdS5d5VKF4XquzDmbaeokXv1XrN6oFYXXhkhbjw+us5+H5LtmaLVOnaryn8/R2ihlxH8n1zSbp7NhH9x+Pct6bWL19XyaIK327OxtNAfiAGBucSs1eko5y0Nh1f5GPH3o2jyzBEiw1BMuHoODigFhv8PZNzV2doXltWZRQ1+O9UCpfwlftQNUoKfSU+TOGB+1myUvsDtsGpcSCvnH98t4M3buxPpE27fK2+GZoaxw8PnM+YbolMens1/5q/k//9tjekQKo6Lq/Co59vYc0BHYH2yungcwe9LKvw6hqPxgnVULyw9TNwa3uqGRgYGGixOr0Qjcd/3dml6lRPpoSCIKA5hoam3oOp+AgrIzvHa3YWRQ66mpgLb6Nk1TyyXr+JrLdupWzjfOydav+A9QRibeMcJNWjotBbS4Jl3ut6AKvty9cbZQv4+9AMDAxOHY9P4YsNWX6j72roWeTdPoUPV2UEBWQAkZZIzBpy6Y6ODgSTQOmG0oDXZZdM2dYywroHChUY0uj6KazwsPpAAb/uOMrSvXnsyymjurCS0yNzz0cbefTizgEG642BSRKZMrwDvz48ktwyF6/+vl8zkKrYuZgjHzzEoZevIWv6ZHI++weurB1V77u8Cv/35VZqE4yqwlkEu+eDRnD/p2EWXlzppthVx3UEAbZ9XvfPMjAwMDhGpVvbYuh0NBP0JFPAL5ATaWu0DqYqGuQn3jEyheX7gqWDoW4vpZoIxdQ2zCJx96jUWq6iD39DXEXQ63U9gNVW0qi3d0wUhTp9QAwMDGonr9yN1nOo3kW+0uOjzO0Las4f1HKQ5vGSQyJhYgLZc7MRbWKAmp851kz0sBMlvFbJyug2o0Oe07mMqqpsPFTEzKUHWLwnzy8WpAIC+GSVllE27h6VyoQ+rfnHd9vp2jKCGwe1bbLxtgi30jkxAot0FM9JWcpQe2jzytxsPFTMgHYxJ18+kH0LQNS+tQ9sLTG6vYkXV7r514W1iEh4K/1lggOnhDxHAwODcxuTJALBz6uno5lQPZkiiBK2Dv0QRBOug5txHdoaVEHSOzn62DgalwYJpga2i+HCbgn8vitHtzyieEw9T2vfrM5ATPAba9YXpS4vZkkM6rGo6wGsti/fmb5BV++YrKiEN0GUbWBwNlHu8qG1vupd5E2iSLlGMJUSnULH6I7sKNgRdE78+HikMImj847iyfUg2kUi+0fS5s42iOYTg1JVlWs7X6t/cucYJZVepry/lt1Hy3B6ZVSVoFLo9PwKnvp+B3/7djtxYRYWPDKqSdVQZUXlvRUHgwIpPb0ETq/MO0vTGDC5DtXwinyQa65mePoCK8Pfq+DBwXX0FFQa/XsGBgahExtm0RSAOx3NBAgtmQIQZpW4qx4TKnpokKd0QRB45bq+TH1/HRsyCnGGEFCJAoRbTTx/bW8embdFdzbGZhZ5/YZ+2Mz157xuMYkoGtvZoTyA1fblmyLjQ46yzZJIdA0StQYGBqHhsEiaddR6F3lZUXHUsMZM7TmVv634G5W+YPW02FGxxI6qWR1NQOD8pPOJszeNe/uZQlGFhwnTl5NT6qpT2v74PSS/wsPKtALGdk9sjCFqsjenDLdGpYaeMlNVhUV78th1pPSYL5WH4kovxU5vwL+H5x5gkixT012jZ4LE5Z1N/Ge5h27xtezgGv17BgYGOrhhUBteWbA36Jlfb3ZJi1Cq2syiyJhuTWOc3mApD4tJ5IOpg3jm+518su4QgoBmlkpUFSySSIfESN6ePIA2sQ7iplq59b21VHpkzQzVydjMIs9N7MWFXev3Zhleg0dIqA9gNX35tuRuIUXZFkngpsFtmyRlaWBwNtEi3FqnME6oi/zenDK6tY4MEjK4oO0FtN3ahv0Fe/GJ+jpg7SY7Dw54UN+kzjF8ssIfZ60JKZCqjsencP8nm/jirqGN3jN1nKJKD6KGVK/eMlO3T+GhTzcR7bAQ7TATbff/P8phpm2sg2iHmc6HUxHXW0EjqD/OP0fb6D+znEeH1iLWZGuaz8rAwODM5LqBbXjp172a74WSXbKaRAa2j2FjRrEukR7wxwHPTOzZZM/LDVo/JokCT13Zg4fGdmLeukzeXZ5OcaUHk+jP+KjA2CiZq9OXcNFzz1edd177WL65dzh/+mIru46U4vN4kU+62Rw3tE2IsPLcVb0Y1gAmXaIocGnPVszfmh2wq91YUbYgCEwe2u50p2FgcM5jt0iM69Ey6G8ZQi8hAFBRuf3D9Xh8Cpf1asXtI1Lo3trvISV5ZJ76OYKH+lgpcMh4ldD87mySjdcvfJ2UqJTTnueZwJHyI3y6+1N+TP+RMq9fLCLMHMbI5JHc3P1mUqK1P4eFu3M5mF+hGUjVZYLu8so89+MuPp42pEHnphe9ZaYC8MMDI2p/YEi+FNY/Vet1OsaKTOph5rW1HnolaNW/WqHLpXWOx8DAwOA40Q4LF/dI5OftRzXX6VCee1++tg9vLknjs3Whq57azCIPj+3MhD467CPqmUZpxol2WLhzVCp3jEyh3O2j1OXDZhKJspsRPW72X/BvvIcPY05KqjqnU2IE39w7nG0/LOLdH3ezJLEH5S4fquov2RnRqQXTRqbQt010g9bCTxuRwoKdOUFfqp4HsFPBahIZ3SWe5BhHvVzPwOBcp6a/ZQhdGMcrq3hlv2LRt5uz+XH7EYakxDF9YlcKHryf6MREvpj8Kw8seYidBTvxKB5NyXQAh8mBRbLw1pi36Nmi5+lN7gzgUOkhnln9DBtzN6KqakCwWemr5Jv93zD/wHw6RnfkySFPBn0mM5akUaFR/h2qgMOGjCIOFzvrVe01VGIcFk0VSL1lplazWPfOa0x7aN0fDq2s9bC/j7IyZ2stAf/A2+ocj4GBgUF1nr6iJ+vSi8gtc+mSKLebRf4xoQeJUXaemtCDpGg7L/26F1EQagyqHBYJVYVnr+rJ1f2T62kGp4ZQm9TqwIED1fXr1zf4II4+9xyi3UHMAw+QUVBBidOHWRJoEW6F5/+FrUtnYm+5pcHHURNjX1nC/pzykEoO6wOLJNIuzsG39w3HUUOpoYHBqSIIwgZVVevoYm/+nMr6NP7VZezJKUOuRyMKq0kguaKAd8LS6PDPvyFI/gzDjvwdfLDzAxYeWohZPFES6FW8pEalMrXXVC5se2HAe2cr2/K2cceCO6j0VqJQdw+tTbLx/MjnuaCtP8BNz69g3P+WBokBKe4Kst64hbjxD9UZjFgkgSnDO/CX8d1OfSKniKyonPfsbxRWBAtDlK79ipI1XxF3yb21VjmIAlzcvSUzJg+o+wfu/QW+mAKeYDXauhGg01i4qWmk0c+G9amxnp0MDJojmYWVXDNjJYUVnpBKsm1mkYfHdObOk8QjSpxevtyQxdtLD1Ds9Fe1gV+MJ8wi8eRl3ZnQpzV2S/1pJdRGbWtTs3hSd192NW8+P5cfn1mAoqj+2nLVb8DbqjSV+9oMYIJPxmpqnA/sZN66aQBXvrGcCrf+hlyLJCII/lp3MQQzMYdFomvLCN6fOsgIpAwM6plZtw5k/KvLKHF6683Yz+1TybRG8/f245gjnsga9GjRg+dHPk+pp5SMkgzKPGXYTDYSwxJJCk+q5YpnF+kl6UxbMI0Kb+gP9i7Zxf8t/T/eGvMWA1sOZG16AaJGBYIeAQePrLJoT26TBFOSKDB1eHteX7g/KCAMtcrBapaYNjLEUtCOYyCxJ2RvBjnYvLdWzA4Y+7S+cwwMDAyO0SbWwU8PjuTv327n1505iBqaCYIAdrNEXJiFv13enYt7tAy6TpTdzNTzOzBleHuyipyUOL0IAmQUVPLmov1cd16bxppSnTTp07qsqPzj2+18viELJXkQXlew4dfB8ET+sTiLfy45zPSb+jOqc3yjj7NjQjhzbxvMzbPWUu7xafrVnIzdLPG/SX3o1zYGQRBwWCR+25XDW4vTyCiowOP1IR/zTJYEv2BHu7gw7h6dyvherTAbohMGBvVOqyg7X90znEkzV1FUGdqu2XFq68vxILLhUDEbMooY2D5QtS/SEkmv+F71PZVzHoP/AAAgAElEQVQzhseWPEalN1gMoWhZEfm/5OPJ9SDZJCIHRJJ4TSJSmH/TzCW7eGjRQyyatIhSpw+fEpzR0ivgUKZxj2ksrh/UltcX7td8L5Qy04QIK/3bRtd6TBWiBDd9Ae9eBEUZoQdUZjtc/xEkNH7AaWBgcPYQG2Zh+o39KarwMG/dIT5em8nhIieSKBBuMzGwXQx3jkqh/7Fn5NoQBIE2sQ6Oh04dE8L50+dbOFLiRFZUTKJItMNcr2reemmyYMonK9z2wXrWphf4d+qEmj+E43Xyd85Zz3+v7s2V/Rp/V7df2xi+v/98nvh6G+szilBVNcgzxCQKmESBrq0ieebKnvRKDlRDurJvElf2TWJndikLP5tOXlRPiE0hPsLKhV0TqxrZDQwMGo4OLcL45aGRzFiSxkdrDqGqqmYvTnVC6ctxeWXeXnYgKJg6l9lVsItDpYdQTyqSzv8pn7yf8ki+PTnAyPjgiwfp8EQHRJN/M8mreFmcuRiT1PHYDTfwOvp9wprOa6pFuJUHx3Ti9d/3n5JS1QvX9NHXH2yLhGmLYN4fIXM1+Dyg1vBzLeEgWfwBWHIIZYQGBgYGIRATZuGu0R25a3RHrpi+nGeu7EmfNiFuCmng9Mh8vzUbWVUZ/p+FWE0SKipeWWVQ+1juHJXCyE7xmuqpDUmTBVNPfL2dtemheVAdx+VVePyrrSRG2RiS0vh+LO1bhPHRtCFkFzv5cFUG320+TJnLh4JKuNXE2O6JTBnegdT48Fqv0711JN2tP8LF4yDp7G88NzBobsSEWfjL+G48cnFnft5+lPlbj7Axo4gCjZ6WUI1VVRWW7Mkjv9zt7/c04MOdH+JRAj9T2SmT+00uSbclEdE7AgBLvIU297Rh75/2UrKyhJiRMYBfmOKdVa9yZfpETJ4WeE7qL9Mr4NDU38vdo1LJKXHx2XodSlWSysvX9WFQh1MI0q3hcPM3cHQ7rHoDdnwNUrXPUPb4s1DDH4KulwW+Z2BgYFCPlLt8hNtOPex4f2U6z/+8BzhRNlh9HV11oICtWcU4rCam39CPwY0YJzRJMLU/t5xvtxzW9J2qW+JW4clvtvPbI6Mae9hVtI628+dLu/LnS7ue2gVUFYoz/KpLBgYGTYbVJFVljAc/95vmMXr6csySyMaMIs3673MNVVX59eCvQWqGlfsqUbwKkQMCM/GSTSKidwTlO8qrgimA/ZWH6GrPRy5rxcnaFXpsKhwWiRsGta3/iepAEAT+eWVP2sQ6eOnXvQgCNRrUh1kkbJLKdNMrDG05/fR+cMuecNVbMP55KEgDV4m/pC+iFUQ3n74DAwODs4vsYicLduaQX+7mSImT7zZnc3GPRHq01udj96/5O/loTUadCZgKj0yFR+aW2Wv536S+jOvZ6nSGHzJNEkzNXpGOT6NXIVSJ28NFTrZmFdM7+dRThU2Ks8j/f3tM7ccZGBg0GjUJzOjpy1FUlRJnaP5SZwMur8zREhflbh9hVhMJEVbCrP7bitPn1JSFl8tlTOEmBCm4DMMUZcKZ4Qx4zWoLI3ziVUz43cXXmw4HKTGGKuCgqjSpD0l1bh+Rwo2D2/LNpsPMWHKAIyVOzJKIqoJXVujfLoa7RqUwqnMC0saj8PmtMO13fwB0OlgjoHXfepmDgYGBgRaqqrJifwEzl6axJr0QEXAdE955c/F+Zi5No22so0ojoC5xuXeWHuCjNYd0V7I9NG8zc8OtjVJ63+jBVKXHx1cbD+M76YYYaikN+B3t312Wzms39GuUMdc7xRkQ3c4vZ2JgYNAskGqosdbTlyPgF5M529mbU8as5el8u/kwoiAgCgKqquJTVC7unsi0kSm0bQGiIJ7c5oQULuEr96HKalBA5SvxYQoPvi15ZA+3j+jA/K3ZmrL2dQk4WCSBawYkN5qEbig4LCZuHNyOGwa1pdTpo8TpRZIEYhzmQCXXAbfCweXw0+NwxWsB1yh2FVPk9vfwRlojibPFNajvooGBgUFteGWFR+Zt5vfduZpZd79Xo8renHKe+Ho7M5Yc4OPbBxNXQwl2YYWHF3/dE6SCCqFVsj36+RYWPza6wdfFRg+mtmSWaDYB6ymlkVWVpXvzGmJ4DYaqqmSVZ1HqLkXIWkV0dCtaqapx4zMwaCZEO8yaWSU9fTmCAHFhZ2+/VEmllzvmrmdLZjFen4KWGOIP247w265c2rew4Y0O/jwdHR0IJoHSDaVEDTpR6iG7ZMq2lpF4TWLA8bIqE2mNpF1kJLcN78B7Kw7qEnCQRIHESBt/Gtcl9Ik2IoIgEOUwE+WooV9JEGDC/2DmKNj2BZ7uV/Bbxm/M2j6L9JJ0zKIZAQGv4iUxLJEpPaZwWcplOMyG4buBgUHjISsqt3+wnjXpBZptPCdT6ZE5kFvOFdNX8OMDIzTXwHnrDqH1lBxqJVtuqZtNmcX0b9uwlWCNHkyVOL2a5rd6JW4rPE0ncauHUk8p3+7/lve3v0+ppxSTaALZjVeRif9qPFN6TuHylMuNG5+BQRNz7cBkpv++v6oc4Th6+nJAODWhgDOA/HI3V05fQW6Zq1ZJeUX1NwXvzanE4UgAS07A+5JDImFiAtlzsxFtYoCanznWTPSwwPJtk2Cidbi/PO+xS7pQ7PTy1cbDIQVUZkkgIcLKZ3cNJdJ2BosrWCPg2vf57bM/8LetL6DgF+cAv+LhcTLLMnlh/Qs8v+55Hh34KNd3vb6JBmxgYHCu8eIve1ibXhhSIHUcr6KSW+Zi6gfr+PLuYQHvKYrKu8vSg+7JeirZXD6Zd5Ye4K0/NqxKaaMHU6KgXd2mV+JWy8CxufHhjg95bdNriIg45WN9ANXu/1nlWby4/kVeWPcCTwx5gokdJzbNQA0MDLjhvLa8/ru2D1AofTkWSeCmwW3PyjI/l1fmxndWk1PqCirRrglZUXHnj8Ta8jsQA32O4sfHI4VJHJ13FE+uB9EuEtk/kjZ3tkE0n/j8LKKFSV0nYT6m4icIAs9e1YsuLSN46ftteGWFSiH4NmYz+/uPLuqawL+v7l1z1ucM4tOSnbwYE47bF+zZVR2nz3+veWn9SxypOMLDAx5ujOEZGBicw1R6fLy/suaqgdpK8ryyys7sUrZllQRYCu3NLdO8np5KNlWFRXtyT31iIdLowVRcuEXT9FavxG24tUn9huvk+bXP88W+L3DXYZZ4/Mb37OpnyavMY1rvaY0xPAMDg5OIC7dyQZd4FuzKQdbYWKurL0cQBG4e1r7hBtiEfLkhi8xCp2YgVauZcWlvLInfapZpxI6KJXZU3Vm8SV0mBb02eUBrBv1lKvsf/y/vpXvYfbQMl1fGJInEhVm4aUhbbjivbY11+Gcaiw4t4qX1L+FWQ6/IcMkuPt71MYmORG7sdmMDjs7AwOBc55tNh2uUAQilJM/tk3ln2YEALYSiCq9mL7PeSja3T8EnK5ikhtvobPSIpE9ytOaHo6eUxiwJzUaVSYs5O+bw+d7PccmukM9xyS7e3vo2rcJacXnq5Q04OgMDg5p4emJP1h0sorDCo1mOXBN2s8QjYzuTFH2aamvNEFVVmbEkTXOHsM6bpGrGnT8GW/wCEPWpHNokG5elXEbLsGCZ+ZL5P+BIac/ll57H2b5ayorM31f+XfN+UrSsiPxf8vHkepBsEpEDIkm8JhEpzP+Q4ZJdvLzhZa5IvYJwS+3+hwYGBganyttLD2gKToRakqeo8MuOo5Q4vUTZa68k0FvJ1hg0ej2KSRK5eWg7rBqlMJGDribmwtsoWTWPrNdvIuutWynbOB97p8BUnigITBnevpFGrI9yTzmvbXqtxhvfvif3seOOHex+YDfZH2QjV5z45XPJLp5b81xADbyBgUHjkRBh47O7hhITZiHUTSy7WWTK8PZMG5nSsINrIjbUYWYcO/ZuHF2GIVpsCJIJR8fBAZtf3sIRqOX9MYuhZ4lsko0+8X14csiTQe+pikLBrFnE3X77qU3oDGPZ4WV45ODPP/+nfI5+fpSW17Wk+5vdSflbCp4CDwdfPIhSrcdAFES+T/u+MYdsYGBwDqGqKocKtcuP9fo0Hio4cZ1ohxlFoxqieiVbKFgksUGzUtBEPlOTh7Tj7aUHNN+ru5QGeiVF0S4urKGGd1p8n/Y9gkZRS/5P+eT9lEfy7ckBDdcHXzxIhyc6IB4LLmVVZuGhhVzS/pLGHrqBgQGQGh/OTw+O4LHPt7A2vRBVBY9G3V+YRcJmlv6fvfuOk6o6Hz/+ee7Una0ssEtbepOmAtKkiIoaS2LUYG9YUGNiSWJ+STQxiUa/iemJXbCgEY2KLXYFadIEkSa9s5SF7Ts77fz+mAF3d2Z3Z7az+7x9zcvZW88dds+d55xzn8Mvzj2BS0Z0a4aSNo1PNxygLEaLY/w3SaF4z4UMycgm1/YBQRMkEIo9XM0mNhyWgzN6nMFlPX/GPz7ewoGi8FDpjqkuzjghmz4bV2C53XjG1n5zbg2e+fqZY8kmjgqWBTkw5wBdb+hK6rBUAJwdneTclsPGn22kYFHBsYmPywJlzFw7k8sGXqbZY5VSDS5W2vKjEhmSJwJF3m87E/pnp+K0W5RUuf8kMpJNgIn9OyZ+UQlqlmAqK83NXVP68fePNyeU4hbCs9g/fPGwRipZ/RhjmLl25rfJJiISufGVBkqZ8fUMDaaUakbZaW5euGE0e/PLmPXFDmYv30VhmZ9AyJDksDG0azq3TOrDxP4dq52fqrU4WFTeABlYhTTvBTx22S38Z/1/eG3Ta99+sTfh5838IT9Tup9NN9s5vLEkxJsfLqM8EORow6QAT8/fRvuSI9x0/jS6BEO1TvZ4vAuEAqw+uDpqeemmUkL+EGkj0iott7ltpA5LpXht8bF7CsBh72FyS3LpnNK50cuslGpbnDar2mHxiQ7JqzgXoM0Srj+1F//+bHNUwBbvZO1JThs3N8GokWbL4jB9Yh/yiv3M+mJHXAGVEA6knp82ir5ZLXPs98Gyg+R586KWJ3rj23BkA76gD6fN2ehlVkpVr0tGEvecM5B7zhkIhBtM2lrrfnWZU+uSgTUnNYd7Rt3DHSPuYOWBleR78wmaIOmudHqlDOKHs9YxJ7c45j3BEE65vtuZxsNbLV55dBEv3jiaDE/rrSeLfcXYLTu+UOVhfsHiIPYUe9SkxwD2dDtlOyo36NktOwW+AjqjwZRSqmFZltDO4+RwjOHgiSSX8wVCdKny3PEVo7vz789iZ9mtbSQbQPtkJyN7NO4cU9AMz0wdJSL86rwT+NV5J5DqtpNczcz0NhGSHDb6Z6fy5u2nMqJHy53DpaC84FgK34pqu/EFiisPeXFYDgp9hY1WTqVU3bS1QAogK81FrM63RMetd0z99pkpl83FmM5jOKfXOZzX+zyGdxzDdc98zdq9hXE1roXnsSriwn8vrDQspK2wpdgIFAcwMeb7ChQEsKdEt5PGGn6ulFINobppQSoOySvduJiQ34sJBijbspwjn82otO3Qrulkp7krLeuQ4uL2yX1JciQ+CsHtsPi/S4Y1yX272fOLXzWmB1NH5vDhulzunr0Kf9BgtwmhENhswnlDO3PD+F4M6Zpe+8GamSUWsfo6K974qgZUsW58BqM3PqVUi3D24E48PX9bVJCTyLj1ZKeNC06svlfk7tmr2Hm4NGoy4NrmJtlb4OX2l1by3LRRDXvRLUSqMzXm82Wevh7ELhSuKCR91Lf3xqA3SNHqIrIvya60fSAUIN3V8u+hSqnj01VjevBENbkQ4hmSl+y0cetpfWLuf/vpfckt9MY9WTuEA6mHLxrGuD4dEr+YOmj2YArAabcY0aMdyS47y355JqWBIE6bhbsOkWhzynBlRA3HgLrd+NKcaVUPo5RSTW5I13S6tUti04HiqHXxjlt3O2xM6p8V8/h788v47JuDUWPi45mbxBcI8cXWPHbklbTYpET1YbNsjMgewbL9yyov99jIujCLvbP2YrmtSkmNHJkOMsZlVNo+y5NFtqfyfUYppRpKdpqbSf07Mu+bA/hi9JjXNCRPCM8de9qA2PcIEeGBC4fQrV0Sf/1oE0FjCFYzeXyyy4bDZvHPy09mQr/GTzxxVLMHU0efQfhiax6je7XHbrdIi9FVeDxon9SeHmk92JxfeXxnoje+kdkjcdhqzrOvlFJN5dbT+nDvnDUx5xGpbdy6224xbXzPahN1PLdoO6bKTO7xzk0CEAoZZizczm+/OziRSzpuTBs6jbV5a6My+nU8tyO2ZBu5s3PxHfBhJVmkDU8jZ3oOluPbe6jH7mHakGltcoiqUqrpPHLJiZz7j/nkFnqrDXZicTssnr9hdLX3CH8wxCvLdvHSkp3YLPD5Yx87O9XFz84ZwIUndW30VOhVNXkwZYJBiufNI++pp/GuX4/xesFup4c7mUtO+w7+s7rhyD5+W9CmDZnGA188UO8bn1JKtRQXnNiFF5fs5Os9BfhqSIMbizcQ4l+fbuarXQXcPLE3I3q0O/bFPhgyvLhkZ1RLZiJzk/hDhleX7+JX554Qc8z+8W5cl3F4HJ6oewpA5qRMMifV/BxxyIQ4t9e5jVU8pZQCIN3j4PXbxnHpE4vZV+CtMWV6RWX+EOf9Yz7tPE4uG5XDVWN6HHt2qsjr57qZy1gXx/O0R0r9/OF/GxjWLYP+2an1vp5ENOmd58jsV9h06nj2/uweylauDAdSAIEAScUFdP3gdbZMOYudN08ncOhQUxatwZzV86xqn3fKnJRJvwf7MfipwZzwjxPoel3XYzPVH+VxeBjbpW3Mn6KUOj44bBbPXn8KvTokx5xwvTZl/hAfrd/PNTOWMvFPn7FgU7h+LyjzxwzOEku7DiFjyCspT7hcxwNLLB489UHcNnftG1fhtrn51Zhf4XF4GqFkSilVWXaam3d+PIEbxvcKJ5dzxVeHB0KGg8XlPPn5Vib+8TOmPbuMffllXPHUEr7enR/Xs1K+YIgjJT4ufmwR2w+V1PdSEtIkwZQxhtw/PMT+hx8imJ9PqKSai/T7MD4fJYsWsfXC7+PbubMpitegXDYXvzv1d3W+8T084eFwIgullGpBUt0O3vzhqUwZlI3TbiUcVBkDpb4guw6XcePzy3h1+S5KygPYY2Q5rZh2PR42Syj2xp4IuDUY13Uc9425D5fNVfvGEW6bm5uG3sSFfS9sxJIppVRlKS4795wzkBX3TuH/Lh7GpP4dGdQ5lYwkR8zMsBWVB0KUB0LM33iQyX+ey8b9RTGfwSpZN5d9z93Jzr9cwu5/Xc3+V36Dd/daDFBSHuCqZ5YQSmCoYX01ybf2Q489Tv6rr2LKvPHtEAgQzMtjx1VXEzh8uHEL1wjO6nkWd4+8O6GAymVz8dtxv2V059G1b6yUUs3A7bDxryuGs+CeyUyf1JuMJEedJi32+kPcN2cNy7Yfjjm2PtG066EQeFzN/ghwo/pu3+/y98l/J8OVgcdefU+Tx+4hxZHCfWPv4+YTb27CEiql1Lecdovzh3XhuWmjmDKoE+WBEPHGN/6QwesPxRwqWLj0DQ5/8hTpY6bS7fZZdL11JqnDz6Vs0xIAQgaOlPiYt+lgQ15OjRr97uPbvp28J57AlFcegnHmls3kBYOVorn3evcmyx5JvGAMgcOH2f/ww3T94x8bu5gN7vKBl5Ptyeb+RfdTHiyPOd4dIjc+Zwp/GP8HDaSUUseFrDQ3d08ZwF1n9ueMP89l66HY9VtNqc29gXBAFSsUSyTtOoSH+bVPbr2T9x51atdTmTt1Lgv2LGDGmhl8dfAr7JYdQfCH/PTN6MsNQ2/gzO5nahIjpVSLsPlAMY/P2xIzMKrpHhFLvMmJSnxBnpi3hcnVZAhsaI0eTB1+4QVMMPZQjX937ca45BrS2QYCFH3wIcF778WWdvylCj+9++lM6jaJhXsXMmPNDFbuX4nNsmEwhEyIMZ3HcP2Q6xndabRmWlJKHXdW7y5gX0HsZ5XiSW1ugkEG+fL4SjIIVnk+Kt606zZLOH9Y5+NuKo26slk2JuVMYlLOJMqD5RSUFxAyIdJd6STZk5q7eEopVcmMBdtijkCI5x5RVSLJiVbuzGd/oTdqIuDG0KjBVKisjPzX34BAPcayW0L+G3Nof+01DVewJmSzbEzsNpGJ3SYSMiGK/cVYWCQ7kjWAUkod156ev5XyQHRjWbyth6VBOJiUgSPkIBij1bK2tOsADptw44Te9biK45fL5iLL0zQtr0oplahSX4A3Vu4hEKr79BcVJZKcyGm32HW4tEmCqUZ9Zqpk8ReIrX6thabMS8F//9tAJWpellikOdNIcaZoIKWUOq4ZY3h/bW7MMfCJtB4eCDkY2DkVR4xEFLWxW8LATqmc0Pn4G7mglFKt3ecbD8V8rjaRe0RFiSYnKi5vmsREjRpMBfIOVTvED+BHe3YzetNGRm/ayO17dld/nCNHGqN4Siml6sjrr/5h4oRaD20WP5kygMxkZ0LJLCwJz2vy5DUj495HKaVU08krKScQqv/0F0clmpwo1d00iYka9yzBYDgfbjX+WdszUxWPo5RSqsXwh0JYArFq54qth7XeLCWciW/OD0/l0ie+YH9h7ZM9Ou0WHZKdzJ4+lqzUxh/CoZRSKnGBoIkZBiR0j6ggkeREvkCInMymmWOvUYMpW1oa1HOYH4CVktIApVFKKdVQUl32mA8VQ+XWw+SB42s8TjBkSE9y0Dk9iXd/PJ4ZC7Yxc+F2/KEQJeWVQ7Vkpw2bJVwztic3TehNukcz1imlVEuV4XFgt4SqaYoSuUdUFW9yolG9Mpussa1Rg6mk4cPrl3wCwG4nefypDVMgpZRSDUJEGNIlndV7CqLWJdJ66LBZ9Gwfbj1MdTu448z+/HByXz7dcIB3Vu/jYFE5BuiQ4uS8oZ05c1A2DptObK6UUi3diB7topJPQOLTX1RVW3Iij9PG9Il96l3+eDVqMOXo1AnPyJGULFxY52OIzUbmNcdnJj+llGrNbjmtDz979StKfNGD/eJpPXTZLa4d1wN7leDIbrM4a3AnzhrcqdGvQSmlVOPo1s7DiB7tWLQlL2pdvD1MibKJkJ3mZlyf9vU6TiIa/cms9jfeQOnKlZjSypM6ftynb1z7uwYOxNWrV2MUTSmlVD1MGZRdY9KIeFKbXzW6R0MXSymlVAsxfVIfVu3KpzRGo1s89wiRGtMvVGJJOOnErBtHYyWQ0Ki+Gn2shGfMGNyDBiHOxGenF7ebTr/4f41QKqWUUvXlsFnce/4gkuowYW6Sw8YVo7uT1QRzgCillGoeE/p24ITOaTjtiYccLrtwYrcMPHHcY1x2i46pLub88FS6ZjTtBOaNHkyJCDmPP44jJwdxueLfz+2m88MPkXTSSY1YOqWUUvUxdWQO08b3TCigSnLYmNCvA/edN6gRS6aUUqq5WZYw8/pT6NYuCVcCAZXbYfH3y4bz2q3jeOjioQzolEqSw6Jqh1Oy00b7ZCc/OqMvH941iZ4d4sgS3sCaJAG7LSWZXq/MZtePfkTZylWY8nKIkXceQDweBOj6j3+QooknlFKqxfvZ2QPpmOLiofc2IIC3mtTmDptgifCDkd34zQWDm3QYhlJKqeaR5nbw1u3juem55azalU95IFjtPIUepw0RePLqkZzatwMA3zupK987qStr9xbwzlf7yC304g+G6JDiYkK/Dpw2ICuheQobWtPMZgVYycn0mDGDsrVrOfzscxR9+CHiiKS1FTCBAPasbDrcdCNp552HldS0XXRKKaXq7rpTe3HBiV2YvWwXzyzYhtcfPBYsGQPGGC4f3Z1rx/Zssrk/lFJKtQwpLjv/uXkMq3fn8/T8bXywNjc89M8AEp6TqlO6m1tP68MFw7qQ5Iwe7TC4SzqDu6Q3feFrIaaGp7pGjhxpli9f3ignDhYWUr55M8HCQiyXC3tWFs7evRHRlkqlGpOIrDDGjGzuctRXY9ZPqn6CIcP6fYUcLvERNIaMJAcndE7DXYdnq1Tb0hrqJ62blKpdQamfTQeKKPT6cdttZKW56ZvVcueVraluarKeqapsaWl4hg9vrtMrpZRqJDZLGNK15bUeKqWUahnSPQ5G9sxs7mI0CJ35UCmllFJKKaXqQIMppZRSSimllKoDDaaUUkoppZRSqg40mFJKKaWUUkqpOtBgSimllFJKKaXqQIMppZRSSimllKoDDaaUUkoppZRSqg40mFJKKaWUUkqpOtBgSimllFJKKaXqQIwx1a8UOQjsaLriKKWaQA9jTMfmLkR9af2kVKt03NdPWjcp1SpVWzfVGEwppZRSSimllIpNh/kppZRSSimlVB1oMKWUUkoppZRSdaDBlFJKKaWUUkrVgQZTSimllFJKKVUHGkwppZRSSimlVB1oMKWUUkoppZRSdaDBlFJKKaWUUkrVgQZTSimllFJKKVUHGkwppZRSSimlVB1oMKWUUkoppZRSdaDBlFJKKaWUUkrVgQZTSimllFJKKVUHGkwppZRSSimlVB1oMKWUUkoppZRSdaDBVCsiIqeJyO6m3repiEh3ESkWEVtzl0UplRitn5RSx5Pjtc4SkfdE5NrmOHdbpcFUDSI3xqOvkIiUVfj5ykY873UisqCxjt8QRMSISN9GPsd2ETnz6M/GmJ3GmBRjTLAxz6vU8UDrp+pp/aRUy6N1Vs0kbKuIrEtgn/tFZFbFZcaY7xhjnmv4Eqrq2Ju7AC2ZMSbl6HsR2Q7caIz5uOp2ImI3xgSasmxKqbZN6yel1PFE66xaTQSyALuInGKMWdbcBVLx0Z6pOjjafSsiPxeRXGBmrJaPiq2jIuISkUdEZKeI7BeRx0UkqQ7nvl5E1otIUaQFY3qMbX4pIociLadXVljeIGWIcb77ReQVEXk+Uq61IjKywvr/JyJbIuvWicj3q+x/U4VrWiciw0XkBaA78Hak1eoeEekZ+UztInKpiCyvcpy7ROStxrxWpVo6rZ+izqf1k1ItmLNM4vQAACAASURBVNZZx1wLvAn8L/K+YhkGi8hHInI4cq5fisg5wC+BSyP10FeRbeeKyI2R8uWLyJAKx+ko4R7BrMjP54vIqsh2i0RkWD3K32ZpMFV3nYBMoAdwcxzbPwz0B04C+gJdgV/X4bwHgPOBNOB64K8iMrxKuTpEjn8t8KSIDEi0DCLyqIg8mkC5vgu8DGQAbwH/qrBuCzABSAd+C8wSkc6R8/wAuB+4JnJN3wXyjDFXAzuBCyJDZ/5Y5XxvAwNEpF+FZVcALyV6rUq1Qlo/Vab1k1ItW5uus0TEA1wCvBh5XSYizsi6VOBj4H2gS+Rcnxhj3gf+AMyO1EMnVjymMaYceB24vMLiqcA8Y8wBETkZmAFMB9oDTwBviYirunKqahhj9BXHC9gOnBl5fxrgA9wV1l8HLKiyjyH8Sy9ACdCnwrqxwLZqzhV1rBrKNQe4o0K5AkByhfWvAPfVVobIvrsT+DwM0Dfy/n7g4wrrBgFlNey7Cvhe5P0HR8tf02ce+bln5Lz2yM+zgF9H3vcDigBPop+3vvR1vL+0foo6r9ZP+tJXC35pnRV13quAg4Qfv3EDBcD3I+suB1ZWs9/9wKwqy+YSHkIJcCawpcK6hcA1kfePAb+vsu83wKTm/v043l76zFTdHTTGeOPctiPhm+gKETm6TICEsz6JyHeA3xBuDbEix/26wiZHjDElFX7eQbglo8HKUI3cCu9LAbdExj2LyDXA3YS/bACkEG7pAcgh3DJcFy8BfwZ+R7jVd44xpjTSfd2Y16pUS6f1U2VaPynVsrX1Outa4BUTflYsICKvRZa9Qf3qoc8Aj4iMBvYT7kV7I7KuB3CtiPyowvZOwtenEqDBVN2ZKj+XEP7DAkBEOlVYdwgoAwYbY/bU9YSRrtfXCA85edMY4xeROYT/gI9qJyLJFf74uwNrGqoMdShzD+Ap4AxgsTEmKCKrKpR5F9Cnmt2rfsZVfQR0FJGTCLfc3BVZ3izXqlQLovVTfGXW+kmplqHN1lki0g04HRglIhdHFnsIN/p0IFwPXVbN7jXWQ5E67RXCddB+4B1jTFFk9S7gQWPMg/Upv9JnphrSV8BgETlJRNyEu14BMMaECN+w/1rhob+uInJ2DccTEXFXfBFuMXAR7goORFpUzoqx729FxCkiEwiPBX61jmVoCMmE/9gPRs55PTCkwvqngZ+KyAgJ6xv5ggPhP/ze1R3YGOMHXgX+RHis9UeR5c11rUq1VFo/xab1k1ItU1uqs64GNgIDCPccnUS4p2w34SDoHaCziNwp4aQSqZGeJgjXQz1FpKbv8y8BlwJX8u1zm0TKf4uIjI7Ub8kicp6En9FSCdBgqoEYYzYSHs7xMbAJqDqnwc+BzcAXIlIY2W4A1RtHuNWj6uvHhMfsHiE8dOStKvvlRtbtJfwQ4y3GmA2JlkHCWWker/mqa2eMWUd4qMtiwn/0QwmP2T26/lXgQcJ/4EWExytnRlY/BNwr4SwzP63mFC8RHhP8qqmcSjXRz1upVkvrp9i0flKqZWpjdda1wKPGmNyKL+Bx4NpIT9IU4IJIeTYBkyP7vhr5f56IfBnr4MaYJYR7+roA71VYvhy4iXBCniORa7mumjKqGogxtY1UUEoppZRSSilVlfZMKaWUUkoppVQdaDCllFJKKaWUUnWgwZRSSimllFJK1YEGU0oppZRSSilVBxpM1ZOIPCsiD0TeTxCRb5rovEZE+jbwMY9dS1Pu21RE5Jci8nRzl0OppqL1U/33bSpaP6m2Tuur+u9bHyLSXUSKRUQnEE9QmwimRGS7iJRFfkn2R35RUxr6PMaY+caYWtPbish1IlI1zWeDEZG5InJjYx2/vhr7+iPnOE1EdldcZoz5gzGmxX4uqm3S+qll0fpJqeppfdUyReoUIyI/T2Cf7SJy5tGfjTE7jTEpxphg45Sy9WoTwVTEBcaYFGA4MBK4t+oGImJv8lIppZTWT0qp44fWVy3PtcBh4JrmLkhb1JaCKQCMMXsIT1o2BI517/5QRDYRnggNETlfRFZFJmNcJCLDju4vIieLyJciUiQiswF3hXWVWhtFJEdEXheRgyKSJyL/EpETCE/ENjbSspMf2dYlIo+IyM5Ia8/jIpJU4Vg/E5F9IrJXRKbV9fpF5FURyRWRAhH5XEQGV9mkg4h8FLm+eSLSo8K+AyPrDovINyIyta7lqFKm7SLyUxFZHSnXbAnPTo6ItBORdyKf4ZHI+24V9s0UkZmRz+WIiMwRkWTC/8ZdIp9xsYh0EZH7RWRWZL/3ROT2KuX4SkQuasxrVaomWj9p/RTZT+sn1eJpfdUy6qtInXIJ8EOgn4iMrLL+JhFZHynHOhEZLiIvAN2BtyOf3T0i0jPyb2gXkUtFZHmV49wlIm9F3tf4Gbc1bS6YEpEc4FxgZYXFFwKjgUEicjIwA5gOtAeeAN6K/OI4gTnAC0Am4ZmnL67mPDbgHWAH0BPoCrxsjFkP3AIsjnSnZkR2eRjoD5wE9I1s/+vIsc4Bfkp4Bux+wJnU3XuRY2QBXxKe0buiK4HfAx2AVUfXR/5YPwJeiux7GfCoiAyq5vrzRWR8AuWaCpwD9AKG8e0s3BYwE+hB+A+/jPBs3Ue9AHiAwZFy/dUYUwJ8B9gb+YxTjDF7q5zvP8DlFco7KHKOdxO9VqUaitZPWj9FaP2kWjytr1pMfXURUEz4M/yAcC/V0X1/ANxPuMcqDfgukGeMuRrYSaSX0RjzxyrHfBsYICL9Kiy7IlJmqOEzbpOMMa3+BWwn/IuWT/iP8VEgKbLOAKdX2PYx4PdV9v8GmARMBPYCUmHdIuCByPvTgN2R92OBg4A9RnmuAxZU+FmAEqBPhWVjgW2R9zOAhyus6x8pd99qrncucGMcn0tG5DjpkZ+fJVxBHV2fAgSBHOBSYH6V/Z8AflNh3wfi/Peoev3bgasq/PxH4PFq9j0JOBJ53xkIAe1ibHfs36LCsvuBWZH3qZHPvEfk5weBGZH3NV6rvvTVkC+tn6r9XLR+0vpJXy3spfVVtZ9Ls9RXke0/Bv4WeX955LNyRH7+ALijhn/LMyv83DNyDfbIz7OAX0fe9wOKCDcO1fgZt8VXWxrTeqEx5uNq1u2q8L4HcK2I/KjCMifQhfAv2R4T+c2J2FHNMXOAHcaYQBxl60j4F3SFiBxdJsDRjCpdgBVxnLNGkdadB4EfRM4ZiqzqABRE3h/7LIwxxSJyOHL+HsDoo93oEXbCrUoNIbfC+9LIORERD/BXwq3C7SLrUyPXkgMcNsYcSfRkxpgiEXmXcIvQ/xGugG6KrG7sa1WqKq2ftH46Rusn1cJpfdVC6qtI7+Bk4BeRRW8CTwLnEe75ywG2JHrciJeAPwO/I9wrNccYUyoiWdT8Gbc5bSmYqknFP+ZdwIPGmAerbiQik4CuIiIVKoDuxP5F3QV0FxF7jArAVPn5EOHhIYNNeAxyVfsI/0Ec1b36S6nRFcD3CHdrbwfSgSOE/wiOOnYeCWfoySTcerQLmGeMmVLHc9fVT4ABwGhjTK6InER4SIFEypQpIhnGmPwq+1X9jGP5D/AbEfmc8FjtzyLLm+talYpF66dvaf2k9ZNq2bS++lZT1FdXEx5u/HaFwMZNeKjfnMi5+lSzb2310EdAx0i9djlwV2R5bZ9xm9PmnpmKw1PALSIyWsKSReQ8EUkFFgMB4Mci4pDww8CjqjnOUsJ/tA9HjuEWkVMj6/YD3SJjhjHGhCLn/Wsk4kdEuorI2ZHtXwGuE5FBkZbQ38RxHfbIOY++HISHjpQDeYRbFf4QY79zRWR8pGy/B74wxuwiPF65v4hcHbl2h4icIuEHQBtTKuE/2nwRyaTCtRtj9hEes/yohB8Ed4jIxMjq/UB7EUmv4dj/I9xC9DtgduTfAZrvWpWqjdZPWj9p/aSOF1pfNX59dS3wW8JDjI++Lo6cuz3wNPBTERkR+TfoK98mwtgP9K7uwMYYP+HnsP5EOBD8KLK8ts+4zdFgqgpjzHLCwyn+RbiVYTORh42NMT7CD/pdRzgF5aXA69UcJwhcQPjBvJ3A7sj2AJ8Ca4FcETkUWfbzyLm+EJFCwmNgB0SO9R7wt8h+myP/r81jhG/yR18zgecJd2nvAdYBX8TY7yXClcthYARwVaQMRcBZhIed7CU87OX/AFesk0s4O8yEOMpZm78BSYRbQr4A3q+y/mrAD2wADgB3Rsq7gXDL7lYJP7zZpeqBjTHlhP/9zuTbhyoTvlalmorWT1o/af2kjhdaXzVufSUiYwg3uPzbGJNb4fVW5NouN8a8Sng44kuEn3maQzgwAngIuDdSB/20mmt/iXAd9GqVXsFqP+O2SCoPV1VKKaWUUkopFQ/tmVJKKaWUUkqpOtBgSimllFJKKaXqQIMppZRSSimllKoDDaaUUkoppZRSqg5qnGeqQ4cOpmfPnk1UFKVUU1ixYsUhY0zH5i5HfWn9pFTr0xrqJ62blGp9aqqbagymevbsyfLlyxunVEqpZiEidZrxvaXR+kmp1qc11E9aNynV+tRUN+kwP6WUUkoppZSqAw2mlFJKKaWUUqoONJhSSimllFJKqTrQYEoppZRSSiml6kCDKaWUUkoppZSqAw2mlFJKKaWUUqoONJhSSimllFJKqTrQYEoppZRSSiml6kCDKaWUUkoppZSqA3tzF6A18wdDfLxuP09+vpXNB4vx+oM47Rad05OYdmpPLjy5Kx6n/hMopVRTKfGX8M6Wd3hx/YscKDuAL+jDZXPRJ6MP1w++nkk5k7BbWi+rtmvrwWJmLtzOe2v2UVwewBhIcdk5c1A2N4zvRf/s1OYuolItihhjql05cuRIs3z58iYsTutgjOHJz7fy7882EzSGkvJg1DYepw1jYOrIbvzyvBNw2W3NUFLVFonICmPMyOYuR31p/aQS4Q14+dOyP/HWlrcQEcoCZVHbeOweHJaD6cOmc9WgqxCRZihp29Ya6qfjtW76JreI//f6atbtLSQYMgRClb8f2gQcdou+HVP4w0VDGdYto5lKqlTTq6lu0mF+DSwYMtz+n5X87eNNFHoDMQMpgFJfkDJ/kNnLdvGDxxZT5PU3cUmVUqptKCgv4Kr/XcWbW97EG/TGDKQASgOlFPgK+Oeqf/LLBb8kZEJNXFKlmsfiLXl8/9GFrNyZT3kgFBVIAQQNeP0h1uwtZOoTX/Dphv3NUFKlWh4NphqQMYZfvv41n67fT5k/dhBVlTcQYsP+Iq6dsRR/UG/cSinVkHxBH9M/ms7Wgq2UB8vj2qcsUMbHOz7moSUPNXLplGp+a/cWcMNzyyj1xfe9BcDrD3Lbi1+yYsfhRiyZUscHHRjegBZuzuOt1Xsp80cHRSXr5lK4bA7+vN1YziQcWb1JHzcVd7fB+AIh1u0r5LlF27lxQu9mKLlSSrVOz617js35m/GHKvf+H5l/hEMfHMJ3wIfNbSNtRBrZl2RjSw4PufYGvby55U3O6XUOI7JHNEfRlWp0xhhunfVlzECqpu8tEO6luvn5FSz91ZnYLB0Sq9ouDaYa0BOfb6EsRoVUuPQNCpb8l/Zn/RB3r+GIzU7ZthWUbVpSqVJ6av5WbhjfS8fpK6VUAwiGgsxaNyuqR+rQe4c4+N5But3YjZRBKfiP+Nn7wl62P7KdXr/qhWUPD9rwBrw8u+ZZDaZUq7V8xxEOFUf32MbzvQXAGwgyb+MBTh+Y3ZTFVqpF0WF+DWRfQRlLt0V3d4fKS8hf8CKZU27FM2AcltON2Ox4+o6m3eRplbYt9gZYtCWvqYqslFKt2sK9C/EGvJWWBcuCHJhzgC5XdSF1WCpiF5wdneTcloPvkI+CRQXHtjUYFu1dxKGyQ01ddKWaxBPzohuBE/neUlIe5PF5W5uyyEq1OBpMNZD/fZ0bc3n5ng2YgA9P/7G1HqPEF+TlZTsbumhKKdUmvb7pdUoDpZWWlW4qJeQPkTYirdJym9tG6rBUitcWV1puicVHOz5q9LIq1dT8wRBzvzlI1VQTiXxvAVi58wgFpZpES7VdGkw1kNyCMsoD0c9KBcsKsTxpiBVf6vPcAm/tGymllKpVbkl0I1ewOIg9xY7YoodT29PtBIoDlZZ5g17tmVKtUkGZP+azTol+b3HaLPJK4kvuolRrpMFUA/HFCKQAbElphEoLMaH4suRUdxyllFKJCYQCUctsKTYCxQFMMDr1c6AggD0l+lFiX9DXKOVTqjn5gyGsGM9oJ/q9RUTwx/h7Uqqt0GCqgbRPcRErbYSr60DE7qB04+K4jpPhcTZswZRSqo1Kc6ZFLfP09SB2oXBFYaXlQW+QotVFJA9KrrTcJjYyXDo5qWp90pMc+GJMyZLo95ZAKER6kqOhi6fUcUODqQR4/UFyC7zsyS+juLxyi+fIHu1IckZ3iVuuZDLGX8nhjx6ndONiQn4vJhigbMtyjnw2o9K2SQ4bp/Xv2KjXoJRSbcX4ruNx29yVltk8NrIuzGLvrL0UrS7CBAy+gz52PboLR6aDjHGVAyenzcnw7OFNWWylmoTHaadrRlLU8kS+twAkO+1kpbqaoshKtUiaGr0WwZBh3sYDPD5vKyt2HMFhEwTBFwzRq72HW07ry/nDOjO2T3vS3I6YczWkjboIK7kdBYtnc+idRxBnEq7svqSNvbTSdiFjuHhkt6a6NKWUatUu6ncR/17176jlHc/tiC3ZRu7sXHwHfFhJFmnD08iZnoPlqNzG2N7dnpM6ntRURVaqSU2f1JsH310f9d0l3u8tbofFjRN6Yek8U6oNa9XB1JESH+v2FVJQ5sdlt8hKdTOka1rc8zgt2nyIH7+8kjJfkJJIRRMMfTsuePPBEn7z5hp+/eYafnnuCdw8sTd/+mBDzEl7UwZPJmXw5GrPZbOE84Z2Js2tXeVKKdUQMtwZTM6ZzEc7PiJE5Xo5c1ImmZMya9w/yZ7E9UOu17n/VKt14Uld+f0762Kuq+17C4AxcOkp3RujaEodN1pdMGWMYdWufJ6av5WP1x/AZbMwgABBY0h127lpQm9+MDKnxjG+b6/ay8/++xXeWhJCHA2yHnx3PVNHdsPU8RlMt93ix2f0q9vOSimlYrrt5Nv4fM/nlAXKEtrPEot2rnac3/v8RiqZUs0v2WXnR6f341+fbqbMH1/CiaNcDuGSkR1JS9InRlTb1qqCqZLyADc9v5yVO/MpDwQJmejseKW+IH/+cCOPfPgNf5l6IucO7RJ1nC+25vGz12oPpCoq8wd5fvEOBnROZV++l2JvgGCckVWSw8Yz151Czw7JtW+slFIqbr3Te/OP0//Bjz75Ed5gfFNP2MRGqjOVmefMxOPwNHIJa1dSHmDOqj28uXIPh0v9GGPITHZy3tDOXDyiG6k6okHVw22n9WH7oRLeWb0v/oBKfJC0hQ8K72fubAeXD7ycqQOmkuXJatzCKtUCtZpgqrg8wPf/vZCdh0tjzvdU0dHK4u5XvqKwLMBlo77tojbG8JNXvsIbY6geQMm6uRQum4M/bzeWMwlHVm/Sx03F3W0wBth2sIQ3bz+Vm59fQV5x+bGeq1iSnTbsNovnp43ixBzNFqWUUo1hTOcxPHXWU9z28W0ETTBqIt+KPHYPHZM68vTZT9MpuVMTljLawaJy/vLRRt5YuRtLpNJzLVsOlrBmTyEPv7eBC07swk/OGkCndHcNR1MqNhHhj5cMo2Oqi2cWbCNkTA2pzgMgBkfGchzZb1MeMpT7ypi5ZiYz18zkvN7ncd/Y+3BYGuCrtqNJgqlAKMCCPQvYkr+FEn8JHoeHbqndmJwzGZet/hlgjDHc8OyyuAKpirz+EPe/vZacTA+n9u0AwNJthzlSGntOkcKlb1Cw5L+0P+uHuHsNR2x2yratoGzTEtzdBgNgibBw0yE+++lpzP3mAI/P28Lq3QU47RbGgAgEgoau7ZK4dVIfzhvWGbcjvonxlFJK1c1JWSfxydRPeH/b+zyz5hn2l+zHbtkxJoT4SvA73AzrMIzrh1zPqV1PxZLmHbq0+UAxlz25mPxSP4FQ7C+2RxsGX1+5m4/W7+c/N43hhM7R6eCVqo2IcM85A7l8VHeeW7yd/yzZiYgQNMFvh8gawZGxAmfmQixnXqX9faHw96b3tr3HrqJdPDHlCZw2nepFtQ1iahiKNnLkSLN8+fI6H/xQ2SFe/eZVXlz/IoFQgPJgOQETwCY23DY3BsPF/S7mihOuoFtq3bPYLdmax/XPLouZSa+mnqSj+men8OFdkwCY9uwyPttwgKqfSqi8hN3/vpb2595J8sDxNZYnO83FF78449hDy7sOl7LlYDFF3gDJLhvd2nnon51a5+tVqj5EZIUxZmRzl6O+6ls/qbZt45GN7CveR5m/hJTXp9Pnhnl0bte7uYsFwL6CMr7z9/kUlPkTeg43zW3nnR9NoHv75h+aWFetoX5qDXVTeSDIe+s38Jv5f6Q86ENspdiSdiGWv9Z93TY347uO5y+n/UWTt6hWo6a6qdF6ppbnLuf2T28/FkRVFDRBSgIlALz8zcu8uvFVHhj/AGf3PLtO53ry862UxQik4ulJAth1uIw1ewoY0jWdzzcejAqkAMr3bMAEfHj6j621PIVlAbYdKqF3xxQAcjI95GQevzc3pZRqbfq360//dv3DP7iyIFD7l8SmcvPzyynyBqICqdoaB4vLA1z/7FI+vnuSfolV9eKy2/hg/78wyStxVPlWdGT+EQ59cAjfAR82t420EWlkX5KNLTk8ysYb9LJw70K+OvgVJ2XptAKq9WuUcQzLcpdx68e3UuIviQqkqvKH/HiDXu5dcC/vbn034XMdKPIyf/OhmD1J+QteJHPKrXgGjMNyuhGbHU/f0bSbPK3Str5AiKfnb8XrDxKqphkwWFaI5UlDrNqH5NltwpHSlnNjVkopVYP0HMjf1dylAGDd3kI2HyipNA0HhBsHD3/yFOljptLt9ll0vXUmqcPPpWzTkmPbhAzsK/Cycld+UxdbtTIHSg+wNHcppsq3q0PvHSL31Vw6Te3EoEcH0fu+3vjyfGx/ZDuhCo9ZeANenl37bBOXWqnm0eDBVG5JLrd/cnvcWZOO8ga93L/oftblxZ7voDqrdubjtEVfRiI9SUFjWLh6J/t+/wCEYj9zZUtKI1RaiAklljpUKaVUC5eRA/k7mrsUADyzYCu+YOX7UCKNg15/kKc+39qURVat0CvfvELVVupgWZADcw7Q5aoupA5LReyCs6OTnNty8B3yUbCo4Ni2BsP8PfM57D3cxCVXquk1eDA1a90s/KHoXpkj84+w6d5NrL15LRt+vIG9z+0lWFI5MCkPlvP4V48ndL5CbyBmb1IiPUkApcYi44T+WFbsj8TVdSBid1C6cXGtxwoEDRkezWSjlFLHhYzuUND8PVP+YIh3Vu+L6pVKpHEwZOCT9Qco9QUaq5iqDfhwx4fHkkocVbqplJA/RNqIyklObG4bqcNSKV5bXGm5Qxws3be00cuqVHNr0GDKF/Tx303/jQqm4u0WNhgW7llIXlle1UNXy2ETYo0MT7QnyeF20e7yyzm1X4eY6y1XMhnjr+TwR49TunExIb8XEwxQtmU5Rz6bUWnbFJedXu11ziillDoupHeH/J3NXQoKymIPD0+0cdBuE/KKY2elVSoeRb6iqGXB4iD2FDtii/7WZU+3EyiuHMAHTID8ch1yqlq/Bg2mPtn5CVWzAybSLQzh9JyvbXot7nN2SHHFfNA2kZ4k4FhP0vSJffA4Y9+w0kZdRLvTb6Bg8Wx2//NKdj92HUVfvkNSv29bC90Oixsn9MKy9OFfpZQ6LmR0bxHPTJX5glgx7meJNg6KhIf7KdWQbCk2AsUBTIw5qAIFAewplXOaSeQ/pVq7Bs3mtz5vfdRkiPF0C7eb2O7Y8vJgOasPro77nKf0zCRW3FKxJ0ksG+5eJyOWHe/2VXh3rq40ztztsI5N3Du2T3vS3I6YadYBUgZPJmXw5GrLYwxMHZkTd/mVUko1s4ycFtEzleZ2EIzx3G7FxsHapuYACIUMqW4daq7qLt2ZzqGyQ5WWefp6ELtQuKKQ9FHpx5YHvUGKVheRfUl2pe1tlo0Md0aTlFep5tSgPVNHvEeiliXaLQxQWF4Y9zmddourxvTAaY++lHh6kiAcAF0aCYBEhD/9YBhuR+IfTZLDxl1T+tMuWSeqU0qp40ZqZyg7DIGas882ejHcdpJd0W2ciQwzB7BZFu1T9D6k6u7c3ufisrkqLbN5bGRdmMXeWXspWl2ECRh8B33senQXjkwHGeMqB07+oJ8xncc0ZbGVahYN2jOV7Ix+Tqhit3DVgCpWtzBAkiMpapnXH6QwMp483ePAZf92KN41Y3vyzIJtMctUW0+S3RKmDMquFABN6NeRBy8cyq/mfI3XHzu7X3SZbVx6Sg7TJ7aMSR+VUkrFybKFA6qC3dC+T/MVwxKuG9eLR+dupjxQ+d6TNuoirOR2FCyezaF3HkGcSbiy+5I29tJK2zltwlVjuuOIkeVWqXhd3O9invjqiajlHc/tiC3ZRu7sXHwHfFhJFmnD08iZnoNVoRHaEoszepxBuis96hhKtTYNGkzlpObgsrkqzS2VcLew2OiR2gMIZzb6aN1+Hp+3hTV7CsK9TwZ8QcPJ3TO4ZVIfTh+YRad0N788dyAPvb0Wr4l/fK4IZCY7+e13B0etu3hENzJTnNz58ioCoRAl5bGH/XmcNkLG8JOz+nPjBA2klFLquJQRSULRjMEUwJVjuvPo3M0x19XWOAiACFeP7dnwBVNtSvuk9kzoNoHPdn1GyFQO7DMnZZI5KbPG/Z2Wk2sHXduYRVSqxWjQYOrcXufylxV/qbSsYrew5bZIGZSC/4ifvS/sjdkt7LAc/GDAD3h39V5+8cbXBEPmWCBTsZdoxY4j3PnySpx2iz9PPZELti1i6651zO55Kt5A7Il3K7JbQrtkJ6/eMpb2En7DpgAAIABJREFUKa6Y20wekMXye8/k43X7eWzeFtbtyccRCmK5XPiDhuw0F7dM6sOFJ3eNOTRDKaVUy3eouJz5/rEcWXYQ9m0jw+NgfN8OZKW5m7wsHVJcfO+kLrz11d64R0Yc5bZbnDEom64Z0aM7lErUPafcw9LcpTEz+9XEbXNzTs9zGNwhuqFaqdaoQSOAdu52TOw6kU93fkqIb28C8XYLA/RM78m8NTYe+fCrWm8kJb4gJb4gtz63jB9uWcKv/vJzTskTHnh3PUVePyUxkki4Is9WTezfkYcuGkqHagKpoxw2i+8M7cx3hnZm26NPcbiwhIwbbiQ9yUH7ZGfMTIJKKaVaNmMMK3Yc4cnPtzJv40HsZij+kMDXG3DYBH/IMKFvB26e2JtRvTKbtK5/4MKhbDlYwto9BXgD8QVULrtFn6wU/vyDExu5dKqt6JLShafPepobPriBUn9ppe911XHb3IzuPJrfjPtNE5RQqZahwbtTpg2ZxoI9C/AGvZWWx9MtnGRg+IFMHlnyTdw3EIByIzza7xyGFNs5/8RszhvWmcVb8nji8y2s3JlPmT+IzRLaeZxcdkp3rhjdnY6pNQdRABxYD4v/DRs/AF8RPfwButs82Fbsh9HTQXrGXUallFItgy8Q4q5XVvHZhgOU+YMYA+VEnsMNhjjaDvfphgMs3prH+L4d+OcVJ1d6VrcxOe0WL944mtte/JIvtuZVm132KI/Txsk5GTx17UjcjqYpo2obBrUfxOzzZ3PHZ3ewu2g3vpAvatgfACEHTrtw2cDLuGvEXViiz+yptkOqzgtV0ciRI83y5csTPujTXz/Nk6ufpCxQFvc+bpubc3Km8PoHoygJRad0LVk3l8Jlc/Dn7cZyJuHI6k36uKm4u33bjdzO42D5vVOw1XeOpz0r4J274eA3EPSBqXIjszlBLOgyHC74O3TsX7/zKdWERGSFMWZkc5ejvupaP6m2LRAMcd3MZSzfcTjuYXRuu8XQbum8dNOYJk3sEAoZ5m8+xF8/2siqXfm4HRa+QAhjwOWwMAZOzMnglkm9Oa1/VquY37A11E+ttW5ae2gtz617jo93fIwgWGIRCAXITMqkl/08MkLjeORizd6nWqea6qZGedDnhiE34A/5mfH1jKgeqliS7Emc3fNsTnTfyOv29VClFa5w6RsULPkv7c/6Ie5ewxGbnbJtKyjbtKRSMOULhvh0wwGmDMqueor4ffMe/Hca+Eur3yYYmVl+52J4ajJc9Rp01wpEKaVaut+9vY4VCQRSAN5AiK/3FPCrN77mj5c03TA6yxIm9e/I17vz6dsxhRO7Z5BfEr7/ZHgcTOzfkR7to7PoKtUYBncYzB8n/pFAKECRr4jyYDmpzlQ8dg95JT5Of2Quh8/xkanTw6g2plGCKRHh1hNvZXD7wTy66lE2528mEAoQrNDDY4mFy+Yiy5PF9GHTOb/3+Zz9t8+jnnMKlZeQv+BF2p97J54B444t9/Qdjafv6ErblpQHeXzelroHUzsWwavXQ9w9agZ8xTDrYrjxY8g6oW7nVUop1egOl/iYvXxXVNpxqH30g9cfYs6qvfz07AFkpTZtYor31uRy3/mDGNO7fZOeV6lY7Jaddu52lZZ1SHFxzpBOvLRkBzdP7MNXu/M5XOLDGEN6kpOTcjJIcuoQVNU6NWoKuondJjKx20S25G/hxfUvsuHwBor9xXjsHnqn9+bygZcztONQIHyT236oJOoY5Xs2YAI+PP3HRq2LZdWufLz+YOLjxoMBmH1lVCDV829FlPph2x0pJDvDQyie/tLHrNV+5l4XaRH0lcAr18DtyxI7p1JKqSYze9lOYg2Ei3f0gwAvfrGTu6Y03dDunXml7C/0ckrPmp85Vqq5XXBiF6Y/v5zH522NWhcMGS4e0ZVpp/aid8eUZiidUo2nSfJ598now6/H/rrGbY6U+nDYLHzByj1TwbJCLE8aYsUXHDlsQkGZP/FgauP7EPDFXBU08PclPn45obqkFSY82eOeFdB1RGLnVUop1eiMMTyzYFtUcqNERj+UB0I8t2g7Pz6jX/2fzY3Te2v2MWVQpyY7n1KJMsbw90828djcLeFn+qrJ+vfy0l38d/luLhrejd9fOER/p1Wr0WLSrVSXB8OWlEaotBATqjmbUTzHqtHCv4WH7MXws3FOHllUTr63hgMHvLDoX3U4sVJKqcZWUOanoMwftTzR0Q9ef5C8kvLaN2wg763J5dyhnZrsfEolwhjDfXPW8MS8rZQHQtT09SsQMngDId5YuYfpLywnFKrLlzWlWp4WE0xleBz4g9F/WK6uAxG7g9KNi+M6jj9oSE+KzgZYo5I82PdVtatHdrFxWk87jyyq4QZqQrD+bQglNsmiUkqpxlfkDcTMxJfo6AebTSiMEZQ1hr35ZezIK9FnpVSL9cyCbbz25R7K/PE3eJf5gyzcnMcD765rxJIp1XRaTDDVPtlJl4zoh3otVzIZ46/k8EePU7pxMSG/FxMMULZlOUc+mxG1/aDOaYk/5FhyAGw1zzv1u8ku/rnUx8GSGoIlESgvSOzcSimlGp3LbsVsCU909EMoRJPN5fT+mlzOOCG7SdOxKxUvrz/IXz7aGDOQKlk3l33P3cnOv1zC7n9dzf5XfoN399pj68v8QV5cspODRU3Xy6tUY2mSZ6biISLcMqkPv3tnXdQEhWmjLsJKbkfB4tkceucRxJmEK7svaWMvrbRdstPGLZP6JH7yoC8cCNVgSJaN8/vbeXiBjxM6VnNjEwuCTdNiqZRSKn4ZHiexRhVVHP2QPHB8rccJhEK0T45j0vcG8P6aXKZP6t0gxyr1BZizcg/PLNjGvgIv5f4QLodF90wPN0/szblDO+uEvyoh76zeF3N5vAldAF5aspM7zuzXaGU0xrB6dwH7Csrw+kOkuu2c0DmNLhlJjXZO1fa0mGAK4LsndeG3b8fu9k0ZPJmUwZNr3N9mCWcNrkNadHc6hAK1bvbb09wMf6KYn4yt5kYa9IWPpZRSqkVx2i3OGpzN/77eVymoqjj6QSwb7l4nI5Yd7/ZVeHeupt3kaZWO0zcrhVCdHsxNzIEiLxtyCxnfr0O9juMLhHjovfW8vHQXIlRqrCz1BdmQW8R9c9Zw35w1XH9qT+6aMkATA6i4PDZ3c1Tjd6IJXWYu2sbtp/dt8N+5Iq+f11bs5sn5W8kv9WOJYIzBsgRfIMSIHu2YPqkPE/p2aBWTXavm1aKCKY/Tzq8vGMRv316b0ISKAG6HxYPfH1K34RBp3cDuqnmiXqBvpsWlgx38Y6mPoVkxzpPZO3wcpZRSLc5NE3rzyfoDUcOS4h394LJbpLjsjHv4U757YheuGN2dEzqn1btcpb4AizbncbjER9CEn/vdkVfC5IFZuOx17y0qKQ9w1TNLWL+vsMZ76tH5HZ9ZsJ3Vewp46pqR9Tqvav18gRDbGmA6m3J/iL35ZeRkehqsbPM2HuTWWSswQJkv9vDdRVvy+GpXPt3aeZh142g6pup3N1V3LSqYArh8VHf25Zfx1PytlMUZULkdFndP6c8FJ3at20ltdhg1HRb8FYI1j9/99SQXL6yOMZTPkQzj76rb+ZVSSjW63h2TsdsEYlThtY1+EIFO6W5emT6W3EIvs5ft4vqZy+jaLokrR3ev0zC5LQeLmbFgG69/uQebJQRDBoPBblmU+AIM6ZLGkq15jOqVidQyFL2qQDDEDc8tY93ewpiTFMdS5g+ydOthfvyflTx+1YiEz6najkKvH4fNivrdSjihixWeziangcr13tf7uOuVVXE1yJf4gmw5WMx5/5jPOz8aT1Za007GrVqPFhdMAdx91gA6ZyTxu7fXRQ1LqCjZacMAD31/KN87uY6B1FEjp4WDqSq235la6eecdAvvvbFaIg0Mvqh+ZVBKKdUolm0/zN2vrGJS/47M/eYgxeW1D+2uyOO08cy1pyAidE5P4s4z+3P75L58uuEALy7ZyQPvrueik7tyxejutU5KaozhkQ++4ekF2wiGDIGoh7nCXwTX7Cnk+meXMaJHO568emRCyZVeWb6Lr3YVxAykStbNpXDZHPx5u7GcSTiyepM+biruboPxBkLM33SID9bmcs6QznGfT7UtDsuKOdy1YkKXeAIqEwwgBUcwXdLqHbyv3p3P3XEGUkcFQobDJT4ue+oL3r9jIk67JntRiWuRwRSEe6i+d1IX3lq1l8fmbWFfgReHLfyH5iv3083m5/bvjeK8YQ300GxqNoy4FlbOqnW4XxSHByb9HJwN102tlFKq/vzBEH//eBMvL9vFQxcNZcqgbNbuLeCKp5ZQ5PXHTEpRkSWQ7LIz64bR9M2qHCTZbRZnDe7EWYM7sTOvlP8s28nUJxbTPzuVK0f3YMqg7KgvZ8YYfvXGGt5YuafWHiNDuDFx6bbDXPL4Il67dVxc9ztjDE/M2xozy1o8yQFKfUEem7tFgylVrVS3PeacnokmdPH5gxRPv4FNvhKc/fri6tcPV9++uPr2w9W/H/Z27eIu00Pvbah2RFNNDQiBkCG3wMv7a3P57old4j6fUke12GAKws9QXTaqO5eeksPB4nIKSv2IgGPNKswzT9BzRAP3BJ3zMBzeBjsWxh9QOTww9BI49Y6GLYtSSql62XqwmDtnryIz2cn/7hhPVmp4GM/gLun8744JPPS/9Xy0bj8iRLVmu+0WBjh9YBa/+M4JdG9fc2NZ9/Yefn7OQO46sz8frM3lhS+2c//ba5k6shuXndL92DMhMxdu542Vic3LUx4IsflAMT/+z0qevGZkrdt/ufMIB4ujh6wnkhzgm9wiNh8ojgoglQKwLOGsQdm8vza3XgldBuVkMurzjwnm5VG+eTPlGzdR/s1GCt/9H+WbNiFOZzi4Ohpk9Q//35ZWeYTQ7iOlfLnjSMyyxtuA8PjcLRpMqTpp0cHUUSJCVqr72I0wNHo4G3+6nlB5OZarAR8atGxw+cvwzp3w9X8xAS9S3XzeNmc4FfqY2+D0e2tNra6UUqppGGN4aelO/vzhRu46sx9XjekRNYSoa0YS/7piOEdKfMxevou3Vu0lv+z/s3fe8VHUeR9/z8z2TUhIgzQISei99yqnqKhYADnLoYjY9bxHPfXOe/T0nrPcnV1s2BCwn2JDUem9ht6SkIQkpPftM88fK5Gws8luSCBl3q8XL2Cn7G92Z+f3+7bP1wkKhFn0TB8Qy7XDuxAZEtwcY9CJXDYwjssGxnG0oIolm7O4/OV1DEoMZ/awRL99eaB+77nDLbPmcCFHCypJjQlVPf4Un21XN9aCEQdwywrLd5/gj7/rGdiFa7Q75k9IZtXhwka3szHrJWYPT8ThljFFRaGLisI6alTtdkVRcBcUeA2so0ex7Umj7PPPcR49ihgS8puB1T2VRVXRqmmHwTgQ0ouqOJhfQa/OZy8q0xjc5Q5c+dUodjeCXkIKM6KPs2q1i62AVmFMnYlotWJMScGeloZl+PCmPbmkgyteRh46n4q/XU5Ykh1BZ6A2ni3g/ffQuTBiPnRMatr319DQ0NBoNEVVDh76NI2TlXY+XjCqQcOjo9XAbRNTGtejsAFSY0J47LI+PDitJ1+n5fHP7w9S7adWKxDvuVtWWLQuk39c1b/e980tt6mmYAUjDuCWFfLKtYaqGv4ZlBhOpw4mVVW/QNrZ2Fwenli+n//9aj+/69OJ+ROSGZgQVms8CIKAvlMn9J06ETL+t5RBRZZx5ebhOHoE59Gj1GzdympHX1ymSJ/3CFZdcFtm6Tk1phRZwXG0jMrV2TiOVyDoRGp9+IqCaNUTOjEBy+AYRGOrXLK3C1rtN2MZNoya7dub3pj6lcrtmVSIFxH+4L8hZyvYSr2RKEsEJIwAvab6oqGhodEs5O6CosPgqARDCER0g4ThDWYA/HzwJH/+bA9XD03gteuHtphicpNe4pqhCSxan6Ga6xCo99wtK3yx8wR/md4bi8E7fSuKQoXNTUGlnYJKBwWVdjJVFrcQvDiAwx14KqJG+0MQBF67fghXvrQWWyNvFfuvdYPf7c3j54MFJHQ08/YfhtebViuIIoaEeAwJ8TBpEgCOp3+GUpvPvsE4EFxumQq7itRnM+GpdFL41h48pXYUp/dzUM74zXmcDsq/zaD8u0wib+iNKTXwGjKNc0frNaaGDqH601dhb4y3vskYClE9IaZXk5y/dNlHRN4yD4whkFK/d0VDQ0ND4yxx1sDez2D981CR63VeyR5v+jUKmDvCmHtg4Bww1fUc25wenvp2P78cLOSlOYMZmezroW4JHMqvVH09uPQ7mXnvbsXulimocFBY5cCoE4kJNRIdaiQm1ITox+gMVhwgKsgUR432R/z+7fx9x2IeG3oDNrfirzCiQWTFG6k6VljF9JfW8tGC0UH1cNNJ6vd8UA4EAYorndQ43bXOiubCU+Hk5Es7kKvdNKSCc8rQKn5vPxFzemHu0zKfb+2Z1mdMOaog7WNC0v6NNSoH5astCIoMoggeN0SmePs99b4cdIYGT7cvt5yDeZVU2l1YDDriO5oZ5CrClZtLyMSJ5+CCNDQ0NNo5J/fBe5eByw4u9agKzmpY+b/w85Nw3SfQxVtbsSennHs/2snAhHC+u288HUz6czfuIHC4PShquXcE5z0XBYER3SIZ3z2KmFAT0aFGH8n0T7fn8Lcv99Y24609NghxAKtBYkyKtmjT8E/lqlXk/eUvXPr6QgbGJPHQZ2nsy61AlhVcDclk+kFWoMLuZs6vUuWdw+rPAlIUBceRI0TWlJOJ728/GAeCgMC3e/NYvPk4MR2MdI8JpXtMCN07ef9OjQnB2gSpdopbpvDNtIAMqTrHuWRKlh4k+vaBGOI0YZiWROsyprK3wOJrQHYjuKoRdIDzDE/fyb2w/F5Y8SjM/Rqiuvucxu7y8HVaHgtXHeVEmR1BAI+sIImCVy3QYWfOhfO4xe4hMqR1fUQaGhoarYq83fDOxV5jqSFOqax+cCWea5eyMCuBResy+NvlfVu8CpdeFFXrmCA477lBJzKhRxRDu0b43Wf6gFge+3Kv6rZAxQGMeolJPWPqvyiNdkvVmjXkPfIoia+9irl/f7oDn98xloyiat5dn8HXaXlU2F24POo3fX1iKwCVdhfP/XCI52YO9DnWmZND9caN1GzcRPXmzYgWC5ePuIQDYndqzlBGD8aBoJMEfvjjBMx6iaySGo4UVHHkZCVrjxTy9roMMoqqiLQa6d4phB6dQkmN+e3vkCCMLNueIjzlDlVDaktOGv/45TUOF2UiiiLdI7vytwvuZlBsb8BrUJV/n0n0zf0Cfj+N5qf1WAoZa2DJLHD55sT64KzyTsxvToZ5P0JM79pN6YVVXPvGJqocbr/NgEHPOzWRvPv0L7z8+8Fc0LtT01yDhoaGhsZvVBXCe1cEZkidjqsG5+LZHI5+leV3X0JcuLl5xteEiKJAqElHhd1XgCIY77nLI9cq2/rDpJeYNSyRDzcfV13MNiQOYNSJ3Dw2CUnUVMQ0fKlau47cPz9MwisvYx5Y19jpFmXl8Sv68fgV/ZjzxiY2phf7HB+I2IpHhuW7c/nbZX0wV5ZRvWkz1Zs3UbNxE7LDgXXUKKxjRhN9//0YEuJJdHt49u8rQUXgJRAHgiQKXDYgjtBfI9vJ0SEkR4dwUd/Otft4ZIXskhoOn6zkSEEVG44W8d6GTI4VVhFhMdRGsHp0CiW1UwjdY0Jqz1fn+ldl16bunU6lo5qbPv0zT114P5f1mozT42ZLzm6MUt0sK0d6GZ5yB1KYlobbUmgdxlTxMVg6JzBDqhbFmxL47nS4aytYIjhaUMWVr66nyuH26yE8hd2jgMfDnUt28K+ZA7l0QMv2empoaGi0Ora84dPTL+n5SmpckHFvCFaDdzH/1g4ni9NcrJprrd3PqLh4Pu5nhPCrz+mQz4arhySwWMXACcZ7nhRpre1ZVR+3TUzh8505uGzq6oH1EWrScf2orkEfp9H2qVq/ntyHHiLh5ZexDB7sd7/skhp2ZPn2fQpGqlxwu3hjwV+ZfuAXLMOHYx01isi5czGkpPjIhRt1EteN7MK7GzJVm2E35EDQSwLzxnfzux28BldSlJWkKCsX9v3tdY+skFNaw5GTVRwuqGRTejEfbDrO0YIqwi362ghW95gQeks6IkvsqudPL8kGYEafqQCYRYmJ3Uao7lu1MY+waUn1jlfj3NE6jKk1z/oYUoFNuIo3SrXtHSpH3MO1b2wMyJA6HbtL5k+f7KZLhJX+CWFNdEEaGhoa7RyPy2tMeXzltz0KvLDZySPj/XteRTxewYqL/+kVIGoFzB2bxJItWaBSph+I99xqkLh9UmAS7p3DTHw4bxSz39iIzekJSBhAFLwLxvhwMzrJVwkxo6iapZuzOFJQSbXTQ5hJR7/4cOaMSCSmg6Zw2xopqXby8bZs0nLKKLe5CTFIpMSEMHt4Il0jrXX2rd64kdwHHiThpRexDPFvSAGsOVKoKr4ZjNiKHYl1/S/g/kVPIUgN1xPeN7UHPx8sIKOoGncQtUhmvcSCCcmNlkSXRIGukVa6RlqZ2ue3TCZZVjhRZquNZG3NLCUnvYIrXQpGfD+c5IhEREHkj988xeW9LmBwfF/CTSrPNreC/UipZky1IFq+MWUvh31fgOKbkhfIhIvbDpte5VNhBtUOj6oh1VDursMl8+8fD/HOTeoeAg0NDQ2NIDn0HcjqUZMHxhh4Zr2DO4YbCDfVk2omCLD7IxhxSzMNsmnpGmllYGI424+X4lFZ7DXkPZdEgWn9Ovvdfib9E8L4751j+f2bm7A5PT6CFKdjNUiEWwx8eMtIXl9zjBve3sy7N40gzKznl0MFvPTTEVVxgTVHinhl1VHGpURx9wWpDO6iSTe3BvbnVvDSz0f46WABIr9JlAPoDxXw9roM+seHcdeUVCb1jKF60yZO3P8nEl58AcvQoQ2ev6zGhVMlQhSM2ApAOfqADCkAs0Fi6a2jmPX6Rk6U2lQjVD7H6CXmjEjk3qm+9fVniygKJEZYSIyw1JaLVPycRcWPx9X8KYQarXx+3cu8unkJD37/LIXVJUxOGckz0x4k2lq3RlJWSRfWOH+0jCYc9bFrqVciV4UHxhh4boODMnsDspIuO6//clC1I3zFli8o+elNwkbNIuGuxcTf/g6hQy7BdmTzb8cD648Vc7JCPTSroaGhoREk2Zu9mQMqDIuTmJSk47kNDTSNddVAxupmGFzz8cK1g+hg0qn4pevHpBd548ZhGHWBLSxP0aNTKBsfvoBnZw6kf3wYJr1IqElHiFFHqEmHUScyPKkjL84ZzJoHJ5MUZeUfV/ZnUGI4176+kce+3Msdi3ewI6sMh1v2UWlzuGWcbplfDhUw581NLN2SFeSVaZxrvtx5gqteW8+Kffk43XIdQwrA5VFwuGW2HS/l9sU7+N+3fibn/j8R/8LzZ93b83SxleYgKsTI8rvGMWNQPCadiFmv/nuxGiRiQo08fkVfHrusr0/aYHMhSCL1/fi7RyXxn0sfYeudn7Fy3rucrCrmf396yfc8Wj1ji6LlR6YyVvnk1J/i9An3ySn+Uww22rtQKXuAuj+qYHJ3AT7YeJz/uahnoy5DQ0NDQ+M0qgvr3fzEZCNjF1Vz78gGWlzYfOsyWjKxYWY+uW00s17fRLnNiadh5zlmvcRLcwYzqpH9s/SSyCX9Y7mkfyyZRdVkldRQ7XATYtLRLcpKQse6NViCIPDY9D5c9eoGPth4PKAUQQVvWvwTy/ehEwVmDkts1Fg1mpev03J56PM07K4Abjy8vZ+WHSpHuekJHh8ReHZOuEWPUSdiO+N9gu111tHacIubM7EadTx9zQD+Mr03X+w4wbsbMymsdODyyJj1En3iOrBgQgrjUqMQz7FRIobqEXSiqgDFmaRGdmVWv2ks3vWVzzapQ/Cfi0bz0fKNqZr6J8pAJtzdSgoOj+8PJpjcXadbZnOGryqNhoaGhkYj0NevwNcvRmJ6Dx3/XOekd3Q9SRT61lerkxoTyvf3jufx5ftZeeAkgoDP4lYvCoiiQP/4MB67rA8DEsKb5L1PFdA3xC+HCjiYX6lqSNWXGm9zyTz25T4GJYbTvVPrqGVrL2QWVfPAJ+qGVH3fqV0y8HGWi7H78rmwb2BpppN6xvDE8v0+rwcjtmLWS8wY1Hjxr1CTnhvHJHHjmKRGn6OpMfeOpPTzo6rbjhYf56djG7m81xRiO8SQW3GSLw/8xJC4vnX2q0HhR9FF/8wShnbteM6iahr+afnGlK5+6cdAJtxyxYpbJaMx6Nxdmyug/TQ0NDQ0GiC8K4gGkJ1+d3l8kokhr1fxp9F+5gFBhLAuzTTA5iWmg4lXrhtCabWTj7Zm8cn2HMptLirtbiwGicsHxnHT2G4BGT7NwQsrj/hNjW9I1trpkXlzbTrPXOPbI0jj/LFofQYulVBoIN+pzeXhhZ+OBGxMxYebGda1I+uP+TqhA+11pqBw5ZCERlxpy0U06zD3i8S2u9CnbspqsLAr9wBvbv2YCkcVHYwhTE0ZzaOT76izn9mow5bUgYc/34PN5eGKQXHMGBTfpM6L3dllfLwtm5xSG063TLhFz7jUKGYMjm+SxsVtjZb/iYQn4k0w9Z9o0NCEaxbdCB4F5YxE1WAaJYK3d4eGhoaGRhPQ/xpY/U+oJ9slNUJkdl89L25x0j9GxVmmM8KQG5pvjOeAjlYDt01K5bZJqYDXiHF55POaUn60oIpD+ZU+rweaGu+RFb7anctfp/dR7bOjce6xOT18si3HR+UumHKHY4VVHD5ZSY8AFu2Ky8W1NYfZ7g7FrvPNHGpYbAUuGxAXVDPc1kLo+ATs+4pRzogQxoZG89qMx+s/WCfSYXQsd1zQjdunpLI/r4Ivd+Vyw9tbiLAamDE4jssHxtM5LPiIvSwrfLYjh9dWHSOv3I7D7anTV3j14UL+/s0/fOHyAAAgAElEQVR+ZgyK587JqQG1aGgvtHwBisE3gL7+L+z0CVeNTkK5ahHi6bm7gRDfChpDamhoaLQKwrtAQsM1GI9NNFLt9ONMC0uEuPolmlsb0aFGCisbEN5oZj7cfFxVWjqY1HhREPgmLa85hqfRCFbsyz9rqXKXW+b9DZkN7mdLSyPjmpkMSFvDyJRIjLrglpoCEGYytNkadUN8CNaxcQj6IJfgIugiTYRO8UbjBUGgb1wYj1zSm/V/nsJfpvfmWEE1Fz2/hjlvbOKjrVkBZ1TZXR7mvb+Vx77cR3pRNTZXXUMKoMbpwe6S+WR7NtNeWMNmlYbM7ZWWb0wljgRrVIO71TfhXpzklVE/k9Nzd2sOb0R22VE8bmzHtlH6y6I6+1oMEteN1JoYamhoaDQZ4/7o4yzLvC+Uqcm/eaMTw0Tsf+lQp2Ev4D1u3P3nYpTnlOhQI0VV59eYOnyyStWYCiY1vsbpIbO4ujmGp9EIMouqqVGRxg/mO/UosPV4KYqfZp2eqiryn3yK7DvvJPKWeXR5601ev2UsfeM6YArQcBAF6GDWsWzBKDq14d5lYRclYRnaKXCDSiegizQTPb8/osH3u5JEgTEpUTx9zQA2P3IBfxjTlZ8PFjDunz9z2wfb+X5vHg63uoKiR1a45b1tbDharJra67s/VDs8zH1nKztVGjO3R1p+/FQQvBPuikfqqPpl3lc3zHxqwvVBbyFs4h1cZO7I12m5PpZ2oLm7oSYdY1Iap6SkoaGhoaFC6gUw4FpIW+ZXtVUVnRmSJ8OA2Q3v28qICjFQeJ6NqWqHeg+bYFPjy2q0OuOWgr8IRbDf6eH8Sq5+bQPv3jyCDqelcFauXEn+k09hHTuGlOXLkcK9gikmvcSyW0fz6Bd7+Gp3LoBq/ydRAKNOIinKwps3DvNRmGxrCIJAxxmp6GOtVPxwHMUtozh8DRnBIKIoYBkYTfjlKaqG1JmY9BLT+sUyrV8s5TUuvtubx7sbMvnz53u4qE9nrhgcx6hukbVKhi/9fITtx0tUv5f6xWY8/GHRFjY9cgEWQ8s3J5qT1nH1Q26E/V9C1kZvE95A0Vu8k23yZG41VfDD/nxVFZuGcnfNeolbJ6SccwlNDQ0NjTbPpc+RXVBMVPb3mAnAiNBboOtYmPkOiC0/uSJYokONFJ3nNL9QP3UqQctaWzT55pZCuEWvWn0e7HeqAHtPVHDFy+v58q6xmMtLOPnkkziOHCXu6aexjvRN3TXoRJ6dOZAHpvVk6eYs3tmQSY3Tg04UkBUFWYaL+3Vm/oRk+sWHBXQ9ZTVOiqqcON0yHcw6OncwoZNa3/MgZGQs1uGdsR8ppXJ1Dq68ahSXB0ESkUINWEfHYh3aCdHUuOV6mEXPtSO6cO2ILuSV2/hqVy5///oApdVOLh8UxyX9O/P2ugwfCXsITJjELSv8d+cJft/OM7dahzElSnDtElgyE05sB5et4WP0Fug1HS79FwgC/eLDuH1iCgtXpwcUxjyFUScyuEs4fxjdvm8UDQ0NjebgSGENs09cz9djh2Pe/SI4a9Sb+RpCQNLDmHtg7H1t0pACb9PRoioniqKcN8njfvFhbMooxnVGfnwwstZWg0SPziHneugafkiODsFikKg+I9UvmO/0FE6PzIlSGzc+8y1Pf/c0EXPmEPfcc4jG+tWXY0JN3Du1B3dP6U6ZzUWl3YVZLxFuMWAIoK7KIyusPlzAwtXp7MwqxaATERDwyApGnciNY7py/aiuxIS2rvRAQRQw94zA3DOiWd8nNszMgokpLJiYwuGTlfx35wluemcrVSqR6ECFSWqcHl5fnc6cEV3atUR76zCmAAwWuOFLWPV/sPl1QPE/4eqMMOFBGLmA0ysu77mgOzVOD+9vzFS1ws/ErJcYkBDGW38Y1io9HhoaGhotmUq7iwWLt/PwJb2JG3YhXHAnpP8MG16BwoPe1D+9GTomwei7oMc0kFrPtNUYTHoJo16k3OYi/DxFdq4b1YVF6zNQU9ENXNYaLu4Xe24GrNEgF/btxJ8/V9/m7zs1xHYn7737VFO8nB6ZQ1UKRc8upOe4AUGNRRQFIqwGIoJoyLsru4z5722jxumuNQhdnt8MQ5vLu6hfuDqdWUMTePyKfkhaNpFfenQK5cFpvdhwrJjS7DKf7cEIkxRWOdh7ooL+CYFFFdsirWtWknRwwV9h4oNwYDlsfAXKslDsVcgeASlpKIy9B1KneqNZZyAIAg9f0pt+8WE8u+IQRVUObC4PZ9ZSWo0SkiBw89hu3DUlVTOkNDQ0NJoYRVF48NM0RnaLYOawRO+Louh9fqdOPb+DO8+cEqE4X8ZUQkcLQ7p0ZKMfta6GUuP1osCsYYlaO5EWhFEnMWdEF97bkOkTcQTf7zSQFC+7qGfRMQdjG8gOPJhfwUdbssksrsbm9BBm0TM8KYKZQxMJszQsnb/mcCELPtjeYFbRqZqfz3bkcLykhkVzh6PX1m/1kl2iXqsajDCJKMDxkmrNmGp16IzeHiX9rwGgatUqSpcsocvjbwR0+GUD45g+IJYdWWW8uTadvTllVJwsIiQmiviOZm4e242pfTppP0INDQ2NZuKttRnklNr4z+xB53soLY6oECMFlQ5SY5quCWew3Du1OzvfKVWtM24InSQyb1y3ZhiVxtkwd0wSH27KqhPRUSPQFC8FWHukiMJKB9GhdVP8FEXh2z35vPLLUdKLqnC55TqqymsOF/LsikNM69uZu6ak+m04uz+3IiBD6nRsLpltmSU89Fka/56lPV/qw+5H4S8YYRJZ8S9a015oncbUGegiIvCUBCfPKAgCQ7t2ZGjXocg2G4dHjabX7l3NNEINDQ0NjVNsTi/m9TXp/PfOMVr0QgVvZEq9b+K5YlRyJP9zYU/+9cPhoBayJr3IK9cN1hp6tkASOlp49foh3L54e71GcjApXkadyN4T5UzuFVP7mtsj8+BnaXy3J9/vvXOq1GJ5Wi4/7D/JS3MGM7VPJ5/9/vrlXr/nqF9pTua7PfncMq6CPnEqSs8agLecpVpFRTAYYRJRENp9c+42EXqROnbEU9p4rXvF40HQtQm7UkNDQ6NFU1Bh555lO/nXrIFtXv64sUSHnP/GvQC3jE/mgYt6IgreRqr1oRMFLAaJV34/hCm9fBfFGi2DyT1jePW6IVjqkdgOJsVLlhUq7L/JriuKwp8+3s13e/ICMsJlxVvvdNfSHaw+XFhn2/HiavaeKFc9rmLLF5T89CZho2aRcNdi4m9/h9Ahl2A7srl2H6dH5u116Q2OoT2TGq0uEhNMH1aPLJMcbVU9T3uhTVgQZ2tM4XaDZkxptADyy+3klduwu2RCTTq6RVmx+pEp1tBobbg8Mncu2cHvR3RlYo/o8z2cFktLaNx7CkGA3p070C3Kyo8HTiII1IlqWAwSigLXDE3glvHd6BrZvhdVrYEpvTrx7T3jmfrv1arNmYNJ8RIEAcNpJRHLtmTzw/6TqiJf9UWS7C6Z2xdvZ9UDk2rV+N5Zn4ms0iA40DREj6zwzZ48/nZ53zo9sTR+45bxyew5Ue6j8giBi80kRljo1bl9R//axCpNNBsRBTtyYTpiaBQYQ+uo+NWHp9qFM78aMawL7iIbUkcjglYrpXEOcXtkVh4oYOHqYxzIq6iViFUUcMsylw2I45bxyfTsfP7qJzQ0moJ/fneQEKOOu6eknu+htGiiQgxkFFWf72GQXVLDiz8d4bPbx5AcHUJZjZMvd+WSXlRFpc1NR6uBXp1DmT4gDnMAzUQ1Wg5JUVbiws1kqQgQBJPipaAQ08FbL6UoCi/+fEQ1IhWIoIVHVli6OYt7p/YA4Id9+apiGcGkIepEka0ZJVzQW4uWqjG5VwxGva9k/ikCEZuZNy65uYbXamjdxlTuTtjwMsKBr0i5xIXw+hiQ3WCN8sroDr4OzB19DlNkBfuhEipX5+DMrkSQwDRgPidf3AGCQMioWKyjY9GFt65eBRqtjz055cx9Zwt2t6c2b/nMLuSf7zzB8rRchidFsPD6oVqkSqNV8nVaLj/sz2f5XeO0BugNEB16/tP8FEXh4c/3sGBiCsm/pgKFWwz8YUzSeR2XRtPx+xGJPP/TEZ/6qWB6T5l0EoMTveusTekllNtcnEmgkSSHW+adDZncOdmrolzpR9QgmDRERVEoq/Edk4YXSRS4c3IKz60IrjYSvKm/oijw8s9HiAs3Mb57+802aJ2rspJ0WHYdlGaC2w6K7FVCd//azLcyD35+En7+OwyfB797srbBozO7kqL39qG4PCgO7wNE8YCgM6E4vf+vXHeCyvUnsAyIpuNV3RECaCanoXEKm9NDSc2vndlNOiKsBtVmdpvSi7npna0NPsA8soJHVtiSUcLlL6/jizvHaikLGq2KowWVPPblPt6/ecR5k/tuTUSHmJo9zc/u8rAxvZiiSgduWaGDSc+QruHEhpkB+HhbNuU2F7doqnxtltnDu/CflUdUtwWS4mXSidwyvlutc+SdDRnYVCIcwUSSXG6Zn1ZuZzSl4HSiVtofTBoitNn+3k3GzWO7sTOrjJ8OqKdn+sNilPjvHWPJKbXx58/2MLJbBH+Z3ieo/mFthdZnTOXthnenexv2KvV86acMq23vQNERuHYp9owqit/bh9LQzfJrWLlmTxHuEjvRt/TXDCqNelEUhU3pJbyx5hjrjhahE0VEAVwehTCLnvnjuzFrWGLtQvJoQRXz3m3YkDodh1smu9TG3EVb+OS2MVpDQo1WQZXDzYIPtvPnab3oF99++5AEQ1SoodkiU9klNby3IZOlW7IQBAFZUVAUr4fa5ZEZ0S2CWcMSefq7g3w4f5TWZ7EN09FqYHKsgZXHq3GrGCUNpXgheA2yUxwrqFJp8xxcJMlls7N32X/pYy0nJHQKlRh99gkmDVEQBM2B0wCCIPD87EE8/MUevk7LUzWIT8eoE7EYJJbMH0X3TqF07xTKD3+cwL9/PMyF/1nDo5f2YsageFUncluldRlTZVnw3uXgqAj8GFcNZK7FtexRig9d3rAhVedYGWdOFcVLDxJ5fe92dWNoBM7+3Apu/WAbJdVObE4PCnU7sxdWOvjPj4f51w+HuXF0Vx6+uDdPf3+QmkbIvTrdMgfzK/nlYIGqjKyGRktCURQe+jSNYV0jmDU88XwPp9UQaTVSWuNElpUmTYl8Y80x/vXDYWRFUa1FAW/foA1Hi4npYCS+o7nJ3lujZfKg9STbRRPFgoSKFoVfTHqRp68aUCcK4c85GEwkyaPTY7rxJhInpzLju4O8vS4Dp6fxaYgeWWFkt4jAL6ydopNEnrl6AJf0j2XhqmPsyi5DlhVcp90UVoOEXhL5w5gkbhzdlciQ3wxdq1HHX6f34YpBcfz5sz18vuMET83oT5fIhhVbFUVhZ3YZmUXVVDs9hBglkiKtDEoMbzXr7tZlTK141MeQSnq+khoXZNwbgtXg/dDf2uFkcZqLVXN/VRVy2Sjb3xXF42tIbclJ4x+/vMbhokxEUaR7ZFf+dsHdDIrt7d3BLeM4XIozuxJjl/atVqLhy+b0Ym56dys1DXhyToXOF2/K4lB+JZvSi1ERKQqoSLfG6WHh6mOaMaXR4lm0PpPjJdV8etuYhnfWqMWgE7EadZTWOOssWM6GZ1ccZNG6TJ+aTDU8ikJxlZMZL6/ny7vGtvseMm0Z05F9vNu7P/Nywyiucvg1suscoxd59JLeXDE4vs7rVoP6kjKYSJJeEulg8p7nhtFdeXt9hup+gaQh6kSBq4bEY/EzLo26CILA5J4xTO4ZQ3ZJDV/uOkFumR2by02k1cjwbhFc0Cum3mj1gIRwvrxrLIvWZXDFK+tYMDGFeeO6oVc5psrh5osdOby+Op2SGm9fvdMdSJFWA7dNTGHG4PgWXyveskd3OtXFcHiFamqfR4EXNjt5ZLz6pONWonF4evi8Xumo5qZP/8xTF97PZb0m4/S42ZKzG6NUNySsuGWq1uRgvL5P01yLRpvgaEElNwdgSJ2OzeVhw7His5J7BdhzopzMomqSojQZYo1zj6IoZBbXUFLtwP1rKmtKdEidCXNrZgmvrTrKF3eM1RrzNoLoECOFVY4mMaa+2JHDonUZQdVDOD0yOWU2bnlvG8tuHdVqPMQawWE/cIDkmTP57po+PPrGT6zMcyEZDX7vFVGA2DAz4RYDTrdcqz4L0Cs2lGOFVT4RrmAiSZIo1AqexIWbGd61I+uPFauOpaE0REkUuGmsVvPXGBIjLNw1pXujjtVLIgsmpnBxv1ge/e8evtyVyz+v6s/AxPDafXZll3Hj25txy4rfNVSN08ZT3x7gmRWH+GDeCAYkhKvu1xJoPcbUjnf9yp0/MMbAM+sd3DHcQLjJd58q93TUWg6ml2QDMKPPVADMosTEbiN830AB28ESPNUuJKvmodPw8tf/7vP7EKgvVU+trwcEV6QrCrAxvVgzpjTOKdUON1/s9HoSi6qc6CTvc1VWFCRB4PpRXblhdFckUeDuJTt5duZAEiO0xryNITrUSFGlEzqf3XlkWeEf3x30uzhuKK04LaecXdllDO7iq4yr0bpRnE6c6RkYe/bEZNTxP78s5JE77+Vbc1deXXVMVQVPViCjqJo/f5bGI1/sYd64btw9pTuSKHDz2G6s3F+gmu4XaM8ii1HH6OTI2v+fTb3TlYPjSY1Rb0qr0fx0ibTw/s0j+HJXLvPe28blA+P404U92J9XwY1vbwmoZty7xvIw+/VNLL5lJEO7tsznUOsxpvZ86lXuU2FYnMSkJB3PbXDw5BRfOXObZzTg+4NMjkhEFET++M1TXN7rAgbH9yXcpN7LR5BEHOllWPq3X+lHjd/ILqlhR1aparFtIKl6agRVpOtRVCVoNTSai+W7T/Dgp3sQBH5zIpxxC761LoO31mXQ0aJn9rBEJveMOfcDbSNEhRgprFKf84Jh3dEiapzqEtOBPKscbg9vrknn1euHnvVYNM4vTrfMj/tPcvhkBaU1LixVZYT2m0KCqEdesQLBZCJqygTWf7ADRwML3VN9iV5ffYydWWW8ceNQBiWGE2MWOe7n2IYiSSa9yPzT1AH35JTz08GT6u9fjxPgFJV2bY483wiCwIzB8UzoEc1T3xxgyr9WUW5z+cjxN4TN5WHuoi38cP+EWsXRlkTrMaZspfVufmKykbGLqrl3pK/RpKDuvQ81Wvn8upd5dfMSHvz+WQqrS5icMpJnpj1ItPWMgkVZQa5Rn5A02h8fbDx+1ql6ZxJc13lvPriGxrng3fUZ/PP7gw1OgM5f63EKKx0czK9scgGF9kRtZOosWbj6WG0Pu9MJ9FklK7DyYAEl1c52KXncFsgrt/Hu+kw+3JyFgnLa/aBgjp/AS0/+yOT8vdz6h9v557JdbM4oDnixa3PJbE4v5p4Pd/D38s1ct2kn/+57JXYl+N+9XhKZNew3oZo31h6rfaacTqAOy5UHCiircWpqfi2ACKuBf80ayPz3t/Lj/gLVfRoykO1uD2+tzeCv01teyU3rMaZUYwC/0S9GYnoPHf9c56R39JmFbv6P7R6VxH8ufQSAo8XHuefrJ/nfn17ilcv/5rOvljOucYpv9+addWf2MwmmSNcgiXTUJgiNc8CP+08GZEidjqx4VeH+/s1+/naZ/2ishn+iQowUV1SDxw1S46fqPTnlqq8H86wy6kQO5lUwJjWq0ePQOD+sP1rE/Pe34fbIOH3mLAGboAOXzIqIXqzcUINAjcp+9S907W6Z1XtP8K2jjJveeIr8rcUs3ZIdVOsPs17ivdP60JXXuPhh30mf+qtgHJaCAB9tzWbBxJSAx6HRfNicHtYdUa9/C8RAdnkUlm3J4oGLera4OtzWY0yZwqEyv95dHp9kYsjrVfxpdN2CXVGoRFYaLlxLjezKrH7TWLzrK9+NooBobT0fl0bzUuEnfSCYVL0zCaZI1y0rTO6lpVBpNC+yrPDoF3v8GlL1LbBsLg9LNmdx64TkFpmW0SJRFMhYDetf4NaMtQiyG7YBehP0vBTG3AVxg4M6pb8FbTDPKkXx/8zTaLmsO1LELe9vDcgRIgsish8lv0AWunZJz8cpE5nbuTN/nd4Jo07k3Q3Hsbs89brCDZKAQSfx7k3DGXJaXd7ao4XoRIEzu60F4wSwu2Q+33lCM6ZaCMvTclWlD4LN6Plubx5XDk5ozqEGTeuxDvrMgPXP+62bAkiNEJndV8+LW5z0j/ktOmURV1HhmQVnNH87Wnycn45t5PJeU4jtEENuxUm+PPATQ+JUPKmygjGl5SqJaLQMgu3MfiaBFOmKAlzQO0ZLudFodtYfK6La0fh6G0XxpsQ+OK3XuRx26+TwD7D8XnCUg7OaOk8Plw32fQ6HvoWwRLjq9YCNKkkUVEVvgk0rNupalidYo35OlNm49YNtqoZUIPVGpwhmoZteVMWh/Ep6dg7loYt7M7FnDAtXH2PDsWIEqCPLbzVICILAdSO7MHdsko/DpbTaqXrfBuuwLFcR0dA4P2xOL1YV7QrGQK52eticXqIZU41m2M2w7j8N7vbYRCMfpNX98Vh131Phme2zr9VgYVfuAd7c+jEVjio6GEOYmjKaRyffUXdHAcwDohFNrefj0mheOpj0VNh8F5nBpOr5o6EiXaNO4pbxyY06t4ZGMCxcfay20Px0Al1gOT0yH2w6zn1Te9SRUNY4g62LYMXD9ToLUWRvE/qiQ/DOJTD7A0id2uCpO1oN5Jf7njeYZ5VHVogObZp+VxrnhnfWZ+BS6a0ZrEBSMAtdl1vm8x05PHyJt0/nqORIRiVHkl9u58tdJ8gsrqbC5iLMYmBY145cOiDWr5Hu8aN6G6zDUq22WeP8UFKtXgMarIHs7zznk9ZjHYR2gpTJ3l5TpwWNM++rq76XGCZi/0vd5rqSUI5J2oVdHlGnfCo2NJrXZjze4FsLOpHQ8fEN7qfRfrikXyzvbMjwqZsKJlUPvIL9wTzqzXqRGYPj66RDaGg0F1sz1YV/gllgKYrC4ZOV9IsPa+rhtQ32fwUrHqnfkDoTVw18dD3M/Rbih/hs9sgKWzNL+HZPHhU2l+pzJphnVbhFT984rWl9a8Hh9rB0S5bP/NQYgaRgFroexRsRq/OessLRgirWHS1i/dEidKKIIMCyLVm8tyGT2yelMLV3J59GsB2tBtU0v2AdlmFmrZ1NS8FsUL+HgjWQ/Z3nfNJ6jCmAi/4Bx9eDozK443RmwpMPczJjDIpKNKFe9CKWYZ3Qd9b6+Wj8xo1juvLexkzUTKFAUvUkAab1i2VLZonfdAY1HG6Znp20vhkazY/bI6t6tiG4BZYgCFRoMv7quGzw39vAXXcBmvR8JTUuyLg3BKvBW2Tw1g4ni9NcrJpr/e3Yz+bB3TtAEPDIClsyvAbU9/vyiQoxcmn/zrx/8wiue2tznRSrUwTyrDLrRW4dn6wJMLUivt+br+qla4xAUrAL3dMl1TceK+a+j3ZSaXfXpnc5T3um7M4p508f70YniTw5ox+XDYyr3TY6ORKXyrwYjBPAqBOZ1u8sG7VpNBmJHS1IIpw5rQRjIOtEgYSOLa93YesypiJT4LrPYPGV4KwO7Bi9GWIHofv9i0QXuCh8Iw3F4QkoHCDoRUy9Iwi/TCte1KhLQkcLQ7p0ZFN6seqt1FCqnl4ncs8F3YkMMXDVK+vJKrX53fd0ZAWe/v4QJyscPHSxVoei0XyI9Syeg11gSZo8ujp7P0etoTx4vfwvbHbyyHj/6XVKZT57Nv/Ix/mxfL/3JDGhRi4dEMvHC0bT7bSG3hN7RvPzgQJVp01DzyoQuHpoy6pP0KifQ/mVqum5jRFICjYSFBHivV+/3p3L/3y6u0Hxi+pfm7I+8OlucststWIRMR1MjI6zsuZ4BcoZz6JAGwADXD+qa4BXqtHcXDM0gfc2ZOKR694TwRjIkihw9ZCWlynWuowpgC4jYd6PsGQ22Er8GlUKOlBkhL5Xw2XPg6THEGek092DKf7wAO5CG4pHBpXfuWAQQYGQCQl0mNpF88hpYHd52HOinNJqJ4IgEGHV89hlvbnmtY2qk1Z9mPUSMwbH07NzKB9vy6ag6sxEhvqxuTy8uyGT+HAT149OCupYDY1AEUUBs15SLRgOtt5GE0vxw/rnwVmluumBMQaeWe/gjuEGwk3qc5DsslH107+JG/syn942mqQo9QyKf141gItfWENhpcNHaro+THqRl+YMJtSkpUq1JvzVlDRGICmYha7VIDEmJZKNx4oDMqROx+6S+c/Kw0SHGrm8i4mi1xYyfd0etgy5HptKz6qGnAAC3uhWpw6mgMeg0bx07xRKaqcQ9p6o8NkWqIHcJ7YDydEtLzun9RlTAJ36wn17IGMNrH8BMteAZABB9Pbk0BlQBt9E5jM/EHftXZik3yYCXaSZTvcMwZlbRdW6E1Rvy0XQ6bwSaR4FqaOJ0IkJWAbFIBpbXl6mxrklq7iGdzZk8NHW7FpPvSB4VcoEAab27sSK/fkBTxpmvcS41EienNEPu8vD41/ta5Taks3l4alvD3LV0AQshtb5M9Zo+Uzr25kvd+f6FIMHs8AKM+tJjWl5k995pzwHyrL8bh4WJzEpScdzGxw8OUV9QSihMMazlTETU1DVHP6VCKuBT28bw8yFGymudqj2yDsTk07kH1f2Z2qfTg1fi0aLooNZfU5orEBSoAtdQRC4qG8nJj27ulHtFOwumUc+2UnKqmfodNnFzFjyKt9+5VUDVEtTrQ+LQeIvLbC5a3vnzkmp3P/xbtWWDQ1Hyb31UgfyKugd27JqOFvvKkwQIHmi94+9AmqKvDnkpnAI6YQo6Qg/mUrhiy+R+NqrPocb4kIIu7AThU9fR+raDSCDaNIhaIpTGniLZv/+zX6WbM5CVhS/i4/v9ubjlmV0ooBeErC7ZNW0P5NeRFHg+lFdePji3oiiwFe781TPGajakiDA8t25zKFNi48AACAASURBVB7epSkuWUPDh3nju/Hd3jxsKuGMwOptJBZM0OptVKku8joB6xGeeGKykbGLqrl3ZAORPWc1GOs3WBMjLHx773ie/Ho/3+zJQxQEnwWNJIJeEkmNCeGvl/ZhZHJkwJej0XLoGmnFrJd8vt9gBZJOp8HUdckrc77jeBmVfnqSBTK3CbLM3sdfZuCFAwF47fpwrn1jIwfzKwN2WloMEm/PHa45cVog0/p15vu9+UE5oU9nU3oxV766nuSoEF64dhDdO4U2fNA5oPUaU6dj6uD9cwbhs2dRvGgRtrQ0zAMG+Gy3HzyIsVcvJIuWgqLxG4qicPeynfx8oKBBb9ipYlqDTmBMShSyrLD2aBF6SUQUvB27wyx6bh2fzMxhCbXd3QFeW+UrOx2M2lKN08Nrq44xa1iitljVaBb6xoWRoPdwxIlq5KOhBZaiKFyl1duoIzecHtwvRmJ6Dx3/XOekd7QfR58gghyYsFKE1cC/Zw/ib5f35bPtOSzdkkVpjRO3RyHEpGNMSiS3jE+mRwtZoGg0jukD4nhi+X7VbcHUGwWDxSAxb1w3Hvh091m1U7AJOt7cXcLvf6cgCAImvcRHC0bzwCdprNiXX69z02qQsBp1LJo7XFMPbaEIgsBzswbiXCqz6lCh36bi/pAVb0rogbwKZryynndvHsHwpIhmGm3gtA1jyg+i0UjUbQsofPElurz1ps92+/4DmHr3Pg8j02jJPPP9IX4+UBDUj9zuktlwrJj7ftedV68fSmmNE6dbpoNZT0eL3sfYqXK4SS/0rfcLVm3pRJmN0hqXVpOi0eR4ysvJf+LvPJBZwH29r6XGHVy/FpNe5Mkr+tFBq7dRxxwekBH0+CQTQ16v4k+j/QhRyC4wBpfyEmbWc/O4btw8rltQx2m0DsLMeqb168zXu/PwqPRZCiSdKhgsBokPbxlFVIiRdUeKVfcJZm7LK7eTV24nLtzbyNeok3hxzmCyS2p4f2MmS7ZkoygKkiCgAE63zNCu4SyYmMKE7tGImuBNi0Yvibx63RAWrc/g1V+OYXd5gq49V/CKl8xdtIUv7hhDUZWT19eksy2zBJvLgygIdDDrmTEojpvGdiMxonkVANu0MQUQftVVFL/xJmVbt7NK15mlW7IoqHTg9siYSsyMi+3H/DJb7Y9Wo31TWu1k0foM1YhUIHVMz/94hBtGdW3wfiqrcaKXRNxneKeDVVsySCLlNs2Y0mhaqjdtJveRhwmdcgFTn3qSt09UM++9rapiFGqY9CL3Te3BNcMSm3mkrZiOSV61WVdNvbulRojM7qvnxS1O+seoRKdiB4Oopadr1GXBhBRW7MvH42q+prVWg4RJL7Fk/ih6dg6l3OZCFMGj8pgIZm7TSyKlNU6feTQxwsKjl/bhgYt6kVtmo8LuwqSXiA4x0lGbA1sVgiAwb1wyc8d047s9edy9dKdqiURD665qp4dLXlyHQSfWmZ9kRaGk2skHm47z4eYsBiaG86+ZA5vNqGrzxlSNIvL+ZXfz6SfZCKbCutavEMrxAljy3CpGdIvgoWm9tNBwO+ejrVmqddxNXcfkT3Y6WLUlBa92ioZGUyA7nRS+8AIVy78m9qknCRk/HoDRKSY+u30M9y7bSXaJDadbVvV4n1pcPX5FX6YPiPPZrnEaogQjb4e1zzXYsPexiUY+SFOpQzGEwLg/NtMANVozfeI68MjFvfm/7w4GnUrlD7NBQicIuGWFrpEWbp+UwrR+nTHqvHOVoih+hP4bMbfVYwMadKJf5UqN1oUkCuSV2zHpRWxn1FAFuu5yywpuP44+b0qowrbMEi59cS1Lbx1F37imX+e3aWOqsNLBtW9sJKfUiEOUQeXDdsqALLP2SBFbMzfy0pzB/E5TL2qXyLLCW+syfIoiG1PH1JAxFW7R12leeIpg1ZZcHrlOHZaGRmNxHDnCiQceRJ8QT7f/foEuom4eeu/YDvzwx4nsPVHO2+sy+G5vHg63jAAIHg9DkyK4fUoPJvSI1vpKBcrQubD2WZ+XM++rW7OUGCZi/4tKKp+og56XNNPgNFo7N45JwiXLPLfikM9C9UwaigCIAsSEGnnkkt4kR1lVC//NtircfuqZgpnb3LJMmFlLD24PyLLCm2vTfe7PYNZdAb2PAhV2N3Pe2MQ394xv8ghVmzWmqhxuZi7cQE6pTbVRoRp2l4e7l+5g0R+GMyY1qplHqNHSKKh0UGn3rWEIto4pu9RGjdNdr2S5xaBjYEIYO7LK6rwerNpSz06h2qSjAXjrBn4+eJKMohqqHW5CTDq6RVmZ0isGveQ/DUyRZUoXf0jRa68R86f7Cbv66noFTfrFh/Gf2YP4z+xBuDwyHlkh95qrib3xKcy9Yprj0touIdEw6k7YvLDBdD8f9Ga4+GmQ2uw0rtEEzBuXTN+4MJ5feZjtmaW4VNZDgUQAZAVOlttJibaSGuM1pNylpdRs20bNlq3UbNmC68QJ+ky8m716XxXIYOa2cLOBhI5a6UV74NDJSqocZ7/uOkVDToEqh5sHPt3NsluDO29DtNmn8F//u5fccruqIdVQn4P5H2xjyyNTsRrb7MejoUKF3YVeEjjzdx18HZNAuc3VYP+n2yelct9HO6l21I2YBqq2ZDVI3D4pNaAxabRd8sptvL/hOIs3H0dWwOHy4JYVdKKAUS8iCQI3jO7KjaOTfBpYuk4WkPfoo3gqK0hathRD165BvbdeEtFLYExJwXnsKOZ+fRs+SKMuFzzm7Td16NvADSq9GcbdDwOvbd6xabQJRiVHsuzW0cx6fSNbMkrqbAsmAuDyyCxcto7/qdzpNZ5ycjAPGYJlxHBi//4Epj59uPdwMfct26kqKBBoO4X547tpCrXthJJqp2omQ7DrLgjcKbAzq4zskpomjU61SWuhvMbFt3vycKqICATyYSsKfLnrBL8fGdzCQqN1oxMF1TztYHO9ZYU6kYCD+RX8dKCAokoHChAdamRyzxim9IrBIIlU07jmdZIocGFfLSW1PfPLwQLuWLIDj0fxSRt1ywruXw31N9dm8M76TBZeP5QJPaIBqPjhB/Kf+Dsdr72WqNsWeJuXNxJDagqOo8cafyHtGUGAq9+Clf8Lm17D6ZExoN6nB70FFBku+j8YdtM5HaZG66bc5mJnVqnP68FEADwKfHXCxUMp0cQ+8TimPn0Q9HUzI6b0isGgE/2qszU0t8mKognXtCPObAh/iqDXXUE4BWRF4b0NmU3a1LlNGlOfbM9WLfAP9MOucXp4fXU6c0Z00bwj7YhIq7FJ6pg8soLVIPHV7lwWrjpGelEVLo9S+9CQRHj55yN0ibQya1gi723IxB5kd3eTXuRfswbVm76l0bb5cf9J7l66I6DGh063jBO49YNtvHJ1X3p/8jo1W7eR+PJLmAcNOuuxGFNSKf/qq7M+T7tFEOB3j5PZ/QZ+fO//uMW4EkF2e3tIgbcnlSkMxtwDg+Z4ZdU1NIKgsNKOQSfi8pydgqyi02G6cS5mPy0PJFHgH1f2548f7wq6KatZL3Hf1O5a6no7IsysV3ViB7vuCsYp4PIofLHzhGZMNcT7G4+rqtcE82EXVjk4kFdJn7jg+ndotF7CLHr6x59dHZMAjE2N5Ia3t7A/r0JVStojg02WOZRfSXZJDSa9hMMtq8qCqmHSizw2vY8mlNKOOVpQyT1Ldwa9WLG7ZO78cAfvG80M++JzRGvTKGIZU5JxHj3aJOdqz7yb5sAy8gGE3y2EwoNgK/MaVJZIiO6p2jhZQyMQbE4ZQUVrL9gIgCQI2FweQuvpH3dx/1jyyu08s+JgwM8os15i1rAEbp2QHND+Gm2Dnp1DUVSsqWDrx4N1ClTY/UT/G0mbNKaKqhyqrwfzYUuiQH6FTTOm2hlnW8dkNkikF1VzssKhmmZ6JjVODzV4iAk1IisKNqf/5nVWg4TJIPHsNQOY0kszpNozr/5yTDWKCg0X4LokPV8MuYIRTWRIARi6dMGVl4fscCAa/TSX1aiXaoeb/+46wTf3jAdJD537n+8habQhOph1yCqL1mAjAE6PHFAj7pvHdSM61MBDn+1BAL/zmlkvISsK9/2uO7eOT9aygdoZJr3E7OGJfLDp+K8y5r8R6LoLgncK+EsvbCxt0phy+VlkBPNhK4rXk6PRvpjSKwajTvIxpiCwOiblV8Uj5xkPhYYWuOU2F+O7R3HD6CReX32MTenFSKLXj+iWFYYnRXDbxBRNdlqDCruLb/bkqU4GgdSEeoAvdubwl0t7N5nIjmAwoE9MxJmZialnzyY5Z3vji50nGJEUQbzWQF6jGegcZlJ9PdgIQHSoEZM+MO//ZQPj+V2fzizfncvC1cfIKbXVpqa7ZW9bjwXjk7l6WEJABppG22TumG58uDkLVPJzAll3QfBOAbMhcGGLQGiTxpRJL+Hy+EotBvNhCwKEmtrkx6NRD5Io8Np1Q/jDO1uCTqEy6UXcsuzjXQlkgetwe3ud/e2yviyZPwpZVmpl2kNNOkTNgNL4lc+355xVTSh4m0Y3tciOMSUFx9GjmjHVCBRF4YONx3nssqbL4dfQOB2jzhsBWHwWEQCv0l5waXgmvcTMYYnMHJbIyQo7JdVOZEUhzKwnPtysRaI06BJp4eohCXyxM6fBfmj+CLYcY3hShP+TNYI2aS30ju3gI/8JwX3YTrdMr86+Tek02j4jkyN54drB3LdsZ8A/bItB4oJeMazYd5LTvStBK8xszOQvl/ZBFAXCLJqnTsOX3dllZ10TWuP0kJZTzu+D733oF2NqCs5jmqJfY9icUYJblhmT4tufR0OjqZg7JokPNzU+AiArCjPPQmmvUweTT3sGjfaL3eXh67Q83lybTlZxdaMNqVMEU46xYELKWb3XmbRJY2rBhGT2nShvdJ8DgNHJkcRoP/p2y0V9O/Ph/FE8/PkesoqrcboVPGfkm0uiVwK9e0wo/3dVf+a8ucmnjiVYhZmlW7J5aFovTaVPwy+lNvXC2WALcEtrnE05LAwpKVSu+KFJz9leeH9jJjeOTtK89BoNklNaw7d78sgts+Nwy0SFGBieFMG41Kh6MxjcHpmv0/JQ8PagU+vBWR9mvcjs4V00pT2Ns8btkfnXD4d5b2Om/3o6RfEruGPSiX4VkANxCoRb9IxK1iJTDTKpZwxGvdToPgcWg8StEzVFmfbOkC4dWXHfBPbllvPkN/vZmlFa+9sOMeqY1i+WeeOSSI0JpaTaiUPlxx3sAtcjy5RUOzXvnYZfrH6aQQdbgNvUTcmNKSkUpWuRqWDJL7ez/mgxT1894HwPRaOFoigKa44UsXDVMXZklSIrSp1UPatBwmLUMX98N1WDJ6Oomj99vAuzQWLFfRO4c8kO0gurVecsNUx6kaFdI/hrE0pJa7RP7C4PN7+7lZ1ZpfVHok4zpCQBLAYdblkhKsTAbZNSWHO4kNWHCxtVjvH45f2a3HHVJo0pSRS4/3c9eOqbA6rpMPWhFwWSIq2MTtbSLdozlXYXqw8XUlzlxOWRCTHqmDUskX9cpa6wVWl3oRMFzvT1By07KwpU2t100kQkNfzQLcqCXhJ86h6CqQk1SCJJkU2n5gdgSErClZWN4nL5NPLU8M+Szce5fGBcvVLTGu0Xt0fmoc/S+HZPvt/1TPWvKrD//vEwb63N4KMFo+kWZUVRFBZvOs5/Vh7hnimp3Dg6CVEU+PS2Mcx7bytpOeWq7TtOIeBNiZrcM4b/zB6kiR9pnBWyrHD74u3sOF4aVG9NURBIjQnh2ZkDSIkOQRAEZg5NZM6bm9iXWx6wQWXSizxwUa9maSvTJo0pgOtHdeVgfgWfbT8RsEGlEwUiQgx8MG+Elm7RTjmUX8lba9NZnpaLJAq4PQqy4m24K4oCh05WctvEFKb0iqkzsVgMOlV1tWAVZmTZGxnV0PDHzGGJvLk2gzPrHoJS5RLgmqEJTTouwWgkM7kfOev3Q0xnwix6esd2IKSJI2BtCadbZsmWbJbOb8LiNY02g6Io3L1sJ6sOFgRUT2J3yTjcDma8sp5Fc4fx/MojVNjdfHLbaFKiQ2r3sxp1LLllFKsPF3Lvsp3YXB50oojT7QEBDJJXrnxMSiTzJyQzOjlSWxNpnDXL03LZnFGiakjVp3jskhUOnqxkz4lyUmO8WgYGncjS+aO4/+NdrNx/EpfLjUdQL48w60UUBf5xZX+uGtK0894p2vQs9/cr+hFm0vP2+gxcHqVeXXmLQSI+3MyS+aOIDNH6pLQ3FEXhXz8c5q116bjcMh6VW0X2KGw/Xsp9y3aSHB3C4nkja0Uiwv2IRQQrOysrChFWQ5Nfn0bboWuklQEJYWzNLPXZFmhN6IikCOKaSIK73Obik23ZvLU2g/Kes5B+zEHQ5QPgkmUuHxjHvHHJ9NQEfXz4bm8ePTqF0L2T9tlo+LJw9TFWHSxUNaTqW3xW2F3MXLiR+y7ozh2TU9Gp1OCKosDkXjFc2LczyVFWwix6ympciIJAR4ueyb1itHRzjSbl1VXHVCOhgSge25weXlx5hN3Z5fx04CSVDjeiIBBu0TO7fxTZ361kQ8Ig9Lrf7nVZVggx6Zg/PpmZQxObVdSrTRtTgiDwwLReXDogjrfWpvPNnjxvKpZHRv410uCRFXrHduD2SSlc2KczBp1W+N8e+euX+/hse05A4eJqp4eD+RVc9vI6lt89jjCzHr0kMmNQLJ9uz8Gj1PXgBbrAlQSBS/vHBtzDQ6P9cufkVG5fvEM16t5QTahZL3HH5KZRMvpq1wke/DQNQRC8YxF03kZWp7Wm+Gx7Dl/tzmVSzxhe+H/2zjs+ijL/4++Z2d3sppJCSICQhCYdpPeiYgHROxsWUFBEwX6eeud5evqzt7PcKVhAbOjZQMEu0kvoSJckpEEK6dm+M/P7YwkSdjbZJSEQdt6vV16YmdmZZ+LsM8/3eb7fz+fafoQZ9Oe7lvfX5XDrqPTT3QydMxC3rPDm8kzN73hDg09V9UqSd0uO1gykjic5xoxLVrihCW0SdHROZNehSnJLrT7bg1E8zi61kbc+p454SpnVRUFJNUpSL8Z0SeCqAe0xiCJGg0h8hIkeydHNYi1zVgdTtfRoG83Lk/vxr8t7smxPMUdqnMcCqndXZ7PkrpH6EnYI88G6g3yxOT+o+jq3rFJY6eDm9zby+a0DEX77jFtyP2AxtyPjO/sRiMKMySByiz6w0gmAseckcv2QDny8ITeo59ZilJg+Io3hnRIa3YYFaw/yzHd7Gsx9l1WQ3QrL9xZz7dz1LJw5VJ8wAHYWVHKows4F3Zs+f1+n5fPT7iIfBVkIfPBpd8vMWZHJhT2T6r1OcoyFHfkVTddwHR0NFm0t0BQ8CUbxGNBUoXTinTBYtreYzTnlfDJzWLNnQoREMFVLtNnIn85td+x3VVX5aEMuB0ttpCc0bTG2TsvAIyu89NN+/4W99aRSuGSFPfmlbH3pT/RPkuj654fp8Z3EjvzKoGVnJdFbYNmzbUxT3JZOCPCPCd3xOF18uv4gDqnh9AWLUWLK0A48cFHjTXV/3VvsDaSCUFJyeBT2HK7i3k+2MWfqgEa3oaXzwbocpgxNbXDlQCc0eWtlFlZn4/zkdh2qIrfURof4cN/zeGS++62Q99Zmk1dmY/WBI0SbDYzpmsjUYalNlgasowNwqNKB1rAoWMXj+lBUKLe5uWbuOr69ZxTtmvEZDqlg6kQEQWBQWhwbs8v0YCpEWba3GLesPSAMJI/XIcPbsffx5k3jAZgz1cGEV1cddXkPrA2iADEWI+/cNLBJ7kknNBBFgdt3LKKLFMmCxEHkl9txeZQ6s9mSKGCSBDrER3DfBV24uFdyo6+rqiqPLt7pN5CqbwLC4VFYvr+Y3Yeq6NE2dCUrK2wuvtt5mGV/HXu6m6JzhpJVUqO5PZjBp8kgkllSUyeYsjo9vPLzfj7OyAX1D4+f/HI7AAeKrcxbk82gtDj+dkk3erXTJ/h0Go/bTwZDsIrHUP87BqDG4eGuj7fw5ewRTdb+hgjpYApgcHocG7LLuGbQybt667Rc5vqZ/Qs0lUJBZFmuhzKri7gIE4lRZr6aPYLJc9cdTSetP6IySgJxESY+nTlML/bVCYrq5cuxrlrFdYsXMSUqip0Flby39iD7i6qxOj1Ehhk4JymKm4anNemK5+acckqt2oa/gUxAuGWVd1dn8dI1/ZqsTS2Nzzblc373NiToYkc6fvA3WRHM4FNRVaqdf9QvllQ7ufatdeSX2/16TNUaz68+cISr56zj1Wv7NZgqqKPTEP6EtYJVPA7kHSOrKrsOVZFZUlNHxfJUEvL5BYPT49h4sOx0N0PnNLHncJXm9mBSKUwGkd+Lqo/9nhIXznf3jOaWkR2JMhuI0JA6jzBJRIYZmD48ne/vGU2avjKqEwSesjIK//koyc8+gxTlzQ3v1S6GF6/uy9d3juSX+8ey+M6RPH9V3yZPHZ27MkszLbZ2AiJu/CzCzxmOaDIjSAbCOw+po1opKypLdhymyuFu0na1FGRF5YP1OUwdphf86/jHnxjW8YPPhhAF4dj7p8bp4eo5a8kptQVs1mt3y9z9yVbWHDgSeMN1dDQY2SWBiDDfsdDxise2/etQ3A5U2YM9cxPlv86rc2yg7xjw9rPz12Sf0ns6npBfmercOpJqh5vCSgdJMfrKQKjh8FMrFVQerwpVDk+dTTHhRh66pBv3je/KD7sK+Xr7IcpqXKhAfKSJS/skc3GvJF3ZTCdoVFXl8KOPEn3ZJCIGD27262dkl6FRFx/0BMTOgsomEcJoaazYX0yMxci5Ka1Od1N0zmDatjKzv8g31S8Yuw2PrJIS503xe3TxTg5V2DXreetNzXUrzPxgExkPX0CE7hmn44fiagfL95ZQZnOhqCqtLCZGdUk49vxd2COJv3/5m+ZnA1U8DuYd41FUvtpSwJN/6t34mwuAkP9miKLAwLQ4Mg6WcVnftqe7OTrNjEEUj6U1HE9QebyC11lbC5NBZFLftkzSny2dJqLyy69w5+bR7uWXT8v1bS6P5vZgJiBUFarsobky9f66HG4clqoryOrUyy0j03n8m92avjyBDj7bx1ro2iaKSrubpTsOa6adB5I2paqwaFuBLp+uUwdVVcnILmPuyixWHzjitR7yKKiomCQRRYV+Ka24fUwnxnRtzdShqbyzOhuXxspoIIrHwYpV2NwyblnB2AwiPyEfTAEMSfeKUOjBVOgRF2GisMrhsz2YPF6PrJKk1zvpNAOu/HyKX3yRDu/NRzSdHnNnURAA30FZMBMQAjTLC+5M4+ARKzvyK5kzRVcz1Kmfy/q2419f7/a7v6HBZ4RJ4vYxXj+5zzblHf3e1iXQ2mCbS2bO8kyuH9xBnwTQAbw+aH/53zZ+3lOMwyWjAsdX0toVb8C0IbuM3woq6dM+huev7MNHq37HpQpwEs9RsGIVkiDg9DRPMBV6bzMNBqXFkZGt102FItcP6UCYRm56MHm8STFmOic2T5GjTuiiyjKHHvob8TNmYD6n8fLmJ0uMRVuGPZhaDgXVb0Hy2cyH63O4emB73WdLp0EsJonrBqdg9lM71RBGSWRiH6965wfrcjTrHINJmyq1utjtp8ZYJ7SQFZVbFmzkp91F2I8GUvVhc8lsza3glpe+4+m9XxBpkk4mlgrqHQNeIQqtmvVTgR5MAT3bRlNQYafCpq1QpXP2cv2QDn73RQ++gtjzbqFy3afkv34D+W9Oo3rLEixd/njxhJskZo3ppM/W6ZxySufNQxBF4qbddFrbcWX/9pgk3+c9mAkIs0GiT/vQqhmyuTx8sSWfKXqqlE6A/O2S7nRLjvIrRuEPi1Hi/VsGHwvaS2qcmscFkzYliQKHK3yzOHRCj6e/3cPG7PKgfAadHoUcj5FFE29n8d2jSIwKCzrQCeYdA9AlMbLZxmZ6mh9gkETO7dCKjQfLGd9Dd6MPJRIiwxh3Tmt+2VuMWyOfvKFUCkFAr4fSOeU49uyhbN580j//DEE6vasaU4elMm9NNlqpfoHUcpgNIreMTEcSQ2sCYvG2QwxIjT1WkK2j0xAmSeDDiw1MX+RmZ5mAXa4/qJIEsJgMzJ8+qM5khT8vxWDSplRVxeHRFmzSCR2qHG4+XJ+jqQjZkP+TSzSwMqucR0SB1Q+dxy97injjxz3sK7FhkD2osowgiXgkIxFhEpVOhROHZYHWC0aYJGaP7XzK/g4nogdTRxmbLKOufxMKVHDbISIBUoZA+uiTyu3UaTk8e2UfJry6iqIqZx3D04YwG0XenjoQSzMtI+uEJorTyaEHHyTxoQcxtmt3uptD21YWBqXFsS6zVPP70tAEhApcO9j/inBLxiMr/LynmO93FnKkxokgQGJUGJf2SWbB2oM8PKH76W6iTkvAUQnbFsLaV4l0VPGxKvGRNJK3lIupUCOwY0I9LrHIYpRQUbmsb1vuHNeljklv7X637CscE0xtsCAIRJm1U3x1QocvNuVr1t8FImQCXu+ztz9fy52lG+mychUvWmuoGDGe0n6DkDt1JSo2mrT4cFQVLnplJfJJilUgwCW9m88fTQ+mslfBmleYnr3KuzKRW5vqJ4LJAuYYGHY3nHsDmKNPa1N1Tg2twk18Pms4E15bRaXN3WD+L4DZKPHatf0Y3jn0pJ11mpeSf7+CKb0jMZdffrqbcoznrvJOQFQGqchnMYr867KeZ129VKXdzburs3h/bQ5uRfExAl/622FcHoWdBZUMSI3VJaZ1/JO1HD65AVQF3DYAjMA0cSk3mZayQe3GZ54xHBLa4BLNxKb2YnSPVK4Y0J5IP89V9+RoNmjUhQcjs+7yKHRLijoVd6zTgnhrla/PYKBCJuA1bf8i08rs9DjavfwSYd26IYjaK64XdG/DL3uKcAToi1aLxShx9/ldmtV6JnR7dEWBb/8K2xeC246ISl0vegVcVu/Psidg7asw/VuI63iaGqxzfPKOoAAAIABJREFUKqkdFE7qm8xPu4sRBHwkaY2SgCgI9O8Qy8MTutO7fdOaoeqEHk6PzIHiGqrsHoySQEJkGKnx4cfyvK3r11P17bekL150RtXltWtl4dPbhjJ57nqqbU4UoeGaDrNR5J4LujB50Nm1KpVXZmPy3HWUWl1+zVBrawte++V3Pt+Szye3DiVRVwDVOZF938Nn08Bj19wtCDBU2MtQ096jG0Q4Eg3nLIN6AvTbxnRkZ0El1kbIrA/tGEcb/Zk9I7G5bSzJWsLiA4spc5ShohJliuL8DudzVderSLA0zaSvwy1TpKF+HIyQCYBoNuO8eirmhIh6j3t5cl8mz13HnsPVARtNW4wSk/omM3NU847VQzOYUlVYNAv2fH1s5qde3DbwOOCtcXDbSojVC4jPJuwumbsWbuWfE3tw5YD22Fwevt52iE825nGkxolHVokyGxh7TmtuHJam1zzoNJq8MhsL1h1kYUYuAsKxTGK3rJAUbWbW2E5M7BjF4YcfJvmpJzHExp7W9mrRLSmaj9sV8chON7sjk1FUVbPuMMIkEWU28tikHlzSO/k0tPTUUVzl4E//XUO5zYWGF6oPDo9CbqmNP7+xhm/vHk1MuJ42pXOUol3w+TS/gZQmqgLOKnhvAty5CcK0V47GdE3EbJQ0gykITGb9ttGdAm+XTrNQ7ijn1S2vsjRrKYIgYD/h2cmuzObtHW8zot0I7htwH+kx6Y26XrXDg1ESfQKbYP2fREGg2tFwVkOYQeKTmcOY/dEW1meVYnfLmobxAAZRwCAK3DQ8lYcu7tbsk4+hGUxtmOsTSKW9Uo3NDdn3RBJh8v5PeGeLiw93uFk+LeKPTmvBpXDXVpBC8093NvLEkl30ahvNFf299SjhJgPXDu5w1tZ16Jw+PLLCPxbtZNHWAr/Bx8FSG49/s5t/OV08PvJyuowadRpa2jCu3FyM77zBp58spCg6kffXHeTLLQVUOdzIiorFKDEgNZbbx3RiWKf4M2plramYsWATlXa3TyBVXyG2R1EpqXZy18ItvH/LEO0T64Qey54Ed91Z/wbHJeAdmziqYNvHMOQ2zVNLosBfxnflyaV7NCXS68MoCnSID2dYp/jg70nnlJFXlcdN399EubMcj6JtpO6UvSqOy/OWs+HwBv57/n8ZmDTwpK9pMUnIGrNGwfo/KXjfD4FgNkq8e9NAtuSWM3dlFiv2lWCUBGTF61coigKyonJF/3ZMH5F+2mxqQi8iUGRY+ZzmipSswqsbXDw8Kkzjg3g7LVsZ7P8eul96ihuq0xws3XGYtZmlLLlr5Fk52NM5c/DICtPmb2RTTlmDKQs2lwyCxKNKV9SNuWdcapyqKBz+56PEz5yJKS2NFOAfE3vwj4k9TnfT/KKqKhXOCqpcVUiCRKw5lghj/Wkm9bGzoJLfi2vwnDC4CKQQ2yWrbMguI6/Mpq9060BNMWT+gpZCZoPjEvCOZ9a+BoNn+hXMumFoKnsLq/l8c37AAZVBFIiLNPHhLUP092MQlNY4Wbgxl6+2FFBhc6OqEGUxML5HG25qguyWI/YjTPluChWOChQaTn9TUbF5bMz+ZTYLLl5A9/jghHCO1DjZmF3GhuxSzWAqGCETALdHpXVUPc/zCQiCwIDUON6aGseRGicZ2WVU2t1IgkBchInhneMJN53ecCb0gqnffwSPtp/UA8NNPL/GyexBJlqZ/XQcrhpY84oeTJ3hlFtdfLoxl4Ub8yi3uvAoKuEmiSHp8dw6uiN928eQX27n0cU7mTdtkK5SpHPK+duXv7E5JzhvDodH4bGvd5EcY2F019ansHXBUfG/z1DsduJuuvF0N6VBrG4r32R+w/yd8ymxl2AUjaiouGU3fVv35ebeNzOi7QikAFNUanl7VRauE4LiYAqxFVXlvbUH+eelZ24AqtNMbH4P7zy7LwGNSwBs5ZCzBtL8D2afuLwnMRYj76zOwi2rmgPjWsweJ+0TY1g4ayTxkYEPfEOZQxV2nvhmF7/uK0EQqNPXl9lcLFh7kA/W5XBuh1Y8Nqkn3ZNPTtTswRUPUuWs8gmkyleVc+SHI7iKXUhmiegB0bS5qg1ShLdvs3vszPp5Fr9c/Yvf/k5VVfLL7WRkl7HxYBkZB8soqXYyMDWWQelxXNA9kV/3ldSZRApGyARgSMc4WoWfnAhRQmQYE87AdPHQC6bWvOYNiDQY2FZibJqBF9c6efK8egotC3+D0kyI13OIzzTKrC4eXbyTH3cXIZ7QmdlcMt/tPMyyvcUkxYQhCgK3jelI35TQMg/VaX4OFNfwzfZDJ+XN4XAr/HPRTpY/MPaMmB12Hz5Myauvkvr+gtPueVUfqqoyb+c85myfU6eewK38kau/uXgze1bswWww8+KYFxmUNCigc9tdMt/vLPSRhg+mENstq3ySkcs/JnRHDDHPLZ0TyFrurcvWIOBxiccJ+RvrDaYEQeCvF53DhN7JvLPiAEu25GI0h+FSVBRVxSh5hWQ6tY7kBvvvDCtYReuo8xtzZyHDnsNVXPf2eqo00n5r8aZ1q6zPKuPKN9cyZ8qAoCfJ8qry2H5kOx61bmrfke+OUPJdCe1ntCeyRyTucjeHPjjEwRcPkv6PdMSjxs92j51VBasYmzIWAEVROVBSQ0Z22bEAyi2rDE6PZXBaHFOHpdItKfqYL+CB4hpW/X7EZ0U+GP+ns7H+LvSCqZK99e5+YlwYI+ZZuWdIPVGzZIIj+/Vg6gwjr8zGVXPWUmZ1adaiACgq2N0y2UdsiAIY/Ehy6ug0JfPXZGvOAgfqzVFS42RLbgUDUk+vEIWqqhx+7DHibpxKWJcup7Ut9aGqKk+se4Il2UtwyNqD1FpsHps3Bebn2Twz6hkuSL2gwfMXVzs0TYeDLcR2yQrVTg8xFn1lPKSxV9S7O6BxierxliEEQI+20TzepoJbir7hwF+eoKTaiUtWiLEYGZgaxzlJUSiOQWRdOomaNWuIHDEimLsJOWoVPasc2rVLWthcMjM/2MTHtw6lf4fA+/WP936Mop4oACFTvKiYdre0I6qPV4TE1NpEyuwU9j+wn8q1lcSO9l7D5rHx+ua3+D27AxkHy9h0sIwos5FBaXGM6BzPfeO7knacouyJdE6MpEfbaLbnV/q80xoSMhGAmHAjw8/C+rvQC6bc9Svl9EqUuLSrgWdXu+je2s9Au7bgU+eMoczq4qo311JS4wxIVQu8gdULP+ylVbiRK/q3P7UN1AlZbC4PX24p8JnJCyYlzO6WeWtlFnOnDmiWNvujcvFiPCVHiJ8x47S2oyHmbJ/DkqyGA6njccgO/r7q78Rb4jk38dx6j7U6ZU3jymALsQ2iiFUPpnQM9ac8BTQuATAGXotT/cOPJI8fR89+2kbgotlMm4f/TtGTTxGxeBGC6ezyhmtK7vh4CzVO7UCqvswDh1thxoJNZDx8PgYpsIndRQcW+QhO2H63obgVogfUTRuUzBJRfaKo2VVzLJgC2F+xh47KISb17cj/Xd6LpJjgJO/fnDKAS15ZRbnd5VddT4twk8R70weflSvxoTctb2g49/fxsWbe3uKioMrPUyIIYDr5wmWdpuefi3ZSatWWJ7buXs7hBfeS+/JV5P9nKkX/ewxH/i4A7G6Fv3/5G6U1zmZusU6osOtQFQaNl0cwKWGqCmsPHDkVzQsYd3Exxc+/QNunnkQwnrmD/2JbMe/89o5mIFW+qpzfH/mdXTN3sffuvRxacAjZ+kcxvkN28OiaRxu8RiR2FNl38HR8IXYgeBSFSHPozWnqnEBMSoOHNDguMVggqk1Al1PdbmqWLydqfP2rsJHjxmHskELpggUBnTcU2V9Uzf7Cas2xR1XGV5T98jYxQ6+h/Z0f0m7WfKL6T8D++4Zjxzg9Msv2Fgd0Lbfixuq2+myXa2QMkQYEyfc9Y4gx4Kmp21dFmsxMGxPLZX3bBh1IAbSJNvP5rGEkRIRpvttORBQgymzgwxlD6Nrm7DR+Dr1ePKYdOOpfUu8cJzK5p5HXMlz0TtSINxUZWp1Z6lqhTJnVxc97inxm/iGwNCoB+HRjHrPHdW7mluuEApU2bT+NYFPCbEFKGjclqqpS+MQTtLrmasw9zmzBhP/t+5/m9kBrCgqthew8spNeCb28H/S4oHgXFGyG/M1QsJnEikIU+TWgblAZbCG2xSgRVY/Zqk6IMOAmr5qfy3egXEuD4xJVge6X+2yusLn49rdCDlfYsbpk4iKMdKoooEtaGsakpHqbJQgCSf/4BwevmUzMpEkNHh+KvLs6G7fiWwsbaOaB1SkzZ0UmF/b0/7dVVZVKu5usslJEQUI+oV5KipTw1HhQZdUnoPJUejBE+vYxTk/jJpA7to7ku3tH8cIP+1i8rQARwecdZTaKqCqM79GGBy/qRof4s1e5NPR68aGz4dsHQSO6P55Hx4TxwQ4/pmJRSZDU+xQ0Tudk+CQjV1MHKdDOzOFReHd1NreN6aRZB6Gj0xj8PVPBpoSdzkez+vvvcWUfpN3LL5++RgSAW3GzcO9CXEpdxdZgagpcsov31vwfL5rSoWCT10w1Ng3a9YeUQTB0FubE7vxp0R4+25yHfMI4KtBCbJNBZOqw1DNCVETnNNPxPG+2Sz3BFPgfl6iCiND1Ioj4oxblt/xK5q7M5KfdRYiCcEwOXRTArMqYO1/PjF8PcN3gDsRG+E/hM3XoQOz111H03HO0//e/vRsLd0Lxbq/3pjECWqVAh+EQYjXIqqqyeGuBTx8AwWUe/FZQyc+7C3F4FIqqnBRVOSisdFBY5aDo6I9REmkTbUKOk32EH8M7hyMYBKo2VxEzOObYdtkhU72jmjZX1V2xVFGJNDXejykhMoznruzDo5f2YNHWAr7Ykk+5zY2iqsRYjEzonczkgSn1Pl9nC6EXTPW8Ar59wGfzwXvrLj2mxIg4HtGQrTRFwIh7/Xo56DQ/CzNycWiopAXTmTk9CltzyxmYFncqmqgTwsRFmFA0/GOC9eaIPEUrGGVHbQRW/n6ESpsbk0GkbSszkwd1YFTnBJTKCgqffpqU119HPMPrJvaX70dWfFfwgqkpUFBYVbEXOg2D8x+Dtv0gzDc15ZaR6SzaWoCsMSvdUCE2eMdDU4emBX5zOmcvogjD7oJfnwbPH3XdgY5LnKqRgi430wnvAP8/vx7gv78ewOVRfNLPFBVsSNhkeG3Z77y1KouPZwylR1v/Mt3xt95K9qSJOD5/EnPRN1CRA4IEisf7ryCAKRyG3gn9p0J4aLxHHW7Fr9hVMJkHHlnlhR/30TEhkjbRZpJizPRIjiYxOoyko7/X+ihN/DKF3OrcOp+XwiUS/5TIoQ8PIZrFOivvxjgjrYbXVSyWFZn2kU1XJx4RZuCGoancMDS1yc7Z0gi9YMoUDgOmw6Z5dTqtgBEk6H1V07dL56Qpb6I0qiN63ZTOKaBXuxjCDBJWZ91BfjApYUZR4DI/heIny4Hial7+aT8/7yn2sRHYlgcr9pUQbjJwTeUurpswCUu/fk16/VNBpaNSc6WnoZoCe07dd4EdFXXEPfWuGnVpE0W/Dq3YnFPud0DljzCDyPndEk+qXkHnLGXoLNi7FA5tAVnbC1MTYzi5addx3VIPN1ceoMbh4b21BwPys3O4FRxuhavnrOWL2cPplqQdUInWAtLPy4LtW0Hyc15XDSx/BlY8C5M/gM4Nq2K2dFweBVEEWSMDO5jMg0izgWeu6BOQqt/0XtN5fuPzx6weamk9oTVShEThp4W4il2IFpHo/tGk3JaCaPxjxdAgGJjUaRLhQYiV6DRMaK3J1nLBv6BND5CCNKIzhsMNn+viE2cYHo2ZYajbmTWEqqqaHkA6Oo1FEgVuHpGG2ejb3UYPvoLY826hct2n5L9+A/lvTqN6yxIsXequpoqiwPThaU3WphX7S7jsP2v4fmchLo+iOfCyumRKapy8Qyp3Rw2j2uEn7bkFcHxNwYn4qykIhLlTBpIYFVgRdi1GSSA1PpyXrjnzg1OdZkQywg2fQVIfr5hEIBjDod/1dL3+Jb6+aySLtxUwd2XmsZS+QLG6ZK5/ewNWLUW6kv3w9jgEVwWiv0CqFo8d3Db4ZArsWRJUG1oikWYDHj8TKcGI0SiKSrQ5MFGfCekTUP1I6MWNiaPLU13o+XZPur/WnXbT2h0z7K1FEiWm9JgS0LV0Aic0gymDCaYu8qZvGBvutBRBxIaZisvfhw5DGjxep3mpXf4+kWA6M1EQiNbliXVOEZMHpaB4tAc4kT3HkXzTK3T4yxek3PkhiVf/C3P77sf2C4J3dSstoWkmcTZklXLbB5uwueSAbASckpGdh6uZ8s4GnH7u4UwhxhyjOdA4vqbgeGprCiJ61P3bWgyWgGqZYsKNLLpjJOkJEVg0guUTsRgleraN5rPbh2MxnbmGxzqnCXM0TP8OBs8EU6T3RwtThLd2+5LnYOJLIAi0a2VBVlS/3+n6VG0BHG6ZRVsL6n7IUQnvTQRnNYJGqrJfPHb48lYo/C3wz7RAJFHw2y8fn3lg278Oxe1AlT3YMzdR/uu8ugcLkBIXWAAdbgznxh43YpaCX9U2SSYGJQ2iY0zHoD+rUz+hGUyBt9OathTGPQJRSSjGCHzmXAwWMIQh9ryChf0+4K4N0SiBmhjpNBsDU2M1S9iC6cycskKvtjG+J9HRaSSesjIcf72X24s3YDYEX2sZYTLw/FV9mqQtVqeHWxZs0l6Jqmew5ZJV9hVV8/z3+5qkHaeKc2LPwSD6Tq4cX1NQvaMa1aPiKnGR90aeT02BKIiMbj864Gu2jgrjm7tG8o+J3WkfayH8hCBJFLxBVKfWETx+eU/+d9tw3VdKxz8GE1z4BDyQ6Q2UkvtBRIK3bi+yDXQ6HyZ/BPftgf43HvvY9rwKDlVo+6oFItFtc3lV5epMRmz9EFzVcFwglfZKNYkvVGN1/bHtnS0uxr53gniG2w6/PtW4v0ULYNaYTj7f+VoCyTwwSgLXDe5AmCHwyZU7zr2DEe1GBBVQmUQTKVEpvDTmpYA/oxM4oVczdTySEYbfCUNns/KHz4jevZD+rWzgcYIlFjqfD+dOAUssN8kKP7y9gTeWH+DO87qc7pbrHMfM0R1ZfeAINpfvrHkgylqCAGO6tqZ1VJBpnzo6DVCzZg2H//4wMZdfxn133YWw4iBvrcwKKA1HELyB1Pu3DKZT68YrLwEs2lqAorFyE4iFgMOtsDAjlwcuOgez8cxcVTGIBq7rdh3zd83HKdetgQy0psAkmpjWc1pQ1zUbJaYMTeOGIalsyinnlz1FvLUyi8v6tiW5lYWLeybRN6VVwyfS0anFaIa+13p/AmDe6mzNleNgzMFLrS625JYzIDXuqLnd696g6ARkFV7d4OLhUfW9M1U4sAxqiiEyMaB7aIlM6tuWx77e5Xd/Q2I0oiBw07C0oK4pCiIvjXmJx9c9znfZ3+GSXSi+ywHHsBgsdI/rzhsXvKHXSp0iQjaYKq5ysOtQFVUON2EGif8VdWTc8Ffo7+ehNkgir17Xj0mvr2FQWhxDOsZrHqfT/AxIjSU+woTNpS0o0lBnZjFKzBytL3vraOCywa4vYftCsB7xerlY4qDH5dDverBoD5BVl4viV1+laslS2j73LBHDvDOR943vSmp8OE8s2Y3bo2DVmAAQFRmTyUinxEheu+7cJgukVFVlzopMn0mHYAZbAN9sP8TVAxs2GT1dXHPONczbOU9zX9yYOOLG1K80lhyZTM+Enid1bUEQGJQWR/8OscxdmcXL1/RD1O0WdJqBnYcqNVP8gjUH31dY4w2msld4pc81eGC4iefXOJk9yEQrcz3PtyDApvkw9qFAb6PFYTFJzB7biTeWB1+rZjaKXNgjiZS44AMcSZR4YsQTXNn1St7b+R4r81diEA24FTeqqmKQDCiqQu+E3tzc62ZGtB2BFKAYl07whFQwpaoq67JKeWtFFuuySjEZRBRVRUDA6vSwPa8Cp0fh6oEpmmkYyTEWXri6D/d+uo0ld40kPlJfyTgTEASBf17ag7s/2RqQgtHxmAwivdrGMDC1YRUdnRDCVuaVKd72kXdAcKL/S+EO+OVxb1B13iN1TLxdOTkU3P9XDAkJpH/1JYa4uoP3K/q35/J+7Vi2t5g5KzLZnleBoqqoeAP78+0F3NSpFQOnX9Kkt7SzoIpSq69KWDCDLZtL5t3V2Wd0MNU6vDUz+8zknZ3v4PBopz35wyyZeWL4E41ug83lIdwo6YGUTrNxolpoLcGo2rplhRrnUaGZgs3g1v7+DGwrMTbNwItrnTx5Xj2pZh4HHFwJnL3BFMCd53XmQEkNP+4qCjigMhtEuidH8+LVfRt17b6t+/Lvcf+mzFHGqvxVVDgrkFWZGFMMg5MHkxJ15vbVZxMhE0xV2FzcOC+DA8U1x2ZmT1RvK7W6eOnH/bz0435eu+5cxvdo43Oececkclm/tvzlf9uZP20QHkXlx92FvLUyi6wSKw63jMkgkhxj5uaR6fypXzsidIf7U86FPZO4c1xnXvppP36EbnwwGUTax1p4d9pA3ThT5w/KsmH+BO9KlOJHotht8/772+ew/3u4cTG0PZfKxYspevY5Eu64g9gbrvf7XEmiwPgebY71MQ63jEEUMEgi1cuWUTpvHky/pklvq6DChqglGx6khUBhZXAByulgZp+ZFNmK+CbzGxxyYO01S2aeG/Uc/RIbr7Jnc8mE6/2+TjOipRYKwUl0G0QBS62gk60MVP+BwRPjwhgxz8o9QxrwnrNX1L//LEAQBP59TT+eXLqbjzNycXsU/LklCAKYDRIjOsfzn+v7YzI0jXRBnDmOyztf3iTn0gmekBCgKLe6mPjaavYcrtKsqzkeu1vG7pa5a+EWvtySr3nMXy88h2qHm2nzMxjw5E889MUOduRXUuP04FFUbC6ZzBIrTy3dw8Anf+axxTvPeBWslo6sqGzLq6R/h1aEGUTC6umgBCDcJNGvfQyL7xhBVICSpDohQHURvDseagr9B1LHo8rgqESdP5GiB2/lyFtv0+G9+cRNuSGoAN1slDBI3mc2cuRIXAcycRcUNPCp4PCq9/m+4YOxEADfSagzEUEQ+OfQfzK732zMkhlLPVLT4YZwEswJzBk/h/NSz2uS61udHiJ0tT6dZiQ1XltVLhhVW4PknWAEvLLr9dArUeLSrgaeXd1AP2kIDS81URR4dFJPvr5zJFcNTMFsFIkMM2AxSliMIpFhEmEGkfE92vDhjMG8fePAM7b2VCd4zvqpM1lRueGdDRRXO4IyVnS4FR7+6jdS4sIZlFY3TUcAosxGVuwvqfcctYHbp5vy2JpXwUczhugD91PE09/uwer0sPDWYVTYXSzckMv8tQePekCoqKq3s3N6FEZ1SeC20Z0YlBarr0jp1OXz6WAv99ZGHSXtlWpsbsi+J5IIk/d5eWeLiw93uFk+7egAxmUl3vI9rT/bhxjeuAJfwWQi6sILqfz2WxJuvbVR5zqeyDCD5srU8YOtiG4jGzxPS5H0FgSB6b2mM/mcySzNWsq8nfMotBViFI2oqopbcdO/TX9u7nUzw9sORxSabm7R5pL9Wjbo6JwKpg1PY9PBMp86zGDMwQ2iwKjOCd5fopO9AVXtKrwGj481039uDfcPq6fkoVVopZl1bRPFc1f24Z+X9mDNgSOUW10oKsRYjAzpGEeCXh5yVnLW9/a/7i0mp9SqGUhZdy+nauMi3KX5iCYLxsSOxAy/po561ZNLdrP4zj8GGKqq8sDnO8jILgu4DQ63wt7DVUybv5FPZg7FKIXEgmCTUbuq50869P11B/l1XzFfzRqBySCSGGXmngu6cse4zmzKKaek2olbVog2G+mTEkNiVGjMlOkESWmmt05A8TWubEi9ShDAYJKhcCN0HNPopkRfOpGip55u0mCqa5so3LLvqlIwgy3veZpGEKO5CDeGc/U5V3NV16uoclVR5arCKBqJCYupd8WqMVidHiLCWkbQqXN2MLpra8xGSVPUJhBV2zCDyI3DU4+tkNP9cvj+b/Ves3OcyOSeRl7LcNE7UWNcY4qEAdMbdV8tlcgwAxf1TDrdzdBpJs76YGrOykzNziUQKWCAfYXVHCiuoXOidwCxYn8JP+wq1CwyrC84c8kquw9V8uG6HKaPTD91N3wWoKoqGdllzF2ZxarfS5CPShRJosCozq25bUxHBqfHIQgCv+4t5vVlB/ji9uHEhNdd9TNIIkN11UWdQFn/JvhJdQtIvcplhTWvNkkwFT5wIHJlJY79+zF37dro8wGkJUTQPTmabXm+NQyBDLYAIkwSt43u1CTtaW4EQSAmLIaYsFPvJ6evTOk0N5IocNuYjvz7p981xyeBSHRPGZL6x4aIeOh6Mez5ps5K/Yk8OiaMD3a4tXeGRUN64J5tOjotlbO6t88ptfJbfqXP9mCkgD2Kyvw12Tz1594AzF2RpVl3FUhwZncrzF2VxbQRaXp6mR/WZZby18+2U25zYXfJdTzXFVnl133FrM8uJTbcxB3jOvHij/t5+8aBdIjXvRN0GoGqepX7FO1BQcDqVQdXewu3w+uX324IQRSJnjiBqiVLMf+laYIpT2kp1xZtZp+nPXaD7wpbQ4Mt8NZ2je7auknaczZjdekrUzrNz4yRHVmbWcr6zFIcQdQ2mo0ir113LonRJ/Rtw++B33+s4zV18N6oOoekxIg4Hon2PanRAiPu8S7b6+ic5ZzV+WZbcssxaEjTBiMF7FFU1hw4AkBemY0tueU+x9QGZ3HjZxF+znBEkxlBMhDeeYhPikyV3c26rNKTvKOzm6+3FTD9vQwKKuzYTgikalHxzvoWVNh5+KudTOqbzABd1lynsbhtINdfSP3EuDBez3BRYq1nkGIwQU1RkzQpZtIk9v24kmeW7ubGdzdwxRtruPHdDTzz7R5yS/3XMZyI4nBwZO5bZE28lJGmGuLjo5FOYoCbCwBlAAAgAElEQVRjMUrcfX4XJF3uu0FsTn1lSqf5EUWBOVMGMKJLApYAxA0EVcFsFHnhqr6a6sW0HwBD72hQjMIHKQzaDYJBM4L7nI5OC+Ws7u2r7F51vRMJVgq4xumtofhu52FNNaxgfVo+25TP8E4JAV07VFhz4AgPfrEjaJ+oTzfmcUH3Nozqos+W6zQClw1ESbNeqpbj1au6t/Y3DyXUW7AdKGsOHOHV5WVs6zMNZXU2nuO6nXVZpby39iB9U1pxz/ldGNFZuy9RFYWqJUsofuUVLD17kfbJQkxpaSwss3Hp66updrg1TT61sBhFJvVN5sZhqQ0frONdmWohQh06Zxdmo8TbUwfyycZc3lieSZnVhd0t17EMCTOIqED/qgLu7hHP8L5t/Z/wvEe8KcxbFgTUt6mSGaFdf7j+E5DO6iGmjs4xzuon3SiJmivMwfguABhE78DpcIW2ImCwwdnhFuDT0pzIispdC7UNdwMRCbl74VY2PTJenzHXOXnM0SD7yfs/jgbVq1QFGlGTo6oq/1l2gDeWZ3rrHkQDJy7Ruo8qVGZklzFjwSZmje3EXed1rpM6bN2QQfFzz4HBQLsXXiB8wIBj+1Liwvn6zhFMnrueSrsLez0TGALewdm1g1P458QeenpygOg+UzqnE1EUuH5IKtcN7sCmnHI+Wp9LXrkNh1sm2mJkUFosU4akEpVzgLzZs5CvvAgp0o+wjCDAJc9C23Nh2f+Bvcw7+XRix2SKRHF7qK7qSszD3+iBlE5IcVY/7a2jwo4GQnUHC8FKASdEeU3p/PmrBBucuXTPqTos21uMU6NgNlCREJdHYdneYu00BR2dQDCEQVQSVB2q97AG1atUBWLan3Qz3lyR+UcgFQB2t8ybyzORRIE7xnXGmZVF8Qsv4ty/n8T7/0LUJZdoBkCp8RH8fP8Yvticz9yVmVTY3Djc8rGVKrNRRFXx2giM6eRjD6FTPzZ9ZUrnDEAQBAalxfn//vbuReSIkZTOfYvE+/9S/8n6ToY+10DOWlj3OhT+5g2qDGav/PnQWZA+npLLr0Bav4HIESOa/oZ0dM5QzupgalSXBM20vGCkgMNNEtcP7gBAfKS203ewwVmr8AYcw0OMOSt8FReDEQmxumTmrMjUgymdxjHsTlj2ZIOpLH7Vq0Qj9LsejCcnvb/pYBmv/3IgaKVQu1vmP7/8TqflX9Pxx8+JnzGDdq++gmiqv5+JDDNw0/A0bhyWyobsMjbnlFNudWE2SrSJDuPiXsm0jtI9UU4Gq1MmPkL/2+mc+bS+7z6yL7uMVpOvwdS+gYkgQYC0Ed4fDUSgzUMPUvTMM0R89RWCUffV1AkNzupgymyUuHpgCh9vyPFJzwtUClhV4U/ntgNgSHo880zZjTLFsxglxupqWMewOj1s15BqDqYODWB7XgU1Tg+RemqNzsnS7wb45XGfzQGrV4kSDJl10pf/768HcJzkCq3dLbNATeaDb5diiA1OkEUQBIZ2jNdtBJoQm67mp9NCMLZJJG7aTRS/8CLtX32l0eeLPP98yj76iPJPPiVu6pQmaKGOzpnPWT/yvHlEGp9k5OKT30vDUsAmSeCK/u2OqTIN7xRPpNlw0qZ4ACoqVww4+TSgs41ymwujJOI5wd8n2Do0oyRSYXPpwZTOyWNpBf2nwdb360gBB4QhDNLHQELnk7p0UZWDtZmlPr1UwCu0gsAWoRWlkgV9ffb0Y9V9pnRaEHHTp5M5YQK2jRsJHzTIu7GmBGyl3tRlS6w3DTqAmklBEGjz97+TO206kRMnsKbEw7urssksqcHulrEYJTq2jmDGyI6M7tpar3XWOSs463v71PgI/jK+K6/8rG1k5w9JFEiKsfDQJd2ObRNFgVtHdeTFH/dpiiU0FJxJAlzWt60+4D8ORdHun4OtQxMEjpn76uicNBc9DSV7IS8DPAEGVJIJYtPh6vknfdnPN+dpbg9mhVYF/rcpj7vO63LS7dBpGmxOfWVKp+Ugms0k3n8/Rc8+SdojkxHWvgplWd6+DbwqpxGJXt+ovpMhLKre84V16cJ3509lwfMrcZrMdSagK3BzuNLBttwKzEaJO8Z1YvqIdF3cRqdFc1b7TNUyc3RHbh6ZhsUY2O2aJJHkGDP/u20Y0ea6Ob+TB6UQYzFyMpMpFpNBH+icQIzFiFv2DUyPr0MLBLesEGPR87N1GolkgBs+hy7jwRih6XVWB1MEtO0Ht/zk/e+T5PeiGk2Bm2BWaF0ehQPFNSfdBp2mQ1+Z0mlpRKc6Se2xCr65zzuhJLvAVeP98TigMhd+ehRe6ALr5/g9j0dWuGvhVuaQShlGzUwe8H5HSq0uXvhhP3d8vEVzHKCj01IIiWBKEAQeuKgbz17Rh6RoM+F+VJYsRpEwg8glvZNYevcokmJ8C8mjzEb+d9swoszGoIwvzUaR96YPIiUuSPO7s5xoi4F2rSw+24+vQ7PtX4fidqDKHuyZmyj/dZ7P8W1bWfRgSqdpMJjgmvfhuo+RY89FUUQwWvB2l4JXvcpghvaD4M9vwbTvvNLqjaDWy+5Ejl+hDYRqh3+fLJ3mw6vmpwdTOi2E9XMQFs1GlGQEpR7rFrfVu2L/y+Pww8M+u1VV5aEvdvDLnqJ6LReOx+6WWba3mAc/34GqIRimo9MSCKne/vJz23FZv7asyyxl7spMtudXYnN6cMkqCZEmLu3TlunD00hNqH+GOTU+gqV3j+T6tzdQanVidfof6ESYJBRVJTU+nD7tWzX1LbV4BEHg9jGdeGLJbmwnzGAFWocWbpKYNaaTniag03QIAnQcS1HBYMK7X0lsr3Cvv4qqgrkVdBoHCU23yuxvIiBopVB9QuG0oaoquw5VUVLtpKTKyb6iKjrEhRMTrv8/0TmD2f01/PyvwNOawat4umkexByVRD/KT7uL+G5noWYgVZ8iqcOt8MOuQr7f2YZLeic3wU3p6DQvIRVMgXfwPjg9jhqnh2pHJrsOVSGgYnfLfLEln48zchnTtTW3je7IgNRYvwP09rHh/PrXsazYX8ycFVlsz6vAZBBRVRVBEHDLCqnxEcwa04mLerbhzo+38ux3e3l0Uo9mvuMzn8v6teXxb3Zr7muoDg2849vL+tXj4K6jcxIodjs1vy6nzYPfQkLCKb1Wr7YxfPvbYZ9BSDBKoWajSM+2jVsh0wmeSrubLzbn89bKLKodbkRRoMbh4bGvd/GPr3Zycc8kbh3dkV7tTt7MWUfnlKDI8M09PoFU2ivV2NyQfU8kESbvGOidLS4+3OFm+bSjk81uuzcI63fDsZX5N5Zn+kyKQmCKpDaX1zNPD6Z0WiIhF0ztK6xm6rsbsDo9dXJ5j19d+nlPEWsOHCE9IYIFNw8mIVLbL0QSBc7r1obzurUhv9xGZomVGoeHcJNESpyFzol/FGm+dE1fLn19NQPTYpmgdxZ1qHZ4aB1lIr/cTrAaEhajxIMXn6PXJ+g0OTUrVmLp3QvDKQ6kAK7o355nv9+ruS8YG4erBqSc8rbq/MHSHYe4/7PtCAg+Ake175Rvdhzix91FDE6PY86UAVh0M1+dM4X9P3hrozSQVXh1g4uHR9XjlyaIsH0hDLmNzJIa9hyu8jkkGM/I/UXVHCiurjN20tFpCYTUCHRbXgXXv70eu0uut7BcVb2zJPsKq5nw6iq+uWskbaLrN+JsHxtO+1j/9VCtwk28cUN/ps3fSLekKDq2jqyz3y0rVNndmAwikWGGlpey5rLCb5/D5vegpsir/hMW7U2HGnI7xHfS/NjaA0e499Nt3DgslTKri4UZeQGrLlqMEtcNTmH6iPQmvBEdHS9VS5cSPWFCs1wrJtzIxT2T+GbHIc0JhYZWaEUBLuzZRk8pa0Y+3pDDE0t2ayq7Ho+ieutC1meV8uc31vDl7OH65I/OmcGaV7wCExo8MNzE82uczB5kopXZz3jEbYO1r8HgmXy5pUBTUTcYRVK3ovD55nz+dkn3oG5DR+d0EzI9en65janvbtBcgvaHR1Eps7q49q31fHv3qEbPKPZp34r7xndl9kdb+Gq210H8m+2HeHNFJgdLrZgkEeVoAebE3snMGNUCUkMclfDTv2DHQkD0FqjWUlME5Qdhy/uQ3BcufApSvB4WiqLy5opMFqw9yL8n92NEZ+/sf1KMmZd+3I8g4HeQYjaKqCrcN74LM0drB2k6Oo1Brq7Gum4dyU/+X7Ndc9a4Tvywu7DBwbkWJoPI7LEn53GlEzyrfz8SUCB1PE6PQvYRKzPf38wHtwxueRNmOmcXsgfyN/rdPbCtxNg0Ay+udfLkefVMJluPQGU+eWU2PBrBVDCKpLICeWVBevzp6JwBhEww9fqyA9hc2kpX9RVGehSVwioHi7cVcO3gDo1ux5QhHdiYXco1c9ZyoMSKIHAswDteGvmb7Yf5YVcRHeLDefOG/j4rWWcEVYdg/iXef/2kCqC4vT95G2DBJPjTG1R0vJT7Pt1GtcPD13eOrKOaOHN0J64ekMKnm/J4d1U2VpfnmKmfrKhEmAzcMiqdyQNTiI0wNcdd6pzlHKlxsubAESrtbgBiw0303p9B+ODBSDHNN5lxTusI/u7Zy9NKR5xi4F2z2Sjy3BV96J6s10s1F499vdNvIFXf+8TpUdiSW87WvAr6d4ht5lbr6ByHoxJEI8hOv4c8MS6MEfOs3DPE/7tWdrgpmHo1R9pcAwnn+OwP1jMyGD9QHZ0zhZAIpmqcHhZvK0DLxiCQwki7S+bNFZlMHpTSJLOJoiCw81BVvamGsuoVxdhfVM1l/1nDRzOG0DflDFIDtJfDuxdBVQGoAXZ+HjvKV7P4PzGTTn0n8tAl3TBKvur8sREmbh/TiZmjOpJ1pIZyW+0g10jHhEhE3TFdp5GoqsrmnHLmrsxi5f4SDJKAR1YRAIMk4nKIjOs4gdl5Fc3yvVPdbg499DfGlpURO/tyHvxmH25Z0eyzapFEMEoiL1zVh0l9253yNup42ZFfwaEKbfnoQN4nDrfM2yuzeHPKgOZsto5OXQQBGnDS65UocWlXA8+udtG9tbaTjmgxk/zkU7TbGQW/lfjsD1aRNE6fJNVpgYREMPXVlnxEjSAomMLIkmonW3IrGJDauNnEZ7/byw+7iho2Az2KqnqDwSnvbmDJXSNJjT95Y9AmZcn9UFNYJ5AKRAFIlB08x0sYxt/uHQ3WgygKeiGqTpPj9Mjcs3AbK38vwe6WUVWoY/PkUUCQ+KkUVr61nkt6J/H8lX0wNPC8niyK00nBvfeBqpIydw6pYWH0TE/knVXZfLXV23cdn54cftRu4c/ntmfGqHQ6nYmr1mcxb6/KwunxnUAK9H2iqLBsbzGlNU7i/Ygb6eiccswx3trmBnh8rJn+c2u4f5j2syqoHoxp3Rgmq3y7r8zHpDcYRdJwk8Tg9LjG3ZeOzmkgJIKpn/YUa9ZKBVMY6ThaQNyYYCqrpIb31h6sk85XS32pIQBWp4dHFu3kg1uG+Hy22bGVwd4lmql9gSgAGUQBtn8Cg289la3U0fHBLStMfSeDHQUVAQsHfPdbIeVWF+/cNOhYymlToVit5N15J1KrVrR7/nkEo1dAolPrSJ65ojePTOzO0h2HOVBSQ6XNTUy4kc6tI5nYJ5mIsJDovs841mWWaoqEBPM+MRpEduRXMq5b4ilooY5OAIgSpI2C7BX1HtY5TmRyTyOvZbjonagxoRTdHqLbMqG3wiOLdmqeI1BFUoBJfXSbE52WR0i8jStsfqQ/gyiMVFQorfGfWxwI89cc1FS7CSQ1RFEhI7uMw5V2kmMsjWpHo9nyvlcSVYOAFYDWvAqDZhxNNdDRaR7+8dVOfgsgkDqeWiW2p7/dwz8vrd8nTlVUPMU2lKOpqWK4EUNiOIJGECZXVZF32+2YOqaT/MQTCJJvPxQRZuCaQbrc+ZmEPxGjYN4nqqIeq9HT0TltjLgHCjb7VfSr5dExYXywQ+N5NUbAyHtBEDAbJSYPSuHD9Tm4Zd9xTkOKpEZR4KoB7XXrAJ0WSUgEUwY/s8nBFkbWpvlkldQwf81BftxdSI3Du0weZTEysXcy04ankRLnK5Fud8l8vjnfR+0mmFRDFfhgXQ4PXtytwbaeUjbN8+uWHrACkL0MDm+Htv1OUSN1dOpSVOVg0bYCXCexMmx3K3y4Poe7z+9CjMVXfly2urFuLKRmdQGqS/ZqlQMoKoJRJHJEOyIGJyFFeusBPGVl5M6YQfiAgbT5+98QxFOTQqjT9Eh+JoCCeZ8IApr1ojo6zUrHcWCK8AmmDt5bN70+JUbE8YiWwI0Kva469tutozry+aZ83HLD6YMnEmYQmTm6Y9Cf09E5EwiJ3tyfR9TxhZENYTKIKIrKn/+7hkteXcXCjFyKqpxYXTJWl0xhpYP31x3kgpdXMHnuOg4UV9f5fMbBMs0UoWBSQ1wehcXbDjV43CnH6ltkejxPjAvj9QwXJdZ6Zv8FCaoLm7hhOjr++XB9DlrD4KqMryj75W1ihl5D+zs/pN2s+UT1n4D99w11jhMFgc825fl8vnplPoef2UDVL7koNW5Ul4LqkL0/LgXF6qHq11wOP5tB1a+5uAoLyZl6I5GjR9Pm4b/rgVQLo5UfL69g3icgEB+pF9rrnGZEEa54Cwwnke1isMDEl8H0x+Rx21YW3rt5MBZjcKtLZmSezl5CEo3L/tHROV2ExFv8yv7tidBYOj6+MNK2fx2K24Eqe7BnbqL813l1jlUVlQ835LA1rwKnR9H0U3DLKk6PQkZ2GZf/Zw0Z2WXH9pVbXagashPBpIYAVDtOPjWkoMLO55vzeXd1NvPXZLN4W4HfFMh6aaBo9XgFIL+oqt/VLR2dpsYjK7y/LsenXrF2ZThu/CzCzxmOaDIjSAbCOw+pUxgN3nS/t1dloap/fI8rlmRR9VMOeFSoL3XQrYJHpeqXXA49/DHRl11O4r336l5DLZCrB6YQZvB9dQbzPpFEgYGNFDPS0WkSOo6Fy14LLqAyWmDs36DfdT67BqTG8snMocRYjIQ3kLIXbpKINhtYOHsUQ/umkXvjjbiLi+v9jKKo2F1ynX5YR+d0ExJpfuO6JWIyiD4qMxB4YaSsqrjdgX15VcDqkpk2P4MvZw+nW1K0t1ZK4+PBphoG238oisqqA0eYuyKTzTnlSKKAW1YQBAHjUTnoC3u0YeboTvRuH6CnjtHi31fqKA0pACEIXjUhHZ1moKTGqanAFszKMEBpjYsap4cos5HqtQVYNxxGDcZk16NiaDMAc3fdbLqlcv2QDvz31wOa+wJ5n4QZRG4clnrK1CF1dIKmzzUQ0Rq+nOmtafZTQ6VKFhSnC+GyfyP29w2kaumb0oo1fzuPRVvzmbsii1Krd7ygKOoxa5O4cBO3jenIn/u3JzLMAH/9K2JUNDlTptJh3ruY2rc/dr7iagcfrc/lw/U5lNlciAioqLSPDefWUel/nENH5zQREk+fJArcPDKd/y47gEOjXqKhwkhAU72poToLm0vmtg82s/yvY4mNMGrOQgfrwRCMgpfV6eGWBRvZkV+pUTStUuthvPS3w/y8p5grB7Tj8ct6NahYpib1RTi4st5jGlQA8jghsWfA96Kj0xiqHR4MogjU/f4HuzJslESqHB4iRJGq7w9qBlIZ+Tt4+tc32X/kIKIo0iU+lcfOv4t+yd29B6giVT/lEDEkCVEfALQ4EiLDGHNOa37ZU6TpAxbI+2TK0NRT1DodnZOk0zi4fx9kLoM1r0DOWpCOprTKbkjqgzDyXgpeW0rEdifx/es/XWSYgSlD07hhSCrb8io4WGqlxuEhIsxAWkIE56a08hkTJdw2EzEywhtQvfsOrnYdeODzHSzbW4wAxzIL5KMz07llNp75bi9PfbuHG4ak8vdLuumTFDqnhZB5k88Y2ZEl2w+TVVKDWysy8oNRElAUlRPFaQJR4AOvP9WPu4rILKnB6vRNjwvGgwEg2mJgbeYRhqTH1xv02F0yV765luwjVk0p9uOplYD+YnMB5VY3/7n+XM3Ar6DCzmeb8sgtGMOTZBCOtnFlLX4VgBC8HXdUm3o/r6PTVIQZRBSNZd1gV4YVVcVsELHv0K4brHZamf7533jqwr8wqds4XLKHjPzthEkn1McIYNtaTORQXQa4JfLkn3qx+WA5ZVZXwJ6BABajxD8mdPNbx6ujc1oRRehygffH4wJHJagKWFqBwZtl0ubB3uRMmUrMny7HENtwqqogCJzbIZZzOwSW1hp3ww1IkZFsu3U2D5z3FwodiqZoUC21E8Ufb8hhz+Eq5k8fRJhBVwTUaV6E+vJOBw4cqG7atKkZm3NqKa1xcvXcdeSX2+v9ctZiMUqYDKKPhK3itJL/35uIn3BvQKtJRlHg8nPbkV9uIyO7THOVq2bXr1RvWoy7NK9Oaoi5ffdjx5gMIlOHprIhu5TCSgcX9kxiYu9khqTH+czGzFiwkVW/H2kwkNK659vHdOSeC7oCXoPTn3cX88nGXH4rqGRSn7b/z959h0dVpQ8c/547NZOeQAgl9A4qIkiRKva+9rpi19VVd9ctP1fX3XWLq26zgohY17a2FbErTToqVXoLJCG9z2Ta+f0xk5BkJpmZkARI3s/z5DGZueXcwfveeU/lijE9Gfn6yVDVfN/mJlnj4eq3oe8pLdtfHBal1Fqt9ZgjXY7DFUt8qqzxcuIfPwuZsjdwL/+Y9HN+Ft297PfyackCbN0uRXtDvxCvy93C1W/+nE33Loh4LFOqjcxfjZVxU8eonQWVXDZrOWVOd9gWqsbsFoOfTBvI3TMGtX3hjmEdIT51tO9OjeX98WFQiqRf/4b/fZ/DB9/nUBzsypcWb+XCUT24YFQPHNaW1ddXu72c/7fP2FPhxRdlrwEAu9lg6pCuzLr2JImrotU1F5s6TcsUQHqCjQ/vmsRv3l3PZ5sOoiBstz+H1YTWcNW4LF5Zvjfk/VjHWWjgrxcfx9a8Ci6dtSzsGjfRdA0Z0T2pbp2bvUVVLNiQx98+2cKBEidnjOjGOcd1Z3z/dPYWVbO0iUQq8hTQPmYv3sW0IRl88H0OH3x/gMHdErlibBZzfjwGe+0sPac9BAt+GehfHQuTFboOhT4TI28rRCtJsJmZOKALi7Y1bFGKpWXYUHDmsAy6dL+M8oXhZ3Trn5aFoQx+9tGfuWDoDE7sOYIUe2LYbX3lbvxVnrrp0sWxZUDXBD65ZzK/+u96lu0qatANqT6H1USi3cyD5w7nvBOkJVJ0ADfeysP/N5sv//gZhmGEDCNYt7+UP3y4mYtH9+Te0wbTNbGJsdNNmLN4N/trVNhEqrnvMC6vnyXbC/lqSz4zhknPF9F+OlUyBYExR09eNZqiyhreWJ3Ny8v2UFjpxq81JkPRO93B7VMGcP4JPcgrd/Hmqmw8voaBItZxFiZDUeb0MLJnMmcMz+SzzXkxLRoKgRajhy8aWfd3n/R47pg2gDumDSC7uJqPN+by+GfbyC6uJjXOgjtMVWm0XRNdHh/XzV3Jjyf05d2fTKRPenxogU68FvJ/CKw5FW1CZbJCYiZc+44s1iva3W1T+rN6T3HIgz/aSWhsZhO3nT4cR7yNim/Whh0vlWiL591rnuKZlf/hV588RkFVMdMHjOPRs35F1/i0Btsqk4Hf6ZVk6hiWkWTnxRtPJq/Mxasr9vLWmmxKqt14fJpEm5lRvVO4feoAJg5Il5py0SHsyK/kinnrKes5Gq9XA6ET+9TG2DdXZ/PJxjzevG08AzPCVyo15vNr5i3bHbZiIprvMNVuH7MW7ZRkSrSrTpdM1UpPsHHn9IHcOX0gWmu8fh2yiGLtrHeNxTrOwlCqrlvh3y8/getfqOG7faU4PaFBKBy7xeCZa0Yzsmf42e+y0hzcOmUAt04ZwI78Cs7+95KQroSxLA7s14HP574zhzRfsDP+BHGpsPixwHTpzU2Zbo2HtAHw4w8C+wjRziYMSCfVYcXpdoaMc4nUMmwo6J3m4LheyXjLml8LZVCXvvzz3PsB2FG0l7vn/4nff/kkT1/wUKMtNSrCZC/i2JCZbOe+M4dw35lD2LC/jN+8u56P7p58pIslRKs6UOrkkmeXUe70oMOu2teQ168prnJzybPLWXDPZHqmRJ5+/csfDuIJUxkcy3eY9fvL2FtUFb4iWIg2INOeQHCa8NCPIsluwesPvaljW5wxkJQlxQW6BVlMBi/feDIXjuqBzWxgDbNeSa14m4lUh4VXbxrH9KEZUZ3L69dhB1/G2jVxT1FV5HUclIIp98HtS2H09WBxgDUxsF6FyQaWeDDbIWs8XDIXbl0IjrTmjylEG1FKMXfmGOIirH0STrzVzOzrTgLA5DCjG89I04SB6X24fORZbC3YFfKe9mqM+PDdBcWxyzDCz/4qxLHulpfWUOnyhlRGVW1eSO5L97LvH5ey/6nrOPjWQ7j2bwICwxwqXV5ueSm6MWTvfXeAqprDW8bCrzWfbsqL6nxCtIZO2zIVjYxEGw6rGZen4ZpKsc7Al5lkb7BosNlk8Mglx3PPaYN4ZfleXl2xF69fY1IKDbi9fkb0SOKOaQM4dWhGTFN9Vri8YXvQxdw1USmq3L7o1m7oMgjO+wec8XBgWtXK/MA6VPYUyDoZ0mVNHXF0GJqZxCs3ncz1L6ymyu2NuG6boSDBbuY/N4+nb5dALaeymLD2TcK9qyxk+x1Fe/ly53IuGHoq3ZMyyCk/yAc/fMnoHqHLAFh6JmDYJQR3NIZSsqCo6HA2Hihjd2EVvkb/b0fT9c6nNbsLq9h4oKzJHja1DpaHb/mP5TuMx6cpqGi+B4EQrUme5M0wDMWNk/ry1Jeh61NFO87CYTVx+7T+YbsLdk+O41dnDeVnpw8mp9RJudOL1WyQnmClS0JsAzZr2cxGqywO7NM6cKxYWONh2Pmx7SNEOzupTxof/nQSD8/fzNIdhWEnDrCbDTQwfUhXHjhvOL1SHQ3eT5zai7tlTQ0AACAASURBVOL9FWh3w/3irQ6+z/mBOavforymkiRbAqcNmMBvp/+kwXbKZiJpWi9Ex2MoFXYafiGOZXMW7wqZBTmWrndur585i3fx76tObPY8jZO1WrF+h/FG2XtAiNYgyVQEV43tzRNfhl/tPrrFfjUXndj8lyaLyWi1vr2ZSfawk0/EvDiw1Ry266MQHUG/LvG8MHMs+RUuXluxj4825FLu9KAUJMdZuHBUT64cm0V6E5Ua9kGpKKspJJnqntiVZy/6Q8TzK5PCPjS9Va5FHF0MFRhEL0RHUeP18fHGvJBEJ5audz6t+XhjHo96fc2uA5XmCN/1OZbvMIaCdJnYR7QjSaYiSE+wce243ry+KjvqCSNqxVlM3Dalf3Rd5VpJRpKd4d2T+C67tMHrsXRNtJgUl54kteai48tItPOz0wfzs9MHx7SfMhTpVw+lcN6msLP6NctikHbVUJRJJp/oiAxDRew+KsSxpKTKg2EQMnFfrMMHDCNwrMzkprc/bXg3Vu4OnXU1lu8wNrOJiQO7RH19QhwuSaai8MC5w9lbXM2yHYU4o/ziFGcxccaIbtxzWvsv0Hj7tAH8/M3vqWrhFNCGUtxwSt92LLEQxx5b/xRSLx9MyVvbok6olMUg5eJB2AfJjJYdlXTzEx2N0+PDaIWZjU1KRayUvmhUTx6evznse9F+h+mWZOPErJSI5RGitUgyFQXDUDx33RgeeH8j7323H4/P3+SK92ZDYTIUV4/rzW/PGXZE1haZMTQDu8UUkkxB5K6JZkNxQlaKTCkqRBQcx3XFlGCl+O1t+CvdgaSq8fdoFUiijHgLqZcOxj5AHvIdmaFkNj/RsSTazWG7rsY6fMDr1yRGmHQn3mbmRyf24u012XjDnDPSd5g4i4k7pg2Qdd1Eu5JkKkomQ/HXi4/jxlP6Mnfpbt7//gBmw0AHvzkpFD6/5rIxvbjhlH7063LkkhGzyWDeDWO5YvaKmLomKgVJcRaejDBAVAhxiK1fMpm/HIN7XwUVi/fj2lpCXW2LSWEflErilF5Y+ybJA74TMJSSMVOiQ0l1WLGZjZCJemKd2dhmNkh1RB7LdPeMgSzYkEuZ0xNTOQ0VqBD+37ocPvg+hy4JNmYMy+Dskd2bXYZGiMMlyVSMBnVL5JFLjueB84azZk9x3c2e6rAytm9ai9awaQvH90ph7swx3PzSGpxuX7gJ/howG4pUh5U3bxtPtyR7u5RRiI5CKYWtTxK264YDBNeh0iiZxKXTCYyZkmRKdBwmQ3HdhD48v2R3SEIVbdc7m9ngugl9MEWxUHn35Dheu3kcVzy3nOqayN9famkNFTVevtlRVPfal1sO8tv3NnLNuN7cOqV/k5MKCXE4JJlqoQSbmWlDoltI90iZOKALH9x5Cn9Z8APf7AwEl9qpTQ1bDubETVislWj8DErvwc8nXXhEW9SE6CgCk0tIK1RnUu7y8OG6HNZnl1Jc5ebh+ZsZ0i2Rc4/vTnw7TkIkRFu4bnxfnl+yO+x70cxsXHuMxvx+zdIdhSzfWURRVQ0Wk0GPFDvnHteDD+6cxHVzV1Lu9IQdtlBHawiu09lY7QLAL3yzm/+u3c+bt41nYEZixLIKEQuJ8B3coG6JzLvhZA6Wu3h5+U4+2LGACttn+E1FoLygAuFnt1fxy6Wf0SWuCzeMuIELBl6AzSQ1OEII0ZyteRU8t3gnH23IRSmFM/ilb+7S3TisJh763yYuOrEnt0zuR/+uCUe4tEK0TGaynbNGZvLpxryQdTcjsfo8TE9VZCYf6vVS7vLwxqp9zFmym+oab4NkyWwonvhyByN7JvOHC0ZgNimeW7yL7/aV1nXX035dt4+Oovu0x6cprnJz8TPL+OjuyWSlOSLuI0S0VHPdEcaMGaPXrFnTjsURbaXSXckdX9zB1pKtOL3OZreNM8fRK6EXz5/5PGn2tHYqoWgvSqm1WusxR7och0vikzjS3l6TzYMfbMTj082OkzIZYDWZ+Pvlx3POcT3asYTHno4QnzpqbHJ5fFw2aznbDlaEdPdris1sMDDFyuOLniD5hOPIfPBBsiu9XDF7BaVON64IM6E6rCamD8ngn1eMotTpJqfURbXby6yFO1mxuzhkIWGAqs0LKV/9Pp6i/RjWOCwZ/UmeeDn2XiMwFPROc/D1fdNkDKuISXOxSTr0dwIur4uZn8xkc9HmiIkUgNPrZHfZbq7+6GrK3eXtUEIhhDi2vLl6Hw9+sBGXxx9xwgmfPzC99M/fWsdH63PaqYRCtC67xcSbt41nTJ9UHFGMD3dYTYzpk8rbd09lyBv/wV9Zxarrb+OCJ5aQX+GKmEgBVLt9fLnlILe+vIb0eBujslIY3C2RlU0kUuWr3qP4yzkkj7+cXne9Ss875pE4+hyc21cCgZk28ytqWL6zKGRfIVpKuvl1Ar/75nfsKd+D2+9u8HrJkhIKPy3Ene/GZDeRdFIS3S7thinehFd7Kagu4Gdf/4y5Z849QiUXQoijz/r9pTz0v01hvww2Vyvu8vj5xdvrGJKZKOM2xDHJYTXz8k3j+HpLPrMW7WTDgTKAupYqW7Ab3nE9k7l96gCmD80ITDphNdP973/nqt9/SIXTjb/RulSR7puVu4v51xfb+MUZQ3h91b6wZfPXVFG69DXSz7kXx5CJh8o8cByOgePq/na6fcxevEsW9hWtRpKpDq6guoAv930ZkkgVflxIwccF9Lq5FwnDE/CUeMh5JYc9j++h32/7YZgN3H436wrWsat0F/1T+h+hKxBCiKPLk1/tCNvNqXzVe5St/C/pZ9yJvd9olMmMc/danNtXYu81AgCP18/sRbt47LIT2rvYQrQKk6E4bXg3ThvejT2FVSzcmk9pdWBm4xSHhWlDMugbZjKrxTsKKVY2/EbDySSiuW+cHh8vLN3NndMH8tKyPWHvv5oDW9BeN47BE5otvwaW7yqiqLJGZvcTrUKSqQ7ura1vhbzmc/rIfz+fnjf1JPH4QO2otauVrJ9kse2X2yhbVkbqlFQAvH4vr/zwCg9NeKhdyy2EEEejwsoaFm8roPFw42hrxX0aPlyfw+/OH06i3dJexRaiTfTtEs/MLv2i2nb2op0hs/JFe9/U+mh9LkVV7pDXAXzOcgxHEsqI3AXRajLILXNJMiVahYyZ6sC01ry+5fWQVqnq7dX4PX6STkpq8LrJbiLx+EQqN1XWvebTPubvnE+Nr6ZdyiyEEEezt1Znh3092lpxCCzy/r/vZeyU6Dxyy5x8t6805PVY7psqt4/Zi3diNDFxhCkuCX91OdrfzDTqQUpBZY03csGFiIIkUx1YlaeKKk9VyOu+Sh/mBHNwLZyGzMlmvJUNA4yhDPKr89usnEIIcazYcKAsbBejWGrFnR4fm3Nlch/ReewqqKqb1ry+WO4bgL05xfh94ZMlW8+hKLOF6m3LIx5H68B6oUK0BkmmOrAqTxVmIzRYmBJMeCu9aF/oDFTeMi/mhIb7GMoIm5QJIURnU+EKX5sdS604UDfGRIjOoLLGS7hVdWO9bzxmK5kp4deIMmzxpEy6huLPZ1G9bTl+jwvt8+LcuYaSr19osK3b5ycrVdaaEq1DkqkOLM4Sh0+HBijHQAfKrChf27Bm1OfyUbG+gvjhDQeO+rUfh1mCjhBCxNvC16DHUisOkBQn46VE5+GwmiBM77xY7xur2eCmSf2Is4S/D5NOvpjUU2+ibPmb7H/yGvY/O5OKb+cTN+hQN0JDwalDMkh2yD0oWoe0cXZgCZYELIYFj79hDajJYSLjogxyXs3BsBsNZvOzpFlImZjSYHuv30uXOJlCVAghBndL5Kst+XgatezXrxVXhgl7vxNRhhnXnu9x7VtP6vQb67a1mw0Gdg2d7UyIjqp3mgOPL7R7bCz3DUBmkp3LTsrisU+3NnmuhBHTSRgxvcn37RYTt0yRGYpF65FkqgMzlMElgy7hja1vhCRUXc/piineRN6bebjz3RhxBkmjk8i6LQvDYjQ4xqm9T8VhkZYpIYS4YmwWzy3eRbg+S0knX4wRn0rZ8jcpnP84yhqHrdtAkiZc0WA7DfxodK/2KbAQR4E+6fEMykisW5eqvmjvG4fVxE2T+pHssHDJ6F68++1+XGHGLzbHbCj6psczundK5I2FiJIkUx3cVcOu4q1todOjA6RNTSNtalqz+9tMNmaOmNkGJRNCiGNPr1QHJ/VJZdnOorDvR6oVNxScOjSDtHhrWxVRiKPS7VMH8Kv/rguZHh0i3zcAfq3rKiF+f8EIfsgtZ3NuedgJYcIxGYoUh4UXbxyLamJGQCFaQsZMdXBZiVmclHESFiP2vsEmZaJ3Ym9GdBnRBiUTQohj012nDmxyzEYkVrPB7VMHtHKJhDj6nTGiG3HWlt03drPBpaN71c3AZzUbvHbLOMb0SQ2Mx4q0v8WgZ0oc/7trEhmJ9haVQYimSDLVCTw69VHS49IxqeiDmIFBkjWJp2c83YYlE0KIY8/EAV24bWr/mBOqOIuJX581lBOypIuR6HwsJoNXbx4XVfJTn9WkGNQtkQfPH97gdYfVzMs3jeOvFx/HsMxE7BYDo1GDU7zVRGaSnV+dOZSP75lMj5S4w70MIUJIN79OINmWzGvnvMaNn9xIXnVexAV4rYaVFHsK886cR7f4bu1USiGEOHbcM2MQaJi9eBdOT+RpneMsBr84YzA3nNKvHUonxNFpaGYSr98ynuvmrqTa7cPrDzNfej1xFoPhPZJ58Yax2MyhSZjJUFw4qicXjurJlrxyFqzP5WB5DR6/ny4JNqYO7srEAenSrU+0KUmmOokMRwZvnv8mr2x+hdd+eA23z021t7rBNg6zA0MZXDn0Sq4ffj0pdqk9FUKIcJRS3Hv6YE7qm8qTX+1gXXYpfq0bzPJnNRmg4OS+adw9YxAn92t+jKoQncEJWSl8cu8Unlm4g3fWHkApqK43jkoBcVYTKQ4Lt07uzzXj+2AxRe5INTQziaGZSW1YciHCU1o3XSswZswYvWbNmnYsjmgPPr+PJQeW8PHujyl0FuLXftLj0jmtz2nMyJqBxSRrL3RkSqm1WusxR7och0vikzia7Cuq5rWVe9l2sIIKl5ekOAvDuidyzbg+0rUoBh0hPklsil6128uH63L4eksBxdVuLCaDHsl2LjmpF+P6pUmLkjhqNBebpGWqEzIZJqZlTWNa1rQjXRQhhOgQeqc7+L9zhh3pYghxTHFYzVwxtjdXjO19pIsiRIvJBBRCCCGEEEII0QKSTAkhhBBCCCFEC0gyJYQQQgghhBAtIMmUEEIIIYQQQrSAJFNCCCGEEEII0QKSTAkhhBBCCCFEC0gyJYQQQgghhBAtIMmUEEIIIYQQQrSAJFNCCCGEEEII0QKSTAkhhBBCCCFECyitddNvKlUA7G2/4ggh2kEfrXXXI12IwyXxSYgO6ZiPTxKbhOiQmoxNzSZTQgghhBBCCCHCk25+QgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gCRTQgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gCRTQgghhBBCCNECkkwJIYQQQgghRAtIMiWEEEIIIYQQLSDJlBBCCCGEEEK0gCRTQgghhBBCCNECkkwd45RS05RS+9t73xaeb6FS6ub2Op8Q4siR2CSEOBpJbBKtTZKpRpRSlfV+/EopZ72/r2nD885USi1tq+MfLqVUX6WUVkqZG73+olLqT0eqXEJ0FhKbwpPYJMSRJbEpPIlNnYc58iadi9Y6ofZ3pdQe4Gat9ReNt1NKmbXW3vYsmxCi85LYJIQ4GklsEp2dtExFqbZpVyn1a6VUHjAvXK1IsBZiYPB3m1LqcaXUPqXUQaXULKVUXAvOfYNS6gelVIVSapdS6rYw29yvlCpUSu2pXxPUWmWIspwzlVJLg+crUUrtVkqd3cS23ZVS65VSvwz+vVAp9bBS6pvgdX6mlOpSb/sLlFKblFKlwW2HBV+/QSn1Yb3ttiul3q73d7ZSalTwd62Uuj24TalS6mmllGqLz0KI9iKxKapySmwSop1JbIqqnBKbOgBJpmKTCaQBfYBbo9j+EWAwMAoYCPQEfteC8+YD5wFJwA3AP5VSoxuVq0vw+NcDzymlhsRaBqXUM0qpZ1pQvvrGAVuD5XkUmNv4xlNK9QMWAU9prR+r99bVBK4vA7AC9wW3Hwy8DtwLdAUWAB8qpazB40xWShlKqR7B/SYE9+sPJADr653jPGAscDxwOXDmYV6vEEcDiU2RSWwSov1JbIpMYtMxTpKp2PiBh7TWNVprZ3MbBm+EW4Gfaa2LtdYVwF+AK2M9qdb6I631Th2wCPgMmNxosweD5VoEfARcHmsZtNY/0Vr/JNbyNbJXaz1Ha+0DXgK6A93qvT8c+JrA5/hco33naa23BT/btwgEMoArgI+01p9rrT3A40AcMFFrvQuoCG47BfgUyFFKDQWmAku01v5653hEa12qtd4XLMcohDj2SWyKTGKTEO1PYlNkEpuOcTJmKjYFWmtXlNt2BRzA2noVDAowxXrSYJPvQwRqSozgcTfU26REa11V7++9QI/WLANQ28/ZUu/32r899f7Oq/1Fa10dPG9CvfevAXYA/w1zjrx6v1fX268HgWuqPa5fKZVNoLYIArUs0wjUIC0CSgkEhAnBv6M5hxDHMolNEpuEOBpJbJLY1OFJy1RsdKO/qwjcdAAopTLrvVcIOIERWuuU4E9y/YGa0VBK2YB3CNQqdNNapxBorq3fBJyqlIqv93dvIKe1yhCUS+Dm79vo9X7Uu2Gj8Ptguf6jlIo2OOUQ6CIA1NVeZQEHgi/VBoXJwd8XEQgKUwkNCkJ0RBKbJDYJcTSS2CSxqcOTZOrwrANGKKVGKaXsBP6HBwK1AMAcAv10MwCUUj2VUs31NVVKKXv9HwJ9WW1AAeAN1racEWbfPyilrEqpyQT6t77dwjKEFWx+fgf4s1IqXSllUUpdRaD5+eMYDuUBLgPigZeVUtH8P/gWcK5SaoZSygL8AqgBlgXfXwRMB+K01vuBJcBZQDrwXQxlE6KjkNgksUmIo5HEJolNHY4kU4dBa70N+CPwBbAdaLzewa8JNM2uUEqVB7cbQtMmEqgRafxzN4Ebo4TAYMP/NdovL/heDvAacLvWekusZVCBGWtmNVO+nwDFBAYm5gN3AedqrQ82s08IrbUbuJhAn+AXIgUGrfVW4FrgSQK1M+cD5wePU/vvUEkgGKC1Lgd2Ad8Eg5kQnYrEJolNQhyNJDZJbOqIlNaNW2CFEEIIIYQQQkQiLVNCCCGEEEII0QKSTAkhhBBCCCFEC0gyJYQQQgghhBAtIMmUEEIIIYQQQrSAJFOHSSn1olLqT8HfJyultrbTebVSamArH7PuWtpz3/ailLpfKfX8kS6HEO1F4tPh79teJD6JzkRi0+Hv214kNkXWKZIppdQepZRTKVWplDoY/J+31Vdw1lov0Vo3N4VnbXlmKqUaTwfaapRSC5VSN7fV8Q9XW19/8BzTlFL767+mtf6L1vqo/VxE5yTx6egi8UmIAIlNRxeJTUevTpFMBZ0fXMF6NDAGeKDxBkopc7uXSgghJD4JIY5OEpuEiKAzJVMAaK0PEFh5eiTUNfneqZTaTmABOZRS5ymlvldKlSqllimljq/dXyl1olLqW6VUhVLqTcBe770GGb1SKksp9a5SqkApVaSUekopNQyYBUwI1vaUBre1KaUeV0rtC9YAzVJKxdU71i+VUrlKqRyl1I0tvX6l1NtKqTylVJlSarFSakSjTboopT4PXt8ipVSfevsODb5XrJTaqpS6vKXlaFSmPUqp+5RS64PlelMFVjFHKZWqlJof/AxLgr/3qrdvmlJqXvBzKVFKva+Uiifwb9wj+BlXKqV6KKV+r5R6Nbjfx0qpuxqVY51S6uK2vFYhmiPxSeJTcD+JT+KoIrFJYlNwP4lNYXS6ZEoplQWcA3xX7+WLgHHAcKXUicALwG1AOjAb+F/whrUC7wOvAGnA28AlTZzHBMwH9gJ9gZ7AG1rrH4DbgeVa6wStdUpwl0eAwcAoYGBw+98Fj3UWcB9wOjAIOO0wPoKPg8fIAL4lsPJ3fdcADwNdgO9r3w/eZJ8D/wnueyXwjFJqeBPXX6qUmhRDuS4HzgL6AccDM4OvG8A8oA/Qm8DK5k/V2+8VwAGMCJbrn1rrKuBsICf4GSdorXMane914Kp65R0ePMdHsV6rEK1F4pPEpyCJT+KoIrFJYlOQxKZwtNYd/gfYA1QCpQRu0GeAuOB7Gji13rbPAg832n8rMBWYAuQAqt57y4A/BX+fBuwP/j4BKADMYcozE1ha728FVAED6r02Adgd/P0F4JF67w0OlntgE9e7ELg5is8lJXic5ODfLxIIWrXvJwA+IAu4AljSaP/ZwEP19v1TlP8eja9/D3Btvb8fBWY1se8ooCT4e3fAD6SG2a7u36Lea78HXg3+nhj8zPsE//4z8ELw92avVX7kpzV/JD41+blIfJL4JD9H8EdiU5Ofi8QmiU0NfjpTP9eLtNZfNPFedr3f+wDXK6V+Wu81K9CDwM1zQAf/Dwna28Qxs4C9WmtvFGXrSqCGYK1SqvY1BZiCv/cA1kZxzmYFa3z+DFwWPKc/+FYXoCz4e91nobWuVEoVB8/fBxhX27QeZCZQu9Ea8ur9Xh08J0opB/BPAjUvqcH3E4PXkgUUa61LYj2Z1rpCKfURgZqTvxGoabkl+HZbX6sQjUl8kvhUR+KTOIpIbJLYVEdiU3idKZlqTv0bPBv4s9b6z403UkpNBXoqpVS9oNAb2BnmmNlAb6WUOUxQ0I3+LiTQBDtCB/olN5ZL4H/+Wr2bvpRmXQ1cSKCpew+QDJQQCD616s6jArP2pBGoUcoGFmmtT2/huVvqF8AQYJzWOk8pNYpANwMVLFOaUipFa13aaL/Gn3E4rwMPKaUWE+i//XXw9SN1rUKEI/HpEIlPEp/E0UNi0yESmzpxbOp0Y6aiMAe4XSk1TgXEK6XOVUolAssBL3C3UsoSHHB3chPHWUXgRn4keAy7UuqU4HsHgV7BfsRorf3B8/5TKZUBoJTqqZQ6M7j9W8BMpdTwYG3DQ1Fchzl4ztofC4Hm2RqgiEBtzl/C7HeOUmpSsGwPAyu01tkE+jAPVkpdF7x2i1JqrAoMCm1LiQSCZalSKo161661ziXQj/kZFRhsaVFKTQm+fRBIV0olN3PsBQRqUv4IvBn8d4Ajd61CRCLxSeKTxCdxNJLYJLGp08YmSaYa0VqvIdBk+RSBmocdBAf0aa3dwMXBv4sJ9A99t4nj+IDzCQyI3AfsD24P8BWwCchTShUGX/t18FwrlFLlwBcEahXQWn8M/Cu4347gfyN5lsCNVPszD3iZQDP3AWAzsCLMfv8hcNMVAycB1wbLUAGcQaBpN4dA0/LfAFu4k6vALDCToyhnJP8C4gjUQK0APmn0/nWAB9gC5AP3Bsu7hUDtyS4VGNDZo/GBtdY1BP79TiNw3bWvx3StQrQXiU8SnyQ+iaORxCaJTZ05NqmGXViFEEIIIYQQQkRDWqaEEEIIIYQQogUkmRJCCCGEEEKIFpBkSgghhBBCCCFaQJIpIYQQQgghhGgBSaaEEEIIIYQQogWaXbS3S5cuum/fvu1UFCFEe1i7dm2h1rrrkS7H4ZL4JETH0xHik8QmITqe5mJTs8lU3759WbNmTduUSghxRCil9h7pMrQGiU9CdDwdIT5JbBKi42kuNkk3PyGEEEIIIYRoAUmmhBBCCCGEEKIFJJkSQgghhBBCiBaQZEoIIYQQQgghWkCSKSGEEEIIIYRoAUmmhBBCCCGEEKIFJJkSQgghhBBCiBaQZEoIIYQQQgghWqDZRXtFdHx+zddb8tmcW05JlZsEm5meqXGcfVx3kuMsR7p4QohmbMkrZ+HWAooqawBF10QrM4Z1Y0DXhJiOU1btobCqBrfXT1KchcwkOyZDtU2hhRBCHHXKnB4WbMglp8RJZY2X1HgrI3okMW1IhjwPOjBJpg5DUWUNr67Yx4vLduP2+ql2+9DB9xxWEw/9bxPnHNedWyb3Z3iPpCNaViHEIR6fn4835jFr4U52FVTi9Wu8/sDdazEp/v7ZNoZmJnLHtAGcPjyzyYegz69ZtC2fWYt28d2+EqwmA6UUPr/GZjaYObEvV4/vTUaivT0vTwghRDvalFPGnMW7+HhjHoahcLp9ABgK4iwmrGaDG07pxzXjepOeYDvCpRWtTWmtm3xzzJgxes2aNe1YnGPHuuxSrpu7khqvnxqvv8ntTEphMSvuO2MIN0/u344lFCI8pdRarfWYI12Ow9XS+FTm9DDzhVVsPVhBdfCB1xSH1cSorBSev34MDmvDuqfv9pVw68trqXZ7qWriODZzoCf15WOy+P0FI6RmUogIOkJ8ku9Onctzi3fyj8+34fFqfM18p7aZDWxmg1dvHsfxvVLasYSiNTQXm2TMVAus31/Klc+toNzlbTaRAvBpjcvj5++fbePpr3e0UwmFEOFUu71c8uwyNuWUR0ykAtv7WLu3hCtmr6DGe2j7xdsKuHrOSgoqa5pMpIC6ypb/rs1m5rxVeHzNxwshhBDHjie/3M4/P9+Oy+NvNpGCwPOg3OXliudWsGF/WTuVULQHSaZiVFrt5rq5q3B6In8Rq8/p8fHEl9t4a002Gw+UsTWvgsLKmjYqpRAinLtf/47s4mrcMSQ1NV4/2w9W8Jt3NgCwOaec219dG1MMcHr8rNlTwq/fWR9zmYUQQhx9vt6SzzMLd8b+fdDt49q5Kymr9rRRyUR7kzFTMXprTXaDGur6qjYvpHz1+3iK9mNY47Bk9Cd54uXYe40AoMar+fU764kPdhdye/0MyUzk9qkDOGNENywmyW2FaCt7i6pYsr0wbGtyj7uIVgAAIABJREFUpHvX5fXz0YZc/u+cofzug41Ntmo1dxynx8eCDbncPEnGUAohxLHusc+2NplIRXqmuL1+3l6bLcM/OghJpmLg92vmLNmNyxP6Zax81XuUrfwv6Wfcib3faJTJjHP3WpzbV9bdPABaQ2WNt+7vDQfK+NV/13H/ewZPXz2aSYO6tMu1CNHZvLhsD/4w3TCivXcV8OzCnWw4EL57RjTH8fg0c5fu4u+Xj2qTaxRCCNH2tuZVsKugMux70TwLnB4fzy3exY2n9MOQsbTHPGkKicGynUVU10uEavlrqihd+hppp9+BY8hEDKsdZTLjGDiO1Ok3RjxuldtHmdPDzS+vZv76nLYouhCdmtvr583V2Xh8DZOpWO7dGq+f/6zchy9MF8Foj+Pza+avz6XcJd07hBDiWDXvm90hzxOI7ZlSVeNlxa6iZs+zt6iKZTsL+WrLQVbvKaa4yt2q1yFah7RMxWDrwYqwA8hrDmxBe904Bk84rOO7PH7ue3sdGYl2Tu6XdljHEkIcUlBZQ7ixwbHeu01NOBPLcSwmgzV7ijl1aLeozimEEKLt7SuqZn9JNdVuHwl2M/27xje5rMX6/aX4/KEPlVieBR6/ZuvBCiYObNgjye3188mmQ0t3WMxGg/emDu7KrVP6c1KfVJSSVq2jgSRTMahweXCHqYnwOcsxHEkowxTxGBHHZnj83P/uer74xbTWLr4QnValy0u4IYmx3LtAoJ9umIdXLMfxa02pDDwWQogjzuPz89mmg8xatIPt+ZUNxq67vX4m9E/n1qn9mdA/vUHiUlkTfqxULM8Cj9dPpathb6d12aXMnLcKt89PVfAcrkaVeJ//cJClOwoZmJHASzecTGq8NerrFW1DkqkYOKwmzIaqW9yzlikuCX91Odrva/YGinZsxoFSF+v3l8o6BEK0EofVRJhKxKjv3TpN1ALGchytNYbUJgohxBG1KaeMH89dhcvrO5S4NBoTv3BbAav2FJOV6uCVm0+ua6mKs4SP87E8C8wmRZz10DYrdhVxw7zVEWcH1DqwbMcPueWc/cQSPrxrEl0TZSHgI0nGTMWgV6oDmyX0I7P1HIoyW6jetrzJfWMbm+FjzuJdrV5+ITqrLgm2kEoQiO7ejUYsx3F6/Pzjs2188P0B3BHWqRNCCNH61u4t4bJZyymqctclUk2pdvvYWVDJOf9eQm6ZE4C+XRxht43lWWA1G/RKjQNgV0ElN70YOZGqz+PTFFbUcPWcFU3OMi3ahyRTMTh1aAaEqd02bPGkTLqG4s9nUb1tOX6PC+3z4ty5hpKvXwBi60fr17BoW0FrF1+ITivOauL0Yd1oPGlSNPduLbOhGN49EasptFUpluMA7Cup5v53N3D+k0tlvTkhhGhH2cXVXP/CqqgWbq/l9WtKqj1cMXsF1W4v147vg80c+hU6lmeB0+1jT2EVhZU1/P2zbc1Os5770r3s+8el7H/qOg6+9RCu/ZvqynWg1MnHG/Ji+AREa5NufjGwW0xcdXJvXlq+J2QWl6STL8aIT6Vs+ZsUzn8cZY3D1m0gSROuAGIfm1EVw00uhIjslin9+WpLfsgDK9K9W8tsKB44bzg3zFtNuFqVaI9Tq8rtY1dBJec/uZQFd0+Wfu9CCNEO/vXFNqrdoTMzQ/Pj2n1+TX6Fi1++vY7vs0vDLrUB0T0LrCbFGSMy2ZZfybTHvqba7QvbFT2a4SHVbh/PLtrJRSf2PPwPR7SIJFMxun5iX15duTfslJgJI6aTMGJ62P1iHZshIyqEaF0n9Eqmd5qDHfmV+Bo9BOvfu4YtB3PytxiWNShjOdobj3YOZHjyFCYO6MLYvml8s7Mw7OyAzcWAcDx+TWFlDdfPW8UHd54iMzMJIUQbqnB5+GhDbosTF5fHz1dbCnjt5pPZlFPOXxZsCduiFOlZYBiK35w9lF6pDv79xTae+mpHSHJWOzwk/Zx7cQyZWPe6Y+A4HAPHNdh2b1EVm3LKGNEjOZaPQ7QSSaZilJXm4NFLjudX76wPu3hvU+r3o40fOini9gk2+acRojUppXjhhrGc8+8llDs9jdqWNOakdVjTF2JYi0B5UareFkmbybbM5y8rL+LuM6/k++dLGyy+HUlztZ0en2ZHfiVr95Ywpq8siSCEEG3lnbX7w1ZaxZK4KBUYjnHt+D4s3VHIom0FMX0ftFsMHrv0BHqlBsZdLdleiOcwp1n3+2HFrmJJpo4QGTPVAheM6smfLhyJPcxkFE2JpR+tyYCzRma2drGF6PR6psTxzh0TSU+wYqkdQKW82Hu+ir37u5jseSjD0zCRAjBqcPmcvL31be5Zcg2/uzSeeKspqhbk8lXvUfzlHJLHX06vu16l5x3zSBx9Ds7tK+u2cXp8PCeTzgghRJv6aks+zjDDKGJJXFweHyt3F6OU4smrRjN9aEaTs/s1ZrcY/Pmi4zj/hB51r5U5wy+VEcvwELfPT2m1LOh7pEjzRwtdOiaLARkJ/PPzbazYXYz267A1C/VFO6bCYhjcNKlfWxZfiE5rYEYCn9w7hae+2sFba/ZiZL4KcTtQRuSWJq/2Uump5PH1P+cf1z3Do/+rJLfM1eRA5mhrO7UOTMFbVFlDeoJMcSuEEG2hpIk1/mJbKxCKghMHWc0Gz1w9mrfXZPP01zvJr6zB5fE16AZuNRsoYHz/NO49bTAn9k5tcDxT45mRal+PYXiIAqzhFlMU7UKSqcPQO83Byf3TWL+/jKoaD0aw6VcRdtI/ILoxFQMzEhjULbHVyyuECOiSYOP3F4yga9YS5m7YjUc3TKRKlpRQ+Gkh7nw3JruJpJOS6HZpN0zxgQea0+vk4TX38tndn7Etz83Fz3wTtg9+LLWdNrPBzoIqSaaEEKKNWMLMxgqxj2u31pvJTynF5WN7c9mYLL7LLuW1FXvZW1RNtdtHcpyF0X1SuG58XzKT7WGP1TXRxpa8ipDXYxkeYrcY8uw4giSZaqEFG3L4+VvrgNBF3ppvn2qew2risctOOIwjCCGi4fF7eO2Hl/HohlOTF35cSMHHBfS6uRcJwxPwlHjIeSWHPY/vod9v+2EEH6Iev4dP93zKRQMvCjsZBcQ4i6cODI4WQgjRNrolhU9oYklcrCZF18TQ4yilGN07ldGNWp4iuWJsFt/uKwlZ76r+8BBlmLD3OxFlmHHt+R7XvvUN1ij1+eG04RkxnVe0HmkTbIG312Tz87fW4fL4Yxp0GInDauL5H49hWPekVjumECK8r/d9jU83fHj5nD7y38+nx7U9SDw+EWVWWLtayfpJFu5CN2XLyuq2rfZW88LGF1BKYY6itjMiRdT97oUQQsTusjG9iLeGxtmY1gpUqlXHtZ8xPBOjiZlck06+mNRTb6Js+Zvsf/Ia9j87k4pv5xM36FBvBwVMGdyFjDAJnmgf0jIVo1W7i3nwg42tlkSp4Beorgk2nr5mNCN7ykwsQrSHlze/TLW3usFr1dur8Xv8JJ3UsELDZDeReHwilZsqSZ1yqNbxYPVBthRvoUuCjdwyV8g5Yqnt9Pj8dE+JO4wrEkII0ZypgzOwWUxh1/KMdlz72D6p9GzFWG01G1w3vg/PL92N2xv63TLS8BC7xcStUwa0WnlE7CSZitFfF/zQZCLV3PTHcGjtKKvZQGuNX8OMYRncOqU/o3unyhozQrSj7IrskNd8lT7MCWZUmJYmc7IZ515ng9cMZbC3fC/Xjh/Ak19ux9XoQRhLN40+afH06xLfSlcnhBCiMZOhuHFSX576ckdIvIbIiYvDauK2qa2fuNw5fSCfbMpjb1EVvhjq6uMsJi4Y1YOxfWPrWihalyRTMdhVUMnm3PKw70Wz2JsGhnVP5OUbx2E1GSTYzU3O4iKEaFsub2hLkinBhLfSi/bpkITKW+bFnNAwZPq1n2pPNVed3Jsnvtwe9jzR1HbGW03cMU1qFoUQoq3dPKk/C9bnsT2/Ao8v+lHucRaDM0dkMnlQl1YvU7zNzBu3jueyWcvJLXOFbaFqzGJSnDo0g7/86DipjD/CJJmKwbxv9uALM2VXLIu97S6sorjKzZBMma1PiCPJbraHdPNzDHSgzIryteUkn3yoy63P5aNifQXdLu3WYHtDGcRb4kmLt3La8G58vukg7jDVipFqO00mxdnHydpyQgjR1uwWE6/dPI7LZy9nX3E1NVEkLnEWE6cMTOfRS4+POnGprPFysNxFdY2PRLuZzGQ79mbGxWYk2pn/00n85t0NfL75IArCli3eagrOHqv57blD6yrlfX7NwXIXFS4vVrNBlwQriXZLVGUVh0eSqRgs31WE9zBXqVbA99klkkwJcYT1SepDsau4wWsmh4mMizLIeTUHw240mM3PkmYhZWJKg+19fh99kvoA8NeLj2PD/jIOlFbH3E1j3syTsZll8gkhhGhTPg+4ykg1zHxw5wQe/GAz73+fE3grzPc7R3Bx9lsm9+fuGYMwIvQm0lrz7b4Snlu8i6+3FGAxKZRS+LXGrzUXjerJTZP6Nbn8TaLdwtNXj6aosobXV+/j1eX7KKqqwePT2MwG/bvEc/u0AZw1MpNZC3fx2/c28rdLjue1lft4adkeXF4fZiMwlMTt8zO+fzq3TRnAxAHpEcsuWk6SqRhUusIv6hnL9Mdev6bcGXlxUCFE27p+xPVsLd4a0jrV9ZyumOJN5L2ZhzvfjRFnkDQ6iazbsjAsDSdA7ZnQkyFpQwBIslv47+0T6mo7I6zhDQQqVx44dxgn9ZH+7kII0SY8Ltj8Piz9JxRsBbMVtMah/Tw66CwKbeMZPnYG/1ufR0FlDW6vH7vFoG96PLdN7c/ZI7s326JUa39JNTfMW82BUifO4MK9jee5eHttNu99d4AxfVJ59rqTSGqi5Sg9wcZd0wdx1/RBAPj9OiQZunVKPyY+spuJj3yFyVD1WrEO1eYt2V7It3tLSIqz8Pz1YxjRQyY5awuSTMWgNRZ7M5TCZpEZ6YU40qb2morFCP8gS5uaRtrUtGb3t7vhRxsduE/KxpqVBQRqFeOtZqKt/9PAnxf8wODMRMb2bf58QgghYrRyNnz5x8Dv7srAf72H1hZUWxcwW32BfccL/Pr6l6D78WitYx6DtCO/kkueXUaly4uvqYUHCawH5fP7Wb2nhPOeWMr7d55CWrw14vEbJ1Ien5+bXlpDtduH16/D9pqqVeX2UeX2cdms5cybOZZx/dOjvzARFflWH4OmVq+uP/1xJGaTIiNRVqkW4kgzG2ZuHHkjcaaWTXEbl5jKaanj2HPpZeT96c94Cgu5/dW17CioJIYxzVS7fcyct4pdBZUtKocQQohGtIaPfw1f/D6QRLnDx1cDP3btguJd8MKZsHtxzIlUUWUNVz63nHKnp9lEqj63z09umZNrnl9JjTeKdQgb+flb3/PtvpKoxnvVqnb7uPHF1ezIr4j5fKJ50jIVg2vH92FzTnnI+gSxTH/s9wfWORBCHHkzR85kfeF6vjnwDS5f6Ox+TXGYHTx/5vP0TB2M98prKZw1m3evuYtVo6+hRofWUUVaNqHa7ePPC35g7vVjW+3ahBCi01r6T/j2ZfBUR962lqca/nMl3Pw5dBsR9W7PLNxJmdNDuDSqudjv8Wn2Flbx4bpcLj2pV9TnW5ddyheb88Mu0xPNs+YPH27mlZvGhewrWk6SqRicNTKT+9/bGPa9aKY/NhuKy8b0Ii7M6ttCiPZnKIPHpj7GA0sf4Ovsr3F6nc1ubzEsxJnjmHPGHAanDgbAnJ5O5m/vZ37qIpx7K2jcxy+qZRN0oG97foVLVrEXQojDUZ4Lix5p0J0PoO+/Kqj2wO57Eoi3BgL189+6eXW9h4Uzg2v8earggzvh1oVRncrl8fHGqn1hp1iPJvZXe3zMWrgjpmRqzpJdYVuzol2iZ+XuYnLLnHRPlkXiW4skUzGwmU1cM643Ly3bE7ZpNdL0x2ZDccMp/dqyiEKIGFkMC49MfoSvsr9i7oa5bCvZhs/vw6sPTRTjMDswlMHlQy7n2mHX0tXRtcEx8spcrMqphkbdQ2JZNkEBr6/cxz2nDW79ixRCiM5izQtNvuXT8O+Vbu6f3Mxwi/wfAhNVdB0S8VQLNuSGfT2W2H+g1MX6/aUc3yul8WFClFS5+XzzwZAJjmI5HxpeWb6XX501NOL5RHQkmYrRz08fzKJtBewqqIxxsTcT9505mH5d4tuwdEKIllBKMaP3DGb0nsHust3M3zmf3KpcXD4XqbZUxmaOZUbvGVhM4Ses+GZHYdgFuGNZNqHG62fBhjxJpoQQoqV8Hlj1XEirVK1fTrTy6Dc1/GSslRR7E2OjfF5Y8Syc/6+Ip/t4Q17I0A+INfb7WLytIKpkauXuIiwmI6RCP5bzuX1+PtqQ26JkKr/cxa7CKipdXhxWE71SHfROd8R8nI5GkqkY2S0mXr9lPFc9t4Lt+RVRTX9stxjcNrU/N03q3/YFFEIcln7J/fjp6J/GtE+p04M3TOVKLMsmAJS7PDGdVwghRD2568Df9PIzY3qYmNbXzOPLavjTqU10qdbewFTqjZIprTU1Xj9Otw+nx0e120d2SfgxWbHEfr+GjzfkUh5m+Z3G6d4PueW4PKHJW6zPmoomlvoJWz6/ZtnOImYv3smq3cVYzUagv6ACt9fPkMxEbp86gNOHd8Ni6pzz2kky1QJp8VbuO3Mw977xPX4d6NlT3ahmQqlAa1SXBBv/d/ZQzj6u+xEqrRCirRmK0KcesS2bACG9BIUQQsSiughU81/o/zjdxikvVHHPuKanJPc4KzjnH4uodvtweQLJk8vjw2IyiLOaiLMEfnLLwk9cFGvsj7OaSW80RXq4unqH1dwqz5po5ZW5uHbuSnJLnXUtcI1bxdbvL+OX/11HnMXEKzeNY1j3pFY7/7FCkqlwvG7YMh+2LoDKAjAMSMiE4y6B/qdS7fXzhw838/Q1oxnXL50P1+Xw4rLd5FfUUOP1E28zc3yvZG6b0p/RvVNjnmZTCHFsSYu3YjEp3I0q++ovmxA/dFLE46Q6Iq83IoQQogk68lThIzNMnDfYzCNL3QzrGj7xMit4+prRgaQpmDzZLaaQ7tx3vLqWjzfmhewfS+w3G4pTh2Vw29QBEcv+6aY8lu4opNLX8GET67MmOS58l/X69pdUc8FT31Dm9OCL0A2rqsZHVY2PS55dxms3j+PE3p1rIXpJpuqrLoZlT8Lq5wM3ZON1CX74AKzxfJNyGRN7X8i0IYEpzi8fm8XlY7OOQIGFEEeDKYO6hu3mF8uyCXEWEz86sWd7FlsIITqWuFTCt+k09IdpdkbPruQXE8JPRKFs8QzulhjxOBeO6sHi7QVU1bR8yRyzSXH6sG4RzwUwYUA6Hl9owhjL+QCcbi+zFu3kzBGZYcfyV9V4uWL2Csqqo187CwK9tH78wio+vmcyvVI7z1gqSaZqFe2EeeeAswR84Qcu4q4CdxWTKucwvctXUDUf4ru0bzmFEEed1Hgrpw3rxscbc0PGUUazbAKAX2suGyOVMkII0WLdR0XVOjUwzeCKERaeWOXmuIxGrVPKgIGnR3W604Z1w2wYQOg4pmhj/6CMRAZFkbgBJNktnHtcdz5YlxPSWhTt+Wxmg/vPGcbK3cVcPns5qQ4LZ47I5MwRmYzokYRSirfXZFNcVRM2kYq8lpWXp77awSOXHB/VNXUEkkwBlO2H52eAs5RoajTicEPxDph7Oty6COydr3+oEKKhW6b056st+TjDDA6OtGwCwORBXaLqeiGEEKIJFjuc+GP06udR/uYn9PndVBuvrA+zjdkOE++K6nRmk8GPJ/ThucW7WrRkjsNq4rapsU1OdvPk/izYmBu2612k8xkKThnQhQtG9eSCUT15+MKRfJddyqeb8rjzP9/i9WlOH57Bh+tycYZZFDiatax8fnj/+wM8eN5w4m2dI83onNNu1Kc1vHoJuMqpn0j1/VcFGY9VUOU+9Nrz37qZ9mJV4A+/B8oOwPu3t3OBhRBHo1FZKZx3QnfiLLGFVUXggbp2TwkffH8g4vabcsr49xfbuP/dDfzmnfX8/bOtrN1bgo6hK4YQQnREWms+ib8Ad5jGqT33JnJa/0Nf7rOSDVwPJB1asLdWci/ocWLU57x1Sn8yk+1hl8dojs1sMCorhbNHxjZB2fAeSVw0qidxltgnmoi3mfn9BSPq/jYMxUl9Uvm/s4fy1c+nMnfmGCpdPkqq3SH71q5llXb6HTiGTMSw2lEmM46B40K6ERpK8d53+2Mu37Gqc6SMzcleCaXZoMNMNRlpcTdfDWz/AspzIKlHGxdUCHG0++uPjqO0ysPSHYVhW6gaMylFgt3M27dPwO31c/cb37FoawF/uHAEifZDrVRen5/563N5dtFO9hVVUeP113UnVMDzS3eTkWDj9mkD+NGJPbG34CErhBDHsp0FlTzw3kYqanyMGX41tu1vgyf81OVNMsfB+f+OaZdEu4U3bh3Pxc8so7CyJqo1SO1mg8GZiTx//ZiYkzCAP//oOIqr3CzZHt2zpvYM950xuG5dqIPlLl5Zvpc3Vu+jpDowyYTNbJAcZwm77E8sa1lVu318taWAa8f3jeGqjl3SMrXsiSZvtl9OtPL4shpKXRFujNVz26BgQohjjdlkMPu6k7h+Yl/sFqPJmkOTEVh/bmTPJBbcM5nB3RIZ2TOZ+T+dhM1i4twnlvLdvhIAKlwerpyzgvvf28DWvAqcHn+DB50GnG4fe4ur+eOHm7jwqW8oqmxi3KcQQhxlPD4/CzbkcuOLqzj3iSWc9a/FXPnccuYu2UWZM/Laey6Pj398vo1Ln13G6cO78f5PTqHLpf+CATPAEsMkCOY4uOhZ6DMx5mvonhzHgrsnM65fOjazgcUUPkGyWwxsZoNzj+/O27dPCEx13gImQzHr2sCzxmZu+lljBJfp6d81nn9eMYonvtzBNzsKufHF1Ux59GvmLNlFYaW7rstgjddPfkX450esa1mFa93qqDp3y5SzNNCy1MQ4qagWd/PVBGb/m/Fg25VTCHHMMAzFb84eyl2nDuS9b/cza9EucsucWEwGPr/GZCh+dGJPbprUL2TQscNq5q8XH8cnG3O55eU1XDOuN59tPsjO/ErcUdR2Oj1+dhVUcuHT3/DR3ZNlDJYQ4qhVVePlmYU7eGX5Xnxah8yIty67jEc/3crZIzP5+elD6lpU6lu6vZAH3t/A0MxAxVT35LhDb17+Mnz+IKyeAyjwhl8TCmsCGCa47CUY0PzY1uakxlt59eZx7Cuq5sVlu3lzdTY1Xj9mQ+Hxa5LjLNx4Sl+uPLk3XRKa6PEUg9pnzU+mD+CdNfuZvWQXhRU1WEwGfq3xa80ZIzK5dXJ/TshKAaCyxsN1c1diKIU3wnTnjcW6llVnWsC3cydT5TlgtjY9ex/RLe5GTXlgbSqzrBEjhAhIsJm5bkJfrpvQlxqvjwqXF7vFRLzVFHHtubNGdueE/2fvvMOrKNP/fc/MaUlOeiEhARJCkyZNOoIVRVRUFHsFFbu76+73p9t0V9fddXfVVVGRVRGxCyoW7PRepUMggRBCQnpOPzPz++MAJjmTnHMwkIS893VxeTEzZ+Z98ZyZed7neT6fTglMfH4p5U4vDduhmlJT8mk6JdVu7nprLe/eGbocQyAQCE41JTVurnt1JYUVLkPhBuB4+dqnm4r4dnsJb95+FoO7JAFQWuPhyc+3sSa/gicu78N5RtLisgzjn4RRD8G6N2DVS+Bzg2wCdPB7IK03jH4Iek0EpXkWnzonR/PHS/vwh4m9cXhVnF4/sVYzNrN8UnxH42xmbhudw62jsqn1+Kl2+4+X69UNaCqdXl78IQ9dB/8J9NhG6mWVHtdIEuI0pH0HU14HhlbSdQjH3A3ZFPCkMiU1/xgFAkGbx2pSsNoj62NSZIkajz8okApHTcmr6mzcX8nuwzVhS+4KBALBqaDa7WPyjBUcrHRiYJkUhKZDrcfPTbNW88FdI9hYWMm/v97F5CFZfPOrs0OXytlTYewjMOZXUHUgUJWkWALWNva05pmUAZIkYbeasJ8iRTtJkoi1mev129bl6S93cKTWY1iLFUruHCLzsoqxKEwenNXcU2y1tO9gymoPy48glLkbqg+s4oVFIBA0H2+v3B+01HNMTSl5wkNE9/y5rj+62zCiuw2rd6xP05i1dF+78voQCAStn//7cDPF1W7DQKqpl3qnV+XyF5fRLzOOt6cNo1d6hLY0sgKJ2ZDYLNNoUzg8fj7ZWGQojhHOAt0xwvWyiraaGN2t/fiwtu9gKj4LNH/Iw5o0d4PA6kYzpYcFAoFA13XeXJEfVP4SiZrSMa+PP1/WR6j7CQSCVkFpjYdvd5TgNSjtC+elXpEl7jw7N/JAqp0zf8NBjCoMI1mgO0YoLyubWWbamBzkE1ApbKu0n+4wI6yxcMalIIV+0fjjWGs9z6njmKJg2PSTMDiBQNBecXhVat3BCz2RqinJksTh6kaargUCgeAUM3dVcMYdwvcw8vg1Xlm899QN+DTh3TUHcHqDJdQjWaALB5Ms0SMtlltGZjfL+doK7TszBTDiPtjxeZA8ev5D9cv2jpm7BaFrMPiWkzlCgUDQznB4/JiUYLWlSNWUZEmixiAoEwgEgpZgzqoCQ8GJSF7qtx+q5mCli8yEqJDHCgKUOZpH7rwpzIqEpsPTk/thNbWvaojWHUzpOjjLwV0JJitEJ4M58h/PzuIavvzpEMXVbvyaTlqsldHdUxjRNRmp44CAmsuhTaCF9jOohykK+kwKlPkJBAJBMxFlUY77ftQlUjUlXdeJOUXNzwKBQNAUuq436oEXyUu9xSRTJIKpiDB6nkDkC3RGRJkVNF3n8gGZdE2N4cF3NvLh9JFB1hx7S2s5VOXG6VWJs5nolmYnuRkk4lsDrfMp6yyHDW/BiheOqq6YA4GV6oPcc2HUA9BlFIYFoEdRNZ0vfjrEjEXjTEipAAAgAElEQVR57C2txefXqNt398byfOKizNw1pivXTJ6L6dUxyK4jmAntJA2AYoWUHjDx2V84WYFAIKhPrNWExSTjU+vfjyJRUwKOLx4JBAJBS+PX9EZcPSN/qTcqWRM0TpzNzOHq4EA2kgU6iyKTFGOh0uXF49Mwm2Q6xFq5fVQOVw3JIu6oimBxlZvpc9bxxm1D0fSf38UPlDsDUu06IAVKNsf2SOXOs7sypEviSZGNP1W0rmBK0342WJNk8LkC29U6Lsq7v4aCpYEs1bXvQHrfoNO4vCrTZq9l/f6KRn9wTq+K06vy94U7eHmxmTjvE3wa9zRmZzH4XU2P0xwNGQPghvfB3H509AUCwalBkiSuPaszs1fkB6kvhaumJEtwQe8OIjMlEAhaBWZFRpEkQ4+jSLPusTZxX4uEc3ulkV/mCHqeRLJAJ0nwwd0j6JQUja7rjQY/f5jYm7veWse02WvYsL8SVdNxHH0Xd/vql3h+u/0wy/YcoVuanTdvG0piTNv0a20930ZNhfdugL2LAkZqjaIH/KG8Dph1Idz4IXT5WYHE69e44bWVbC2qbtQIri5un4bb58FlS6T21u+J2v0BLH8eXBVHfaiOIsmBUsPEbBj1MPS9CpTW888nEAhOL24dmc2clQVgsJYbSk0JAi8uU8d0PUmjEwgEgsjpkhxNXqkjaHskL/Vev0bXlJhTOew2z00juvDG8nyMnifhLtAN6JRAp6RogCazSIosce1ZWUx7a12QT2JDdD2Q3Nh+qJoJzy/hs/tHk9IGS/9aTzSw4OFAINVACKJJfA54+2qY9gOk9gDgic+2su2QcSAVyr/glre28vkDdyINvRMKlsHub6C2JOCiHZsRUP7LOLO5ZiwQCASN0ikpmqE5SazcW2boDdIUshQodX539X5yUmKCatcN0XUo2Q41hwLVALb4QD9pVMIJzkAgEAjqc9fYXP786VbDqqFwXuplKZBlSYhumxmMliIrMZrBXRJZnldmuD/UAl20ReHusblhXWv34RoeeHdjyECqLj5V50iNhxtmruKz+0djMbUtsfHWEUwVb4Gf3v+5rA/IfrYGpw/2PWgnxhKIgF9b72XOZh8/3lpnRcLrgC9+A7d8SpXLxwfrCg0DqVD+BX5NJ7/Mwfr9lQzukgjZowN/BAKBoIV44bpBTHh+McXVnkYbiBsiSxAfZeb9u0fw+rJ8xv9nMY9f3ofxfdKNP+CphZ8+gGXPHl08qvNYUL2BRaQR90HHAc0wI4FA0J65tH9H/vTJ1kb3h/YwUph2tsi4nwiPTjiDq19egcsXWb+ZRZHp2SGWs3ukhnX801/twNVIi01TSQ2fpnOgwslXW4u57MyOEY2xpWkdod/Kl8DvDdqs6vDcquDt9dHhwEqoPMBH6wqRDVKP4foXuHwqM4V/gUAgaCXER5v5+J5RdE6KxhrGSp1FkUi1W5l3zyi6p8Xy1BX9eO7aAfz9yx3c8/Y6SmoaeE7t+hqe6QELH4OK/EBlgKf65z9+N2z5CF6/CGZPCgReAoFAcIL4NY2sxChDr6lQmBWJ7ml2BnYS2fIToW9mPC/eMBCbOfxXfwmIsZp4846hKA1MeCudXvYdcbDviINKZ+BdvaTGzdLdRwyFRqpXz6P8u5nED7+GrPvmkDn9dWIHTcC1e9XxY5xelZd/zDuR6bUoLZ+ZclcHHtZ6cBT7yEgL/1jm4Z6zLCTYmvjp6Tqsnsms9WMNI+5w/Qt0HX7YWUKl0ytSyAKBoFXQIc7GgvtH879l+3h9WT4en3q8mfcYMRYFWZa4aXgXpo3pWq+Jd1jXZL54cAzPf7ebi59dwu8u6sXVQ7KQfvoQPr0/tOCOrgWqBvavgNfOg6nfgdXerHPMq8xj7o657C7fTa2vlhhzDN0Tu3Ndr+vonti9Wa8lEAhahs2Fldz/zgZG5iaTlRjFyr1luHyhe9shYAabbLfyxm1D27TqW0tzbq8OvHnbUKbNXltPGKIhsgRWk8IZGbEcrHSxeFcpE/t3xO1T+XLLIWb8mMe+Iw4sSiAw86oa2ckxdE2JQTeo7zuW1Eie8BDRPX/WOYjuNozobsPqHbv3SC07iqvplW7g7dpKaflgqmA5yGbAHbRrSEeFcdkmnlnu4a/nNqGap3rRt86jqGqY8e4I/AvMisyBcpcIpgQCQashxmri/nO7c8+4bvy4s4RPNhZRWutB13SS7RYu7pfBhb3TG60zt5kVfntRLyb278jvPtrMtlVf88fKR5H9wffdRvG7oXwfzJ0Cty5o0poiXBYdWMRLG18iryoPv+ZHrbOo9tORn/g071Ny4nO458x7OKdz04IbAoHg5FNa42HboWqqXT6sJpkOcTb6Z8U3GeBoms7/lu1jxo95PHF5Xy7pn4FP1fjth5v5aktxyLKzaItCZkIUc6cNb7Nqb62JYV2TWf3Y+ceDov3HJMuP4vVrXNQnnaljutIvK56tRVXcNGs1e0pqeG1JPrr+cxBW175jd0kteaW1GFWkR2LKjA6r95WLYCoiXOWGWaljPHGOlVH/c/DgsKZ/QLVlZciahioFv0xE4l8gSVDjidC8VyAQCE4Biixx3hkdOO+MDif0+d4d45h3z0gqnv1NUCAVVp+q6oFDGyB/KeSMOeF56LrOc+uf4+3tb+NWjQM6VVdRVZUd5Tv47eLfcl2v63h48MNiVVogOMXous6a/ApeWZzHkt1HsCryMasgVF0nzmZm2pgcJg/pFCR2U1br4dcfbKLS6WP+vaOOq8GZFZl/X3Mmkwdn8cqiPFbtKwc43vNuViRMskxWYhR3j83lkv4Z2MwnZiorCMZmVrhiYBZXDMxi3xEHxVVu3H6VOJuZbmn2ev8f+3SM54LeaTz77Z6Q522stTeSpIbHr1HlbFvv4S0fTIWgb5rCxB4mnl7q5YzUxus8o60ymt94f0T+BTrYhS+LQCA4TTGV7SLVtc9w37E+1UfHNCFN63UG7CN+QTD14sYXeXtH44FUQ9yqm3d2vIMiKTw4+METvq5AIIiMKpeP215fzY7iGlxeFZ1A5qIuTq/KM1/v4p9f7+T5awdy4VGxm+V7jvCr9zdxxaBMfnVBj3rZDwjIa4/qlsKobikcqnLxzbbDHKnx4FN1EmPMjOiaQr+s+FM11XZLTkoMOU1IzX+68SCfbCz6RdeIJKkhy2AWan4REpUEUtP/sI+PszHolVp+PaLxB7wSm0iKZKW0NtijKhL/Ao+q0TEh6sTnIxAIBK2ZlS+B6jfcFV6fqh6wsag+BHEZEV9+TfEa3tz6ZlAgVbGkgiMLj+At8aLYFOIGx9FhcgeUmMDzwa26mbN9DsM6DmN4xvCIrysQCCKjyunj0heWUlzlwhvCnuFYqd4D727gicv6cKDCxXtrDvCva85kTPfQKnAZ8VHcPCK7OYYtaEY8fpVH520JMtuFppX5GhJJUsNqUkhqY+WcLR9MdRkJWtPpvG5JMlP6mHl+tZd+aQbRqmKB3ldws6kLL3y/x1AaPVxTsmE5SW3SMEwgEAjCIu8H0I2DqbD7VBULHFgFfSZFfPmZm2cGBVJHvjxC6ZelZE3Nwt7bjq/CR9FbReQ/k0/OYznIR1cp3aqbmZtnimBKIDjJaJrOLa+vCiuQqovbp/G7j3+iX2Y8nz8whtRY8T7VlvlqS7GhoEQou6GGRJLU8Gs6559gKXtL0fLBlC0O+lwJm99rsnfqj2OtvLW5kaBLkmDoNK4zpfHC943XdIbyL4iJwJRMIBAI2iSe6iZ3h9Wnqqvgroz40sWOYtYdXldvm+pSKZlfQuYdmcT2jwXAkmqh0z2d2PXILqqWV5F4duLx4zeWbqSotoiO9rblQyIQnGrUKg/+cjeaR0W2KZiSolDiwlvxX7y7lN2Haw0DqVAZCV0PBGMikGr7vLJob5DiXyTKfHUJ15T5/DPSRGbqhBhxL2ydV0+iN/+h2HqHdIqXcf/eSNlDgk7DIaETKcCFfTrw9dbDhtmpppAlSLZbGZmbfAITEAgEgjaC3PRtP7w+VSmQnYqQT/Z8ErTNuduJ5tOIG1z//q7YFGL7x1K7tbZeMKXrOvN3z+eegfdEfH2B4HRH13TcuyqoWXQA74EapDq9J7pfw5odT+zYLKy5CUhy42IuRi/REH5GYk9JLbsO19CjQ2zQOQRtg1qPn12Ha4K2R6TM14BQSQ2rSWHamLZnytw6OrzS+0L/a8AcHflnLTEw4Z/H//r0lf3JSozCrISv+CQREJ146w7hXyAQCE5zokMvGD0+zsbM9V4OVjdS3iPLEBO6D6Ih+dX5eLX6RuxqrYrJbkIyuGeb4k34a+uXJPo0H/nV+RFfWyA43fEfcVH8jzWUz92Bd181+HV0t3r8D34dz55Kyt7aRvG/1uKvMBaAKaxwsn5/RdD2YxmJpAumE91zJLLFhqSYiO42rF6ZFoBP05i11FjoRtA2qHR6g0RDIDJlvkiIMitcMySLgZ0TQx/cymgdwRTAxP9AztjIAipzDNzwAaT2PL4pxmriw7tH0j0tNiyXZ7MMiTEWPpw+ki7JjauZCAQCwWnBwBvB3LTITt0+VUN0HbIjV/Nz+BxB2xS7gr/Wj25QTuSv8mOyB2fSjM4jELRnfMUODr+wAbXKg96IEesxdK+GWu7m8PMb8B0JNu3esL/ScEE6koyEqgXU/ASnH3WV+ZqLKLPCRX3T+dOlwf1WbYHWE0zJClw7F8/AW/FLFlwYl5BoOjixUW3NwHnzlwEBiwYkxlj4+J6R/ObCnqTFWomxBEfPMRaFGFnjKvc+Fj50tkhFCwSC9sHAGwPBUAj+ONaKwxt8nC6bYdDNYG5CoKIR4i3BMsfR3aKRTBLV6+r3cqlulZrNNcT0Dl7kirO2HTNHgeBko9Z6KX11cyD7FK5WhA6620/pK5vRGnj6VLt9+A0MgyLNSNR6jIVuBG2DhGgLPjW4ZaauMl+4WBupFouxKMTaTDx8QXf+fc2ZyE2UnrZmWkfP1FEKq9xM2XwBPt8gLte/Z6rpCxKoxYcJGR0TfhZr/XnVP5HN/t5kvFfBe3e6SYsLfqjbzApTx3Tl9lE5LM8rY96GQg5Xe1A1naQYMxf0Tmd890QKL76I2APnQu/eLTBjgUAgOMVEJ0GvibBtPmg/v+yE26fq0eAL8yVM0vSIH3x9UvqwsGAhrjr9sUq0QtqkNIrmFCHb5HpqfuYkMwkjE+qdw6bY6JvSN6LrCgSnMzWLD6J5grMEqws389QPM9h1JB9Zlume3IU/nXc/AzLOCBygg+byUbvqEHHndD7+ObMsIxP8247EKwjAZFAiJmg72K0meqXHsqWo/kJXJMp8AN072LlyYCazVxRQVuvFq2pYFJme6bHcPTaXC/t0MCwnbEu0mmCqtMbDFS8up9zhRdWjmclEZqqXkEgNcZITr26mAjtujqrD+HUOlDuZ9OIyvnzwbOKjzYbnlWWJ0d1TGN09xXB/8rRplL7wIp1eevFkTU0gEAhaFxc9DflLoLaE8JeyAXM0tYMf4O1dMu/sXsHTV/UnN9Ue9scndp3IM2ufCdqeOiEVJUah+L1ivCVe5CiZuEFxdLqrE3KDcm0dnUtzLw1/zALBaYzu13CsOhRw3K5DjcfBbR/+H09e+Csu7XUOXtXP6sJNWBsKx/h1apcWETu2E5rTgWvDRszLt4EnFeT6x0biFQSQ3MYU2QTB3D0ul999uDlIjCRcu6EYi8J953Tj8gGZTB/XDYAH3lnPub06MGlg5imbx8mm1QRT02avpcLpRa1XfiJRuG19oxKcfk2ntNbDPXPX8fbUE/MdSZhyDWWzZuHasJYofTtsfBscpaBrYEuAXpfAoFvAHnmztUAgELRK7Klw25fwv4vAVV4vQ9Uo5ig4ayopF/wfH+jw1soCJs9YztQxXbnz7K5NrixWu318vK6Q99cWoprPhOi1INUvH0kam0TS2KQmhyBLMhd2uZA4iyjzEwgAXFuPGJbt7i0/AMCk3ucDECUrjM0ZangOzeGi4JZHcG/9gajevRkyeAiYLdAg2RVJRiLKrHD90M4I2jYX9k7n/8k/Ge4LpcwHIEkSF/VNr/85m/m0KwFtFcHUtqJqdhZXB9XohiPB6VN11uZXUFDmOCEBCVl10GlSLJZ5F4LVBg0bm0u2weJ/QLcL4PzHIaXbCc9TIBAIWg3JuTB9GXxyL+xbFEhQqZ7g4yx2MNnggscD/VYErCRuGZnNeWek8ei8LSzYfIh/Tu5P38z6PVFltR6e+mI7CzYfQpYkXD4V2TKa6JwNSFJk9hUAFtnCHf3uOJHZCgSnJe4dFeje4N9S16ROyJLMw58/yWW9zmNgZh8SbI31hpuIGXcVnWc9hWwJZJOuW7CV2SsK8DXIeIWbkdDRuWLQ6ZN5aK9YTDL/uKo/D7+/Ebcvsnu2zSzztyv7YTUdLQlV/bD3R8ZWriC+xgtKF0jpCdmjA36xbZhWEUzNWro3yBguElMwTdd5fVk+f74sQhWQigJ4/WKsjhIkWQ0OpAD8R6VDd34Be3+E69+H7FGRXUcgEAhaI/a0gCJqdRGs+R9smA2uCtB8YIqCjAEw+iHodn5AJKgBWYnRvHnbWXy8/iC3vr6ayYM78dD53bGZFfKPOLjmlRWUO7z1Fso0bxrug9diy3wXSW7EiN0Aq2Llb2P+Rm6CMFYXtB9q3D5Kajy4vCqxNhPp8bafX04B1WH8G4q1xvDxDS/w0qq5/Parf1LqKOec3GH846LfkhoTnAGWoxKOB1IAt47MYc7K/RiVAYfKSJgViUv7dyTWZtx+IWhbXNwvg8PVbp7+akfYAZUsQfdUO5f0ywiUk699HVa9DKqPc/0eZE2F/eaA76E1DkbeDwOuh6iE0CdvhUh6E6pOQ4YM0deuXXtSB+D2qQx4/GvcDUx2XXvXUfLh43T+zbywGh2jLQo//Xk8SrgN0Y4jMGMUOEoCJX3hYo4OlMd0HBD+ZwSCVoQkSet0XR/S0uP4pZyK+1O7RdcjXiksrfHw58+2sq2omt9d1JPH5m2h3OltVDhQsW8nKnMuoAUWsxrBLJsxySb+PubvnNO56ZISQdvndLg//dJ7k67rrC2o4NXFe1m0sxSzIiFJEtrRH9PkwVncNiqHnJQYjryxFfeO8pDn3FNWwAML/kpOYhYvXvanoP3RA9NImtKz3rbXluzlX1/vwuULXwJbkSTS46188UDjveyCtsmXPx3idx9tRtV0Q0NnCPRIybLEny/twzur93NF7HauL/gDkq7+nJwwwhwNihlunAdZg0/SDH4ZTd2bWjwzVVrjMVSEilSC06/q1Lh9JESH2fD48TRwltULpLKfrcHpg30P2omxBMb02novczb7+PHWoyWEPie8PRl+tQOUFv/nEwgEgubnBEouUmOtvHj9IBZuLeb+uevxqbqhtIVj2491+mCtWLPiSLsimpjuMSB7kSSQkIgyBbywJveYzPVnXE+mXZQMCU5/9pc5ue2N1RyqcuPyqeg6NHxvfWfVft5bc4BRuSk8HRsPEiF1ZLold+GavhcxZ+OnBns1lD3vwAufQ1RioFd84E1MHdOVyionMxfvwSOHDozMikSq3cr7d48UgdRpyMX9Mji/dwe+2XaYGT/msf1Q9fFeWZ+q0TM9lunjcrmwdzoWk8xF5g2YPv4/JBrxK6yLzwk+4M2JcPOn0OmskzuZZqbFowGnV0U2eHBHKsGpyBK1Hn94wVTlAShYFihlaYCqw3OrvDw6xtr4530u2L0wcMMRCAQCwXF6Z8SBJGEUSjXWB1v+/U8o9qHIllIsFjfn98rhwh59ObfzuViVJu7FAsFpxM7iGq5+eTm1Hj8GNk/H8Wk6aDrL8o7wx1gnvzeZoEH51Z6yAr7LW8Flvc4lIy6NourDfLL9OwZ1DG6HkPAS5ZkHR/IDG4o3ww9PofecwI2rzKRaejPD3o8at88wI2Ezy+g6nN+7A09O6hv+oragzWFWZCb0y2BCvwycXj8VRz3KEqLMxFjrhBQl24n5bBqEE0jVxeeEOVfCfWsgNj308a2EFg+mYqwKqsFdI1IJTr+mEWsNcyVk9auNmlY+MtLCP5Z5uOcsCwm2RlZnvbWw9D8imBIIBIIGvLkiH6Py8VB9sP6awN+9wC4pjv+MH3NqBiwQtAJKqt1c++oKqt3hq5x5/BrfVzm5S7LTUG84xhLNxqLtzFzzPtWeWuKsds7PHcFj59wTdB5FKsIi5/+8wXfUB27rPNKiZG669RNu7jqaFXllvLwoj02FVbi8KiZFIinGwo3DOjPlrM4kCin0dkW0xUS0pZEwYtE/wV9f0Cis6i8IfG7VK3B+cDlqa6XFg6nUWONVx0hNwaymgItyWKyfDapxtDyko8K4bBPPLPfw13ODzYCPc2gzVB+CuIzwrikQCATtgA/WFgYpgAF4Du5A93uJ7jEi5DnySmoprHCSlRh9MoYoELQ6nv12NzWNBFL1S2PrW8R4NY05Wg33S9GYpJ+reDJiU5kx6fGQ15VwEWf6sJF9OpKiwrtT4I6vGdmtLyO7GXt2CgTHcZbDzs8N9QjCqv5SPbBmFpzzaKCPqg3Q4pbDVpPClYMyMRn0TcUNvZLEc++gasV7FP73Bgpn3ErN+gVEda//MLYoMjcN72LYe+X1a+w6XMOa/HI2HqjkYHktuKuaHNMT51j572ovpY4mhClMVqgpCm+SAoFA0A7QNJ1qt7G6WCR9sGaTzOFqA5l2geA0xOn1M2/DwSB7GAiUxpZ/N5P44deQdd8cMqe/TuygCbh2rzp+zHxZZ59VASXSXkcPI16+lux/f47D+/O1X1vvZdwbddSNfQ6YcxVo4QtRCNoxG99utO/2kZEWnlnuodIdoslPV2HHgpMwuJNDi2emAG4fncNH6woNbyThmIIhwU0jutTbVFjhZPbyAuau3o+OfrwvS/E7WWuSMDXRrdk3TWFiDxNPL/VyRmoT8eaxVLhA0FxU5AfKULd+Ap5qQA/Ihp4xEYbdDUldW3qEAkGj+DSt0V74SPtgPREoiAkEbZn5Gw4avnuGaxGjAr9Wa5mXnIpS7gF/OArFbizSHmSqwssWeGth9zfQ86LwJyZonxSsaPT9OOzqL28tHFwPfa44SYNsXlo8MwWQm2pnUJdELBGvqoBF8zNCKyP9aH+TX9V45MNNnPevRby+fB+1Hj8Oj0qN20+N20+l34zUhBz8MR4fZ2Pmei8Hqxs79uhLrkDQHJTuhNcnwItDYfVMqC4MBFOeGqg+GPAAemkEzBoPh7e29GgFAkOsJgWpkRXJun2wodB1iItqG+UdAsEv5fOfDuE0EHaIpDS2Bp1vByUSMzgNTDKS2fj1TsKFhBu78hWplscAPbxsgbcWlj0b7pQE7RlXRZO7w6r+goCFURuhVQRTADNuHExanM2w3K8xLIpEp9RY/uDeRP41U6jdk8fN/1vNgk1FePyaYd0+SOTroRVCuiXJTOlj5vnVjSiRqD5Iygl7rAJBo+Qvg5nnQsHyQOOlUT+f5gt4NBxYCa9dEDCQFghaId3T7Ibb6/bBOnetQPO50VU/rry1VPzwv3rH+lWN7JQYw/MIBKcbFQ7j94xISmN9qk6p00viFd3p+PthxF2UjZJorVP650PhIPGmWWRYbyTB/BqSFAjg6mYLmuTgeqgpjmRqgvaIOarJ3XWrv5rE0naeAa0mmIqPMjP/3lF0TY0hyhz6xhFlVuiVEcfH944h91//IPGmG7n/yQ9Yv+8IrhAOzS+rl+LQQ8vt/nGstV4d8XFkE/S9CqyxIc8hEDRJ8RZ4++rAql8oo5Bj+BzwznVQtOGkDk0gOBHuHptLjMX4Hh5OH6wiwWUDOmK3tooqdIHgpNNYNrduaWw4vLx4L4P+8g3/WZyHs18yGb8bStaTo8n86wiybFeSYbsLu+krZCnYPDXsXvFq0SsuCEFSDkhNv8eHrP5SrJCY3fxjO0m0qqdVit3Kp/eN5qN1hcxYlEe5w4vLqx5/xZQAm1mhQ7yVe8Z2Y9LATCymQDxYPHo8KzYuxW2QjWqohDMjLZvxY32cX7/NivyH6gdHneJl3L83KOWTzTA8WF5UIIgIXYf3bggER3UI2zz6nevh4a0gt5o1EYGAi/ul89j8LY3uD9UHazbJTB0jegMF7Ydku7GkeKQWMQDlDi+vLt7LK4v3MrZHKs9cfSbxsgdkBbTGZddFr7ig2Rh0S0CEoonvSt3qr35pjXzf+k0+SQNsflpVMAWBYOmG4V24flhn1hZU8N32w5RUe5AkSIuzcWHvDgzolBC0kjNr6T58BgFuYyaRT23/mJGdS4iWIlSMUqzQaSik9/0FsxQIgP0rwFFquCushmBPFexbBLkhBFoEglOI1aTwwLndePbb3bgiFJGwmmSGd02mRweR9Re0H64YmMm6/IogQ9xILWKO4TkqQPHjzhImPLeEj6ePoIMe+rf4+Dgbg16p5dcjGnvu6GATveKCEGT0h8QcKNnW5GF/HGvlrc1G6q8SdB0nTHubA0mSOCs7ibOyk0IeW+P28dmmoiDz36aUcPZ1O4tF2nOMlTcRLYXp0KxYIKETXPt2RHMRCAxZ9jx4nYa7wjOPdsCy50QwJWh13Hl2V3YdruGLn4ojCqhS7VZeumHQSRyZQND6uLhvBr9vJJsbN/RK5JhEqla8x5EFzyBZorB26EbciCkhz+tTdYqr3Ux5dSUL4ntir9zR5PEhswWqr02VXglakDG/hk/vD1TRHCXs6i9zFIx+6GSPsFlptcFUJGw8UIlFkY+vxhyjKSUcHZl7fQ/ylOk1LlNWYMWLIjXes6L5ZaT0Pki3zBe9UoJfjt8De76lsT6psOVD85eCpxasxk3/AkFLIEkS/5x8JgnRFt5eWYBP04MWu+oSbVHITo6muMpN/hEnvTuK1W9B+8FmVrhuaGfeXJ5vKJwVqjS2KVNfVdM5VOXmvxnT+Z3+a9zwbQ4AACAASURBVGSp8VI/aCJbICkBmWrx/iMIh75Xwa6vAl5RkZSGmqMDNjBdRoY+thVxWgRTVS4fmsFLaSglHA2Z//PfyfvqOKaZPucceSM2iyXgvqzrYLIE/pven8oD6XhLupMelXCypyNoDzjLA0ImmrHBKQQagkf9z8GDw4zr6YGAO7izTARTglaHLEv8YWJvJg/OYtbSfXy2qQiTIqGqOjpgkiV8ms5Z2YncPTaXUbkpfLHlEHe8uYaPpo+kY0LTilACwenEfed04/PNhzhc7aaJdYcgGmtlcO1ehS2rDxAo+5uzL5ZfxZiwUj+YCjtbYLLAiHsjnpegnSJJMGkGfOSH3QvrZagaxRwNg2+F8/540ofX3JwWwZQsSUgEl0KFaxK5Xu/BdF8PMi0Olk3wBl5ONT/YEiBnLKR0I66igr0TLyXh6quxnXHGyZyOoD2gehp1CD9GWA3BkmwspS4QtBLOyIjjmavP5E+X9mbxriOUOzz4NZ2EaDNDc5LJrBM0TezfkUOVbm59fTUf3D2SeOE1JWgnJERbeP+uEVzx0jIqnT78YURU4Zr6BjbG8E2Px5iY/7fwXmzrolghcwik94vsc4L2jWKGq1+HNbNgyb8Cfd7e+oJbGhKyOSrQH3XOY21KdKIup0UwlRTTPEo4flsSDDzfcJ8pMZHUBx6g+C9/pcuct3D7dRZsLmJLURWVTh92q4muKTFMGphJsj207Lrg9EbTdJblHWHhlmJKawMiJ2mxNi7ul86IrslItvgms1LHCNkQrPnAFt+cQxcITgqxNjOX9M8IedzUMTkcrHRx11trefP2oVhNoa0yBILTgU5J0Xz54Nnc+/Y6NhVWoWp6k0FVJKa+Dq/KB55hTMweA/mLwy+9UiwQlwnXzg13GgLBz0gSDJ0KZ90REMxaOQPK8sDnwKfE8H1lGhfe/DhS1pCQC8ytmdMimBrUOdFweyRKOBZFYtKAzCavkzD5KrbO+4oXn13AggoLkkQ913KrSebvC3dyTs9U7h6by8BGxiU4fXF6/by9cj8zl+zF4fEHqTN9tL6QuCgzd43J4broDGw1+5s8X8iGYGs8RKc05xQEghZFkgLlgfe8vY5HPtjMs1MGIEdg5i4QtGVSY628f/dI9pbW8vqyfD5cV4jHrxqW/kVi6gtwpNYLt7wNn9wL2z87GlA1kQGzxEByd7h5vlDxE/wypKMKfV3HHd9k0nV+/9R39I7pQ6c2HEhBKzLt/SVYTDI3Du+MRQmeTjgmkQC618v4VfPxHjjQ6HV+3FPGtNzJfFwMLp9aL5CCQF2y16/x9bbDXD9zFa8sykPXIyh+FrRpDle7ueT5pfzrm52U1HiCAikIBN/FVW7+vnAnl7uf4IjSIeR5GzWPNkUFatiFz5TgNEORJZ67diCFFU7+sXBnSw9HIDjldE2185dJfdn6+HhuH5VteEykpr6qpgdKr654Ba5/H7pfECjhM0cHSsaRwGQL/MkYAJe/CFO/hSixMCxofiRJYlDnBDYcqGzpofxiTovMFMBNI7J5fVm+4b5QSjiyBENyksnyRJF/9TXEjBxJ8rSp9XqjftxZwvQ56wKmwCFWgXQ9EGw9++1ufJrGfed0P6E5CdoOFQ4vl7+4jCM1nrBq3d0+jTzVwpX6/2OB5f8RJ/1cchF2QzA6DLr5lw5dIGiV2MwKr91yFpNnLCczMYqbhv/ssr6npJbZK/LZcrCKWo+faIuJbml2bh7Rhf5ZQiRIcPogyxLp8VGYFSlI6S/SVoaEYz2IkgQ5YwJ/aoph55c/94pHJQZ6xdN6nYzpCAT1GNg5kfUFFVx2ZseWHsov4rQJpjITorj/vG68+H1exEaRMVYTT08ZRFryGJKnTaXyvfc5cNfdWLt3J3naNCp79uOet9fj9mlBn21KktTlU3nx+zzOzEpgTPfU5pqqoBVy11vrKKsNDqSa+n74NThIIvf5HmC25e+RXdAcDUNug+jQPmwCQVslKcbC67edxdUvryA9zoZZkfj3N7vYWVwT1E+yubCSzzcfomOCjQfO685lZ3YMMncXCNoiI3KTUeTgYCqSVoYos8IFvQ0qIWLTA88SgaAFGNQ5kSe/2N7Sw/jFnDbBFMC947pRVuvl3dUHwgqoJAliLCbm3DGMLskxACh2O8l33E7iTTdS/emnFD/+OK92HosvqW/Q58ORJHX5VP7zzS4RTJ3G7D5cw+bCyqAHXTjfDxUzq/Q+HFA600ltun/qOObowMrhBX9t7qkIBK2OLskxvHrTYK6buRJNJ8hP8Bja0YqAvFIH//fRT6zIK+PJK/qhiH4rQRunT8d4OidFs+twbdC+cE19NV1n8pBOp2rIAkFY9MuMZ1dxDW6fis3cdsWGTqtgSpIk/nRpH3JSYvjnwp1omm7Yt6JIYDbJdEu18/x1A+maGuzRI1ssJEyejO2yy1nw+Nc0jM0ikSTdWlTNviMOclJimmeiAkPcPpVKpw9Zgrgo8yn7Yc5aug+fVv8FL5LvhyabeSPhXv5Q9SdQ/aA3YqoomQL17v2vgUv+LXqlBO2GVfvKUTXwqsaBVENcPpVPNhahyBJPXiHknAVtn3vGdePReT8F9WpDeK0Ml/TLEFYDglZHlEWhW5qdrUVVDO7SdittTqtg6hg3j8jm2rM6s3BrMS8vymP7oWpkSULTdSwmmUv7d2TqmK70TA/t5P3tjlIwKaA2EJuIQJJU03Rmr8jnT5f2OdEpCRohIPhRzIwf89hRXINZCaxC+1SdAZ0SmD42l3N6pZ201WmPX2X+xoM0fMeL5PvhU3XeLc3m0XsXoax+GTa/C7IZ9KMnlaRAkNX/ahh+D6QJnzNB+2FzYSXPfrvLMJAKVWb98fqDjOmewkV9Q0uyCwStmQn9MpixKI+9JbX4InH1JdDK8PAFPU7SyASCX8bAzgmsL6gUwVRrxGKSufTMjlx6Zkd0XcfpVTErMhZTZKv5ew7X4vQErwRFIknq03S2FVVHdF1BaD7deJDH5m9B03UcR/8fqXUeMusKKnjw3Q1YzQr/uuZMzumZ1uxjOFLrNTSMjlSy1qfpVNm7knTpszD+SShYAa7ywM6oROg8HKyhg3+B4HTj5R/zDEv7wi2zfuH7PSKYErR5LCaZuVOHcdkLyyipcQeVlTdGtEXhzduH0ikp+iSPUCA4MVLtVuauLgi0S2g6KTEWRndP5fwz0jAZqHS3Rk7bYKoukiQRYz2xqVa5fIYuDHUlScN5Ya5xN1K6JTghXl6Ux7Pf7jIUBamLw6vi8KpMn7OOv1zel6ubuWbc6fEbVttF+v0wyRIOjz9gQG2Jge7G5tECQXui3OHlux0lQR47kZTR7impZdfhGnp0EIsRgrZNst3KFw+M4fY317CtqLpR/ymAGItCjNXE7DuG0itdeEQJWheqpvPppoO89GMe+8ucePwa+444j++ft/EgZlnm1pHZ3DIym8QYSwuONjTtIpj6JcRFGf8TRSpJareJf+pjuLwqu0tqqHL5MCsyKXYruakxYStvfby+MKxAqi5un8YfPtlCaqyVcSEyVJqmI0mENZ5oqwnNYBiRfj/8mn7CAb9AcLry2aYijH6GEZXRahpzV+3nz5eJMmtB2yc+2sxH00eyubCS15bsY+HWYsyKfPx34vVr9MuM5+6TXOIuEJwoLq/KtNlrWb+/wrAHEDhabaQyY1Eec1YV8N5dI8g10Dc4hk/18d3+71hetJwKdwUmxUR6dDoTu06kT8rJv/eLt7cQdEmOIdqiBP0Pj0SSVJGge1rjX4L2Ql1Hd0WWjt/8/apOaqyVu8Z2ZdKAzCaDCpdX5ffzt0QsUw+BgOrX729i9WPn13vAqJrOjztLeHlRHpsLq/D6NSQJYm1mrhyUyW0jc+icbFwikWK3oBvkLiP5fgCYZUk0BwsEDcg/4jD8rUdSRqtqsPdIsApaODi9fuZvOMiS3UeocHixmGQ6JkRx9ZAsBnVOFNLrghajf1YCz183kCqXj/1lTmrcPqIsCh0TougQZ2vp4QkEhvhUjZtmreKng1WNKrPWxePX8Dq8XPHiMj5/YExQuWqZq4zZ22bz/s730XQNp//n7JaExIe7PiQjJoPb+93OpV0vRQmz9SJSRDAVgov7ZvDYvC2G+8KVJDWbZG4a0cXwHO0Br1/jkQ828dXW4iBvmGPsL3fy5Ofb+euC7Tx77QDG90k3PNdnm4sMt4fTPwEBxb/Fu0o5p1cgO/XRukKe/GI7Hr96vO8KAsbLVS4fc1YWMHfVfs7slMC/rzmTrMT6P2SrSWHSgEw+WHuAhiXsYX8/FIlrh3YWK4gCQQNqPcbl0ZGW0Rr1vTZFYYWTl37IY96Gg0gS9RbTJAk+3VREit3K3WO7MuUs8dsVtBzxUWb6ZcW39DAEgrB48vPtbCkyDqSaWhCv9fi5fuZKFj1yDvLR++2uil3csfAOnD4nXs0bdD4dHbfqZl/1Pp5a9RRf7P2C5859jihTVLPPSwRTIYiyKFw1OJN3Vx8wDAJCSZICdE2xt9uaZa9f48bXVrH5YGXIVYhjLywPvruBxy/rw5SzOgcd8/KPeUFZwkj6JxxelZcX5XFOrzSeWbiDWUv34WqiXDDQ5KuzLr+CS55fynt3Da/3//JQlYtqty8okDpGON8PWZK4dWR2k8cIBO2RhGjjOvlIy2jjIsj6riuo4Nb/rcblUw3v+boeuFftL3fylwXb+XzzIWbeMoRoi3icCgQCQWM4PH7eXbPfsNog1IK4pgd6aBfvLmVczzTyq/K55ctbcPgchtVBDXH5XawvWc9d39zFrPGzMMvNWwnUNmQyWpg7RnfFpJzYyqMiSdx5dk4zj6jt8Ov3N7L5YGXE/U1/+nQrS3aX1tteUuOmsNIVdHwk/RMAa/LLmbl4L7OW5jcZSNVF1XWqXD6ufXUlRZUuyh1e/rpgGxc9u4ROSdEM7pxwXJY9EiwmmRG5yUJpSSAwoG9mHDGW4MxT3TJa564VaD43uurHlbeWih/+V+9Yq+6nV1Uh/tLSoPM0ZMvBKm6atYoaj98wkGqIy6eytqCCm2etxhemB5ZAIBC0R+ZvOIhsoH58bEE86YLpRPcciWyxISkmorsNq9cS4fCqvLJoL37Nz7SvpxkGUhVLKtj9+91svXMrOx7YQdGbRaiOwAK8R/WwvWw7/1n7n2afm1hKC4OclBj+cVV/fvvR5oiCgiizTKekaP77/R76ZibQrZ31Te0oruab7YdPuL/pD/O38MNvxh3vS6hw+LAoMt4GGa5IZchNssw/F+7Aa5BOCjWuGncg1Vzp8nFJvwy+fvhsOsTZqHB4mfD8EkprPGG9hEGgTyo9zsbz1w0M63iBoL1xUd90Hp33k+G+cMtoddnExQVLyLvkaaw5OdjPPZfY887Fkptbr+fJ7VO5adYqw4bopu4LHr/G1qIq/rlwJ49OOH094HyqxrfbDjNr6T4OVDhx+zSiLQo902OZNqYrI3OTRQ+ZQCBolNeW7sPpC76/RrIgvn5/BZ/sWES1tzookDry5RFKvywla2oW9t52fBU+it4qIv+ZfHIey0E2ybhVNx/s+oD7Bt5HtLn5FrFFMBUmlw3IxK/pPDbvJ9x+DT3E+3KUWeGGYZ15dMIZfLiukCmvrOCvk/pycb/243cya8k+Qy+McPubDld72HigkoGdEwEaTeVG2j+h6jpGSaRwxqVqOvvLnXxw94h6BnOJMRY+uXcUU15dyaEqV8ig22aW6ZIUw9xpw4izCeEJgcAIq0nhurM68+aKfMN7SagyWkmCs3ul0e+Wp9G9Xhxr1lD73ffsn3YnksVM7LnnEXveuUQNHMiXWw4FLdRAuH5WGnNWFvCrC3pgM5+cBueWQtV0Xvh+N7OW7kOt4+kHgb7SQ1VuVu8rJ9Zm4uHzezDlrE4iqBIIBEEcrAiuLILIFsQtJpk3Ni6oJzQROIdKyfwSMu/IJLZ/wAbDkmqh0z2d2PXILqqWV5F4duBdUpIkPt/7OVf3vPoXzuhnRDAVAVcOyqJHh1ie/343i3aWIknUe2k2yxKyLNGnYxz3n9v9uMjBNWd1oldGLNPnrGdjYSWPXNizzRiRnSg1bh+fbS6qZ6ILkfU3efwqMz5dz99Ty/Ds3EHt7gN40y8BpX7wEWn/hKppqA1SzZGMy6TI/LCjNMitOy3OxucPjGbuqv28ungvDo8fR93GdQI9ePFRZu46uyvXDu182r14CQTNzW2jc5i7ej8+NTIRCQCrSeaB87oDIFks2EeNwj5qFB3+8Hs827dT8933FD/5FP7iYv479iEcxNT7fCT3BYAFmw8xeXBWxONsrbh9Kne8sYb1+yuaLIl2elWcXpXHP9vGpsIqnpzU93iTuEAgEOi6jreRUuhIFsQ1TaOg8jANbtU4dzvRfBpxg+vrEyg2hdj+sdRurT0eTLn8Lt7a9pYIplqSvpnxvHrTEMpqPby75gCbCyupcvmxWxVyU+1cO7QzOSkxQZ/rn5XAZ/eP5oF3NnDz/1bz3+sGkmy3tsAMTg0/FVZhlmXc1P/xRJLO1XRYUVCJq2gTtp496HXBhaT9WEthlafecZHKkMu6jtZg5TSScXn9Gh+vL+Q343sG7Yu2mJg6piu3j8pheV4ZX28rpqTGgwSkxVq5qG8Gw7smiZVbgSBMMhOieO2WIdz+xpqIy6yfnNSP/lkJQfskScLWuze23r1Jvf8+tv2UR9Hc7TRMfkdyX3B6VV5bsve0CaZUTeeut9axtqAiLAljCPSQzd9wkGiLwh8m9j7JIxQIBG0FSZICbRoGAVUkC+K6rxaTXoOvwXa1VsVkNyEZlB2Z4k24CupnxYocxsrQJ4oIpk6QZLuVe8/pFtFnkmIsvHn7UP719U4ue2EZL90wiDM7BT/oWws1bh8Hyl3UevxEH/WvSArThbrK5TMsyou0v8ljiabjk08e//t0qYAnP98eVHcbbv+E1SQjIeP+hX1XVa6GP+X6yLLE6O4pjO6eEtb5BAJB44zMTeH1W4cy9c01+DW9yZd7syJhkmX+Mbk/l57ZMazzFyl2zBYz7gZS7JHeF4oMBHJailqPn3nrC3ljeT4lNR68fo0oi0LvjDjuGpvLmG4pTWaP3luzn9X7yiOWMHb5VOau2s8FvTswvGvyyZyiQCBoQ2Qk2CgocwZtj2RB3CdHY49xBwVTil3BX+tHV/WggMpf5cdkrx/ueNVgKfVfggimTjGKLPHbi3rRPyuB299Yw2/G9+S6ocES4Eboun5KMho/FVbx6uI8vt52+Lizuq6DV9UYnpPEXWNzQzYb1zXlrbc9wv4muUE15KSBmfzl822Gx4YjQx7wgwkO806k70ogEJw6RuQm88NvxvHWygJmryjAr2m4PX5UJGQp0KeqA9cM6cRto7LpkhxcIdAYDq8frRnuCw0XaVoCp9fPXxZsY96Gg8iSVE9Qw+PXWJ5XxqYDlURbTPxmfA9DCwpd15nxYx4ug2bxcHrI3D6VVxbliWBKIBAcZ+roHJ76YofhfSXcBfFu6SYqLbU0jKaiu0UjmSSq11UTP/Rn3zXVrVKzuYYOkzvUO96qNG9lmAimWoiL+qbTvYOdu95ax8b9lTx+eZ+g/hld11meV8Yri/aytqAcl09FkSRibSYuG5DJbSOzyTYoKTxRyh1e7nhjDTuKa/D4VTSdoFXJxbuPsK6ggmS7ldm3D230+sl2q6FIR6T9TbENxBlirCZ+P+EMnvxie9iy5sewmWVuG5XN7OUF0KD8MNJxxQhPGYHglJMWZ+PXF/bkgfO68/2OEjZ/9h3VXo0Oo4fTKSma8X3ST6gP0W41IRus/kR6X7CZW7YXttzhZcorK9hf7mwye+fwqji8Kn/+dBvbiqr582V96i2OrSuooMwRvHIbbg+ZDizLK6Ok2k1anK15JicQCNo0VwzK4skvtje6P9SCeIxFYeqYLvxtS3BlkBKtkDYpjaI5Rcg2uZ6anznJTMLI+lVgGfbmFYM7vVUQWjm5qXbm3zuKGo+Pa15ZwcE6JSILtxxixN++Z9rstSzeXYrTq6Lr4Nd0Kpw+5q4qYPyzi5k8YzkFZY5fPJaSGjcTnl/ClqIqXL5AINUYDq9KYYWTS19Yys7iGsNj+mfYUTR/0PZI/GHMisQVAzKDznHjiGxuG5VDVAQvTTazzG/H9+LOMbmGNbuRjEuRJEZ1E+V7AkFLYVZkxvdJ52a5iIc6erjv3O5cPiDzhAVdclPthj5RkdwXAHIiyIY1N26fynUzV7LviCOiHqf31xbyr6931dv+3poDhqvHkfSQScBnmw+FNQ6BQHD6Y7eauGZwpxNadJKlgIn7pP496Z1s3I+ZOiGVDld1oPi9YrZN30beX/IwJ5nJ+W0Ocp1rRpuiubXPrSc6DUPE8noLY7eaePH6QcxcspfLX1jGs1MGsKO4mme+3tlks3VAJlhn/f4KJv53KXOnDqdfVnyjxzeF26dy7asrORKBR5KmBzyXrn11BQsfOvv46qO/tJTKjz6m8v33uSL3HN5JGYBHr7/iG246V5Ykbh7ZxfD6v72oFx3jbTz5xQ4kCUNvGAisZCDB36/qz8T+gf6J8X3SWbC5KChgDHdcFpPMnWd3DevfSSAQnDy0mhrkzuGVSTdFdkoMPTvEsqmwKmhfuPeFGIvCnWfn/uKxnCj//X43+UcchvfwUD1Ory3dy8X90unTMfAMOVDuNKwsiKSHzOPXWlUPmUAgaHl+P7E3Px2sYtuh6rAXfSQCVUlzpw1DkSVu63MbO8t3BsmjAySNTSJpbFLwSeqg6RoX51x8IsNvFBFMtQIkSeLOs3PpmxnPnbPX4vZpEQc117+2ks/vH0Pn5MhNyD5aX8ihSnfED2EINDm/9OMeHkmrpeLd93AsX07c+PFk/vd57u3cjXf+/j0Y/GBCpXNlCQZ1TiQrsfH53DgimysHZ/HJxiJm/JhHcZUb89HGQ5+qkZMSw/Rx3bi4XzpW088P/2ljuvLNtmLDMsFw+q6yEqPom3ligatAIGg+1NoalLjYZjnX9HG5/PqDTfV8lI4Rzn1B1lQu7NOhyWNOFj5V460VBYYvJ+H0OPlUndeW7OM/UwYAweXdx4i0h6yxRS6BQNA+sZhk5kwdxh1vrmHTgSrDDHi94xWJWJuZd+4cfrwP9uyss4m3xuP2u9GIsN1DsTGl5xSiTFEnPAcjRDDViuiTEY9P1U8oqHF4/Pz6g418cPfIoM82ha7rvLzoxBuNfarOu0t2M2XXW3SYMpmMvzyBEht4uYkCHptwBn/7MvL+JrvVxD8m9w953J6SWrYfqqZbmp30OBvRFoUh2YlcMSiLzATjH0sPbzkDy/JYF9cZD5GVBdnMMo9f3if0gQKB4KRR4/ZRVutlv1MiyxyDXdN/sa/R+Wd0IM5mxuVtuszZiCgFJhcs5/BvviX90UcxpaY2frDPDVs/huX/hfJ9oLpBsUBsRxh2Nwy4HmxxjX/egK+3Hg7y9IPwe5xUTeeLnw5xz7hcCsqclDuNla4i7SFLCVP9VSAQtB9irCbm3DGMjzccPL4Q7var9bLhMVYFRZK4cXgXbh+dQ0odKyFFVnjtwteYsmAKDp8D3VA7OhirYqVvSl8eGvxQc09JBFOtiffXHsDofSCcoEbTYXNhFQVljohUrNbvr6Cs9sQbjQFkq5Wtf3iOM4Z0CjrPzSOzKXN4eXXx3pArEBDISNmtJuZOCzSUG6Fp+tEf4R6KKt3HxTKOsTq/nBd/yGPy4CzuHpdbL6iq/mohxU88wb8feJC7y5LYcagmbAUum1nmL5f3ZWSu6JcSCE41mqazeHfpcUEeiyKjx49B/bEW+8pvuW10Nted1fmE/ftMisw704Zz6QtLqXX7w3w8B+4Lo7un8thjj1E+42X2Xj6J1AcfJOHqyUh15Ug1FX54ClbNCPzdW6fX1e+Bin3w3ePw7Z/gzOvgoqfBHJ54wzurC+oZhB8jkh4nj19j0ovLGJydRMd4G0WVrqPl5D8TiYRxjEVhYJfWa/0hEAhaDpMic82QTlw9OItNhVX8ZcE2ymo9dE21k2K3cHaPVMb3ScesGPdXdY7rzJwJc7h94e3Uemvxak1LnUeZohicNpj/nPMfTHLzhz4imGolaJrOzCV7gzI4kQQ1mq7z+rJ8/nxZ+JmTJbuO4PqFD2GnX2fh1sNMNgimAB6+oAfZydH85fPteHyq4UPfJEsoskSfjnE8O2Vgo+WKbp/K9DnrWLWvvNESkmPb31m9n/kbDvLG7UMZ1NFOyTPPUPPd93R69VWi+vbhPb/KQ+9u5IedJfhUjUbMuYm2BLJXz107kAt6t0wZj0DQntl0oJJps9fi8PiP3z98qgqyBVTwOLy88N0env9uDzcN78JjE844oUxVdkoM8+4ZyZRXVlLr8Yes6Y+2KFzQuwPPXH0mJkUm7de/Im7iRIr/+EeqPvmEjCcex9qtG/i98O51ULAcfMF1/sc5tm/Tu3BwPdz6GdhClxQXV7kNt0fS42SSJe4/rzt3j82l0ull2FPfYWQjEW4Pmc2sMLZHWsjrCgSC9oskSQzolEDvjDhyU2O4dVRO2J/NTchl/uXzmbt9LnN3zMWv+ev1UcnIWBQLnWI7cXu/27k4+2KUMD0DI0UEU62E7cXV1HqC1e8iCWp8qs4nGw82GkzpmobucqG5XGhuN5rTSXFBUbOY61Y0UhZyjCsGZXHZgEwW7Srh5UV72XigEp9fg6OZqEkDMrltVDZdU+2NnsOvatzxxhrWFVSElU3yazo1Hj83zlzJcwc+5wybn5wPP0BJCKyWWk0KM24czJ6SGv63NJ95Gw5ikqVAtyPgV3VSY63cPbYrkwZmEi3k0AWCU86S3aXcOXtdyMz2sXvC3FX7yS9z8OpNQ476ykXG/2fvvuOrqu/Hj7/OOXdnB5KwRwgzLAHZS0TqrkXFgQMnaLXaob8Oq7X229pqrdaFdDaO5AAAIABJREFUCxcO1NZRUeuoLNkgIhtCWCFAFhk3d5/z++MCJuTc5N6YcYPv5+PBo+29557zuUl6znmfz+f9fudkJvHFLycxf+Ve5n21B18wVCuP6viDn0GdU5gzqRdn9s+sVVbc0bcP3V9/jbIFC9h79TWkXjaDjKy1KHu+gmCUBRmCHijaCvMvhlkfgaX+5XKBCOsSY8lx0g2D4LEnSqkuG9MGZLHw20LTJY8N5ZA5LCrXje/RqJ+/EOKH50BZNRP71LM8OoI0Rxo/Pe2nzB4ym0X7F7H84HJKvaXYVBsdEjpwbva59Evv1wwjrk3uDuNESZXf9MITa1BT6fax96qrw8GSx4PuqcaoDgdPhs+H4nCgOp2oDgeKy4m3y2RI7V9nP7E31234oqmpClP6ZTGlX3h2JxjSjzX3je6C++jnO1m/zzyQqrdaVVDnrs7TWHnfOWh2a53P5mQm8efpg7jn/P5sP1RJhTeIVVPITLLTKyOxRRolC/FDEUvz8S0HK5j9asOBVE2eQIjlu4r53bvf8uDFDeddmkl12bhtSm9umZzDou1HWLm7lBK3D7tFo1OqgwuHdKp3ObWiaaRfeSVJZ55J+V9nY5QsRtG+O2/1eLSS6gDk35FIgi38s3h+vZ/5GwMsmnVsvyE/HNoEa+fB6DkRj1XlCxLppxlLjpNVU0l2fnd+vGNqb77YdiTmIhIK4LJZmDnSvBKrEEKcrOCohy5pjS8KYVEtTO0+landpzbhqGI4fqscVdQR1HWzFRUxBzW6otLuttvQElyoTieKw4nqcob/u91eew0/0PvLXVg+21Gn6EWsicaZSbHnKVgirIU14wuGeGn5HtNCFtHklAUtNhZuOswlw7tEPIbLZuG0bmkxfw8hRGQhPcRXB79i3rfz2FyyGV/Ih6ZopNhTmN57OjP6zqBDQgfTz97z3rcRb+brL/et896GAq4f35M+WY2v9qepCmf2z+LM/o1b3mvNyqJ9bjXsr3veChnw2Co/v51Qz7kz6IHlj8Go2XAsAC0s97BmTxnr9pSydm8Zu4vcpDgtqAp1ZpFiyXFSFYVhNc5/OZlJPDVzGHPmr6u3TUdNCuCya7x+8yjSpPiEECIKhmFQUOah8/cIplqbBFNxIsVZd8YEYg9qHDaNxNGjGtzuuB/lduDxL3bWCaZiTTT+yWl1m+s2pU82HcIwaXwSbU6Z2x9i7qJd9QZTQoim9d7O93hk3SP4Qr5aa9mDRpASbwkvb36Zlze/zOkdTudP4/9Ee+d3xV3yi91sPlhhut/oKo3qvLAsn782cnaqSZTkQeE3pm/dNdbG377ycevpNlIdkWfqQtVH+Xzh2yys7M26vWV4AiGGd09jRPc0/ji0EwM7p1Dq9jP5oUWmOV7R5jh1Ta/b8mFy30zmzTqdm15Zi65T7wxhgk3DadN48+bR5GQ2Tbl6IcSp72h1AFVVSHaY3we3BRJMxYn+HZNNS6LHEtQowOk96m9WdrKczET6dEhi4/doVmm3akzue1KicdF2WPkUbPsI/JXh0dmTIfeicPnf9OiTDAGeX5r/vatVFRz1srWwgv4dYys7LISIjWEYPLr+UV7f+jrekHlxBOBEBaZVhau4+IOLeeWcV+ieHF4eNm9ZPvr3KvcN728o4PfnDyDR3kqXui3vhav4mRjRSWNyDwsPL/fxpymRq/YpwWra7foX48c/zJ1Te9OzfUKdZZIdU5yM6pnOkp3FpvtoKMfJZdO4ZbJ5w+Gxvdqz/Ndn8q91B3h2yW4qvQEAQoaBpijoRrj33pxJvThvcEcc1uZJ8BZCnJrCS/xi75EaTySYihMum4WfnNaZt9burxNURRvUOG0asydlx3zsWyaFm1WaLadp6CJsR+e6Md2+y/cqWA8f/hyKtkEoCEaNohqBaljzAqx7CTqdBhc8Bhl9oxrj/jLzClix5JRpqsLekmoJpoRoZi9uerHBQKqmoBGkzFvGtR9fyzsXvkN7Z3s+3XLItLBCLA9QrKrKur1lTGpEYnOTKC8APRDx7T+eYWfcPDd3jIq8JE4FRqR5GBGhWupxd5/djzV7VsSUXwbh82Jmkp1zBnaMuE2K08r143ty3bgerNtbxv6yaty+EEkOCzmZieR2kibmQojoeYNevj7yNUd9R9mwv4yEVB/uwDASrNG39oknEkzFkevH9+Df6w+YzlA1FNRA+II3JrtdzMf9UW4H3lyzn5W7SxosBVyTRYWO/iqmvvxnAoP+hrV8A7x9LQTqqVilB8L/9q2E56bAzLehe8ONhn0R1uzHklNmGAZuk4qJQoimU1BVwFPfPIUv5KvzXtnSMor/W4z/iB/NoZE8PJmsS7LQEjQMDMp95Ty4+kEenvSwaXVTiO0Bim4YHG2g0mizCtV/7IGZGuf3sfDgMj/9M+rJIW1gPwADO6fw2OVD+dmbX0ed42RRFVJdVt68eUxUM0qKojCiRzojYlwBIYQQAPsr9vPattf4985/oynh874/qGMAkxY8ybk9z+XqAVfTO613aw81JtFXABDNLicziR/ldsBhjf3X4rCq3H9hbqMqz6mqwjNXD2dQ55Soj22zqHROc/HW7y8iY9wYCm84B+PNq+sPpGoxwF8F8y+Bw5sb3DrSuGrmlDVEURQSHfL8QIjm9PrW19GNujfzxR8Xc+jtQ3SY0YEBTw0g+/fZ+Ev87Hl4D/qxhzhBI8ii/Ys46j0acf81H6A0qLULcSY23Gfp/skOnlvvp6CinjbBrugekk3L7cDz15yOy6bhbCA4SrBp9GifwEc/m0CHlOiaAwshRGMYhsE/1v2Dn3zwExZsW4An6KEqUIU74CZgeAjiwRfy8UHeB1y58Ep+v+z3BPW28/Bbgqk489ClQxjSJTWmgMphVbnrR/2YlmteESu6fWi8cfNoLhnWBbtFjXghtllUbBaVyX0z+PD28WSmOGl/8010HVuMotd+Et3j0UoyH6rE7f/uJuH59X4mv+T+bqOAGxZcDSbFJQxdx7N5MyXPP0+n8kOm46mZU1a9YwV6wIsRCuLJW0vZl/NqbRvUdXpltM0pZCHaAn/Iz792/ovASUvbQp4QR947QqerOpE0OAnFomDLsNH11q74i/2UL/8uZ1MxDP697H6ScJ+8eyC2BygqCmmuVqwq12MC2Oo/5+Skq1yWa+WfqyPMPtkSoPe0qA85vnd7Vv32TH59Tj86pzpx2TQS7RYSbBpJDgt2i8qY7HSenDmMT++cSGayBFJCiOZjGAb3fnUvr297HV/IR9CIHCSFjBDekJdP9nzCbV/cRiiah2ZxQB7TxxmbReXVG0Zx1zvf8N9NhwjqhumyPwCnVcMwDP580SCmN0GVOqum8qefDOKus/vxzroDPL90N4fKvVg0haBukOq0ctXo7lw1ujtZNS/AOz9FwfwPPqryv5WFcGAtdD0d/4EC3Mu/wr1iBdUrV6GlpZEwZgw3DB/J77fppkUoos0py26fKFWmhGhGaw6tMX29emc1ekAneXjtfEXNoZE0OImqzVWkTQyX5fbqft4tWMJ5nYbzyl4HJ69Yi6UoT1A3Yi7K06R6TgoX3vGbB4bH3TvJzqsbI+RWGQYMujSmwyY5rFw7tgfXjOnOtwXlHDzqxRsIkey00Ccrqc0newsh2o55m+bxyd5P8Aajy6EF8Ia8rDu8jr+s/gv3jL6nGUfXNCSYikM2i8pjl5/GriNVzFuWz7+/PoBVVUEPoXt94HSS6LBw04RsLh3RNWJZ9cZKcVq5YXxPbhjfE103cPuDOK1a5L5Qyx4NL9kzEU35XyPgwfv8HApWtEevriZhzBgSJ0wk6+67sXYMJ0VfENL5wwOfRRxzQzllCXaNORGqVQkhmkapt9S0hUGoKoQl0YKi1T0HWFIsePbWXh581O5i1gXTmf/IYtAbV+7boipcPLwzTlsrVpdTVRhzG/zvT+GeUcfsubP2Q52uKSree0wK46gWGHIF2BoX/CiKwuAuqQyWjhBCiFZQHahm7jdzTYsR1ZdDC+GA6t2d7zJnyJxabTPikQRTcSwnM5E/Tx/EPef3Z2thBcW793P06acY+OjfGNAxuVH5UbFSVYWk+mr/V5fCwfUR346m/K+CgUPJp8tj87H362f6vayaypzJvXj8i10xV6tSCFdLPPt7LIMUQjQsqAcxTLqPa4kawaogRsioE1AFy4NYEmtfikJ6iK7pLoZ2TWV1fqlZP/MGH6BoqsJ142JrwdAshl0TbhNRWQgmuWT1siXAhF80z7iEEKKZfbLnE9N7uuKPiyn6uIguN3YhcUAigbIAB189yJ6H99Dzdz1RLeGH94qi8Pb2t7ll6C0tPfSYSM5UG+CyWRjePZ0p/bMYdmgbuZ1SWiSQioq7CLT6cxL+eIadx1f7KXJHvpFQNAuO7ln1fq85E3sxoXd7nDEW6HDZNd64aRQ2i/y5C9Gckm3JqErd/5+5clwoFoWKdbWb8Ia8ISo3VpIwoHZe0fHyuP/3k4G4GjGz5LRqzBzVjV4ZiTF/tsk5kuG6j8CRCkoM38WWAFe/CykyrSSEaJte+PYFPMHaKw9iyaH1hXy8tvW1uC9GIXeXbYiamIjurn/tfYsL+aGBwK5m+d+IFBWCdUsp16SqCk/OHMa03A5R3WDZLSppLiv/umWs5EoJ0QKGZA4hEKqb+6O5NDIvyuTg/INUbqzECBr4i/zsf2o/1nQrqWNTv9tW0RjTKdxDKicziRevGxlTQOUwQkztn8k95w34/l+oqaT1gDlLIb0numE1nWk7wZYICZlww+fQeXgLDVAIIZqWL+TjQNWBOq9Hk0N78n4OVx9u1rF+XxJMtSGqy4Xu8WCY5BC0GkcqRPHEoMHyvyE/OFPN36vBqqk8etlQHr1sKMO6pWK3qFjU2sFcgl0jxWll9qRsPv/FJPp1kCa9QrSE9s72jOk0BsWkJnnGuRlkXZzFoQWH2HLLFvIeyMOabqXn3T1Ra8w2W1QL1wy45sT/HtkznX/fOpZu6S5cNi3isxunVcNuUbn0yHr+6NyHqsbJ7P1xKV0ITH+fgrUdMLqOA4s9XJzi+D+LIxw8XfQ0/GILZMVRMCiEEDGq9FdiVeumiTSUQxusqn1Pqakalf7KZhtnU5CcqTZEUVVUhwO92oOWGCclvpM7gdXZYH+pmuV/B2WaxPCp3cL7iYKiKEzL7cC03A7kF7t5/+sCCso9eAM6GYk2RvZMZ2r/rMgFM4QQzWZW7ixWH1pdZ2kHQPqkdNIn1V9dLzslm5y0nFqv9euQzOK7JrNubxnPLtnNl9uPoCoKqqIQCOlkJTuYPTGbnwzrjHVvL/bNmoVryBDs2XGQM1VDyQvzsJ8xC/WGX0HFQSjZBb7K8JK+1O6QHl/jFUKIxrKqVtOeg7Hm0GKATW3FFhdRkGCqjQkv9atq+mDq0Lew+jko2go+N9iToONgGHkztK+nE7WqwahbMJb8HcWkWktNEcv/WhNg/M8bNeye7RO486w+jfqsEKLpDc8azqD2g9hQtAF/qJ6lvSbsmp1fj/y16XuKojCiRzojeqTjD+qUewL4QzrJDguJdst3+ZZ9+5Bxx88o+OUv6fHmG6j2etoytKDA4cOUL/yIXh8tDL+Q3Cn8TwghTkFJtiTTgkQ1c2hTRqaceP14Dm3WJVm1tvfrftIdrdjiIgoSTLUxakJC0+ZNbfkAFj0IZbsh6AejRqW8grWw/hXIyoXJv4XeU+t83DAMKks6k+j31cmtjrr8LzoMuqQJvowQorUpisLjUx5n5kcz2V+xH59efy7kcQ7NwR/G/oFhWcMa3NZmUclIihwkpV52Ge7lKzjy0MN0uOd3J153+4JsLaygwhvAoqq0T7TTv2NSixT0KXnhBVIvughLu3bNfiwhhGhtqqJyZrcz+WzPZ+h8N0NVM4dWdai1qvmdnEMLkNsul1RHw2kgrUmCqTZGTUxErzLv6RQTXYdPfg1fvwqB6gjbBMP/CtbBW1fBuJ/DpLtPFJzwbt/B4QceIFTtxnnRxaj7P4q8r0isLph4V3iZixDilOCyunjt3Nf4xaJfsO7wOvy633S5B4DTEl7e+/Ckh5nYZWKTHF9RFDo+8EfyfzKdyjGjOTRwJC8sy+f9DQfDOZbHYqeQbpDstDJ7QjYXj+hCcn1tIBrgC4b4ZNMh3lyzn6JKH0FdJ9lhZUrfTC7rnUDF+x+Q/Z8PmuT7CSFEWzArdxZLDiyps+w749wMtASNQwsO4T/iR3WqJA9LpuvsrrVyaBMsCVw/8PqTdxt3JJhqY9TEhKYJpj79bf2B1MkCHvjqUdCshIbcSNHjT1CxcCEZt99G6owZ4Se7b14B+Uui36fVBbnTG73ETwgRv1xWF3PPmsvWkq28suUVPtv72YlkZAWFoBGknaMd1w+8nvOyz8NlbVxj2ki0lBTa/+0hbn36f6zp6COgG4RM4rlqf4i//Xc7f/3vNh66ZDAXDOkc03HcviD//GInr63ah2EYuP21++BtP1TJU18EGTn1Z9xrOJFFyUKIH4qB7QfSMaEj+eX5dZb8RZNDa9NsTfaQrTlJMNVGeAMhluwoYmtqf5QtFWRa9jGwcwoDO6c0/OGT7fgvrHu5VtDT49FKqgOQf0ciCbbwY9vn1/uZvzHAolnHZo0C1Rj/+wsFD87HOvwCshd+iCUt7bv9Xv46LPwFfPPmd7NaZjRbuBT6qDlw5r0NllYXQrRd/dv15y8T/sJvRv2GXWW7qPRXYtNsZDgz6JXaq9mW2PmCIa5f7WFzZl98wXqLkZ9oBH7XOxspqw5wzZgeUR3jSKWXy55ZycGjHnxB85m38OsqX4VSuOjJr5h71XAm9smI5asIIUSb9cjkR7hy4ZVUB2NbueTQHPxzyj/R1Nh7DbY0Cabi3P7Sal5evoc3Vu9DURT8tn7ou4JY924BoHOqgzmTczh/cEcc1ij/4JY8ZDp7FDLgsVV+fjuhnoRt3U/nCzuhzb6/7nuqBhc8BqNvhZVPwcYFoFrBOHYjoxD+78OuCRe2kMpVQvxgJNuSo8qHaip3vrmBLQcr8BnRV/X0BnT+/NFWuqa7OKNvZr3bVnoDXDp3BQfKqk1nvE5mEJ4Fm/3qWubfOIrh3eM7oVoIIZpCr9RePHPWM8z5fA7VgWrTohQnc2gO/j757wzNHNoCI/z+JJiKYwvW7OPe9zejGwaB0PGAJBwwBY89Sd1V5Obe9zfx90+389bsMXRNb2CpTPGucOU+E3eNtfG3r3zcerqNVIf502JFAa1oNVQVQWKEp6sZfcNB1bT/gwNrwFMWnolypUOXkWB1NPzlhRCikbYWVvDl9iN4TWaL3FsWUbHmPQIlB1BtTqyZ2aSMnYGjSy4QDqjufW8TS+4+o95Zs9+++y2F5V7TQKq+Y3gCOte9uIbVv5sa/QMwIYRow4ZmDuWN897ggRUP8E3xNxiGQUCvXd3ZoljQVI0+aX343ejfkdsut5VGGzsJpuLUi1/l87dPtkdcOlJTtT+ENxDi/MeX8eHt4+sPqNa/DHrI9K0RnTQm97Dw8HIff5pSX8CjhGedxt5W/8DsidDrjAbHL4QQTemFZfnfPYCqoWL1u5Sveod2036Ko+cwFM2CJ38dnp2rTgRTACVuP+v3lUWcPSpz+/l082H8JufnaI4R0g0++raQ6cO6NNE3FkKI+NYzpSfzzp5HQVUBb2x9gw/zPqbUW47TqpFgTWBy18lcNeAqslOyW3uoMZNgKg4t3VnEXz/ZhjcQxdqRY3QjvOxkxjMr+PJXkyM/8SzZCbpJr6dj/niGnXHz3Nwxqp4GaUEvlOZFPTYhhGgpld4A//nmICG9djCl+9wcXfYa7c69E1ffsSded+WMwpUzqta2nkCIZxbv5tlrzIOpBWv2maZ6RnsMtz/E04vzJJgSQvzgdE7szK9O/xVj0mbxxP928cas0a09pO9Ngqk49ODHkQOp+paP6AZUeAIs3FjIxcMjXKR99feoGpipcX4fCw8u89M/o55cA29ltF9HCCFazKaCCmyaWmdW31ewDSPox9VnTIP7MAxYubsk4vuvrNxreo6O5RgHSj3kFVXRKyOxwW2FEOJUU+L2k55Yz4P7NkSCqTiz/VAleUXmpc+jWT7i9oeYuzivVjBlBAL49+zBu30H9gNFNJSxdP9kB8OeqeKXY+opROGS5GkhRPwp9wRM05tDngpUVzJKlJWhqr0BSl54AcViAYsFxWpFsVhRrBaKy80fNMVyDIumUHjUK8GUEOIHqaTKR7sECaZEM5j3VT4Bk3X4sSxR2V9SxbInXyZ73xa823fgz8/H2rEj9j590Lp0wwjloRj+iGPISVe5LNfKP1f7GZRpctNgTYCOQxr/JYUQoplYNQWzshGaMxm9ugJDD0UV7GgYBItLMIJBjGAAIxCAQBAjGCRgmWza0iHWYxwvyS6EED80pW4/6RJMieawdk8pJnnTMS0fIRBkS5GHgSNHknbV1dhzeqE6neH33MXwjwEQoQXUcfdOsvPqxki5VQbk/qThcQghRAtrl2hHN5mbsnfuh2KxUr1jBQn9xje4n+QkJ1n/727T9xy//8Q0EIr1GEkOuQQLIX6YStx++ndMbu1hNAk5k8cZt8/8SWUsy0dCVivKlGmkTjSpiJLQHnr/CLZ9CMZ3M2B77kyqtVnXFBXvPSZ/5KoVhs4EWwMl2IUQohUM6pyCw6LVOZeq9gRSx8+k9LO5KKqGo+dpKKoF754NePdtJO2M609sa9MUpp/WOeIx+mQl8s2B8jqvx3IMX1AnJ1OW+AkhfphKq/yyzE80D5vFfC1+LMtHNEXBYa2neMTEu2DXZxDwxD5AzQZjfhr754QQogVoqsJ143vwxBe76vSZSh45HTUhjfIVCyj+8GEUmxN7Vg7JYy6rtZ2iKFwztkfEY8ye1Iu73v4Gt7/uw69ojqEAE3u3p31iPXmpQghxitlT7GbR9iMcrQ6w8cBROqQ4GNgphW7t2vYDegmm4kynVAf7SqvrvB7L8hGLppKVXE+ZiY6D4YLH4T+3xxZQWZxw2SuQ3jP6zwghRAu74vRuPP7FLtP3EnPPIDE3cv87TVEY0T2NzqnOiNucNSALTY3c0LehYzhtGjdP7BXxfSGEOFWEdIP/bTvC3MV5bCoIz+j7g+HF2K+v2ssbq/cxpGsqcyZlM7lPJmo959Z4Vc/0hWgN14zpQYK97sxTzeUj1TtWoAe8GKEgnry1lH05r872k/pm1H+gwZfCRXPB6gwv3auPZgNbAlz+GuRMjeXrCCFEi2uXaOf+C3Nx1jdDH0Giw8JfLxlc7zZWTeX2M3NwRurnV+9nFXplJHB6j7SYPyuEEG2J2xfkqudXccebX7Nubxm+oI4v+F1Wqz9k4AvqrM4v5bbXv2bWS2vwmMz4xzuZmYozZw3IQjPrBkl0y0esmsIVI7tit0Rxkc+9KDxLtWoufD0fUMDvBgxADedFKSqMuAFG3gQpkXMIhBAinlw+shul1X7++cXOqBqga4pCosPCGzeNpktaw0tObhyfzbbCSj769lDUVfksqkL7RDuvXD8KJcJ5XgghTgXeQIgZz6xg15GqOn3/zFT7Q6zaXcIVz61gwewx0d3HxgkJpuKMVVOZNbYHzy7dbXoD0NDyEVVRuGZMj+gPmJ4N5/wNpt4PWz+Esj3gPQrOVGjXG/qeC5ZTI0FQCPHDcuvkHHq0c/GHD7bg9gVNc5ysmoKqKAztmsrDlw6ha3p0a/cVReGhS4aQ7LDy5pr9+IIhdLMGV8e4bBpd0py8ftNo0k6RpGshhIjkV29/Q16UgdRxvqDOtkOV/OZf3/LIZUObcXRNS4KpOHTblN4s3VnM5oPl+M3qpEfgtGr84cLcqG8GarE6w0v/hBDiFHLuoE6cnduR5XklzF2cx7pdh/GioqkqSQ4L04d1YdbYHo06b6qqwn0X5vLj0zrz3JLdfL71MJqqhAMrHax6ENVqJTsziTmTe3F2boeIRYaEEOJUcfCoh8+2HDYNpNxbFlGx5j0CJQdQbU6smdmkjJ2Bo0suAN6AzoffFvLrc/uRmVRP/n8ckWAqDtksKq/cMJLz/rmM/aXVJh1T6nJYVe4+uy+Xnd612ccnhBBtiaoqjO/dnvG927PvxptInXklyWdEnuGP1dCuqTw5cxhlbj+fbz1Midsfbr7+2ceM7N2RUbPObbJjCSFEvHt1xV4Mk5vXitXvUr7qHdpN+ymOnsNQNAue/HV4dq46EUxBuOLp6yv3cedZfVpu0N+DBFNxqrDcS6U3wMxR3fjgm4OEDKNO3xRNAatFJScjkV+f05/xvdu30miFECJ+bTlYwaaCciq8ASr97egbTGZqUG/yWaK0BBuXjvjugVZpwQp8+Vua9BhCCBHPgiGd+av24g/VnpXSfW6OLnuNdufeiavv2BOvu3JG4coZVWtbX1DnpeV7+NmZvdtEdT8JpuKQxx/ip6+t5zfn9mfGiK7cd2Eun285zCsr9nKw3IMvoJPosDC8exo3jO9Jn6ykhncqhBA/IL5giI++LeTpRXnsL/WgAAFdR80ahWXpEZRlnzFzVDeuHtM9qoITjWHv14+Kjz5uln0LIUQ8KnH7CYTqLu/zFWzDCPpx9RkT1X6q/SEqvAFSXfGfYyrBVBy674NNDOqcwqXDuwDhohTnDOrIOYM6tvLIhBAi/u0vrebyZ1dytNpft+iEZsPn1wGdF7/K5+Xle3jgooG1ZpSaiqNfP/SCTRi7l6DoAXCkQma/cKsJIYQ4BVV6A6Z9+EKeClRXMooaXZU+i6ZQ4QlKMCVi9+7XB1i7p4z/3D5eSucKIUSM9pdWc8ETy6jwBOqtrgccK/Bj8Pv3N+H2BZk1rokakvurYdM7aMsepcfkA/DGFaCqgAGhIAy5DEb/FDLaRj6AEEJEy2HV0E0K+GnOZPTqCgw9FFVAFdINnLa2UR5dygrFkbyiKh74cCtPXDmMBLvfCHTKAAAgAElEQVTEuUIIEQtvIMTlz66IKpCq/TmdBz/ZxrKdxd9/EDs+hYdz4ONfQ2keqqajBKrAVwG+Sgh6wn39npkAC66GgPf7H1MIIeJE+0Q7IZMTsL1zPxSLleodK6LeV6rL2pRDazYSTMUJbyCcJ/XLaX0Y0Cm5tYcjhBBtzkffFlJWbR5IubcsovDlO9n3yCUceOJqDr91H94Dm0+87w3oPPjx1u83gI1vwVvXhJufB9yRt9ODEPTCzk/hxXMkoBJCnDIcVo1puVmcvNJPtSeQOn4mpZ/NpXrHCvSAFyMUxJO3lrIv59XaVlMVzh/cEavWNsIUmf6IE39auIVemYlcObJbaw9FCCHapLmL86g2acwbbTneXUVV7Dhc2biiPnu+gg9uDwdJ0Qp64cgWeHsWXPlm7McUQog4dNOEbL7YegRPoPb5OHnkdNSENMpXLKD4w4dRbE7sWTkkj7ms1nZWTeGG8dktOeTvRYKpOPDhxoMs3VkseVJCCNFImw+Ws7/UU+f1WMrxBkI6LyzL568XD459AB/fVSeQ6vFoJdUByL8jkQRb+Nz+/Ho/8zcGWDTrWBGKoBfyF0PBOug8PPbjCiFEnBnSNZVu6S52Hqmss1IgMfcMEnMj9/nTVIWcjMQ2tUpLgqlmdLjCy/yVe3l/w0EqvAEMA5IcFs7O7cC1Y3vQNd3F3hI3972/mZeuG0myo22sDRVCiHizqaDc9PVYyvGGdFi7pzT2gx/6Fkp3m+/TgMdW+fntBHvkzwe9sPwJuPTF2I8thBBx6PlrR3D+40up8ASJNoVVUcL3yc9dO6JZx9bUJJhqBvtKqrn3g02syCsBws3Hjiv3BHh5xR5eXbmXwV1SOFod4LYpOQzqktJKoxVCiLav0hskaFJCKtZyvFW+YOwHX/EkBP2mb9011sbfvvJx6+k2Uh0RVh4YOmxbCNWl4EqP/fhCCBFnuqa7eHvOWC5/diXl1X5CDURUFnRSExy8efMYOqY4W2aQTaRtZHa1Id/sP8p5jy9lyY4ifEG9ViB1XCBk4AvqrNlTxu4iN70zE1thpEIIceqwW1Q0k2XSNcvxRrefRpTizfsfGOb7H9FJY3IPCw8v99W/D4sNDqyJ/dhCCBFH9Brr+vpkJfHx7eM4p2wbDhVcJqXOXTYNp1Xl7IMbeO/cDuRkJlJU6eO9rwt46at8Xvoqn/c3FFBS1cA5tBXJzFQTyiuqYubzq2J6shkyDG56ZR0LZo9mcJfUZhydEEKcurKSHVg0FU56gFWzHG9Cv/EN7qdjiiP2g/ur6n37j2fYGTfPzR2j6mk+aejgORr7sYUQohV5AyE++OYgzy7Zzd4SN4GQgUVVyEiyc+2Y7pxzdDu/LF7Jg4/+jPc3HOQ/Gw9S5g4AkJZg5aKhnblwaCe8/6pg2RPzuG/SlSzbWYxFVQgeC8wsmkIgZDClbyY3TcxmWLfUuKoxIMFUE5ozfx1uf91Ayr1lERVr3iNQcgDV5sSamU3K2Bknqkh5AiFufHktK39zJqpJ12ghhBD1m9gnA8Oou46kZjleRdVw9DwNRbXg3bMB776NpJ1x/YltE2wa14zpEfvBlfpnswZmapzfx8KDy/z0z4i0IESFKJciCiFEa9N1g79/up0Xl+8BqFVJNagbFJZ7efSLnTziCzJt8s38TVWYObo7M0d3r7MvXzDEbwPZLGpvx7f1CAbgq/V++D8/3XKIxTuKmDogk0dmDI2b0unxMYpTwDf7j3Kg1MPJ1/KK1e9S+sVzpIyeQZfb5tP5lhdJGnYunp2ram3n9gVZvLOoBUcshBCnDodV44qR3bBqdR9IJY+cTtqUGyhfsYADj8/kwNOzqFz/Ic7etYtSaKrCtNys2A/ubHhVwf2THTy33k9BRT2JAwntYz+2EEK0MH9Q5/qX1zDvq3yq/SHTlhQQ7t/nVy18dijIhU98Ram7bm5pIKRz9fOrWbSzGK9qrbdYhW6EJyA+23KY615cY9ocuDVIMNVEnl+6G1+w9h/T8ZK86WfdgqvvWFSbA0Wz4MoZVetpKIDbH+KZxXktOWQhhDilXDu2B2qEpR+JuWfQ8dpH6faLf9H1tvlkXvoHHF36n3jfYVG5ZkyPxj3pHHIlWOpfHpiTrnJZrpV/rjYvVIGiQLex5u8JIUScMAyDn7+1gZW7S/AE6tYFMOML6uwtcTPz+ZV4T+o9dc97m/i24CjeKPcF4SBt3d5SHvjPlpjG3lwkmGoCId3gk82H6tTSj6UkL8C6vWWUVweaYYRCCHHq65ru4v4Lc3FaY1suZ9NU+nVM4mdn9m7cgUdcD1EU/713kh2332Q7zQYjbggXoRBCiDi2aEcRX247Yhr8uLcsovDlO9n3yCUceOJqDr91H94Dm4Fw8bX8IjcvLMs/sf2RCi/vfl1gGpTVty8AT0DnjTX7KDOZ7WppkjNlwjAMvjlQzqrdJZRV+7FqKhlJdqYN6EAHk+TkCk8ABYWTL6axluS1aSrFbh8pLuk3JYQQjXH5yG64/UEe+u/2qJ50Oiwq2RkJ/Obc/nxbUE6yw0LXdBeOWAKypCzIPgN2flarqt+eO5NqbdY1RcV7j0kjSkWBkTdGfzwhhGglzyzOM13WV7H6XcpXvUO7aT/F0XMYimbBk78Oz85VJ2oEeIM685blM2dSLzRVYf6qvZitJYhmXxA+dS5Ys485k3Oa6+tGpUWCKe+OHZTNfw3vli3objeqy4ktO5v0mTNxDBkSNxU5vIEQH2w4yNOL8zhU7iUQ0k9UEnFYVP60cCtjstsxe1I2Y7LbnRi3P6SjqsBJf1s1S/JGE1ApioIvhmlOIYQQdd0wPpuczCQe/Hgr+cXh6lInr613aRDy+0lNTmJXkZsbXlqLooBuGBgGXDK8C9eN60nP9gnRHfS8R2DuOPCUxTZYqwsm/RpSusT2OSGEaGH7S6v5el/dqqPH01ranXsnrr7fLVd25YzClTOq1rbeQIglO4qY2CeDV5bvrdNCKLZ96Ty/LJ/Zk3q1aizRrMFU5aJFFD32T/z5+RiBAIS+iza8W7dR+fkXWDIyyPjprSRfeGGr/iAOV3i5/NmVHK7wmkbc3mO/7CU7ili9p5TzB3fkLz8ZhEVTSXZYCQTrLt2ItSRvSDdIdspkoRBCfF+T+mQwqU8GWwsrmLcsn/X7yqj0BrFbVJKcVvaXVmMEAhyqCNeM8p90QX9j1T4WrNnP5L4ZPHb5aQ3PVKV0hlkfwUvngrciYt+pWqwuOP1GGPezxn5NIYRoMZ9sOlSn0BrEltbi9od4e91+cjsl4wnUPU/GmiJT4QlQ4Qm26qquZsmZMgyDosefoODOn+PbuhXD660VSAGg6xgeD4F9+yj8w/0U/vZ3GCdv00KKq3xc8Pgy9pVWR6xIcpwBePwhPvzmILe/8TWGYeCwqqa9SWqW5K3esQI94MUIBfHkraXsy3l1trdZVDokN6LHiRBCCFP9Oybz0KVD+OKXk1n9u6n89ZLB5Be5qfQG8aiRL74BPdxcfdH2Ii5+ejnVJm0v6sgaALOXQrfRYHFgKBEejtkSwZkOZz8I0x4Ir1URQog4d7jSiz9UdwVVrGktRyp8VHgDWEyqr8a6L4umUuFt3XoDzRJMlTz7HCXz5oWDqCgYHg8VH39M4b33mvYJaU6GYXD1C6sodftjKrHoCYQvsk/8bxcEAlydUokjVPeXGW1JXrtF5Zox3cNNJ4UQQjS5nYcrufHltaZPQyPxBXV2Hani5lfWRXd9Su0K130Et67AmzAGHTsoKqCEC010GQnTn4Vf7YTh1zb+ywghRAsLmARSUDutJdr92C2a6SxXrPvSdQO7tXXvnZt8TZln40aKn366TiA1NW8XJaFQrejt4+xsMi3hJ4OG10vFRx+TOGEiyWf/qKmHFdGq/FL2llSfyI2qKZpmu3M/38aUB+YwKac3/8i8AEz+zhJzzyAx94x6x2EAV5s0MhNCCNE0/rRwK54Iqw/qO9/7gjrr95WxPK+EcTmRe0GVewIUlnuo9odItGeg7+xGp4t/RvJZZ4EeAk2WcQsh2q72CXZUhTrVq2NNa0lz2WiXaDMNzmLdl25AqrN1K6E2+Zm95LnnMXw+0/ee7NyFsQmRk3kNj4fiuXNbNJh6dvFu04trtJVEdF1ny10Pcul5p3Pzp9t5bml+TE89AZxWjYtO60SmLPETQohmcajcy4rdJaYFzKM531cf6wV4cjBlGAZr9pTx3NLdLN5RhE1TTxSyCCRP45xDKdxcWEFup5QW+JZCCNF8xvRqx9OLtTopMTXTWhRVw9HzNBTVgnfPBrz7Ntbqreq0aZzZPxOXzcLE3hn8b9uRWuflWPalKjAtNwub5RSamQqWllK1ZDGm83ZR8u/Zg3f7Dhx9+zR6H8VVPvKOVFHpDeK0aXRKdZpWZCqq9LEsr7jOxTWWSiIexcK8XV4uBX5+Vh/yit38b+uRqAMqh1XltG6pPPDjgbF+TSGEEFF6deUe0xK8sZzvV+WXcqjce6JFxoGyaq6dt5rCci+eQAjDOKmQhWblw60lfLpzBf07JvHCtaeTliC9pIQQbdPw7mm0S7BR7ffUeS955HTUhDTKVyyg+MOHUWxO7Fk5JI+5rNZ2hmHwk2Hh6qU3T8xmxe6SOsFZtPuyWzRumpDdxN8ydk0aTFUs/AhML1fRMwIBjr61gA6//31snzMMVueX8uyS3SzbVRyOUo3wcAIhnR7tErhlci/OHtgBuyWc1JZXVIXdotap4hRrJZH84iogXNr88ctP4w//2cyC5XkEFY1QhJ+HqoT/CKYOyOSRGUMlV0oIIZrRZ1sO1ynBC7Gd7y2qwvK8YqYP60JeURXTn1pOlTdIqJ4HiCHDwBMIsamgnHP+uZQPbhtHZpKsQhBCtD2KojBnUq/wkmmTSYOG0lo0BX48tDOJ9nD4MbJnOu0T7eHqqjHuS1WgS5qTIV1TG/VdmlKT3sH79+2LuMQP4PaCA4zauYNRO3dwW8EB841CIfx79sZ03JIqH+c/vozrXlrD/7YdwRfUqfQGqfQFqfQG8QZ0th2q5Lf//pZR//cFX+8L9wGp8ppXZ4q1kog3oJ9ITFZVhbsSD/PUrne4ZFgXHFaVRLsFl03DZdVItGvYLSrnD+7EW7PH8PgVw7BKICWEEM2qwvP9z/cB3aDcE6Ckysdlz6ygwhuoN5CqyR8yKK70cfmzKyPmbQkhRLybPqwLmcl2NDX2yROX3cLtU75rsKsoCi9cOwKXLYYm6cck2Cw8d82ImD/XHJp0Zkp3V9X7/uMN5Eyd2E91ddTHPFLp5YLHl1Fa5SfQQDU+tz8EhLjyuVW8cO0InBF+ebE22w2vkQ//Uel+P4cf/Cuj/nAfZ44byu9/PJCVeSWUVfsxDEh1WRnVs12r1sMXQogfmkjVx2M53yuAqig8u2Q35Z6A6Yr2+gpZBHWDwqMe/rV+P1eN7vG9v5MQQrQ0p03jzZtHc8HjyzhaHTAt4GbGZdN45fqRdElz1Xq9d1YSr900mqtfWIXbF6xT3OJkqgJJDguv3TiaHtE2VW9mTToloqWlNc1+UqJL1PUFw4FRSRSBVE2eQIgbX1lLSNcbrCQSjfZJ362BL335Zey9e5M4bhwAiXYLUwdkcemIrsw4vSvTcjtIICWEEC0sPUKuUizne83QSQj6eG3VPgKhutecitXvUvrFc6SMnkGX2+bT+ZYXSRp2Lp6dq05s4wnozF28u8XbgAghRFPpmOLk4zsmkpOZiMum1Zvgk2DTyEiy8+9bx3JaN/M4YWjXVBbePoGpA7KwW1TsJgUllGP/stsn8vJ1IxnYOX6K+jTpzJRjwACUhAQMt7vR+1DsdpxDhkS17cKNhRw86mlUWXNvIMTry/PprvnZHqj9Y4ilkohVU/jJaZ0BCBw5QukL8+ix4M3Gfn0hhBDNYPppndld5K6zzj+W830wGOLAQ38n1P8cUGsHZ7EUsih1+1m7t4zTe6Q3wzcVQojml5Fk5+M7JrA8r4QHP97KpoIKIFzDQFEUFAV6tk/grh/1ZWr/rAZrA3Rr5+LZq0ewv7SaX761gbV7yzAMTuRSHf/P/WXVXPbsSsb2asfdZ/ejf8fk5vuSUWrSYCrprLM4dO99pqVno2YYpM64FIBthyp48as9bNh3lCpfEIdVpWu6i2vH9GBinwyeXpxXpwIIRFfmVjfgi82H+Lm+gyedubiDtUcdbSWRQMjguSX5fL3vKJfkLWHsxRdj6y79ooQQIp5cMqIrf/vvdtP3ojnfa4rCOUO6smPgzXg2H66zj1gKWXj8IRZvL5JgSgjRZoV0gye/3MXzS3cTMozv7v0VBYNwYe9D5V7ufX8zpe4AV4zseiIlJpLiKh/XvbSG/aXVEZf7HS8ktGh7ESt3l/LUzGGc0S+zyb5XYzRpMKXabKReNoPSV+dDIFDrvc975UT4VA2KQsKECSwtDvHQ/KXsLq4iEDII1fiJ5hW5WZNfikVTcfvqJhTH8nRQs1sJTv4xtuV7cAcDJ+8qqma7AP6QzvK8EtaHetPX3o5XPAFSnLKUTwgh4kWK08o5Azvw4cZC09UMDZ3vbRaVGydk838Lt5i+H0shCwMoqopcrEkIIeKZNxDixpfXsnZvKd5A3XSZ49z+EG5/iAc+3MKG/WU8OH0waoTCFVW+IJfOXcH+0uqo8rAMwmk7t7y2jpeuG8no7HaN/TrfW5OXkUu/5hpUW+P6aOhWG4tGnc8t89expbACb0CvFUgd5/aHKPeYJ73F8nTQGzRYsrOYV64fhdMaeyWROvvTbGw9XMWFTyyj3FM3OBNCCNF6fntef1Kc1pgbeDitGhcP68ygLikRbwRqFrKIhtbAE1ohhIhHum5wy/x1rNlTfyBVkycQ4j/fFPLHDzdH3Obe9zdRUE/qTuHLd7LvkUs48MTVHH7rPrwHwvvyBnRufHkt1X7ziq0tocmDKWuHDnR5+ikUR4x9NOx2nh4xg/u36VH/cszEWtb8qCfAoC4pvHz9SBLs2ve+wPlDBoVHvVz/0mpJMBZCiDiSmeRgwezRpLqsUZf1dVo1JvfL4P5jjdUzkuym28VSyEJVICvZfD9CCBHP3ll3gJW7S0379tUX9HgCIRasOcDyXcV1PldeHWDhxsI6fV8husI+umHwn28ONuG3jE2TLvM7LmHkSLo9/xz759yCEQjU23sKqxXFoqH85j4+22gnGOGXU18xiZpiLWtu1cIX1JE901l4+wT+8dkOPtl8CAUDbzByMFTfmPwhna2FlZJgLIQQcSYnM4mP7pjAba9/zaaCckK6Yfok1GXTMAy4aWJPfj61z4m1/hcN7cznWw/j9jW+kIXNojItt0PzflEhhGhihmHw5KJdpg17o6lX4AmEmLs4j7E57Wt99u11+1FNJjOiTd2p9od4elEeM0Y0nJfVHJolmAJwjRhBr/9+QtmbCyh79VX040FVMAgWC6o9/FQudcalpM28intXlRDU99fZTzS/nJpqPh1M6De+wXHW7ETfo30Cj11xGuXVAf76yTbeXLPPNAEuqj8Yf4hnFudJMCWEEHGmY4qTf90ylryiKl78Kp9/rSvAFwyhqQpB3aB7uos5k3px4dBOuGy1L5MT+2TgsGh1gimIvnBRz/YJcVGBSgghYvH1/qMUVdadIImlXsGq/FIKyz10THGeeO2VFXtNA7RYUneOVPrYfriSfh1a/tzabMEUgKVdOzJ+eivt58ymaulS/Hl5hCorURMSsHXtSuKUKag2G1W+IO9t2MjJLZ9i+eUcF8vTwQS7xowRXevsI8Vl5UCZeSWRaMdkAEt3FlNU6Yu4LEQIIUTr6ZWRyJ8uGsSfLhqELxjCG9BJtFvqXQKoqQo3jO/JP7/YiddkJUVDhSxcNo05k3o1yfiFEKIlLVizH+/3DHoMw+DDbwq5aWL2idfMAjSILXVHUxUKj3pPvWDqOEXTSJo8GSZPNn3/sy2HTHOVYvnl1BTt00FNUZiWm2W6j28OlJu+HsuYbBaVLYUVTErKiGn8QgghWpbdomG3RJdre924nrz7dQH5xe6oqk59dwyVIV1SOW9Qx8YOUwghWs2BCCXLYwl6/CGDgqOeWq8FTp5NOSaW1B3DwHR2qyW0SDDVkINHvaaRbqzFJGpqsMytpnDV6O5YIzQR85j0r4p1TIZhUCFV/YQQ4pTitGm8ftNoLn56OYcqvKZJ0yezWRS6pbv4/fn98Yf0BhtYCiFEvDErOgGx1yvwBEIcqfRytDp8j2y3qARN7rtjSd1RFEhytE5YExfBlDcQImQS6cb6y4lFyAj/0Msj9ISyaApm8VQsY1IUBbtFLphCCHGqyUiy8+HPxvOLBRtYurMYA+oNqkI6FJZ7uWTuCkK6wbQBWdw0MZvBXVJbbtBCCPE9pLrMe6jGWq/gv5sKeXd9AVaLgoJCdYQJjFhSd/xBnT5ZSY37Yt9TXNzpJzusJ6rq1RRLqdlYhXSDf36xi5H/9zmPf7GzThnz9ATzXlmxjEnXDbKSYywRL4QQok1Idlh5/trTWXTXZG4c35MUpwVVAatJzlVIN6jyBan2h/AFdRZ+W8hlz6zk3MeWcqjc2wqjF0KI2Izt1d60L2vNoKd6xwr0gBcjFMSTt5ayL+fV2f6oJ4g/pOP2hajyBalvsXTyyOmkTbmB8hULOPD4TA48PYvK9R/i7F073WZ0drtWu+eOi5mp3M7J2DSVQKjxpWZVJZx8ZtNU3BEi3JMdX1v51KI8DpR5ePDiQSdKKs4c1Y3HPt9Rpzx6LGNKclgZ3CWlUT8TIYQQbUPHFCd3n92Pu8/ux6EKD5c8vYIjFT78EfIAAPRj6/u3H67knMeW8O9bx9GzfUILjloIIWJz8fAu/PWTbabvRVuvoDEaSt1JsGnMnpQd8f3mFhfB1JjsdiQ5rKZBULS/HAvw0W3jeG/jIZ78cle9Ue7JPIEQH3xzkI4pDu48qw/+PXuY9OmrPBIYBlrdH1E0Y3JaNW6emN0q9e6FEEK0PF8wxHUvruVwhZeA2dp1EyHdoNwT4PJnVvDxnRMjrooQQojWluK0cu6gjnyw4SAho+45rqGgpzlYVIVu7VyMyW7XosetNYZWO3INiqJw88RsHvrvdtNKHNH8cvp7DsP1V/CvkbdiYB7A1Ndo1xMI8dSiXUz55CUcyxeRMXMmZyVm8dnOUtOLYjRjunh4l3rfF0IIcep4f8NB9pS4Ta8Z9V1/dANKq/08sziP35zbvxVGLoQQ0bl9Sg7/3XwoYp5TtOo7J0bLoiq0S7Tx6g2jWnXyIi6CKYBLRnThiS934Q2EYppVAnBYVX43+8cUbMrh6NJSUOsmyEXV/NfvZ2HmIO7+7LdoiYn8udrP148t5UiFzzQCr288j10+1LSwhRBCiFPT3EV5ppVgo7n+BEIGr63axy+n9cUmhYuEEHEqOyORZ64ezk2vrMUbaLiSqZmo7skb4LJpdExx8MbNo2mf2Lr9XOPmjJ3ssPLmzaNx2bUI80rmHFaVP1yQy8ie7XilxInXJJA63mg3/axbcPUdi2pzoGgWXDmjalcC0aws8LXDcLoASHXZeOeWsXRIsZsWyIg0nj/9eCDTcjvE8C2EEEK0Zd/sP0qhSSGJaK8/EG6n8cnmQy01ZCGEaJQJvTN4+bqRJNotpgUp6hPLORFAITwDpSrhEuoOq0r/Dkn8ZfogPr5jIplJrV/oLW5mpgD6ZCXx7q3juOLZlXgCoXqnEK2agqYq/HX6YH58WmcAlu0qNp3ViqXRbiCks+NwFQM6hTsod0518tHPJvLAwi18+M1BFEWpsxRRU8GqquRkJnLP+QMY3YrrNoUQQrS8RduP4AvWvWbFcv1x+0N89G0hFw7p1BxDFEKIJjMqux3LfzOFf687wDNLdof7qirhPFBNVQjphunMVSznRABVVfjpGb2wqCopLisjuqefuEePF3EVTEE4oFp89xm89/UB5i7eTanbj64bBI41ObRoKgrhantXj+lOl7TwLFIwpEfs8RFLo11VUTjq8dd6LcVl5eFLh3DfBQP49/oDvLF6PyVuP8GQTqLdwtic9twwvmer1bcXQgjRug5X+tBNnubF2ny+pMrXxCMTQojmkeywMmtcT64d24P1+46yv7Qatz9Iot1CstPKrfPX15mAiPWc6LCoXDi0M70yEpvjKzSJuAumABLtFq4a3YOZo7qzbm8ZWw9VUukN4LRqdEp1MrlvBnZL7V+CoijhucAmaP6rRkhiS3JYuXZsT64d27MxX0sIIcQpyqS1FBD79UcqwAoh2hpFURjePY3h3dNOvFZY7jHdNtZzIkS+L48XcRlMHacoCiN6pDOiR3qD22qqgsOimVYDjKUzc0g3SHNJaVohhBDRy0xyoClwciG/WK4/ABmJcv0RQrR9qU6baa+9WM+J/pBOmiu+C7rFTQGKpjClX6bp08FYOjMn2Cz0zozfqUQhhBDx56wBWVhNqvDFdv3RuHBo55YashBCNBunTWNw55Q6r8dyTgTIyUgkNc4nOeJ6ZipWN03M5n/bjpjOTkXTaNdhVblhQk/USOs1hBBCCBP9OybTs10CWw9V1nkv2ubzVk3lzH6ZLTVkIYRoVnMm9+IXb23A7at9Xx7tOTHBpjFncq+WHHKjnFLB1JAuKXRMcbC72G36fkONdg0DLhvRtbmGJ4QQ4hQ2Z3IvfvPvb00r0TZ0/bFbVK4Z2x2LdkotGBFC/ICd2S8Tq6oCsZ8TIZzuc/bA+G81dEqdtRVF4eEZQ3BYY/9aTqvG/zu7H2kJ8T2VKIQQIj6dN6gjuZ2SsUXZl/A4TVXISnZw04TsZhqZEEK0PIum8vdG3pc7rCp/vXhQnYJz8eiUCqYAhnVL44krhsX0i3NaNa4d253rx0uVPiGEEI1j0VRevG4kvTITsZvkT5mxagoZiXYWzB5NkiO+k6yFECJWZ/bP4r7zB8R0X+6wqvy/s2XX/dQAACAASURBVPtx3uC20XPvlAumAKYOyGL+DaPokOLAZYsc0bpsGgk2jd+d159fn9O/BUcohBDiVJRot/DureOYlpuFzaJGDKqsqoLdonJ6j3Q+vmMCHVOcLTxSIYRoGVeM6s6TVw4jzWUloZ778gSbRorTyj9mDOW6cW1nguOUypmqaUSPdFb8egrL80p4ZnEeX+WVYFEVFAX8QZ1emYncOrkX5wzsiMMa/1OIQggh2gaHVePxK4ZxpMLL/FV7eWX5Xtz+IBZVJajrWFSVGad35bqxPejRPqG1hyuEEM3uzP5ZrL3nLL7cdoS5i/NYv+8oVk1BIVz+fHCXVOZM6sXU/pltLnf0lA2mIJxDNS6nPeNy2hMI6VR4AgR1gxSnVQIoIYQQzSoz2cEvzurLz6f2ocoXpMoXxGWzkGS3SNVYIcQPjqYqTB2QxdQBWfiDOkc9fgBSnNY2kRsVySkdTNVk1VTaJdpbexhCCCF+YBRFIclhlZwoIYQ4xmZRyUxytPYwmkTbmkcTQgghhBBCiDghwZQQQgghhBBCNIIEU0IIIYQQQgjRCBJMCSGEEEIIIUQjSDAlhBBCCCGEEI0gwZQQQgghhBBCNIIEU0IIIYQQQgjRCBJMCSGEEEIIIUQjSDAlhBBCCCGEEI2gGIYR+U1FKQL2ttxwhBAtoLthGBmtPYjvS85PQpyS2vz5Sc5NQpySIp6b6g2mhBBCCCGEEEKYk2V+QgghhBBCCNEIEkwJIYQQQgghRCNIMCWEEEIIIYQQjSDBlBBCCCGEEEI0ggRTQgghhBBCCNEIEkwJIYQQQgghRCNIMCWEEEIIIYQQjSDBlBBCCCGEEEI0ggRTQgghhBBCCNEIEkwJIYQQQgghRCNIMCWEEEIIIYQQjSDBlBBCCCGEEEI0ggRTQgghhBBCCNEIEkydQhRFmawoyoGW/mxLURSlm6IoVYqiaK09FiFE/drS+UhRlEWKotzYUscTQsSPtnSuagy5d2p+EkzV49gf3/F/uqIonhr/e2YzHneWoijLmmv/TUFRFENRlJxmPsYeRVGmHv/fhmHsMwwj0TCMUHMeV4h49P/Zu+/4qOv7geOvz61c9oAECIQRhkzZIihDq1VRsQXEuqpYd63a/qq2ah11tY5qW6viQgUnuBVwi2xkKLJHCARCICE7d5e7+97n98cdeEkuyV0g+/30cQ/vvvPzPXKf7/f9mZIfhaaU6hnIjyzVlr+ilHqwudIlRHsleVXt5NmpbbLUv0n7pbWOO/JeKZUNXK21/rL6dkopi9ba25RpE0K0L5IfCSFaA8mrRHsjNVMNcKRaVyl1h1IqD5gdqkQkuARCKRWllHpcKbVXKXVQKfWcUiq6AeeeqZTaopQqU0plKaWuC7HNnUqpgkDpxKVBy49LGkKc7z6l1DtKqdcC6dqklBoVtP4vSqldgXWblVK/rrb/NUHXtFkpNUIpNQfoDnwcKM26PbgEWil1kVJqTbXj/FEp9VFjXqsQLY3kR2Gl80ql1NLA+YqUUruVUufUsm0XpdQGpdRtgc/fKqUeUEotC1zn50qpjkHbTwnkecWBbQcEls9USn0ctN0OpdS8oM85SqlhgfdaKXV9YJtipdT/lFKqMb4LIZqL5FU1zifPTm2EBFMN1xlIAXoA14ax/T+AfsAwoA/QFbinAec9BJwHJAAzgSeVUiOqpatj4PhXAM8rpU6INA1KqWeUUs9EkK4pwFtAEvAR8HTQul3AeCARuB+Yq5TqEjjPhcB9wG8D1zQFOKy1vhzYC5wfqJ5+tNr5PgZOUEr1DVp2CfBGpNcqRBsg+VH9xgDbAul5FHipesCilOoFLAae1lo/FrTqEvzXlwbYgD8Htu8HvAncCqQCC/A/xNgCxxmvlDIppdID+40N7JcJxAEbgs5xHjAaOBGYAZx1jNcrREskeVVV8uzUFmit5RXGC8gGzgi8nwS4AXvQ+iuBpdX20fj/GBVQAfQOWjcW2F3LuWocq450fQDcEpQuLxAbtP4d4G/1pSGw774Ivg8N9Am8vw/4MmjdQMBZx74/ABcE3n92JP11feeBzz0D57UEPs8F7gm87wuUATGRft/ykldre0l+VOWcVfKFoOWvAA8GXcPOoHUxgX06Bz5/C/wr8L1eXO043wJ3B32+EVgUeP834J2gdSZgPzAp8DkHGAH8BngeWA30x/8w91G1f5tTq31Pf2nuvzN5yetYX5JX1TivPDu1wZf0mWq4fK21K8xtU/H/oa4NKghVQMQjqwSaptyLv+TAFDjuT0GbFGmtK4I+7wHSj2caapEX9N4B2FWgPbRS6rfAn/D/oMFfInukmUwG/tKXhngDeAL4O/6SlQ+01g6lVBqNe61CtDTtOT860ufCGvT+yGdP0OejeVQgnwB/XnTEpcBOYH6Ic1TP347sl47/mo4c16eUysFfmgv+2qlJ+B8MFwPFwET8DyiLwzyHEG1Je86rQpFnpzZAmvk1nK72uQL/HyEASqnOQesKACcwSGudFHgl6qBOmuFQSkUB7wKPA5201kn4m5UEN1VJVkrFBn3uDuQerzRESinVA3gBuAnoEEjzxqA05wC9a9m9+ndc3RdAaqDfwcX8XE3dLNcqRDNqz/nRAfxBU89qy3sRFOiE4b5Aut5Q4Q8hnIu/uRIAgWaDGfhrp+DnYGp84P1i/MHURGoGU0K0B+05r4okzfLs1IpIMHX8/AgMUkoNU0rZ8d+YAX9pJf4fxZOByB+lVFelVF1t4pVSyh78wt/mPgrIB7yBkpZfhtj3fqWUTSk1Hn8b4XkNTMPxEIv/h50fOOdMYHDQ+heBPyulRiq/PoFMBOAgkFnbgbXWHmAe8Bj+NthfBJY317UK0VK0m/xI+4f7fRd4SCnVQSllVUpdjL/JzMIIDuUBLsSfZ72mlArn/vgOcK5S6hdKKSvwf0AlsDywfjFwGhCttd4HLAHOBjoA6yNImxBtVbvJqyIkz06tiARTx4nWejv+KtMvgR1A9bkO7sDfhGSlUqo0sN0J1G4c/hKC6q+b8d/Ai/BXz35Ubb+8wLpc4HXgeq311kjToPwjuDxX91XXT2u9GX918gr8P/AhwLKg9fOAh/CXjJThb8ecElj9CHC38o9u9edaTvEGcAb+TC+4iU+k37cQbUY7zI9uBArxD+hwCH9p7rla64N17FOD1toNTAU6AS/XF1BprbcBlwH/xV+qez7+jt/uwPrtQDn+IAqtdSmQBSzTMueLEO0xrwqLPDu1Lkrr+moDhRBCCCGEEEJUJzVTQgghhBBCCNEAEkwJIYQQQgghRANIMCWEEEIIIYQQDSDBlBBCCCGEEEI0gARTx0gp9YpS6sHA+/FKqW1NdF6tlOpznI959Fqact+mopS6Uyn1YnOnQ4imIvnTse/bVCR/Eu2J5E3Hvm9Tkbypfu0imFJKZSulnEqpcqXUwcAf73GfhExrvURrXe8QkkqpK5VS1Yf/PG6UUt8qpa5urOMfq8a+/sA5Jiml9gUv01o/rLVusd+LaJ8kf2pZJH8Swk/yppZF8qaWq10EUwHnB2ZxHgGMAu6uvoFSytLkqRJCCMmfhBAtk+RNQtSjPQVTAGit9wMLCcwkHajy/b1Sagf+CeNQSp2nlPohMOHZcqXUiUf2V0oNV0qtU0qVKaXeBuxB66pE9EqpDKXUe0qpfKXUYaXU00qpAcBzwNhAaU9xYNsopdTjSqm9gRKg55RS0UHHuk0pdUAplauUuqqh16+UmqeUylNKlSilvlNKDaq2SUel1BeB61usfp5RG6VU/8C6QqXUNqXUjIamo1qaspVSf1ZKbQik623ln7UcpVSyUuqTwHdYFHjfLWjfFKXU7MD3UqSU+kApFYv/3zg98B2XK6XSlVL3KaXmBvZbqJS6qVo6flRKTW3MaxWiLpI/Sf4U2E/yJ9GiSN4keVNgP8mbQmh3wZRSKgOYDKwPWvwrYAwwUCk1HHgZuA7oAMwCPgr8YG34Z5meg3+m6XnAtFrOYwY+AfYAPYGuwFta6y3A9cAKrXWc1jopsMs/gH7AMKBPYPt7Asc6G/gzcCbQF/+s1Q21MHCMNGAd/pm+g10KPAB0BH44sj7wI/sC/6zZacBvgGeUUgNruf5ipdSpEaRrBnA20As4EbgysNwEzAZ6AN3xz2T+dNB+c4AYYFAgXU9qrSuAc4DcwHccp7XOrXa+N4GLg9I7MHCOTyO9ViGOF8mfJH8KkPxJtCiSN0neFCB5Uyha6zb/ArKBcqAY/w/0GSA6sE4Dpwdt+yzwQLX9twETgQlALqCC1i0HHgy8nwTsC7wfC+QDlhDpuRJYGvRZARVA76BlY4HdgfcvA/8IWtcvkO4+tVzvt8DVYXwvSYHjJAY+v4I/0zqyPg4wgAzgImBJtf1nAfcG7ftgmP8e1a8/G7gs6POjwHO17DsMKAq87wL4gOQQ2x39twhadh8wN/A+PvCd9wh8fgh4OfC+zmuVl7yO50vyp1q/F8mfJH+SVzO+JG+q9XuRvEnypiqv9tTO9Vda6y9rWZcT9L4HcIVS6g9By2xAOv4fz34d+AsJ2FPLMTOAPVprbxhpS8VfQrBWKXVkmQLMgffpwNowzlmnQInPQ8CFgXP6Aqs6AiWB90e/C611uVKqMHD+HsCYI1XrARb8pRvHQ17Qe0fgnCilYoAn8Ze8JAfWxweuJQMo1FoXRXoyrXWZUupT/CUn/8Rf0nJNYHVjX6sQ1Un+JPnTUZI/iRZE8ibJm46SvCm09hRM1SX4B54DPKS1fqj6RkqpiUBXpZQKyhS6A7tCHDMH6K6UsoTIFHS1zwX4q2AHaX+75OoO4P/jP6J77ZdSp0uAC/BXdWcDiUAR/szniKPnUf5Re1LwlyjlAIu11mc28NwN9X/ACcAYrXWeUmoY/mYGKpCmFKVUkta6uNp+1b/jUN4E7lVKfYe//fY3geXNda1ChCL5088kf5L8SbQckjf9TPKmdpw3tbs+U2F4AbheKTVG+cUqpc5VSsUDKwAvcLNSyhrocHdSLcdZjf+H/I/AMexKqVMC6w4C3QLtiNFa+wLnfVIplQaglOqqlDorsP07wJVKqYGB0oZ7w7gOS+CcR15W/NWzlcBh/KU5D4fYb7JS6tRA2h4AVmqtc/C3Ye6nlLo8cO1WpdRo5e8U2pji8WeWxUqpFIKuXWt9AH875meUv7OlVSk1IbD6INBBKZVYx7EX4C9J+TvwduDfAZrvWoWoj+RPkj9J/iRaIsmbJG9qt3mTBFPVaK3X4K+yfBp/ycNOAh36tNZuYGrgcyH+9qHv1XIcAzgff4fIvcC+wPYAXwObgDylVEFg2R2Bc61USpUCX+IvVUBrvRB4KrDfzsD/6/Ms/h/Skdds4DX81dz7gc3AyhD7vYH/R1cIjAQuC6ShDPgl/qrdXPxVy/8EokKdXPlHgRkfRjrr8xQQjb8EaiWwqNr6ywEPsBU4BNwaSO9W/KUnWcrfoTO9+oG11pX4//3OwH/dR5ZHdK1CNBXJnyR/kvxJtESSN0ne1J7zJlW1CasQQgghhBBCiHBIzZQQQgghhBBCNIAEU0IIIYQQQgjRABJMCSGEEEIIIUQDSDAlhBBCCCGEEA0gwZQQQgghhBBCNECdk/Z27NhR9+zZs4mSIoRoCmvXri3QWqc2dzqOleRPQrQ9bSF/krxJiLanrrypzmCqZ8+erFmzpnFSJYRoFkqpPc2dhuNB8ich2p62kD9J3iRE21NX3iTN/IQQQgghhBCiASSYEkIIIYQQQogGkGBKCCGEEEIIIRpAgikhhBBCCCGEaAAJpoQQQgghhBCiASSYEkIIIYQQQogGkGBKCCGEEEIIIRpAgikhhBBCCCGEaIA6J+1tK1wegyKHG49XkxBtITHailKquZMlhBCNSmtNmaeMksoSrCYriVGJRFuimztZQogIaa0pcXoodXqxWhTJMTbsVnNzJ0sIQRsOprTWrNh1mFnfZbFsZwFWswmlwGP46BgXxbXjM5k2qhsJdmtzJ1UIIY4rh8fBp1mf8vLGl8lz5GE1WdFovIaXEZ1GMHPwTMalj8OkpHGCEC1ZidPDvDU5vLAki8IKN1azCa39zzIT+qZy7cRMxvRKkQJiIZpRmwymftpXwnVz11Ds8OBwGwB4fcbR9QdKXDz62Tb+sWgr103I5I9n9pOMSAjR6mmteWXTKzzzwzMopXB6nQB4fd6j26zOW83Ggo3EWGN4YuITjOg0ormSK4Sohc+nefzzbby0dDcmpXB6/M8wHuPnZ5lvth1i5e7DdIi1MevyUQxMT6jzmFpr1u4p4rvt+eSXuzEp6JRg56xBnTmhc3yjXo8QbVmbC6aW7Szg6lfXHM14anNk/QtLdpN92MFTFw3DZJKASgjROmmteWDlA3y862NchqvObR1eBw6vg+u+uI5/Tvgnp3c/vYlSKYSoj+HT/P6NdSzelk+l11frdhpwuA0cbifTn1vO7CtHMyazQ43tXB6Dd9fu47nFuyiocONyG+jAOrOCZ77dSWbHOK6f1Jtzh3TBLM9CQkSkTbXx2HKglGteqz+QCub0GHyx+SAPLdjSiCkTQojG9fyG5/k4q/5AKpjLcHHHd3fwY/6PjZgyIUQk7v94E4u3HYroWcbhNrjqle/ZcbCsyvJDZS7O/c8SHvx0CzlFTpxBgRSAocHl8bH5QCl/eXcDl7+0CofbixAifG2qZurO93462qyvuorN31L6/Qd4Du/DZIvGmpZJ4rgZ2LsNwukxmLtyD78d24MeHWKbONVCCHFs8h35PL/hedw+d411RUuKKPisAPchN2a7mYSRCXSa3glzrL/zustwcc+ye/jwVx82dbKFENXsPFTGO9/n4ApRI1XXcwz4A6q/fbiRt64dC0BRhZsLnl5GflklXp+ucbzqHG6DtXuKuPj5lbxz/ViiLDLAhRDhaDPBVFZ+OZsPlIZcV7r6fUpWzafDL3+PvdcIlNmCc/danDtWHc2EfD7N7GXZ3DdlUFMmWwghjtk7294J2e+zYGEB+Qvz6XZ1N+IGxuEp8pA7J5fsx7PpdVcvTBZ/44Tc8lw2FWxiUEfJ/4RoTi8tzcbjqxlIhfMco4H1e4vJKXSQkRLD1a+toaC8ZiBVV1BW6fWx7WAZf/tgI49OH9oUlyxEq9dmmvnNXpaNEaLkxVdZQfHS10k58wZiThiHyWZHmS3E9BlD8mlXHd3O49O8syYHVwTV6kII0dy8Pi9vbH2DSqOyynLDaXDog0OkX5ZO/InxKIvClmoj48YM3AVuSpaXHN3W7XPz6uZXmzrpQoggFZVe3l+/D6NaLBXucwyAT2teXZHNptwSNuWW4DGqPheVrn6fwq9eIPHkGXS7aS5db5hN/IjJOHesOrqNy+Pjwx9yKaqoWdMthKipzQRTizblhazGrty/Fe11E9NvbL3HMCnFuj1FjZE8IYRoFNuKtmH4ahYCOXY48Hl8JIysOsKX2W4m/sR4yjeVH13m0z4W5yxu9LQKIWr3fXYhFlPNx7JInmM8hmbBhgO8tGQ3Hm/VZ6JIgjKl4O3v9x7bBQnRTrSZYKrcFbrDpOEsxRSTgDKF1/a3yOE5nskSQohGVeIqCdnEzyg3sMRZUOaa6yyJFrzlVfNMl9eF1vX3qxBCNI5ihwdNzd9gpM8xpS4vn/50AKPa7zmSoMzl8fHS0uywzidEe9dmgqlQGRCAOToBn6MUHaLkNtRRGnx+rSlxeMgpdFBQXomnej29EEI0IXOcGW+5F23UzNe8JV4scW2my6wQbYJGh3wMiew5BgynE7OnZhO9SIOywxWVuOsYml0I4ddm7qbxUVZcnsoay6O69kdZrDi2ryC2/6n1HEWRFGON6LyHyyt58/u9zF6aTanLg8Vkwqf9od15Q7pw9fjMeifSE0KIhkq0J4YsTIrpE4OyKErXlpJ4UuLR5YbLoGxDGZ2md6qyvd1il8nLhWhGSTG2kL/ByJ5jIDrahturoVrsFRyUhRNQWcwmKiq92Cy2sK9BiPaozdRM/XJQJywhJpozRcWSdOqlFH7xHI7tK/B5XGjDi3PXGoq+ebnKtobXy/BuiTWOEYrH8HHnez8x7h9f8/TXOzlc4cZjaJweg0qvD7fX34Fz6rPLOPc/S9hf7Dwu1ymEEMH6JffDpGpm5eYYM2m/SiN3bi5lG8rQXo07303OMzlYU6wkjUs6uq1JmRjfbXxTJlsIUc3onil4Q4zkF8lzjNWsmDgwHSw1y8qDg7JweA0f0TYZHl2I+rSZmqmrTu3F/LX7Qg5CkXDSVEyxyZSseJuCTx5H2aKJ6tSHhLEXHd3Gguasgs3knvcUSdOnkfjrqVg7pYU8V6XX4PIXV/PT/uI6Zyc3tMbwaLYeKGPyv5cw//qx9O0Uf+wXK4QQAVaTlUv6X8LsjbNrzDOVOjkVc6yZvLfzcB9yY4o2kTAigYzrMjBZfw7AbCYbVw66solTLoQIFhdl4YJhXZm/JofqrXPDeY4B/0BaV4/vxacbDtQ4fnBQpkxm7L2Go0wWXNk/4Nq7ocYgFHFRFuxWCaaEqE+bCaZ6p8bRv3M8P+4rCbk+btBpxA06rdb9zRYzNz94I13yJlM8bx5Z559PzMiRJF14IXETxqMCpTxaa256fT0b9hWHnFQvFENrSp0eLnp+JYtuGU9agj3yCxRCiFrMOGEGszfODrkuZWIKKRNT6ty/S2wXBncc3BhJE0JE4Hen9uLDH/ZjeH5+vrDiJZpKTIMm1vkcAzA0I4lB6YmM79eRr7ccqtEAONygzGY2cdnJPY7XZQnRprWZYArg4alDmP7sCpwRzhUVbTVz0ehuZKbGQepgoocMptMdt1O6aBGHZ80i7/77SZz6a5KmTWO1086yXQURz06ugVKnh8c/3yYT4Qkhjqu0mDSuGnwVr256FacRWZNiu9nO30/5eyOlTAgRiX6d4pk6vBur1v/AhXoRF5u/Jg4nBiYsGBzQHXjBOJd3jQmUEVNl3xibmQcu8BeKXDehNyt2Hcbhrvk8VF/hMgAKLh/rD6Y8ho+1e4r8EwAbmsQYK8O6JZEcK32phIA2FkwNSk9k1uUjufaVVbh84XWkjraaOa1/KvecN6jKclNsLEnTppE0bRqu7dspnj+f7OkX8u+x1+CI6lTjOOHMTu71aT76MZd7zh9EXFSb+uqFEM3sxmE3ku/M59OsT3EZrrD2sZvtPDL+EYalDWvk1AkhwuIo5KGK+/BYvkP7fEQp/xQG5sBoEl3VYW5Xb/MXy5u8YZzOQ97LMDATbTXz4hWjOKGzvyvB6J7JdEm0s7ugghC9H+pkM5sY36cjCsVjn21lzoo9gWNotPY3Jaw0fJwxII1rxmcyLCNJBq8R7VqbGYDiiDHmUh5bM5tOMWZi6+g4GW01EWUxMfOUnvzvkhGYQgxecYS9Xz8633knsR8t4kd7zX5UkUyEZ1KK99fta/gFCiFECEop7h17L9cNvY4ocxTR5uhat42xxJBiT+GZM57hjB5nNGEqhRC1KtkPz52Kyv4Om3YfDaSqi1GV2JWH35i/4fWof9Ajwcy868cyrnfHo9sopXj1qpOIt1uJJMwxa4OOPiejeiYz8bFveHHJbkpdXsorvZRXGlS4Dcoqvbi9PhZtzOOSF1Yxc/b3OEPUgAnRXrSpYMrncLD/1ls59ZrfsOLus3ju8pGc2qcjVrMi2uoPrmwWE50Sovi/X57A6jvP4Paz+4ddovJddgnmECPkRDIRnsNt8P763IivTQgh6qOU4uohV/PtjG/548g/0iW2C1aTlRhLDNGWaCzKzEi3wWPj/8nXF37N6M6jmzvJQggAVwm8MhnK8sCoOUdUKDHKzWjLLr7NnMPgEFOwdEuO4d0bxtEhzoY1xOTd1dktJnqlxnNq4Q7+vWgzlV5fnYNs+TQ4PQYrsg4z/bnluCLsYiFEW9Fm2ppprcm7/+/YBw4kafp0lFKM75vK+L6pVFR6Kaxw4zF8JERb6RAbei6H+hQ53Li9NTOLSCfCK3KEl1EKIURDxNniuHjAxfym/28oqiyitLIUi8lCsj2Z2OdPB1M8hJlfCSGawJInoTQX9M/PGD2fKsPhgd23xBFr8z+zvLjOzdwNHr69MhYAs+GCXV/Dzi+h75k1DtsnLY7Pbp3Ac4t38ebqHLTWVFSrRYq1mbFbzcw8tSdpcVHcU+wKe4AtgEqvj12Hyvn9G+t46QopoBHtT5sJpkreew/npo30euedGoFSbJSF2OPQR0mHnpw84onwhBCiKSilSLGnkGIPGs2v31mwfSF0G9l8CRNC/MzrhjUvhayRMjT8e5WbO8dH1b6/xwHL/h0ymALoEBfFXecO5M9nncCijXks3JhHYYUbs1Kkxkfxq+HpTOyXhgLGPPxVrYFUXYNsubw+lu0sYHNuKQND1JIJ0Za1iWDKtW0bhx5/gh5z52CKial/hwZKjrFhs5hweapmNJHOTp4UY22sJAohRN36nQMLb4PT727ulLQ5hs/ApEzSGV9EZuvHoEMHMLeNs/HoskpuHG0jyV7H31XOaijaA8m1D2ceZTFzwbCuXDCsa8j1i7fn43CH7qcVziBbHkPz0tIsnpghA9qI9qXV95kyysvZf8utdPrrX4jq3btRzzW+b0d0iKqpSGYnj7aaOf/E9EZNpxBC1KrbaH9H95L9zZ2SVs+nfSzfv5xrPr+GkXNHMnzOcIa+NpSxb4zlvuX3sat4V3MnUbQGP7wJ7vKQq0alm5nU08LjyyvrP862hceUjOcW76rRBBDCH2TL8Gk+2XCAUpfnmNIhRGvTomumPIaPvYUOSpwebGYTHeOi6Jz484S3WmsO/O1vxJx0EolTpjR6ejJSYhjePYmVWYU11oU7EZ5Pa6aN7NboaRVCiJDMFpyZZ/PJ54vZYh9KkcNNvN1KZsdYpgzrSorMHROWL/d8yUOrHsLhceDwOqqsK/eU88HOD/gk6xN6J/Xm3kunYAAAIABJREFUkfGPkJmY2UwpFS1e+aE6V//9tChOebmCW8bU8ds0KqGi4JiS8WNOccjlkQyyZTOb2JZXxuiedU8ULkRb0iKDqbwSF3NWZjNnxR68Po1ZKTT+4KpXx1humNSbswd3puLtt3Bn76HnW282Wdqum9ibDftKGjQRnknB5CFdSIyWZn5CiKa397CDF5ZkMf/Hc1E+Lw5f9tF1dquJRxZu5RcD0rh+Ym9O7JbUfAlt4V7Z9ApPr3+aSqP22gJDGxiGwZbDW7j4k4uZdeYsmc9LhFZLE78jBqeZOa+fhX8sdTMgtY4GRb7QTfTC5axlNL5IB9kqdUrNlGhfWlQzP8On+dsHG5n42De8EJjbwBGY06C80kul18fWvDLufO8nRt7/GYteX0i3p57EFFVHx8zjbGLfVEb2SCbKEvlXF2+3cttZJzRCqoQQom7fbD3EWU99x5ur9+L0Khy+qoU6Lo9/GORFG/OYMWsFLy3NaqaUtmwf7fqI/63/X52BVDCNxuF1cP0X15NVIt+pCCEmud5N7p9k54V1bvaX1jIDr9kGsR2OKRmWWubbDB5kq14KbA14PhKiNWsxf/GGT3P1q98zf+0+Kr0+3HUMy1nhNij3aO4fdjELC5u2cs1kUjx/uX+WcXuYGYby+YizmnjjmjGkJ9U+kaYQQjSGb7Yd4obX1+L0GHh9tTyMBfi0P7B6/LPtzFosfX6ClbnLeGDFA7gMV411RUuK2HH3DjZdu4mtN28l99VcjIqfHz4dXgd3Lb2rKZMrWosBF4C17sGz+qSYuGiQlf+srmVqFWWGzEnHlIzaWs0ED7JVH8On6RjXdAXcQrQELaaZ390f/MTKrMM4PRHMbeCDO97bQKdEOydnHluJTCSibf7Zxm+bt4FFm/JAg9uomW6lwG41k2bSPLD+NfonTGiyNAohBMC+Igc3zl1XYxRSqHuoY6fH4MkvtzOkayLj+nRshpS3PB/s/CDk8oKFBeQvzKfb1d2IGxiHp8hD7pxcsh/PptddvTBZTGg0O4p2kFWcRWaS9J8SQYZeBJ/XP7rmPROjmLOhliZ0KZnQeUiVRVprVuw6zKzvsli7pwinx8CsFAnRFqaN6MZvx/Wka1AB74UjM3hpaRZuo2qBS/AgW8pkxt5rOMpkwZX9A669G6oMQpEcY6N/5/gILl6I1q9FBFNZ+eW8t25/yJm267rZg78E9e4PNvLlnyY2aZqjLGb+c/FwcoudzF25hzkr9+A1fJhNJnxa4zU0k/qnct2ETEZ0T+bA31aSd999pD/2qAybK4RoMrOXZeP11cxbwxnq2OXx8dRXOySYwv9g+sqmV2rUShlOg0MfHKLr77oSf6L/IdKWaiPjxgy237adkuUlJE/wN+Py+rzM2TKHe8fe2+TpFy1YVDwMmeYf1S9o0t7sW6sGJRmJJlx315zDyaetMPqGKk2NFm08wH0fb6bU6anSx9tAU1DuZvaybF5Zns2onsk8Nn0o6UnRXD62By8v202oGTXDGWQr2mrm2gmZ8owj2p0WEUzNXpaNEaLpSTg3e4D9RU427Ctulg7T6UnR3H52f/50Zj8Kyt2UujxEW810iLMRY/v56+18111kz7iI4vnzSb7wwiZPpxCi/XF5DN5avRdPtZLmI0Mdd5h8KzEnjDu6PKbPGGL6jKmy7Y85xew97KB7h8abw6812Fm8kzJ3WY3ljh0OfB4fCSOrPuSa7WbiT4ynfFP50WDK0AaLdi+SYErUNOmvsOUTcIUeUa822mTF601k792vkf6PwcQMH86sxbt48svtIWujjzjSmmZlViGT/7OEt68dywmd4zmpVwrLdx0O+UxW3yBbAFNHhJ7DSoi2rNn7TDndBvPX7qvRjj/ceQ0AKr0GLy7Z3ZTJrsFiNtE50U6/TvFkpMRUCaQATNHRdP33U+T/60lc27Y1UyqFEK2G4YEQNUqR+GxTXsjlkQx17NOaOSuzjykdbUGhqxCzqjmamVFuYImzoMw1S+MtiRa85VVHWHN4HOhQExaK9i2xG/z2Q38tFWHW7JhtqMSu2P7yPWl/uo19f7iZFx9+iafqCaSCGT5NscPDRc+v4ECJk8emDyXRbg03BUfZrSaevGgY8XYZrVi0P81eM7VhX3HIEWQiu9n7Z+5u6aIyM+n017+w/9Y/0nPePMxxsc2dJCFES+EzYPtnsOwp2L/u52GObXH+JkAn3wipkY0GuvNgechJOCMZ6thjaDbllkZ03rbIqGUkM3OcGW+5F23oGgGVt8SLJa7qbVYH/lMRP66KNi99GFzzDcz5NTiLap3IF2X2j97XbRT85nWwJ5Jw1i9x9R/Mo8+swa0i7zJR5vJyx/wNvPa7Mbxz/cnMeG4lJU43Rhhxv91q4v4pgzl7cOdjuXohWq1mr5kqdnpCtM6NfF4Dh/vY5ldoKolTphA9cgR5990npZNCCL8f3oTH+8B710DOKvB58Pdb0OAug/VzYdYEeP40OBz+CHvFtcz3EtFQx/gftNq7hKgEdIi7VUyfGJRFUbq2asBpuAzKNpQRO7BqoZnNbMOkmv3WK1qqjn3hlg1w0RzIPA3MUf7aqqgEsMWDxQ5DfwPXfAVXfgL2xKO7zstyYLLWrBkqXf0+hV+9QOLJM+h201y63jCb+BGTce5YdXQbw6dZtbuQAyVO+qTFs/DW8Zw+oBNRFlPIqWBMCqIMD73jzbx0xWguGp3RON+HEK1As9dMmVTo8rngm304AVVr6vAo/aeEEEd9/SCseBo8ztq38Xn9rwM/wPMT/c2Buo6s99AJYQx1HNv/1HqPExfV7LeKZtc3uW/IAjBzjJm0X6WROzcXk91UZTQ/a4qVpHFV+/KOSBvRVEkWrZXJBL1P978qCqA0FzwOf0CV1B2i4mrsYvg0Ly/djavaQF6R9I/UwJwVe7j97P50SrDzwm9HkV9WyRur9jJvbQ4lDg9enyY2ysy43h25yLGd7ms+J6PP2Y3yNQjRWjT7HTIl1hayZirSm318K7rZH+k/tefSy4g+8UTsJ8hEvkK0S6tmwYr/1R1IBdM+qCyD134F1y32D4dcC4fbS0WlF4tJ1eiTGslQx2YT9O1U8+GtvYkyRzG171Te3vY2Hl/VGr/UyamYY83kvZ2H+5AbU7SJhBEJZFyXgcn6c6l+jCWGmYNnNnXSRWsW29H/qseGfcW4vDVrmiPpMuH2+nh33T5uP7v/0WWp8VHcckZfbjmjb43tfa4B7Hz5v2xes5nXcwzW7S2i3OXFbjXTPSWGK8b15NQ+HTHVMhmwEG1Fs0cgQ7slYg7xQ4vkZm81Kc49sUtTJvuYSf8pIdq58kPwxd/AW1llcc+nynB4YPctccTa/Hnji+vczN3g4dsrA/mEuxw+vAlmLqiyb06hg6+3HuKrrYdYt6eIQekJ1FZpH85QxwBWk4nLT+5xfK65lbtkwCXM2z4v5LqUiSmkTEypc/84Wxwndzm5MZIm2rmCcjemED/2SLtMlNTSNDiUZTllPHz6n8iavxOvyUzwdJs7DpWzMuswMTYL103MZOYpvUI+6wnRFjR7MGUxm/jt2B48/11WjXmmwr3Zm0yKmaf0aspkHxeJU6ZQsXq1zD8lRHu05mVqG7XL0PDvVW7uHB8Vel/tg/1r8BTsZm1pAl9vPcTXWw9R7HAz6YQ0Lh6dwf8uGU683cpf3/uJd9bsJcS84mENdRxlNWO3hvcg1tZlxGdwXuZ5fJr1aY35pupjN9u586Q7JZ8XjcJr+AjVDTvSLhOhhkQP5cUlWTz++TZcnkAeFSJ/qXAbVLgNnvh8G0t2FDDr8pGSl4g2qUX0gq2r1DNu0Gl0ueIpuv/pXTJumkvahfdh7zbg6HqlYFB6Ar06ts6anc533UXltm0Uz5/f3EkRQjQVw+tv4ucN/UB+2zgbjy+vpNhV+4ON12vw5tN38/CCLURbzTxx4VBW33kGj184lHOGdDk6RPHV43thNTcsq4+2mpnYN5Xzn17Kg59sptjhbtBx2pK7T76bkZ1GYjfbw97HbrZz8/Cb+UWPXzRiykR7lhBtDVkLHdxlIhzVp3UJZe6KbJ74PPzh150eHyuzDnPdnLVhB2tCtCYtIphKS7DzxzP7Ed2AEosYm5lHp5/YCKlqGjL/lBDtUO46MGoPTEalm5nU08Ljyytr3caCl0tjVvHRTafyxzP7MTQjKWTfhN6pcTz86yHYrZFl99FWM5eO6c5/LhnO57dOoMJtcPoTi3lxSRaVIfpmtBcWk4Wnf/E052aei81sw2qqfV6daEs0UeYo7h17L5cPurwJUynam8FdE3F7awY3wV0mHNtX4PO40IYX5641FH3zcpVtlfYxKG8bBx95hPJly/C5a+ZRu/LLeXDBFpyemnlAxeZvOfDqrez913T2PX05B9+5F9e+TQBUen2s3l3Iayuyj8v1CtGSNHszvyOum5DJ4XI3c1fuCfkjrU7hD6RenXkSfdLiGz+BjUj6TwnRzlQUQD3DY//9tChOebmCW8bYat3G7C4L63RTR3TD8Gn+9uFGKj2+kIP+BIu2mrns5O789Rx/K4C0BDuPTB3CzFN68o+FW3l1RTZ3nN2fc4d0CbvZ2p7DFby1Ooes/HLK3QaJ0RaGdkviwlEZpMTWfo0tkcVk4b5x9/G7Ib/jzS1v8u6Od1GGB4VGWaLxai+JtkSuHHwlU3pPId7Wuu9RouVLjLYyeUgXPvohF6Nae79wu0xER1n5/YxxmLeupuC/T1O5cycxJ51E3IQJxE0YjzU9nZeX7sYbYvKp0tXvU7JqPh1++XvsvUagzBacu9fi3LHq6FxWTo/Bc4t3ceW4ntLcVbQpqq65jkaNGqXXrFnThMmBuSv38M9FW/FpTUVlzaDKrBQ2i4mMlGievmQE/Tq1nZtU7t13o12V7ab/lNfnpbiyGIfHQYw1hsSoxDpLecXxoZRaq7Ue1dzpOFbNkT8dN5s/gg9/D5U1J8Pt+VQZL06J5oxMC5e+56BzrIkBqaaqA1AcYbLCPQVhn3bDvmL+89UOluzw7xPcT9VqUphMikHpCfzh9L6c1j+t1uMs31nAQwu2YDWbuOvcAYzuWfvAC99uO8R/v97Jxv0l+LTGE/QgZrea8Gk4Y0AaN53Wl4HpCWFfS0tSaVSyc+GfKImKwTL4QlLsKfRO6t0u8vHjrS3kT82VN23KLWH6syvCKpAOJSM5mu9uP+3o3623qIiKpcso/+47KpYswZPWhWkDZ+Kq1qjJV1nBvv9dQYfJt9Y7+nKszczzvx3FKX3qH6FQiJakrrypxdRMHXHZyT2YMSqDzzbl8ey3u9iSV4rFpPD5NCafweRhGVwzPpPBXRPrP1gr017mn8opyzlammtoA7MyY2gDheL83udz2cDLyEysfchnIVq96OSwNrt/kp0Rs8r5v7G1DERhi6wW+8RuSbx4xWgOlbl4+/scftpXQonTQ2yUhT6pcVw8pntY/U/H9enIxzedygc/7OeWN9czpFsid5zdn8zUn4dQ9/k0jyzcwtyVe2t9uDvS52LRxjy+3nqIR6edyJRhXSO6ppYgyhzFoMpKyJgAXU5q7uSIdmpQeiKn9unAkh0FNeabqo/dauL+CwZVKQCwJCeTeP55JJ5/HtowePeTVZhWFlC9ajuS4dcr3AavrciWYEq0KS0umAKwWUycPzSd84em4/Npyiq96OwsCv70R/r8c0H9B2il6pt/SmtNpdeH1WxqlUOMVngquP2721l5YCVa6xpztQC8v+N9Ptr1EYM7DubJSU+SbA/voVOIViV9OBj1D0HcJ8XERYOs/Ge1myFp1ZoFKhP0mtig06fF2/nD6TXnjYmEyaSYOqIbk4d04eVlu5n27HKmDE3n5l/0pUNcFA8t2MIbq2oPpIL5tD+wuv3dDZhNinNPTD+mtDWLsjyI79TcqRDt3H8vGcFFs1awLa8s7IDKbjXxl3P6c3r/2v9+ldlMQVInKlVRjWAq0uHX9xY6wtpOiNaiRQZTwUwmRWK0FV/vXuTt3492u1G21tW+PhLV+0957XYW/pTHs4t3seNgGQqFD01clIXpI7ox85RedO8Q09zJrlexq5jLFlzGgYoDuH21d7z3ai9ew8uG/A1M/3g6r09+nc6xnZswpUI0gag4GDIdfngTtLfOTe+ZGMWcDSECL4sdTrm5kRIYPrvVzI2T+nDRqAz+89UOzvjXYiadkMqijQdr7aRe+v0HeA7vw2SLxpqWSeK4Gdi7DcLl8fHneRsY3DWRHh1aWd/R8oMQ17LzKq/h48sth5i9bDc5RQ5cHh8xNjOD0hO4ZnwmI3skS9PEVs5uNfP2dWO5+c31fLcjH4+hax1BL9pqQkPYNcJOtzfkFAuRDr8e7iiAQrQWLT6YOsJks2Ht0gX33r1E9enT3MlpVIlTplC+ajWP3/8Sr8X0A+2vGvfzZ4plLi9zV+3hjdV7GdY9if/8ZjidEsIfqrcpuQ0313xxDbkVuSFro0Lx+Dwcdh5m5qKZzDt/HnG2uPp3EqI1Gft7+Gk+eKsGU9m3Vu0HmpFownV3iL5E8V2g68jGTGFEOsRFcf8Fg7liXE8ueHpZyEAqnE7qXsPH7GXZ3DdlUFNfQsQKK9zsPFROmcuDvSiJ9Mp4WuKMh17Dx7Pf7uKlpbvx+HxV+iMXVsD+YidLdhTQIdbGbWed0CqbWoqf2a3+fklbDpTy0tLdfPxjLjbLzzXbhk8Tb7dwzfhMLhyZQWJMeH2V4+1WrGZVpd8jVB1+vb4+U/7jtJpHTyHC0qr+om2ZmVRmZbX5YEprzb8G/5pP1udQGWIQjiP8GZpmTXYR5zy1hHk3jKV3assLOj7e9TF7SvfUCKSKlhRR8FkB7kNuzHYzCSMT6DS9E+ZYf8mWoQ3ynfm8vuV1rht6XXMkXYjGkzYABpyPd9NHWHyRTQCLJRrOfYKQE8s0M7fhw+OrWfLsq6ygeOnrdJh8KzEnjDu6PKbPGGL6jDn62ePTvLMmh7+c079FTvCptWbd3iJmfZfFt9vyibKYQGuovAbP8xvonrKTGyb15pzBXVpE+p1ug9+9+j3r9xbhrKVGQGtwuA0cbid3vPsTG/aXcNfkAVJL1coN6JLA4xcO5d7zB7I5t5QSpwerxURqXBQDuySEnEqhLgPTE7CaTXiMqs8lwcOvK5MZe6/hKJMFV/YPuPZuIPm0q45uazUpRvaQ5vuibWlVwVRU70zcWVnNnYxG98iCrSzYnE9lmCPbGT5NkdPNjFkrWHjLeNLiW04Nldaalza+hNPrrLK8YGEB+Qvz6XZ1N+IGxuEp8pA7J5fsx7PpdVcvTIFStEqjkrlb5nL1kKsxh9keW4jW4sUOf2YEmxlm2YWp2m+kVpZoOPdx6H1a4yaugeas2BNyvptIOqkr4LNNeVzQwmpIih1urpi9mu0Hy3F5DLQm6FpjweNj+8Fy7np/I/d+tInZV45mZI/aRzpsbIZPc+2cNazdU1Rl5Ma6OD0Gr6/cS6zNwh/P7NfIKRRNId5uZUxmh2M+ztjMDiTYrTjcNQt5wx1+3WRSXDmu5zGnRYiWpFUFU7ZemVSsWtncyWhU2/LKeG1ldsg2xXX1NdAaShweHvpkC/++eHgzpDy0H/N/pMBZdehmw2lw6INDdP1dV+JP9DdpsqXayLgxg+23badkeQnJE34uuXIbbpbsX8KkjElNmXQhGtX/vtnJ/LUHOPemhZgW3wGb3gOfAbU1hbUG+hBNewH6n9t0CY3QtrwyQnXRiKSTusNtkF1Q0Qipa7jD5ZWc/9+l5JdX1mjmVN2Rh83LXlzF878dxfi+qU2RxBrmrtzDmuzQgVRd9xOnx+D577I4rX8awzKSmiHloiVSSnHthEwe+2xbyGa8cYNOI25Q3YU8mamxra8/pBD1qHvWyBYmqncm7l1tu2bqpaVZeELc+EpXv0/hVy+QePIMut00l643zCZ+xGScO1Yd3cbr0yzalEeJI7x+SU1hcc5iXN6qTZgcOxz4PD4SRlbtB2K2m4k/MZ7yTeVVt/c6+Dz780ZPqxBNQWvNU19u5711+3j72pPpkpIAv34WblwJJ10DtjiwxUNUgv9ljYaUTDjnH3DbjhYdSAFUuEMPqBHcSb0+Gihxtpx8zGP4uOTFVeSX1R9IBXN6fFw3Zy3bD4Y3ufLxpLXmucW7au27Vt/9pNJr8Px3u5oyyaIVmD6qG9E2Mw1pAGo1K3KLnTyycAuuBs6FJURL1LpqpjIzce/ejda6TbblLq/08tGPuVS/V4fb1wDApBTz1uZw9fiWMU/TIechdLVxVI1yA0ucBWWu+W9oSbTg3FOzudNh1+FGS6MQx0prjdNj4HQbxNktRFlC175orXnss218teUQb107ltT4oPmjUnrB2Y/AL+6Fgm3gLAazDeLS/MFUK8nz4qJC31Yi6aSugKSYljNq62eb8sgpdOAJUeVWVw0P+Pss/XPRVl66YnSTpnnV7sKQAWm49xOfhq+2HKKowk1ybMv5txDNK8Fu5Y1rxjDt2eU4Ko3qo6TXym41cffkAZw1uAv3f7yJc/69hEemDuHk49D8UIjm1qqCKXNCAio2Bm9eHtYuXZrsvCUOD/PW5vDJhgMUO9yYlCI51savh6fz6+HdiK3l4SFSS3fkYwnRITSSvgZOj8G8tftaTDAVijnOjLfcizZ0jYDKW+LFElfz+9Q6/NJgIZrKnsMVzF6Wzbw1Obi8Pswmhdfw0TEuiqvH92LGqIyjQYHWmocXbGHZzsO8ee3JpNT2gGq1Q5ehTXgVx9fALgms31uMt1rgEUkn9ZgoM5mpLacp0LPf7grZTySc0Qk1sGRHAYfKXE3an/WNVXtxhkhzJPcTk4IFGw9w6ZgeVZZ7DR9Oj0GszRLxIAai9evfOYF3bxjHJS+swun21jqwCfhro8wmxYMXDGb6qAwAnr5kBF9sPsitb/3A6QPS+Ms5/Umwh9dHXIiWqFUFUwBRmb2pzMpqkmBqX5GDxxZtY9GmPExKVW0uUVDBlgOlPPjpFn41rCv/98sTqpYyN8DhCneNBxCIfEK8Ykft8zg1tdToVBSqSu1UTJ8YlEVRuraUxJMSjy43XAZlG8roNL3mxIEdo2W2dNFyHC6v5A9vrmftniJ8Wh9t+nVkPpdDZZX864vtPPH5di4c2Y17zh/Iwwu2sm5vEW9cM6ZF1bocb78d15O3vs8JmZeF20ldoThzYMuYAHf7wTJ25ZfXWB5JiwEFvLFyL7c24YAO+4ocIWsNIrmfOD0+8kr8zbQPlbqYs3IPc1fuodjpwawUhtZ0SbBz9fhMpo3sRmK0PBC3F/07J7D4tkm8u3Yfzy/JotjhwfBp3F4fFrPCZjGhNVw0OoOZ42rOh3nmwE6MyUzhkQVbOevJ73jggsGcUcdvXmvN2j1FPP9dFuv2FuFwG1hMipRYG5eO6cGMUeEP8S7E8dbqgilbZi9/v6lTTmnU8/y0r4RLX1xJeaU3ZGdq+LmT8fy1+/hyy0HeunYsfdIaPjS5z6drzCwOkU+IV9sEfc1hUsYk3tj6RpXR/MwxZtJ+lUbu3FxMdlOV0fysKVaSxlXt8BxjieHMHmc2ddKFCGlfkYOpzyynqMIdstnXEUcGkZm/bh+fbT5I1yQ7c68e0+ZLYHunxtG/czw/7isJub6+TupWs+LSMd1rbSrZ1FbtLiRUxXgkNTyVXh9fbT0UUTCltWbV7kJeWrqbbXllONxeoq1menaM5apTezGxb2qdtUK1TYwa6f2k2OHh2jlrWLwt/+i1AHgDX0puiYvHPtvGPxdtZcaoDO45fyBWc6vqji0aKN5u5cpTenHFuJ6s3l3I1rwy/5xrVjPpSdGc3j+tzukBEuxWHpk6hBW7DvPX9zbwwQ/7uW/KIDrGVS2Y/mxTHg9+spnDFW6cgVE0jyh1efnXF9t4/PNtnD2oM3+/YLAEVaLJtbpgKiqzN5W7djbqOXbll3PxC/5AKhxen+ZwhZsLn1vOglvG0yUxOuJzegsLse3ajsnrAapmPpFPiNdyMpKhqUPpGN2RnLKcKstTJ6dijjWT93Ye7kNuTNEmEkYkkHFdBiZr1RuxzWxjQrcJTZlsIUIqcXq4aNYKDpe7McJseury+HB7KxmUnkCcrdVluQ3yxzP7ccPcdSEHP6iP4dP8YkBaI6SqYUocbjxGzcAk0hYDpa7wB9SYvzaHJ7/YQZHDjdNdtV9KTpGTdXuKiLaZuXFSb64c1ytkUFVbLVGkfdc+/jEXh8cIOdz9EUf+neevzWFrXilzfjemRcyxJZqGUooxmR0aPPz62N4dWHjLBJ76cjtnP/Udd04ewK+Hd0UpxbPf7uTfX+2otXAAONrMcMHGA6zZU8i868eRnhT5c5gQDdXq7uy2zF6UffFFox1fa83M2d+HHJGqvqHJS11ebpi7lg9+X/cNSvt8uHftwrF+Pc5163GuX4+3sJC+w8fgTfllje0jmhDPrFrUg4hSipmDZvLomkdrjOqXMjGFlIl1z8FiM9u4dMClMseUaBH+9/VO8ssqQwZSdeUPPg2rdxeyeHs+p/VvOb/PxjLphDSum5jJrMVZEQVUdquJaSO6ccPcdTzwq8FMHtJ0fWNrYzGbMCmFr9q/eaQ1PKH6w1bn82n+9uFG3lu3v87vrcJtUOE2eOyz7azaXch/Lx6BzVK1EOrkzA6s21tzWPRI7ifgDwLDbezg9PjYsK+E6+eu5eUrRkt/KhG2aJuZv04ewLknduH2+Rv48IdcRvdM5n/f7KozkArmMTR5JZVc+NwKFtwyXpqdiibT6oKpqN7+PlONZdXuQgrKK2s06wino7Hh02w9UMb2g2X06xR/dF9fRQXOn37CuX69P4D64UfMyUnEDBtO9PDhpFw1k6g+fVAmE+NeXs232/NrpCvsCfFUy5sQb0qfKby97W2ySrLw1DaHTghmZSYtOo3LBlzWiKkTIjyVXoN3LPzmAAAgAElEQVQ3Vu/FHWJo7HDyB4fbYNZ3u9pFMAVwyy/6YjWbePrrnf4JbuvY1mLy97H43yUjOK1/GjNGZXDTm+tYlXWYO88d0KxN/jrE2rBZTHirDeYQaYuB5DD6yT20YEu9gVQwp8dg8fZ8/vj2Dzx9yfCjo9yWV3oxKWqdqDfc+4mGkE0c6yo4qPT6WL27kM8353H24OYPhkXrcmK3JD7+w6k8/tlWHv98e8ht6vr7M7TmUJmLv3+8iSdmDGvi1Iv2qtUFU5ZOndAOB0ZJCebExPp3iNDzi7NqjIAUSUdjj8/HC59t5O6UQpzr/bVOldnZ2Pv3J3r4MJJnzCD94YexdAw9oMK1EzJZnV0YcuSocCbE69spjm7JMXVu09SizFG8+MsXuXTBpeRV5OH21T9AhkVZSLYnM/vs2cTZGt4PTYjjZeFPeSFHlYwkf1i/t5icQgcZKS3rN9oYlFL8/rQ+nJzZgWe+3cnSHf7Ju4Mf8GNsZrSGqSO6cs34THp29I/gNzQjiU/+MJ7b5//I9GdX8PQlw5ttos8zBnTi7g821lgeSQ2PWcFP+4q59MWVnN6/E2cMSKtxPct3FfhH4AsRSNX18Ojy+Ph62yE+/CGX4d2TeHX5Ht5bv4+xmR0Y2SOZdXuKQgay4dxPQgm34OC5xVkSTIkGsZpNmE0mrGZVY163cP7+PIbmkw0HuHfKoDbfR1W0DK0umFJKYcvMpDIri5jhw4/rsctcHpbuLKhx44mko7Hhgw9/OsgfXF8RO3wECeedi33QIEy28EbvGtu7A8MzklizJ/Ss9XWxmRU5hx38+8sd3Hha7xbVCTjJnsTb573Nbd/dxve5K/BpH54Qt3izMmMxWRjUYRBPnvYkKfa6mwEK0VTeX7+fimMcahrgyy0HmXlKr+OdvBZrZI9kXrpiNIfKXHywbj+7Dzsoc3lIjrExuGsC5w9NJyZEX7LEaCvPXTaSV5ZnM/WZ5c3W7C851sYZAzqxcOOBGs3dwq3hsdvMLL39dFZnF/LVloM8++0uEqMtnDGgE78Y0IkR3ZN47tvaJ9it7+HR6Ta464OfsJlNzBidwac3j6drUjSbckuY9uzysJtJ1SeSgoOtB0rZlV9O71QpDBOR8Rg+5q7cUyOQinTOzffW7uPKdpTXiubT6oIpCIzol7X7uAdT+WWVWM2K6s9LkXY01jYbiX//V4Pa6yqleOGKUUx7dgVZ+eVhB1R2q4lnLxtJ/87x3D5/A9OePci/ZgylT1p8/Ts3kThbHM9OfJKc/57I6yN+xfv7F2NoA7My49P+6zwv8zwuG3gZvZN6N3NqhaiqoLwy5PJI8odKr4/C8pYzdUFTSou3c+3EyH7XSilmntKLEd2Tm7XZ3+n901jw04GQ6+qr4bFZTFx8UneSY22cNagzZw3qjM+n2bC/hK+3HOS+jzaxr8gRcsCjSB4e3V4fr101hpE9ko8uG5SeyFMXDefWt9eHHVAp5X8QDTUqbCQFB16f5rNNedw4qU9Y5xXiiFVZhTX6KELkc26+vmqvBFOiSbScqosI+Oea2nXcj+v0GEfbnAcL7mgcDrNJ4WrAKFZHxNgsvH/jOMb37YjdYqKuCqbYKDOJ0Vbm/m4Mp52QRpfEaF676iRmjMpgxqyVvLR0t3/I9ZbixzfJ6DSUv0x6lKUXL2Xh1IW8dd5bfDr1U5ZfvJx7x90rgZRokULd3CHy/CHcUQDFz440+8srdTH92RXsOVzRJOf1Gj7++9UOHl6whVP7diQ6whHqzCZFeqKdW37Rt8pyk0kxLCOJP/3yBBbcMp7Lx/YIee+J5OHR54MP1u+vsfzswZ159rKRRFvN2C113/JjbWY6xNoY0Dl0IVwkBQden+Zgqave7YSo7lCZK2RfvUgLtg9XtM+CK9H0WmUwZeud6Z9r6jhLsFtDPjAFdzQOh8fQxNuPrdLPbjXz4hWj+eTmU5kxqjt2q4m4KAvxdgvxURaiLCYGpSfwz2kn8v1dZzCq58/N4ZRSXHZyD96/cRwLfzrAJS+uJKfQcUzpOS58Biz/L5xyCwBWk5XUmFR6JfYiLSYNq1naNouWKyU2dFPdSPIHq1mFNRCBqOlIs7+pI7oy9ZnltdYU1eAqhaJsKMwCR2HoERVC2JVfzrTnVrA6u5BPbj6VV2ee9P/snXd4VGX2xz/33mmZ9E5CEpIQeu+9KirYEBV1FQtIU1fZXcuuu2tb/bnFtmsXERXBBqiIXZQaOtI7IUAIgfQ6fe7vj6GFTCYzIZNMwvt5njw+zr0zcyZM7n3Pe77ne7iqayuvEyqdIpMQbuDT6YPqHFdxZuDphfiyeHSoKseK3V/nR3WIY/Vjo3jwsnZEB+sI1isYtQp6jYxRpxCkVWgbG8xT13Vh9WOja638+bxx4MasRSCoC5vD6dZB0tfvn93NSAOBwB80S5mfPj0dy+GGT6biwwxudwd9tZKNCtb5vINZGxlxoTw/oRt/v6YThwsqKTPZMWhl4sIMtK5jjkKb6GA+nT6I2auyuP71NTx2VQcm9k12+xkbhb3fQFAEtPHvwGWBwB+M7dqK346W1DCH8cmIQJYY3j62sUNvMXgt+3PYYf93sPoVOLEVFJ1Lv+awQngyDJkF3W4GXU0jEKdT5YO12fxv2QH+OKY9dww8VzV6aWIPOiWE8uovB3GqKpWWmos6g1ZGVeGyTvE8P6GbV3LvC02PzuCr/bo746IzRIfouW9UBtNHtGV9ViE5JSZMVgcheg3t40PplnTO0CnSi40Db2ZUXTh8VSDwhjCDFsWNrb6vDprB+ma5xBU0Q5rlN02XkoL9RB5OiwVZ33AXa63TzoQoKx8fV7FfcOPyttE4SCszdVhagycsRp2GLom+uxcqssSMEW0Z2SGWP366jR93neT5G7sRF2po0PjqRFVhzSuuRUxTJXMCwUUwvlcSz36zx+0xb68P6bEhdKhFQiXwHo9ufwd+hsX3gsMG1grXY87z+pGKsuCHv8D3f4bLn4IB088eOl5i4uHPtmGxO1h83xDSYqo77kmSxLThbblnSBo/7jrJWysOceBUOWabE60iER2s545BKdzWL4VoHxKJqJCLT14AoryoeiqyxOAM926yZ7iySzyZhwouauMgSKcwTGwcCOpBj+QIrG6qSr45aEoMSBMGVoLGodklU6qqsvZIKR8NuJ3y2ZkQZCQ6WM+YzvFc0SW+Xg52jvJySj77jKIP53F1Rhc+ix+HO98Hb6xknSpM7Jvscwz+pmOrML68fwiv/nKAcf9dzdPXdeHq7o3ojHUkE0wl0PHqxntPgaABCdFruL5naxZtzsHuRoNS1/XBqFOY6aMBg6B2LnT7e3Z8V8Y6lsPSP4Dd5PnJ1tM9Vz8/BaU5qGOeYeHmHJ7/bi/3Dktj+vC2bnfGz6BVZK7unnD2Gqqq6kVtoPVOiSRYd7SGW6RPyYtWoX8DLR6v7ZHIk0t2uT3m7cZBbIie3ikRDRKP4NIiMSKIfqmRrDlYWOOYt98/naIydXh6Y4UsuMRpNsmU2eZgwfqjvLMqizKTjaqozpBTBbg04j/tzuMvi129QlOHpdcqUzgf24kTFH04j9LFiwkeNoykN16nXZcuXPbRZpbtPeWzNXmQVuHmPklEBGhPhE4j86crOjC6Yxx/+mwbP+zK45nruzROvGtegcG/By8bRwWCQGTW5e34YWceJSbvh0+Dq1eqXVwIV3Vt5afILk3Ol/3N/Wgul1mfR6e6d110i60KdcNsPt1r533nWD6aMoDOiWH1iuNiGNM5HrmW5M3bxaNTVbmxT9JFxXEGw+l72fz1R+u1cRCkVZg+Ir3pJOWCZs/04W3dyqrBu43tJDWPLqsegNF/h9j2/gpTIACaiQFFSZWVG9/M5N8/7CWv1Oz647rgIl1pdVBmtvPuqiyufGUlWfkVtb6eee9ejj/6KFnjbwCnk7QvFtP6hf8Q1MU1s+OlW3qSEReCvg7no/MxaGV6pUTwxLWd6/chG5FeKZF88+AwooJ1XPXKKpbvO+XfNzy5C3K3Qo/b/Ps+AoGfSQgPYv7UAYTqNXgoXFRDr5FJijTy4eQBATX7rSXRIymcl7Rv10ikUl8pJ+4/5VRazyUE726xMvL9c26Akt3EjcXv8tW9XeuVSDUEWkVm0sA26Gq554R0GUXCXa+Q8sdFJD/wEXE3P4UhqdPZ44okMa5bQr3GcdTG/aMy6mWkpJElEiMMTOjdMImd4NJkaEYM7eJC0Cq+J+QGrcwTd1wFrfvA3Ktgye+hLNcPUQoELgL+zl5ltXPTW2vZf7LcqzkZVodKfoWFG97I5HjJOamHqqpUrF7D0clTODZtOob27cn46Ufi//JntImJ1V7DoFX4fMYg+rSJxKjzXEmRcMl3RnWI4/17+qNpJoulIJ3CU9d14aWJPfjrFzt5/IsdVLqZc9IgZL7q6kvQNnKflkDgB7okhrPk90NJjjR6vD4oMhg0Mv3Tovj690MJNwq3Sr9xeAWytdztIYcK/13v2SJZq9Gg377AH5F5zd1DUuu0Lq8NnUbm/lENO88pLszAx9MGEmrwfuNAq0jEhOj4eNpADA1kwiS4NJFliQ8nDyAhPAidDwmVQSvzt3GdGNapNQydBb/fDEFR8OZg+OkJMBV790IOO+xZCu+OgecS4OkI+EcsvNgRlv8TKvy8CS1oVkiqB5vYvn37qps2bWrEcGrywIIt/LT7pFvJXeXu5ZRt/BJbYQ6yLghtXDrhgydiSOqCIkmkxhj56YFBlH//PYXvzQWnk6jJkwm/ehySrm5pm9OpsmJ/Pm+tOMTWYyXIkoTF7kACdBoFp6oyKD2aacPTGdQ2utlKGsrMNp75ejcbDhfx4sQe9EttwKbN0hx4cwg8tBWCIus+X+B3JEnarKpq36aO42Jp6uuTqqqsyypk8vubsDud6DQysiThdKo4VbihV2smD00NqMHZLZZ5N8ChX2o8nPpKOTP66vj3GgtZD4USYZB4d4uVj7bbWH53dXMJQlvBH/aA3HQbYtuOlXDr7HW1uvu5w6CVeeuOPozsEOeXmLILKrnxrUyKK61u7aoBZAn0GoWOrUKZc3e/WscIeENLuD419bWpJVFqsnHXexvYf7Ick9VBbStWg0YGCf59Y3eu69m65gllua4kaO9SGPQADJjh1s0TgI3vwi/Pudw/rW5UThqDy1Sr3eVw3WtgFEYXlwKerk0B3TOVX26pNZEq2/AFpesXEn3F/RjSeiMpGkyHN2M6sB5DUhccqkpuQTlf3jSFPjFa4h7+E8FDh/qU8MiyxKiOcYzqGMexoipWHsinpMqGLElEGrWM7BBHq/DmX20JM2h54eYe/LT7JPfP38INvVrzhzHtG2Znce0b0OsOkUgJWhySJBEVrCcqWMc3Dw7lVLmFKquDUIOG1hFBYme+sVBVOLyy1sN9ExVGpmp4IdPCs6M9XK/NZVB8GKKbziSkR3IEn08fxB3vrsfqcHq0OjdoZRRJYvadfet056uL7IJK3s/MZsvRYsrNrvEbKVFGJg1MpXdKBHpF5vFxnVi5P591h4vQKTKS5PrV2xxOrugSz9Rh6XRPEoYTgoYlPEjLF/cNZm1WIe+syGJtViE6jYzT6TJ9UVExaBTuGZrq2UUzLBGu+5+rd/uXf8CrfWDEo9BrEiinl8KqCt/8CbZ9DDYPszntp4dRH/gJ3hoCk3+EiMAzHmvOWO1ODuVXUGqyoZElYkL0tIk2BmzRIqCTqQXrj7p93GmppGT1fKLHzcLYYfDZx40ZAzBmDDj7/2anxNdXTmbCA6MvOpbkKCO3D2hz0a8TyIzpHE/vlAj++sVOrnttNS9N7EnX1t7ZsauqSqXVgeP0wGJZllzl9K3zYeYaP0cuEDQN3+/M46qurYgw6gLWeKbF42nRc5pnRukZ8l4lDw3w8G8ka1xDfZswmQLo2jqc1X8ezZe/5fDW8iyKqqyggt3pRJIkbA4n0cF6pgxL45a+yV6ZLdXGuqxC/vPDPnYeL8XpVLGdV3rac6Kc1QcKkCSJpMggJg9J495h6RRVWsktMZ3bOIgMIqyOocQCwcUgSRKD28YwuG0MJ8vM7DxeSpnZhk5RiAvT0zsl0qP7ZjVi2sHED+H4Zpeb59rXXCYVna93VaPqSqTOx2GF8pMwdyzMWCU2jRuA4yUmPszMZv76o6ioyKeTJ5vDSXyYgRkj2nJ9z0SMusBKXwIrmgv4YG2226qU5fheVLsVY/tBHp+vShJr8iyUVFnFQsdLokP0vHlHb77cepy73tvAXYNTuW9kW7e9YKqqsvlIMe+szOKXvS79sCSB3anSLTGcGXG7GNNuHNpw0YgsaJl8f9oRUxDYdI1TuKa9hn+uttIpNvD7WkP0Gu4YmMrtA9qw5Wgxh05VUm6xo9fIPPP1bpY/PJLgephDnM+8tdk89+0ej73IZ6zaswsqmfz+Rt6a1IeoYN1FyfgEgoshPsxAfFgDKIJa94E7l7jkwT8/Bcufd82gc5zrr0x9pZwqGxx+KIRgnWtRX0MmrDpcCdXPT8G1/734uC5RHE6VJ7/ayeebc1BVFaujpqDzSGEV/1i6m38s3c2LN/dgbLdGHO9TBwF7V7E7nBRXuW8adpjKkI1hXk2E1ykyuSXmhg6vRSNJEjf0SmLpg0PZmF3EjW9mcvBUdd3wtmMljPzPcu58bwM/7TmJ3alid6rYHCqqCtuPl/LIbzH02X4dizfnNNEnEQj8x5HCSvLLLfROEbuRTYq2lr6HC3h6pIHZW6wcL6ul68JpD7jeB0mS6NMmion9kpkyNI07BrahY0Ioe/LKLup1P990rM5E6nzMdifrsgqZMW8zjtoapwSC5oYkQcZlMG0FBMdWS6TO4I2BDU4rbP/03Pw6gU84nCr3frCRRVtysNidbhOpM1RZHVRZHfzhs63MX3ekEaP0TMAmU1U2B5payrZKUBjOqjJUZ91NupIElVY/udS1cBLCg/hwcn9u6pvMzW9l8t7qw2dNOW59Zx1HiqqosjqozcOkEgNlFiePf7mDV37a37jBCwR+5oddeVzRJd57eYnAP0gSpI2o87SMKJlbumj534ZaFkaGCIgK/CGfPZMj2HqspN7PP1pYxd+/2uk2karcvZwTH8zi6Es3kfPaJE5+9iTmHNfwXrPdyfrDRbyfmV3v9xYIAhJbJeS4Nwx5ZLCOFzItlJjr2kSQYfvnDR/bJcDfvtzBuqxCTF5u7gCYbU7+8c1ufvX3aB8vCdhkKlinwV5Ldqpv3RFJo6Vq/9o6X0dVXZIJQf2QJIlJA9vwxX1D+GbHCa5/fQ3TP9yEyea925TZ5uTtlVnMXx84uwgCwcXy/c48ruoihvAGBEMeAl1wnac9MUJfbebUWbRGGPxAjfmFgUiPpItLpuZmHnZbXSrb8AVFy2YTPnAiSQ98ROuZcwntPQ7TgfVnzzHZHLy94hBOUZ0StCQOLoNalE7nG9h4xFYJWz7wQ3AtmyOFlSzectxtIuVpcwdca8u/f7ETT67kjUXAZhmKLBEbqudUec0vsKwPJmLo7RT99BaSrGBI64UkazBnb8V8dDuRoyafPdfmcJIYEdSYobdIUmOC+Wz6IPo/9zNmN31s4Nmq3mRz8MzXu7mmW6KYtyNo9uSVmskqqGRgenRThyIASBsOhvAaMpvsWdVt6ZPDZcx/czOYV3VAz9v9GWGD0TMlgpd/rl+l32xz8OnGY9gu2Kj01tQJoNJiZ82hAoa1i61XDAJBwFGZ75L51oJXBjZnXkfgE+9nZuN0kwzV5dh9hqIqKxuzi+mf1rQS7YCtTAFMHpKKQes+xLD+E4gcPYXStZ+S8+rt5Lx5N+VblhLU7pwphSzBZZ3iGnQq/KXMrtzSWq16vdnVlCWJzzcfa6xwBQK/8ePuPEZ3jENXzyGrggZGkmD8m6Cpx8aZNggufwaCmoetd1p0MGUmGwUVdeyUu2HZnlNui2/emjqBy5TiAyH1E7QkHDZQa5eYnW9g4xGnrYEDa9mYbQ4+87C5EzVmJsYOg5F1BiRFgzFjQLViCYDJ6qqWNzUBW5kCuKVfCi//fKDW4yFdRhHSZVStx/UahanDAl8D31x4d9VhLPaayZS3u5omm4N3VmYxeUiayzpdIGimfL8zj7sHpzZ1GILzSR/pctP6+iGwm7x7jtYI/afCwBn+jKxBkWWJHskRbM8pYXTHeJ+em1tiwupGWeCLqRPAsSIvf78CQXMgKBIUrVsDijM8PdJA77cr+NOgWuZYAejdVL0FtbL7RJnbuVG+bO6owNqsQrLyK5i7JpuVB/KpMNtRZInoEB239E3mxj5JhPp5fENAb6tGBuu4rkeia7K1j2hkifTYYHomN4/dxubAT7tP4k4q78sXv8Ji51C+m4niAkGAYrU7OVlm5nCBy70vv9zCjpxShrcXMqeAo8ctcOt8MEaDLqTW02yK0ZVIjfkHjHmmEQNsGHokRbD1qO99Uyabw20vsi+mTmdeRyBoMaQM9Cjzg7oNbFRFB+3G+CO6Fkupyea2Uu7r5k6V1cHY/67i4w1HOVJYRWGllVPlFvacKOffP+yj33M/8/Dn2yipxSG8IQjoyhTAszd0ZW9eOftOlrvdUXOHIktEGLW8f0//gJ2W3NxwOFXMtdxAffniK7JEUaX/vtACQUOxN6+MOasOs2RbLpIEiiRhP72b0DoiiMJKK61FP2bgkXEZPHwA9v8Aa15xDedUdIAEDisVwUl8IN/A/Q886pL4NTfsFq5UV6P+tgAOWVz9XkGR0Ok66Hmbq3esFkINGrQauca99HxTp+COQ+sMIfQiZ1wJBAFFVBok9oajnk3NnhihZ95291I+q0PlG+3VjLU6CNJ5lwRc6ii1rM/P39zxNqFyN5MWONua8tVvx1lzsIDPpg8iOcq7cRq+EPBXRL1G4ZNpA5n8/ka255TWuSNm0MrEhRr4ZNpAYkM9lGMFPuFUVZBw1VQvwNcvvjCCEgQyp8rMTJu3mb15ZdgcqlvnsyNFVYx+YTmjOsbx8sSe4uYZaMgKdBzn+rFUgKkInK6kQ6cNY87zy7iuXCU5sMZKecZcBiv/A5vn0sXpRLZVwvkjFHN/g5+fhM7jYdTjENmmxkt0TghDI0tcuJ3li6mTRpaE4kPQ8hgyC/K2VzOx8drABrDE9+H7HD3P/esXbh+QwqRBqRe3Bs39DfJ2uP7utUEQngxtR7nkiAGIqqpUWOzYHCphBg0apW5FWVSwzq0zqK+bO95gc6qcLDNz01uZfPfQ8AYfPB7wyRRAsF7D/HsH8NXWXN5acYgjhVVYHdWzUKNOIdSgYeqwdG7tnyLs0BsYrSKjVWruaIJvX3ynqgpDEEHAcrSwihveWEOpyXa2CuUOh9OVZP269xTjX1/DwpmD/K7JFtQTfYjr5zQ64NruCSzecpyHLm/XdHH5QulxeH8clJ0Ah8W9Pt9W5frvjs9h37cw6UtI6lPtlP5pUQRpFbdGQmH9JyAHR1K69lMKlr6ApAtCH59B2KBbqp2nkSXuGZLWQB9MIAgQ2o2BmA5wcqfH3im3aIMIu/6fvJPYi6z8CuasPsxlLy5nXLcE7h2WRkZcaN2vAWAzwc7FsPplKDvuesxpB0kBRQOSDP2mQr8pEJboW4x+YkdOKe+uyuLbnSdQVVdPp83hJCM2hBkj2nJ19wQMWvebjZ0TwgjSKVRecD3yZXPnQjy5SjtVKKyw8tii7cy+s2+D/h4kT/7sffv2VTdtcj/IrCl5+PNtZBdU0ibaiMOpEhOiZ2SHOIZkRAtZnx+5670NrNjv3vqzbMNiStcvJvrK+z1+8cOCNGz+2xi0XuxaCPyDJEmbVVVt2CtJE9DQ16eSKitX/XcVp8rMPlVPdRqZbq3D+WTaQPG9biZsO1bCg5/8xvKHRwb+PaOqCN4aCuV5Lkmft+hC4N6fIa4TqqqyNquQV5cdZPcJlyvrhQ5a3tI9KZwlDzTMbrE7WsL1KVDXToI6MBXDO6NdiYzDS7dMTRDcNAc6Xl3t4cIKCx+tO8q8ddn0SIpg6vB0BqRF1X69KTzk2jCxlNcY8VANRe9Kqsa/AV0nePnBGp6Dp8qZOX8LOUUmLHaH23tm8GnFxh/GtGfK0DS3n/3Vhet4beMpLFLNhKti16+Ub/oKW+Gxaps7hqRObmOqzU7dcmxXtXWoTiOz+rFRxIUafPrMnq5NzbJ8c6SwklmXt2dou5imDuWSYvrwdDZlF9XYRQDvdjX1Gpm7BqWKBacgIHlz+SGKKixubwqedrusdid7TpTx3c48rusRGLuFAs90TwpHI0tsOVpMnzYBrvX7Yrprfs15iVTqK+VU2eDwQyEE61wLlHe3WPlou43ld58eXmytRJ03gRVX/8prv2ZRWGnlvpFteb1TLy5/aWW9elcNWplHruzQIB9LIAg4giJh+nJYcCuc2Ap2c+2W6dpg10iGWxdA+ogah6ND9Dx0eTumj0hn8ZbjPL54B8F6DVOHpzOua6vqMriCgzB7lCuRctdLcT5nkrwv73MlXb0n1eujXgxbjhYz6d31VFkdHqM9s1Z88cf9HC6o5NnxXc8mVLYTJ8h/9TUGrdnAa4Mfcvux63LsPh9fZuVJwIJ1R5k1pr1Xr+0NzS6Zstgd7Moto2eK0Gw3NoPaRhNq0LpNpqDuL74K3DGwpo5fIGhqLHYH89cfxepmt96b4YFVVgdvLT8kkqlmgiRJ3NgniYWbjwd2MlV6HLJWuJUdOVT473orjw+rrS9DxVRRzI9fLWDSlbdwTfdElNMjKRZMHcCNb2ZSafG+0hWkVfjTFe3FsF5By8YQDvd86zKuyfwf7PseNOf9jTntLondkFnQ9UbQeTYzMGgVfjcghVv7JfPL3lO8syqLf323l3uGpLpaUjCfq0jVlUidj90E3z4C0RnQpm4n5YYiK7+CO+dsqHUd6A6TzcHiLceJCdHzYP94Ct55h9KFi1Y3m6QAACAASURBVIi49VZ6LV3ME7uLeHbpnotyCfXFVdpid/Lxhks8mdp5vJT02GDRE9UESJLE8zd2Y+ZHmzHbvHNWPEOQVuGuwW2ID/OtrCoQNAbf78xDdXMj82W3K6uggr15ZXRsJWaNNAfG92zNuP+t4slrO9eq6W9yNs6mtgXWI4N1/HuNhfv66YgwuJcOGVUTz8UtQ+r5x2qPd2wVxqKZg7ntnXWYbU6PixiNLKFRJP52TSduHyA2wwSXAJIESX1h4ocumW3+XjCXgsbgSqRifa/OyrLE5Z3jubxzPFuPlTB7VRav/3qQ51M2cIWlHPm8v3OvKs/gSqiWPQOTv7voj+wtj3+xg0qrext5TwoOk83Bm7/sp89zD9FuWD/SlixBGx8HwO0DQimssPLG8oM+ry3P4KuderGpYQcsNzu91cbsYvoG8k5iC2dUhzieuKYzBq33Xx2DVmZM53geu6qjHyMTCOrP0u0n3O7S+7LbZXOoLNtzyh/hCfxAYkQQXRLDAvvfbPP7tTbD901UGJmq4YVMz70dUs5GKD9Z4/GOrcJY8egoHrmyA63CDATrFHSKa+GmkSWCdQpBWoVb+iXz7YPDRCIluDQxRkGbwdBhrMtNrx6J1IX0TI7g9d/1Zsn9QxiQ+xHyGfOY8zhTea6T3C1QdPiiY/KGY0VV/Ha0BHdWC2UbvqBo2WzCB04k6YGPaD1zLqG9x2E6sP7sOarTyfJpT5Dwj2fOJlJnePCydjx/QzcijdqzvVYXosgSWkVCdrN35OusPLujfklbbTS78s6m7GLG9xJSmqbkdwPaEB2i54+fbQWVWsu9Bo2MzanSKszAK7f0CPxGb8ElS2GF+wWpL7tdDqda6+sIApMJvZJYvCWHq7snNHUoNXHYweR5MO8zo/QMea+ShwZ4sPnV6F0N9aHxNQ6FGbRMHprGPUNSWZtVyO7cMsrNdoJ0CgnhBsZ0jseoa3bLBIGgWZBcuQMoc3vMm8oz4Orp2vAOXPW8f4I8jw/WZrvG5FyAtwoOm6Tw2b4yHrM70GvO3VNLqqzkl1toFx/K4vuGsC+vjNmrDrPtWAlO1aUZMWgUru6eQL/USJ79Zg/l5urVMV/t1BtajdCsrpKqqrL5SBHPju/a1KFc8lzZpRWb/zaGb7af4K0VhzhaVIVOkUECu0MlSKcweWgqN/VO4u65G/lyay4Teic1ddgCgU/UZ3igoPlwVddWPPX1LvLLLYE3l9Bucs3LcrqX1AB0jVO4pr2Gf6620inWg1rAZvL4VpIkMbhtDIPbClMngaDRyP0NHO7lZudXnp8d7aE9wmGFI5l+CrA6X2/LdesC6ouCA2DzkWIGpkWz4kA+b684xJYjxWg1MhISTlVFkSUmDWzDa7/rRaRRd1pq7Lq+FVda+ftXu2q8pq926l0SG1aO36ySqUP5lQTrNbQKF303gYBBq3BjnyRu7JPEiVIThRVWHE7XHKnkKOPZZucXbu7BXe9tYEhGjOiZEgQk0SHuF9K+7HYpslTr6wgCk2C9hjGd41myLZcpQwNsdpI22Csr9KdHGuj9dgV/GlS7EQUG0ccnEAQc5jKPM628qjwDWNxXtxqaC6tBZ/C1X2n7sVJmfbKVSov9rLLJ6qh+rXt39WHmrD7MdT0S+b8J3c4+Hhms4/JOcXy/M6+G8663s/KC9QozRrT1KlZvaVbJ1KbsIvqlin6pQCQhPIiE8CC3x7q2Duf2gW34y+IdzLmrr5D7CQKOa7onkHmooEbflC+7XVpFYnTHuAtfWhDg3Ng7if/7dk/gJVOyDBFtoDjb42kZUTK3dNHyvw1WusW5qU45bBCZ6pcQBQLBRaDRg6yptfrsdeVZ0zib1LWNpfVFwWF3Onnpp/1Y6+hZstpdx7/ensuRwirm3dsfvUbBWVXFzbkb+cUej1mpmWR6Y6eu1yiM7NCw9+pmZUCx6UgxfdpENnUYgnrwwKgMTpSaWbg5p6lDEQhqcFXXVki4T/LD+k8gcvQUStd+Ss6rt5Pz5t2Ub1lKULvqkoa0mGA6JYgKQHNjUHo0xZVW9uY1zu6uTwx+ELSerZcBnhihp9LqZqUjKdB5POhD/RCcQCC4KMJa15kIPT3SwOwtVo6XebBND09u4MDcE6x3nyidr+CoC7PNWWcideH524+XMOvj3yj5agmHxo6j3clDdE+JQqfxPYUJ0io8emWHs8qphqLZVaamDktv6jAE9UCnkXnh5u5MmrOBoe1iaq1iCQRNgV6jcPuAFOauOex21lRdu11GncLMkRn+DFHgJ2RZYnyv1q7BmuMCLBnuPhF++GuNh7NnVU+OksNlzH9zE7tGB4Pu91d0AoHgYugwFpb83uMpdVWeTVIQO2In0NPurFdy4QujO8axaHMOF94ife1XcocnW3WzzcmvO3JYc/wHRr38MsbevZhjtnHda2s4XmzyOjkLOjPzq39KfT6+R5pNZSq/3EJxlY12cSFNHYqgnnRJDOeuQan8edEO1NrqxQJBEzFjRFsig3VubVc9odPIdGwVytiurfwTmMDvTOidxJe/HW9wu9yLRh8K/ad6VZ2qgaKDxN6Q0L3h4xIIBBePPsS1YSJ5rmvUWnkGJF0QLx9OYvA/l/F/3+4hK7/iosM6XmLin9/t5aY3M7ni5RVc99pqfr9gC/1To9AoF6fgcIc3tupWWcPSsfdi7N0LgFCDlq8eGELX1mEYa7FSP4NGltBrZGaObMvfru7kw2/Ce5pNZWrzkSJ6p0QgN3BpTtC43DeqLeNfX8Pnm3KY2K9xStMCgTdEBuv4bPogbngjkzKTDfuF3a1u0EkqqdFG3p/cH63SbPamBBeQERdCQkQQqw8WNLiW/qK5/Ck4ucvl2GX37Mp3FlkLoYlw28f+jEwgEFwsg+6H7Z+B/VzflNeVZ20QhmEP8fHQoRwuqOSTjUeZ+PZa2saG8LsBKVzZpZVPFuBbjhbz4g/72HSkGFVVq6k0duSUsmzvqRqmD+fjTb/ShXhrq+5E4pe9pyiqtBIV7OqVCjNoWThjMKsOFvD2ikNsPlKMRpZwOFUkSUKRXe6AN/ZO4p4hqaTH+q8Y02ySqY3ZxfQV5hPNHq0i88LNPbj93fUMbRdDYoSQ+wkChzbRwXz/0DCmfriJfSfLsdkcONz0UmkVCUlV6XtqP289cBthBm0TRCtoSG7s7ZL6BVwyJStw2yc4F0/Dsvs7DJhr6e47jTYYotPhziVgCG+sKAUCQX2I7QCXPwnLngE3w3trRdFD6z4w6AHA1bP7l7Gd+NOYDvy85yQfbzjK01/v5vqeidzWP4X28Z77JhduOsbfvtqJ2ea+Oq8CVbXMFPUG6fRrXIgvtupaRWJdViHjup2bCyjLEiPaxzKifSw5xVVsPlJMmcmGVpGJDtEzNCOGoDoqVw1Bs0mmNh0p5vGxHZs6DEED0CkhjHsGp/LnxTv44J5+wt1PEFDEhRn46oGh7M4t43//WcAvhmRkjYwiSdidKlpF5vYBKUwa1AbdBwcofe5ZQl9/TXyPmznXdk/kPz/so9xsIzTQkmONjo+SnuRoXj/+GvGTq0olSWA3u44rOpfZRHQGDJ0Fna5z9UsJBILAZ+BMsFTAqhe9qz5rgiCxJ9z2KSjVl/E6jcy4bgmM65bAsaIqPt14jDveXU9ylJHb+qdwdbeEGsnF19tyPSZSblFV1zXIC2TJdao7FbUvtuoOJ5RUuZ/LBZAUaSQpsh6S6AagWSRTJquD/Xnl9EiOaOpQBA3EjJFt+fGNTD7deMwvzYACwcXSMULDH1a+y0s//UyZYqDKaifEoCHKqDs7QNB5331k33gTZUuXEn7ttRwrquKDtdlsyi6m3GzDoFVIjjRyx8A2DMmIFglXABMZrGNw22i+3XGCW/oF1jUpv9zCK8sO8sm0O5Hi74fSHNj3HVQWuGyVjVGQNgJaiYH2AkGzZMQjrr/fn5+CkqNgt9ScM6cLcVmpD7wPhv0RFM+bPslRRh6+sgOzLm/HL3tP8cnGYzz7zW6u7Z7Irf2T6ZIYzvESE48s3OY2kfJkCnEmkdLIYK8jB3OquC9L4ZutOuBzT3Nj0SySqa3HSuiYEOqT9lMQ2JyR+902ex3D2sfSWsj9BAFGZWYmhu7dMEZFUNtel6zTkfD883z/yDMsOBLO9rxKnKpabUr8rtwyVh7IJ8ygZcaIdO4clCp6PwOUCb2TmLP6cMAlU89/u4eb+iSdk+qEJ7mMKQQCQcuhw1jXT+5WWPcGHN8C1nKXfXp4MgyYDu3H1qhG1YVGkbmiSyuu6NKK3BITn206xtQPNhEbqic8SIvTTSNU2YYvKF2/kOgr7seQ1htJ0WA6vBnTgfWuZOo0qiohodaWK9XJ+bbqwR2HejxXkSHCGJgV92aRTIlhvS2TDq1CmTI0jccWbmfelP5i114QUJT/8iuho0bXed4SUyhP9LgDS055redUWR1UWR386/t9/Lovn7cn9RGbQwHIqA5x/HnRdj5cm822YyUUVFhRZIlW4Qau75FI/7SoRr9Orc8qZG1WIT//cUSjvq9AIGgiEnvChHf889IRQcy6vD2/H92OZXtOMnP+FhwXJFPemkIAOC7SmdkXW3WbQ2VIRvRFvZ+/aB7J1JFifjcgsHYKBQ3D9OHp/LgrjwUbjnL7gDZNHY5AAIDqcFCxfDkx9830eN5Xvx3nySW7sHg5ZcJkc7A+q5Bp8zYz9+5+DT44UFB/iiqtzFmdRaXFzj+W7q5WXZSAL387TqRRx7Th6dzWP8XvM10AbA4nf/9qJ3+/pjPB+mZxuxYIBM0ARZZQAYNWptJSXU7oiymEOzzKA90Q1n8CcnAkpWs/pWDpC0i6IPTxGYQNuqVavNf2SAy8ftbTBPzV2eFU2XK0mJcm9mjqUAR+QHNa7nfLO+sY3i6W5KimaR4UCM7HtG07mpgYdElJtZ6TU1zFY4u3+6w1N9udbDxcyJzVWUwb3tafH0PgJYfyK7j17XWUmqxuhzafcbKqspr453d7WLwlhw8nDyDc6N8b+/trsokPM4gZZgKBoME5XmzCZq95vfPFFOJCvJUHXkhdtupaRWLK0DSf42ksAn4wyv6T5cSG6okO0Td1KAI/0S4+lKnD0nls0Xa32l2BoLGp+PUXQi7zLPH7MPNIDXkEeDeA0GRz8s7KLPF9DwByiquY8EYmBZUWt4nUhZhsTnafKGPi22sxXYRVcF2cKDXxxvKDPHN9VyGBFggEDUZJlZV3Vh7ixR/3YXVjsXe+KYQvnJEHRo2ZibHDYGSdAUnRYMwYUE2u5ytBWoWb+iTRKcHNrK0AIeArU5uyi+jbJrKpwxD4manD0vhhVx7zNxxl0kAh9xM0DiargyXbjrN0+wmKKq0ARIfoGLw5h1sfvafW51nsDhZsOFpNCga+ac1NVgcrD+QH3lyjSwhVVblzzgYqzHYulP57qi7aHCrZhZX8ZfEOXrm1p19ie3bpHiYNbENaTLBfXl8gEFxa2BxOnl6yi8835yBLro0hd/hiCnE+FysPdEeQVmFkh1ievi6wnUoDPpnamF3M0HYxTR2GwM+ckftNfHstI9sLuZ/Av+SXW/jfsgMs3JyDJNUcRrgxcRivLs7llmMKvx/djsjg6g5Cv+7Nd/u6vtxMKq0OPsjMFslUE7I2q5C8MnONJmpvpCoWu5Nvd57gicrORAU3rMPUyv35bD9ewotC3i4QCBoAs83BnXM2sON4CZY6vMx9MYU4H1/lgRIgS5JbEwujTsGpqkwZmsafrmgf8NX5gE+mNh8pZtbl7Zo6DEEjkBEXwvTh6TyycBsL7h0o7KMFfuHgqXJueXsdpSYb9lpkdiaNHqwO5q07wrc78vhs+iBSos8l+LklJrfyCF9vJkeLfJh4L2hw3lmRVUOq50t1UQY+2XCU+0Zl+PS+5WYbK/cXUFhpwe5QCQ/SMiA9iqRIIxa7gyeX7OKpa7sIx0eBQHDROJ0q9y/YwvacEsx1DYU6jTemEBfi68yoxIgg+qdFsXR7ruv5koTV4SQxIogZI9IZ3yuJkGZivBPQUeaWmDDbHELmcAlx77B0vt+Vx0frj3DnoNSmDkfQwjhWVMWNb66lzGTzai6GzaFyqtzMDW+s4btZw4gLNQAuVz5HHVpzb24mde0QCvxHUaWVzKzCGt8DX6qLZruTuZnZXidTe/PKmLPqMF9vz0Ujy9gdTpyAVpawO1V6p0SSEG6gbWwwl3WK9/1DCQQCwQUs33+KtYcK3SZSnuTMdZlCXIiv8sC4MD0v39KTF2/uQbnFjsXuIMygbZabSAGdTG06UkyfNpEBX94TNByKLPHCzT246c1MRrSPpU20SKQFDYOqqkz5YKOrP+aCY55uKE4VSk02ZszbzOL7hgAQZtCgVWQcF9ycfL2ZNJddt5ZITnEVekXGesG/oa/VxYIKC06n6rGSrqoqz3+7lw/XZWNzqKeNS869r/X0f9dmFQLQrXU45WZbwNoACwSC5sNbK7JqSNmh/s57teGLPDBIq3B9j0TX82SJ8CAt0HyvdwF9JxfDei9N2saGcP+oDB5ZuJ1Ppgq5n6Bh2HqshGNFpnr1x9idKrtzyzh4qpyMuFA6JoS5/V76cjNRJOieFO7fDy2olUqLwyXavwBfq4uyJGG2OzDq3N9OVVXlsUXb+XrbCbc2+u7Yd7KcG97I5Kv7h4j5UgKBoN4cLaxi27GSGo/7Imd2h4wTJxIXXkS9lQeqqNzYp/bRI82NgL5Kb8ouZnyv1k0dhqAJuGdIGt/vzOODtdncMyRwZwsImg+zV2Vhsde/P8bmdDJndTbPT+hG3zaRRBp1VFlNNd7H25uJViMzOYDnZrR0QvQa3Gk9fa0uOlWVIA+ylHdXZ/H1thOYbN7bDFvtTo4VVTFj3mbm3Vv3wkYgEAjc8fOek24fr6/zniJLaBWJCd1b8eWWo1SpNc136pIHKjJc2z1wB/DWh4BNpsrNNrILK+maKHZuL0UUWeLfN3XnxjczGdUhjlTRNye4CKqsdn7efYoL/SZ8uaE4nPDFbzn84/ouaBSZGSPS+b9v97pdJHujNU+PCaFjq8Cdm9HSSY4Kcmsi4quTVbS1gpPPPodx4ACC+/VDiYg4e8xqd/K/ZQdrTaQ8yUstdiebjhSxO7eMzonieyIQCHynqNLqtje3Ps57Oo3MtT0SuXdYGh3iQzlZbmXNgZOYVV+SIifIFu4Z3rJcugMnmaosgP0/QFUBOB3klWkYE5+IThPwc4X9hqqqbDpSzLI9pzhVZkaSIC7MwBWd4+mZHNHie8nSY0N4YHQ7Hlm4jU+nDRJyP0G9KaywolEkLpSN+3pDcapQbrYTGazjht5JvPLzAcw2h1dmFudj0Mo8clUHH58laEgijDqGt4/l5z0na8yY8ra6GKSVmdwrCW15OSWffsaJP/8FXZs2GAcOJHhAf5YZUnC6sf0F7+SlVofKnNVZvDjRP7OsBAJBy8bdYHnwXc7cKtzAj38YXq2a9Pqk/tzy9gq2Hy9FdVOhchMNyBaCU95l1qp3+PzazwnXt4yCSdMnU8c2wJr/wcEfQdKAwwKqSqqk5QVVhXffhiEPQfurQGn6cBsDs83Bws3HeGtFFkWVVkzWc4s1CXh/TTbxYXpmjmzLDb2SWnTCec/gVH7YmcfczGymCEmUoJ5UWR24y8V9vaEokkSVzUEkLpnYx9MGcsMba1z9N15ikFUevKwdo8R8qSZn2vB01hwscNuc7U110anC7df2J9w4hOh770W1WjHt3EnlunUUznmPV0OGUhmWWPN5XspLHU6VpdtP8NR1XVqUJEYgEDQOUcE6tIpUY8C8r3Lm9JjgGtcgg1ZBm/w22opOWEt7AiqcTqoqdvxM2YYvsJfkIemDMLYfQNSYsYSkLULVFVNg0vLoykd5e8zbDfZZm5KmW4U7HbDkQfjweti7FOwWsFWC0w6qA63TjFa1QM4G+GIavHclmEubLNzGoqDCwvWvreG5b/aSU2yiylp911vFZcucXVjFU0t2c9ObmZRW2ZoqXL8jn5b7vfbLAbLyK5o6HEEzJcSgwY2iq9oNxRvsTmc1B7728aEsmjmYqGCdx74ZcOnEDYrE5H0/ck+UmC8VCPRtE0lypBGlHlVvg1ZmfK/WhBvPLTAknQ5j797E3ncfbT78gCMR7nt+fZGXahWZrPxKn+MTCASCYe1j3F7fzpczV+1fi9NmRnXYMR3aRPGv71U716hTGNctocZr7CrYxZHyLPQJiwjO+Be6mF9BqaRsw0KKl88lcvQkkv/wEUkzJ6NaN5P/+V9AdjmW2pw2Np3cRE55jn8+eCPTNMmU0wmf3wU7PgdbFW67gM/HWgl522H2aLCUN0qITUGpycb419eQlV/hVbOyyeZgT14ZE95cQ6XF3ggRNg2pMcE8dFk7Hlm4vdaStUDgidgQvdvKlC83FIBgvYYwQ/UKecdWYfz68EgevqI98WF6gnUKWsX1ZrIEwToFg1bmpj7JfP3gMGbeP56c3z+I7eQpv3xWgfdIksSHU/oTHqR1+/2oDb1Gpn18KM9cX7t9sM3hxF6LxM8Xeakkue4NAoFA4CsdW4WRHhPi9lhY/wlEjp5C6dpPyXn1dnLevJvyLUsJald9k8epqm7N4D7Y/QFWp2uwg6ypQB/zK0FJT1Oa+SEJdyQSc8XPhHZ+lrCuP5Ly+1isBVZKM88VRVRV5eO9Hzfgp206mkY3t+KfcHDZ6UTKSxxWKDkGn94Bd37lv9iakPvmb+ZUmQXbBQmDpyZlm0Mlp9jEHz7byjuT+jZR5P7nzkGpfLczj7lrDnPvsPSmDkfQzNBpZG7pl8y8dUdqyB287Y/Ra2TuGZzmtlcxPEjLlGHpTB6axtpDhezKLaPMZMOgU0gIN3BFl1bnKlrxo7EcOEjOAw/QZt6HyAaD3z63oG7iwwx8df8Qbn1nLYUVVreDLc/HqFPonhTOnLv6odfUngxpZAkJ91uFvspLg3TNb4ilQCAIDGaMbMufF22vl5xZkWF8z9Y1RjRU2apYdnQZTrX69dJ0sBKnzUnEABlJOacmUgwKod1DqdhVQeTwSMBVnVq4fyEP93242XsANH4yZa2CzFdrJFKpr5RTZYPDD4UQrHP9Ut/dYuWj7TaW333ayc1hgWPr4eQuiPd9oFggc/BUBZuyi2u4S3nTpGyxO1mxL5/cEhOJEUFNEb7fkWWJ/9zUg/FvrGFUxzjaxrrfaREIauPuwWnMX38Ud8tbbye9/25AisfjkiQxOCOGwRmenYqip03Fsm8fJ/7+BIn//tfZG4nJ6uC7nSc4XFBJmdlGRJCWtnGhXNkl3uPCXXBxJEcZ+X7WcOavP8qcVYepstqpPG/hIUug1ygkRQYxY0Rbru+ZiEbxLOyQJNcgymI3Mmxf+hVsdifxoSLhFggE9WNs11bMWZXF7hNlNTYT6yLUoGXW5e1rPF5gKkAjabCeHTnuwlHhQBOiQVJqJkeacA2mI9XHiVidViptlYTomvearvGTqV2LcTspEXCo8N/1Vh4fpq/9+XYrrH0Dxr/un/iaiLlrDteQsPkyA0dVVT5ce4Q/j+3YKPE2BSnRRmZd3o6HP9/GwhmDXTrgU3tg26dQehTsZjDGQNpw6HQdaLxxlxFcKqREGxnePpaV+/PdWsV6wqCVubpbArGhHq5NPiBJEgnPPcuROyZRNGcOZeNvY87qLBZtPo4sUW0hH6xT+MtiuK1/CncPTiUp0tggMQiqE2rQMmNEW6YNS2flgXxW7Mvn2Ibf0BmDSOnVhWu6J9C1tW/OU7f2S2HO6iysFyxgfLFfT40JJiVa/JsLBIL6oVVkPpw8gMtfWkF+hcWr58iSS9a+4N6BtAqvuZlTZa9ClmpuKCkhCvYKO6pDrZFQ2UvtaEKqpx0aSSOSqXqx+hWX0YQbHhms499rLNzXT0eEoZaSn+qAnYtg7D9BH+rHQBsPi93B4i3HsV+QTPnSpGx1qMxff4RHr+zQoi3E7xjQhu+2n+CXRW8zpnABFOx3SUDV88rXOxfB0lnQZzIMuh9C45suYEFA8d9bezL+9TVkF1Zh9TKh0mtkOsSH8n8TujVoLHJQEEmvv8bcmU/yUlYCdqQa1wA4l1h9kJnN/HVHeeOO3sIJ0I/IssTIDnGM7BBH3s6v0CYmEl3PTapJg9owZ81h3FVDvZGXBusUZo5sW9+PIhAIBAAcLqzE7nTSOyWC3bll2J2q2/uNhEtWHBuq54N7+tc64zNYE1xD4gdgzDAiaSTKNpcR3v/c5pPD7KB8eznxN1Vfj9md9mafSEFjJ1OqCkWHaj3cN1FhZKqGFzItPDvag6xB0UBRFiT08EOQjU9BhdXt477OwLHYHC5pkLHlVmRkp405xldh189ALTss1tM63fVvwpYP4K4lLea7Irg4jDoNC2cO5u73NrA3r9ythrz6+Qq9UyJ5584+fpHZLc2181Ln6zA7oC4jHptDxeZwMPOjzbxxe29GdxSbBP5GkqQ6/ZE8kRgRxMC0KDIPFbpduNTdryBxVddW9Q9AIBBc8uSVmpkxbzP/urE7V3RpxcFT5by3OpsvfnMpIWRJwomKza4yNCOGaSPSGZAW5bGPKdYY6zaZUowKcePjyP0oF9kgE9I5BFuxjdx5uWijtEQMjqh2vlFrxKhp/pX3xk2mbCZqk/id4ZlReoa8V8lDAzwlBBKYyxo0tKak0mJHdiO/93kGjixTbra33GTK6YRPb8d4ZBW1JlLn47C6fuaOhSk/Q3xnv4coCHzCDFo+mz6IH3ad5K0VBzlwsgKHqp7Vkus0MhLQtugo90/oz1Wje9bLOrsudh4v5S9f7DidSFXHk+mM2ebk/vm/8d1Dw2rdNRQ0FBI1Jvr6yIsTezLuv6sorLTgixmpQSvzbh0mFwKBQOAJs83BtHmbmDSoDVd0cW3MZMS5lBZPXNuZ3BIT5WY7Rp1CXJiB8CDv5tkZqnuIagAAIABJREFUNAauTr+aLw9+iUOtfhOLHReLEqyQ92ke1lNW5CCZsN5hJE9PRtaeW+zqZB23dbyt2ZtPQGMnUxoDuMlkz6drnMI17TX8c7WVTrEeGny1zT+TPYNRp+CsYwaON0PVHKpabQZOi2PVi5C9GuznGhi9Mi6xVsGH18EfdoGmYXpeBM0bjSJzdfcEru6ewP6T5azYl09RpRVJcg05vKxTPCGffYBt+SKUy3v5JYbXfjnotnfLG9MZm8PJ7FVZPHdDw0oPBRcgXXwyFRuqZ9HMwdz8diZFlVavGsCDtApv3N6b/mlRF/XeAoHg0kVVVR5duJ3U6GDucyMXNmgV0i/CzGtS50kszVqKw1FzRzBqRBRRI+q+fk3sMLHe7x9INO6cKVkGQ90NvE+PNDB7i5XjZbXcdBwWCKs5QKy5EhOix+nmhu3rDByNLBHm5a5Cs8Nhc+sCCeeMS2pHdVVFdy/xX3yCZkv7+FCmDk/nsbEdefSqjtw7LJ20mGAib7uN8h9/wl5Q0ODvWVhh4Zd9p2qs08+YzkSNmYmxw2BknQFJ0WDMGFDNlMDuVFm85ThV1pY7Xy4gkGozN/eNlGgj3z00nPE9W2PQym4HPOsUCb1GZnDbaBbOHMSojqIvTiAQ1J83lh8iu7CSf9/U3S/Vn7YRbekW0w2t7Pu6U6/oGZUyijhjy7jONf7Q3t53guJZhpYRJXNLFy3/21DLArlVdwhL9ENwTYNBq3BN9wS3UiJvh6ppZYmJfZP9IkcKCPYurW4ycR6PDNbxQqaFErOHRY+1Ala/7KfgBC0RTWQkYWPHUrxgQYO/9uebjrm9+PpiOiNJsHTbiQaPTXAeDVCZOkNUsI7/3NyDTX8bw+PjOtIlMYzEcAMxWEnT2bl7cBrL/jSCBVMH0iXRN9dAgUAgOJ8fd+Uxb+0R3pnUF4ObzZuG4pVRrxAbFItG8l4VpZN1JIcm848h//BbXI1N42vC+k+FDW/XedoTI/TM2+5m6rsuBIbM8kNgTcuUoel8s+NEDXt08G4GjixL3DMk1U/RBQCZr54zlrgAr41LirPg1F6Ia7n28YKGJequuzhyxx1ET52KHNRwM9y25ZS6HQ7ri+lMldXBztxSJpLcYHEJLkBySWUakhC9hkmDUpk0KBWA/DfeQLVYibv6+gZ9H4FAcGmyN6+MPy/ewXt393Nra96QhOvDWXD1Aqb8OIXj5ccxO8wezw/SBNE+sj1vXv4mQZqWMxe18ZOpiBRoMwQOrwLnuWQpe1Z1m/PkcBnz38JqPl8bBO2v8neUjU7nxDA6tgpj5/FSt65PntAqEn1TI2kT3YKb0YsOezzslXGJrIHCgyKZEniNPj2NoB49KP1qCYYJN7H5SDHFVVZUINKopXdKZI3J8N5QanKzUYTvpjPuBsIC5Ffl89m+z1h4YCGlllLsTjsGjYGOUR2Z3HUyw1oPQ/HSJfRS5mLd/LxBEx2Dacd2/76JQCC4JCissDD1w008eW1neiZH1P2EBiA6KJpPrv6EJYeWMHfnXArNhZjtZtTTF08JCYNiICEkgcldJzMufVy9pIGBTNO4FUyYDW8OgcpTdRpSVENrhNsXuqzRWyCz7+zL2P+upKjCgrMO18MzaGSJ2BA9r/+ut5+ja2LsJo+HvTIuUZ21VrcEgtoov/kOXpi/nB93/XRWRquiIiFhdzq5oVcSU4amkRHnfSNvbQmYr6YzoRe8TpG5iCfXPEnmiUxQXdPlz2Cym/jt1G88tvIx9IqeB3s/yE3tb/I65kuSBpT51YYmNgZHQaFf30MgEDRvKi12ftp9ktxSEyargzCDls6JYQxuG322H8pqdzJz/hau6Z7I9T1bN2p8Bo2BiR0mcnP7m9mav5WVx1ZSYCpAlmRijDGMThlNl+gujRpTY9I0WUlwDEz50WVZXZnvsq/2iAS6YPjdp5DYs1FCbApiQ/V8elUrfjdvOyUhkVjqcH3Sa2QSI4L4ZNrAlmuHfgaN4bS1fu08PdJA77cr+NOgWhz7JNklExUIvMDpVHnu2z18tK4UR0xX7Db3PXufbzrGF1tyuLFPEs9c39WrvsWM2BB+lU/VqEKfbzojyQqGtF5IsgZz9lbMR7dXM6HQa2TSY89Vo3PKc7jzuzspNhdjV2s3pqiyV1Flr+JfG/5FVkkWj/R7pEVY0/qHRkimoqP9YnIiEAiaPwdPVTBn9WG+PD0Tymxz4FBdiiSdIhNi0DB1WDo390nin9/vI8yg4ZErOjRZvJIk0SuuF73i/OOCG6g0XYknsg3MWA2/PgdbFwAS2Cqrn6MxuG5k7S6Hy5+GmHZNEmpj4SgtRX38YRbd9yCLQjKYm5mNzeGk0lJ9EResVzBoFKYMTeOuwan1khk1OyLTwFTs8ZTzjUu6xbmpTjntEJXupwAFLQlVVfnDZ1v5cddJl325VLsk7swk+cVbjpNfbuGtO/og15FQ3do/mffWHHYr6Q3rPwE5OJLStZ9SsPQFJF0Q+vgMwgbdUu08h1Pluh4uI55iczF3fX8XhaZCnHhX7Tc7zHy+/3PC9eFM7zHdq+dccjSQm58nlOgY7IWiMiUQCKozb202z327B7tDrXGvODPEvdLq4MUf9/PST/uJC9Gz9KFhdd5/BA1P067CjVFw9Ysw5h+wcxFs/QgqCyiuMGPRhtKq343Q524IiW3SMBsD1enk+KOPEjJyBK3GX81DwP2jMvhl7ym+25lHQYUFW04OUQ4zN982huHtY1uuc587Bj8ASx6sU6ZXq3EJQESqGNwr8Ir/LTvAj7tOYqqlGuUOk83BqgMFPP/dHv56tefvWZvoYLq1DmfTEfcbBHWZzkhAkE7hd++uZ+qwNH4zvUWRuahGIlW8qpiCHwqwnrKiGBTC+oQRf1M8SrArOTQ7zMzeMZur0q6iTVgbrz/rJUNjyPxionEUFKCqqqgQCgQCAGavzOKln/ZjttW9OXbmPnWqwkJOcRUdW7nxGxD4lcAoaeiM0HuS6wdYuu4Iu3PLeH7EpTOQsuD1N3BWVhL/yCNnH9MoMld0aXV2anXp1/lU/LqG1h1vb6owm46O18LXNV0cvTUuMUtBWPo8gDAcFtRFhcXOmysOub2JVe5eTtnGL7EV5iDrgtDGpRM+eOLZYbomm4MP1x7hvpEZRAZ7lt4+MDqDmR9t8SlhO4NBK/PBPf0x2xy8uWonW6XvQKou7Sv4roD87/JJujeJkM4h2Ipt5M7LJfuFbNL+moascVVvHU4H83fP5/GBj/scR4vHD25+FyIbDEh6Pc6yMpRwcYUSCC51Vu7P58Wf9nmVSJ1PldXBbe+sY+Wjowg1tCyDh0Cn8edMeUF8qJ5TZZ7tFVsS5b/+SsnChSS9/DKStvY/AG1iArbcS3SujEYHg+53mZD4iAo4FT1X/hjJ3DWHsTt8u0AJLi2+2JKD5MYApmzDFxQtm034wIkkPfARrWfOJbT3OEwH1lc7T5Lg041H63yfkR3imDSwjdsBrp4I0io8dHk7ereJZHBGDFcNPI5eU31fzGFycOrLUyTekUho91AkjYQuVkfyfclYC6yUZpaePdeu2vny0JeY6jB5uRSRZNnvlSk43TclpH4CgQD4zw+1J1KVu5dz4oNZHH3pJnJem8TJz57EnLPr7HGzzcmizTmNFargNAGZTMWFGThVbmnqMBoF65EjnPjr32j9ystoYj3LGbWJidhycxspsgBk+COQMsjVS+cDki4Y473f8NGM4fy46yTXvbaGLUc9918JLk1UVeWtFVk1qkVOSyUlq+cTNWYmxg6DkXUGJEWDMWNANVMIcN3M3l19GKcXIw7+n73zjo+izv//c2a2ZdM7CQkJEDqE3gJIEVEBsWFBVFBsWE7OOz3v9PR+enffK56Hp6dgQWwIcvYuFpoEUJDeAwECpJGQsiW7OzO/P1ZK2E12FxOygc/z8cjjAbMzn/lMMjvzebfX+/fju3LD4HZBG1QWo8Ldozpy5wUdT2x7d9e71J3W28O+247m1ojpXz9Kq1gUonOjqd1aP11WlmRWH14d1BzOLyQIsVXFmaAkJQkRCoFAwJ7SGnaX1vj9LBiHnsOt8uKKvc0eURfUJyyNqdQYMyXnQWRKs9spuvc+ku69B2vfwMonhuRkPJWV6K5A6ofnKLICU96GjhcGF6FSTGCJhWmfQJue5KREs+D2wdw5sgMz31zHw+9uotIW2u/SVufhYIWdPaU1lFQ7/TZZFrReauo8fp89dYd2oHtcWDsPDW4cp4fy2sAOIUmS+OPE7jxzfR96pMdgMcjIWn1DzihLmA0y/bPieenm/tx3Yad6tTWVdb6OAbVWxRBlQFJ8I2yGWAOe2vopgaqmUuGsCOrazivOQs0UgCEpCVVEpgSC8555K73CY6cTikPvmN3N2n3ieX42CY+aqdNIijJTYXOhavo5K7Kg6zpH/vgYlh49iJ8yJahjJIMBQ3IS7tJSTBkZzTzDMMVghuvehK3vwcp/Q0UBeFygn1yA1skRgIR50K0w9F6ISTvxmSRJXN6nLaO7pvD0V7u46N/LePDiLlzTP7NBBRxd11mzr4IXl+1lxZ4yjIqMJHnV1CwGhZvzsrhxcBYpMc3baVzQ/FQ73BgVGc9pBo3qqEa2xgTVSBfAoEhUOdxB3xPHayPXv/sFr684wJFu/amt8xBtMdI9LZqbhmbTPsl/U25V8625UqIUPLUedFX3Mag8VR4MUfUf/To6bq0B4ZbzGB0ptF6IZ4ghMRFPmYhMCQTnOz8drMRfJUIoDj1V09lRXMPgDonNMEOBP8LSmDIqMnFWE0dr687ZBWrl669Tt28v2QsWhKTgdDzV77w1pgBkGXpN9v4Ub4FNC+HYQfA4wZrE0YSBTFnZhu8uurhBAynGYuRPk3owuX8Gj36whUU/HOTJK3rSI71+AfiukhpmzP+BozYXDpeKDrjVk4tXp1tj7rK9zFm2lyv7tuXPV/TEqIRlwFcQBCZF9pseoUTEoNmr0TU1KINK18FkaPw+KK+tY+EPB1i//xjVDjcRJoWUvUe4rkssQ+8MLgIGEGWK8olOWXOsSAaJ6nXVxA46eU+rTpWaTTWkTk6tf32SQozJV7jFo2oosnReqcxtPVzFyyv2sWRbCfa6LHQg4rEvGNEpiTsu6EC/dvFN9/vQNNi3lGjLj8gHv4FP1kNMBvS4AhI7Bj5eIBCcU9Q6/fcIDMWh5/Jo1DiFc+xsEpbGFEBKtJmS6nPTmLKtXUv5iy+RvWghsiW06zOmned1U6fTpie0+XO9TemAed0y1h2oZGB2QqOH92wby3sz81j040GmzVvLZb3TeeCizkRbjKw/UMlNL6/B/rMR1RB1Hq8b6aMNhygst/H6jEGYDaGJCgjCg1ir0W/vJ3PbrkgGI/Zd+UR2HR5wHLeqkdCAmt/Ww1U8881ulu0sA07ePwCynsSnRQrZs5dz75gcJvRKC7hwz0vP491d79Zr1KtYFVKuSOHwm4eRLXI9NT9jgpG4vLh6Y3g0D/1S++FWNZZsK2HOsgK2H6nGo+pIEsREGLm2fyY352WRER+6CExrYHNRFb9dvJEDFXZcHg1V1+FnIRK7S+WrbSWs2F1OcrSZv12Vy9COv8Dr66yG9a/BqufAVYPVZfOe6cefQDbC8n9Am1wYPgu6jP+535VAIDjXsTRQPxuKQ8+gSA2OI2gewtaF3irrpir2wue/g5fGwHMDYM4IWHQz7Ft+Iu/eXVLC4d/8lvS///2MokvGtDQ8R85TRb8QuLxPWz7ccCiofWVZYsqgdnz165HY61TGPr2MV1buY9ora7EFMKROxeHW2Fh0jFkLN4jiz1aK2aAwqkuyj5afbI4kbvhUKpbMwb4rH83tRFc9OAp+pPK7eT7jDMiK9ytN+/HGw1z9wiqWbPM2Az7VkALQJBmn6k3ReHDxJn67eGNA9ckbu92I4uflmjw+mdSrUyleVMy2mdsoeLIAY4KR9g+1RzbWf/T3S+3HV5uc9H9yCQ/+byObiqpwq7pXCVP35uDPX7WPC/+1jKkvrz7n1FaX7Srj2rn57CypweFWfzak6qPrXqNq/1E7t8xfywc/Bfd88aFiHzw/BL79C9QWw3FD6jia2xtlL1oL794Oi6d5U5kFAsE5T1aif2fVqQ69QJgMMulxEU09NUEjhG1kKjXGQklNK3lhF66Eb56EIxtAU70vw+MUb4KCr8ESizb0fg7N/o74qVOJGj7sjE5lTE/DuXVbE0383OWy3HSufP57Hr+sR9BpdwmRJv4+OZd1+yuYNm8ttXX++/801mvI6dZYuquMdfsrGRAgKiYIT+64oCOrCo5id9X/+8cMugo5Mp6q/EWUf/IUkikCc2oOMUOvq7dfhOriiuVvU92uiuiLLkJSvIbOl1uLefB/G4PuHeJwq3y2+QiaDk9f27vBCFV2bDZd4ruwqXyTz2cJIxNIGNn4fWhRIqD8Wv5vxXYcjczNpeqAzuq9R7n0mRUsvmsoHZKjgrqWcGbjwWPc9ca6kPp9Od0aD7+3ifhIEyM7h9BUvqrI62xzHguuFsttg11fwdvXw9TFXhEegUBwznLz0GzyC45iO+39c6pDT5IVLO37IskGnIUbcB7YVF+EQocxXVPO8szPb8LWmEqJNlNa3Qrk0de+DF89Co31aHHZvD9fPkJqh1Qst756xqczpqdTs+TrMz7+fKFdopV2iVZW7ilndJfQHirtEiJ/Xjj6Ur32farW/I/Ecfdgad8PSTHg2LcOx+41Jxq3Ot0qLy7fK4ypVsrA7HgSo8zYK+w+n0X1GE1Uj9GNHh8dF80ll1xOxUsvU/rvf5N46wxsoy9m1sINITcCdrg1vthSTF7HRK4ZkNngOR8a9BAzvpxBnRraM1PGSHXReJbXenyiZA2halBhd3Ht3Hw+v/8CkqPNIZ0znNB1nXve8t84OVCDZqdb474F61n3x4uCc9hoGrx+OTir6hlS2bNrsLth3/1RRJq8BvPL6128ucnN0umR3nfLgXz47i9w4WNNc+ECgSAsGZ6ThNVs8DGmIDiHnlGRuG5gpkjzO8uErzEVY2Hr4arAO7YkP70FSwIYUqcg48ESeRRp8XSYstArpBAixrQ03CLNLygm9U7n4w2HQzamFqw54Kdl60lp0sTxs7B2yTux3ZozGGvO4BP/13VYuquM8to6kqJa70LzfEWSJObe2J/Jc1b5RKcCYTHKzL2pP7Ht4okZMwbHunWUv/QSz328GU/GYDjtzgrGOHe4Vf7z7W4m989oMDrVO7k3fxn+Fx5d+ShONbiIvkk2kah0Y1/NIL9SvI0ZE/rPqX9PfLyVZ2/oF9LvKJxYu6+CCrtvCl0wfxfwqmZ9tbWECblpPmP4sPc7qCmupzx6Yhwdnlnj4g8jGnheuO2wZo63155RpO8IBOcqsixx18iOPPXlTr9OnkAOPUWWmJ7XvjmnKPBDGNdMWSgJ58hUxV749DfgPmlIZc+uIeWfNdhcJ6MaL693MWq+7cT/JY8TClfADy+Ffs66Wozly4hL2Iz+/kz44vfw46veYmaBDxNy0/h6ewnOENJ3AD7dfNivlz4UaVKjIvH9HiF13Frpnh7DvOkDsZoUv4a1P6wmhbk3DaBvu3jAa5RZBwygzX9f4LP2Q3GfNlIofUOO1rpYt7/xRtMXZ1/M7NGzsRqsRBgaXnDLkoxFsdAnpQ/9LL/x20YpmOaQHk3nq20lVNlbr2rU3OV7cbjOvEGzzaUyZ9me4E72/TPgqvX70YN5Jp5aVccxZ4Bay63vB3cugUDQarklL5vBHRKwBFCEPZ0Io8xfr+hFuwbqrgTNRxgbU2ZKw7lmavUc0HwlLI97GBvFbff2SApWpKB8D3x8PzyVg/ztYyTmVCFtXACrn4cv/wBPdYL3Z0Lp9jO4kHOXlGgLuRlxfLO9NKTjqhz+F4ehSJN6VL3BcQStgyEdEvngnmEMyI7HbJAx+muA+3ND3b6ZcSy+a6jf+plvd5Si+7lnQjHOHW6V+asKA+43rO0wvrv2Ox4a+BCZ0ZlYFAtRxqgTPybZxLiscbx6yas8N/pFPvypxEe9MBRjQpbgnR8PBpxXOOJRNZbtKvMRmAm1QfPO4lqOBmrQXH0EDq5u8OMB6Qqjsg08taqRcVw2+P4/Qc1JIBC0XmRZYs6N/clMsAbtzLMYZR6Z0J2r+p/HbXNakPBN84sO48iU2wEb3qwvNPEzD+aZ+Mf3ddw90EScpZGvQV017FsGHUY1fq6tH8AHd4Hq9mu84f65rmPTIq/XcsLT0PeGoC/lXGdS73Q+2ngouDScAITaa0jQ+umcGs3iu/LYf9TG/FWFfLb5CDU/9wGJMhsY1yOVW4e1b1SI4UCFDZfHT2pXCMa5rkNBmf+oxulYjVYmd57M1Z2upuBYAaWOUlyqi2hTNDlxOcSavX2nvt5W4rcpemhGnsY7Px7k9gs6BDW3cKLa6UGRJdTTjMlQGzQbDRIVNheJjaX0Ht0Nihk8Db/TnhhtZtg8G/cP9i+pD0BlYVBzEggErZuCslrKa+u4d0wOi38sosbp9qmjMioSsiSRmxHLb8d1EU16W5CwNaaSokxU2lx4VA1DuDVB3fk5p9c+HOdUD+OfxzTSQ8plg7UvNW5Mbf0A3r8ruJosXfXu99kDgAZ9bwx8zHnAxT3b8OQn26hyuImN8JWq9kdchMmvIR9KryGDLBFnbWRRJGhVZCVG8vhlPXj8sh6Bdz4Nu0vF7UfQJFTjPNT6LUmSyInPISc+x+/nR211aH6i46EaE5V+ao5aAx5N8/sUD/XvIiH5/fvWoy6wIdwzRWFiZwN/W+miW3ID7zxPGGdrCASCJqG2zsO9C37iT5N6cHmftjxwUWdWFRzljfz9HKy043CrxFiM9GsXx7S8bLISI1t6yuc9YWtMGRSZ+EgT5bUu2sSGWePeqoO/3MMIjXsZy3bCBzPrGVIBVZ/AGzX79LeQ2gPS+wZ7RecssRFGhnZM5MutxVzbiBraqUzqk86z3+zGeVrdVCjSpB5NZ3hOUpNei6B1EmU2YFR8F9yhNgKOMjf8uHarbr458A1vbX+LI7YjuFQXVqOVnkk9mdZ9Gj2TevqIV3g03W8/tFCNCS04EcCwI8Zi9Cu8EerfxaNpxFoDOGpMwdUw/L9RFvrNreU3QxuIchmEg0YgOJfRdZ1H3t/M4PYJXN6nLeB1jA3LSWKYWFOELWEW8qlP2NZNuR1+U/yOc6qHsfFxfKWXT/D9M34NtqBqsjxOWP5U4/ucR0zqk87HGw8Hvf+UQe0abNQbM+gq4sfMoCp/EUXPTqXohenUrP+EiE4nU6JkCS7slkJCpFj4CCAnJQqTn0LiUBoBK7JEj/QYnzHcqpv/rP8PFyy6gMdXPc6Gsg2U2EuorKvkUO0hlhQuYcZXM5j4/kS+2f9NvWNjI4zIftL8QmkOCRBlCVufXKNYjAo5Kb7pmaE2aI42G0mLCeDwi88GNXAELydB5roeRv6ztoF9o9MDjiEQCFovi344yI4jNWeUBSFoOcL6LRi2dVPmGFBMjb4cA3oYj4/jD2c1bHnXr4RucDVZOuxeArVlEBVCQ8lzlAu7pvL79zZTVlMXVE+chEgTY7qm8NXWElQ/nvtA0qRmg8LtI1pfDYmgeRjRKRmLQcHmpwl0sI2A/cnd1rpquXPJneyq3NWgHLqGhsPj4EDNAR5e8TC3VN7C3X3uBmBQdoLf9LRQIrBGWQqtaW2YMXNURx59f4tPLUKwfxeLUea2Ee39GqX1iM/2ZgscWhdwTo+NNPPGJj/OOqMVhswMeLxAIGid7Ciu5h9f7uSdO4cQYRJ12a2JsDamUmPMlFSHYWSqTc+AxtSpHsZeKX4CgLIB2vb3f/DmxSD5DxoGXZMlSfDTmzDi141dyXlBhElhbLdUPt10mOnDguu/8MiEbnxfUE61w4/oR2PnMspc3COVPplxZzJVwTmIIkvcMjyb577Z45M6CsE1AtY0nae+2sn9F3aid2YcbtXNXV/fxY6KHbi04GqWnKqTV7e8itVgZXrP6SRFmclJiWLbYd/WCsEaE7Iscevw1tvT5NKeaTz6/ha/nwXzd9F1gk4fZtj98MHdPvLohbOi6/0/M1bG+agfR5uuQe8pwZ1LIBC0Kmx1Hu55az1/GN+NnJTowAcIwoqwNqZSoi2UhqMxlT3CG1VqoGfIcRr0MILXmGrIy1i8qdEUwKBqsjxOKN7Y6PzOJyb1TufZb3cHbUxlxFtZcNsQrn9xNXaXBy1IFfsubaL55zW9G2yuKjg/uWFQFnOW7gU/xlQgLEaZp6/pQ7mtjjvfWEf39BiyOuSzs2KnjyFVuaKS8i/LcZW6UCwKMf1jSJ2cihLp9XI6VSfPbfgvHlsXFn3vQtN1zAbZb1+1YIyJHukxtE9qvcXPFqPCw5d25a+f7fDbILMxIowKM4a3Jz7YdN4u48EU6RUfajCRuAEMFuh9PVgayGYQCAStmsc+3EqfzHgmC2nzVklY10ylxJgprQnDND9Jgrx7vWkXp1A4K5qxHU7ap8c9jCfEIU4ltQckdfI/vuNYo6cPuiYrwDjnE8M7JVF41M7Bikbq1E6jZ9tYPr5vOB2To4gwKjSUyWMxypgNMhd2TaG4ykmlrXWqmwmaj4RIE6/dOpAIY2ipGxFGhduGd2B8bho3D81m6YOjGNk5gf/tecsnta/883KKFxfT5to2dH++Ox3+2AHXUReFTxWinWIs1XlczNv0Go9M6MaSX19A/yxvH61QsRhlHjsH8vpvGprNjUOyQvrbRBgVxvVI5TfjOgd/IsUIN38YtBjFyeNMkNQZLvl7aMcJBIJWwf/WFbGx6BhPXtH6n6fnK2FtTKVGW8IzzQ+gz1Svt/BMMETAmD82/Lk5cIj3/42y8NJ6F4eqG/FwBjHO+YJRkbm0Zxs+CkGIAqB9UiRLHhjJwjuGML5XGkbF26TValJTl0rJAAAgAElEQVQwKhLJUWZ+PbYza/5wIa9MH8iUQVnc9eY66vz0FRKc3/TPSuD1GYOIMhsw+WkAfCqS5F2w3zO6Y70Fu8Wo0C7jANbTSv9Uh0rpB6Wk35hOdG40kkHClGwi8+5MXOUuqlZVnTK4hjtiHQM7RCLLMi9PG0D7pMiQDCqLUWb2dX3PmXTWRyZ049cXdcJskLEYG/49mAxex8m0vCxmX9cn9Ah0SjeY9glYYr3ZCYEwRkBab5j+CRjDTNVWIBA0jssO61+HhVNh3qXw2iRvu5t9y705wsDukhr++tl2/ntDP6ymsE4WEzRCWP/lUmPCVIACICIOpn0E8y5Gd9mC7lKN8WdDqmMj6TPJXb2GWiM9RQLWZCkm7ziCE0zqnc7jH23lntH+++40Ru/MOJ67oR8uj0aVw43TrRITYSTGYqi3oLpvTA5bD1fxp4+28X9X9WrK6QvOAQZmJ/D1AyOZv6qQt9bsR9P1esIUxw2a0V2SuXNkR/q2i/cZ451d72D31I+w2nfb0dwaMf3rp4EpFoXo3Ghqt9YSf8HJsWRJ5ruD3zGxw0SsJgPv3z2Mu99aR/7eo7hV3aeR7XEiTQqSJDH3pv7nnEzvHRd05Kp+GSxce4B53xdS51aRZQnN6QRZQTYZuWlIFjcNzSItNuLMT9S2H8zMh5VPw4YFaHUuZPnU2kzJG72yxEHer2DArUISXSBoTVQdghX/go1vez1jLlv9z7d/BJZYXIPv4/41nfndJV3o0kY4v1szYW5MhWma33Ha9EKf9jnaCxciGSVkvZH0LtkAshEu/qv35dgYva+Hb58IePpGa7IkCfrdHHCM84mB2QlUOdzsLK454weXySA3qggoyxJPX9eHK//7PW+t2c/UwVlnOl3BOUqbWAsPX9qVBy7qzFfbitlyqIqjNhfRZgOZCVYm9U4nMarhe+yI7YjPNrVWxRBlQPIT8TLEGnDsr9/426W6KLeXn/h/hEnh1VsGsaO4mldW7OPjTYcxyDLH/QRuVSMz3spdIzsyITcNS4jpiq2FpCgz947pxMxROWw/Us0xu5uK998j1igx9L4ZfiXuz4jYtjDhX+hjHqd06gBSruiD7KnxOsFi0qHPFG9trqi9FAhaF4fWwxtXeKNSDbXQcdm8P1//if+aO5Ld8/OzO0dBkxPWxlRilJljdhduVcOohGdGYsXnP2A7NI7M24ZC/nNeUQrV5f2RFG8kSlchdwoMvbvhOqlTiUyCnItgx6ecWqgctOoTQOYQiAtSZeo8QZYlLuudzkcbD/Fgm+aL2kWZDbx08wAmz1lF59RoBmYnNNu5BK0Xk0FmYm46E3ND6x3k8qMiqkQpeGo96KruY1B5qjwYouo/6lVd9Sun3rVNDP+8pjePT+rBntJaqh1uzAaZlBhLqxaaCBVFlujZNhaAysxonDt2Np0hdQrOPfuxu7sg3/xOk48tEAjOMiXb4LWJvpGoBjDpdWS79yC9dhnMWCJSeVsxYW1MKbJEQqSJ8tq6X5ZW0UzU7d7N0Vfmkb14MVJGWxh6DxQuh+LN4KjypmrEZkDXCV4Vp1AY/gAUfONtEBwqxgi44LehH3ceMKl3OjPfWsdvx3VpVsW97KRInrqmN/e8tZ4P7x0WlvevoHUSbfKNqlpzrEgGiep11cQOij2xXXWq1GyqIXVyar39TbLJ7zjHiTIbzpl6qF+KITkZz/IVzTK2fc0arIMGN8vYAoHgLKJ6TkakTiF7dg12N+y7P4pIk3fN8fJ6F29ucrN0eiSS6oLy3fDFw3DZ7JaYuaAJCM9wzymEa92U7nZz+HcPk/zrWZgy2no3yjJ0GAV598GFj8KIByD32tANKYCM/nDh4z6KgQExWmHYr6H9BaGf8zygR3oMRkXmp4PNr3Q4qksKtwxrz51vrMMZouyyQNAQA1IHYJSN9bYpVoWUK1I4/OZhajbVoHt0XGUuDj5/EGOCkbi8+oaRLMv0TOp5NqfdajGkpOApLW2WsW1r1mIdPKhZxhYIBGeRnZ822PZA1eGZNY2UgXgcsHEhOH17/glaB2FvTKVEm8Oy11T5iy+iJCYSd801zXeSITPRhj6ApkrBdSUxRngNuZEPNd+cWjmSJDGpdzofbQhN1e9MuWtkBzITrDzy/hZ0PcTeMgKBH6Z0neI3qpo8PpnUq1MpXlTMtpnbKHiyAGOCkfYPtUc+TaEuOSKZ3KTcszXlVk1zGVO6241j/XqsAwc2+dgCgeAss3J2g71HH8wz8dSqOo45G1kDSJJXsELQKgl7Yyo+0sTOkhpKqp1h4913bttG5VsLSPvzk83enPXIF2VUKpOR2o8Eg9lboHwqstGr/NcuD65fAKP/IIqWAzCpdzqfbj7SoGJZUyJJEv+cnMu2I9XMX1XY7OcTnPtkRGc0aAgljEyg01860eOlHnT7TzfaTm97omHvcSIMEdza81bRWDpIDImJeI4dQ1eb9v3j2LwFY7t2GOJ9FRsFAkErorIQSrc1+PGAdIVR2QaeWtVIlpXbDqtfaPq5Cc4KYVkz5XCpfLzxMHOWFbDvqA1Fkvjvd3twqzpDOyRy58gODOuYhNxQF9VmRHO5OPzw70n93UMYU1MDH/ALqP78c5wbN5H2/ntgtcKxg/Djq1CyBZzHwBIDyd1gwC2Q0KFZ53Iu0SE5ijYxFvILjjK8U/PLO1tNBl68qT9XPr+KLqnR5P0sKe3yaHy5tZjlu8o4anNhUmTSYi1c0bctvUW9iqAR7u93P7d9dRt1amgp0DIy0aZoxncY30wzO/eQjEaU2Fg8R49iTEk543GKKu1sP1JDjdONxahgXfETXQeKFD+BoNVTWeh1dDfSzuaJ0WaGzbNx/+BG2hzUnJ2MGUHTE3bG1LyV+3jqq50A2F1eT6BH1/H8HEVYuaecnw5UEm0x8t+pfemfdXaV0sqf+y/GzExiJk1q1vO4S0oofvLPZM55Adn6c91UXCaMfaxZz3u+MOlnVb+zYUwBZCZYeeb6Pvxq4QZemTaAL7YU8+aa/Wiajs110uMtAQt/OEharIW7RnXk6n4ZKC3gNBCEN31S+vDY0Md4Mv9Jv6p8/pAlmShjFPMvnk+EQQiihIIhJRlPaVnIxpSm6SzbXcbcZQX8dOAYJkVG1XVkSUJ1JBAdkcpdK/dxdf8MYiOMgQcUCAThh8uOv1qpU+mZojCxs4G/rXTRLbmBpDBPI3VVgrAmbIwpXdd54pNtLFx7EEeAdD6bS8XmUpn68hqem9KPsd2bN0J0HMfGjRx77z06fPB+s6bI6JrGkd//nvipNxCRK+oamoOJvdO4ZPYeZgyvZk+p7YS3OC3WwsDshGaJeg7LSeKqfm256vlVKLKES9V89tEBh1tlb7mNxz/cyoc/HeKlaQNEZ3SBD5M6TsJqsPL7Fb9HR280SmU1WIkxxfDqJa+SEZ1xFmd5bmBMPl431SPoY0prnNz48hoOVTpOOEzqPKd852Ujjjr455c7eeqrnfx3aj9GdznzyJdAIGghzFF4XaGN8/9GWeg3t5bfDG2gj6BBSKO3VsJmhTZnWUFQhtSpON0a9769ngW3D6Ffu1+ed15WU0dJtROHWyXaYiAz3kqk2fsr0pxODj/8e9o8+giGpOaNZlS++RaazU7SnXc263nOVzyqxsaDx/CoGhOfXYnZIKNqXjFGdG8D0xnD23P9wHbERzYSkg+RvWW1LFhzAFXXUdXA9VoOt8qP+yu56ZW1LLxjSNj2WhO0HGOzxtInpQ+Ldy7mrR1v4dE8aLqGpmkosoKma7SLacetPW/loqyLMJ1ecykIilBFKEqrnUx4diWVNteJrIqGOP7Om/nmOv51TW8mhNh3TCAQtDCJOeAJnHKdkyBzXQ8j/1nroleKn/d5fHbTz01wVggLY6q8to7ZX++u77X7Gdu2pVT/8AHuo0XIpgiMKR2IzbsWS4bXQ+h0azz0v018/cDIMzq3quks31XGnONpGAYZSQJdB4+mcVluOjNGtCf+1eexdOtKzCWX/KJrDUTdnj2Uv/AC2QvfRjKExZ/nnOJIlYPrX1xNeU3dCW+x+7TCcptL5ZlvdvOfb/bw7JS+TRL5VDWdqS+vweby+HzW2D1e59HYeriKv362nccvC94rLjh/SIpIYmafmdyeezs/FP9Asa0Yh8dBtCmaLgld6BzfuaWn2OoJxZhyeTSuf2l1UIbUqTjdGr9ZvJGMeKuomRQIWhMx6dBuCOxbFnDXx0aaeWOT2/cDUyQM+1UzTE5wNgiL1frbaw/43V699n2q1vyPxHH3YGnfD0kx4Ni3DsfuNSeMKYBDlQ42FR0jNyO0F9CO4mqmzVtLrdNzYmF9eurVe+uL+GhDEd2PRjPvTzNCvLLQ0F0uDj34kLd3VVZWs57rfOTQMQcT/7OCaqcnoJKf0+29D+59ez3/uDqXSX3a/qJzf7ujlGqnm9PV0YO5x51ujYVrD/LgxV1Eup+gQQyygaHpQ1t6GuckhpQUnFs2B7Xv51uOUFzl9GtIBeMc/NvnO3j7jiFNOn+BQNDMDLsfDq3zkUcvnFW/OXpmrIzz0Rj/Y/S4qrlmJ2hmWnxlpmo681bu84lKaXU2jq18i8Txs7B2yTux3ZozGGtO/Y7xdR6Vl5bv5dkb+gV93vUHKrnx5TU4XGqjZYOqDqoKWxOyuerNrXxw9zBirWdQKKxpsPc7OLQe7EfBZPV6M7pfCZGJAJQ9+xzGtLTm7V11nuJ0q1w3N59qhwc1hH5PTrfGQ+9uIjPBSt9fkEo6Z1kBtrr6EbBQ7nFJgg9+OsQNg4WRLRCcbQwpybiDjEzNWVZwQjzpVIJ1Dq4/UMnBCjuZCSE2bBcIBC1Hh9Fo5gRw1npLBkLBaIXBM8EoaqZaKy1uTG0/Uo3LT3pf3aEd6B4X1s6BPa2aDt/sCD6f/WCFnWmvrPX7wmsIlwaHKu3cPG8N787MwxBs/YrjGKx/HfKf83bHdtmBn6/XGAFf/gE6XYwzYRzHPnifDh98IPq/NAOfbDpChc3l15AKxlv8jy92nrG3uKjSzpZDVT7bQ7nH7S6Vl1fsE8aUQNAC1MQksrvKTfWhKqItBtLjIvzWMO4ormZfuc1neyiOE03XeT2/kEcmdG/y6xAIBM1DzbffUvaJhewLo0C1g+67rvWLIQLaDfX2CBW0WlrcmDpmd/tVTlMd1cjWGCRZ8XOULw63iqbpQamw/fPLnX5rV6DxhbVL1dldWstX20oY3yst8KRKt8P8iV4jyuPw/dzt3abv+AST5xOypl+NIeHsSr2fL8xZuqfFvMX7j9oxGWSf6Guo9/jhKj/3kEAgaBZUTWfZrlJeWFrAhgPHMGRMRHlxNZquo8gSNw7O4qahWaTHnZSZX7e/0u9YoThO3KrOyj3lTXYdAoGg+dB1naNzX6Ry4UIynn0VOSMKXh0P9gr/675TMVqh0zi46iUIch0gCE9a3JhqCCUiBs1eja6pwS02dSg6ZqdtnLXRvjxVdjdfbi3GX8lMMAtru0tlztKCwMZU2S545SKoqyVQ/wFJ15AUMBd/6o1UXfJ/ga5WEAKbi6o4dMy3F8/Z8hbb6jx+b4FQ73F/Ai0CgaDp2XjwGLe9/iP2upP1tG6DGepOOuFeWbmPed/vY2JuGn+7OhejIlPt8OD20/IgVMdJjdO/s08gEDQfuq7jcKvUOj1EmBSizIZGM4U0p5MjjzyK68ABshctwpj6c2uDu1d7M5JWPQuuGq9D/cQ5ZCSjCdJ6w7BZ0OVSbx6/oFXT4sZUnNWI5if1yty2K5LBiH1XPpFdhwccR5LgurmrqbC5yIiPoH1SJFmJkWQnWslKjKR9UiRpsRbe+fEg/mytUBbWu0pr2FNaQ05K9OnDeHE7YP6EoAyp+sfZYd2rkN4Xcq8N/jhBo2w4WOn3HgvVW7yq4OgZnT/KbPDbgiLUe9xsENLoAkFz8/2ecm577ceAbTqOixV9uvkIByscvHHbIIyK5G3Ie9pzP1THiWiDIBCcPcpq6liwZj/zVxVS4/RgkCU8mo7ZIHPdwEym57WnXWL9rBR3SSlF996LKSuLrDdeR7acUu9kiYG8e2HI3bBvKRSuhNpSPFU2qlf+RML/vQuJHc/uRQqalRY3prq0icYgy0D9F5dsjiRu+FQqlsxBkhUs7fsiyQachRtwHthE/OhbT+4rwbjubZhzU3+cbpX9R+0UHrWx/6iN7cU1fL6lmP1H7ZTV1iHh38MfysJa12Hl7vKGjakt7/3siTj5Qs2eXYPdDfvujyLS5F1Zv7zexZub3CydHnnyWLcDvnkSel0jvBVNRLWzZb3FWUmRfusCQ7nHAdqekk4kEAianp3FNdz+emBD6lScbo1Nh44xa+EGLunZBpMi+7RbCNVxkhLTQFNPgUDQZDjdKr9/bzOfbT4CnFwbHlfi9LhU3li9n7fWHGBgdgLP3dCXOKsJx+bNFN33K+KnTCHxjtsbjl7JMnQc4/0B5IJVqJ9MQc9/HklRIKoNdBkPKV2b/2IFzUqLG1NGRebmoVm8uHyvj5ETM+gq5Mh4qvIXUf7JU0imCMypOcQMva7efhajwh0jO5z4d5c20XRp42voON0q4/69jAMVvnmsoSysXR6NY3Y/fQKO8/1scPsWIas6PLPGxR9GBHhR2o/CgXzIymt8P0FQmBQZRZbQ1F/mLTYpZ2bcto2LIDcjlh8Kfespgr3HrSaFOy7ocEbnFwgEwfGnj7Y2KEzUWD2t062xdGcZU/uk4nF7OD0UHYrjJNKsMFUIzQgEzYqtzsO1c/MpKKttNIXereqAzpp9R7n0mRW81sUJ//wLaU8+QfTYsYFP5KmDzf+D72cjVxWR1M2O9OPL3s9kAyz7OyR3heGzoNskUTvVSmlxYwrgpiFeY8ofUT1GE9VjdKPHp0Sb6RtEk0OLUcFi9H+jhrSw1nXsy5ZSXvAdhjapGNukYWyTiiE1FfnYLqgq8nvYg3km/vF9HXcPNBFnaWRh7rbDqueEMdVEpMSYm8Rb3Cb2zCNDd43syLbDP52ovziVYO5xXYdJvX9ZryuBQNAwByvsrD/gX0AiuH5wHuY+/TZj46L5wpTJ6d/0YB0nEhKX9GjTHJcoEAgAj6px6/wf2F1a6zdrxB9uVae0ys4ty6v48MWXie4VRP207Si8Pgkq9nrXdZyWcKR5vD9HNsAH98CPr8KUt70NfAWtirAwplJiLNw1siMvLt8bUnoFgMUo8/erc4OWE0+KMrOrpNZneygLa7MikZQSj2Y7gG3lHjzFxbhLSvCUlBDb0UVqzzq/zoUB6Qqjsg08taqOP49prJ+ADiXBNYgUBObCbqmo+iaf7aF6i28c0u6M5zCqSwpxVqNXdTKEMjqACKPC1MHtiDAJj5VA0Fy8nl/ot7Yy2HpaHYk1aT1487YhfPPKGlQ/i7RAjhOzQWbq4HaYRH2kQNBsfLalmM2HqvwaUo1FoFVkyiMTeLPYwKxeAU7irIKXx0DVYdBcgSfltsHB1d56+1u/BINI9W1NhIUxBTBrbCeKq5x8tPFw0AaVxSjzz8m9GdwhMejzXN0vg40Hj/lECEJZWOuyzOXTLyMlpr5BpOs62tLZSCv+Cpr/a3hitJlh82zcP9jU+ERdvmmCgjMjymxgUu903l1XxGmZfkF7iw2yzNhuqWc8B0WWWHD7ECY+u5JapydoWRKzQSY3I5aHLxU51QJBc/Ll1pKfU3rqE0o9rcmgUGGr4+7RObywtCAk56BBlshOjGTW2M4hzVsgEITGnKVn3li7TtWZn1/IvWNyGu83uugmqK5vSAWsnffUQekO+HgWXPlCk16zoHkJG2NKkiT+dnUvMhMiePbbPciaC4fmf3qRZgWjIvPfG/oxLCcppPNMyE3jjx9u8ftZsAvroR0SSY3xjSxJkoQSlwKKsUFPRM8UhYmdDfxtpYtuyY18EQ2iE3ZTctuIDny08TCqO3RvscUgMy0vK/hGzQ2QlRjJ+3fncd3c1dTUeQKmF1hNCoPaJzDnxv6/+NwCgaBxapz+62BDqadVNZ1Ku5v7xuRQ43Tz5uoDQRlUZoNMZoKVBbcPFhFogaAZ2VFczd5y3+ykUBSd3arGtztKGddQOm7ZTji4BlTfdWDA2nmPA7a+C+OehMjQ1reCliNsjCnwGiP3junEzb2jeffZ3/Gi6UaO2j0YZQkd7w3cMz2Wu0Z15MKuKWe0wLQYFa4bmMmbq/f79UIGWlgHFAKIaxdQhe//jbLQb24tvxnaSBg3XhQgNyWdU6O5b0wnnvt2T0jeYqMi0Sk1intG5zTJPHJSolnywEheW1XIa6sKcWsatrr684kwymQlRjJzVEcuy00PqhG1QCD4ZTSUKh6qUI0secd6ZEJ3urWJ4e9f7qDW6fFbL2k1KWi6zpV92/LHid2xmsLqlSwQhC3H7C7e/+kQe0prqXF6iI0w0i0thkl90r3tSBrg620luH+horOtTuXDDYcbNqZWPw+qf/Xf4GrnJVj3Glzwm4BzEYQHYfnkjtnxDrf0iWT65WMpr3VR5XBjkCXiI03ERhh/8fizLuzMV1uLKa6qQ/WTI98QFqM31Wtox0bSCrOGgTECXL6ej+PkJMhc18PIf9a66JXixyA0RcHgmUHPSxAcd4/qSG2dh1e/34fTT4TqdIyKRKeUaN6cMQSzoem8xQmRJn59UWfuG5PD19tLWbm7jHKbi2qHm71ltcybPoju6TFNdj6BQBCYuAgjFTZfT3Io9bSKLBFvPZnCfVX/DK7s15ZVBUeZu6yALYercbhUjAaJlGgz04Zmc2W/jEYXfwKB4CRbDlUxZ1kBS7aVIEnUe5dbTQpPfLKVSb3TueOCjuSkRPkcX1zt9En3h9BbpZTV1vn/wO2ATYtA929MBVU773F6DTJhTLUawu8Jruuwbj5cOQdJkkiONpMc3bSFeLFWI4vuHMrVz6+iwubCHYQiQIRRYXCHBP51be/GxS5kGYbc45W79PhKsB/nsZFm3tjUgLy6bPD2HhA0KZIk8btLutI9LZq/f76TCrsLh+v09poQaVLwaDoGWeKt2wYRa/3lBrw/DIrMJT3bcElPr3ertMbJxf9eLgwpgaAFuLxPW55fusdHJjmUelqPppF3Wuq5JEkMy0kKOSVdIBDU5838Qv782XZcHs2vkNPxOqh31x/i441HeOqaXCbkptfbpyH/eagRaL2hgaoPg9T48UHVzjsqvYaZUfSXbA2EXyFG4QpvvVDGwGY9TUa8lc9nXcCQjomYDTLGBnoIWU0KFqO3ZuaVaQOD60zffxpI9fcrnBXN2A4nbdfMWBnnozH1G/aC94sz9B5Qws/OPVe4rHdbVvxuNPNvGcSF3VJIjTETbTGQFGWiR3oMf5rUg42Pj2Ns91ReWrHvrM0rOcqMR9P9escFAkHzcsPghtU6YwZdRfyYGVTlL6Lo2akUvTCdmvWfENHpZEqQIsFluY2nGAkEgjPj9VWF/OWz7Tjd/g2pU1E1HYdb5TeLN/LppsP1PkuJNuMvc/7UCHQwJEQ2YAjVVfus/07n1Nr5BlGMUFcT1FwELU/4PfXXzYf+0wPWHTUFCZEm3pgxmKJKO2/k7+fttQeoqfNgkCU8mk5mvJU7L+jAFX3bEhnKC9KaANe9DgtvbDQ65YPBAm0HwvAHQr8YQUhIksSg9gkMap/Q4D6PjO/GJc+s4Kp+GX7TBZpjTjkpURSU1ZIQ2fC8BAJB05McbWZEpyS+3VHqd7EWqJ7WqMjMGNG+GWcoEJyfrNtfyV8/3+43PT9QM+3fLt5ElzYxJ97h3dJikCXJpw1CSK1STArje6X5n6wxEvTAZQQBa+c1DxitAccRhAfhZUzZymH31zDh6bN62ox4K78f343fj++GW9VwulUiTYZfVvifMxaufhneuw3cTggkhm20QuZguH6BiEqFCSkxFu4ZncOfPtrKGzMGBd3L7JfQMTmKPaW1DMwWxpRAcLb506QerC2soNrhv96hISKMCtcOyKBrG5GiKxA0Nc9+s9uvIRWMlLnLo/LC0j3kdUzi3fVFbD1cRYRJocbp+x0PurG2JJ1Iz/chug2oDZRwnELA2nmDWTTvbUWE16p949vQdQJExLXYFIyKHFwqXzB0mwgzlsDS/4M9XwOSt7DwBBKYrBARD3n3w8AZ+O32K2gxpg3NYvGPB/l08xEmnpZ73RzkpERRUNqweIlAIGg+MuKtLLhtCFNeWo2tzhNUg+0Io8KF3VJ4/LIezT9BgeA8o6TaSf7eoz7bg5UyV3VvDVVZTR1TB2dxYbcUPtpwmMc/2upX2TdwBFpi6uB2DYtSWWKg00Ww41MCOdEbrJ1XTNBv2lnJ0BI0DeFjTB0Xnrj8+ZaeSdPSppc32lRb6pW6LFwBjmNgtEBsJgy4xasAKL40YYlBkXnyip786u2fGNUlpdnrITomR7Haz4tDIBCcHXq2jeWT+4Yz88317Cuvpa6BYnerSUHX4c4LOnD/2E5nJXItEJxvLFx7wO/2UKTMLUaZS3ulMSHXm5o3qU86877fR0FpbVACZMeRJIizmrhzZMfGd8z7FRR8B25bvc2Fs6Lr/f947bzfEw26I+h5CVqe8DGm9n8PshEyB7X0TJqHqBQY+aD3R9CqGJidQF7HJP7zzW7+ML5bs56rY3IkBWUiMiUQtCRZiZF8dv8Ith2u5tb5P1BeW4ckgSxJuFWNjHgrd47swBV9QqynFQgEIbHlULWPwiaEJmXudGvsLD4p5mAxKiy4fQiXP7eSkuo6XGrgGidFloi2GFh0x5CGxSeOkzkI4jKhfDfowfe19J7IBO3yIEHUX7YmwuctcBaFJwSCUHn40q5cPHs5k/tn0Dk1OvABZ0i7BCul1XU43SoWo0j5FAhakg7JkdTWeVj7yIUYFBmXRyPaYmjSvnMCgaBhapz+649ClTKvtNdXzkuINPHp/SO46411rD9QiVvV8GdTyRKYDQrtEsX5ObMAACAASURBVKy8estA0uOCkCqXJJi6GOaMAGcVAWvmTxynQGQyTJ4X3P6CsCE8pNFtR2HXV5B7bUvPRCDwS3K0mVljO/HHD7Y03F+iCTAoMu0SrOwtswXeWSAQNCs/FFbQOTWKhEgzMRYjSVFmYUgJBGeRKIt/n3+oUuZxEb79ImMsRhbcPoRP7hvOtQPaEWFUsJoUoswGrCYFRVNJjzaRmxFL1zZRLFx7gMLyIN/Nce28NfORSd6sq0Ao5pPHWIUAVWsjPIypjW9Dl0vFDSQIa6YOzqK2zsNHGw8H3vkX0DE5SqT6CQRhwLKdZYzqktLS0xAIzls6p0b77QN6qpS5fVc+mtuJrnpwFPxI5Xf1IzsRRpmOjbQ3yUmJ5v+u6sVPj13Eq7cMZFSXZDRdxyhpFFW7WLOvgg83HuH5pQVcPHs5Vz3/PV9vKwnsWE3uDDPzqe51C27FilvxI3VuigJLHAz7Fdy5HGLbBvV7EYQXLZ/md1x4YtKzLT0TgaBRFFniict7cvdb6xjdNYUYSxDepjPgeK8pgUDQsizdVcbT1/Zu6WkIBOctUwa145WV+/CXKheslLmmw+W9Axspm4qquHX+D7g8Gm5VB6n+O96j6Xg0nfUHjnHfwp+4vHc6f7myF4qfNjqqprN0ZylzlhWwqWgMEfIILtLzuZjVJEo1aJKMOa4NSUNvpM3Aq0RLnFZOy//19q/yyoG3G9LSMxEIAtI/K55RnVOYvWQ3j13WvVnO0TElkm+2lzbL2AKBIDiKKu1U2lz0TI9t6akIBOctmQlW+raLY/XeCr+fB5IylyW4pEcbYq2NOz/XH6hk2ry1fuXS/eFwqXy44TB1Ho2nr+1dT83zmN3FTa+spaCsFrvLO14dCosZzmKGn9jPUCZh+FTibts+7huTIxRBWzEtn+YnhCcErYyHLunChxsOsf1IdbOM703zEzVTAkFLsnRnGRd0Tv5lzdsFAsEv5ldjOhFxhoJMJoMcUMq82ulu0JCybVvKkddmceDpyRQ9dxMl7zyOs2grAA63yhdbillwinx7lcPNZc+tZEdx9QlDqiE8mo7TrfHC0gKe/GTbGVydIFxoWWPKXgG7voTc6wLvKxCECYlRZh4Y15nHPmweMYqOyVHsK69FC6H/hUAgaFqW7ixjVJfklp6GQHDek5eTxJ0jO4RsUEUYFR6/rAfd0/30cjqFd38swqP6vm+r175PxTcvETvkWjLufZO2M18lut94HLvXnNjH4VZ59ps96LqOruvc8upaSqqc3jTBIHG4Vd5ee7DBnlqC8OespPkVHCvgre1vsb1iOza3DavBSofYDkxxG+nV5RIhPCFodVw/sB2LfjjIe+sPMb5XGh9vPMw7Px6kvLYOTYeYCAMXdUvlhsFZJEebgx632lXNhwUfYm63kIvffQaDLBFnjuPi7Iu5stOVxJpFypFA0JS4XC62bNnCpk2bsNvtAFgiIigtlBk8qWsLz04gEADcf2EnAOYu2xswFU8CzEaZR8Z3Zcqgdo3uq+s6L67wHVOrs3Fs5Vskjp+FtUveie3WnMFYcwbX27fa6Sa/4Cgmg8yO4hpcfgwp27alVP/wAe6jRcimCIwpHYjNuxZLRg/Aa1D988udXDMg80QNlq5767Nezy9kb5kNh8tDlMVIbkYs0/Oy6ZDcsKiG4OzSrMbUiqIVPLfhOQqOFeDRPKinNC/bfnQ7S3SNVGsqd+39hAntJ4h8UUGrQZElfndxV2a8/gOPfrAFScInpL+7pJbnlxZwQedkfndJV3IaURMqsZUwe/1sluxfgoyMZnJQ7F3XUVRbxJ5je3huw3OMaTeGX/f7NWlRac15eQLBOY/NZmPp0qVs2LABALe7fj+b3pLCy/+dTe/evRk1ahTR0c3XX04gEDSOJEnMGtuZQdkJPPvtHtYfqETT9XoRILNBRgeG5yRx35gc+raLDzjuj/srqXb49rKqO7QD3ePC2nlowDHsLpWXV+5DkSW/hl712vepWvM/Esfdg6V9PyTFgGPfOhy715wwpgCcbpXlu7wR8f+tK+K5b/dQVluHw61yahLMlkNVLPrhIN3SYvjtuC4M75QUcI6C5qVZjCld15m7aS6vbH4Fp+r0u4+GhlOC/Y4Snsh/grVH1vL40MdRgmi+JhC0NEWVdh58dyN1Ho2GMv2Od23/ensJ3+8p5+VpA8jr6PvQ21W5ixlfzqDGVVPP4XAqx79HX+77ku8Pfc/L416mW2K3prkYgeA8o7y8nPnz52O329E0P506AQUVjwd++ukntm/fzvTp00lJETLpAkFLkpeTRF5OEgcr7Cz84QC7S2qpcXqIjTDSPT2G6wdlkhJtCXq8fWU2/GXUq45qZGtMUA2BAXYW11BeW+ezHgglwmVzqbywdA8fbDjEV1tLGozAHVcV3HDwGLe9/gO/Hts5YF2YoHlpFmNq3pZ5vLKlYUPqdBweB5/v+xxZknl86OMiQiUIaypsLq5+fhVlfh6c/tB1r+dqxvwfWXjHEHpnxp34rKimiOmfT6fGXRPUuTU0ql3V3PLlLSyauIismKwzvQyB4LykqqqKV155BYfDEXDfpUuXUlFRwVVXXcW8efO48847iY8P7O0WCATNS2aClQcv/uVpuLV1HlQ/DhUlIgbNXo2uqUEZVNXVNoySTN1p20OJcAGsP3AMw6EqnG7/Tp7Tcbo1Zn+9G5NB5pZh7YM6RtD0NLkxtaV8C3M2zvExpCpXVFL+ZTmuUheKRSGmfwypk1NRIr03qVN18tm+zxjWdhgXZV3U1NMSCJqMB97ZQIXN5deb1VhetMOtMv3Vtax9ZCxGxav9cv9392Pz+Cr3Bfq+ONwO7v3mXj664iPhfBAIQmDRokXU1dVf8mzevJn8/HzKy8sxm820adOGESNG1Nunrq6OBQsWcPfdd4vvnEBwjhBpVlBkGdT6USBz265IBiP2XflEdh3ewNEnMWgqLl316U0VaoTreNTpdAKtLf7+xQ6GdEikW1rjYhuC5qHJjalXNr+CS3XV21b+eTlln5eRcVsGUd2jcFe6OfzGYQqfKqT9I+2RDd6FpcPj4MVNLwpjShC2HKlysKrgKG4/D7tg8qJdHo2vt5Vwaa80th7dyoHqA2h6fQ9UMN8XDY0SewkbyzbSJ6XPWbl2gaC1U1xcTFlZWb3Uvvz8fFauXMnEiRPp2LEjiqKwZ88eduzYgclkOrGfruscO3aMQ4cOkZGR0RLTFwgETUx2YiT+uh/I5kjihk+lYskcJFnB0r4vkmzAWbgB54FNxI++td7+GW2TKDxqA6en3vZQI1z+CGZt4VZ1Xlq+l6evE+uBlqBJpdErnZUsP7QcjZMvKtWhUvpBKek3phOdG41kkDAlm8i8OxNXuYuqVVX1xthXtY/dlbubcloCQZPxRv5+f43YT+RFJ1w0E2uXPGSTBUkxYM0ZXO+ha3OpvLCsAIDXt76OS6vveAjl++L0OHlt62vNc6ECwTnI6tWr8XhOLnacTiffffcd48ePp1u3bphMJhRFoUuXLowbN87neI/HQ35+/tmcskAgaEYGtU8g2uK/oW/MoKuIHzODqvxFFD07laIXplOz/hMiOtVP2Ys0KVw/MBPVj5P11AjXmRDs2kLVdD7dfIRqp6+YhqD5aVJj6rN9nyGfNqR9tx3NrRHTv37oUbEoROdGU7u1tt52j+Zh8a7FTTktgaDJeHvtAVyqby5zKHnRO4tr2F9Rzdf7v/aJSoXyfdHRWVa0DIcncO2HQHC+o6oqW7bU7w1XVFSEx+OhW7fgxFx0XWfHjh24XK7AOwsEgrBHkiRuH9G+wR5WUT1GkzZtNu0eeJfMe98k5Zo/Ycmo/7yIshi4fmAmiZEmn+NPjXDZd+WjuZ3oqgdHwY9Ufjcv4PxCWVvIksRnm44E3E/Q9DSpMXWw5qBPrZRaq2KIMiApvnFUQ6wBT239kKiqq+yv3t+U0xIImgRN0znmR0IVQsuLNsgSy3YX4u2Gcdo4IX5fDLKBSmdlUPMXCM5n/AlO2O12rFYrshz8q1CW5RP9qAQCQetn8im9nUIlwqhw7+gcFEXmrpEdsZp81wDBRrj8EcrawuFWOVQpnKstQZPWTNlcvoX0SpSCp9aDruo+C0RPlQdDlO8UHG5xMwjCD5eqIeE3yy+kvGiHW2XO8u24k3UfeyrU74uEFLRqpkBwPuN2u32EI6xW6wl59GANKkmSRGRKIDiHiI0wMv+Wgdz0ytqADYFPJcIoM657KjcO8arqXtG3LX/5bLvffaN6jCaqx+iQ5xZqzVVNnSfgPoKmp0kjU/EWX8lYa44VySBRva663nbVqVKzqYbI7pE+x8SYhRqJIPwwG2QkP9EkCC0vOtJk4KnJQ/GXVRDq90XVVaKNopmoQBAIs9ns01MqIyMDg8HAjh07gh5H0zQsluD72AgEgvBnQHYC86YPJNKkYAgiSmU1KUzMTedf1/Y+4aSJNBt44cb+WIyhLa0VScJk8H9MKGsLCYi3+q//EjQvTWpMdU3oitVgrbdNsSqkXJHC4TcPU7OpBt2j4ypzcfD5gxgTjMTlxdXb36yYyU3KbcppCQRNgiRJdEj2Nf4htLxol6rRo00qkUbfsc7k+5JgSWi6ixQIzlEsFks9db7j20aNGsVnn33Gjh07cLvdqKrK7t27WbJkid9xFEUhMtL/c0AgELRehnZM5ItZF3D9oEwijIpPyp5BlrAYZHpnxPLv6/rwj8m5GJT6y+iRnZP51zW9gzaojIpEWpyF6wZkYlJ8jwllbWE1K3QV0ugtQpOm+Y3NGssTq5/w2Z48PhklUqF4UTGuUhdyhExMvxgy78xEPu2G03Wdqztf3ZTTEgiajLtGduSxD7dgc/mmAsQMugo5Mp6q/EWUf/IUkikCc2oOMUOvO7GPIktM6JVGbISZG7vfyIubXqROrd/zJtjvi0k2MaXrFJQzlFsVCM4nZFlm8ODBrFy5sp6iX15eHlFRUSxfvpz33nsPk8lEeno6I0aMoKCgoN4YiqLQv39/FEV85wSCc5HMBCt/vqIXfxjfjY82HOaHwgoq7W4ijAqZCRFM7p9JTkpUo2NMyE0nPS6Cv3y6nc2HqtA03aeditWkoOtwZd+2/O7SrthdHt758aDf8YJZWwAYZZkLu6b8sl+A4IxoUmPKpJiY3GkyC3YswK3VL9RPGJlAwsjGPegSEsPaDiMxIrEppyUQNBkTctP444dbGvw8UF60SZGYMcLbpfzqTlczd9Ncv/sF830BuLbLtQH3EQgEXvr378+KFSt8tufm5pKb65sRkZmZWe//kiQxaNCgZpufQCAID6wmA9cPasf1g9qd0fF928Xzv5l5FJbbmL+qkBW7y6hxejDIEolRZv4/e/cdHlWVP378fe6dmUw6BJJA6CEUAbGAdARE0bWXtYNir7u6xd39ubvqurrfXdddd+2KCgo20LX3RkeaItJLCC1ASG+TTLnn98dMMMlMkpmQRvJ5Pc88mbn13MnMmfs59cpRvbjgxB7ERvlvwxOj7Yzql8SS7bkhj9fQvUWUzeCacX2CaspEy2jySXunD5nOgm0LgoKpcESZUdx6wq1NnSQhmozTbnLbpP48vXBnRB1VARw2gxN7dWZoWiIAXaK7cH76+XyY+WHEg0g4TSdn9DmDlBgphRIiXPHx8QwbNoyNGzfWqJ0Kh81mY+DAgXTq1KnhjYUQAujbNZYHzh8a1rZ3nz6Q1Vn5VHiCp19piN00jgyEIVpek4ew3WK78eTUJ3GakXXQdZpO7h97P0O6DGnqJAnRpO48LYMpg5PrnJciFIepSEt08vw1I2osv3fMvQztOpQoMyrsY0WZUQxMGshfxv0l7H2EEH7nnXceqamp2GzhlyWapknXrl256KKLmjFlQoiObESfzjxw3tCIB7CItpu8fP0ppMTLwDitpVnqA0/pdgrPnP4MsfbYBm8S7YYdp+nkofEPcW7/c5sjOUI0KaUUT1x5Mhed1INou0lDA//EOEyGpCXw3p0TgmZatxt2nj/jeSb0mEC0LbrBc0fbohnVbRQvTnsRuymj9ggRKZvNxrXXXkufPn2w2xv+DjkcDnr16sV1110X1vZCCNFYV4zqzd8uOh6n3cDewM1FlM0g3mlj3o2jGdFHBqJqTar6bPC1jRw5Uq9Zs6bRB8915bJg6wJe3fwqXu2l0luJV3sxlYnTdKLRXDLgEq467ip6xvds9HmEaC3f7SnghcWZfLUlB9NQuL0WWmvsNgOtYUj3BG6b3J/TBqfU25ZZa83qg6uZvXE2qw6swjRMPD5/U1m7YcenfYxIHcF1w65jTPcxQfPlREIptVZrPbLRB2gjjjZ/Eh2bZVns2LGDpUuXkp2dDXCk6V9VrVW3bt2YMGECAwcOjGhiX9F47SF/krxJHK2s3DJmL9vFgrX7UPjnp7S0f0TBKLtBlM3k+vF9uXJUb7rEhd+yRTRefXlTswZTVbyWl6X7l7KzcCcl7hLiHHH0jOvJlN5TImreJERblV/m5ustOeSXVeK1NInRdsamdyE9uf5Rf0LJKc9h2f5lFFYWApAYlci4tHF0i+3WJGltDzcrIDcsounk5+ezfft2ysvL0VoTExPDgAED6NJFBkNqae0hf5K8STSVCo+PzzYeZF+Bi7JKLwnRdgalxnPqwGTMMObDEk2nvrypyQegCHkSw8bkXpOZ3GtyS5xOiBaXFOvg5yOapnY1JSaFiwZI3wwhWkpSUhKjR49u7WQIIUQNTrvJBSf2aO1kiAZIuwUhhBBCCCGEaAQJpoQQQgghhBCiESSYEkIIIYQQQohGkGBKCCGEEEIIIRpBgikhhBBCCCGEaAQJpoQQQgghhBCiESSYEkIIIYQQQohGkGBKCCGEEEIIIRpBgikhhBBCCCGEaASlta57pVKHgd0tlxwhRAvoo7VObu1EHC3Jn4Rol475/EnyJiHapTrzpnqDKSGEEEIIIYQQoUkzPyGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJhqR5RSk5VS+1p635ailOqtlCpVSpmtnRYhRP2OpfxIKbVQKXVjS51PCNF6JG8STU2CqXoEbtyrHpZSylXt9dXNeN6ZSqmlzXX8pqCU0kqpjGY+R5ZS6vSq11rrPVrrOK21rznPK0RbJPlRaEqpvoH8yFZr+Ryl1EOtlS4hOgrJm0KTvKnjsDW8SceltY6req6UygJu1Fp/WXs7pZRNa+1tybQJIToWyY+EEG2R5E2io5OaqUaoquZVSv1eKXUQmB2qhKR67Y1SKkop9ahSao9S6pBS6lmlVHQjzn2dUmqzUqpEKZWplLolxDb3KqVyAzU7V1db3iRpCHG+B5RS85VSrwTStVEpNbLa+j8opXYG1m1SSl1Ua/+bql3TJqXUyUqpuUBv4INA6dbvqpfyKKUuV0qtqXWcXyml3m/OaxWirZH8KKx0zlRKLQ2cr0AptUsp9bM6tu2ulFqvlLon8HqhUuqvSqllgev8XCnVtdr25wfyvMLAtscFll+nlPqg2nbblVILqr3eq5Q6MfBcK6VuDWxTqJR6SimlmuO9EKKlSN4UVjolb2oHJJhqvG5AEtAHuDmM7f8ODAROBDKAHsB9jThvDnAukABcBzymlDq5Vrq6Bo5/LfC8UmpQpGlQSj2tlHo6gnSdD7wBdALeB56stm4nMBFIBP4CzFNKdQ+c51LgAeCawDWdD+RprWcAe4DzAk37Hql1vg+AQUqpAdWWXQW8Fum1CtEOSH7UsNHA1kB6HgFerH1ToJTqBywCntRa/7PaqqvwX18K4AB+G9h+IPA6cDeQDHyMvwDIETjORKWUoZRKC+w3NrBfOhAHrK92jnOBU4DhwGXAmUd5vUK0BZI3NUzypmOd1loeYTyALOD0wPPJgBtwVls/E1haax+N/4uogDKgf7V1Y4FddZwr6Fj1pOtd4K5q6fICsdXWzwf+3FAaAvvui+D90EBG4PkDwJfV1g0BXPXsuw64IPD8s6r01/eeB173DZzXFng9D7gv8HwAUALERPp+y0Mex9pD8qMa56yRL1RbPgd4qNo17Ki2LiawT7fA64XAvwPv65W1jrMQ+FO117cDnwae/xmYX22dAewHJgde7wVOBq4AngdWAYPx3/y8X+t/M6HW+/SH1v6cyUMekT4kb6pxTsmbOshD+kw13mGtdUWY2ybj/4KsrVbYoICIR6ULVP/ej7/UxAgc98dqmxRorcuqvd4NpDVlGupwsNrzcsCpAu2jlVLXAL/Gn7GAv9Sjqiq6F/6aq8Z4DfgX8CD+0pl3tdblSqkUmvdahWhrOnJ+VNUHw17tedVrT7XXR/KoQD4B/ryoytXADuCtEOeonb9V7ZeG/5qqjmsppfbiL8kGfwnwZPw3iouAQmAS/puzRWGeQ4hjmeRNkje1e9LMr/F0rddl+L+AACilulVblwu4gKFa606BR6Ku1mkzHEqpKOBt4FEgVWvdCX/VbfXq4M5Kqdhqr3sD2U2VhkgppfoAs4A7gS6BNG+olua9QP86dq/9Htf2BZAcaNt7JT818WuVaxWiFXXk/OgA/huTvrWW96PazUQYHgik6zUV/vQL2fibLwEQaJrTC38JMPx0wzIx8HwR/huWSQTfsAjRHkneJHlTuyfBVNP5ARiqlDpRKeXE/+EH/CUC+AOKxwK1Jiileiil6mt3qpRSzuoP/O1ao4DDgDdQ8jItxL5/UUo5lFIT8bd1XdDINDSFWPyZ6eHAOa8DhlVb/wLwW6XUCOWXEQjAAA4B6XUdWGvtARYA/8TfJvuLwPLWulYh2ooOkx9p/1QJbwMPK6W6KKXsSqkr8Tc3/iSCQ3mAS/HnWa8opcL5fZwPnKOUmqqUsgO/ASqB5YH1i4ApQLTWeh+wBDgL6AJ8H0HahGgvJG+SvKndkWCqiWitt+FvbvYlsB2oPffB7/FX036rlCoObDeIuo3DXzpS+/FL/F+SAvxN296vtd/BwLps4FXgVq31lkjToPyj1zxb/1U3TGu9CX9TvBX4g6PjgWXV1i8AHsZfq1SCv11zUmD1/wF/Uv4RZH5bxyleA07HnwlWr0aP9P0Wot3ogPnR7UA+/k7TOfhrws/RWh+qZ58gWms3cDGQCrzU0E2L1norMB14An/J8Xn4B81xB9ZvA0rx36igtS4GMoFlWubLEx2Q5E2SN7VHSuuGWlIJIYQQQgghhKhNaqaEEEIIIYQQohEkmBJCCCGEEEKIRpBgSgghhBBCCCEaQYIpIYQQQgghhGgECaaOklJqjlLqocDziUqprS10Xq2UymjiYx65lpbct6Uope5VSr3Q2ukQoqVI/nT0+7YUyZ9ERyJ509Hv21Ikb2pYhwimlFJZSimXUqpUKXUo8OFt8glctdZLtNYNDr+tlJqplKo9HGiTUUotVErd2FzHP1rNff2Bc0xWSu2rvkxr/TetdZt9X0THJPlT2yL5kxB+kje1LZI3tV0dIpgKOC8wg/XJwEjgT7U3UErZWjxVQggh+ZMQom2SvEmIBnSkYAoArfV+/DNPD4MjVb53KKW2459ADqXUuUqpdYHJYpcrpYZX7a+UOkkp9Z1SqkQp9SbgrLauRkSvlOqllPqfUuqwUipPKfWkUuo44FlgbKC0pzCwbZRS6lGl1J5ACdCzSqnoase6Ryl1QCmVrZS6vrHXr5RaoJQ6qJQqUkotVkoNrbVJV6XUF4HrW6SU6lNt38GBdflKqa1Kqcsam45aacpSSv1WKbU+kK43lX8Wc5RSnZVSHwbew4LA857V9k1SSs0OvC8FSql3lVKx+P/HaYH3uFQplaaUekApNS+w3ydKqTtrpeMHpdTFzXmtQtRH8ifJnwL7Sf4k2hTJmyRvCuwneVMIHS6YUkr1As4Gvq+2+EJgNDBEKXUS8BJwC9AFeA54P/CFdQDvAnOBJGABcEkd5zGBD4HdQF+gB/CG1nozcCuwQmsdp7XuFNjl78BA4EQgI7D9fYFjnQX8FjgDGACcfhRvwSeBY6QA3+Gf+bu6q4G/Al2BdVXrA1+yL4DXAvteATytlBpSx/UXKqUmRJCuy4CzgH7AcGBmYLkBzAb6AL3xz2z+ZLX95gIxwNBAuh7TWpcBPwOyA+9xnNY6u9b5XgeurJbeIYFzfBTptQrRVCR/kvwpQPIn0aZI3iR5U4DkTaFordv9A8gCSoFC/F/Qp4HowDoNnFZt22eAv9bafyswCTgVyAZUtXXLgYcCzycD+wLPxwKHAVuI9MwEllZ7rYAyoH+1ZWOBXYHnLwF/r7ZuYCDdGXVc70LgxjDel06B4yQGXs/Bn2lVrY8DfEAv4HJgSa39nwPur7bvQ2H+P2pffxYwvdrrR4Bn69j3RKAg8Lw7YAGdQ2x35H9RbdkDwLzA8/jAe94n8Pph4KXA83qvVR7yaMqH5E91vi+SP0n+JI9WfEjeVOf7InmT5E01Hh2pneuFWusv61i3t9rzPsC1SqlfVFvmANLwf3n268AnJGB3HcfsBezWWnvDSFsy/hKCtUqpqmUKMAPP04C1YZyzXoESn4eBSwPntAKrugJFgedH3gutdalSKj9w/j7A6Kqq9QAb/tKNpnCw2vPywDlRSsUAj+EveekcWB8fuJZeQL7WuiDSk2mtS5RSH+EvOfkH/pKWmwKrm/tahahN8ifJn46Q/Em0IZI3Sd50hORNoXWkYKo+1b/ge4GHtdYP195IKTUJ6KGUUtUyhd7AzhDH3Av0VkrZQmQKutbrXPxVsEO1v11ybQfwf/ir9K77Uup1FXAB/qruLCARgYPVtAAAIABJREFUKMCf+VQ5ch7lH7UnCX+J0l5gkdb6jEaeu7F+AwwCRmutDyqlTsTfzEAF0pSklOqktS6stV/t9ziU14H7lVKL8bff/iawvLWuVYhQJH/6ieRPkj+JtkPypp9I3tSB86YO12cqDLOAW5VSo5VfrFLqHKVUPLAC8AK/VErZAx3uRtVxnFX4v8h/DxzDqZQaH1h3COgZaEeM1toKnPcxpVQKgFKqh1LqzMD284GZSqkhgdKG+8O4DlvgnFUPO/7q2UogD39pzt9C7He2UmpCIG1/Bb7VWu/F34Z5oFJqRuDa7UqpU5S/U2hzisefWRYqpZKodu1a6wP42zE/rfydLe1KqVMDqw8BXZRSifUc+2P8JSkPAm8G/g/QetcqREMkf5L8SfIn0RZJ3iR5U4fNmySYqkVrvQZ/leWT+EsedhDo0Ke1dgMXB17n428f+r86juMDzsPfIXIPsC+wPcDXwEbgoFIqN7Ds94FzfauUKga+xF+qgNb6E+A/gf12BP425Bn8X6Sqx2zgFfzV3PuBTcC3IfZ7Df+XLh8YAUwPpKEEmIa/ajcbf9XyP4CoUCdX/lFgJoaRzob8B4jGXwL1LfBprfUzAA+wBcgB7g6kdwv+0pNM5e/QmVb7wFrrSvz/v9PxX3fV8oiuVYiWIvmT5E+SP4m2SPImyZs6ct6kajZhFUIIIYQQQggRDqmZEkIIIYQQQohGkGBKCCGEEEIIIRpBgikhhBBCCCGEaAQJpoQQQgghhBCiEeqdZ6pr1666b9++LZQUIURLWLt2ba7WOrm103G0JH8Sov1pD/mT5E1CtD/15U31BlN9+/ZlzZo1zZMqIUSrUEo1ahb4tkbyJyHan/aQP0neJET7U1/eJM38hBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGkGCKSGEEEIIIYRoBAmmhBBCCCGEEKIRbK2dACE6Aq/lZdG+RXy480NyXbn4tI/Ozs6c3vt0zup3FtG26NZOohAdktdn8fWWHN75fj85JZX4LE1SrIMzh6Zy/gk9iHaYrZ1E0cFtOVjMvG93syOnjHK3l3injeN7dGL6mN707BzT2skTosOTYEqIZlTuKefljS/z6uZX8Vgeyr3lNdavObiG/1v1f1zQ/wJuGn4TKTEprZRSITqWcreXWYszmb08C4/PoqzSV2P9t5l5PPD+Ji4Z0YNfnDaA1ARnK6VUdFSfbjjA41/tIDO3FI/Xwqd/WrdqVz6zl+3i5N6dufv0AYxO79J6CRWig5NgSohmcrj8MNd/dj0Hyg5Q6asMuU1VcPXWtrf4NOtTXpj2AoOSBrVkMoXocHJKKrjy+W/ZV+Ci0muF3Kbc7Q+u3li1lw/XH+C1G8cwJC2hJZMpOijL0jzwwUYWrNmHy+MLuY3HpwHNisw8vp9dwO/PGsx14/u1bEKFEID0mWrzLG1R7C6m3FOO1rrhHUSbUOwuZsYnM9hbsrfOQKo6r/ZSWFnIzE9nsqd4TwukUIiOqbjCwyXPLGd3XnmdgVR1XktTWO7hsudWsCu3rAVSKDq6hgKp2io8Fo98upV5K7KaNV1CiNCkZqoN8vg8fLXnK17a8BJb8rdgM2xY2sJQBlN7T2Xm0JkM7Tq0tZMp6vHHJX8kpzwHn675Y1iwpIDcz3Jx57gxnSYJIxJI/XkqZqy/X0a5p5xbvriFjy/+GKVUayRdiHbtN/N/4FBRJV4ruHCqbNNCile/iydvH4YjGntKOonjLsPZcyhlbi8zXlzJ4numYBjy3RTN47ONB+sMpOr7fLo8Ph76eDMj+iZxXHepQRWiJUkw1ca8te0t/rXmX2itKfP6S0E9lgcAn/bx+e7PWbh3IWlxafxz0j8Z2Hlgaya3SbjcPj5Yn82P+4ooLHcTG2UjPTmWi07qSXJ8VGsnL2IHyw6yPHv5kf9bldxPcjn8yWF63tiTuCFxeAo8ZM/NJuvRLPr9sR+GzcDCIr8in5UHVzKm+5hWugIh2qdDxRUs3nYYty+4Rqp41TsUrXyLLtPuwNnvZJRpw7VrLa7tK3H2HIrWUFDmZvnOPCYM6NoKqRcdweNfbQ8ZSDX0+QR/079ZizP59+UntnSyhejQJJhqQx5b+xivbX6NCl9FndtY2qLCV0FmUSbTP57O01OfZmS3kS2YyqazO6+MWUsyeXvtfpT6qY8CQJTN4NHPtzFxQFdun9yfEX2SWjGlkZm/dX7QMp/LR867OfS4oQfxw+MBcCQ76HV7L7bds42i5UV0PrUz4O9HNXvDbAmmhKhFa83qrAJ25ZZSVukjLspGv+RYRvbpHFZN7twVu0MutyrLKFz6Kl3OvpuYQeOOLI/JGE1Mxugjr8vcPp5bvFOCKdEsth0qYefh0qDl4X4+fZbmox8P8MAFQ0lw2lskzUIICabajJc3vtxgIFWby+vijq/u4NWzXyWjc0Yzpq7pfbM1hzte/Q631wrZ3KaqL8PXm3NYviOPO07rzx2TM46Jpm/zt87HbblrLCvfXo7lsUgYUbP5hek0iR8eT+nG0iPBFPhH+SuoKKCzszNCdHRFLg9vr93H80syKXF50IDXp7GZCgUkRNu5eWI6l4zsWe9N5Ksrd4fsJ1W5fwva6yZm4NgG07JyVz55pZV0iTv2as1F2zbv2914jvLzaRiKD384wFWjezdHEoUQIUgw1QbkunJ5/PvHcfvcQesa6mPj8rr48/I/8/o5r7d0shtt8bbD3DZvLRWehjt/a8Dl8fHU1zvxejV3n9EEzRo9FVBZDLYoiEqAJgzQPJaHYndx0HJfqQ9bnA1lBp/LlmjDtdtVY5nDdHCo/JAEU6LDW7s7n5mzV+P16aDmT1WV2WVuH498tpV/f7mNl68fxcm9g783Xp9FocsTtBzA5yrGiElAGQ3PKeUwDQ4UVUgwJZrcjpzSGsOfV4nk8+ly+9iTLwOlCNGSJJhqAxZsXYC/fLWmcPrYaDTbC7aTWZRJemJ6K6Q+MgeKXNxaRyDVUOfa5xbv5ITenZgyqBFzMVUUwbrXYcUTUHwATDtYPjBMOP5SGHsHpBx31NdX4a3ANEy8lrfGcjPOxFvqRft0UEDlLfJiiwv+Krq8rqBlQnQk32bmMXP2airCGNXM5fGBB66etZKXrx/FqH5JQetNpfCGGBXVjE7AKi9GW74Gb1hrN0kWoqnU9bmK5PMJUOzyNriNEKLpyNDorcxreXl1y6tBw2dX9bFJm55G/PB4lE0d6WPjznVTtLzop20tH/M2zWvytGmt2ZRdzBebDvHeuv18veUQ+wrKG96xHnOWZeGto/N3/lezSBxzGT3vnEeP22YTf/LZuLavPLKNy2Px2BfbIjuhZcHnf4ZHB8JXf4GifaB94K0Ay+P/u+41eH4KzDrNv/4oxNhi8FnBP4gxGTEom6J4bc1aK1+Fj5L1JcQOia2xXKOJtddcJkRHsje/nBteDi+Qqs7l8XHdnFXsL6xZGBHrsOGrY3qJqB6DUTY75dtWNHh8rSEuSsohRdOLd4b+XEXy+QToHOtoymQJIRogvwitbHPe5qBaDIisj41Xe/l89+fcN/a+JkmTy+3j/R/288zCnRwqqcRmKCytMZTC7bU4oVcnbp2UzqSBKZgRDBHs9lq8unIP7lrtGMLtXAuw7WAJO3JKyUiJa/iEPi+8cRVkLfEHTXXRPvC6IHsdPDsBrvsUUgaHfV3VmYZJakwqB8sP1lweY5JyYQrZ87IxnEaNmkZ7kp1O4zrV2N7lcfHcD89hM2wkxyQzuedkRqSOOCb6jAnRFJ5fnEllHU2B66vFBn9eM2txJg+c739tlZVRtmIFKbqSQwQ3zzOiYuk04Wryv3gWZZg4+52EMmxUZK2jYs96Ok+5/si2Hp9Fr6ToZrhi0dEN75nIysz8oNEmI/l8xjpMBqaG8fsohGgyEky1soLKgpBN/CLtY1Pmbpo20quz8rl+zmp8lq6zycGqXfls3F9ESoKT124aTffE8G4svtp8CE2IwSYi6FzrtTSvrMjiwQuG1b+h1vDenZC1GDxhNpfTPnAVwpyz4bblEN8tvP1quWboNTzx3RO4fDXPm3x2MmasycE3D+LOcWNEGyScnECvW3ph2GtWEltYfL77cwAUigVbF9ApqhMzh87kogEX4bQ5G5U2IY4FLrePt9buCzk4TbhDRM9ftYeb89bgXbII17ofcJ4wnBnHn8lT+U5c3uDjJoy6GCO2M0Ur3iT3w0dRjmiiUjNIGHv5kW0MBWcN7Ua8jJQmmsH0MX14YcmukOvC+XwCoBRnDWvcb5cQonEkmGpluo5mJ5H2sQkVpERqyfbD3PzKGlxhDAxR5vaxJ7+cs/+7hA9+MYGenWMa3Gfn4VJcIQK0SDrXei1/08MG7VsNm9+rEUj1/U8J5R7YdVccsQ7/e/rCd27mrfewcGZVkzoNriL44s9w8ayGzxPCBRkX8N/v/htyXdKkJJImRTbMu0ZT7i2n3FvOv9f+m/nb5vPCtBfoEt2lUekToq374IfskOPCRFKLrStcfLy9kMuuvJIejz+BGRfLNRUennjoS6gjv4wbOoW4oVPqTFeUzeSmU9t+31RxbOqeGM2ofkks2Z4bcn1Dn0+7qbhyVC+ibA3/lgohmo70mWpliVGJIZdH2scm2nZ0zU525JRwy9y1YQVSR9JiaYpdXi5/bgXl7oY7vBa6PIQoaK7RuTYcJRVhdK5d/njIGimfhv+uDB41sQbthU3v+wetaIQERwIXZlyI02z62qMKXwVZRVlc/fHVIUcNFKI9WLkrL2TNeCS12C5bFDtGTCHhjDMw4/z5ZYLTzhWn9CLaHvlPn91UDO4ez7AeofNsIZrC3acPwNmIzyeA3TS4bny/Jk6REKIhEky1ssFJg7F0cABTvY9NyfoStFfjPuxm79N7g/rYGBiMTWv45qI+//58W50dvcs2LeTAy3ez598/Z9+TMzg0/34q9m0EwKc1BeUe3vl+f73HtyyNL1QkReSda/cVlPOrN9fx7y+28fbafazOyienuOKnWr6yPNj+OaFKn+8Z5+DR5ZUUVjRQk6cM/+h/jfT7Ub9nSJchRJmNGz65YEkB2/+0nY03b2TLL7eQ/XI2vjL//8ervRwuP8xvFv6m0ekToi0rKD/6IcwB8suCC07+dO4QhvfshNMW/s+fqS26xkUxe+YpYe8jRGOM6JPEn84+LuKAymk3eG7GCHp0kv58QrQ0aebXypw2JxcNuIj5W+fjsWreQITbxybKFsXMoTMbnYb8MjdfbckJWWsUTv+EcrePZxfu5KpRvVFKYVmarLwyftxfxIb9RazfV8Sm7GLsNoVpqKCgKpLOtaZSjEnvwviMruzJK2PJ9sPsXlnOnrxyyt0+eiVFc5ljOTMsI0Q3cxiZZjK5r41Hl1fy0Gn11Bx5yuGH12HMrY15S7Ebdp474zl+vfDXrDm0JqJhzsMZEt9tufk+53uyirLom9i3UWkUoq2KtocOliIdIjraEbyN3TR4+fpR3PHad6zYGboGrHZaUsqLeNaWSaeYqYC/efaa3QW8tXYf2YUu3F6LpFgHpw5M5oIT04hxyE+raLzpY/titxnc//5GPF4r5NxTVUzlLzacdc1IJg5IbrE0CiF+Ijl+G3D14Kt5a9tbIdeF08cmJTqF47se3+jzv7l6T4ghMCLrn3CopJK73lhHTkkFG7OLSXDaOb5HIsf3TOSOKRkc3yORaIfJyX/9IuTNS7ida+02xW+mDWJIWkLQMUorvezJK0etXIGZW3dTwAenRDH+pTLuGt3A8LGu/PrXN8Bpc/Lk1Cf5Zu83zF4/iy25G7BMO54QozdWqRoSv8cNPYgfHg9wZEj8bfdso2h50ZFRHH2Wj1c3v8ofx/zxqNIpRFvTOykGm6GCBqCoXosdO3hCvcewGYpedfTldNpNZs0YyZebD/Hc4kw27C/CsjSewPkM5e8flZoQxe2TMzinbwwHrryCvMEZfJo8jGcW7SS/zI3L46N6t9dF2w7z4AebuOikHtxxWobUEohGu/yU3pzUuzMvLMnk/R+yMZSq8dsZ6zBRSnHFKb1YvP0wB4vqGbFWCNGsJJhqA3ol9OKCjAt4f8f7VPgiyxCdppP7xt53VENmL9p6mApvcFPDSPoneLwWheVubpvsD5yS6pjn4rKRvXh15W48IYraGupcC9C3S2zIQAr8c78MSUuA1Dj/3VAdBc7DUkzOHWjj70vdHJdcT1OKOgYHiYShDKb2nsrUglx2F7v5avi55Lpy8fg8vL397aDayEiHxH9v53v87pTfYTdldDHRfvx8ZE9mL9sVFExFVIttKC4Z0bPOcxiGYtrQbkwb2o1duWV8vvEgh0sq8VgWXWOjmDgwmRN6Jh7JWzv/53Gue+JrtndVuOqoKqi62Z2/Zi8f/JDNyzeM4uTenY/27RAd1MDUeB75+Qncd95QPlqfze68coorPHSOcTAwNZ5pQ1OJsplsyi5mxosrmTggmW6JMtKrEC1Ngqk24t5R95JTlsO3B74NO6CKMqO4b+x9jOo+6qjOXeg6+v4JGhjcPYFJA+tvZnD9+H68uXovHl9kE3GCv7nN3acPCGPDzmDYwVf3QBN/mezk5OdK+c3Yevo0OTvVvS5Sm9+nz5Cfc/2wqwDIr8jnfzv+F7RZpEPia63Jr8gnNTa16dIqRCvrnxzHoG7x/LAveBCYcGuxh6Yl0K9reBNf9+sayy2T+te53u21uGlxPluS+gbNkxeK19KUVHqZ/sJKFtw6lqFpMmiFaLy4KBuXn9K7zvVD0hKYPqYP977zIy9eO1LmIxSihckAFG2EaZj897T/csmAS3AYDhxG3U3QYkwnsfZYHpv8GOf1P++oz203Q38MIh1lL5wO3b27xPCvy4ZH3Lk22m5yxSm9OGtY94Y3Tp8M9TSlA8hIMrh8qJ3HV9URcNmcMOT8iNJYRVsaX5kHb34FvjIPuqIUdi2GgWcd2abMU4ZNBZdlVB8Sv7ZQQ+IbyqDM0zRzjAnRltw+JYOYEH2ewF+L3f3a/9D712/T6855pFz6AM6exx1ZH+MwuX1yRpOl5R+fbmHTgWLcIeKo+gboKXf7mPHiKiq9kRceCRGJO6ZkkF3oanAwKCFE05OaqTbEUAZ/GP0Hbjj+BuZvnc9rW17DY3kwlf+GwmN56I2d6zsNZ9rPnsBhNtDnJ0zdE538uD+4BDiS/glOu0FyQnjNC84+Pg2fpbnnrfW4vVbIgS+qi7abXDW6N388+7j6N6yS2AP6jIPMb+rd7L5JUcxdH7pWzuP18UPXCxihddilfJ5cF6XL9lO+9hDa8u+nLY0yLGIT7iauMgZboAtHjC0Gnw6+wao+JH7iqJ9Ks6uGxE/9ec0aKEtbxNgbnuNLiGPNtCGpnH18dz5anx3RlA3RdpPzTkhj6nEpTZIOl9vHayv3UBEiDeEM0FPp8fHphoNccGKPJkmPEKE4bAaPXnoC1760igkZXUkJ8/dYCHH0JJhqg5JjkrnjpDu45YRb2Feyj2J3MTbDRpIziW6F2fDWdVRaBptziil2ebCZBslxUfTu0rib6itG9WLpjtyggSEi6Z+gNZw1NPxZ1887oQf9k+N56psdfLH5EIaixs2K3VAYhmJYj0TunJLBlMER3hiNvwv2roJqtTZZd8fX2KRXokHFn4L7X2llcChlIr/9OJvkJXncMSWDSQOT6wyqrHIPea9uoXJ3EVhQFR1WTaSsfYrSgpGU/msNzoGdSbp8MIlRiRgquHau+pD4htOoMZpf7SHxq86R5IxsEmAhjgVKKf5+8fG48gv4ckcBlWEUHkXbTc4cmsrfLjq+yZo6vf/D/qOaQLjM7eOZhTslmBLNbliPRK4a3Zt739nArGtGSHM/IVqIBFNtmM2wBQ17va8ygZdLzuW1Bz8Dw8QIZJYen0WPTtHcOqk/552QhrOOoYVrO1hUwdLtubjqGB44nP4JCjh1YDLJ8ZHNqTQkLYGnrj6ZgjI3C9buZd3eQgrLPcRF2UhPjuWKU3rTN8w+D0HSJ0PvMbB7GXgjG9RDOWLpedmjfNmpHx/9eIC/fbyZf32+jTtPy+CM41IxjJ9+oHzFbnKeWoev1E2949daBliaim0F5DzxPSm3n8CF/S9kwbYFeHXNJonhDolvKpNz0s9pshpKIdoao8LFr975BydcfDsvHI6irNJLWYi8KsZhEu+08YvTBnD16N5NehP54tJdRz2B8O68cnbklJKREtdk6RIilDtPy+D8J5bx/g/ZEsAL0UKUrmfEspEjR+o1a9a0YHJEXbw+iz++u4F3v9+P5fPi0aH7HMU4TBTw+JUnMfW4ugcl2JVbxnOLdvLJhoNccnJPHDbFnGVZIUf1a0i03eSVG0ZxSt82VkPiLoc5Z0POFgh3nidHLEx/B3r/VLJsWZovNh/iya93UOn1cceUDM4dnobyWOQ8+T3evAoabKtYnamwd4vFdeZhfr78t1Sqxo0aGGVG8cY5b5DRObK+IUqptVrrkY06aRsi+VP7l/37P4Bpkva3h7EszbKduby4dBc7c0pxeXxE2036p8Rx44R0xmd0aZaS+GH3f0ZpZXAfzNKN31DwzYv0unNeg8eId9p4/IqTIq9h74DaQ/7U2nnT+n2FXD9nNR/fNZGUeGnuJ0RTqC9vkpqpY4DHZ3Hd7NWs3V1ApdeivnFDqkpQ73jtO/524fFcXGto4I3ZRTy9cCcrduYxfUwfvvntZJJiHVR6fSzfmcfmA8Uhhy2vS7Td5PJTera9QArAEQPXfQrv3gpbPgZ03SP8OWL9o/ddvQBSh9ZYZRiKM4d2Y9qQVBZtO8yTX+/gsS+28VC3rvQtCA6kVu1bz9++eYZtuVkYhsGALn24f+ovOLF7oM+XT+M9UEjym29w4uAefFdxKGiI9IbYlI0hXYZEHEgJcawofOddXBs20G/BfMD/PZw4ILnFJyata/CISCYQ1lqHDMiEaA7De3bispG9+PO7G3h2ur+534b9RXy64QAHiyuxLE1yfBSTB6UwJj1JmgMKcZQkmDoG/OHt9azdnR9RJ+wKj8W97/5It0Qn4zK6smpXPk99s4MtB4u5cUI6/7hkOHFRP/37o2wmc68fzZWzviXzcGlYNVTRdpOzj+/GfecObXDbVmN3wqVzoCALVs2CtXP8HbwME7TlD676jPP3seo3GYy6A1WlFJMHpTBpYDIrM/NIfGlz0FxWJZVlXPfWH3h42q85b/AU3D4vq/b9QFStpnjaslHsuIVHzxvEpR9eSm55blBzv7qYyqSzszP/mfKfiN4KIY4VlZmZ5DzyCL3nzMGIiawvaE5JBa+v3MPnmw5R5PJgKkVSnINLR/TkwpN6EOOI7GfPaTPx+IK/m5EM0KOUIs4pP7ei5fxy6gDOeXwJf35vAysz89lX4KLS6ztS9qeAud/uJjHazi2npnPZKb0i/m4IIfzkm9PG7cgp4aP1B0IGN2WbFlK8+l08efswHNHYU9JJHHfZkVGkKjwWv56/jp6dYzhcWsktp/bn+WtGEGULXYqaGGPnf7eP46GPNvPW2r1BM65XiXWY2G0Gd00dwMxxfY+NUq3OfeHMh2Hq/VC8HyoKwR4DcSn+eakioJTiBK9Bvmmia82XlZm/F4ALh5wOQLRhMqlf6HnArFIv0TkGr539Gtd9dh0Hyw5S6aus99xRZhQpMSm8dOZLMvCEOCZtzC7i0w0HOVhcgWVpusZFMWlQMmPT/c30rIoK9v/q1yTffTfOQQPDPu6OnBL+75MtLNmei4JALb7f7vxyth4s4a8fbuaik3pwz5mD6FzHxOK19e0aG3K000gG6HF7Lfp3lf5SouV4fBYO02Det3tCrtf4W7KUu33849OtvLxiN2/cPIZUGQVQiIhJMNXGvbg0C4/VuCF5AQ4VV3L1mD7cNqk/tjrmk6rOaTd56MJh/L+fDebddfuZvWwXh4oqqfRZRNtNBqTEccuk/pw2OAXTOAaCqNpsDkjqd9SHca0/jA4RaKYn9cJQBr/66GHOHzyVk3oMpZMzPsQRQHssXBvzSP5ZP+afO583trzBK5teweV1Ue4tr7FtjC2GKFsUM46bwVXHXUWsvZEDcwjRCrw+iw/XH+CZRTvZnVcWNCVCVQn5zRPTOfWLuTj7p9PpskvDPv7ynbnc+PIaXB4fdXUDrioYWrBmL19vzWHBLWPplVR/rZdlaZLj6g66IplAuLGjrQoRqQqPj8uf+5adueHNQejy+NibX875Ty7lk7tOJSnMggYhhJ8MQNGGlVV6GfHQF0Hzm1iVZex76lq6nH13g81LDAVnDevG01ePaM6kdjiHX9pA5baCkOu252bx9MrXWJq1lsNl+UzpP5pHzvodybHBNUkxJ6WQdPmgI68tbbE8ezlf7P6Cw+WHAeji7MIZfc9gfNp4zAb6ZoSjPXTwBsmfjhVllV5ueHk16/cVhazprs5pQFJ5IQt+dxY90rqEdfwf9hZyxfPf4vKEPzGuoaBrXBQf3zWRrnGhRyH1+ixumbeWZdtzGzUwT5XYKJPHLjuRaRFMHdGRtYf8qbXzpl+9uY5Pfoy8RYvd9E9H8s7t41sh1UK0bTIAxTHqx/1F2A2DCmpmiJEMyWtpWLYjr7mSKEIY0LUvj51zLwA78nbzyw8f4oGvnuCp8+9vcF9DGUzoMYEJPeoPkoU4FlR6fVw561u2Hiyp0eyuLhUWHIruxEWzv6830Kni9VlcN2d1yECqvptGS0N+mZu731jHvBtHhzgy3PvOjyzfcXSBlKGgU7Sd02QUP9FCcksr+ejHA7hDfG4batHi8Wm2HChhw/4ihvVIDHF0IUQoEky1YUUuD6HqDX2uYoyYhAZHkKpS7pZRpJqaGWcPa7uMLn24bNhZzFv3fsj1Rrw0pxByEGMIAAAgAElEQVTt15/f3cC2OgKp+oKd/DI3189Zzft31l+o8OXmQ1SGCKTCaQbttTSrs/LZV1BOz841m+Bt2F/E+z9kB7UKaCjd1SkFcVE2Xr9pbFhNrIVoCq+t3EOoBvjhTjLt9lq8uHQXj11+YgukVoj2QXL4NsxUKmSmWH1I3nAYx8IAEceY6GFdUVHBweyOvN08t+oNDhTnAJBdfIj3Nn/FyWnBIx4qh0H0EBlEQrRPBWVu3luXHbJmp3jVO+R/NYvEMZfR88559LhtNvEnn41r+0rAH+hsP1TKD3sL6z3HMwt3Bk3iW3XTmHTGbcQMGofhcKJMGzEZo2sMCgFgac3cFbuDjvvCkkw83uCirIbSXV1SjIN37hgvfaVEi5qzPCtk4UW4LVp8WvPxjwcok6H8hQibBFNtWFKcI2TNVPUhecMhQ/I2PefgJJQZHKTGOmJYl72Z8+beysB/T+P8ubcxqGs//nzaHUHbGrF2HH0SWiK5QrS4N1fvIVQ5TrjBTqXXx6wlmXUe/0CRiy0HS4KWR9IM2uPTvLF6b41lRS4Pn2w4iK9Wf+JIgjRDwcUn96B/sozgJ1pOhcdHYXnouRQjadFiMxUHisKc6F4IIc382rLhPRKxhRgxL5Ihee2m4vwT0loy2R2CMhRx43tQ/M0eqFaC3T0+mWcu/EvD+9sN4if1PDaGlReiEV5alhWymVy4wY6l4YtNhyiu8JDgDG5We6CoAofNCCqFj7QZdHGFB8vSGIG8dsXOPOxm8HEj7av65eZD/PGcIWGlQYimUFbpDfnZhcgmmTaUoqRCaqaECJcEU22YzTSYOa4vzyzaGZQ5hjskr6EUM8f1bcFUdxxx49MoW3MQX2ElIasQ62KAmeQkdkRqs6VNiNbk9VkcLg09Z1pEJeSWj01z3yLDKsEqK8MqLz/yd487Git+FBg1+x1GctMIoCyLTdPOItpuoqKi2NFlCN4uo8Co+fMYaZBW5JKbUdGyYqNseHyhB0yJZJJpS2viouT2UIhwybeljbtqTG+eWbQz5Lq4oVOIGzqlzn2VguE9E+nTReYkag6G00byLSeQ89T3WGUeCGfQL1NhJjhIvul4lP3ohzkXoi0qc/uwGQqPL7iUIaJgx+ejcF82VgIYsbHYunbBiI3FiIkhzetErXTVqBmGyG4aATAMBsx+ETxudGUl8Zvy4btCqNUlNdIgTYiW5rSbxEXZKA5RqxRJixaPT5OaKJP3ChEuCabauJR4J787axCPfrYtonlUAGIdNv5xyfBmSpkAsHWKIvWuk8l7ZROeA2Von6bGbKRVDMAwiOqbQJerj8OIlq+eaL9iHCbeUN8DIgt2lNNJ71tuIiUtuG/hkEov3lVfULtaOJKbRoB+XWOJ6tXzyOtU3wFs69eDr+YNaaRBWoL0VRWtYPqYPrywdFfIodHDadFiKJg6OCVk01ohRGiS2x8DbpiQzsrMfD7fdCis7ZXyB1JzbxhFunSAbnZmnIOU20/Ec7CMkiX7Kf/hsL/jvaGocPuwm4r4kd2IG5+GPVlG9hLtn9006BRtp6DcE7QushJyi7ROoUvI46JsnDc8jf99vx9frcAt3GbQMQ6T2yZn1Fg2Nr1ryKZSkaTbYTM4Z3j3ht8oIZrYjLF9eHHprjrXN9SixWk3uenU9OZImhDtlgRTx4Af9xWxZncBd00dwJzlWXgti7LK4FoqU4HdZtA/OY7HrzxJRpJqYfZusSRdOpDOF2bgK3OjK3088vV2+vTuxPTxfVs7eUK0qBlj+/DcosyQneHDCXYUMCGjK51i6p6L7YaJ/fhgfXZQMAUN3zRWObdW0JMYY2fakFQ++vFAUCVzuEGaAmaM6dvguYVoat0Tozl1YDKLth0OWTtVHxMfvRwuTuoR30ypE6J9kmCqjTtUXMHNc9fwt4uGcdaw7vzitAy+3JzDMwt3sGF/MQAWGqfN5LwTunPDhHQGdZOMsDUpu4EtUJqe3COenfllrZwiIVre9NF9eHZR3UObNxTsRDtMbm6ghHxwtwRO6ZvEql35IYO2+kTbTW49tT/OEH0Xbzo1nS8354RsWt1gX1VgdL8kukmfE9FKHrnkeCb84xtCD5IemqEg3ulgTrfXUHPnwCUvQHy35kqiEO2KBFNtWIXHx82vrOGqUb05a5i/9NRmGpw1rBtnDeuG1poKj4XNVNhNmTKsLUrvGseyHXmtnQwhWpzTYdIl1sHBooqIBrsE/4Tl3RKcjOrX8KTWz04fwflPLmVvgSvskvhou8mUwcn8YmpGyPXDe3bizKGpfLrxYMjh3esTE2Vy//nBk3QL0RK8Pos/v7eR43t2Iq+0kr355SEnzq7OYSo6xzh485axdE86Axb/E56bBBc9C/0brt0VoqOTO/A2SmvNPW+tp0+XWO48LfQPvlKKaIcpgVQblp4cS2ZuaWsnQ4gWtSYrn7P/u4RJA5Pp0yUGe4gJruuilH+i8VduGBXWPGyxUTbeuWM8w9ISiHE0NH+OP5C66KQePHHlyfUe/5+XnsCofklE28PPX2McJnOuGyVNrEWr8Pgs7npjHSUVXuZcdwof/GICv542kOT4KGJDfDdio0wSnDZuPrU/n/3qVPp2jQXDhMl/gIufh3duhW/+BlZkg18J0dFIzVQb9cTXO9iTV8abt4yViV2PYb2SYjhUXEmFxxeyOZEQ7YnXZ/H41zt4beUe/u/i4zljSCr5ZW6ufuFbduWWNVjLYzcVnaIdvHHLGHp2Dn+wlgSnnQW3jmPh1hyeWbiTH/cXYSj/EM9K+QfE8FqaM45L5caJ/Tipd+cGj2k3DWbPHMWDH2zkjdV7URCyhF/hb5KYFOvghWtHMrhb8MiDQjSW1ppSTykV3gpi7bFE26JD3hN4fBa/fP17Kjw+npsx4sjvzc2n9ufGCeks25nL++uyOVRSiWVpusQ5OGtoN04fkhq6QDZ9EtyyGP53I7xygTT7E6IeEky1QZ/8eIA3Vu3h3TvGyw34Mc5uGvTsHM3uvHLpyybatb355dz1xvfERtn4+JcTSEnw9xlKinXwzu3jmbtiN7OWZFJW6aXMXbOkO9ZhYhiKGWP6cOPEdJJi6x50oi6moZh6XCpTj0tld14ZK3bmUejyYDMUSbEOpgxKoXOExzUNxV8uGMYvpw7g9VV7eGlZFmWVXmyGwtLgtSwmZHTl5lP7MyY9SQq+RJPJKsri1c2v8t7O9/BYHkxl4rW8dIrqxIwhM7h4wMV0dvoLBdxei1+8/h1en+bZGSOIstW8bzAMxcQByUwckBxZIuJTYca70uxPiAYoretuzT5y5Ei9Zs2aFkyO2LC/iGteWsUr149iWI/E1k6OaAI3vryGS07uwc+ObxtDJSul1mqtR7Z2Oo6W5E9tx7vf7+evH27itsn9uX58PwwjdFBhWZrlO/N4d91+coor8FmaLnFRTBuayrQh3XDY2naTZcvS5JZVUuzyEmUzSIp1EBslZZJNqT3kT0eTN+W58vjNot+wIXcDPsuHVwdPwOs0nVhYnJ9+Pr8d+QfufuNHtIanrj4pKJBqMpkL4X+3wIhrYdLv/c0BhehA6sub5FegDckpruDmV9bw0IXDJJBqR/onx5KZKyP6ifanuMLDfe9u4Mf9RbwcRgGQYSgmDOjKhAFdWyiFTcswFCnxTlKkklk0g/2l+5n+8XQKKwpDBlFVKnwVAHyY+SGfbl/HEO7hmavHNG9hRPpkf7O/t2+QZn9C1NK2iwHbC62h5BDkbIHDW6EsN2iTCo+Pm+au5YpRvTm7jdRgiKaRnhxL5mEJpkTblFNSwecbD/LW2n28t24/y3bkhpy0tra1u/M55/ElxETZ+PAXE6UASIijUFRZxMxPZ5JfkV9vIFVdha+CcnZjT5tLc1VI1RCfCte8B33G+5v9ZS4Me9cKj4+ySi/1tYYS4lglNVPNqbIUflwAy/4DxQfAtPuX+9yQlA4TfgVDLkTbovjdW+vpnRTDL+oYuU8cu/p1jeON1XtbOxlCHKG1ZuWufJ5flMnSnbk4TANLaxT+UUKVghlj+jBjbB+6J0bX2Nfrs3jymx3M+3YPD180jDOHSum0EEfr+fXPk+fKw9I1CzIKlhSQ+1ku7hw3ptMkYUQCqT9PxYz1R08WHr7LWcvifYuZ3Gty8yfUMGHK/4M+YwPN/mbCpN8FNfvTWvPDviJmLc7k800HsSz/SJ2W1pzUuzO3TurPaYNTMOtoEizEsUSCqeay+kX4/E+AAk+gVsJX+dP6w1vgo1/DR7/hs/R72Z13gozc105V1UxpreX/K1pdaaWXG+as5sf9RbjcPjSEnJ/phaW7eHHpLu49ezDXjusH+AeZ+NWb64iyG3z0ywmkJsjEtEIcLbfPzdvb38ZjeWosz/0kl8OfHKbnjT2JGxKHp8BD9txssh7Not8f+2EEmvW5vC5mb5jdMsFUlfTJtZr9veivucLf9/uXr3/PgaIKKr0+rKrKqMDftbsLuPuN73HYDB68YBjnnZDWcukWohlIM7/m8OWD8PkfwVP+UyAVirsM3KVM3vIAc4d9JyP3tVNdYh1orckvi2Q+eiGaXlmllwufWsa6vYWUBwKpuri9FpVei79/spX/frmN99bt58KnljFtaCpzrx8tgZQQTeTz3Z8HNX/zuXzkvJtD2vQ04ofHo2wKR7KDXrf3wp3rpmh5UY3tN+RtYG9xC7eAqNHs71TIXMiS7Ye59NkVZOaW4fJUC6RqKXP7KCj3cM9bP/D0wh0tm24hmpjUTDW11S/CyqfB4wp7FydunEv/Bsm9YeiFzZg40RqUUqQnx5GZW0aXuKjWTo7ooLTW3PTKGvbml1MZoiaqLi6Pj8e/2k6XOEdYg0wIISLz4c4PKfeW11hWvr0cy2ORMKLmvGWm0yR+eDylG0vpfGq1+dI0LNq3iOlDprdEkn9SrdnfhvkPcnPxr3BZ4ZfTV3gsHv9qO11jo7jslF7NmFAhmo8EU02psjRQI1UzkOr7nxLKPbDrrjhiHf5mXi9852beeg8LZ8b6N/K64IO7YPC5YMq/pb3xN/Ur5ZS+Sa2dFNFBrdtbyPd7CkMGUmWbFlK8+l08efswHNHYU9JJHHcZzp5DAfBpf6HA0DSZkFaIppZfkR+0zFfqwxZnQ5nBTcNtiTZcu2veZ7gtd8jjtJj0ydxt8+GyKkKuri+PqfBY3PfeBs46vhsJTnvLpluIJiB37U3pxwXU1XLSp+G/K93cO7GemgnLC9s+gePOa570iVbTP1AzJURrmbUkk0qvL2h58ap3KFr5Fl2m3YGz38ko04Zr11pc21ceCaYASiq8rNiZx7iMY3NYcyHaqtqDTgCYcSbeUi/ap4MCKm+RF1tc8O1bqOO0lPX7Ctlf5Am5Lpw8RhmK/63dx8zx/Voy2UI0Cekz1VS0hqWP1dlH6p5xDh5dXklhRT29FNylsPQ/zZRA0ZrSu8rw6KL1FJS5+XJzTlD/BauyjMKlr5J0xm3EDBqH4XCiTBsxGaPpPOX6GtuWu308t3hnC6ZaiI6hs7Nz0LKYjBiUTVG8trjGcl+Fj5L1JcQOia2x3G7YQx6npcxaHLqwJtw8xuX28fziTBk6XRyTJJhqKiUHofRQnatHpplM7mvj0eWVdW4DQPZ34JWBCtoTy9LklVWyNiufBz/YyD8/28rcb3fLgBSixWw6UEyUGZzdV+7fgva6iRk4NqzjfL+nsKmTJkSHN63PNGJsMTWWmTEmKRemkD0vm5L1JWivxn3Yzd6n92JPstNpXKca2xvKYHza+JZMdg1fbD4UcrCJSPKYgnIPO6XQURyDpJlfU3EV+OeR8oZuLwzw4JQoxr9Uxl2jHXUfx3T4jxUYYlQcu4rKPby+ajcvLN1FudtHudvHS8uyAIi2Gzz04SamHpfCzaf258Reneo/mBBHodjlCTlyn89VjBGTgDLCG0nU5QkueRZCHJ1z0s/hkdWPBC1PPjsZM9bk4JsHcee4MaINEk5OoNctvTDsNQtH0hPTyejcOvNUenxWnYPaRJLH2ExFQbkUMopjjwRTTamB2ulhKSbnDrTx96Vujkuup1JQ5iI65m05WMxVs1ZS7vZS4Qn+kXEFln264SDfbDnMTRP78aszBso8VKJZOGwGoT5ZZnQCVnkx2vKFdbMjE2wK0fRi7DGc1/883tn+Dl7trbEuaVISSZPqH7go2vb/2bvv+CjK/IHjn5nZ3exuKgkkhBBKCITeO1IFxYJ6FtRDFBVF1BPU8353ljvP655dT0CKDVTsWLArKD3SCT2UNEJCSN8+M78/VpCwk2QXUuF5v168XmbL7BMYZ+f7PN/n+7VxW6/banxNfdJ1kDC+BQr1GqNWV0tdEJowkeZXV+xxoNY+o/LXMVbmb/KQW1bNBUP1gFWsUjRn+46Wc82cNRyv9BgGUqfSdP9s//yfDvKP5bsaaITC+SY+0opmsBchLKkrksmMY+/aoI7Twl7DqrogCGfsrj53EWGJCPl9FtlCakwqF7a7sB5GFeQYTHK1Ey2hXGM0XSfaJqr5Cc2PCKbqSmQCxLSr9WWpsTLX9zDzwoZqAq/koWASNyzNldOjcsP8dTjcgelQlTtXcOT12WQ9cy05L03l6Lt/wZWT4X+fV2XJuiw+2ZLb0EMWzgM9k6KIsAYmIshh4cRcMIXj38zFsXctmteFrvpwZv5M8Q+Lqrw2zCRzw6Dar3GCIIQu3h7PwosXEmmORA7y1ixMCaNNRBvmTpiLWW7cIKS6th+hXGMUWSI1PvSAUhAam0jzq0sX3A9fPASemjdQ/nl0GG9uMyghaomAC2bX0+CEhvDJ1lycHjUg3SGY0rBOr8p/v97DpD5tRLqfUKckSeKOkSk8/fXegH1PUYOvRg5vQenapRz77Ckki42whFSihl0fcJzfDhHBlCDUly4turD08qXM/G4mBY4CXD4XukHynIwCyAxqPYinRz+N3WwPPFgDu2t0J7Zml1DpCZxIDOYaE2aSuWVYB8wGhXIEoakTwVRd6nk16ue/5/Ss4EOzI6v8nBwt43rUoPmlyQqp4+tvfEK90nWdOSsycZz2ZXKiNGzcpbOxpw0/+bg9dQj21CFVXltU4WFTVjED2ovmvkLdum5gMk99vcfwuYgeY4noMbba95pkiVFdWtEqsoY+eYIgnLXkqGQ+vepTNhds5rWM11iduxqTbEKSJDRdQ9d1Lku5gk9+7MTsiyc1iUAK4ILUloSHmQyDKaj9GqMDNw1tX0+jE4T6JYKpOqJpOvNW55GrTuevyisoavVV/QyZrPCbeRBkVS2h6dmeW8rRssDS96GUhj2xf0oEU0Jdi7aZeWZyH+57ewu+EDZ5SxK0CLfwr6t71ePoBEE4QZIk+if0p39Cfyo8FRQ6C8nIK2HzQRWH24z3qEz3lpU88VkGS6YPbezhAiDLEv/8TS/ufXtTrXuFT2czK0wd1p6EKGs9jU4Q6pcIpupAQbmLB9/disur8tysh1EyWsEP/wSfM7gDmKxw2TPQWaxKNWd78ssNCzGGUhpW12FnXlmtrxOEM3GoyEGM3Uy5y1dtKeNTmWSJ2HAL7901jJYRYlVKEBqST9X4bmcpc1Zkc6ioEq9PQ/1lHkSW/AWMLvjP99x3YWeu7NuGMFPjTsaO757AHy/pyr+/2B10QCVLMK5rK/44sWs9j04Q6o8Ipn7h9qkcLnJQ5vRiUmRaRYaRFGOr9X0r9xby0HtbuWFwO+4bl4pJkWHEfRCVBJ/fD5oKngrjN1siwBQGv3lFBFLngHKXz3DGP9TSsA6Pr9bXCEKoFvx0gHfTs/n8vpFkFlTw+KcZZB934vFpqKdV+rOaZXQdJnRP4K9X9CBOBFKC0KDKXV6mvZrOriNlAanjwMkGuTnFTv7ySQavrT7EkulDaBHeuAWspg3vSFx4GH94fxuShOHYwb9HSgcSosJIiLIii7YLQjN23gdTuSVOXl99iLc2ZKGjI0sS6OBRNdrH2Zk5phOX9EzEaq56E+zxaTz99R6WbcnjuRv6MrxTy6oH7nUNdL8Cdn+Ovuo51LytKGaLv9eL6oXkwTBiNnSeIFL7zhF2i4JisDR1amnY8K4X1Hqc0881QThbr685xOtrD7H0zmEkRFlJiLLy9f2j2ZFbysJVB1mTeYxKt4pJlmgRbuH6QclcPzC50W/MBOF85PKqXDd3LQeOVeIJYgXZ6VHZV1DOVS+v5rPfXUCktXEr+03q04YLu8XzyZY85q7MJL/MdbKwhKrpWEwytw7vwI1D2mFRZCa9tIp+7VowqU+bRh23IJyp8zaY8qkaj3y0g4+35KLrOh41cEVh79EKHv1oB499nMH/pvRndJdWAGQVOfjd25toGRHG8lkjia3uhkMxQ4+rOJQwgdsX/Mj3d/cFSQZbjH9FSjinJMfakQ0KEZ1aGlaSFawd+yHJJlyHtuDK2kaLsVWbLSa3aBobioVzw5L1h3nlxwO8c+dQ2py22t4zKZpnr+/bSCMTBMHIQ+9t5aBBIFW5cwVl6R/jLcpBttgwx6cQPXwy1rY98Ko6R0pd3LNkE2/cPqSaIzccu8XEDYPbcf2gZHKKnRRVelA1jWibhQ5xdn8Wzy/mTBnAzYs2kNY6ki4JkTUcVRCapvMymPKqGrcs2sDmrJJa9w2cqEwz482fefKa3iBJPP5JBveOTeXWER2CKmG972g57RPiICqxTsYvNE1DU+IIMylUGvSYCrb8dHiYwq0jOjTQiIVz3bvp2bz0/X7euXMoybEiSBeEpu5omYuvdh4NCKSCaa/h8WmsP3icA4UVpLRqGv2aJEkiOdZe4/WnZ1I0j1zajRlvbmTZvSOIauSVNUEI1XkZTP3h/W1szirGGULFGZdX4/6lW4mPCuON2wbTMyk66PfuK6gQsy3nAUWWuG1EB176fj8ugyC9ttKwAGZFZlzX+PoaonAe+XBTDk9/s4e37xhK+7jwxh6OIAhBWLzuMKdP0YbSXkPVdF5dfYi/XdWzAUZbd64Z0JYt2SU8+O5W5t00QOyhEpqVZhdMubwqn207woKfDpBT7MTtU7GaFNrF2blzVAoTe7ausaLNnvxyvthxxLDSTE1L6ACqrhMeZqoxkHJ4fCzbkseiVQfJL3XhVv19IdrE2OjXrgXju8VXWd4Wzi03Dm7H/37IPKP3Ws0yd4xMEeeHcNaWbcnl31/s5q07hjSZGWpBEGqm6zpvrD0UkDETSnsNn6bzwaYcHr28W6NX9wvVY5d35/pX1jJnZSb3jE1t7OEIQtCaTTDlUzWe+noPb6z1z9qc2hjOq/rIyCvj4Q+388hHO7j9go7MurCz4czGwlUH8BqsGgSzhA6QW+xkR25pQEDl9qn8c/ku3k3PMaxgc7jIwYPvbcEky9w7NpXpIzsGlSIoNC9xEWG8fFN/Zi7eGFKvjTCTzMD2sdw1ulM9jk44HyzffoS/f76LxbcPITVerIgLQnNR7vYZVr8Lpb0G+FtsFJa7advM9t9aTDJzpgzgipdW0bttNCM7t2rsIQlCUJrFFLjTozJlwXpeW3MIh0ettsN2pUelwu3jlR8zue319ICc4wq3j0+25HF6rYkTS+ixE2ZiTxuObLEiKSbsqUMCigN4fBoLVx2s8liZy8t1c9aydEM2Tq9abSnQSrdKqdPLM9/s5b53NqOG0DhTaD7GpsXz3PX9sJplw75Tp7OZFYamxLHgloEoIrVBOAtfZ+Tz52U7eO3WQaS1FoGUIDQnFS4fJoMqRqe21wiGjs5fP93J1IXrmTJ/Hfe9vZnPtx3Bq4bWTLcxtI628vwN/bh/6VZyih2NPRxBCEqTD6ZUTWfG4p/Zkl0S9Ey/06uxLrOIWe9sRj+lf8rW7BLDFKpQltBVXWfl3sKTP3t8GtMWbWD30XLDfTLG41P5dmcBj368o8r4hHPHxJ6t+XDmCMZ3SyDMJGPRqvaOktCxWxSSYmw8cllXFk0bJEqiC2flh90FPPzRdl6dNpgebYLf0ykIQtNgtyj4tMD7iFPbawTD5dX4ZudRftp3jNWZRXyyNY8/vL+V/n/7hie/3E1xpaeuh16nhnWKY8aoFGYu3oTLG1wAKQiNqcmn+X24KYf0g8WGVfdq2uPk8mms3FvIlzvyuaSXv4peqdMLBAYvoS6hV7p/vTF+Y+0hdh4pM+wFUdP4nF6VjzfnMql3IsNTWwa8V2j+ureJYv7NAzmaV8j/7vs3+y+6hlKXisVZSVxuJnfceQODO8aKdE/BkKZrbDy6kZzyHBw+B+HmcFKiU+jVslfAOfPj3kJ+/95WFtwykF5tRSAlCM1RlNWMSZbxqlUDiFDbaxg5kdGz4KcDvL8xh6UzhtGxZdMtTDN9ZEe25JTwl2UZ/Ofa3o09HEGoUZMPpuauzMRpMDMRzB4nh0dlzsrMk8GUP4Mq8Mb11CX0YAIq+ZcbGU3TeeXHA4YrZsGMz+lVmfdjpgimznH2XduYHlVCuzv9VZh8hYUcmPRXOne8WwRSQoBSdykf7fuI13e+jsPrQEdH0zTkX9J/4qxx3NbzNi5LuQy72c6a/ceYvXQLr0wdQL92LRp59IIgnClZlrhuYFveXp+F97RtAMG216iNR9U5VuHmNy+v5otZI0mMttX+pkYgSRJPXtObK/+3mnc2ZHHD4HaNPSRBqFaTDqa2ZpeQV+IKeDyUMqF788vZsTKdNod3o27PQqMbKFUb5p66hB7e9YJaxxVp9f+1rc48RoXbF/B8KONbd+A4+aUuWkdba/1coXmqXL8e+5Bf/92Vlv7gWT12DFMrscFW+FV6fjr3fncvmq7hUk+79v0yZ5NTkcNTPz/FC5tfYFb3p/jnsjJentKfgR1iG37AgiDUqVtHdGRpejYY7KkOpr3GqWrKjil3+Zi2KJ0vZ49sspN64WEm5k0dwOS5a+mWGEWf5JjGHtJ56UjFEd7Z8w6bCzZT4akgTAkjOTKZyWmTGZAwoMmeP5fJiukAACAASURBVA2pSQdTH27Oxe0LXJUKZY+T1+PlncVfcU+Sj8GDe6Nst4Gn6kpSKEvoZkXiyr5tAHj352zDYhOhjA/81bduu6BjUK8VmgevquH2aYRbFBwb0kl8/C8nn5MkibDOnXHv2yeCKeGk1bmrmf3D7MAgyoDD58Dhc/DXjXfz8KQXGZoS1wAjFAShvnVsGU6/djFsPFyM9/RqWSGoLTtG1XSyix1syiphQPumu6LdqVUE//hNL+5esolP7h1BXERY7W8S6sTWwq28tPklNh3dhI6OV/OefC6jKIMVOSuICYtheq/pXNvlWmSpyZdhqDdNOpjKK3YYTc6EtMdJlRUcEy4ncXJfAKZG7mbBqoMBe5yCXUKXJYlbhncA4IjBqlmo43P7NI6W1X7zJDR9BworeHX1IT7cnIPToyJLEqqmk9z2Cu7ztOByr3qyyMSJYCp8+PBajiqcDw6WHuT+FfcHFUhVIXt4addDTEhbRkubSBcWhHPBy1MGcMnzP3KswnNGVX+DzY5xelXm/3iAAVMH1Mm468vEnq3ZmlPCfe9s5o3bhoiqtw1g2f5l/H3d36v9TtLRcfqcOH1Onkp/ih9zfuTpMU8TppyfwW6TDiPd1ZTxDLVM6KnFK6YOa2+wa8ovosdYEm95jnYPfEDyvYuJv+5xrG27nXxelqBfu5iTvRs8dTQ+l8Hqm9B8HCl1cu2cNVzy/E+8vSGLSreKpvubJ+pAVkQ8f/lsF/3/9g0vr9iPruuEdemMa9++xh660ETM2zoPt+oOeLz4p2L2PbqPjDsz2H3fbvJez0OtrHq9cPlcvLP7nYYaqiAI9Sw23MLH94wgKcZGmCn027Rgs2N0Hb7fU0CZy1vj65qCByd0AeCpr/c08kjOfV8e/LLGQOp0TtXJuiPreGDFA2h60y+/Xx+adDAVF24xfDzUMqEtTzlOYrSN2eM7YzuDMtThFhP/vvrXqjIxdvNZj0+WIC78/IzkzwX7C8q55Pmf2Jxdgtun4atmFrHS4+8/9uJ3+3nw3a2YU1Nxi2BKAMo8ZXyT9U3Al9CxL46R/14+rSe3pvvL3Ul5LAVPkYdDTx1CO2WCyKN5eHv321VSMARBaN4So218ft8FzBidQozNTLgl+HuWULJjzIpEQTPIjjEpMi/c0I9PtuTx5Y78gOc9Po3Ccjf5pS6c1fT6FGp3pOIIj61+zDCQqmlyz626Sc9P561dbzX0kJuEJp3md0FqK77KOBqwLymUPU7hYQrDOlVNf7lrdCeKKj0sWZdlWCnwdJLkD6TenD6EDqeUEh3VuRXpB4sDjhHK+KxmhYEdmm6+slC9gjIXk+eto9ThNSi4b8zpVfliRz4tzAlM3rcfXdOQDJo0CuePZfuXIZ82r6U6VQo+LiDp9iQie/ub71paWUi+O5m9D+2ldE0pLUb9et3waT5WZq9kfPvxDTp2QRDqT6TVzAMT0rhvXGe+3VXAp1vzKKxwo+s6kVYzP+4tNJzAC6VCsYxEhbt5BB9xEWG8PKU/t72WTueECFJahrP2QBGvrDzAqv3HUGQJSQKvT6dDSzt3je7EpD5tRA/HELy9+21UPfB8OPbFMQq/KKTt9LZEdI/AW+wl7808Dj11iI6PdEQ2yTh9ThbtWMRvu/32vNs/1aSDqct6J/LYsh2GzwW7x8ksy4zvFl/lMUmSePSy7nSIC+fJL3ejqhqVBuXNFVnCrEh0ahXBizf2I6VVRJXnrxuYzH+/Ml5yDnZ80TYzw8Tm8Wbpb5/tpNTpMQykausxtmRzPkPiO9Ix7wiWtkkNPnah6UjPTw+YBXTsc6B5NaIGRFV5XLEqRPaOpCKjokow5fA52Fa4TQRTgnAOMikyE3u2ZmLP1icfc3lVuv/5S8PXh1KhWNN1IsLq7lZQ03TketzT1Cc5ht9fnMa0RRvwaTqlTi9Oj4oOVQLLzMJKHv8kg798ksFDF6dx6whR5Ks2HtXDu3vfDchyCGVyr9Jbybq8dQxPOr/2gzfpYMpqVrh+UDKL1x02rGpTW5nQMJPMzcPbY1KMI+Sbhrbn+kHJfPTGchZklLLfHIMkSei6jsUkM6l3G24f2ZGuraMM3x9tMzOxZ2s+23oEVQ99fDazzIxRKaKsZDNU4vDw9c6jGG2bC6bHmNensazrWIbt2yuCqfNcqbs04DG1QsUUYUJSAq8NpmgTzsPOgMeLXEX1Mj5BEJoeq1mhhd1CUaUn4LlQsmN8mk7iWbRmOV7pYWl6Fq+vOcyxCjc+TcesSLSLtTNjdCcm9W6DLYQUxWAkRlvJK3XVWpzjRKPiJ7/cQ9ZxB3++vLu436rBT7k/YTQ7HMrknsPnYMnuJSKYampmjOrEBxtz8KqB/ZxqIuHvUXDLsA41vs6syIzY/j0Thg0j5tpLcXhUzIqMJchNn7PHd+GbnYGpiLWRJYi2Wbh2YHJI7xOahqXp2Rhdk4OtoqTq8K21Hcf3ZBI5Nvi+IcK5x2oKvJFRIhR8FT50VQ8IqHylPkwRgZduu8leb2MUBKHpmTaiAy99v79Kka0TgsmOUVC5VPmZ8J82QN8p0DI16M+udPv404fb+TIjH1kC1ynZPV5VP7ky9PgnGdw2ogMPTEirkxWrbTklzFy8KaQqh06vyjsbsomPDGPmmOB/x/PNkYojhntvQ53cyy7PrrcxNlVNMpjSf1nlkSSJ1tFWFk8fwg2vrAs6YDmxx+ntO4bW2pNA93ioWLWahD/9CUmSCA9xubtjy3AW3jKI215LD2r/FYAiSURaTbw7Y1idLq8LDefDTblVvjxOCKXHmEmWWJ1ZRPv6GKDQbCRFJCFLcpUCFPZUO5JJomxjGdGDo08+rrpUyreVk3BtQpVjWGQLSRFihVMQzic3Dm7HS9/vr/b52rJjLGYz06+9AvLeh1cvgdiO0Pe30ONqsBpn5IB/NerauWvILXYGtJk51Yl7toWrDpFxpIxXpg7EXE2mULAeem9btfdataXXP/ftPq4Z0Jb4yDNfiTuXOX1OfFrgwkWok3suX9MvaFLXmsSdvK7rbDh4nHk/HmDdgSKcXtW/smQxMbFna24f2ZH3p/bmt3NW4bFH4PBVPyMRblGItJp4646hAXucjFSmpxPWseNZNU8d1imOt+4Ywo2vrMOratTUZy/cohAXEcZbdww5WWJdaH6KHYGpFRBaFSUvEu+7Ylj/zmZMskRitI1JfdqQ1jqyrocrNGHXdL6GTzM/rbJvSrErxF8VT97iPGSrXGXDrznWTMzwmIDjTOw4sSGHLQhCI2sZEcalvRL5YvsRXDUENUbMikS3xCh69OoHvfrB+Mdh/7ewZQl8/WfocjH0mwIdRsEpRZJcXpUb568ju8iBN8jVIadXZW1mEQ+8u4UXbuh3xql2O3JLOXy80vC5YNLrAd5en8Ws8V3O6PPPdRGWCMyKOaBNR6iTexHm2u+9zzWNHkz9uLeQP324nWKH5+QmQvCnbZa7fXy4OZdPt+WR7C3j2chsyq/+LXNXZHLgWAVmRUbX/StRXlUjLSGSmWM6cWG3hKBnPyq++56IceOCem1BmYsvduRTUO7G7VWJjbAwqEMsA9u3oE2MDatZZtb4zry/MYe8EhcmWUJDR5YkPD6N3m2juWt0J8akxYumc82cZrBHDkKrouTRYFN4WzZtyfO/V4IFqw6Q0iqCmaM7cVmvxHrdyCs0DT1a9iAxPJGDZQerPN7q0lYo4Qr5S/PxFHiQbTJR/aNInpGMbK56fRvYeiCtw1sjCML55V9X92JPfjn7C8rx1DSTewqTLBEXbmHBLYN+fVAxQ9ol/j+VRbD9Pfj6UXCWQJ8boe+NEJvC/J8OcPhYpWEgVdPKkMur8d2uAlbuLWRMWnzAe4OxcNVBvAaT6cGm17t9Gq+uOcS94zqLezADnWM6G1bhC2VyT5EUesT1CDjGua5Rg6ml6Vn85ZMMw3SpE1RNR9V09unh3KP1YEGEheWzRpJZWEFWkYNyt49Iq4mOceFVypYHQ9d1yn/4gXbzX6nxdWszi3jlx0xWZxYh8WsTYEWCMLNCbLiFluH+/U8zx6Qyc0wqu/PLyDnuxOFViQwzkRofQXKsWIk6V0RaTRyrCFydCqWK0ulUHVSvxs68Mv7vg218sCmHuTcNEGVdzwPTe0/n7+v+jtNXNfc8dnQssaNja3yvzWTj9p631+fwBEFooqxmhXfvGsatr25gR25ZrdsNrGaZNjE23rljKLHV9PIkPA6G3uX/c2QbbHkLFkxAbdmVRYfuwuULvOEOZmXI4VGZt/LAGQdT3+48aljsK5T0eq+qkZFXSu+2gav757sBCQOItkQHfA9B8JN7ZtnMTd1vashhNwmNFkx9s/NorYFUFZKEw6sx/fWfef+u4XRvE0WnINL4Tsg+7mDxusNsyS6h3OXDblFoZ/ZyYXQSqZ06Gb5H1XQe+3gHH23OxeVVA4qcqLr/4uDwOMkpdlJQ4Wba8A60bWGna+uoaqsACs3f+G4JvLbmUECVyVCqKNXE4fGnRdyyaANLpg+ptiKlcG6YlDKJ77O+Z3Xu6qC7zgNYFSvXdL6GwYmD63F0giA0ZRFh/j3in2zNY+7KTLKPO3H7VE5dPLJb/NX/7hyVwuSBycFX2Evs7f8z4Ql++P4rPPt9QNUgLNiVIYBNWcVkH3eEPLms6zqVHuNCZKGk18uSRLFDNDg3IkkS03pO47mNzxl+DwUzudcuqh1psWn1NcQmq1GCKZdXZfY7mw0DqZqWicF/k3nf25v49sExQX3WhoPHeeabPWzOKkHT9So3v5vQWd51Ms88s5LfjevMlX3bnMzl1XWdB5Zu4eud+TiDDPiOlrqZ9OIqls8aSWK0Laj3CM3TLcM78MbawxjVEQ22xxjUfL67fRrbckr41xe7eezy7g3wWwmNRZIknhz1JA+seIBVOetQcdf6Hqti5fJOl/PQoIcaYISCIDRlJkXm6v5tubp/W3bklvJVRj5Hy/zlw1tFhjE2LZ7BHWPPvDS4ycL7R9tQqeUHPBXKypAOfJWRz01D21Pq9P76x+Gt+rPTS9kp/13i8FDdFq1Q0uvh1yJnQqArO13JnK1zQprUO8GqWJnVf1Y9jKrpa5Rg6rNtRwwfD3YDYW6Ji63ZJfRJrnmZ9s21h/jH8l3Vrn5pSLiQyCys5E8fbuenfYX855remBSZ+T8d4OudR4MOpABUXafM5eO389fz3QOjxX6Xc1jbFnb6t2vB2gPGvX1qq6IEwZ3vTq/GkvWHeWBCl5ArTQrNi0WxcH27P7M243ki43/Eo3pw+BxVXiMhYTPZiLJEcXffu7kq9SrRN0UQhCp6JkXTMym69heG6Gi58Q12KCtDHp/GP5fv4smv9hBtMxv+ibKZSY61Bzz+m5dXG1Z1DiW9Xtchxl5NeqNAhCWChRct5ObPbsShef19hoJgVaxM7z2dUW1H1e8Am6hGuTubu2L/yWZqJ4SyTOz2qSz46QAv/rZ/tZ+xND2rxkDqdE6vyvLt/iDvX1f34n8/ZJ5R+U1V0ykoc7FyXyFjzzAvWGgeHru8O9fMWRN0SfxThXK+y5LEx1tymTJEFFE/l1W6fTz80Q6eumIWo7o8warcVby5802yy7Nx+pzYzXZSY1K5pcctDEwYKIIoQRAalK+aAhehrgzdNLQ9T1zZM+TPH5PWii935AesUIWSXi9J0D1RbMGoSVrBfl4vcjA9IRaX5gmo7ncqRVIwy2Z+1+933Nzj5gYcZdPS4MFUbomT7OLAzW2hLBNrOnyZkY+u64Y3FPsLKqrdj1VzHwKN5dvzibKa8KnGQVgwqwmVHpV5KzNFMHWO694mirlTB3DXmz+HtIIJoZ3vDo/K/B8PiGDqHPffr/YwqH0sY7v6rxujk0czOnl0I49KEATBr4XdbPh4KCtDsgQJUWfW5+mOkSn8sLvQcAIzmPR6iyIxdWh7LCaxB7la2enwye/oOuU9Pm+Zygf7PuCNjDdw+py4VTeqrp7MkFB1lYkdJnJzj5vp0uL8Ljff4MFUUYUbiyIHdOwOZZkYQNP8HbeNNlEuXHUQr0EwFFxalcpbG7INO4qHspqwOauE3BInSTFi79S5bHSXViyePpTpr6fjUTUq3cGtUoV6vueWBE5ACOeOnw8dZ/n2I3x9//mZIiEIQtM3vlsC6YeKA4KZUFaGwswKQ1NqLmJQnb7JMbSOtnLwmHGvqdrS6yXVw1TrKtA6Q5DfveeVokx457dw1RxIGkA0cFvP25jWYxrr8tax8/hOytxl2Mw2WttbM6H9BCIs519PKSMNHkx5Vd0wBzPUZWJZBq+mYaPqayvdPj7anMPpsVRoaYTGqwyhrCZYTDL7CypEMHUeGNC+BRseGc+3O48yZ2Ume/JKMfk8yHY7XlUzXCEN9Xz3qjqapot9eOcgl1flD+9v44kre4hcfkEQmqzf9E/in1/sMnwu2MJL8RFh9G/X4ow+X5Ik/nNNb25etD74StC/sJkVpvWNIfHgfNjzJlz+LCRVv1XkvFNRCIuvhnGP+Bs2n0KWZIYnDWd40vBq3iw0eDAVYzejGpRkCbU/j0/TibAEDv/LHfnIBql/oQRC1QllNUHXdSpcxmU8hXOPWZG5pFcil/RKZO97n5C1bitRt9xHVlElj3+aQcVpK1ahnu9mRRKB1DnquW/30S0xiok9Ext7KIIgCNWKtJqZ1LsNH27KwWj7VG0rQzazwozRKWe133Nwx1j+e20fHnp/a9ABlc2scHGPBP5wdV/gM9j6Drx9A3SbBOMeA9t53nPKUwlvTYZek2HAtMYeTbPU4Imj7WPtWAx65py6TOzYuxbN60JXfTgzf6b4h0UBr++VFG14c5lT7MBpUO0l1LQqI6euJtRGQsIebB8H4ZwSW5RH7zaRDO4Yy4jOLQN6UUHo5/uZ5pgLTdu2nBLe35jN41ecfx3jBUFofmaN74z9DCrLKrJEQlQYv+nX9qzHMKlPGxbeMogYm5nwsOrvs2xmmTCTzB0jU3j2+r7+IE6SoO+NcPc60FT43xDY9p6/zN85xuHxsSe/nJ8PHWdHbilFFQaFJFQfvHcrxHeDsQ83/CDPEQ2+MmVSZG4e3p55Kw8EpNMFu0wcHqZw9xjjRrvlbp9B55/Q06qMhLKa4NN0klqIFL/zkTc3D2v3bgAkRtvo3Taa9EPFAa8L9ny3mRVuG9GxQcYuNByPT+MP72/j0cu60yoyrLGHIwiCUKu2Ley8cdtgpixYb1im3IhJlmgRbmHpjGHBNwuuxYjUlvz86Hi+3eVPr9+ZV4ZFkZEkf1p8tM3MHSNTuG5gW+P0aXssTHoO+k6Bz+6HzW/CZU9Dy851Mr7GtCe/nAU/HeDTbXmYZP/fia6DR9UY2L4FM0Z3YmRqS2QJ+PwB0Lww6Xl/oCmckUYpjX7TkPbMXXnA8Llg+vOYZJnx3RIMn4uxmZElAkpnhppWZXSMUDZZtomx0iUhstbPEc493iNHiBx/4cmfZ47pxM63Nge0A4DgzndN17l24NnP5glNy8sr9pMUY+PKvm0aeyiCIAhB69euBR/ePZybFqzH6VWrLbwkATaLQoe4cN64fTAtI+p20sikyEzsmcjEnomUODwUVXrw/RJIxUeGBZcanzwI7lwBG16BhRfBoOkw8gEwN7/J8HKXl7ve3MjGrGK8qv7LlpqqixZrMovYml1CtM3MG713kJq3CW79AhTjSo1CcBo8mFI1nfgoK1OHtuet9Vkh9+ixmWUeu7w7JoNUQYAuCZHYLErA/9yhBEJmRUKrptV2MKsJ4RaFu0Ybr5wJ5z5vXh6mxF/3v4zuEk+0zYzDq4acSWA1y1zVN4koq7jQNRdun0qZ04fFJBMZZjL8Qt91pIw31h5m+X0jRb8oQRCana6to1j3pwv5dlcB81ZmsvNIWZWS4x6fxsjOLZkxuhMD27eo9+tcjN1y5gV8FBMMuxt6XAVf/hFeHgqXPg2dx9ftIOtRqcPLlf9bRV6pC081RdROqPSoODw+rlzVlrenvU7vMDHxf7bqPZhyeHx8vDmX+T8dJPu4A5+mo8gS8REW2sXayDruCLpHj82scOeoFK4dUP0s/biu8ZhlGTizPgTgb5J6Se9EvsrIN6zsV9tqgqJITOojZpvPR7qu4z1yBHObpJOPKbLE4ulDuOKl1VS4gy9KYjHJpMZHiP00zUCl28fHW3KZt/IAOcUOzIqM9kvkfGmvRO4YmULPpGgAfKo/ve//JqbROlrshRMEoXnyrwy1ZmLP1uSWOMktduLw+Ii0mugQF05cHa9E1buoNjD5Ddj3LSx/EBL7wMR/+x9vwryqxk0L15Fb4jTco21ER6JSt3LT2/tZPiuRti3s9TzKc1u9BVM+VeM/X+5m8bosJIkqubWqpnOkzE2xw4PHpyPrGiaTgqeak8BqltF1ePjSrkwd1qHGz61pTxYEl1bVu200T17bm30FFewvKA/65AR/wPfarYOxmkXxifORWlyMZLGgRIRXeTylVQTvzxzGja+so8Ltq/WcspkVureJ4rVbB4lzqQnTdZ0Xv9/PnBWZVa5zp157Ptt6hK8yjtI+zs7cmwbw5Y58om1mJg9MbqxhC4Ig1KmkGNu50wqm83h/gYqfnoE5I2DUQzD4Tv8KVhP05Y58MgsrDe8rKneuoCz9Y7xFOcgWG+b4FKKHTz7ZW7XC7ePZb/by9OS+DT3sc4qk15B3NHDgQP3nn38O+aBun8qtr6azOas4qFUniwThNjMeVUNCOrkHTtN0rGaFXklRHK/0UOFWkWWJlhEWrh2QzOW9Ew1vNI8WljLuqRVUSqGf+FazzKvTBjOsUxylDi9TF65nb0F5rSU4Zcl/Azz/loEM79Qy5M8VmqcKt4+PNuWweF0WhRVuvB4vVkc5QwZ05o6RKfRJrlpytaDcxYIfD/LWhix0Xa+yj+pEfnlsuIUZo1K4YXA7zNWks54NSZI26ro+sM4P3MDO9PpUVzRN5/53t/B1xtGg0pWlX64RigTLZ40iOVbMBArC6c6F61NjX5uEOnRsn79Ig7MYLnvWv8eqibn0hZ/YmVcW8HjZho8oXf8+cRfdg7VjfyTFhPPgRtzZGVUbKZtk0h8dL7YT1KKma1Odh9mapnPvks1sOlyMq5a8zRM8OuhuH2kJkfzlih44PSqF5S4+2pzH+oNFrDt4vEows78AtueU8udlO5g8MJn7x3ch2u4/CTzZ2VT+7j6e6jyA+60DQmrsZjMrPHRxGsM6xQEQbTfz3sxhLFp1kAU/HcTlC9xoaTPLaDpc1COB+8d3IaWV6AZ9Pih3efnH57v4eEsusiRVWXktN0ewfPsRvttVQGK0lUcv78a4rv6CKfGRVh6+rBsPXtyFL7bn8+ncd3H16IMlMoLW0Tau7p/UIPnlwtn7++c7gw6kwF9NyeFRsZrkc7EKryAIwrmnZWe4+RPY/j4svQnSJsKFf/FXA2wC9uSXc6CwIuBxzV1JyaolxF06G3var8127alDsKcOqfJaWZL4cGMO00TV4DNW58HU1zuPsjrzmGEgVdNyo1fV2V9YwdbsEvq1i+HuJTtxeHxoOsZLl7/cvC5Zf5hvduazdMYwYrb/TN6fHqblXXcx8aYptDpczLRX0/GqmmHK3wmKBGaTzKOXd2PKkPZVngszKcwck8qMUZ1Yua+QpRuyyS9z4VU1om1mxqS1YvLA5DPf+Cg0OwVlLibPW0teibPa1FRNB6dX5cCxSu5esonfX5TG9JEpJ58PMylc2SeRrutep8uza5DtYpWiOdlfUM5bG7IMJ2tqS6vwqBp//mQHr906OPgPVH0gySA3eGtAQRCE85skQe/roPME+P7v/gIV4/8KfW4IqZy4rus4PCqyJGE1y3UyafrTvkI0g9tbd+5udJ8He5dhtR7D6VX5cke+CKbOQp0HU3NXZhr2HqhuudG5b/3JmwyXV+PlFZk4Pb6gi1J4VZ38Uhe/eeob5qydQ/cXnsc+YAAAAzvE8v3vR7N43WFeX3MYn6bhcKsn+1DZLQqarnNFnzZMH5lSYylzWZYYmxbP2LT40P5ChHNKucvL5HlryS52oAa56Onyajz99V4irSauH9Tu5OO+wmPIkZEikGqGFq46hNfgBAjmOqfp/vK0R8tc1Tdj1jTI/A5WPwfZG0D1+h+3hPsrTg27199kURAEQWgYthi47Cl/09/P7ofNi+HyZ6BVWrVv0TSdlfsKmbcyk/SDxSd7PkkSjO0az4xRKQw4i2yUUocXj8F3keosQ7ZHBd1XtdjhPaPPF/zqNJjKLKxg15HAvM1QlhuLKz2GTXdrmu1VdSj2STx/3SO8+UsgdUJ8pJUHJqRx37jOfLe7gD355ZQ4PERaTSS1sHNpr0QizqCbt3B++ufy3eSVOA0DqZrOUadX5c/LMhjdJf5kBTdvTjaWtqJ/VHNT6fbx0eacgHMglOucBCxed5gHLzL4Et7xEXzxEHid4DktfcNTAVvehu0f+L/Ar36lxi9yQRAEoY4lDYA7foD0BfDqJdD/Fn+RCkvVidFvdh7lTx9uw+lRf90ffeIGV4dvdx1l9f5jtIwI49nr+zCgfeipg9UlKyi2KDRHGbqmBhVQiaSHs1OnUcQX24/80iSsqlCWG40CqWBme32SwvrcCnJLnIYVZUyKzMU9WnNxj9Yh/16CAL8Um9icY5jaF8w5CvDmukM8dHFXwL+/z5wsKro1N+sPFmGSZU5vhhjKdc7t01i2JS8wmPrpGVj5JPic1b9ZV/3PH9kK88fBTR9CuyHVv14QBEGoW7ICQ2ZA9yvhq4fh5SFwyX/9e6qA19cc4l9f7Kpx3/6JfbRZxx1MWbCeF27ox0Uh3qNGWs2YFSlgO0xYUlckkxnH3rWEd72g1uPUdUPl802dxqL5pS58BsFUqMuNpzox2xs7YSb2tOHIFiuSYsKeOqRKNRIAdJ031hw6w9ELQs0+3pyDbLAUH+w5glewmQAAIABJREFU6vZpvLn28Mn0MG92DpZksTLV3Byv9OIzSFIP9TpX5jotrWLTG/BjLYFUFbp/pWrx1VC4N8j3CIIgCHUmsjVcuwgmPe8Pqt6Zwudrt9caSJ3O5dW4753N/HzoeI2v03WdzMIKFq46yNSF63n66z2G991yWDgxF0zh+Ddzcexdi+Z1oas+nJk/U/zDoiqvDbcoNfZvFWpXp8GU1+AfFKouN4YqlNlej6rz9oaskD9DEIKxeF2W4X7AUM5RTddZd6AI8Kf5mduKlanmptTpMfySDPU6V6Win7MYlv+S2neKDs+VE//fcio9v754wSYPY16r/PVFnkr4aEZIv4MgCIJQhzqNg5lrcLXqw0PL9lZbnOjI67PJeuZacl6aytF3/4IrJ+Pk8/6AaguntyxyeHx8t+soj328g1H//YEp89ez72g5U4a0Y/0j4xn+SwXq00UNvpoW426ndO1Scl6cQs6caZRv+gxb59PuVSSJiT1F1tbZqNM0v+qWCUNdbjxVqLO95S4fqqajyKK0tFC3jlW4DR8P5RzVdCgo8x/Hk51DzLViNqi5eWu98YSNr/wYoJP1zLUo1oiAKn6nMysSLq/q75W3eXG1VaFUHZ5f7+HhkdWlYehQsNO/OtWqyxn8RoIgCMJZM1v5NPpGMG8Hb9WAKNitACUOD+sOFNEq0sqKPQWs3FvIpsPF9GobzZi0eObfPJC0hMgqBSvuGt2JzVklhpO9ET3GEtFjbLVDtigSNw5OJswUeuaY8Ks6DaZGdIrjtdUHqzQiharLjZKsYO3YD0k24Tq0BVfWtsB0vVOEvonOf4MSLopKCHXMqEQ/hHaOarp+svKOV+yZanb2F1SQUxyYhnfiizK8x1gcmRtpMd6/UuTYsxbd7Qy4zsmSvwhF/799Q6/ESBYef44Ir3F630PDLTy52s3dgyzEWKuZJNJ8sG4OTHr2rH9HQRAE4czMXZmJ47RAKpTiRA6Pyq2vptMi3MKYtFZMGdKel6f0J7KGhroXpLZkTForvt9dEFJqoSJJxEdZ+d24zkG/RzBWpxHHsE5xRFrNAcEU+Jcb5fAWlK5dyrHPnkKy2AhLSCVq2PU1HjPUVS1N17FbRIQt1L3wMIVSZ2D50FDOUZMsEWU1ozmdqKWlmOJFqf3m5NXVBwOK7Jz6RRne9QIqMn6gbP37eIuykSw2fMdzAq5zZkXmg5kjiIuwsGvLGixfVVKdgW0UxnQw8dQaN38fV10pdR9kfCCCKUEQhEaSX+oynGwLZSsA+HsRrvrDWBQluJ04kiTx3PX9uO21dDYePh5UayGzIhEbbmHpjGFE26oP1ITg1GkwJUkSd45K4b9f7Tb8x6xtudFIqKta7WPtddIITRBON6h9LJ+VHkE9LZ85lHPUq+r0bhuNNzcXc1ISkqhH2mxoms4Hm3ICNvue/kUZzHWuZ1I07eL8ZXQHtlTBbAZ39YUnnhgbxohFlcwaUkNzcHf5rw1MBEEQhAZVVOnGrMi4fVXvf0PdrqLIEhUelWhb8PcHFpPM67cN5p/Ld7Fk3WEkCcP7cLMiIUsSgzvG8vwN/YgNr+E7RQhanefCTR6UzKLVBzlS4gq46ayNWZFQNZ3T61gEu6pltyjcOSrlbH8FQTB0x6gUvt55FKf3zFde+ybHkBxrp3xrNmZRya9ZObEf83RnUq30rtGdfv1B89X6+p7xCpd3MfHvVR66taruCza0660gCIJQd6q75Q11u4qEhFZNQbeaKLLEY5d3Z/b4znywMYf5Px0kr9SJSfaXTreZFaYObc/UYe1JjrXXfkAhaHUeTEWEmVg6YxhXvLiKUqfXsGTj6STJHwj9+ze9eeiDrYY5n8HM9uo6XNUv6YzHLgg16ZkUTdsWNvYVVBg+X9s5Gm5RTt5Ee7NzsIhKfs2K06uiyIH9PEL9orSZlaq98GwtCCYQ+usYK/3nVfDgsGoKUZhsYlVKEAShkUTbzIZtM0LdruLVNKLOIvUu0mpm2oiOTBvREVXTqfT4+NfyXaQlRDJtRMczPq5QvXrJMUqKsbF81khSWoVjtyjU9PUeHqaQEGll2T0jmNS3Df3btSDMFPqwbGaFW0d0wG4RhSeE+vPIZd2wmkM/P02yhNWs8Mw3exjz1A9cu9vG/fTgyx35+NTgN4wKjSfSasJnUITk1C/KYEiS/1gnJfQAvfZzIDVW5voeZl7Y4DE6KrQfbvC4IAiC0BCSYmxEhgUGQaH0fALonRRdZxWplV/2aXdsGU7W8WB7GAqhqrcNGwlRVr6aPYqFtwxiTForLCaZSKvp5J8wk8zA9i14/vp+rP7jOFLjIwGYN3UAbWKsmJXgTySrWWZ4ahy/vyitvn4dQQBgTFo8f5zYNaSASgJ8mk6Zy8v23DIOHXNwQLOx1hHGg+9tYeDfv+XZb/biMkgfFJoOu0Ux3Kgb6hclQHzUKatLZhv0uwnk2mci/zw6rErPqZMsdhgxO+jfRRAEQahbsiwxfWRHw/uDYHs+hYcpzBzTKeD9Z6tdbDhZx6svdCScnXpdxpEkiWGd4hjWKY5jFW4OF1VS7vIRHmaibQsbidG2gPdEWs18fPcF3LxoPfsKKgzr5p88PmCzKFzcozX/vbY3sugtJTSAaSM6EmU186ePtgMEbDY93YlbX6PS6pVuFVCZtzKT73cfZfHtQ4m2i8o6TZEkSdx2QUde/G4frtP+zYPdM2dGZXLC0cCeHkPugo2vg1a1WuSh2ZFVfk6OlnE9GhU4OGsMdAith58gCIJQtyYPTOaZb/YaPhfMdhWTLDO+W0Kdj6tdrJ2s4446P67g12A5cS0jwqpt6nu6aLuZD2YO55udR5mzMpO9+eXogMenIUkQZlJQdZ0LUlsyY1QKgzvGigp+QoO6ekBbLujSkiXrsnh9zSF8moYOqC43ismEBwmfple7IfV0Lp/G7vxybpy/jg/vHu5v5Co0OTcObscL3+0zfC6YL0rZZOLWsnlwZAAk9vn1idgU6D0Ztr8H1fSbqo5usiJd9ozYLyUIgtDIWoRbuGdsKnNWZBoWq6qJ1SzzxBU9MAVZEj0U7eL8wZSu6+J+uR402Q1GJkXmkl6JXNIrkf0F5azeX0Sp04si+2vjX9gtnvjIanquCEIDiI+0cv+ELvxuXCobDh7naLmLgvc/wtYmkX/mhxuvRO1cQVn6x3iLcpAtNszxKUQPn4y1bQ+8qk5mYQX/+mIXf72iZyP8RkJtYsMtXN0/iY8254bUHBEgzCQzsnNL2veZBcvugTt+ANkEuZv8PaJcpWCJANUbVIU/AI8UxtsRt3N9ygTE1VAQBKHx/W5cKnklTpZtyQspoLqqbxJX1lMRtYgwE+EWE4XlbuKjxLdFXWuywdSpUuMjT+6pEoSmxqTIDE9tCUDB9x7eqXQjERHwurINH1G6/n3iLroHa8f+SIoJ58GNOPetx9q2B+BPGXw3PYc/TuyGTTSfbpKeuLIne3NL2J59HE8Q+5zA3wOkfZydF27sB+aB/hWo926Fwl1QlvfLalQIpXDNdkBHmfQS6TtSWLF4I3OnDghMHxQEQRAalCRJ/OvqXiRGW/nfikxkCEgNP8FuUdB0nalD2/PR5lz2F5TX2/1uuzg7h487RDBVD0THUEGoQ3JUNG8WhwfMRmnuSkpWLSF2wkzsacORLVYkxYQ9dUhA42lJgk+35jXksIUQKKqPf214lYEWF/YgAl67RaFXUjQfzBzurzbqqQBnCez+FIr2g9dBzYGU5F+xCovyB1Ex7eCif8Dv96H0vpZnr++LxSRz71ub8YrKkIIgCI1OkiRmje/C2j+O474LOxMXbsFuUU4WYbOaZdq2sPGnS7ry86MTeOSy7vzfxK5MezWdgnJXvYypXaydrCKxb6o+NIuVKUFoLnaYWlCpB+Yju3N3o/s82LsMM3hXVQ6PyqLVB5k8SPShamp0XSf/scewx0Ty5uPXsXL/MeatPMCW7BJ0txuP4l+psphkZCCtdSR3je7EhO4J/jx4nxteuxwKdoX4yRJc8RLEd4OWnavsjzIrMi/e2J+Zizcy653NvHBDv3rJuRcEQRBCExcRxt1jU5kxuhOHiiopdXqRJYkWdjPtYu1V9i9dNzCZ3BInt7/2M0tnDK3zVj/tY/0rU0LdE8GUINShQlM46C5Ob66mOsuQ7VFBNXUFOFpWPzNTwtk59tL/cB88RPvXX0M2KYzrmsC4rgkc3J/De3/4D6a77kXX/ZuQR3VpRWr8aemenz8IhbtBdZ98qMNz5Ti8cHBWBOEW/4mzYJOHxdu8rJgWDuj+1asf/gH3rDcsNGExyfxvSn/ueONnHnxvK89M7mvcp0RT/XuzVA9Yo/1l2QVBEIR6pcgSnVoFpv+fbtaFnckpdvK7tzYzb+qAOp0YS461syazqM6OJ/xKTF8KQh1yh9nRDDK2FFsUmqMMXQtuM6qnlnLrQt3TdR29hvKLJR9/TOmyZSS//D9kW9UgJL4gi+vDS3ngojQevDiN2y7oGBhIOY7790r5AgNlVYfn1xs14z0xOBVKc+Dw6mpfYjUrzL95IAVlbv7vg21op56IuRv9e7T+ngBPp8HzfeCfSfBcb0hfBO6K6j9bEARBaBAn9lt5VI3HP82o8TspVJ1khejsCio3HsW54xjefNF3qq6IlSlBqEPRMeHI+rGAx8OSuiKZzDj2riW8a+39gETxifqn6zpbskuY/+MBVuwtPLnPLdxiYny3eKaPTKFnUjQAles3UPDfp2j/xuuYWrYMOJZrz17CunSp+QO3LKm2fPlDwy08udrN3YMsxFirKVvrdcCaF2vsJ2U1KyycNpBbFm3g0WU7+McFFqSlU6E02x/E6acF6SWH4etH4auHYcQsGPNHUWJdEAShEZkVmZen9Oe6uWuZ9+MB7hp95k18dZ+Gc8cxylZkE3/MxY0+jZJl+/3XeU1HaWElcnRb7L1bIRk0GxaCI4IpQahD3dq1xEd2wONyWDgxF0zh+DdzkWQFa8d+SLIJ16EtuLK2BRSh6NEmuqGGfF7acPA4f/hgK0dL3bh9apXVxAq3j0+3HuGrjKMkx9r4x7A4oh98gKSnnyask/GXmnvvXmx9+9b8oeterraH1MA2CmM6mHhqjZu/j6uu0pIOmT9AZRGEx1X7MXaLiUXTBvHE3DfwbH8Ui+ZEqqnAhfeX2ck1L/gLYlw9H2TxpSoIgtBYIq1mXr11ENe8vIY2MTau6NMm5GN4j1ZSOH87ukdF9/gn0sKRTv43gK/AQcmy/ZR+foCW03thaVN7KqIQSHxjCkId8Pg0lm3J5YlVeUinz/7/Imrw1bQYdzula5eS8+IUcuZMo3zTZ9g6Vy1KEW5RmDEqpSGGfV76bFseNy9az6FjDpxe1TAtU9V1nF6VvUcruOmD/eyZ/hDhQ4dUe0z33r1Y02pYmdJ1KD9S47ieGBvGixs8FFbWkOJpCvOvJtUisjKLJ51/IUxz1BxIncrrgD3L4etHgnu9IAiCUG8So20snDaIv36SwfoDoe118uRVUPDyVrQKb5XgyYju0dAcPgrnbsWTXX42Qz5viZUpQTgLZS4v//t+P0vWZ6GjU+lW/Te81YjoMZaIHmNrPGak1cywTtWvPAhnbvX+Y/z+va0hNdx1K2YeOiCTlFVM/3YtAp7XVRV3ZiaW1M7VH8TrBEn2732qRs94hcu7mPj3Kg/dWtUwz+UJYn/T8t8jearmw9de6AJ/QPXzIhhwK7SqJW1REARBqFfdEqN47oa+3PPWJt65c1jgXlwDarnHvyLlDr5hMPiDqsKF20m4fwCm6OrvY4RAYmVKEM5QXomTy57/iVfXHKLC7fMHUmfJZlZ46OK0KuVShbrhUzXueWuTYSBVuXMFR16fTdYz15Lz0lSOvvsXXDkZJ593eTXuXrLJcDOw53AWplatUCLCq/9wkzVwv5KBv46xMn+Th9yyGlaTLLV8mZbmwqHVhp9Xa6ELAM3nT0kUBEEQGt3Izq34v4ldufW1DRSWu2t9fcXqXHSv8f3IhpxtXPXmTLo/ewk9n7+M3yy+my1Hfm3VoXs1ylcGblUQaiaCKUE4A8WVHq6es4a8UledVd6zmRV+O6Qd1wxoWyfHE6r6dlcBXoN/q7INH3H8u/lED51M23sXkzTzVSL7X4pz3/qqr3N6WWtQVta9d0/txSdkGSJa1zrG1FiZ63uYeWFDNQGPzw0x7Ws+SPqCap96aLiFp9a4KXHVEKxpPtj2DnhEpSdBEISm4LqByVzdry23v56Ow+Or9nW6qlGx7gj4Aq/x5e5Kbn3/j0wbcA3bZ31G+t0fMnvENMIUy68vUnUcPx+tNhgTjIk0P0E4A7Pe2UJRhRvVYMNN5c4VlKV/jLcoB9liwxyfQvTwyVjb9jA8liJJmE0Sd45KYfb4GlLFhLMyd2UmlZ6qXxCau5KSVUuIu3Q29rThJx+3pw7Bnlp1j5TDozLvx0z6t2/BD7sLyC1x4vSosOkonTv0pq2u17yiOGQmrPgX+IyLUJzw59FhvLnNa/CMBB1H1Vh8AoAdH1TpY3Wq4ApdALLJv7rV5aKaP0sQBEFoELPH196DyplRRHXbZA8c9684XdV9PAA2WWF0x8GGr3VsO0b4gIS6Gfh5QARTghCi7OMO1h8swqsGXrHKNnxE6fr3ibvoHqwd+yMpJpwHN+Lct75KMCVL/jLWqqZzaa9Epo/sKCr41aNSh5eMvNKAx925u9F9Huxdhhm8K9CP+47R/29fI0sSbp+GqumYtBaYTArR//6eO0emcM3AtkRZzYFv7j8VVvwz4OFDsyOr/JwcLeN6NCrw/Wa7v3x5bVyBv+epnhgbxohFlcwaYqn+RboGzuO1f5YgCILQIE70oLrttXQe/zSDv13ZM2ACz51ZUu1eqZTYZGRJ5v7P/8EVXS+kX1IPYqyRAa/TPRru/SUimAqBSPMThBC9ufaw4d6ZE6scsRNmYk8bjmyxIikm7KlDAkqft4u187cre7LhkfE8e31fEUjVs6JKNxaDWTzVWYZsj0KSg+vrpevg8GhUuFW8qo6mg0dScKhwpNTFk1/tYdSTP7AzryzwzfZY6HmNf/9UiHzIuMNb19hj6pRR1vjsqYUuaj6MaBwtCILQlFhMMi/f1J+fDxXzyo8HAp5XK42yGvwiw8L5cMpLgMQfvvwvfV+4gls/+COFlYETZzUdRwgkgilBCNHb6Vl4DFalQlnlyCt1MbZrPNE2gxUMoc6pmg4GGXiKLQrNUYau1U1+uNOrUuLwcu3cNcYB1WXPQMsuoNSwKnQ6SUYzR/Kbst/zZUZ+7a8PC5xpPF2thS4kGWyBlQsFQRCExhX1Sw+q19Yc4tOteVWek801Twx2btmBZy97mPR7PuDb21/jaEURj3/3YsDrRAPf0Ii/LUEIgcenUek23vwZyiqHRZE5Ulrz3hmh7kTbzPgMAuCwpK5IJjOOvWvr9PMcHpXfLlhHqeO02T2zFaZ9Dm36+tP2amMKg4h4LHf9wL9vu5QnPt3Js9/sRTNqjnVC18tBrjlIr7XQheqFdsGlPgqCIAgNKzHaxsJbBvH4JxlsOPjrypISExb0nX1qXHsm95zInsLTVrikX44jBE0EU4IQAqdXRZGNiwyEssohSf/P3n3HR1VmDRz/PfdOTS+kQQIk9N6rVBV7b6hYEBDUVVd31y3uvu67zfVddd1dXVGxIboq6toVO6CANKV3QiAQIAkJaTOTaff9YwIkzCSZSaGE8/18+EBm7tx5JiT3Pucp5xBIXiBOiJRYK4lRwbNBmjWahDFTKPniGRzbluH3uDB8Xpw7V1H6zYsNnrPxdOo+3li5J/iFtjiY+gmc/1dI7FwTVB33M2WJCcwMnXU/3LkMkrvQPzOB9+4+i+92FHPHq6uprCeoZ8QsCCOgf2i8lSp3iKBM6dD7crAnNHoOIYQQJ0fv9nE8MXkgd722mh2FgfqDUYNSA9ljQ9hxaDfPrniD/eWFABSUH+T9zV8xuH3d5FjKpMl+qQhJAgohIhBjNYWc4YC6sxzRPRve22IYgeK84sRQSnH72Gwe+3wbzuNSvsYNvwotOpGyZW9S/NFjKIsda1pX4kZNrvd84SQacXn8PP/dLm4fm4N2fACum2HoVBhyK+xdCRvehvL9gbTk0SnQ7TzofgHodS/RqbE2/nP7CH7//kauenoJc24ZSqfk4+pbJeUEZr72fF/n4bATXZgsMOruej+7EEKIU8O47in8sqYG1X/vPIuU1CjM6VF49gYXd4+2RLGmYDNzVs6nvLqSOGsM53YZxW8n3lXnOD3JhqV948WBxTESTAkRAV1TZCTYKDjsCnqu9iyH0nRs2YNQmglX3hpce9bVSULh9fvJTLSfyKaf8a4ZmsXfPtsa8rmYPhOJ6TMxrPNElE692suSncWM7ZYS+mRKQdbwwJ8wWU06f72qH/O+383Vs5fyj8mDGNOt3bH2+Q3W9v0dvfKvwWYE/5w2yGSHXpdCRv/IXieEEOKkuG5oFvtKnUyfu5I3Zo4kbkIWJW9uxjhulXlGbAqzr/hDg+dSZo24CVmt2Nq2SZb5CRGh28fkYLeEXkYVN/wqEs+eTtmyN9n75BT2zp5KxQ8fYe92bP+Jriku7d+eaKuMZZxI8XYzd03ogr2RDbqNiSTRSLXXz8ZQiSiaSSnFLaM68+QNg7nvzTW8+N0uDMOgsMLFBf9czJSPqri9+j4cRvjr3l1Y8WeNgMufbvH2CiGEaD33nduNbqmxPPjqQqzLrsdmLAEiHUxTWLsmYB9Qz+CfqJf05oSI0NVDM3lkwZZ6n29slsOsK6aPzW6NpolG3HtON/aUOPhk/YGg5X7hiiTRiNdvBCehaEGjuiTz7l2juf2VVazeXcr3uw5R5vDg9Rt8S39ucP+WOZbHiaKaGBX6xurTrGgKvo+axOLEX/KQLstPhRDidKKU4q+TUqh46ir8vsMk6T9yyP9rqv0DMWi8HIcya1g6x5F8Yy9UPfvCRf1kZkqICMXZzNw4oiP2JqQOtZg0BndMpGd6iL0qotUppXjs2gFMH5ON1aRhNdX/f6jqqbMUSaIRTUG0tXkzYY3JSoriP7eP4JuthRyqdOOtlelvrdGVkdX/5m7PPXzv60m1YaLSsFFh2HEYVkqNaGZ7LyXvpmUMuvNlvthyiE/W72/V9gohhGhhbgeWVy8jyTiMCS9KeUk2/4VY/Q0UlSgcIV+mcKCUg9hRybS7ra+kRG8imZkSogl+e1Evth6o4Ic9pbg84RU3NeuKjHgbz948pJVbJxqilOIX5/fgppGdmPd9HvOW7cZnBAIfCOw5Mps0EkoPkmdNCnp9JIlGbGad9PjW3xv35abCep+r2LSY11e+xyuH9qJbbMSlpdBt1Hnomf05QBKG0tm05DBP39SFp28cwq0vraBneiw5KbIBWQghTgs/vgrle1H+Y1les/9ZjsPzErn3/hfNNJJK75W8smY972xczDs3/gVdHSDW9B5282qU5U7QHjqJH+D0JsGUEE1g0jVeum0Y977+I99uL8bRSJpzTfOQkWjhnTtGSBa/U0R6vI0Hzu/Jfed2Z/P+cg47PCgFiVEWeqbHsmj+Au7+0Y3zuJpNkSQa8RsGF/RNb9XPYRgGTy/aEfJnsL6sg5u3byQx85yaE8CXWwopqXLTLzOen03qzl2v/cC7d51V795AIYQQpwjDgKX/BE9w7UqfAf9a4eTBsUuJ0peSYHZj1Ty0t005dpAfWPk8TPxNINOsiJgEU0I0kdWk88xNQ3h7zSYe+uRbnJXpgAGGmUDdIC8oP5qlGEvyIlyJO5n86eM8f97zZMfLnqlThVnX6J8ZXFNp4jWTiF3xNk5r8M0lnHTquoLLB3YgppUTjazdW8bBsuqgxyPJOqgBb6zYw10TuzJlREdW5pXw0PsbePTaAUePcbp9fL/rEKVVbvwGJNjNDOucRHyU3HzPFHtLHby5Mp/thZVUubzE2c30y4znuqFZJEUH13ETQpwAu5eAozTkUw+MtvC3JdXcNcxCgq2BvVCGD7Z8BH2ubKVGtm0STAnRDLllufx98wwsmQ40Txzeit74vdGAhqZXoUfvQLcdAMDpB5ejihs+voFXLnyF7ondT27jRYM0s5lpnXT+sd+LSwVfKhtNNGLSmD6m9YPmlbtK8PqDl5pGknXQ5fXzzdZC7prYFaUUD1/Zj8v/vYT5K/MZ2jmRF5fs4p3V+9A1hWEYGICmFB6fnwv6pHP7uBz6dohvhU8nmsIwDEqrSymvLsekmUi0JRJtjm78hfVYsqOYp77ewQ97SvEbBp5atfa+2nyQJ77Yxtk9U/nJxK7ycyDEibb9S/BUhXxqaHudCZ1NPLa0mj+f3UAiCnclbJZgqqkkmBKiiUpdpUz7bBpVnioMDDTzYSxJSxt8jYFBlaeK6Z9N593L36WdvV2Dx4uTa8aMi1h43xzWpPXAVU+x5lDsZp1fXdCD7mmxjR/cTIednjqd2yMiyToIUOY8lnUw2mri6RsHcdlTS/AT2EdWO7FFbR+uK+DzTQc5p1cqf79uIJYGknqI1lXhruCDnR/w8oaXKXGVYNJMGBh4/B4Gpgzktr63cVb7s9DD/JkwDIMnvtjGnG931Zv90uUNBPKfbTzAN1sLefiKflw1JLPFPpMQohGVBxt8+o8TrZz1YhU/HdHI7HFVcQs26swiwZQQTfTa5teodFcSGKc/pvTbUoo/K8Zd6Ea36cQNiSPtmjT06GMdmCpPFfM2zeP+Ifef6GaLSNijuCQnhg3lHsK9XNrMGvee05WpZ52YpZwWXaGA40Od2lkHwwmozPqxIMgwDJ5dnIvfMHA3EkT6DXB6fHy5+SC3vricedNHYNIloDqRDMPguXXPMWf9HDSl4fQG9k64/e6jx6w6uIpNhzZhN9l5fMLjDElrPBHO459v44Xv6g+kavMb4PL4efC99WgaXDFIAipBiLq0AAAgAElEQVQhTgit4ett31SdS7qbeOQ7N71SGji2kfOI+sl3Togm8Pg9vL7l9TqdFYDiT4s58NYB0q9Lp/fTvcn5nxzch9zkPZaH3+uv8/r5W+fj8bVeDSLRPIECuN/ysDOTysYCKcPApKBP+zieuWkId07oemIaCbSLtWILkc62dtbBcKTFHVsC8vTCnXyy/kCjgVRtLo+fNfmHefDdDWG/RjSfYRg8tOQhnt/wPNW+6qOBVCgOr4NDrkPc8cUdfL376wbPu3BrYb2BVNWmheyfex97/n4Ne5+6mYPzf49r70Yg8HPwm/+uZ0dhZfM+mBAiPLEZoBruzv9hgo05P7jZV97ANT0mrYUbduaQYEqIJliUvwhvrRSkAD6nj8L3Cml/U3ti+8eiTApLioWsu7JwF7spW1pW53i/4efLPV+eyGaLMBVVVHPJv74jr7gKRzip75VC1+CifhlM6JHa+g2s5bze6YRagVc766Bj2zL8HheGz4tz5ypKv3mxzrHRVp3Jw7KAQKKJp77eUe9sREMdaafHz/tr9rG/rP4OvWhZT/74JJ/lfYbLG7oocygun4tfffsr1hatrfeYf365PeTPQPmKdyn5ag7xI68j8+5X6XDnS8QOvgjn9uVHj3H7/Dz/bW5kH0QI0TQ9LwGTtcFDuiZpTO5j5l8r3KEPsMRAv2taoXFnBgmmhGiCZQXLcHjrFsFzbHfg9/iJG1K3IK9u04ntH0vlxrojtQ6vg6X7Gt5jJU48n9/gxjnfU1LlrnefUCjVPnjy6+18vvFAK7YuWEqslXHdU1AhEjXFDb+KxLOnU7bsTfY+OYW9s6dS8cNH2LvVTUph1jTO6RkIAj9cWxDyXBBeR9oA5i3b3VIf77RnGAZOt4+qai+GEf7PUzgOVB1g7sa5OH3BwWvpt6Vs/912Ns7cyJZ7t1AwtwBf1bHgyOVz8fslvw953tyiSjbtLw96/EiGyKRJdxLVYzSaxYbSTUR1HVGnJIDPD++t2UdVtTfoHEKIFtZ+ICR0avSwh8ZbqXKHvgaV+cxsiR7a0i07Y8ieKSGaoKS6JOgxX6UPU4wJpQf3RE3xJpy7gzs8oc4jTq5vthRScNgZMpCq2rSQ8pXv4Tm0F81ix5yaQ/zo67Bl9gECS5z+8slmJvVOQ9UXkbSCmeNy+G57cciZhMayDlpNGlNHdz66z2n2op0ha1aFm2rd7fUz7/vd3D+pe519WK0pv8TBjsJKKqq9RJl1spKi6JHe+sk/6mMYBj/sKeW5xbl8vaUQvx+UCtQdG9IpkTvGd2FCj1R0rXk/I29seSNozyYElhsXfVpE5oxMYnrH4Cn1UDCvgLzH8sj+bTZaTZKQfZX72HRoE72Te9d5/WvL9+AL8fMfSYZITSk+Xref62pmPIUQrWjM/fDR/eA5Nsibd1/da2BWvIbrd3HHvxLDZGNzp5u5+4VVDOucyL3ndKNXRvBxon4STAnRBGYtuLaOHqPjrfRi+IyggMpb5sUUE/zrZtGkNsup5plFO6mKoACuc/vyo8EUQGF5NT/mH2Zwx8QWa5Pfb6A10PEe1jmJa4Zk8vbqvWElCzjCjJecGI07JnQBwOPzk3codIrdSDrSPr/BvlInnds1PR13OO/xzZZCnlm0k/X7yrCYNAyjpsKb36B9go07J3Tlkv4Z2MwnrvjwD3tKuf+NNRRVVuP0+Dg6GVXz98q8UjYV/IjVrPOXK/pyYb+MJr2Px+fhza1v4vHX3Xd5ZLlxh+kdiO0f6EwdWW687YFtlC0tI3Fc4GfT7XfzysZXeGTcI3XOsfVARcjBhEgyRDrcPnKLZd+UECdE32tg9VzYtxp8wXUH66WZUYnZjJz8GxZj4bXv93DLiysY0jEQVPVuL0FVOCSYEqIJMqIz0NDwc2w/TVTXKJRJUb66nPjhx2qt+Fw+KtZVkHZN3c2dmtJIi5YNn6eSPYccrN9XFvR4JAVwq70+nl+cy9M3NZ4trT5lDg9vrc7npSV5FFa48PgMLLpGVpKdmeNyuGxAB+yWuh3aP1zWhwqXh882HgwroLKaNDrG2fmP+jW2dSUw5FYqXF4suka1N3ifWCQdaV0pyl2tl1xlb6mDG+cs51Bl9dHA9/g27yyq4qH3N/CnjzYxd9pwBmYFF2ZuaV9vOchdr/2Aq5F9dlVuH1VuH/fPX8O+w05mjM2J+L02l2wOOSsVznLjI8GU3/CzMH8hOworOOzwUOYM/NlVTxAUaYbIww5JsCPECaGbYMp8eOkiKN4G4eyh1C0Q1wFu/RAsUUQBt4/L4aaRnXht+W5ufWkFgzsmcO853ejTXurHNUT2TAnRBBdlX4RFrzurpEfppF6RSsGrBVSsq8DwGriL3OQ/nY85yUzC6LqdOatu5dIul57IZotGrN9XFnJpWiSzMn4DFm4r4vHPtvLJ+v14fGEksKjh8vh44K21DH/4Sx7/fBv7DjuP1pBy+/zsLKriDx9uYvCfvuCRTzfXWYqlaYonJg/k5+d1J95uJtoSurNrN+tYTRqXDWzP+/edTeK0t2Dxo7DsaSwmLeTyLqjbkW6M12/QwtuDjtp9qIpL/vUd+0qdIWcQa3O4fZQ5Pdzw3Pd8n3uodRpU44c9pWEFUrW5PH4e+3wr7/+4L+L3O1x9mEBS/LoaW27sray7j6nK4+T2eav4yyebmff9bhZtKwqZ0AQizxCZGCUz70KcMNZYmP4F9LkikJDCVE+RXt0a+NP1XJi1GGJS6jxtt+jMGJvD4gcmMqxzEre9tJKZr6xiQ4iBxnBVVXvZVFDO8txDrNt7mMKK8BPmnA5kZkqIJuiR1IOs2Cy2H95e5/GUi1LQo3UOvHkAd6Ebza4RNziOrFlZaMelr06PSqdPch/EqaPC5QkZTERaANfh9vHkNzuItupoSnHzyE7cOrpznfTjxytzeLjuuWXkFVeFnBmqfW6AuUvz2LCvnBemDsVqCrRLKcWMsTncOrozX20+yDOLctlRWInL48Osa6TEWrltdGeuHppJnK1mqWpyF6qmfMgPLz/Aig1R+PwpId+3dkc6uueYBj+/y+Pj2meWkpUURa+MOHplxNE7I46eGbGkx9mavJ+swuVh8nPfU+7y1NvhD8Xp8TH95ZV8fO/YVll6aBgG972xpt5AqqG9di6Pn1//dz2T+qQRZYnslhwqYI10ubGuKb7+2YQ6/yd//WQzLy7ZFVQMunaGSKXp2LIHoTQTrrw1uPasq5OEItqi0zU1JqLPI4RoJrMNrnwWzvtLYNnf8tngOES1oWPCh2GJwTR8BgybAfEdGjzVkaBqyohO/GfFHqa9vJL+mQncd243+nYIb6Zq8/5ynv82l4/X78ekaYHkRgZU+/wMykpg1vgcxndv/v7Rk02CKSGaaHq/6fxh2R+C6rokjU8iaXxSg6+1m+xM6zetwWPEiWc1a4S6pke6vOmIqupA4PP8d7uYuyyPl6YOZ3h28M9GtdfHzS8uJ7eoMqgDWx+nx8+q3SXc+/qPzJ4ypM6eKrOucUHfDC7oG3o/TmGFi++2F7Myr4RVeaXsKKykT9rdDCv8nJGJFpYfjg8KViLpSA/qmMAbM0exs6iSzfvL2by/nBe+28Xm/eX4DINe6YEAq2dGLL0z4uiaGhPWvqa3Vu2lzOEOGUg1mhzE6+dfX23n75MHhvHdjczq3aUUV4bepxDOXjul4P01BdwwvGPIc5Q5PGwrrGDbwQq2Hahg68EKtpRuw5fqCVpfEuly4yhTVFBwO2VEJ15emkdwKehAhkgtOpGyZW9S/NFjKIsda1pX4kZNrnOc3wiUChBCnATR7WDczwN/vNVc+shH+EzRPDttTMSDHHaLzvQx2UwZ0ZH/LN/D9Lkr6dchnp+e051+maGDqgqXh1nzVvPDnlI8PqNmkLLuYNPyXSVs2FdGjM3E3GnD6Zl++u7PkmBKiCa6MPtCPs79mBUHVlAdwYZPi99goNPJpYl9W7F1oinS6pk1iWRWJhS314/bC7e8uJxXpo0ICqjmLdvNtoMVIQOpxmY1vt1ezJebD3Jen/SQ720YBrnFVazKK2FlXimr8kooqXIztHMSQzsn8tClvenXIT4QzDj6s/H5WVxTdjNOIzi4CacjHW3VuWN8Fywm7eisVO22FFVWs3l/BZv3l7NkRzEvfLuLvENVdEqOomdNkNWrJshKibUe/f8wDIPnFufiDDH7E07A4vMbfLx+P/97eZ9js3It5NnFuSH3qYW7187h9vHMwp1cNiCDHYVVbD1YwfaDFWw9WMm2AxVUuDx0S4ulR1os3dJiOLd3Gl1S+nHNpy9S4a577am93FizaXWy+R2/3FhTGhOyJgS1u2NyFAOyElixK3S20cYyRJo0xTVDMoP29QkhTgKTlXItgSqnl+Topi+9tZl1po3J5sYRHXl9xR5mvLKSvu3j+em53eifeey6ctjh5op/L6GgzIW7gVUWcGz/6FVPL+XVGSNaNHHTiSTBlBBNpCmNv0/4O3d9eRfri9bi8tdTDK8Wm99Pr2o3/yw8hD7nbJj+JaR0PwGtFeEYkZ2MOcRek0hmZRoLfm57YRlf3T2K9PRAQOX3B4KEUEvEwgkSHG4fzyzaeTSY8vj8bCworwmeAjNPVpPGsOwkhnZOYsbYbLqnxobODhiVRJ+Zz9Pp4ffY6k7CCLGttvFU6zpn9wxduFgpRWqsjdRYG+O7H1tOWO31sf3gkVmsCr7dXsTm/eVoStEzI5Ze6XGYdMVhZ/DvWCTJQTSleGfVXm4bk11v+yNlGIGsgqGW3EWy125PiYNBf/qCrimx9EgPBE23jupE97RYOiTYQ/5/Xd/jeuZunIv7uGtPuMuNLZqFW/rcErI9953TjelzV0WUHfIIs64xvQW/x0KI5lEYODw+4u3NH0iymXVuOyubG4Z35M2V+cx8ZTW928fx05rsfze/sKLOft9wONw+bn1xBR/fM5aOyVHNbuOJJsGUEM1gM9l4bvRf+Ne8CbwRZUIBDi24A2r3+zGAa8srub/0MGYAVzm8fDH8ZDlENbwsUJwYuqaYOjqbpxfuCNq3FM6sTDjBj8ft5am7/8K06u3YBw1ibfYgKl3Bo4WRBAnr95Xx0Psb2FFYydr8w2QlRTG0cyIX9cvgoUv70CHBHv43wRbP03dcxGVPfUelP7JRTJvyMueidkdrVoXLatLp2yG+zjp8wzA4WF4dCLAOlPPej/tCBpyRBCxOj4+F24rqDaYMI7AcxXvkj8+Px2fg9fvx+kI85jeodHnx15NtI5K9dlEWnVenj2BQp/BHZif3mMzcTXNDPhfOcuOs2Cx6JvUM+dzoru2455yuPPnVjogCKptZ4x/XD2zVtPhCiMbtKKzkpSW7WL27lIMVgRnsqS+t4JZRnZnYs/n7lGxmnVtHd2bysCzmr8pn1rzVJMdY2FnPcvXGlmJXVXt5/Iut/PP6Qc1q18kgwZQQzWRa/gw/KynhzmIPC6KjmBsfy6e/2YXP7af/37qRqRtMLSvnwOLDzF/n5pdTj3QyDHCVwYrnYMKvT+pnEMfcOKIjsxftCPlcQ7MyYRe11Ux81O98fnXFVDxr1/De+mIcWnpg40wtkQQJHp/B1gMVzBqfw5COScRHNW/0Mad9Cv+ZNYYpz35Hld+EP4zEr3aqedr0D4Ys2AZRc6BX8zJVKqVIj7eRHm9jYs9UDpa52HYwOGV3pMlBlu4sZvRfv8JTExgdDZL8gSBJU2DSNUyawqQpzLqGSVeYtCN/H3tM1zQ0QieCgMj22mmaIkRyvgalRacxve90Xt74ctDezcbYdBt/POuPDR5z14SumDTF37/YRrXX32CGRl2BxaTzxOSBnF/PklMhROtbuqOYRxZsYVtNvbjaNeMWby9m9e5SLCaNGWNzmDUuJ+LBr+PZzDq3jOrMdUOzmPjYwiavsvAbsGDDAcocnmbfw040CaaEaA6vG1a9AD43duDKyiqurKyis9dLhReueWsXD461AvB8iM3c+Kph+bMw9heBOhHipEuJtfLoNf154O11EaW5jij48fv53pzKpFtupmzO97AzOG13pEFC97RYzu7ZcnXL+ndK4eO7R/M/T8/le283DMBN3ZkqMx40DPqoPP5ofpm+Wh54gP/eDpNfg67ntFh7LKbQN/xIk4P0TI9j9k2DAwGRpjDpGmZdoWsKs6Y1WBw5FMMw6PLgJyEDjUj22vn8RpOW4Nw54E6KHEV8vOvjsAMqm27j8QmP07dd4/s2Z47rwtDOScxeuJPF24qAujW97GYdA4PLBrRn5rguksFPiJPo5aW7eOTTLQ3eu47sU3ry6+0s3lbEi1OHEW1tfv9jT4mD0qrmL8V+a3V+k2rvnUzSexOiOTZ/UO+w9AOjLfxtSTV3DbOQYGugg+Zzw7YF0OuSVmqkiNSlAzpQWe3jDx9uDDugiiT4cVV7+fH1D+jt2Eql0Q8s7YKOiTRIcEdQzypcHdc8xlzbKxzw2JnnPZcPfGexfPYv8XvdDJn1KOfbN3Cb/hlfrdnD3es8LDwy6+pxwvyb4WebwdYyxR7T4+1YTcEFhSNNDtK5XRSZiS23Jl8pxeCOiazaXRr0XCR77exmnU7JkS+NU0rx0KiHyIrNYvba2Sil6g2qokxRxJhjeHzC4wxMDT+r4eCOicy5ZShFFdW8++NecouqKHd5SIyy0DsjjssHdSCmBTpjQoimm78qv9FAqjaXx8+a/MNMe3klr80Y0ewZqiU7ign1zpEuxV6w4YAEU0KcUXZ+De7gpUcAQ9vrTOhs4rGl1fz57PrrC+GuhNxvJJg6xdwwvCOdkqP46yeb2X6wEo/fT0PxSiTBjw9FVVwiMcPGkbrTCge9QcdEGiS0i7E2ekxE3FXw46vgdZGuXDxgfosHzG/RWVVQYcDla+89OusakmHAmtdh5B0t0pyL+qXzfwu2BD0eae2j64eFTj/eHHeM78JP3/zxaCr82sLZa2czaUwbkx20h2HjoY28tuk1tpRsweF1YNNtdI7vzJReUxiaNvRopkOlFNP6TeP6ntfzUe5HvLThJQ44DmDWzBiGgdfvZUjaEKb1m8bIjJFoqmmdppRYKzPHdWnSa4UQrSe/xMFD728IGUg1tFep2utn3d4ynl2cy08mdm1WG8ocnpDZ+yJdZXHY6WlWO04GCaaEaI6q4gaf/uNEK2e9WMVPRzSykb+yqAUbJVrK6C7t+PCesWw7WMGL3+1i+a4S8kscddagHxFJ8KMpSBkygPizu3H+ynyWfbjxaDHeo8dEGCSc1SW5ZT70Eevfpr5NPGHNunocsPSfMGJW0H6wpsiItzMiO4lvtwf/zoVb+yjGZmJ0S3+fgIk9U7HoGlWETtTQWAZEwzDq1Jj6cveXPPnjkxRUFuD2u/EbxzoouWW5LCtYRpw1jpn9ZnJN92uOBlVR5iiu63Ed13a/lgpPBeXV5Zg1M/HWeGymBgZ0hBCntZeX5oUsOB/OXiWnx8fz3+Zyx/guzUpKoWuKmpq8dR+PcJWF6TQs4CvBlBDNoTe8x6Fvqs4l3U088p2bXikNjAabWnhWQbSo7mmxPHJ1fwB+Pn8N7/64r1lFbe1m/ehSs0sHtOf3H2wM+b7hBgmxdjOjWjpIWP4MeKpCPhX2rKurDPathsyhLdKkWeO6sGp3KU53cNDSWMBiM2vcPjYnZB2x5tI1xZ+v6MvP31ob0T47ALvu53bL5yQVmDG6nssTq5/g9S2v4/K5Qh5vYODwOnB4HTy68lFWHFjBw2MfxqwduxYppYizxBFnOX2LYAohwuPy+HhjxZ6gDHqR7FVy+/x8s6WQc3s3fd9tYrQFm1kLqgUY6SqL5Jim18I6WZq3QFKIM11cB2hkycwfJtiY84ObfeX1pMJSOsS1b4XGidZw86jOWE2hR9fihl9F4tnTKVv2JnufnMLe2VOp+OEj7N3qrhX3GxzNeGa36FwzJLPe0biYPhPJuPUfdPzZO2Td/Sqp1/4vtsxeR5+3mTVmtUaQULG/waf/ONHKkyvcFFU1EDwoDcr2tliTzuqazGUD2mM3R3brspg0+rSP59bRnVusLce7uH97fjapO7YI2mY361w8sCP333o9fPhTnv7vdQ0GUsdz+pwszF/IQ0sewmgo1Z4Qos1auLUo5OR/JHuVqqp9zF2W16x2nNsrLWiQEeoONDq2LcPvcWH4vDh3rqL0mxfrHBtdcz883cjMlBBNYBgGO4sqKWl3JX6Wk+A/RFe1D5MK7lh2TdKY3MfMv1a46ZcaoqNlskC/a05Aq0VLGJAZT3q8jV3FoWdtGpshMWmKa4ZkYrccC8jumtiFD9buo8wZvHeqIbpSJEdbuHZYVkSvC4u3usGnw5p1NfyBZBQtRCnFw1f2o7Kyiq8378dJ4zO6NpNG97RYXr5tGOZmbrBuzMxxXWgXY+W3725AKYKWbtZukwHcPjaH+yd1QynFmqtn8/LCe3Adt0im9NtSij8rxl3oRrfpxA2JI+2aNPTowM+Py+fiqz1fsSBvARdmX9iqn08IcerZX+bEHaKuU6R7lfaVNu9anR5vY2ROMou2BW9bCHeVhdvnZ0zX4IRMpzoJpoSIQJnDw1ur85nzbS4VLm9gfXH1PfgNMOPlNn0BN5q+AirqvO6h8VbmratnU2VSDqT3a/3GixahlOKn53TjN/9dH1Ex0yNMumLacUVjM+LtvDp9JNfPWYaj2hcqiX4QXVPE2828OWtU62RSs0QF9j014A8TbAx+tpKfj6onqFEa2Fp2qZmOwVPqbzzX7Xye3puD1+enKkTQEmXRMQy4bmgmD17cq97ZxJZ21eBMzu+Tzvtr9vHMolwKK1xHgzif38Bu1pk2Jpvrh2WRXCtpyAs73+b48LX402KKPi0ic0YmMb1j8JR6KJhXQN5jeWT/NhutJl280+tkzro5EkwJcQZyefz4/MEDuZHuVXJ5I7+fHW/W+BxW5pWEHEhqbKDRrCs6JUVz4T+/Zda4Ltw8qhM2c3jX7Q37ypizOJeluYeoqg70zRKjLFw7NJMbhnds+QRNx5FgSogwvbosjz99vBlNqeM60faj/3radzlP+y7n3p98yDmmd44+nhWv4fpdiE6lOQrOur/1Gi1axeUD2/PdjmI+Xrc/ooDKbtb5y5V9yW4XnAK7X2Y87911Fje9sJzKam/IzHBHRFt0MhLsvDp9BOnxrZRYIL1/IFtlAxqddfV5ILV32G+5eX85Ly3ZxZr8w1S6vNjMOllJUdw6uhPju6cGBi+W/hPlcTBr6gymofHV5oM8uyiXnUWVuDx+LCaNjHgb08Zkc/nA9kRZTvxtLtpq4sYRnbhheEfyS5yUONz4jUAdqc7J0UGbvIudxSzdtxSjVhjtc/oofK+QDtM7ENs/FgBLioWsu7LY9sA2ypaWkTgu8ejx+RX5bCnZQs+knifmQwohTgmxNhNmXQsKqCLdq9QSg3KjcpKZ1CuNzzYdiGj/qK4UaXE23rlrNAfLXTz22VZeXLKLn57TLbAMvp5VBUt3FPP7DzaSX+rE4/Xjq7XcucLl5amvd/Dk1zuY0D2FP1/Zl9TY1rlfSjAlRBie+GIbzy3ODapxc7zqmqKmc3wXs99I5m/m5+pPZGayQaezoO/VLdxa0dqUUvzf1f1RCj5aG15AZTNr/P7S3lw1uP714N3SYlnyq7P5ekshzyzaycaC8jrFat1ePyNzkpk1PodROcmtkkzhqNH3Qv7yQIr0BjQ061oS35v4hM40Nrb4zdZC/vbpFnYdqsLjM+pkpcotrmJVXglWs87M/iZu3/o0+syvQTdhBi7om8EFfTMi/HAnhlKKjslRdExuuK7Vgl0Lgv4vHdsd+D1+4obUHYTRbTqx/WOp3FhZJ5hy+9y8ve1tfjfydy33AYQQp7w+7ePQQtwLIkmKpGuBenLNpZTisesGUPaKh+W5h4KSUYRi1hXJ0VbenDWKeLuZeLuZ524Zyo97Svnbgq08tziXn53XnYv6ZtQpqv7mynx+/0HodPBHHOmzfbXlIKv/Ucr8O0bRJaXlC4tLMCVEI+avyue5xTvDuigc4cTGR/6RdPAWcZ/53aDnXVixtB+CNnkeaJIH5nSka4q/Xd2fsd1SeOrr7eSXOHAfFwhYawKhkTlJ/PTc7mHdrEy6xnl90jmvTzr7DjvJL3HgcHuJsZrp3C6q1UbWgmSPB2tcUDCVd19sna/rm3X1mqKZ7bmUL/++iFnjcrhycIeQS+1mL9zBP7/a3uANscrto8rt4x/LKvku4x/MiWpfaz749JdfkU+1r+4iP1+lD1OMCaUHd5JM8Sacu+vub/DjJ78iv1XbKYQ49QzMSiAl1sruQ8HLssPdq2TWNaYft/y8qcy6xgu3DuPRz7bw8tI8NKVCLvuz6AqlFCNzkvnH5IEkRtfN4jeoYyL/uX0E3+0o5tHPtjJ74U5+eUFPxnVrx2cbDzYaSNXm80NJlZtrn1nGgp+OJTWuZe+jEkwJ0QC318+fPtoUMpBqqBAeBAKq2b7LmWr6nARV0yG1RGMYBl9bL2Rb5i+4z9yWuoRnHqUUlw1oz2UD2rNhXxn/Wb6H3OIqHG4vcTYzA7MSmDKyIxnxTft/7pBgp0PCSfoZ0TQ4+3fwyS8iTyKhmTDFpfPgT+7lnN1lzF64kye+3Mb0MdncOKLT0eUkLy/Zxb++2hH2DdGFhZVFGre/soq504Y3qybKqcQRYm+aHqPjrfRi+IyggMpb5sUUE3z7dnpbLtmHEOL0oJTijvFd+NNHm5q0VwmgW2oM3dJiGzwmErqm+PWFvbjn7G689+M+nl2cS36pA10pfIZBrNXEDcM7csvozg3e45RSjO2Wwpiu7Viw4QB//HAjCVEW1u8rC1kguKF+mQGUOd088PY65k4b3mKfFSSYEqJBn208gL+JhfAANAzmeycw0/41JGTB6HtQfa9miEvnoX99x9geGQzplMj+Mifzlu1m4bZKhoAAACAASURBVNYiyl0ezLpGSoyVG0ZkcWHfjLA3YYqTp2+HeB6+qo0lEhl0ExxYDz+80mgyiqOUDrZ4mPoRSjcxMieZkTnJbNhXxrOLcxn7f18zZUQnxvdI4ZEFW0IGUg3dEKu9flbvLuGlJbuYMTanhT/wyZFoC56xjOoahTIpyleXEz88/ujjPpePinUVpF0TXA8m3hIf9JgQou27fGB7nvhiG06Pj0irJNjMGr+8oHX2WkZbTUwZ2YkpIzvh9xs4PD5sJq3ePVD1UUpxYb8MJvVO4+fz1/LD7tKgY8Lpl/n88H3uIfaXOZs8yBmKBFNCNOCZRTuDMoVFUgjPiZXno6Yx4zev1Vnrm2aBP1/Rl7v/8wPZ7aJZvbsUA+qMtOwqrmJjQRm/e3cDNwzvyH2TurdO1jYhGnLBI2CNhaVPgc8NRgP7w8zREN0ObvskqHZa3w7xPHnDIHYfquK5xblMmfN9UJFJCO+G6PT4eW5xLtPOyq7ze3W66tOuD9GmaKq8x5ZU6lE6qVekUvBqAZpNq5PNz5xkJmF0Qp1zWHUrA1MHnuimCyFOAVEWE2/MHMnl/15CZbU37IDKbtb4xfk9GNstpXUbCGiaanYfRtcUK/NKgjLeRtIvM4B5y3a3aAApmzWEqIfH52fz/vKgxyMphAdQ7vJQUBa8/MasKworqlm68xDVXn/oKeuavSLzvt/NpU9+R2F5eMU8hWgxSgWW+03/LFAPzWQDSzQoE6BAtwayUiZ3g4sehZ8sh/j6k2x0So7mwYt6oWmq3hti0qQ7ieoxGs1iQ+kmorqOqLNZGqCq2suSncUt/3lPgrOzzkYLUfw75aIU0q5O48CbB9h05yZ2/mkn5iQz2b/MRjuuOLCBwVXdrjpRTRZCnGJyUmJ4/ydnkRJjJcrS8GoWs65qkiL1YfqY02eGf/P+Cg47gxMeRdIvc3v9vLWq5YrJg8xMCVGvcmdgud3xGfwiLYRn0jTKnB4ya63kWbqzmJ/854c6yQoaUu31s6ekimufXcZH94wh1mYO+3MI0SIyBsBVz8GFf4Otn0DlQfC6A0v6soZBhyFhn+qLTQfRQ2SfiuSGWOX2MW/Z7hMyotrazLqZa3tcy7xN8/D463YUksYnkTQ+qcHXa2iMzxwfcrmgEOLMkZMSw8IHJvD+mgKeWbiTwspqDL8fj9eLrpsxmxSGAZOHZTF1dGc6JQeX6TiVFVVWh9wrG2m/rCxEQNYcEkwJUQ9dUyGnyiMthAeBgOoIp9vHzFdWR7xXxOeH/WUufvfeBv55/aAmfy4hmsWeAANvbNYpCsqcIQtERnpDzC8Ncx/XaWBKrym8ufXNoGAqHBbdwqz+s1qhVUKI002UJZDc4fphWfyYf5iNW7ZSvno+trH3kBFv4+yeqaftPmy310+oqvaR9su8IYocN4cEU0LUI9ZmrlMA7ohIC+G5vX4So4/NJH2wdh/+EOcNZ6+I2+tnwYYDHHa4SYiyBJ1DiNOBy+3DF+JeFukNMZKikKe61KhUnjn3GW7//HZcvvCX89p0Gw+PfZgeST1asXVCiNONUorBHRMZrJsh7wdoodTnJ1OczQQhtslG2i+zt3AwKXumhKiHrinGdWsX9HtbuxCeY9sy/B4Xhs+Lc+cqSr95Meg82SnRR2sDGYbB7IU7g9KXRrJXRFOB2ldCnK7i7GbMIeon1b4hhnUeW9saDxyYOpAXzn+BGHMMNr3hOigWzYLdZOex8Y8xqdOkE9RCIcRpx+MEU9sow9IzIy7k/vJI+2UDshKCHmuOtnUnEqKFzRrfheW7SoKCn3AL4UVbdO6a0OXo11sPVlBYUbc4J0S2V8Tp8fPKst3MHNel0WOFOBX17RCPWdfw+Or+XtW+ISpNx5Y9CKWZcOWtwbVnXZ2BBbOuGNKp7e0R6p/SnwVXL+Dd7e8yd+NcHF4HPr8Pj9+DSTNh1sxomsYNPW9gco/JpEalnuwmCyFOZR4ntJGalvF2Mxf2TefDtfuDVg5F0i+7Y3zL9p8kmBKiASOyk0iMsuBwB2fjC6cQnlKKC/qmH/16f5mrRTZPHqp0h3WcEKeiEdlJxNvNIQtMhntD1JRi6ujTf9lKKPHWeKb2ncotfW5h+f7l5JblUuWpIsoURWZsJmd1OAuzJklohBBhaEPBFMCMsTl8tvEgTk/TChRHWU2M6dquRdskwZQQDVBK8a8bBnHT89/jjHB/hs2s8cTkgVhNxwIklzt0Qb1I94p4Qm04EeI0oZRi1rgc/m/B1ibfEAdkJdAxOaq1mnhK0JTGqPajGNU+vDIMQggRpI0FU307xDMiO4lluYeCsi03xmbW+F1NaY6WJHumhGjEkE6JPHXjYGzm8H9dbGaN31/Sm0m90+o8HmszEyIjdMR7RU7XTDxCHHH1kEyiLHqovcSNspk1fnGeJFwQQohGeRxtKpgCmH3TEHJSorGawu+X2c06s8Z14fJBHVq8PRJMCRGGc3ql8frtI+mcHIXdrFPfoEaURSc9zsYzNw3hhhGdgp7vkR7bIpsne6bHNvszCXEyxdrMvDFzJNFWU0QBlc2s8b+X9mF4dsO1l4QQQlAzM9W2ZvHtFp137hzN8OykwKBcAzcRi0nDatL41QU9uH9S91ZpjyzzEyJMgzom8s0vJrAm/zDPf5vLl+v34MaCUoE6UqO6JDFrXBdGdUlG1fObnRJr5awu7fhma2FQqYRINk/OauHNk0KcDN3SYnnvJ6OZ/Nz3ON2+kHuojrDoCk1T/N9V/VtlZFEIIdokb9ta5ndElMXEK9OG88Oew8xZvJNvthZh0TX8hnG0D6YpuHlUJ24e2Zn0+IYzpDaHBFNCREApxaCOifz7onZwYAqe+zZhGIGRj3DNHJ/D97sOhew4hrNXxGLSOLunZPASbUPX1FgWPzCR99bs45mFOzlU5cZvGHh8fkyahklXOKp9XDk4k3vP6UaHhLbXKRBCiFbThlKjH0+pQFbXITcP5VBlNWvyD1Pm9GDWNZJjLAzrnIRZb/1FeBJMCdEUhZshtXeTfklHZCeRkxLN1gMVeHwhslE0wG7WuX9S95AZAYU4XUVbTUwZ0Ykbh3fkhz2H2by/nAqXF7tZo32CnSU7DhFnN0kgJYQQkfI4ISbuZLei1SXHWDmnV1rjB7YCCaaEaIrCjZDaq0kvVUrxyrQRXPyvbymqqMbrDy+gspt1rhzUgZtHBu/FEqItODrKeFz9qPYJdu54dTX3n9u9xbMwCSFEm9bGsvmdiiSYEiIMPr+PtUVrKXQW4va5iStYSq/sSTR1DCQp2sKH94xhypzl5Jc6GtwroiuF2aSYelZnfnl+j3r3YwnRVvVpH4fdrLN05yEMDEpqlgIm2C0M7pRIvF1qLgkhRG2eg1V4i5z4C9LRjCT0/VVYMqJPdrPaJAmmhGhAiauEd7a9w7xN86j2VaNQGBgojwPP1l0MLfuR2/rexoj0EREHOe1irHx07xg+33iQ2Yt2sKOwEgjUkNKUwqJreP0GF/fLYPrYbPq0j2+NjyjEKW9PiYM4u5mpL63AZtYxMMAIzGR5fH4u6pfBDPkdEUKc4QyvH+eGYsoX5uM75ApkYPAMgXwTrFiDnmAldnwWUQPaoaTESouRYEqIenya+yn/s/R/AKj2Vdd9UgF+D0sLlrKmcA3Z8dk8O+lZ4q2RdebMusbF/TO4uH8G2w9WsCKvhDKnB4uu0S7Gytm9Uomzyai7ODMZhsHjn29jzre5+PwGXr9BZbU36LgP1hTw6Yb9TOqVxt8nDzwhG46FEOJU4jlYRdGc9RhuH4a7dgkWC/gB/HiLnBz+YAdln+TSbno/LB1iTlJr2xYJpoQIYf7W+Ty68tHgICoEh9fBttJtXPfhdcy/dH7EAdUR3dJi6ZYm9aOEgEAg9eC7G3jvx32NVrn3GQY+j8EXmw9yywvLeWX6CAmohBBnDHdBJUXPrsOorn/LwBGG24/h9lP0zFra3d4Pa8e2n5yitcndRojjLCtYxqMrH8Xlc4X9Go/fQ5GziJlfzMRvNNzxE0I07vnvcnnvx304PY13Do5wefysyT/Mb9/d0IotE0KIU4ev0k3x8+vDCqRqMzx+il/YgPdw+H0dEZrMTAlxnPoCqdJvSyn+rBh3oRvdphM3JI60a9LQowPrjj1+D3lleSwtWMqYDmNOdLOFaDNcHh9PfLG93kCqatNCyle+h+fQXjSLHXNqDvGjr8OW2Qenx8/7a/Zx37ndaC+p1IUQbVzlkgL89SSxWrF3HQ9/M5ttxXlomka35E78/px7GJgRyEZsePxULNpL4uVdT2ST2xwJpoSoZUvJFvIr8oMeL/60mKJPi8ickUlM7xg8pR4K5hWQ91ge2b/NRqsp2uvwOnhxw4sSTAnRDB+v20996VzKV7xL2fK3ST7vJ9iyB6N0E85dq3FuX44tsw8AhgHzlu3mVxf2PHGNFkKIE8zw+alcVgDe4BIrFdVV3Pb2r/nLeT/j0p4Tcfu8rNi7FqtuOXaQ38Cx+iDxF2ajWSQhRVPJMj8hanll4yt4/J46j/mcPgrfK6T9Te2J7R+LMiksKRay7srCXeymbGlZnePXFq2loLLgRDZbiDblmUU7qQox0uqvruLwd6+RNOlOonqMRrPYULqJqK4jSJw47ehxbp+fV5fvxt3IXishhDidOTcdgnpKVeaWBAaGr+h9LrqmYzdbGZ89nF6pXYLPs66oNZvZ5kkwJUQtqw6uwmfU7cQ5tjvwe/zEDam7SVO36cT2j6VyY2Wdx82amfXF61u9rUK0RT6/wY6iypDPVe/bguF1E9V9VKPn8fsN8ksdLd08IYQ4ZVTnltW7VyonKQtNadz/8V/4Zuf3HHZVhDzOcPtx7Tjcms1s8ySYEqKWKk9V0GO+Sh+mGBNKD154ZIo34a2sm6rZ5/dR7i5vtTYK0ZZVuryYtdCL/HzOcrSoOJTW+HIUTVOUOz2NHieEEKcrf1X917hYazT/nfIUoPjlgkcZ+K/LuO2dX1NUVRLReUTjJJgSohaTFryNUI/R8VZ6MXzBc+neMi+mmLqvUUph0SxBxwohGmc2Kfz1LFvR7XH4HeUY/jDS/xpgMcktTgjRdilzw9e4bu0688TFD7LyJ+/w5fSXOVh5iP/96sng88i1slnkuydELcm25KDHorpGoUyK8tV1Z5t8Lh8V6yqI7h1d53FNabSzt2vVdgrRVtnNOlo9M1PWDj1RJjOObcsaPY/H5yclxtrSzRNCiFOGnmALuyffNbkT1/W9gK1FuXWfUKAn2lq+cWcQCaaEqOXa7tdi1+umU9ajdFKvSKXg1QIq1lVgeA3cRW7yn87HnGQmYXRCneMViuHpw09ks4VoM5RSXNIvAz1EQKVZo0kYM4WSL57BsW0Zfo8Lw+fFuXMVpd+8WOfYHumxpMZJB0EI0XZFDUwBLXRXfseh3Ty74g32lxcCUFB+kPc3f8Xg9n3qHKdMGtFD0lq9rW2ZpEYXopbLul7G46sfD3o85aIU9GidA28ewF3oRrNrxA2OI2tWFlqtaXazZmZyj8mYdfOJbLYQbcqMsTl8smE/vhDr/eKGX4UWnUjZsjcp/ugxlMWONa0rcaMmHz0m2qpz5/jgjFVCCNGWmFOiMGdE48kPTi4RbYliTcFm5qycT3l1JXHWGM7tMorfTryrznF6og1Lh5gT1eQ2SYIpIWqJNkdzcfbFfJj7YVCK9KTxSSSNT2rw9ZrSuL7n9a3ZRCHavN7t48hpF8PmA+UYIfZPxfSZSEyfifW+3qxpnNtbRlqFEG1f3IQsSt7cguGuWwoiIzaF2Vf8ocHXKrNG7ITM1mzeGUGW+QlxnF8M+wWpUanoKrICdjbdxgNDHyA9Or2VWibEmeOpGwcRbYl8vM9m1nj25iGYdbm9CSHaPlvvJGy9kqGRZBRBTApr1wSiBqa2TsPOIHK3EeI4sZZYXr7gZdKj0zFr4S3Xs+k2Zg2YxeSekxs/WAjRqJyUGF6bMYJYm4l68lEEsZt1nrxhMCNyghPJCCFEW6SUIum67ti6JTSa3e/oa8wa1s7xJN/YCxXuBVbUS4IpIUJIj05n/qXzuaDzBVh1KzY9eCO7QmE32cmMyeSRsY8wo9+Mk9BSIdquAVkJfHzPWM7q2g6rScMSYrZJ18Bm0uifGc8bM0cySZb3CSHOMErXSL6pN7HndETZTShr6JU1yqKjbDqx4zJpN61v2MGXaJjsmRKiHnGWOB4e+zC/HvFr3t/xPvO3zqfEVYLX78VusjMgZQC39b2NASkDUEpGdoRoDR2To5g3fQT7y5zMW7ab99bso8LpxY9BjNXEOT3TmDamM11TY092U4UQ4qRRmiJuQhaxYzNxbT5ExeK9eIudGB4/yqyhJ9mIHZuJvU+y1JVqYRJMCdGIOEscN/e+mZt733yymyLEGSsj3s4vL+jJLy/oebKbIoQQpyylK+x922HvK/UuTxQJTYUQQgghhBCiCSSYEkIIIYQQQogmkGBKCCGEEEIIIZpAgikhhBBCCCGEaAIJpoQQQgghhBCiCSSYEkIIIYQQQogmkGBKCCGEEEIIIZpAgikhhBBCCCGEaAIJpoQQQgghhBCiCSSYEkIIIYQQQogmUIZh1P+kUkXA7hPXHCHECdDJMIyUk92I5pLrkxBt0ml/fZJrkxBtUr3XpgaDKSGEEEIIIYQQockyPyGEEEIIIYRoAgmmhBBCCCGEEKIJJJgSQgghhBBCiCaQYEoIIYQQQgghmkCCKSGEEEIIIYRoAgmmhBBCCCGEEKIJJJgSQgghhBBCiCaQYEoIIYQQQgghmkCCKSGEEEIIIYRoAgmmhBBCCCGEEKIJJJgSQgghhBBCiCaQYEoIIYQQQgghmkCCKSGEEEIIIYRoAgmmTnNKqQlKqb0n+rVNfL+FSqkZJ+r9hBCnttPp+nXce3+qlLr1ZLy3EKJ5TqfrjvSbTg8STB1HKVVZ649fKeWs9fWUVnzfqUqp71rr/M2llOqslDKUUqbjHn9ZKfXnk9UuIcQxcv1qmArIVUptiuA1/6uUerX2Y4ZhXGgYxtyWb6EQpx+57oQm/aYzh6nxQ84shmHEHPm3UioPmGEYxpfHH6eUMhmG4T2RbRNCiIbI9atR44BUwKSUGmYYxsqT3SAhTndy3RFnOpmZCtORqV2l1K+UUgeAl0KNitSMQnSt+bdVKfWYUmqPUuqgUuoZpZS9Ce99m1Jqs1KqomZUdVaIYx5UShUrpfJqjwS1VBvCbOdUpdR3Ne9XqpTapZS6sJ5jM5RS65RSD9R8vVAp9Sel1JKaz/m5UqpdreMvU0ptVEodrjm2V83jtymlPqx13Hal1Fu1vs5XSg2s+behlLqj5pjDSql/K6VUa3wvhDiVyPXrqFuB94FPav5duw19lFJfKKVKat7rQaXUBcCDwOSaUfa1NccuVErNqGnfYaVU31rnSakZmU+t+foSpdSamuOWKqX6N6P9Qpw25LoTVjul39QGSDAVmXQgCegEzAzj+EeA7sBAoCvQAXioCe9bCFwCxAG3AU8opQYf1652Nee/FXhOKdUj0jYopZ5WSj3dhPbVNgLYWtOevwEvHP+Lp5TKBhYBTxmG8Witp24k8PlSAQvwi5rjuwOvA/cBKQQ6Qh8qpSw15xmrlNKUUu1rXjeq5nU5QAywrtZ7XAIMA/oD1wHnN/PzCnG6OKOvX0qpKOAa4LWaP9fXXENQSsUCXwILgPY17/WVYRgLgIeBNw3DiDEMY0DtcxqGUQ38F7ih1sPXAYsMwyhUSg0CXgRmAcnAs8AHSilrfe0Uoo05o687YZJ+02lOgqnI+IHfG4ZRbRiGs6EDa34RZgL3G4ZRYhhGBYGb8vWRvqlhGB8bhrHTCFgEfA6MPe6w/6lp1yLgY+C6SNtgGMZdhmHcFWn7jrPbMIw5hmH4gLlABpBW6/newDcEvo/PHffalwzD2FbzvZ1P4EIGMBn42DCMLwzD8ACPAXZgtGEYuUBFzbHjgM+AAqVUT2A88K1hGP5a7/GIYRiHDcPYU9OOgQhxZjjTr19XAdU17/8xYAYurnnuEuCAYRiPG4bhMgyjwjCM5WF+xP8c16Ybax6jpv3PGoax3DAMX80+q2pgZJjnFuJ0d6Zfd8Ih/abTnOyZikyRYRiuMI9NAaKA1bUGGBSgR/qmNVO+vycwUqLVnHd9rUNKDcOoqvX1bgKjqy3WBuDIOmdzrX8f+dpT6+sDR/5hGIaj5n1jaj0/BdgBvB3iPQ7U+rej1uvaE/hMR87rV0rlExgtgsAoywQCI0iLgMMELgijar4O5z2EaOvO5OsXBEaf59fs2fAqpd6peexdIAvY2cTzfgNEKaVGAAcJdDTerXmuE3CrUuqeWsdbCHw+Ic4EZ/J1R/pNZwiZmYqMcdzXVQR+6QBQSqXXeq4YcAJ9DMNIqPkTX3ujZjhqloO8Q2BUIc0wjAQC07W1p4ATlVLRtb7uCBS0VBtq7Of/2bvv8KjKtPHj3+dMyUx6gQRJ6EgHaYqAiNgb9i5iw7bWd13ffVf3dXV33XV3fX+2ta2gqIhrL6vi2lZ6ERDpNYSSENL79HN+f8wASWYmmUkmIST357pykZw558xzhplnzv2U+/F/+Ps22t6Peh/YCDwWKNd8pVSklVMB/psS4HDrVS8gP7DpUKUwJfD7QvyVwlSCKwUhuqouW38ppXKA04EZSqlC5Z+/cQVwfmCOwT6gf5jDG79uDR/0tya/h3+o37XA54HWbALnfaJe+VMNw4g3DOOdaK9BiGNUl613kPumLkOCqdb5GRiulBqtlLLhf8MD/lYA4FX843QPTUTOVko1NdZUKaVs9X/wt2LGAcX4W1PPA84OcezjSimrUmoK/iEr77ewDCEFbhg+BJ5QSmUopSxKqWvxdz8viOJUHuBKIAF4UykVyXvwPeACpdQZSikL8CD+oTLLAo8vBKYBdsMw9gOLgXPxz1H4KYqyCdGVdJn6C7gB2A4Mxt9zNBp/i/V+AgEQcJxS6gHln3yeFOhpAn9vU99m6qr5+IfVXM+RIX4Eyn+nUmqC8ktQSl0QmKMlRFfUZeoduW/qOiSYagXDMLYDv8c/cXkH0Hi9g1/j75pdoZSqCuw3mPAm4W8RafxzH/4PRjn+8fifNTquMPBYAf6J1XcahrE12jIof8aal5so3y+AMvwTE4uAe4ALDMM42MQxQQzDcOOfv5AFvNZcxWAYxjZgBvA8/taZ6cD0wHkO/T/U4K8MMAyjCsgFlgYqMyFEI12s/roReNEwjML6P8DLwI2BnqSz8NcthYHXY1rg2ENZrkqVUmtDnTwwv6oW/9CaBfW2rwZuA/4euMadwE1hyihEp9fF6h2Q+6YuQRlGkyMYhBBCCCGEEEKEID1TQgghhBBCCNECEkwJIYQQQgghRAtIMCWEEEIIIYQQLSDBlBBCCCGEEEK0gARTraSUmquU+mPg9ylKqW3t9LyGUmpgjM95+Fra89j2opR6WCk1+2iXQ4j2IvVT649tL1I/ia5E6qbWH9tepG5qXpcIppRSeUoph1KqRil1MPDmjfkKzoZhLDYMo6kUnofKc5NSqnE60JhRSv2glJrVVudvrba+/sBznKaU2l9/m2EYfzIMo8O+LqJrkvqpY5H6SQg/qZs6FqmbOq4uEUwFTA+sYD0WGA/8tvEOSilzu5dKCCGkfhJCdExSNwnRjK4UTAFgGEY+/kUVR8DhLt+7lVI78C8gh1LqQqXUOqVUhVJqmVJq1KHjlVJjlFJrlVLVSql3AVu9xxpE9EqpXkqpj5RSxUqpUqXU35VSQ/EvFDkx0NpTEdg3Tin1lFJqb6AF6GWllL3euR5SSh1QShUopW5p6fUrpd5XShUqpSqVUouUUsMb7dJNKfVN4PoWKqX61Dt2SOCxMqXUNqXUVS0tR6My5SmlfqWUWh8o17vKv4o5Sqk0pdTngdewPPB7Tr1j05VSrwdel3Kl1CdKqQT8/8c9A69xjVKqp1LqMaXUvMBxC5RS9zQqx89Kqcva8lqFaIrUT1I/BY6T+kl0KFI3Sd0UOE7qphC6XDCllOoFnA/8VG/zJcAEYJhSagzwGnAHkAG8AnwW+MBagU+At4B04H3g8jDPYwI+B/YAfYFs4J+GYWwB7gSWG4aRaBhGauCQJ4FBwGhgYGD/RwPnOhf4FXAWcDxwZiteggWBc2QCa/Gv/F3f9cAfgG7AukOPBz5k3wDzA8deA7yolBoW5vorlFKnRFGuq4BzgX7AKOCmwHYNeB3oA/TGv7L53+sd9xYQDwwPlOtpwzBqgfOAgsBrnGgYRkGj53sHuLZeeYcFnuOLaK9ViFiR+knqpwCpn0SHInWT1E0BUjeFYhhGp/8B8oAaoAL/B/RFwB54zABOr7fvS8AfGh2/DZgKnAoUAKreY8uAPwZ+Pw3YH/h9IlAMmEOU5yZgSb2/FVALDKi3bSKwO/D7a8CT9R4bFCj3wDDX+wMwK4LXJTVwnpTA33PxV1qHHk8EfEAv4GpgcaPjXwF+V+/YP0b4/9H4+vOAGfX+/ivwcphjRwPlgd+PA3QgLcR+h/8v6m17DJgX+D0p8Jr3Cfz9BPBa4Pcmr1V+5CeWP1I/hX1dpH6S+kl+juKP1E1hXxepm6RuavDTlca5XmIYxrdhHttX7/c+wI1KqXvrbbMCPfF/ePKNwDskYE+Yc/YC9hiG4Y2gbN3xtxCsUUod2qYAU+D3nsCaCJ6zSYEWnyeAKwPPqQce6gZUBn4//FoYhlGjlCoLPH8fYMKhrvUAM/7WjVgorPd7XeA5UUrFA0/jb3lJCzyeFLiWXkCZYRjl0T6ZYRjVSqkv8Lec/AV/S8ttgYfb+lqFaEzqJ6mfDpP6SXQgUjdJ3XSY1E2hdaVgqin1P+D7gCcMw3ii8U5KqalAtlJK1asUegO7QpxzH9BbKWUOUSkYjf4uwd8FO9zwj0tu7AD+N/8hvcNfSpOuAy7GUwtauAAAIABJREFU39WdB6QA5fgrn0MOP4/yZ+1Jx9+itA9YaBjGWS187pZ6EBgMTDAMo1ApNRr/MAMVKFO6UirVMIyKRsc1fo1DeQf4nVJqEf7x2/8JbD9a1ypEKFI/HSH1k9RPouOQuukIqZu6cN3U5eZMReBV4E6l1ATll6CUukAplQQsB7zAfUopS2DC3UlhzrMK/wf5ycA5bEqpyYHHDgI5gXHEGIahB573aaVUJoBSKlspdU5g//eAm5RSwwKtDb+L4DrMgec89GPB3z3rAkrxt+b8KcRx5yulTgmU7Q/ACsMw9uEfwzxIKXVD4NotSqkTlX9SaFtKwl9ZViil0ql37YZhHMA/jvlF5Z9saVFKnRp4+CCQoZRKaeLcX+JvSfk98G7g/wGO3rUK0Rypn6R+kvpJdERSN0nd1GXrJgmmGjEMYzX+Lsu/42952ElgQp9hGG7gssDfZfjHh34U5jw+YDr+CZF7gf2B/QG+BzYBhUqpksC2Xweea4VSqgr4Fn+rAoZhLACeCRy3M/Bvc17C/0E69PM68Cb+bu58YDOwIsRx8/F/6MqAccCMQBmqgbPxd+0W4O9a/gsQF+rJlT8LzJQIytmcZwA7/haoFcBXjR6/AfAAW4Ei4IFAebfibz3JVf4JnT0bn9gwDBf+/78z8V/3oe1RXasQ7UXqJ6mfpH4SHZHUTVI3deW6STUcwiqEEEIIIYQQIhLSMyWEEEIIIYQQLSDBlBBCCCGEEEK0gARTQgghhBBCCNECEkwJIYQQQgghRAtIMCWEEEIIIYQQLdDkor3dunUz+vbt205FEUK0hzVr1pQYhtH9aJejtaR+EqLz6Qz1k9RNQnQ+TdVNTQZTffv2ZfXq1W1TKiHEUaGU2nO0yxALUj8J0fl0hvpJ6iYhOp+m6iYZ5ieEEEIIIYQQLSDBlBBCCCGEEEK0gARTQgghhBBCCNECEkwJIYQQQgghRAtIMCWEEEIIIYQQLSDBlBBCCCGEEEK0gARTQgghhBBCCNECEkwJIYQQQgghRAs0uWivELFQVO1ke2EN1U4PNquJ41JsDM5KQil1tIsmhAjDpxv8vL+Csho3Xt0gxW5hZE4KiXHytSGE6DjKneVsK99GjbsGq8lKZnwmg9MGyz2GaDfyrSjahGEYLM8t5R+Lclm2q5Q4s4ZhgFLg9RlkJcdx59QBXDw6G7vVdLSLK4QIKK528c6qvby+dDcen45CYeD/7Hp8OhePzubWU/oxKCupReevcXnZW1pHjcuL3WKiR4qN7klxsb0IIUSnZhgGPxf/zNxNc1m8fzFWkxUDA4XCZ/hIjUvl5uE3M33AdBKtiUe7uKKTU4ZhhH1w/PjxxurVq9uxOKIzKK1xMfO1VewuqcXh9hHuHRZvNaEpxaszxzNxQEa7lrErU0qtMQxj/NEuR2tJ/RR7by7L44kvtwDg8uoh9zFpYDFpnD2sB09deQJWc2SjxTcVVDJnyW6+WH8Ai0lDKTAMcPt0xvZO5Y6pA5h6fHc0TVqTu7LOUD9J3dS2qt3V3PPdPWwp24LT68QIc5dhN9sB+Nupf2Nqr6ntWUTRCTVVN8mcKRFTxdUuLnhuCdsLq6lrIpACqHP7qHF5uXnuKr7ferDdyiiECPbMN9v584KtuLx62EAKwKeD06Pz9eZCZs5ZiccXfl+ASoeHq19ZzhUvLePTdQW4vDo1Li/VTi81Li9ur86K3DLueXstk//yPTsOVsf60oQQnUSVu4prPr+GjSUbcXgdYQMpAIfXgcPr4FcLf8UXuV+0YylFVyPBlIgZt1fnuldXUFLjwqM3FUY15PTo3P32T2wuqGrD0gkhwvl47X5eXrQLh8cX8TFOj866/RX89wfrw+5TXuvmwucX89PeChweHV8T9UKt20dhpZNLXljKz/sqoiq/EKLz0w2du765iwO1B3Dr7oiPc/qc/G7Z7/ip6Kc2LJ3oymTOlIiZrzYVkl/hwBvihql28w9U/fgJntL9aFY7lsz+pEy6ClvOcAAcHh9//Worc285qb2LLURYhmHg3LwZz549+Gpr0eLjsfbpi234sE4zudmnG/zhiy04PaF7mJr67Do9Ol9uOMADZx5Pn4yEBse5vTrXz15JYaUTjy+yxhUDf1B1w5yVfHn/FHLS4lt7eUKITmJ5wXJ2VuzEo3uCHitfXE7Jv0twF7kx2Uwkj0sm64osTAn+Odkun4u//vhX3rngnfYutugCJJgSMfPyD7uocwe3bFet+pjKlR+Qcfbd2PqNRZnMOHavwbFj5eFgCmB5bikHq5xkJdvas9hCBNEdDqq+/JLSV2fjOXgQpWkYPh/KZMLQdcyZ3cmYNYuUCy9Es9uPdnFb5YdtRbjC9EhF8tnVdYPXl+bx2EXDGxy7YOMB8kprQwZSzTWu1Li8PP3Ndv7vqtExvlohxLHqtY2vUeetC9pesqCE4gXF5MzKIXFYIp5yDwVvFZD3VB79HumHFpjXuaN8B7mVufRP6d/eRRednAzzEzGxtbCK3JKaoO26q5aKJW+TftZdxA+ehGa1oUxm4gdOIG3aLUH7v7k8r+0LK0QTnJs3s3Pa6RQ+8SfceXkYDgd6bS2G0+n/1+HAs2cvB//8JDtOm4Zj46ajXeRWeXnhLmpDNIJE+tn16Abvrd6Hs1FA1lTjStl3r5Jy8lXk3DOP7LteJ2ns+Th2rDzy3AZ8vv4AVc7gFmghRNdTUFPAuuJ1Qdt9Dh9FnxTRc0ZPkkYlocwKa3crvX7RC3eJm8pllUf21X3M2zyvPYstuggJpkRM/JhXTqjEkK78rRheN/GDJjZ7DpdXZ+H24jYonRCRcaxfT96MG/BVVGDUBbeA1mfU1aFXVrLnhhtwrAv+kj9W/LyvMuT2aD67mlLsOHikMWXLgSp2l9YG7RdN44qmFB+t2R/l1QghOqN1Reswq+DBVHU76tA9OsnjkhtsN9lMJI1KombTkXrJa3hZXrC8zcsquh4JpkRMVDk8IbN6+RxVaPHJKC2ytaSqHN5YF02IiHgOHmTvrbOaDaIaMxwO9s66Dc+BA21Usrbj9elhs/FF89lVPh9Fm7fjys3Fc/Agizflh0w2EU2A5vD4+GpjYfMXIYTo9Krd1fiM4J5uX40Pc6IZZQqew2pOMeOtaXhPUesJbuQRorVkzpSICYtJoSmF3qh7ymRPRq+rwtB9Ed2UWUJUiEK0h7K5c9GdzqDtZ+7aSanP16DlaUH//mSaLYf/1p1OSl97nR6PPNwOJY0dTSlQECq7cDSfXd3tovqNuewv241eW0tujwl4+k0B1bC9LtrGlQqHDPMTQoBZM6MIvj8wJZrw1ngxfEZQQOWt9GJObHiba9bktlfEnvRMiZjolhgXcvHOuOwhKLOFuu2Rda13T5LkE6L96W43Fe+9D57QN+8vZOewZtDgwz/1AykAvF4qPvwwZDDWkWmaIjEu9M1FNJ9dPc7OyGf+woAFX3L8ooVkzro1KJCChgFaROXrJBkThRCt4/Mm4PUF1wfxA+NRZkXVmoZLq/icPqrXV5MwrGGW0TRbWpuWU3RNEkyJmDhjSFbIYT1aXAKpp1xP2TcvU7d9ObrHieHz4ti1mvL/vNZg3wSriWtP6tVeRRbisOp//xsj1KS/KFV99VUMStO+LjqhJ2Yt+CYlms9uZlIc/bsduWnJSIzDZml940q3JGuUVyOE6GwWbS/msXc9eEI0wpjiTWRekknBvAKq11djeA3cxW72vbgPS7qF1EmpR/bFRE5iDiWOkvYsvugCpL9TxERKvIVzh/fg8/UH8DW6KU0+6TK0hDQql79LyedPoax24rIGkjzx6gb7KaU4d0SP9iy2EADUrVwV9Vypxoy6OupWriL1kktiVKrWcXl9eH0G8VZTk2ti3XJKPz5csz/k+nCRfHbjrSbumNq/wXOcNSyLP325Jeh89QM0pZmw9RuD0sw489bh3Lu+QRKKhDgTV4yTxhUhurJlO0u4/a3VOD0acRXjsaSuQGkN53l2P787pgQThe8W4i5yo9k1kscm0+uOXmj1GnV8+FhSsIRzPjiHydmTuXXkrZzQ/YT2viTRCUkwJWLmtlP78+/Nhfg8wTdlicOnkTh8Wthj48wa10/oTZw5srkUQsSSt6K8ycfvzd+PORAsnBgfz9+zc0Lu5ytv+jxtbcuBKmYvzuWL9Qdw+/TD8xiH90zhzqkDOHt4FhZTwx6jAd0TGZRqYUOxAyPE0LzmPruGAZeMyW6wrWeqnfF901i6szRo/4gbV1CcO1waV4ToqkpqXNz25urDC4q7yyZjSfUvoVCxvMK/SO8BN5pNw9bbRs8be5IwKKGpU+L2uQH4Yd8PLC9Yzt2j7+amETe16XWIzk+CKREzI7JTuHFiX95cvgdHmEVAQzFpkJNm574zjm/D0gkRnhbX9Fy957NzmJTQ9Jc0gLIdnTl/u4pruHf+T+SW1ODxGYeH3B5KCLMhv5L//uBnfvOR4jfnDeHaCX0AMAyD8rfn819fvsO9E++iJvKPLQA2i8Yz14wm3hr8VXLn1AH8tLci5FpTzQVo1kDjSqh5mEKIrmH+yr0Ne8wNM2Ci+KuDlHxRTM8be5I0MgllUlRvqKZ6bXWzwdThU2Hg9Dl5Yd0LGIbBzSNvbpuLEF2CBFMipv7nvCEcrHLx+fqCkMOGGrOaFD1T7bxz+8kkhJkIL0Rbs+Rkg9kM3lak5jeZsGRnN79fjK3dW87MOauodXtDrvV2yKGFeX//+RbySuv478k9OfDb3+IpOMDk11/gHVMq181eQa3LSwQfXWwWjccvGs45YXqPThnYjWmDM/lu68HDLcuRMGmKHsk27j59YMTHCCE6F59u8PrS3bi8gbpDebH3+Qe+OhfFHxeRfWs2KeNTDu+fPCaZ5DHJIc9Vvrjc34tV5MZkM5E8LpmsK7IwJZhw+py8+POLjOg+ghN7nNgelyY6IWn2EzH1/dYi/r0p8rVhJg7oxuf3TSFTsviJoyjl4otRrRxiqiwWUi+5OEYlikxucQ0z56yixtV0IFWfw+PjzaW7+ctdT2Dp2ZM+78zH2qcPI3NS+OLeKZzUL504sxZymQJN+edI9cmI59WZ47n6xN5hn0cpxdNXj+bEvunYLZG9thaTIjMpjnfvOJlkm6X5A4QQndKiHcW4662BZ07aiGaqxpFbE7RIb/nicnb8dgebbt/E1vu2UvBGAb5af+NRyYISCt8vpMdVPRj24jD6/29/3KVu8p7KQw8Eak6fk5d+fql9L1B0KtIVIGLmq40HeODddVG1Qq/aXcbOohpG90ptfmch2khc//7EHX88zg0bW3wOa+Ac7enXH66n1h26N6128w9U/fgJntL9aFY7lsz+pEy6ClvOcBw+gzf6TmXW3WegWY9kzOudEc8/b5/IvrI63liWxyfr8ql2etENg3irmYn9M7h9an/G9EptMqnFIVazxtybT+LJBVt4a8UeNKVCDvuzmhS43ZzcK5VZZw3j+e92sr+iDpdHJy3ByuQBGVw6NidsGnchROdQXO2itNbFsp0luL1H7iWsGT+gTO6gRXpLFpRQvKCYnFk5JA5LxFPuoeCtAvKeyqPPL/tQ9Im/FytpVJL/PN2t9PpFL7Y/tJ3KZZWknepPlf7TwXXk1+STndj+owvEsU++mURMbDlQxX+9+3NUgRT4W8lnzlnJdw+eRvekuDYqnRChGV4vNQsXUvX1NxgeL2ga6A3fw98OaH64mbLb6Xb7bQ227Syq5tN1Bewvd+Dy+shIiOPk/hkhk0C0xJ7SWtbvrwzZI1W16mMqV35Axtl3Y+s3FmUy49i9BseOldhyhvt30jTmrdzDL88aHHR8r/R4fnvhMH574bBWl9OkKR65YBj3nzmIT37azysLc9lf4cCkFD7DINlm4bqTemHdsJYP97i54y0nLq+vwVDDRduLeeLLLVw8Opu7TxtI74z4VpdLCNExeHw6X286yMsLd7GtsBqrWcPt0/H4/JWAFncAzepPZlN/kV7drTcZLBX/qzioFwvAZDORNCqJmk01h4Mpj0/nvs9f4v2r/oAWYqkIIZoiwZSIiWe/3Y7TG3r2elMt5ABOr86by/N48Ozgmzoh2oKvupqyN96kfN48dLe7VWnRXSYLjlEnMvicczAMg682FvLSwl1sL6zGqxsN5g5+9NN+fvORYsbJfbh5cr9WNSDMXZZ3OMFEfbqrloolb5Nx/gPED550eHv8wAnED5xwpNxenTeW7eG+04/HHIPgrjmJcWZmnNyXGSf3RdcNat1e7BYTXt3gznlrWFmdjkPTIUTymkO9WR+s3sfnPxcw56YTObl/RpuXWQjRtpbsKOHu+Wvx+vTD8zrrD+8DMNn3Hf798CK9a6vQ4rQmg6W6XXUNerHqM6eYcexxHP5baT62lm3l4Y838OfLRkbU8y7EIRJMiVYrrXHx/bbiFreQu706by7fw31nHB+TFnshmuLJz2fPDTPxlpZiuFytOpey27GPO5H7+l3OdQtz2ZBfyX+2FYUcygZQ6/Jvn71kN/NX7WX+rJMZ1jP0pOnm/OvngsMtt/W58rdieN3ED5rY7Dl8usHP+ysY1ye9RWVoKU1TJNkseH06t879kdV7yo9MNG+Cz/An0rj59R+Zf9sExvROa4fSCiHawhfrC3jw/QhGtGhOwF93muJNZF2aRcFbBaROSMWUaAIDqtdXU7ullh5X+xPimFPMOPc5D/diHQqoDqVUd+51okyKvP/Lo/v07iQMSkBXDj5dV8D4PmlcMV7WuBORk2BKtNo/f9xHqDacSFvIAbw+ne+2HOTcEce1cWlFV+YtKWH3VVf714PSoxuSWp+WkICyWMiYdSvpt9zC/EonZz+9CLdXjyiLpdur4/bqXPnKMj69ezIDM5OiLkO1M/RcKZ+jCi0+GaU1n/RBKSitcUf93LHy/Pc7Wbs3dCDV5Jwvj48bX1vFiofPCJmWXQjRsf2YVxZZIAVgmPDnS/Pv2+28bphTzBz85CC+Kh9b/2sr9n52MqdnHj7EW+nFmmnFVeiiak0VKSelUPJVCcVfFNPj2h4UvlNI1qVZmNPMR1Kq6zYcHh9Pf7uDy8flSO+UiJh8C4lW+zGvLOTNUDQt5LVuH+v3V0owJdrUvl/cja+yMiiQOnPXTkp9vgbpTRf070+muV5GOZMJU1oa1r59yLjpJhJPOw1l9lehn/yUj24YIQOppoKCOpePa19dydJfnx71mkqhhvgBmOzJ6HVVGLqv2YDKMMKfp625vTqvLd2NI8TNVCQ92l7d4NN1BVx7UviMgkKIjunhjzaEDaQa15nWnt3JulyROORIfZw6KZWk0UlsfWArPW/oScpJR9Kk+5w+qtdXk3VFFvb+dgrmFYCCgx8dpMdVPahcUYkl3ULqlFQ0i0bymGQM3YTPmQVAeZ2bVbvLmCBDiUWEJJgSrVbl8ITcHk0LOUBZ7dFrIRedn3PLFlzbt4ddS+qFZhbmVWYzfee/jbV3w5t3j0/nlUW5IW8MmgsKDKDO5eWrTYVcdELPsM9tGAb5FQ62Hqhma2EVWw5UH16Yt7G47CEos4W67ctJGHJK2HMCKCDZfnRSkP97UyF6iGuItEe7zu3j5R92cc2JvaQFWYhjyIb9lewvd4R8LHSduZLqdXPwllv960UdcKPZNGy9baROTqVgXgGaTWuQzc+SbiF1kj9YMiWYKPxnIYbboOizIpLHJtPrjl5olvoNWApPxckAONw+Zi/ZLcGUiJgEU6LVbGHWkImmhRyQRXtFmyqd+waGJ3TgHwlD1yl7ax49Hnm4wfavNx1sVVBQGwgKDgVTdW4v2wqr2VpYzdYD/sBpS2EVdouJocclM+S4JM4alkWNy8viHcVBC+xqcQmknnI9Zd+8jNJM2PqNQWlmnHnrcO5dT9q0Ww7v69WNiJYlMAyD5btK2XygimqnF7vVRHaqnTOHZmG3tmx9rjlLdh+ecF5fND3axTUuNuZXMTInpdl9hRAdw+zFubhCJKwKX2dOpm7L9xyYv5qeN/YkaWQSyqSo3lBN3bY6si7PovDdQtxFbjS7FhQspU9NR7NoFP6zkCHPDQlZJp8zB8PjD54MYMfB6thfuOi05O5VtFqfjARW5JYG3dRF00IeZ9bISbM3uY9P97E4fzFvbX6LvKo8nF4nNrONfin9mDlsJpN7TsYUYS+Y6Fp0p5Pqr74CX+jEEBHxeKj48EOy/ufXKNOR99mri3NbHRRsP1jNDbNXsr/CwYFKBwMzExnSI5khPZI4Z3gPBvdIIiOxYea//t0TWLW7DEeI7HfJJ12GlpBG5fJ3Kfn8KZTVTlzWQJInXn14HxMGlwzv3uSco0qHh/dX7+PVxbnUOL2H0xWbNUWcWUM34MrxOdwyuR99u4Xv1Qtlb1noDIrR9Ghryn8eCaaEOHb8sD24EQjC15m6q5aSBevIntWLlPGJh7cnj0kmeYw/gU/61KaT6NRPqd44u5+hW3CXnNFgW7gkQkKEIsGUaLXrTurNJz/lB93URdNCDjA9zDAnwzB4c/ObzNkwB5fPRZ33yE1YlbuKoroiNhRvwGa2MWvkLGYMnSHDfkQD3pJS/xpSTbg3fz/mwPvmxPh4/p6dE+JEXtwVlZjT/FnkFJBXUhvyfNEmghhyXDKPTh9Gv24JEaUqH5WTSs9UG7uKQz9/4vBpJA6fFvZ4Mzpn/OMxSmsuJm3GDDSbrcHjmwuquG72ClwePeiz7dUNvIGbjXdW7uW91ft4/KIRXH1i5BmwXCGCQIiuR1vXodYVetimEKJjCveZDVdn+oMsD3E5V2Lo36K06EcY1E+pnnLikcYXQ7fgLj0VX23DBdcrHR7e+3Ef00/o2eLed9F1SDAlWm1kTkrYm7pIWsiVgqmDu9MtMXjNHY/u4aGFD7EsfxkOX+gx1gB13jrqvHU8t/Y5fi7+mSenPIlZk7e38NPralGaRlOpFp5vZs4UgMMHp/7+C4rj0zEg5HIAh0QTFFhNGsN6JnF8VnRZ/f54yUhunrsq6sWy7RYTF4zKYcrtz1L89DOUzTuP7vfeQ8rFF6PMZjYVVHLly8sjap316AYe3eB3n23E5fExc1LfyMpgNYXs0YumR1vTFIk2+ZwLcSwJ19YZrs48FGR5q6bitnqwZiyMOqCqn1JdaYqE4YkozUrFit7UbNxHWqN2J5dX57F/beKxf23i8rE5/PKsQaQlWKO9VNFFyLeQiIl7Th/Iwx9tDDnkqLkWcpvZxB2nDgjabhgGDy9+mKX5S3H6nBGVw+lzsnDfQh5b9hh/mPwH6aESAJgSEjBakQr9ELtmsOz3F2FKOdKyecLjX1MZIglLVIkglCKhBSm+Jw7I4M+XjuQ3H4fPjNWY3aIxrk8aT142ErNJI+f553CsW8fBp56i9PXXibvnAa5fpUc9zMXp0fnTgi0cn5XExAHNT9we0D2RkpqyoO3R9Gj7dIOBmYlB5xBCdEyGYWC3mPD4gnunwtWZ9YMsd8mZ6K4s4jIXoMw1oDwo1bBVy9AtoALnNzSU5q/Lup3XDVOylaLPinG9ko9mjceaaW7QuFvfoTrwnz/u5ZstB3nv9on0zoiPxcsgOhkJpkRMXDI6m2+3FPHdloNRtZLbLSbuOLU/4/oEL7759Z6vWbh/YVAgVb643J/Rp8iNyWYieVwyWVdkYUrwt2Q5fU6+3vM1Z/Q+g2m9wwdxouswde8ek/Momw0tqWHvUe/0eDbkVwbtG21QEO2co0MuHZtDaoKVe+avhcCitqFYTQqU4uLR2fzxkhENhhLaR4+mz1tvUbNwIS/M/RZH5okQome3qTTv4A+onvp6Gx/eNSno2MZum9KfjfmVIcsbSY82QJ+MeAZF2ZsnhGh/1U4PH67Zzz8W54ZtqAlXZ+puByh1OMjyVo/EWz0Czb4Xa/oiTPb9KM2FYZgxvMlkGmdQWDAYl8+FOWU1JlsByuTA0OOIH5SJJfNEDG/ziXcO8fgMiqqcXP7SMr68fwrdk4JH0YiuTYIpERNKKZ6+ajT3//MnfthWHLKHqjG7xcTMiX24/8zjQz4+Z8McHN6GQ/tKFpRQvKCYnFk5DdKg5j2VR79H+qEF1upxeB3M2ThHgikBgGa1knLRdCo+/ChsavRmWa2kXXsNqtHcq1lT+vHwRxtaFRT0Sre3KiiYNjiTNb89iy83HOClhbvYW1p3eN0qXTfQNMX1E/pww8Q+ZKeGTvSilCLh1Kl8tNSLq9oV9Hgkaz8BbMyvJK+kttngcNqQTKxmLWzw11yPdoLVxF2nBfdoCyE6li/WF/Dg+z+jKdVsj3e4OjNp9LlBQVbtxmJK93pIm/abw8fHmTUuOnUAcwpyMXxmPGVTiXRAYFONRbrhX3/qvnd+4p3bT27FqyE6IwmmRMxYzRovXj+Wt5bv4cUfdlHt9ATdKGkK4swmeqba+NXZgzlvZOhFeneW7yS3MrfBNp/DR9EnRWTfmk3SKP+Np7W7lV6/6MX2h7ZTuayStFOP9HBtLdvKnqo99EnuE+MrFcei9BtvpPKTTzFCBFPfDhjY7PFKKdKuuy5o+7kjevDIxxvDHtdeQYHNYuKysTlcNjaH/eV1lNa4cft0UuwW+mYkRLQo8PLc0pCTwyNN8w7+4O2N5Xn8bvpwmmLSFPecPpCn/r09osaX+hQQH2fm3BE9ojpOCNG+3l6xhz98sTmqESvh6kxrzyHNNkwZwIyJvZk4ICOq+aSRLhS+dm85e0vrZLifaECCKRFTSilmTurLDRP7sGxXKa8v3c3uklrKav03dmcMyWLWlH6Mymm6i/3TXZ/i1Rve1NXtqEP36CSPS26w3WQzkTQqiZpNNQ2CKZ/u47Odn3Hv2Htjd4HimBXXvz/20aOp++kncEe3QLSKiyNhyhQsPYJv3uPMJm6e3JfZi3NxRJkIAvyNEOeNCN2o0FI5afHkpEX/Zb+5oAq3L/gaoknz7tEN1u6tiOguQ6YoAAAgAElEQVT5bpncj3X7Kvh288GoXrv4OBPv3DaBOLNk2RKio1q4vTjqQKopzTVMKWDKwG5kJtnITLLx6szx3PHWGlweH74mkgVF1VhkGMxdtptHpw+H6kKozAdPLcQlQVo/sEc+fFB0HhJMiTahlGLywG5MHtgNgBW5pfy/r7fz3LVjIjo+vyYfn9GwtdpX48OcaA5aIwLAnGLGsafhkECv4SW/Nr+FVyA6o5znn2P3pZfiOVgU+XA/iwVLdjbZf/1L2F0eOHMQP+2t4Me8MlzeKIICq4m3Z50cduHr9lbl9OAJcdcRTZp3gOoQCTlCUUrxzNVj+M1H6/l8/YFmhwDFmTXirSbeuf1kBmbKXCkhOirDMHj0041hA6nm5l+2hM1i4hfTjvTyTzm+O1/cN4WXftjJZz8XAIQsT1SNRT6D91bl8WjJQ7B/NZgPzZ8ywOeGwRfApHsge1yLr0MceySYEm2quNrF/JV7eG/1Pg5UOjnh8a9Jspk5c2gWN0/uS5+M0PMqnN7g7H1NLbrnrfRiTgx+O4c6j+i6TMnJ9H3vPfbceBOe/fsxnE2/P5TdTtyAAfSeMxstPnxPj0lTzL5xPL94ey3Ld5U2O2zNYlLYLSbevHUCw3omN7lve7JbTZgUQa240aR5P3SeSJk0xV8uH8X5I4/j5YW7WJtbggF4ODIsMcFqwmLWuHlSP26Y2Id0SVEsRIe2dm8FRSHmXkLTQ+p8VcUtCrLsFhP3njGQcX0aLt7br1sCf73iBB6dPpwX/7ODVxbtxtdoxeBoG4vqPDquvJXEKS/4Gl3j5k9g+1eQNQyuex/im15MWHQOEkyJNpFf4eDRTzeyeEcJCg631lc6PFQ6PLy9Yg/vrNrLiOwUHr9oOCOyUxocnxoX3FV+eNG9NVWknHRkf5/TR/X6arKuyAo6Js0WnCVQdG3mjAz6ffA+FR9/TNns2XjLKzDqjiwEjVIomw1zZibdbptF8kUXoVmbv3m3WUzMnjmeT9bl89LCXewvc+Dy+qj/vZ1gNWEAV5/Yi9tP7c9xKaGTQRwt2al2bJbg9Z+iSfMO/gyH0VBKcdrgTE4bnMmyq2ay6qJbKU5Ix+nxkZEYx/g+aZw+JDOixYyFEEffq4t24QzRqNTUkDpvWQFl373abJKbxuwW/7zTu6aGn3uaGGfmjKFZvLViL9XOhqMSom0sMuHDSRxxhBjdYOjgqYMD6+HlU+D2HyAxs9lzimObBFMi5jYVVHLdqyupcXrxhVnV1KMboBus2VPOlS8v58UZY5k2+EiFc2KPE/lu73fUeY/c5JriTWRekknBvAI0m9Ygm58l3ULqpIYBWLw5nnFZ0tUugmk2G+nXXkvaNdfgWLuW6u+/x1tcAkphyexO0llnYRs5Mup1yjRNHU4CsTG/kg/X7qegwsHSnSWc2DedC0b15MJRx3WYYX2NnTUsi//5cEPQ9mjSvMdbTdwwsWVJX7wlJWTkbeP+ayajzPL1JMSxakVuWchFzcMNqYtm3hKAWVOYNMWI7BTuP+N4Th3U/PIXSTYLeohCRdtY5MVEInVN7+RzQ00RzL0Q7lgEFluz5xXHLvm2EjG1t7SOa/+xgipn5OmnHR4fd81bw9uzJhzuoj83ZTB/DjFEr/v53TElmCh8txB3kRvNrpE8Npled/RCszRstVZKcXafs1t3QaJTU0oRP24c8eNiH3SPyE5hkLeCyn/9h282rGPYgTh67D8Od9E44s44AxVBb1d7i7eauWxsNu/+uA9vo6EwkaZ5T7aamNi/+UV7AXYW1bCzqJpqp5d4q5n0DavImTBBAikhjnHh5j+GG1IXzbwlgIGZCbw8Y3xU6/P1Sosn1Nrt0TQWAWRTwoBnq6jzwO77E0mw+hvdZq91M2+9hx9uCpRJ90DlPtj4AYyZEXE5xbFHvrFETN09fy01IVIrQ9MTTp0enVlvrGbVTRlYVjyLPW8JF/c/gQ/q8vAaDc+XPjWd9KlNj0O2aBYuP/5yrKaOd8MqOjfDMKj5zw+Uvvoqzs2bMXw+xgSSXZQDFR9/jPrfR0m77jrSb5iBOUYLCsfKLaf048O1+4OCKWg+m5YNnct/+oyKDypIvfzyoDW5ANxena83F/LyD7vYWVyDWdPQDQNNKbxORbe0s7hr5R4uGZ1NQpx8RQlxLNI0IEQ8FW5IXbTzllLs1qgXOrdbTVwyJpv3Vu+lcdLSSBuL7Di5w/w5/4t/bumzK908PKWJRXw9dbDkGQmmOjkZgC5iZlthNTuKqglxD0bVqo8p++5VUk6+ipx75pF91+skjT0fx46Vh/dxO2r5bt5fodfJcP96bj73BeLMLVtpPM4Ux8xhM1t6KUK0iOH1UvA/vyH/wQdx/PQThssVlDXQqKtDr6mhbO5cdl04HcemTUeptKEN6J7IoxcOx26J7uvBZtaYPLQHd//xF1S8/wF7bpiJa+fOBvvkldRy6t/+w68/WM/GgiqcHp0al5c6t48alxenMrHfrfHEF1s4+c/fsTqvLJaXJoRoJ8k2S8jt9YfU1Vc/yIpErbtli6/fekpfLCEaecDfWHTcjc/Q+5cf0uueeWRe+Ri2nKEN9jFQXGpaAsBDk6w8tcxFhbOJvOsAVfmQv6ZF5RXHBgmmRMzMWZKLJ0Ra6ENjodPPuov4wZPQrDaUyUz8wAkNus9rDSsvJd8DE38BcYn0TOzJi2e8iM0U3Vhju9nOK2e9QlZCcEIKIdqKYRjkP/QQ1f/+N4bD0fz+bjd6ZSV7bpiJc9v2dihh5K6b0JuHzhmCLcKAym4xMXlgN164bizxQ4fS9535JJ9/HntumEnRs8+iu1zsLKph+t+XUFTlDEpw0Vid20e108uMOStZurMkFpckhGhHl43NxhpiGZP6Q+rqti9H9zgxfF50twOUCgqywkkKE6w1Z2BmEhP6ZxAXwSLmjdlxcYPpGxKVfwrC+J4mTutr5qllobMWHuZ1Q+7ClhRXHCMkmBIx89m6gpAL40UzFnprYQ1FVUfmSo3NGstr57xGkjUJu7npzGd2s50Uawpzz53LqO6joi6/EK1R9sab1PywMCjd+pm7djJm+zbG1fsp8h5Zh8moq2PvTTehN5Omvb3dcko/5tx4ImN6pRJn1jBrwTdG8VYTPZJt/Prcwbw6c/zhxBrKZCL9+uvp98nHuHflsv7SK7j6hUXUuLwhe67DcXp0bntzNbnFNbG6LCFEO7hhYl8Ik8An+aTLSDv9ViqXv8v+569n/0s3UbvhW5JGnxsUZDl2rab8P681ON5iUoxoxZISL14/lpw0O9YosoPacDFB28z/mN9psP330+J4fpWb4tom1hc0vFArjUKdmQxIFzHh9PhCLvYJ0Y2Ftpg0iqpdZCYf6Y0a2X0k31zxDV/kfsFrG1+j3FkOgG7oaErDwKCbvRs3D7+ZC/pfQLwlurTMQrSW4fNR+sorYXukXsjOYVJC+PH9hstF1YKvSL30krYqYoscWng7t7iGucvyWLe3gmqXF5tFo3d6PDdO7MvEARlhsx5asrLIee5ZPpzzb2q2OjBMwa3JzS3e6fLoPP/9Tp6+enSbXqsQInaSbWa6J8aRXxG6Tgw3/9Lac0iz85Y0pbjh5L4tLltCnJmP757MTa+tYmthdZOLhSt07Lg5S1vNU5ZXMKmG9zkjMk1cOMjMk0vcDO3eRHAWou7r0rwu8HnAmhA26D6WSDAlYsLl0dE08EUx4TQUpcDdeGYokGBJ4KrBV3HloCvZULKBPVV7qPPUkWBNoG9yX4ZnDI86jbUQsVKzaBG6293i4/W6OkpffbXDBVOH9O+eyO8vHtGiY3XdYN4BE64QNxNNLd55KJjyGQZfbjjAYxcNJ8UuNyRCdHSFlU6u/sfyBqNMItVckhuAE3ql0jujdY2myTYL7985iT99uYXXl+4O2WNeu/kHPD++y8HSfN6LM9jeQ+ORKXGc0rvhrfPjp9kY+0oND04MM8fbZJW1pgAK1sHyv8OWf/lTxyvNvy5Xz7FwygMw6DwwHZthybFZatHhJNrMIbN/QXRrOOi6EXbiKvhTWY/qPkqG8YkOpez1uRi1ta06h+fAAZxbtmAbOrT5nY8hi3eWUBdisng068poSvHB6n3cOqV/m5dXCNFylXUeLn9pKYVVLnzRjOmNkM2i8cAZx8fkXCZNsWRHSdikWZUrP+CMc8/hrUH/Id3s5KudXj7d6g0Kpgama1w93MJzq9yMzAzRO6UUDLkgJmU+Jh3cBB/cAhV7/PPHjECruxFoOM9fDR/f6e+9O/dJOOGao1fWFpJgSsSESVP075bAruLgG8qo1nBQil7pTc+NEqKjceXmNvn4vfn7MQd6Tk+Mj+fv2TnBO2ka7ry8ThdMrckrC5lwIpq5lA6PjyU7SySYEqKD+9UHP1NcHTqQam5Ib3PsFhMPnHk8kwZ2i0lZtxyoYm9Z8OK79Rt6dh0/Eat1MRalmD7YwvTBoRt7H50ax1vrPSEfI3s8pPWNSZmPOXlL4e0rwdNMY6M7MC/28wegbDdM+03bly2GJJgSMXPn1AH87rNNIccfR7KGg8WkuPakXsSZI1tnQoiOonHSicaeb2bOFAC6jq+6Ooal6hhKakMPf4x2XZmKujA3KkKIDuFglZOF24txh5g/HcmQ3nA0BVazxq/OGcStp8SuQWVTQVXI6Tr1G3p8mHjDew53mv+FTR2pg/IeSGpwTK8UDedvQyTFsMTD5AdiVuZjStEWmB9BIFWfxwHLnoOE7nDSrLYrW4xJMCViZvoJPfndZ+HXzGluLLSmFDdO7NsGJROibSmbDWpal3FOaRqmxMQYlajjsFtCB0vRzKUE/82UEKLjmrdiD6FmLkczpLc+m0XDMOCsYVncceoARuakxLS81U5PyB60xg09r/gu5ELTCvpwEIuKbB0sAMx2GHgmHH9WrIp8bPn4TnA37Pnr+0w1dR7YfX8iCVb/u2X2Wjfz1nv44aZAg6OnDr5+BEZcBvHp7V3qFpFvJxEzNouJe08fGPbmqeljNc4bcRy90iUTnzj2WHv3avU5DF3H0qt3DErTsRyXYgu53ky4xTtDUUDPVBn+K0RH9vbKvbhCrDUZzZBeTcH5I3tw06S+/Oa8oaz4zRn8/bqxMQ+kwN/Qo4Xommq8gLADG9e6f8sBIwOXEWEfhCUeek+Ay+d0imx1USvaCsXbgBDBqgHPrmwmYZPPA4v+1jZlawMSTImYunPqAC4YdVxUAZXNrDG8Zwp/vUKSSohjU/pNN6GaG8bXDHO3bthGRDZ34FhywajjQmbaDLd4Z6h1ZexWE9ec2PqAVQjRNnTdoLyu9UN6E+LMzJrSn8cuGs6Nk/qSlmCNdVEPOy7VTqilpkI19BSTygXuP/G9PganYcFhhEmUZUnw90iNvwVmfATmtit/R+Xy+nAseg7DF3po9kOTrDy1zEWFs4kEJYYPVrwIb1wMro4//F2G+YmYUkrxtytGkZ5g5c1lebh9ethFOjUFcWYTpw7qxnPXjpFhPOKYlXT66SiTKUQbHHw7YGCzxyu7nYxZszplev/jUuxM6JfOoh3Bi1ZGMpcSIC3eykn9jo3hHkJ0RW6fjiJUP0T0Q3qdTaz7FEuTB2Rg1jSg4fOFS5pVlLeZa/amMGjas1xn+o4bzN+QTjWG8t/PqNQ+MOl+GHUVxHW+IdtN0XWDpbtKeGVhLstzS1lr+Ri7Cs7iCjC+p4nT+pp5apmLP55uC7nPYXuWwj9Og1nfgT019gWPEQmmRMwppXj4/KFcNb4Xt8z9kQOVDuLMJnTDX81qSuH26Zw1NItZU/oxuldqp7yJFF2HMptJv+WWJhfubfJ4i4WU6Re2Qck6hjunDuDHvHIcnuCbpObmUtotGref2l/qCCE6sDizhgoTTkWzPAoGJDWxPEosmU0aMyf14ZWFuUHDE5tq6CkmlWd9l/Os73KUVkv6wD+iWaz8fvIfOK//ee1S9o5kdV4Z98z/iWqnJ5C51SCRppMy/X5aHJNfq+X+Cc303OkeqNgL8y6DW77usOtQdcxSiU4hI8FKeZ2bRQ9NY0thFWW1HnTdICXewkl909u0+16I9tbt9ttw/LSWupUrMZyuiI9Tdhu958xGi++88wUnDezG5eOy+XDNfhye4DkV4cSZNcb0TuP6CZ1vLpkQnYlSij4Z8eSWtG55FLdPp0+39qsLZ0zowz8WhV7aotkFhJUbS9oK3CYDdBePLnuUlLgUJmVPCn9MJ/Pt5oPc885anI3qdRWyj/KIEZkmLhxkZsBzNRjA6B5HeiyDElL43FC8FbZ9CcMuivUlxIQEU6LNfLIun9OHZHJcqp3jZPK46OSUppHz/PPk//JBapcubb6HymxGi4uj1z9ewT5yZPsU8ih6/KIR1Lp8fLWxMGQPVWM2i8bI7BRm3zgec6iJDUKIDuWOqf15/F+bW7w8iqbg3OE9SG6nnimAzGQb/3flCTz4/s9BAUFjNRu+pWrVx3grClFxdhKGjiD75orDjzt9Th5c+CALr16I1dT5G4vX7i0PGUiBwoGVBJpuVHz8NBvvbKjBZob8qmYa2dy1sPQZCaZE12IYBu/+uI9HLxx2tIsiRLvRrFZynnuWqi++pPTVf+Deuw/D7Qb9yBeFio8HwyDlkovpNmsWluzso1ji9mPSFP/vqhM4ISeF57/fidPjC7mYb4LV30I5c2Iffnn2YCwSSAkREcMwWLW7jNeW7mbHwRrq3D7sVhPHZyVy6+R+nNQvvU2Hy150QjaPfbY57OPN9fTEmU3cdmr7Lsx9oNLByt1lIVOk11e16iMqV35Etwv+C1vfoejuLVR8/3/secpDv0f6oQXmfOuGzrd7vuX8/ue3R/GPqoeaCED/uL4b36zcxtYSnaQ4xegeGo9MiWuwz8B0jQSLP7vfvkqDCqdBqq2J9+fBTVC8HboPiuVlxIQEU6JNrN9fSZ3bx8n9M452UYRoV0rTSJl+ISnTL8S5ZQsVH32MpyAfw+nClJpKwsSTST7//E49rC8cpRQ3Te7HDRP7snB7Ef9YlMvOohocbh9Ws0aFw8NvLxzGZWOzZfFuISJkGAYfrNnPM9/uoLzOjcPtazDIKq+kliU7SkiLt3LfGQO5anyvNgmq7FYTt0zuy2tL8yLqfa7PalIM75nMiOzYp0APZ92+CmbOWUmd24c3EEwpSxGGNw0MxaFbZN1VR8WS+WSc/wvsxw/Bmr4Ia8YikoZms/2h7VQuqyTt1DQA6rx1zNk4p9MHUz/vq6CgIvS8qKpVH/P0yr28dkEiFw80sJrgq51ePt0anJAixQZFtZBqI7KEFLsXSjAluo53V+/jqvE5aJpMGhddl23oUHo8MvRoF6PDMWmK04dkcfqQrAbbr3hpGb3S4iWQEiJCPt3gNx+t518/HwgbwBhAndtHndvBY59t5se8cv5y+ShMbfD9/ODZg9l6sJplO0sinh9p0RSZyTbm3HhizMsTzpYDVVz36opGQxINEvo/h+FNxl0+EV9dPwyfDfe+dRg+NxnT9mBO+Qal/IGXyWYiaVQSNZtqDgdTAHlVeRTWFtIjoUe7XU9MGIY/WNnyOdQUAgoSs2DodOh3aoP1smYvzsXlDX6/HVmg+X6mDHmfBM2fxXX6YAvTB1v429kNgyWzpvjyejs9ElXzCSm8LnCUx+RSY02CKdE65Xtgx9dH3uD2NJx9pvHF+gP8+4FTj27ZhBDHlHF90li7t5xTju92tIsixDHhd59ubDKQaszh8fHF+gLizBpPXBr7uZqapnh5xjgeev9nvtxQiNvXdEAVbzXROz2e+bedTEp8+8yVcnl9zJi9Mnhul/JQsaKU0q+34j6wBM2mYettw97PjjlRw5K6Lehc5hQzjj0N58daNAtlzrJjJ5hy18GaubDseXBV+ucn1ffzfLClwMT7YNyNYI3nu61FIZe9ObJA8yQe8SbziuVp7KqZBXo5kpDiySVuhnYPM7RbKVAdc9i3BFMieroOu77zTwbcvxpQ4A1095rjsOgG71uPp0fhbyHpHNA65ptfCNGxjOmdxj9/3Hu0iyHEMeE/W4v4cG1+yECqdvMPVP34CZ7S/WhWO5bM/qRMugpbznAcHp2PfvIniDpjaFaIM0emqNrJ/BV7WbCxkCqnB6X8a8JdOiab/71gKCt3l9EtMY5tB6vRFLg8OgZgNWloGvTLSODO0wZw3ojj2nWdya82FuIM8ZpV/fgpVasO0PPG40gamYQyKao3VFOxpAJvjRfDZ6BMDXvzvJVezIkNb6UVCo8eesHaDqf6ILwx3Z9+3BsmaZK71v/z3eOw9nX0mf/CEWYtsPoLNC/ST+Ax70weM78ZUUD1+Gk2xr5Sw4MT40LvYLZBfMecOiLBlIiOxwHvzoC9y4NbLwC8TkzAINdG+PBWyB4H1/4TrAntXlQhxLFlbJ9Ufv3henTdkCHCQjTjhf/sDBlIVa36mMqVH5Bx9t3Y+o1Fmcw4dq/BsWMltpzhADjcPl78YVeLgqldxTX86YstLN5ZgoIGazQVVDjJLa7lyQVbSY238OrMcYDimy0HKatx4fUZpCZYmTQgg6HHJbf00lvlpR92BSW/0V21VCx6l+zbepIy/ki5ksckkzA4geoHqqlaU0XKSUfmdPmcPqrXV5N1RcPXUDd0kq1H59qi4iiH2WdA9QHQQy+w24DXAaW7YPZZwJ9D7tJ4geZ3fadTbiTxvOV5rHhpaqrewHSNq4dbeG6Vm5GZIYJr3QeDzo3s2tqZdBmIyHnd/haMvCWhA6nG3LWwbxW8fh54ml7ATQghMpNsJNnM5JbUHO2iCNGh7SmtZUN+ZdD2Q3NW0s+6i/jBk9CsNpTJTPzACQ3WcwLYmF9JXoh1oZryY14ZF/19Cd9vK8Lt1YMWuwX/UEKvblBW6+b855ZQ7fRww8l9uP/MQTx4zmBuPaXfUQukdhbVkFcafM2HhqclDj8+6DFTvInMSzIpmFdA9fpqDK+Bu9jNvhf3YUm3kDopteH+ykROUk6bXUPMvHcj1BxsEEj1faaazL9VU+s+MoZv9lo3p80NvGa6F63mADYVuuet/gLNh3ytn8hJrhdwR9B/8+jUuAbPfYSC/lMh+bjIrq2dSTAlIvev+6Fw45EhfZHwOqF4G3z6i7YrlxCi0xjXJ421eyqa31GILuyDNftDpvM+MmdlYrPn8OkG76/ZF/FzbiqoZOacVdS6fBhNZxIHQDegvNbNFS8vJ7+imXX32snestqQyy0cGp7mqTgdwxc8zKz7+d3JujyLwncL2XzXZnb9YReWdAv9/rsfmuXI+ayalWuGXINFa7+1slqkZAfsW+lfELcRnwHPrmxiWJ7PzRRtfciFeesv0Fy3fTm6x4nh81K4axuXfN0dt9EwuVDe/2fvvMOjuK4+/M7MNu2qV5pE75jeu4HYuHfjAu69O4mTL44TJ07iOHbsuOACtrHjTtwbYBtM772ZjgQC9V62z8z3xyIhsSPtLkiIct/n0WM8O7N7d3d25p57fud3HolhUqejQVZ6nIz7idijDXtrMEfByAcje48nESHzE4RHZT5s+xzU+k3YOrxYidMHmQ9H47AE8rdB3av97oA7TPlhiDs7euoIBILjY2BGwITi2iHpLT0UgeCU5WCJs9bOuy51a1ZC4dd0DhY7w3o9TdO57d21Eddn6UCV2899H6zn6wdGh/VazUm1R0UziARr5Gm+sh7YGvCNSByXSOK4xJCvcW33a090mM3PqtcDsjkDHhtp4dnlHu4bYmmw79Pdyjcs0/rg1IMDz4YaNO8ecQOFvEuaXoJJCiMar8EUBT0ugg5jwj/mJCOCKUF4rH+nwYdqVjEeH9NA0SAAOqx9Eyb9palHJhAIziAGtoth57IvYd36QI2mNQbSekObATQquBcIziKcHuMal2NrVkI+TwNGAseyZE8hVe7g1wynPkvVdXblVbInv5KuaTFhvV5zEW01IRtcR47K09ZiSZyMNW0OkhyZiYRNsXFp50tPuoufruus2l/C+6sOkFVUjdunEmMz0T89nptHdqBTSnT9A1QfbP4YGjDJGNxGYXwHU6N9nwZKe0imnIOkGj5u1KBZBaZ4M/jM8iSJeiUWKYxzz2yH9qPg8tdP6eu/CKYEodE0WD0jKCtVQzirGKheWPs2nPsEKOK0EwgEx1CZD+tm0WfNDP7odKPPA0lXQT5yvYhtA6MegT5XgeXsa3gsENTFZ+RLTf2aFUeP0JmgBEcjfX3qMGPxfmPThmUfknThI9i7j6zdbu8yDHuXYceMV+PtZZk8c1XfsF6vueiU4sBrUOdVV54myfcTO3Qg1qT1VO8spXpHNa2mNB4g2RQbg9IG8fiwx5tr6EFoms5Hqw/y2uK9lDl9Qc2at+dU8MnabHq1juWx87szssuRlhPOYtAbt6x/6lxro32fJAmetc3iFt/juA0+z4Y4pKdwpfYcHye+SfvqLYFxGEgNMdsDfa+G3AmTnoQwFgZaElEzJQiNu6xRw4m6qxiNonrBWdTEgxMIBKc9O76Dl/vBsv8guUqJllxIflfgmuFzBv6K98K838NL/aBwd0uPWCBoMX76JZ/le43vpQ3VrLj2raN04ax6+9otCoPaJxg+T13KnF7WHygJ2h5ZfRZ8ufEwejjFVs1I+yQH3VsZZ8dih15JwoTbKV85m/1//oBdv95N8fxSYgY0/BmZJBNWxcoFHS9g+sTpKCdp0u/1a9z74Xr+MWcHOWVunMcEUhCQcXr8Ghuzy7jtv2uZtSwz8ICn6ugiVQPU7fvUEMPN+3h2cho2c/ihRJRZYXS/7mQ88iPcvwaG3gXW2ED/KMUMSBDfHs77Ozy2B8576pQPpEBkpgTh4C4P/PAayExB6FUMIPAc7nKIOU0a2QkEguZny6fw7YMBSV8ovNWBBpNvTYA7FkBK9+Yfn0BwCrEpu4wHP96AT204KGmoZiV2xJR6++k6XLUfhmkAACAASURBVNa/TcjXLKj0YDEpeNX6Mr9I6rMgYHhR7VWJtrbs1PPecZ357WebqfYEy8yOladJpnJMCavQ1ZVQJ1yJtplRNT+XdL6Eqb2m0imu08kYOhDISD3w0QaW7CnE7QsvK+T2aTz3w05sZoUbejrCskIP2fdJU7m0X1sS0tJ58KON+DTN8DOFQBCl6Tr3je/MAxO6IEkSJLSH8/8R+PN7A9brlujTIng6FhFMCUJjtoPeuLY1rO7VuhZ4LoFAIAA4vB6+ebDhZpGG6IGV1XcvggfXgy0u9CECwRnCn77a2uAE2sgIIvXaO2trl+pikiWuHNgWuyX0NNDtUzES8Edan6XIEm5fywdTk3qlEWMz4/KqNKCWrEX3x+EtPB9v4SSUqGysVhfn9U7lxiE96Z3UG3sLzGk+XnuQpXuKDM+DUM2an/puO0MzhtNFCp1NCtn3SZLBnsiYrmbWPTGJ+TsKeGPxPrbnlGNWZCQJ/KpOosPCXWM6cdXgdsTaGnA5NFkCf6cpIpgShMaeCAYWmMcSehXDD47kph2bQCA4fVnwt6BAKiyHUPRAlmrTxzD8npM8aIGgZdiTX8mefOMebOEYQdTFYpK5e2znsF431mZGNZDnRVqf5VO1hifTJxFzVS4fDzvApT8nU+WX0Q1DxWNRMPs7Mbx9Mi9cNLjFmorrus5rC/cdd7Nmv6rz5He7uNUygXG+uZhpfKH8z+OsvL/FwKhCNkP/G49I88CkyEzu04rJfVpR7vRR4vSiahqxUWZSoq2BTNQZjKiZEoRGMUOvK0BqfOWp7ipGEJIM3S8M9AoQCASC8sNwYIXhQyH7nECgjmrFS4TV8EYgOAN4e1kmPi04GxFJo14Am0nizZsGk5EUXlalTXyUYbgRSX0WQOu4KCymFpp26jrsWwjvXQ4vD6Djisf5Qv4/EqnASohrDYH6sok903hj6qAWC6QA1mSWUOoMHm+454Bf01m5vxjfkLsxmYID27D7PskKDL/XcIxxdjMdkx10SY0hNcZ2xgdSIDJTgnAZcT/s+DpkXUODqxgmG4x8qJkGJxAITjvWvt3gQ2E5hEKgBjNrKXQc2wwDFAhaHo9fpaTai8ur8tMv+agGCr9IjCAAHv1Vd0Z1CV8lYjHJXDc0g/dWZgXVaoVbnxVlVrhn3MmrK6qH6oMv74FdcwKLMEfoKh9mvvUx3vNP4h31AnyYqObogq9JllBkiZ6tY7lnXGfO753WpIFBtcfPt5tz+CW3gjKnj2ibiS4pDi4f0I7EBlwWP1h1AJeBnX0k50CUWcEV2wkpfSgcXGXsptcYigXSh0NSeJnNswERTAnCo3VfSO4Gedvq1U9lPVLfFadmFaMekgzxGdB24MkYqUAgOB04sKxBU5tw+pwA4PdAzkYRTAnOOLYdLuftZZnM2ZqLLIEsSUHW5DVEYgRhtyi0imvkN9UAt4zswAerDmAk+TfqKXQsOjpXDGwX8eueMJoKH18PWcsMazMTpCoeNn/F/aZvWKj1Z5XWkyISsFrMtBl6BZcO7hzcp+kE2V9Yxcwl+/lq02FkSarX68tmlvnXvF1M7JHKPeM707ddfL1js4qdhkUXkZwDLp9KbrkbrvkvzBgDlXlhGVIAASOx6DS49t3w9j9LEMGUIHyu+wjeGA2uMsKpoarFGgM3zD6lG64JBIKTjLui0YfDcgjV/OAsbeKBCQQtR265izv+u479hdV4/ZphrdKxRGIEUROYRUp6op1JvdJY8Et+RH2FIJAJuXlkh5YxnljwFBxYHtLkxiRp/ErZwK+UDYENsgVyV0LKD006nB+25/HIJ5vwqqphlrHGVGLe9jx+3lXA7yf34NZRHWsfN6qVgsjOAU0PZMWwJ8IdP8N/L4ay7NBGQKaowML4Ld9BVGhL/bMJUTMlCJ+4dnDbjwETCTmMIlLZBFGJcOs8SOjQ7MMTCASnEaYGjGqOEE6fE5DA4mjkcYHg9GF/YRUXvrSUnXmVuHxqWIEU1DeCCI3UoIQsFM9f049uaTFYI6h7ijLLjO6SzO/Ob4E2Bt5qWDOjnrQPAiY3qc9VUu09+vm+tcHL+Hfr9NNUvZC7JZD5biJ+3J7Hw59sDHy3IeJRTQ8EVs/O28WbS/bXbo+1GQekkZwDJlkiLurIHC4mDe5aBBOeCLStsRhk4SzRENMaJv4J7loI0akhX+NsQwRTgshI6Qb3roABU/HKNryygaGE2R5Yweh3fWDftF4nf5wCgeDUJr59yF3+Ot7Gmxu8HK5oYFJptkNs6yYemEBw8imu8jBlxirKXD7UUH7dxxCJEYTXr9KztXHT2lDYzAqz7x7BsE5JmJXGs1uyFMhIXdyvDa9PHdgypg1bPqWhaW5YJjd+N6x8tUmGcqC4moc/2dSgnXnufx/h4AtXc2j6NPL/9yTuQ9uBQCbq+Z92sWp/MQADMhIwGXyWkZwDVpNMj9Z1yjEsDhj5ADy6A6Z8AINvhx4XB/4G3w7XfQi/3hGonReLV4YImZ8gcqJT8V7wAudunMjXY3NIzvwanIEfOlGJ0Ocq6DclIO8TCAQCIwbfCvsWgNfY6hnC6HOiq9DzkmYcpEBwcnh5wR7KnF5Dc8rGegfVEK4RhKbDqGcW8rvJ3evJx4woqHCzI6+SSrcPm0mhdbyNXq1jefWGAYx6ZiF920Wz7XA5iizhV3UkKZD18Gs6v+qVxh1jOtE/Pb7R12hWVrwEvmrDh8IyudE12PFtwOjmBPvZvbU0E59BOiocO3O3T+OFH3dzSf82LN9biL+BYDtsMxCLwhgjAxJZhs7nBv4EESGCKcFxsWBHPu3SkkkefwmMv7ulhyMQCE43Oo4LLLg0EkxBIw6hkgK9LhdNewWnPW6fyqfrD+EzmCSHnGzLbiRTJZLkI6Z/P6J7jwUarpnxazp+TeXZebvILnHyp4t71XOo0/WAdfbMxftZub8Yi0lG13UkSULVdJIcFrqmxjC2axLTbxxEXrmb5XuLKHf5kCVIcFgY2zWFhOOUEjYZug6lWQ0+HLbJjWKBkkxo0/+4h+Lyqny2/lBQEFRjZ5504SPYu4+s3W7vMgx7l2H19l2TVYLDpvB/F/Rk+s97WXfAuFY0lBmIzSxzx+iOLWrvfiYiginBcTF7XTZThqS39DAEAsHpiiwH2iX8/FS9lgthOYRCYJIz8oHmHqVA0Ox8uznHcHvDk+2hxPRNxJL0Nop9P+gKIIGkgS7jLR2Gr3Qkur/hrJDLp/LxmmzSYm3cPS5gcV1S7WXa26vJLKqudZjzHGM24fS6yC51EWVWWLaniNFdk7lqUAu49IXC54IQzXjDMrkB8ATMcnRdZ3VmCd9vySW/wo2m6yRHWzmvdxrjuqWiNBCgzNmaa+i/FYmduVmR6NkqlvHdU7GZFW55Z42hZDAUVpPCdUMzIj5O0DgimBJETG65i40Hy3j9xkEtPRSBQHA6M/Qu2D0PslcH6hPCxWyHsY9Bq3Oab2wCwUnim8059eyxazCabMvWPKLavYukOEH2Hpmk1z/WkrgcS+IK/JW9KZ6fSsXabw0lgi6fygs/7ebawemous7FryyjuNJjmCE7FpdP5Y731vLSdQM4v3erE/0Imh6TLSDTa4S6Jjc9Uxq2EPBIVj5ensnMpfspc/pwedV6fsbfbsnBalK4bVQHbhrZgVhbfYOu3fmVht9vJHbmPlVne04gqBveKYnHL+jJ03N3RBRQ2S0KH94xjHh7C2cNz0CEAYUgYj5ff4iL+rYmyhL6AiAQCAQNopjg+k+g3RAwG5jZGGGOguH3w+hHm3dsAsFJorjK2Ajh2Mm2bDuIvf1rSOYyJMXbYLcRSVaRZD9lKxZRuvBV4kZcSbsHPqDtve8QM/BCXHtW1+7rUzWm/7yHG95cRVGYgVQNbp/Gw59sZOuh8rCPOWnIcsD6OwShTG50v4c7vsrlmXk7ySlz4zwmkAKo9gQaK7/y814ueHEph0rruweWuQxkytS3Mw+HCvfR57lpZAeeurQPNrNsaEhRF5tZJi7KzP/uHkGftkIW3RyIYEoQEZqm8791h5gyWEj8BAJBE2Cxw7SvYPRvAr1LjKx5JTmQjUrqAlfMgIlPiL51gjOeer2DzMXYM95uNIiqi+pUKfwql9bT2pA0cS+yxYqkmLB3GUbCubfV7qfp8PbyLA4UOw2NDRpzmoNAQPXPuTua5P02OYNuA6XxFgx1TW6M2Kp3ZlWhJawMkMevkVfu5rJXl1NQeTTT3hR25gAxx/TpunZIOnMfHssNwzKIMis4LArKkVm9WZFwWBRSYqz8+lfdWPzYeBFINSNC5ieIiFWZxUSZFfq2Ez9KgUDQRCgmGPdYINu0ex6sfgPKDqL7XByokmnXczimUQ9BOyEtFpx5JEUby67qTraTJmaCHDzhL11aStEPRXgLvCg2hdhBsaRdnYZznxPNpxE32AFSFop9H6qzS4NjOLY2CsJzmgNYf6CUw2Uu2saHmV0+WQy5HVa+HHK3hkxuXFIUr/ouwqfWDzIbc1dUdZ1yp4/b3lnLdw+NAaBzSjQ2sxwUkNW1M5dkBVvHAUiyCXfWJtwHt9QLek2yRNe0YIfkjskOnrqsD3+4oCc//pJ3JHvmJy7KTPdWMYzqnCzMJk4CZ1wwtatkF+/98h5LDi2h2leNhITD7GBS+0lM6zWNjnGNW4EKAlS4fczZksvBEicVbh8Jdgs9WsXy4/Y8rh2SXs/9RyAQCJoExQQ9Lw78ESgfv+3fi5gxfpDhREIgOBO4tF8b1h8oDaqrOTrZfp2o9ARizrGDIlH1SxXVO6oxxZoonFtIuzvaEd0rGl+pj5z3c8j6dxaJExIxRZuQFAld92JJWkzRukMhLdZriMRpTtd13luRxR8u7Nk8H9DxEtsaOk0ItGBQjwai4Zjc6EClZuUnX79628MJMP2azr7CapbtKSSz2MkX6w81mNkK185ckSWmDm+4N1+UReGy/m1DfiSC5uGMCaY2F27mryv+SnZlNj7Nh6ofvSh5PV6+3PMl3+z7hm4J3fjryL/SNaFrC4721GVnXgUzl+zn+y25yJKEy3f0c3RYFKq9KvEOM7nlLlrHnWKrUAKB4IwjPdFOdqlTBFOCM5ZL+rXhyW+2Gz4WO/RKLGm5FH67iEMz3Sg2BVsHG8nnJXPwlYO0vb0tMX0Dvw1LioX0+9LZ/dhu3Nlu/FV+dFVHUiRKl66mZMECks57oNEsUw2ROM15VZ3FuwtPvWAKcF/8Kvobo7G68pH18GqTAHxyFLd7H0erUw0TSYDp8qnc8s5aJvdpxf3ndmHutly+2pRj2JA5lJ05QO82sXRMFg1zT1XOiGDq5wM/8/ulv8etNuwG5df9+FU/W4u2cuOcG5k+YTpDWw9t9rG5fSplTh9ev0ZslIm4KPMpm9V5f2UW/5izA5+qG/7gq4+smn28OpvP1h/irZuGMKJz0kkepUAgOJtIT4ziYLEz9I4CwWmKzaxwzaB2fLT6oKEBROJ4PymTO9fbVrmlEs2nETuofkZFsSnE9I3BV+JDMklUbKggunc0hV/lkXLl1dgyGg8CaojEaQ6g0u0Pa7+TRVZRNbOWZ/LZ+kO0kv7MO/yFVhRjlYzNIGqRZHRLNLd5fsdWf33L90gCTABZlnjmqr5EW020T7Lz/dZcw7lVKKLMCg9NFAmAU5nT3oBibd7akIHUsbj8Lh74+QF2luxsljHpus7arBLufG8dfZ78gfH/Xsjkl5Yw5B/zGfr0AmYu2UeZ07jYsaWYtSyTp+fsxO3TQv7YvapGtUfltnfXsGp/8UkaoUAgOBvJSLSTXeoKvaNAcBrz0MSuxNsthuYSkhJ8/qtVaq2M71hMcSZUl0raFWnkvJ9D0bwiNJ+Go1c7XPvWUbpwVsjxROo0ZzYYR0ug6zov/rSb819cwserD+L0quz3xHKR5++8rU6mQo+iSjdo0mu2B+zUz7kGz+2LWenpHLRLpAGmWZbIKw/MTbumxfCPy8/BZo5s2h1lDliuj++eGtFxgpPLSctMaZpOhdtHlceP3RLI0DTU4Cxc/JqfXy/6tWEg1VBRpuII/AhcfhePLnyUOVfOadJM0e78Su58bx2FlR5cPhVdp55DTmGlh//8tJvnf9zNtOHtefzCni1eHLhyXzHP/rDTUNPbWKGly6dx+3/X8vNvxpMW20gHcYFAIDhO0hPsrM0qbelhCATNSlK0ldl3D+eq11dQ4fbXX9TUg+cISrRST8ZXF3+5H1O0ieQLkjHFmcj/Ih80OPCvGVhSewbV4xhR1/zC0WN0yP1TYhp3zTtZ/Pnr7Xy2/lCQoUYVdp71X89//NdwvryWq5XFJEvlmCUNrzmOHuOuxzxoKkTFU1XlwaTsQvXXX1iu564YRkAlSxLVnqMZu6sGtcOvaTz5zXY8fg09RJIqyqxw++iO/Oa8buF/AIIWodmDqbxyN++tzOKDVQdw+1QUWUbVdRRJ4upB7bhtdMfj1oEuPrQYnxqcsi2aW9RgUWbHP3ZENgVWBordxWwq3MSA1AEn8hZrWX+glJveXm3Yh6AuriNBy4erD5JZXM2MqYMwKS2XJPzP/N2GgVQ4hZY+Vee9lVk8dn6PkzxqgUBwNpCeaCe7RMj8BGc+nVKimfPwGO787zr2Flbh8+uouo6uRgH1FxTsXewBGd/6CuKGHnXXVd0qlVsqSbs6DYD4kfEo0QoHXjxA+9/9FrV6SFhjicRpzmFRuGFYxol/ACfIf1dk8dn6Q/VqvY/Fh4nvtBF8px2V6llVmXH7U5g5Oh6AaKsJvxo8i4s0wNR0nehjbNGnDMmgR6tYXvl5D0v3FAH1nRTNioQsSfRPj+eBCV0Y0zUl5OsIWp5mC6ZcXpXffLqZBTvy0QFvzcmiHj3JP15zkP+ty2ZARjyv3TiIREdkXZlnbZ1Ftb+63jbVpVLwVUGDRZnlK8pJGJsAgNvv5t1t7zJgwokHU5lF1dw8a01tXVE4uHwqK/YW8fiXW3n26n6hD2gGDhY72ZxdFrQ93EJLr1/j/ZUHeGRSN8wtGBAKBIIzk4ykQDCl6/opW28qEDQVreOi+O6hMWzPKeetpZnM2ZoLVQPRrYVI8tHFY8WukHp5Kjkf5CDb5HoLx+ZEM/Ej42v3tXexgwaZf3+Wdve/j2wJKEkqN/9A9faFtLrhGcOxhOs0B3BBn9ZN/ElEhk/V+PePuxoMpBpT2Xj8Got3F7K3oJIuqTF4VQ2rWW7EXTF0gAkBVVIrA9VOv/R43rp5CIWVHmavO8i2QxWUu33EWE10SY3m+qEZpCfam+7DETQ7zRJMVbh9XPP6SrKKqw17F9Tg13T8ms76A6Vc8NISvrp/VNgOcWXuMnaUBDeKc+5xNlqUWbW9qjaY0tFZfGgxfs2PST6xj+Iv32yj2mtcgBlKKvfN5hxuHtmB3m1Ofu+m91dloRnkmiMptFR1nQU78pncwhdTgUBw5hFrM2NSZEqqvSRFnxpSIoGguendJo7/TOnPM1edw6GywVw790e8x0ynUi5MQXEo5M3Ow1vgRY6SiR0YS/rd6ch1anPkKBNylBnN5ab4h+kkTX4ASTbhLczCX5bX6DjCcZq7qG9rbObw6oiaix+356M1UO8dlp25qvF/n2/FbjWx4UApqTFWDpW6ghoZh21lLklM7t0Kh7XhuWVKjJUHzhXGEmcCTR5M+VSNW2atIbOoGq8aumN04BidoiovU2as4vuHRhNjM4c8psRdgkW24NPqy/xCFWW6DtQv5JQlmSpvFfG2+KD9wyWv3M2q/SWG+tdwpXJvL83khSn9j3sMx8uWQ+VBDekgskJLp0dld34Vk/s0xwgFAsHZys68Ct5emkm1x8+YZxdiVmTi7WauGdyO64dkiOBKcMZjNSl0Tk7hV+1/xbysefXavgAkjkskcVxi40+im5AtMdja98a5Ywmu3SuRrFEo9jjkqNjGj6XxBWGAlfuLWzxz/MbivYbKoHBVNqoOGw+W8u9r+/HajQNxHrnmHBtMQXgBpsUkc+fYTifwjgSnE00eTH23JYcduZWGgVSjXaM1nfwKN28vzeSRX4UutvNpvkBHx2MIpyizLpJfJfe1l1Hj2qEkJmJKTEBJTERJCPxbtodOtb6/KsuwRirsH7Gm8/3WXP5yWW9iwwgkw0XXdZxelWqPnyqPn2qPSqXHR7Xn6LbMomrDYyMptNThlHMnFAgEpy9rMkv4yzfb2V9UVduqwe9VAZVyl4/pC/byyoK9nNs9lacu701qjDDAEZzZ/Hrwr1mRs4JST2RmLLpmxl/ZG10/THS/89E1P+akDBLGTquV+dVgwoeMjpejJRfhLAgXV3lZua+YkV2Sm+bNHgc7cisNt0eisomymOiSEkO01US01cSknmnM35HfqMLKCJMs0S0tmj5tT77aSNAyNHkw9fqifYaa1XB+kB6/xrsrs3hgQpcgQwZd1/EdzsGzezee3btwZm7G180Z9A7CLcqswS/rxDmS8RcU4N61E7WkFLWkBH9p4L9IEqbExECAlZiAKaHOv48EXd+u9RytCatDJD9isyKzcl8xE3qk1gt+qjw+qjx1A6LAX2Xtv9V626vqHOv0+rGZFRxHLgwOq4LDUvNvE9E2k6ENK0ReaBljOyNalgkEghbmq42H+b8vthia4tTgPnK9/WlHHmuzSph99wi6pEafrCGe1eSWuzhc6sLpVYm2meiQ5Ii43lkQOan2VGadP4ub591Mla8KTQ89wdc1M6qzI+6ca4AfAYgffSN5H/yO2MGX1ttXQSVDKuAO+TueV6+jiihcHl9YC8JOr8qMJftaLJjy+jXDcgWITGUjSYEylRqeu6Yvl01fzoHiarwGCh4jFFkiwWHh7VvCM/oQnBk06Qx42+FyskuC+yFE0jXap2rM35DFWIpx796NZ/eeIwHUbmSHA2v37li7daXDiPNwVG/B66+od3wkRZkA6bEZtLn5PsP3o+s6WrUTtbQkEGCVlASCrdIS/CWlePftx19aQlnseaAEr4xG8iOu8vi578MNALWrIg5rnUDIEgh+arbHRZlpGx+FwxIIjGJsRwKkI8c4jhwTyn7+95rOZ+sOoR5zIYqk0NJuUUhPFJ25BQLBibFwZ0HIQKouqgYl1V6unbGSuQ+PES0amglV01m4s4A3Fu9j6+FyLKaji51ev8aYrsncNbYzQzokCJOQZqRLQhc+veRTHl34KPvL9+PTfEGyPwBZk1GR8ZYOw1twIXVbilpSOhDVZQjlqz7FnNSOaFy8an6RifJGrPiQJLjetJgDeiq3Z44mO8wF4VX7Syiq8pDcAtJbkyw16KAcqZ25tc65bbeY+OyekUybtZo9+VWNugQC2EwyqbE2Zt89vEU+B0HL0aTB1IId+Xj8wSdbJBmaao/Kx9P/Rw/3JqzdumHt3p3Yiy7E1q0bSnz9QGjaliJmbpkZ1Gcq3KJMNCt9o6/A7VMNiyclSUKJdqBEOyA9vcExK0/9CM5gi/ZIfsQ2s8yfLurFDcMyTurNaNrw9ny9KQfV4CIRbqGlrsMFfVqdrCELBIIzEKfXz/0fbYi4350OlDu9/OZ/m/ngjmHBTyw4IfbkVzLt7dVUuv21NSnHyp4W7Cxgxb5i2ifZ+e9tQ4XsshlpE92G2ZfMZnfpbt7f/j5zMuegoyNLMn7NTxIyN5cU8EvZJXztm4QEQYFG/OgbKXj3QW4dHs1GuZzJchGKdHQvSYL2FHCD7wu22mGSaRMLtEGNjstikjlc6mqRIEKWJWKsJircwSZgkahsfKoWdO7G2c18es8IPlt3iDeW7KO4ylvbQ7QGh0XBbjVxx+iO3Di8PdGNmE4Izkya9BvPq/BgZKYSaddo98hxdLj7/0Lud1W3q3hjyxuGj4VTlKmj89WyFL5e+hOv3DCACT3SGt2/IWJtZkoNgqlIfsQmWSI5xnrSV/X6tI0jPSGKPQVVho+HKrSUJbh8QNtGHWsEAoEgFF9tPGy4PRyJuKrDmqwSDpe5aBsfniOsIDTbDpczZeZKnJ7GeyfqekDqtSe/igtfWsq3D44O25lXcHx0S+jG30b/jb+O+iuV3ko8qocYSwy25dORDj4HykdcIW1ghv9ilml9qJu7OS8pD3rDB2tLOCdVrhdI1SBJ0Nbup8Sp8aLyMv+SbuQD9bxGx1S3Qe3J5qpB7fhg1YEgQ61IVDbpiXYykoLr5K0mhRuHt+eGYRlsOFjK91vyKKh0o2k6yTFWJvVMY3SXZOQQSiDBmctJmQFHmmYNT5kKibZEbuhxA7N3zsalBssLG30NzYyncBI+jwKo3PfhBp6+/ByuHNQurOM1rxfnmrVULVzIiN1ecloPwneMvXokP2KfqjO8U1JE76GpeGhiV3732ZaQKWwjNB1UTWswuycQCASh0HWdGYv3B/V1iUQijq7z3oos/nBhz5Mx5DOe3HIXN7y1impP+PcFv6ZT6vQxZcYq5j0yBrtFLLI1N7IkE2etY3QwYCoseQ6AYfJOhll2UqDH00uq5lJ5OcNNm7lJmU/eODMfb/E0+twj2ilYTfDDLid/7PkRRXoc87QGsr86QQ1qTya3jOzAR6sPYjSDDEdl47Ao3Duuc6OvIUkSg9onMqh9CPdEwVlHk575KdEWZImg7FSkZgbJ0eEXsz466FGyK7NZfnh5kNyvIXTNjK98AL6SMbXb3D6Nx7/aSqt4GyM7GxdR+ktLqVq8mKqfF1K9ciXWzp2JPvdc7rl4FF9+fhAMTCjC+RErksQFfVoRF9V0Tn6RcHHf1izfW8TXmw7jCrNWAQLSxN+d34O1WSVc8soynr+2H33bHb/FvEAgODvZW1BFQWXwxC4SibhX1fl0/SERTDUR03/ei7OBQCqUM29hpYfP1x9i2ogOJ3fQAohtDR3Hwt751AQWqVIZRY+YgOW1u6XHybifCNiid3ixEqcPMh+OxmEJZFfe2uDlgy0+nhpv5f45bkwy/K3TcxkveQAAIABJREFU6/ysnkN51i9BC8JeVSM9oeUazbZPctA/PZ61WSWGCqlQKhtJkrjwHNErU3B8NGkwNaFnGm8uzQzKcETaNVqRJEqrvSSE4RAkSzIvjH+BZ9Y8wxd7vsCvqai6caq5dEUFxT8U4sn1I5uzMKdurtcrwe3T+PNX25j/m/FAYLXUu28flQsXUrVwEZ7du3GMGE70+HNp9eSfMSUFMknJwNANVSzbU2SYVQv1I7aYJO4Y03L9CCRJ4h9XnIMkSXy16TAug14Nx2IzyzxxUS+mDm/PraM68M3mHG57dy3XD83gwQld6xUoCwQCQWMUVHowKRIco5aOVCJe4QqWWwsix+n188WGw4Y9dsKRXbp8KjMW72fq8PbCkKIlGPtbOLAMfOErdlQdXlrt5fEx9WuefjPSSqtomb8v8bDjCxdYbkFP61VvQVgCxnZLCWvOBoHFk282HeZwmQuvXyMp2srwTklM6pka5OQcCdcOTmdNZknEx9nMMjOnDRLqGsFx06TBVL92cbSKsxn2LgrXzADgh+15LNtbxEd3Dqdn69AN5WRJ5vFhjzO151Tu/fYlDngXgV7/Al70Qz5F3xeTOPkuUjLGGN4AAA6XuVg9dxkdNi2lcuEidK+XmAnnknzvPdiHDkW2GhdX/uXS3lw6fVlEkgiAKLPChee0avF+BIos8fQVfRjdJZmXf97DwSNWoGqdm6lFkZEkGNIhkUcmdWVwh0CqW5IkLuvflhGdkvjDF1u57NXlPH9NP3q1Cf3dCQQCgVFrCYhcIq5qeos3Dz0T+HpTjmHbjEhklyVOL2sySxjWQvL1s5qM4TD+cVj0dNgB1WMjLTy73MN9QyzE2+p/+Tf2NXNj34ByZq/Whknev9R7XAeirQoVbl+DvTJ1XeeH7fm8tmgvu/Mq8Wt6vWD90/XZmGSZm0e2Z9rwDmw7XM7stdnkVbjxqRrxUWbO7ZHKNYPSibMHv8ZHqw/yn/m7efG6/vz1218od3kxaHcaRJRZ5j9T+rdojyzB6U+TBlOSJHHvuM48+c12w/qbcLpGQ0Cu4XX6uPr1FXxx3yi6t4oJ6/UTLK3Zt2MSbnUMSlQ2kuIEXUKtlin48k+BG0CXozJDoxuAx+Nnxncb+VfvONq99CLWHj3CujF3Tonm3VuHcvOsNUG6/4aIMisM65TIv67qG9b+zY0kSVzUtzUX9W3N9pxyPlx1kP1F1Ti9fmJtZvq2i2Pq8Pa0aaDAOzXWxls3D+bT9YeY+vZqbhvVgXvGdT6hlSaBQHDmExtlfCuKVCJuNcsikGoC5v+Sb3gfi0R26fKqrNxXLIKplmLUQ4H/Lvon+Jwhdx/cRmF8BxP/XuHh7xMadmNsJxXSmmJyCXyviizRPsmOIklM+Pci7hzXhtE9FNxaNXaTnTRHGtGmeB6dvYmFuwoanB8FFqJVXl24j1cW7MVmloPKDjYcLOW5H3YxuXcrHvlVNzomO9A0ned+3MXcrbl8evcIOiQ7GN4pib98s50FOwuQOdqXrgaTLKHIEj1bx/LkJb0YkJEQ8vMRCBqjyasFLxvQhndXZLGnoDLIVcWIxrTXTq/KDW+uYtnvJxBlCb0queVQeWDi7jejOo/K5lz714d9A9BkmY1p3Um5r3HXGiOGdEjk83tHcud76yip9uLyGjsg2UwyOjBlSDp/urhXyF5QLUHvNnE8feU5ER8nSRLXDk5ndJdkfv/5Fn76JZ/nr+1Hl9TwAmKBQHD20S0tBp/BMnKkEvFzjmT4fZqPA+UHKPeWY5JNJFoTaRfT7qwKtPYWVLJoVyEl1V4AkqKtTOyRSofk0D0BS51ew+2RyC51oLCqcYMDQTMz6iFoNwSWPAuZy0Az/l5reOpcK6NmVfPwsIblel5MJEoV5OpJKJJEXJSZj+8cTrFvP9Vx83h130Km71WwmRQUBbyqlyi1KyU5I3F5O1G355URNWoYo/rtmm3fbslh/o58Xp86kE/XH+ZwqZMv7htV2zw6LdbG61MHUVLtZfbag3y6/hBlTh+aphNtMzGuWwq3je5I5xTR6FvQNDR5MGU1KXx4xzAuf205ueXuBuUbEFp7rRPQXn+7JYdrBzfc56mG8gb08pHq7qvDzCwZ0bN1LEt/dy6r9pcwc8k+lu0twiTLyFLA6SjaauL20R25bmjGGd01vk18FO/dNpQPVx/k2hmruGdcJ24f3emUDBwFAkHLEmMzc3HfNny58XA9aTGELxF3WBWuGxHHyxte5pNdn6BqKrIUmLj5NB+p9lRu7X0rF3W6CLu55QrlmxO/qvHjL/m8sWgfu/MrUXW9dlHTYpJ5dt5O+rSN455xnZnYI7VBK+eGrtORyi4tQpXQ8rQfAdO+hPwd8MYoMGjyW0OfVIWLu5l4ZpmXnikNf3cmVGxmmSSHhTdv6cVvl93JzpKdeDUvuqSB5MOtAUemf15+QWm1F0eKA2f2bejeFKDxxfTG0PTAPO2Wd9YyrGMiH9053LDeKdFh4d7xXbh3fJfQn5NAcAI0i49lgsPCdw+O5v4PN7Jyf5Fhhipc7bXTq/L6on1cM8h4VVHXNPyFRfiyD+LefBjdKwH1f1SR3gBOdL4vSRIjOicxonMSTq+fkmovXr9GbJSZRLvlrOlFIEkSU4e3Z2zXFH772WZ+3J7Pv6/pF9bKqEAgOLu4fXRHvtuSExRMQTgScQ0leQ5Pb1kFgNdgBT67Mpvn1j3Hc+ue428j/8b5Hc9vqqGfElS6fdzyzlp25FYYSqlqFjbXHyjl4U82Mqh9AjOnDTZUfbSKM5Z5RSK7NMsSqbGiee8pQ0p3aqObRvjreBsDZ1TxmxHG9eEmNGRHIo+O6cbkftHcOX8aRa4ifFrD5i+SBChekH04OkzHeeBuypauDmlkAo0HXJoOv+RWRvpJCARNTrM1BYixmXnv9qFcP3MlK/cHu6tEor3OK3exftV2ursL8WVn480+dOS/2fgOHUKOjsaSno4tvTeY+gUdH6nuvqECyuPBbjGd9b02MpLsfHLncN5ZkcUVry3nkUndmDa8fcig0q9qrDtQSlGVJxCM2sz0TY8L6lAuEAhOf3q2jmVk52SW7y3C04iiIRgNR7sPMcXuNQyi6uLyB4rxn1j+BCWeEq7vcf0JjPjUwen1c9XrK8gqdjaqBjm6v8qazBKuf3Mls+8egdVUP6A6v1crftiWh/cEGqDKcqDlh+AUQZahzSA4vK7R3bokykzpbeblNV7OSQ3OTplsDj5/7Hp8ksqU76ZQ6CzE34CD8rFIko4ue7Akv0HZsu0kXfhoo4vp4ThH+lWNOVtzuXJgeD1CBYLmoNln+dmlxk4yEWmvXS7WvfwWqTFOLO3SMaen4xg+DHO7dCzt2iI7ApmOtqqG6e/z4Ri5XyQ3ALMicXn/tifwjgVGyLLE7aM7Mr57Cr/9dDPztuXx7NV9SU8MltsUVnr4aPUB3lmRhf9IHYWugyxJeFSNMV2SuWtsJ4Z2TDyraiAEgjOdV28YyJWvr2B/YVXYAZWj9feYHTvxNSJfOha36ub5dc/TxtGGcenjjne4pwyPfLKJA2EGUjV4/Bo78yp5/IutPH9tf3RdZ2N2GR+uOsgP23OPXFuPrwEqQJ+2cUKFcKox+hH48h7wVjW625/HWXl/i0GmyWTDMvoBMJmYs+dbcqpyDAOp0qWlFP1QhLfAi2JTiB0US9rVaSgOBUkCV1Yputr4Ynq46qXqI+olEUwJWpJmD6YaavoXifROt0Vhf+gRMkI0ADQrMjeNaM+MJfuDbirh3gBkSeKmke1DvzHBcdE5JZrP7hnJzCX7uezV5Tx2fneuG5JeGxT9b102f/pqG0CDk6mfdxawcn8xfdrE8vYtQ4hpwkyiQCBoOaIsCp/fO4K73lvPhoOluHwqegM+RhZFQjKXY05Yg2oQSDU2oQPwqB6eWvUU89vNP60XZQ4WO1m8u9DwehmqJsXt0/h2Sy7d0mL4elMO1V4/Nw7L4I8XTWD22oO8NH9PkBMahJZd2i0K94zr3HRvUtA0dLsAlOD7ZdYj9Q2i6jb0rYeuw8BbAJi1bVZtprcuRXOLKJxbSLs72hHdKxpfqY+c93PI+ncWHf/YEdkko1X5MEUrSIoKuvH8LxL10qFSF3sLKoXRlaDFaPZgymo2LmKMRHqnyBJRYTZTmzq8PTOX7Dd8LNQNQJZgYEYC7Vqwi/fZgCJL3Du+MxN6pPKbTzcxb1se/7qqL99uzuGFn3aHXJHWCchUNh0q57JXl/P1/aNEQCUQnCHYLSbev30o6w+UMnPJfhbtLsSiyFR5/DgsSm3gc8PQDHxx3/F1ZqDhaF3CmdABVHmrWJO3hmGthx07jNOGd1dkohlEnOFIpCBQS/XpukM8eWkvRnVOrpVf3zqqI99uzg3bmbcGm1lmTNdkJvVMPfE3J2haFBNc8Bx8+2BEDX0BMNthxP3gSGJb0TbynflBu6gulYKvCmh7e1ti+gYCG0uKhfT70tn92G7KV5STMDYBJVrBX+XHFL0Bf+Vww5erVS8pKpKlEN2TSkNOgCZFIrvUJYKpZqbIVcSnuz5lyeElVHorMUkmkqOSubLrlfyq/a8wGwTqZwvNHkxlJNrJLXcHbY9EeieBoRzMiLRYG789rzsv/LTbsNdVYzisJp65KnI7cMHx0b1VDF/eN4rXFu5j0vOL8KhaRDdtr1/jUKmL2/+7jtl3DT+tV5cFAsFRJElicIdEBndIpKjKw6p9xTw8exP/uOIckqItDO2YiCxpjJ39RVDhe7gTOgCn38msbbNOSjDlUzVcPhWHxdRkrqZev8bstdlB181ImusCFFd7GN0lud411GYOOPNeM2Ml2SXOsGSXUWaZQe0Tefn6AeJ6fKrS9xooO0Dpktcp9lrQkIiTqkmlzLBRMxAIpHpeAuf+EYCVOSvxqsH1ic49TjSfRuyg+lktxaYQ0zeGqu1VJIxNwN7FjmSScGX+hDnZOJiqVS/pXmxJ81HsB3DnXIfqDM546nqgr5mgedhTuoeXNrzEypyVSJKERz3a8mBf+T62Fm3lb6v+xrXdruXufndH5JZa7vLx+fpDfLslh9JqL5IkkWA3c1n/tlw5sO1ps1De7MHUbaM7si2n/EhDtvqEK72zW00M7ZAY9mveObYTJdVe3l2RFVZAJUvgsJj48I5htE8SGu+TiVmReWhiFz5ac4CqCuOeJI1JVbx+jW2Hy1l3oJQhEZwjAoHg9CA52srILsnE2ExcPuBoPesvxbvQDWp6wp3Q1bAur/GC/BMht9zF+ysP8OHqg1S4fZhkCb+q0ybexp1jOnHloHYnZHhUWOXBwPwwIokUQJXHT4XbT1xU/bEkOCx888AoHv9iK3O35SFJAWngsdgtCroON41oz+8m9xAtME5RfKrGj9vzeWPzMHY6e2DR3EiAD4VEKrjL9D1XKUuIlY5krUw2QIfh98GEJ6iJtkrcJYbSWrVKxRRtQlKCv39TnAnXgcDzKnaFtCvSyP/fRhInrTRcTK9VL+1ah6NTPrK5kqj0d3HnXIO/sm+955YkiLEdM53VVNg7H7bMhorcgCW8PQl6XAS9rwSLUCCFw4rDK3hk0SO4/W7D6y0EFqUAPtjxAT9n/8w7k98hOSq50efNKXPx3A+7mLM1F1mS6s3VM4EduZX8c+4OLunbhsfO737KO4M2ezA1sUcqZkUGjIOaUNK7KLPMnWM6Rmwn/vsLetAh2cE/5+zAp2qGvaMUGcyyTLdWMbx83QBRLNtCbDhYSqXb2A0oHKmKy6cyc8l+EUwJBGcoTq8f+zFS7wpvBRLB94VwJ3Q1+DQffs2PSW6622G508ejszexbF8R6OA9YqRTk0E6XObmX/N28c+5O7l+aAZPXNQz0HA+QirdPsPAJdLeiiZZptLtCwqmICC7fPG6ATx5pAHqOyuyKKr0ouk6iiyRnmjnrrGduKx/m7PeufZUZuGuAh7+eCOqrh9Z3JbwEVX7eC7JPOu/jmf81/OQ9Vvui1qINPxuGHQrxKTVey5FMj6vauR7uqoH/f785X5M0UfPj+QLkpGi0iieZ7yYflS9NAN7egLRfaKRFC++orcoX9uB+NG/rn0ur1+jS+qRBrxeJ6x6HVa9Bn53sNlG5mKY8xj0vxHG/AZiW0f6UZ41bCzYyMMLH8atBqvLjPBqXrIrs5nyzVQeO+cNUqPj6JwcTZy9/nVle045N7y5mkq3z3AxCKgNrr7ceJgFOwv45K7hdEs7dWWczX7lMykyd43pxCs/741YdgcBuUc4DXuNmDIknasGtmX+jgLeWLyPbYfL0fRAbG23KFzary23j+4gdLYtzIwl+w3PjXClKroOS3YXUlTlITnauDeGQCA4fXF51aB+SIqkGAZTkUzoaqhp7tsU5Fe4ufK15RRUehqVLddc82avzWZXXgXv3jY0yKI8FHazybBeKtLeiqquhwyEEhwW7hnfhXvGd0HXdfyafmShVHCq89n6Qzzx1VbDrGJdXATun9PVqznc9WH+Ma6voVwzOSoZs2wOktjWyPcq1lcQNzSudrvqVqncUkna1fWDstiBPbCkTau3rWrrfHJ+mI6/LA/JGoWtrZ38r/LJnpGNYlOwdbCRfIHjSJYkMLZB7RNoHRcFVYXw30ugNAsMzDEA8FYH/rv+Xdj2Gdz8LbQS5R3H4lE93L/gfsNAqjFzH1VXKXDm8djCP0PBjXhVjUk907hzbCf6p8eTVVTNdTNXNbiAfix+Tae02ss1b6xkzsNjaBsfFfqgFuCkLCPdPa4zK/cXsyazJKL+ITazzMxpg4m3W477tU2KzOQ+rZh8pN+F26diVmQhQziFWJdVaujYFYlUxWKS2XqonHN7iKJngeBMw2kQTCVYEwxtmSOd0EWZoposmKry+JkyYyV5FR7D5sNGuHwqGw+W8cCHG5kxbVBYKgzN7ca1aTOsWoPP0x6OyRRE2ltRAmKPlUk1tr8kYTbI/AlOPZbuKQwrkKqLy6/z5aZcWsXbeWhi16DHJ2RMYPqm6UHbFbtC6uWp5HyQg2yT65m/mBPNxI+Mr91XV634KwbUO75izReUr/6C5Isexdb+HFTnIcqX/hHdJ9PztZ61xjG6asZ1aD+qszMOi8Ld4zqDpxJmnQdl2dBIA+FaNB+4SuGdC+CuxZAk3Cfr8s3eb6j2VQdtD8vcR1bR7Vup8peD6mDutlx+3llAz9bRlFT7qPYEX7cbK+fQgSq3n7veW8f3D405Ce8+ck5KMKXIEm/eNJj7PtzAyn3FITNUsgRWk8KrNw5gdNfGdZeRYgvTFVBw8nB6jVcoIpGqaJpOhTuMC6hAIDjtcPlU7Ob6t6vO8Z2JscQE2TNHMqFTJIVJ7Sc12ThfX7SX3HK3YSDV2GTB7ddYvq+IhbsKmNgzLehYzevFtWkTzjVrca5ejWv7dmxdu2IfNozxrW0syK8vl4nE4EmR4YoBbY9LZig4tdF1nd9/tqXBQKqxc9LlU3l14V5uGJYRpPjIiM2gZ2JPNhVuCnrOlAtTUBwKebPz8BZ4kaNkYgfGkn53OnI9d2cZf2XP2v/TPE7Kln1E0gUPE9VpEEgeYvt+S2zfdkHGMcheLEmLcTs7k+CwMKZLMnx2C5QfrhdIdXixEqcPMh+OxmEJBP9vbfDywRYfi245UtbhqYL3LoeHNwcaGwso95Tzz9X/RNPrnzeRmPuAhDluHb6ScWh64Bq++VA5mhbcvS6ccg5V19lXWMX2nHJ6t4njVOOkCZxtZoW3bhrM7HXZvL5oH0VVnqAeIjazjK7DpF5pPDyx6ymtjxQ0HSZZBoIv9pFIVSRJwiImAwLBGYmRzE+SJG7pfQuvbHgFl1o/oAp3QmeWzdzU66YmGaNP1Xh/5QFD9UU4kwWnV2XGkv1M7JmG7vXi2rqV6tWrca5Zi3vLFiydO+MYNpSku+4kasBAlOjAZPC+g6Use3N10CJluAZPZlnmjjEdm+QzEJxarNpfQpnLeJExnHNSkuCTNQd5YEJwduq2nlP5fcEmXAYJysRxiSSOa7iGWddMeEtGAEd/057DOwJKlO5DQXZhz3gbxZYLBBvHSBIojj3YzBLVHpX7Z85lesEcFC3YYVDV4aXVXh4f01AJgA6uEti/ELpMbHDMZws+1cdNc2/CpwefN5GY+0iyD0vCSnwlR5uiqwYxfSTOo26fxnPzdvHubUNP5C02Cye1WlSWJa4fmsF1Q9LZcLCMT9YcJLvUhdunEh9lZnjnJKYMTifBcfyyPsHpR6LDQpVB2jcSqYqOTkqMqJcSCM4Yqgpg3TuwbwGDy4ro6gE+7AwDb4Zuk0ExcVmXy3hpw0uGh4ea0EFghb17YvcmGe5Pv+SjGuiVI5ksbMoqZtUdD5KwYQWWDh2wDxtG4i03Yx88GCU62vB1+6fHk54Yxd6CqqBi7lAGTyZZolebWFE3fIYyc8k+Q8vwcM9Jt09j1vIs7h3fJag0YlzmOoa6PayymvEcyejs+s0uNK9G9393R7YGtpUsLqFsRRmd/tDpyIvL6L54vMX17+mquxjZHoslaR2WpMXI5orax4yMYyQkvn5oCB0Sktj32ROoeTpGS66PjbTw7HIP9w2xEG9rQJrqrYLlL53dwZTPBdu/Yk7m9xyqyDTcpWJ94Dv55d5fgmql6n5HZSvLAjVVuV4k07SgZuF1idR5dNHuQj5cfYAbh7U/zjfaPLSI9Y4kSQxqn8Cg9gmhdxac8UwZks4rC/bgPmZFNxKpitWkMCBDnE8CwWlP/i/w899g74LAErTfTQwQA7BnHxxYAYoZht5N7OhH+cPQP/DMmmfCdpyqIcoUxdOjn26yYX+x4ZBhC5BIJgu6rrN+1CXc88I/UGJjQ+4Pgfvp2zcP4aJXllLp8jdgXmyMSZZ46rLgCY7gzGD53mLD8yGSc9LtU9lbUEX3VnUCbtWHvGYmz3vKuCctlW1WC+4aiZwGRT8WkXpJcP2yWdOI1/yUqQreuE1o/njQTUhKNVFtfqHYVYY15ZuwjGNMskzrOBsWk0zP7P8Bxhm4wW0Uxncw8e8VHv4+oRF77YOrAgs40WdZ3XXpAVj5Kmz6AIC3U2LxWoJdPYvmFlG+thx06PlKT/wV/nq1UjXfUdG8Igq/L6TNzW2I7hNL9d5/GDYLryFS51GAp779Basic/VxmtM1B8LHVNDiXD80g5cW7DF8LBypikXzcUOSjlzH3ed0R9f12uJPh9khGmAKzg72zIf/TQuskjYUFtRYHS9/EXZ9z1U3fUOpp5QZm2eEHVBFmaKYPmF6k2WlAAoqjfvkRTJZ8CFT1Soj7ECqhvREO5/fM5IpM1dR4fLhD9P8QtV1rp2xihnTBjGma0pEryk4tSl3V0PMauwJy5DNZSCpoJtRPSm49sWEfU4qskSZ8xj53K65oPmx6vBmXgHPJcbzeUw0EgHL86K5RSRNSEJxBJ5fBiyazkiXm/tKy7mpjYzN/k29p7QkqBEZx2i6hsPsAE2D6qJG38NT51oZNauah4c1onoyWaE8+5QLpjRNZ8meQuZszaWgwoMOpMRYmdy7Fef2SD0xM7W9C2D2NFC9oPnYbrGQZ+AoWlMr1Xpqa3Lez6FycyVxQ+Nqa6VKF5dSuaWSlEtTKPgsUFMVNzgOXbMgKaYGm4VDw+UcjdXzefwaT3y9jf4Z8adMVl0EU4IWJ9FhYVKPVH78Jd9wEhBKqiKZzExcMpvMn2aR+utHcYwefVoGH5qusTp3Ne9se4fVeauRkUEKbB+cNpjb+tzGiDYjmtTGWSA4ZchaBv+beiSQCgO/Gwp2wrsXcccdC2jjaMMza57Bo3pqm0jWRZZkLLKFttFt+dfYfzVpIAUYWpRD5Dblfi1817W6dE2LYd4jY/jtp5tZsrvxyWUNPlXHp6rc+d463pg6iPHdT62JpCByvKqXF9a/wBd7vsCSqiIpdQMhPyb7QWzt3GiuMkyxi/BXjCPUImTQ/XT9O7WLGibgDyVlPFBaTidNo226BVd3OyVzC0m9Ko0oTcPuV5lzKIc0VUUDYjTtaCbrCJEYxwAMSBsQuBf6qgPGEVrDxmZ9UhUu7mbimWVeeqY0cv/0Bl83Wgqn1897K7J4e1kWTq8/qFfq3K25WM0Kt47qwC0jOxATafPvfT/DJzfWs5BfYrfhMZg71dRKxY+Mx1/pr/cd2bvZKfymEHOiGUuSpV5NleYOfT0xKucIp57P59d4c2km/7qqb2NPf9IQwZTglODvV5zD+oOlFFZ6GmziZoTNLPP0FefQf8CFVP70E/lP/xNTSgqpv/k1Uf36Nd+Am5iVOSt5YvkTVHmraieCGkdtb9bkrWFb0TbsZjtPjXyKMe1OTXtQgeC48FbDx9fVC6TCcuLSfFCyH354nAsveZHzO5zP0sNLmbVtFlsKtwR20TUsioWJGRO5qfdN9E5qHllbQgMtPCKp/VRkiUTH8dd+llb7WJNZEvFxbp/GfR9uYO7DY2ifJJrXn65Ueau488c72VO2B4/qOdYxvxZHNzOSScJz+DMc3XNx50whkD8KRtV0Eo5pukpFTtB+/8/eecdHUef//zkz27KbSgohhRJqKAHpLYCoiAVRPEHBjv1rL1f8nXfe6XlNPc+znA1RUc+KBXsB6S1I7wkBAoSQXrbPzO+PNSGbnc3uhKCUeT4ePHxkd8pn3Z3P5/Nur3ecquJQFP5eXkHqCIFxc8r5qo+bzyr9zPP76SgHjAERuKamjmeSEkIMqmiFY+wmO9f3/ynN32wHNbID4k8TbAx+oZ77RrXyfFnbFuVQFJVlheUs2hHodykKkB4fw/kDOjEgS7/yXFmdmyteXMmBKldI+UMjDV6ZBq/MM9/v5t01+3nn5lFkRNuDqeYAvHNlkCHV9ak6yuV6ujwRj2gNzLeN9W4dxndoaoTe8jtCAMkh0e3X3ajbUNd0nCpb8DYTnwhHy3IOS0Yvqpe+SdyQKXgObm+q6WuucC59AAAgAElEQVQZ3ZJV+Hj9AR66sC+x1l/elPnlR2BgQCA69d7No7nsheVUNnhbbXbZiM0s8pvJfZg2OAuA+EmTiJs4ker58ym58y5i8vJIvedurDk5x3v4x8SCwgU8vOJhPLJ2mlAjTr8Tp9/JPYvu4cERDzKt57SfaYQGBseZje8F0nVaEFmJi0CEasP/YNIjSNY4JmRPYEL2BFRVxS27MQkmzJJOr20bmNwvnYLiSpy+ttd+miWR/GNoB/Lswt34/NpzZ2tpMwBev8KLi4v4yyVGA9OTEZ/i45Zvb2FH1Y6QZrotkewSHS/pyKF5++l01QosHTvir83HXbwh5DcZazXRPbWF+IkcqprXnIE/RYIeDxMJuri+nv8kaRsZ0QjHWE1WxmSMCfwhCJCQDdX7Wj2nRweRGf3MPL3ay4C00DGpsgchSZ+oQb3Hz9ur9vHSkiIaPMHRI1GAucuLyUy0ccuEHkwdlBFVk+sap49Lnl1Gaa1bU/2uJR6/wsFqN1OfXcYXd+WHyNhrsvpFkEN/I0qYereWjdCbf0clL5Wg+lUkhxR0XED6vm/ksRBczuE9shdkH97DRSSMvrzV80RB4LONB5kxrHNU9zmeGMaUwQlD52Q7X9w1joc+2sy32w4jCIT0x2jsQdYxwcofL+wX0qRXMJlIuuwyEqZMoWrePPbOupK4s88i5fbbMXcM7d/yS7P8wHL+tOJPEQ2p5nhkD4+teoxkWzLjsyN7fgwMTmhUNVD/pNEgMiolLghsqDa+C8NmN3tJIMYUpaf2GPGXlzN25afI7qyAOEYLopUpz06KoX9m23qo1Dh9fLWlVFNRMJq0Gb+i8uG6A/y/C3KxW4ytwcnG29veZkdlqCFVtaQqoKxW5g1SYEs+tyOq1JeyD/bgq3gWwTJH8zfZuYOd8noPafHNxBuskX+jrUWCEhSV/6tp4LkOSbg1Gm+3hk2y8fCoh5Gap8yOuh2+fRh8rafp/WG8lTc2hhoRMgLf+Qez/JuD/GqIGNUzeLDaxYwXVnCk3qPZx6uxt9LuIw089NFm3lmzj1evGx4xinL72+soq29AkWoQzR5UxYLqj6Nhy/KwzhBZValq8HLj62uZf9uY1gfu98LaOZoGcY8JPdn65e6gejf4qRG6KLDztzuRa2VEm4its43kc5OD6tkaG6bXFNRh63whaOoratNYzlG/ZSFVC1+h4/Q/RTzH6ZUpLj8xUjONGdPghKKDw8KzswZT1eDlnTX7eGv1PsrqPHh9MgIqsTYLPdPimDWiM6N7JIe9jmizkXzDDSRedhkVL7/Mnoumkjj9MpJvuAEpIfrNyqaSGt4v2M/+Khcev0IHe0DC/+JBmTiOMbQsKzK/WfIbzaL5cAtg4wTnkT38bunv+GHGD5jF4+91NzA4bpRugvrDmm9FrcTlc8LK54OMqZbsr3Syp7wBp9eP3WKiW4qD7A52zWPr3D4+XHeAr7eWUtXgQxIFUuOs/GpIFuf07djkYXbv3Enla69R9823xJ9/Hhf1O4P5O2vaVPtpt0jcMr47qqri8St4/ApxVhNilAXm7xfsR9Sod9AjzS4I8OmGE8PTaxA9iqowd8vckLWk/ItyjnxxhKwbsoJqkPb8cy/pV/4DW+cMOl3Teiua9furmfTUYt66YSR9M34SRuk+Acq2tBqhihQJurbBTWnemcw/tDSk8XY4bJKNe4fey8TOE4PfGHgFfPOHkOOL7w5O28tOEHH/PlTcRTLHkHfxQ2w5aObmNwqIjzHzqyFZXDwog2SNSM+ROg8XPbOUKqdPs0F3S1w+mY0lNVz+4grev2U0NrO2kbF07xYK6l/C2n0dVgRQBUCl/OtSqhZWkHz+tVgzzwnrDNl2qJbtpbX0SW9FwGbnF5ppkSoCNR3Oxt67lPIvy+l46VHnc9XiKhBBrpfJuimL2H6xVK+q5uDc4Ho2yS6RdklHDr1RSoezJWzd3AiiierFr1O/+XtUn1szKt4cvTWm4fqo/dwYxpTBCUt5vZeKei+SIKAgAAI1Lj9r91ax7VAtv/94C1cMz+b/JvQI25tMSkgg7b77SLrySsqfeZbC884n+frrSLrySkSb9uZMUVQ+Wn+A5xYVcqDKhdevBHl7v9texiMLtnLJGVncNqF72A1ZJJYcWIJXYzEKtwA2SpCKpsDCJCsyC/ctZFLXSW26v4HBCUFNCYjhl6KolLgA6ktDXvLLCt9tL+O/iwrZeqgWiynQGF4QAmltfTvFc8uE7pzVJw2TJFJS5eSpb3fx6YaDiIIQ0gh31Z4KJEHg8kyRaSvfQ9q5naRZM+n+1ZeYkpK4t8bF108toUbnAm8WBdLirCzedYQH52/CL6uIYmCD1CstllsmdOe8/p3CbsIA1pfUhIwX9MlgO72BTd+MYbqGb/ALs+rQqib110YaFdgyZ2cSlxcwKiypFrJuyWHn/UXUFWwnNq9rxGv7FZVqp4/pL6zgk9vHkJMai3/IbITlz0aMO4SLBAEICdn89swn6LTlNZ5Z/wyiIIY1quymwBr75zF/5tyu54YeYIsP9J/78fXoBWwaEc2Q0pv0vqO5p5/AXWf1ZOWeCt5fW8JT3+5kVE4ylw3NZkLvVMySiKqqXPvqaqo1DKlICnS7y+r5/Uebefyy4HruWm8t9y26j9WH1iEm+BGEo8aO7JQ58lEpGddnkTBsNYqnGFfJNZrOEJ+s8vKSPSHXD+LIzkCNagtcWInxdSDlgj7sfXwNyecEnNWqrFI2v4ysG7KQXTKHPzzM/v/ub6pnS78sPaieLeXcjvgbLqdmxbuUL3gCBAHV7ydhzEzih16kaQg2R0+NKRBaz/cLYRhTBiccxeUNzHhhBZXO8LVTjbnJry0v5tMNB3n35lGtFk6bO3ak0yN/psN113LkqX9TOfk8Uv7vNhIvuQTBdPQxcPtkbn9rHct2V2huTCCw4QB4b+1+PtlwgFevHc7wbq3neGsxZ/OcENWxcAtgowRpzfKapg7jTr+TOZvnGMaUwcmNz9lqAXnUSlz+YMdE0ZF6Zr68ijq3r6n/k6dFMfeP+6u59531xMWY+cOFffntBxtp8PgJV7LZeJ1Xd/n5Kvsi3vnHE6QkH/V+d0qI4c0bRnD5iytp8PoJI/AXhPmnnjqHalzsr3I1bdB+qtdnx+F6fj9/M7+fv5l7J/Xm+jFdNdVKQ+Srf0JvH5eqMNcxOHGZv2t+yFrSqMDWqKzWiK/qUmJyFuIq3khs3uSg91ozBhq8fq58eRX3n9ub/3xfyFOm/uR51wXpAEYbCcLsgLH3IAgC1/a/lkt7XconhZ8wd/NcqjxVmEQTqIE6sM7xnbm+//VM6joJq9RKPdC5f4HDm+FAQaCOMhpEEziS4cr3Ax4WQBQFRndPYXT3FOrcPr7YVMqLiwv53YebuHhQBv0z4yk60hASfY4mldbtU/h0w0H+3/m5TQ7gSnclMz+bSZmzDAUfLR9t5+7A95gwNA5B9CFaD2Hv9m+cxf+H6gved8iKyoINB3n04v5NjhdVVXF6ZaqcXqqdPhIPHSBLo+1Eg2pDwEqc7ULiBu7kyGdHsGZYUdxK0++osV4qHKoi4a0YT2y/s4ntdzaKp4GSZ68hZcr9QYZRa1LpempMHRaJXh0NaXQDgxBKa9xMe2451S5vVKp+PlnlSJ2HS55bzpd35QfndWtgzckh6+l/49qwgbInnqTy1bmk3n0Xceecg6LCTW8UsLqoIqyCTnP8iorfI3PNnNW8fdNIBmUnRjynEZff1aQ21pxwC6Bkk4jLi6N+S32TMQWwo2oHtd5a4i36+tIYGJww2BIggtx/NEpcsmQDRUUSBbaX1vKr51dEZdA0eGWcXpnb3lwX9ZB9ookDXoFfvbyGz+7MJyHmqHe0f2YCH98+hqtfWU210xsiadyIJAqYRAGfrESc6xqv8fhXO9hT3sAjU/uFGFSOMHVOetNmTgRlLAN9lDpDo7JyvdykrNaI6rfjr8tDcmzCW7o76PhIxoCqwqFaN88vKuTRi/uTZ38S4dXz9EeCBAliU6HfJU0vxVnimJU7i5l9ZnLYeZhaby0iIom2RFJiohRkkcxw5Yfw7tWBNgsaNZhBmO0QnwHXfgYO7XvE2cxMH5bN9GHZFJc38H5BCb/9YFPI/kBvKu07a/Zxy4QeeGQPN3x1A4cbDuMPUzvW8nsURAUEJ/YuL9BQdBcowZkxPlnlihdX4PIpVDm9TWnKSXYziXYL1/v9TCNUu1H56RV/bR4pk/PY87clpExOQZXVkN+RFqoioPrj8VaOa3pNT1S8OdHWmKrA5P7pTX83+BrYWrGVOm8dJtFEckwyuR1yf5Z2MsasaXBCcePra6lx+0I2F615zBQVal0+bnx9LR/fHjksDBAzcCCdX5tLw9KllD3xJBWvvMKH59/Cmj1uTUOqtfu7fDJXv7KKlQ+eFXXhdo2nBrNoRpaDN1paC2AjpgQTrr3BC5dZNFPjrjGMKYOTl9Q+4G9dgCVS/QXAZn8mV/7pa3I7xbPxQLVmUXg4wtkyrT33fkWlrNbDA+9t4MWrhwad1z01liW/PpNlheX894dC1hRXYZFEBCGgt+GTFcb3SmXp7nI8OurvXT6ZDwpKSI21cNfZvYLey0l1YBKh5fSlJ23GYhLpmmJIo59saKWLt1RgA/BWDwVU5IZKRPvRNSNaY0BVA3XNY3qkACkw7SX44MYgie1WEU1gSwwYMOZQcRhBEEh3pJPuSNc4OQrMNrjif7Dj84CoTemmQP+pJlEOIWBEOVJg7D2QNwMs0aXpd01xcNP4HF5cUhTynh6jwe1TeGVZMbdM6MEnuz9hf93+EEOqec20YBJQnAr+Oj+muMD+QhBUkBqwdFiKtzw4M8UsCVw+rDMDshJJcphJsluC04M3HoIF8482P/+Jo3OghN9zB/FDdlHxzWFMSaaQ31FLVEVElR04994MylGHtt6oeHMi1ZiaJYEZw7KxmSV2Ve3i9S2v80XxF0E15IqqYDfZubrf1UzrOY2EKIRT2ophTBmcMGw5WMPusrqQPORolah2HK6LXHzZDEEQiM3PxzFmDBWffs6cJZW4TKGRrWjv/9GPB5g5IjppVZ8cxlutsQA24q/xY4oNfmQFhEA/KgODk5XEbMgeFvAmt0Jr9RdYYhl46R/5IetMfj9/E95wvVkiyIM3J5rn3isrLNp5hMO1bjq2iIqLokB+z1Tye6ZSUe+htNaNyysTazORlWTnptfX4g6TShzJefPcokKmDc4iu4Mdf1UVdV99xegvFvFy2uQQNUE9aTMCcOlPrSYMTh4SraFZEY3KarUFtSQMD2wifVWjUTwyrqICEsdd3XSsHmNg/f7qo7/33Clw+Tx456rAm62p6VliIS4drvk0EBE6Xogi5F4Y+Fe+G7Z9EqinlP3gSIXuZ0L2CELy6aLgQJULiySGzC96jYYjdR7OemIhZQnPgrl10RD3ITeFfyik8JFCej7Ws6lmWhD9WDqswFt+Fs1V8wRBYGyvVDLD9ZzqcyF8elfIy2Jzl5JiwzHwYWpW3YJokxBMIrXrakkYFmyMqKoIqojsysZ9YBaqHCyhrzcqrgdBELhqZCb3LbqPH0p+wK/4kVU5RBnZ6Xfy7PpneXb9szw08iGm9pjaruNoxDCmDE4YXlmyB2+LYgU94XOfrPLKkj38s7XiSw0EUWR1zhAo2AAtUnKivb/TK/PCD0VcMbxzUPpNjdNHYXk9RUcaKDry03/L69lbWYmluzek8bzWAgggu+UgCdKmz6z4jKiUwcnPmLtRDvyI2Cw1J+r6CwCTFXqeQ5wqsGR3uWbaXDTGUSO60naAeSv3ct+k3mE/XnKsNUgVrKTKScHeqjaPU1FVXpr3PTds+RTXuh+Jzc9n4DUzGLDdTMG+mpBrRpM2IwCjcpJDjEKDE58xmWNYd3gdLvlohEiyS6RdnMbBeQcRbSKO3Fh8FS4qvv4XprgUYvsdVcTTYwxYTAGhlqbfSY+z4d5tsP5tWP40eGoBIVAHKYqBfkadBsHYu6HnJGjnTXWrpPSA/Hvb7XINHr+mDabXaJBEgRvPlnhikxN3sy2HVs10THYMaZekUTa/jNK3S0mfno4gCdRvrad+6xESx27DX9e/6RoqKimxrYj1WOwwaBYUzG0WsYO/3H4R/5X70miKmOIy6HzfJwimGuo3PMfBN1aBIBHbLxFBUqnfWkfdBgcJY+9C9aZq3qoxKu7buYTMPoPwYKIOO2qYBtHNiej4Evz8dsVt7KndFbG1TOP7j658lBpPDVf3u7rV49uCYUwZnBB4/DKfbToUEpXS4zGTFZVPNhzkL5cMwGLSlyP7ypI9mrUNeu5fWuvmT59uweVVKPrJgHL7ZHJSY8lJdZCTEssFeZ3ISXXQLcXBzM9fo7CmMOgaLRfA5mp+zSVIG+nk6KTplTQwOJlY4Mylny+RzoIbSdWO1oTFbIcJvwVR4uuNh1A0LBQ9xhHoe+49foU3V+1r1ZhqyRsr9qJqFHNFO06frPL+AT93nD+FrCefRHQEUvNuzyzjtnnrNMVzIqXN2MwSt07oHvVnMDhxmNpjKv9e9++Q11PPT0VySJS+UxpIGTM/gL3nSBKnPIBgOhrB1GsMNAqxNBGTCKNuhZG3QMlaqD0QqKWyxUNaX+jQ7Zg/44mA3WLSrMHUq0CnqCql/h/xtJCyD1cznXZRGvWb6qleWU3Vkiokm4Stq420KWmY4jY1GVMmUWDaGVlYTRG+w5G3wvp5QcbULNN3PC9fFHKo6k/A0e93qHzP4fc/ouSFrT85Y3oSP2pGWEPKjpup9pVI42N56dsnuN8ay6TuJiyiyj925zCvOAl5wj2a50bjUDJ1fJudVTuQiV4wxy27efrHp8mKzWJil4mRT9CBYUwZnBBUNYSq2ID+8LkgBFStIglRtGRvpXaxqp77+xWV4nInZ/ftyNQzMuieGktanFVTeQtg9oDZPLry0RAVppYLYKMEafbN2UESpHaTnev7Xx/2+gYGJzr1Hj8Pf7KFgr1VPHvFfKSPzwdXFURrUJnt0P9SGHYjAEt2HTlmpwi0TQFPUdSoe0J9vOFgSBRe7zhFWwy7+g0h1XG0xunM3mlcMTybt1fvD6tGqkWMWeKmcTmMyAnfu8/gxCXeEs/Znc/mi+IvUFooY3YY36FJga1u21/QaqSqyxhQIdYWZusoCIGUXU5Nbf3MxBi8cmgKsZ5UWoAku4UKdzlqi2rN1mqm7T3tCBaBbg8EG6b+hqO1TyZR4PqxURiuyd3homfhk/9rEhBJE2rIFzfxvXKGZuQott/EoGhmS6w/ObA9fpk7pPncZvoEGYHYMT5y42z8c0kD189XiLMIDO60kZfyHWRZHuBW390UqplN14nGoSRaSxEd25EJTfuOpkfnX1f/lTM7n9mueyfDmDI4IXB6/UgaP2zd4XNBaJIu14MnTLG6nvtbJJHJ/dO5fHh0DS8ndZ3Eoysf1Xyv+QIYDkVVOK/beVHdy8DgROPHfVXc/c56RuUks+COsYEm2Df/AHMvgPqy1usvBDGQ2jf0Ojjn0ab6h4p67XQPvcaR3nlHlBW2nD2JGElAMJsj/qsxjwMhdPnVN06VSg0Z899f0BdZUXl3bUlUBlWMWeLa0V25++yeUdzT4ETl1kG3snD/whDnXHMEcxWqL1S5To8x4JUVurSxt+LJToLdTH7PFL7fVhYiWhOtAp3VJHLNqK5UElojqrdmujkWSWBwlyR6pMWGPSaIAZcGIlOf3v2TlLzKb03/Y4W3H070OaPNkkD31FgOVtbzd55lklRAjHB0bpqVZ2ZWXst+UCqyepCPLQ9xpfdB1qs9gOgcSuakpSCEKvdE26Oz1lvLmtI1DO80XNfnbA3DmDI4IYi1mUL6NoD+8LlfUcN7zVrBZpGo05DV0nN/SWzFY6eBVbLy62G/5m+r/xbSuT7ieCUb9wy5B7v59FzUDH55FEVldXEl+yqduLwyDquJHmmxDMxKaNXjJysqzy3czWsrinn04v5M7t/p6JsJWXDbStj8ASz9F9QeDBSOK15ADKh1KTL0ngyj74SsYBU9S5j0Fr3Gkd55R5VEes97Dfx+VJ8v+J/XF/Iay1W0dGP0jFMFzZorURT409T+DO+WzFPf7aSkMrTxuCSA2SSSkxLLPef04py+HUMvZHBS0SW+C8+c9Qy3fXtb2PXE0mEJnrLzQQ1tMRBtXd3YHilB9X+nGzeP686KwgpNp22kVNpGZo7ozDu7UxEQgqJTemumAVR/HCZRIDXOxvOzhuj7MAMvD6ipLv4H7PqWnqYKnlKf4VbfPcgRWzIHsEgCafE23rxhBO4FvyVhawF2ofUapkYkQSUWN69b/soU71/Yq6ZHdiiJHswJGwKKhs3Q06PT5XcxZ/Mcw5gyOPXoYLdglsSQppp6w+cWk0iSvZXiyzD0SHVwpC50AtBzf1kJSCLr4dJel3Kw/iCvb309aoPKJtm4vM/lzMydqeteBgbtQbXTy7tr9/Pykj00ePwoaqAGQPopxS01zsot47szdVBGSKuAkion97yzHrMksuCOfNITNDyg5hg448rAv5IC2L8SnFVgtoIjLaBG5dBOR8tItCEKoUaGXuNI77yTnWTHkhG9Qlnshm9xacw3esYpIgT1t2rJBXmduCCvE5sP1PDa8mJ2ldVT7/ETazWR2ymOa0Z3jVr51ODkYFDqEK7r9jRPfbcevzcRRTEhiF5EaymW5CWY4n/EU3ZB2PMjGQMxlkA66OnMsK5JZCTGsKe8IaTGOxJWk8g5fTuSGmflzOwzeXXzXLzK0XlAb820KlsQGwbRLcXBWzeOJMEefj4IS8YguPwtqD9C2fI3eXRpOkLYZhGhJMRY+OyOsSQ498KuN6GZIdX1qTqcPthzVywOS2B9eHmdl3kbfSy69mh6sgM3j5he5Wrf7yI6lCTrQVBD0xD19OhUUVl/ZH3UnzEaDGPK4ITAJInMGtGZV5eFKvpFGz63SCJXjezStKnTw+yxOWw8UBNaWKvj/llJMeR20r85uWPwHaTZ03h87eMICEGKTM2JMcWgqAp3Dr6Tq/pepfs+BgbHyvLCcm58bS2KquIKkxq7t8LJIwu28o8vt/PWjSObnomP1x/gz59u5aZxOdyYnxNdfVHWkMC/KJk6KJM3Vu4LkRzXaxxB9M+9vQ0bzAm9U/mwoISWZVN6060Gd44sPtM/M0G3wqnByYWsqDzz/S5eWboHWVXxeo6mmquKHdmfgNvVDUH0YLLvQ3Z2RVX1bf9MokB2kp3h3VpPPz/VEQSB168fzvlPL6HWFdoTM+x5BFT8/jy1Hwu3l/HQR4fxJCcgWMqCjou2ZhpAxMz94y5mxrCuxFiOTSWx3pzEjI0DOeBzoqdQos7t45WlxdzreR6U0OweWYV/r/LyYH74aKYkqAwXt5NOBQcjOZQk7f2R3h6drmh7o0WJYUwZnDBcNaoLc5cXo9VCM5rwuSAErtEWzuyThkUSaQgzjUS6v8Miccv4tithzegzgwtyLuCTwk+Yu2UuVe4qTGLg8fQrfhKsCVzb71qm9phKnCUuwtUMDNqfRTvKuGVeQVTNcJ1eGadX5tLnl/PqtcP435r9bCip5rXrh9M/8/g1TuyfmUB2Ugy7yupD3ovWOGpONPOOqsLFZ2S2ekxLZo/txqcbDiJr/L+MZpySKDBlYAZxtjZ4og1OKdw+mdlz17BuX1VYBwcIqIoVVbFiIQGzSWjl2FAkIRAFfX32cEPwCMhIjGH+bWOY8cIKql2+sH3tGrGZRbqnxpKRGMNl/13BgSoXbr+CSR2PLf0jBDFYSCGammmzaOGmQddy7cD2UeB84usdHKh2hzh4oHWZcrdf4cXFhVxgWUxvDWPqgdEW/rHMw23DLCTaWv/tXGn6lseZ0apDKeXCMYT0lEF/vZkktK9Ev2FMGZwwZCXZObN3Ggt3lIWk+0XCahI5K7cjnRLCNKqLgCQK3D6xB49/tVOXChYEHmubWeKCvE4Rj22NWEssM3NnckWfKzhQf4AaT6BfTLw1nqzYLGMRM/jF2F1Wz21vrovKkGqO0ytz+UsrueSMTBbcMTYk7a+9qXH56OCwIKDlkom+piFaYswSV43qEhDP0EGf9HhyUmLZeqhW8/1I4zRLArOjUe0yOKWRFZWb3ihg7d6qqNdMt19FEFQyE21UOX0RBZtsZpGUWCvv3DzK6EHWjG4pDr6+ZxyvLN3D6yv24leUoMwWgUDUOj7GzE35OVwxojNfbynlnnc2NNUv+msGoSSuRLQdRBCj33dIgkSGo1O7Zai4fTLvrNmvaRRGI1PukxVe8U3iH9LukPOHZkhM6Gri8eUeHp0Y/vdjE3xcJC7ncWa06lAKNAYOHafeerP2dkobxpTBCcW/Zgzi4meXsaeiIaK3pxGrSaRbioMnpx9bKsv1Y7rx475qvtt2WJfXLsYi8daNI7GZ28fTIQgCWXFZZMVltcv1DAyOlf98vyus4mWk5oomIZAedLwNqS83H+KPn2zhrNw0zJLAmuLoN5htwWYWGdo1id9M7tOm8x+9pD8zX1qp20C1mUUm90tvU0qxwanF/9bsY82eSs3feWvPpapCZYOPW8fn8O32MnaW1uFT1KAaoOaGwPRh2cTqdBicDiTaLdw3qTd3ntWT77Yd5uuthymv92ISBdLirEwdlMnInA4IgkCd28dvPtgUJAQDEs79s/Ed/iOV3xXhLfUg2kRsnW2kTknF0csRck+zaCY5Jpk5k+fgMIe+3xY+23hII9YTfd87WYVPfMP5gziHWCG09vvPZ1oZM6eBu0a0Xs8eLxxVogznUFLcCqpiRZCClUz11JuZRTMX5lzY6lj0YjwdBicUMRaJ928dxbWvrmHbodqIXjO7RaJfRjyvXjf8mI0ZQRB4asYgfvfhJj7bdCjivS0mkRizxFs3jqB3upF6Z3BqUuP08eXm0mcFDSAAACAASURBVBabgABReS0VldeWF3PHxB6YJH3NtKOhrM7NHz/ewo7SOv5zxWCGd+uA2ydzzZzVbCypiVoePLdTHHVuPweqXVHNO2f2SeNf0we1qUYTYHDnJJ6aMYi731kftUEVY5Y4o3OiUQN1GlHV4OX9ghLW7aui1uXDbjHRPc3B9KHZPL+oUPP3Hc1z6fbJrN9fwye3j2V3WR1fbC6ltMaNV1ZIjbUytkcKo7onGxkRUWCWRCb37xSsTNqCDwpKNHtp1q78gppV+0m/YggJQ2sQJJH6LVXUrasLMqZURcJqMjG44xk8Pv5xEqztly791up9x9yfT0JmoTKIKdLKkPf6p0lc2MvE35Z6yU0NvwZEJ3wh4q3Mx5r6TUh6ZLT1ZqIgtruAl2FMGbQ79fX1bN68mcrKSjweDw6Hg/T0dHJzczGbI+f4x9nMvHPTSL7cUsrziwopPFKPX1abpNNNqEiqQs/MJG6Z0J3J/dLbbZNmkkT+8as8zhuQzvOLCtlYUoOiqviaJRI7rBKSIHD1qK5cM7orqXGnr0yswanPewX7ETV2AdF6LSGQBvLttjIm909vt3Gpqsp7a0v4+5fbuXx4Nv+aMajJoWIzS7x5wwj+8dUO5q3cC6BpINl/Ktq+cmQXfn1ub2RVZcGGQ/z3h0JKqlz4ZKVp3mlsSjmkSxI3j+/OuJ4px7zRnNy/E3NizNzyRgGyqmoK4EBAXEcQ4KJBGfzl4v7HxSg1OLHYdqiWZ77fzbfbDiMIBBnci3YIAbEJDfWDaJ9LFVhaWE5ZrZseaXHcMdFwCB4vVFXlxSVFIXNQ8+/KnDIWV0kN5qSVxOWtIG6gD1URCaS0ScjVwxjf5TKenHRWu49PS8kY9PW98yFRroY38P40wcbgF+q5b1T4/VK9Gl2ZhtQwHEunb9HyQUWqNxMFkYGpA8mM1VfnGgnDmDJoN0pKSli6dCm7du1CEAT8/qPFiBaLhQULFjBkyBBGjBhBYmLrKlQmSeTCvAwuzMtgR2kdC3eUNTXkTPQ66fXfxzj7bx8cF6+ZIAhM7NORiX06sreigY/XH+RAtQuX14/4yYecc9e1TBqWg9nY0BicBqwsqtD0fuvxWjZ4ZTbsrw4ypvZXOtl8oIY6tx+rWSQtzsbwbh2iivTsq3Dyu/kbqXH5eH32cPplhC7iJknkwfNzuefsXny68SAvLS7iQLULj0/BahbJTIzhxnE5TMnLaFLCMgGXDsni0iFZbCqpYVlhOZUNXsySSLLDwqR+HclKat/ebqO7p7D29+fw9daA82h3WT3mn4wnRVERRYGrRnbhypFdyEhsW02owcnFgo0HeeC9DXj8iqZanK8VCTk9z6UALNh4iOuN+rvjys7D9VQ7fSGvt/yuVH8C3iPn4j1yDogeBNGDqlhAsQEi39XKMK39x+dXtCPjevreKYh4Ce8s79FBZEY/M0+v9jIgLXTv5FVNfKUM1TjzKCZRwGISeW5mPrWmP/LIikd09+iMNcfyyJhHdJ0TDYYxZXDMqKrKkiVLWLJkCT5f6IQB4PUG8ltXr15NQUEBV1xxBd26RTeB906PC0qj236olqc+Hc7v/vwVTjmwIMTHmLl4UCZXjWrfDUeXZAd3ntWz6e897z9GekyDYUgZnDZobQJAn9cSoLzBg6yo/LCzrCnqa5ZEFFUN5OsLYDVJXDemK1cM70yKRmNQWVF5ddkenl24m1vGd2f22G4RozQxFonpQ7OZPjQ7qnE2MiArgQFZx095sDkW01HnUWmNm/J6D15ZId5mpnMHOxaTMd+cLny5uZT739ugu5auET3PpcevcLCmfSWiDUKpqPdoOonCf1ciKDGoSvBeps7tQ1XVdncix1nNQKhRoqfvnVmEhGY1T1r8YbyVNzZqrycmk8S2jJlY9gfmuuY183aLhKrCtMGZ3DQuhy7JDuAiajw1PL3u6agMKlEQiTXHMufcOWTERt8TMFoMY8rgmFm0aBHLly8Pa0g1P66yspJp06bx1ltvMWvWLLp27Rr1fdbtq+KhjzZTeKQeb1o/FNdRb3mDV+aVpXt4ZdkehnftwGOXDKBzcvt6kAEsXbviLS4mZqBRs2BwehCuFlGP1xICTo/z/r2YA1Wupvz8loXzDR6ZZxfu5tmFu/nnrwYyZeDRRW97aS2/eX8jMRaJ+beNoWtK+xRfn2ikJ9i0mxkbnPLsrww0tdYypCIJvTSi97mMVB9ocOyEiyTq/a5UNfCvvRNyxvVKoai8PqicAfT1vVMFM0PNxUFCe8V3B6eOZieIuH+vLZwjZpzBP2dfzP21bj5YV8Lecid1Hh9Jdgt5WQlMGRjaBP6qvleRGZvJX1f/lVpPLS6/C7VF3ZVZNDel9j0y5pHjYkiBYUwZHCM7d+4MMaQ2bdrEihUrKC8vx2q1kp6eTn5+ftB5Pp+Pt956izvvvJPY2NiI9/li0yHuebf5IqMRJpYD7y0vLOeC/yzhzRtGkJcVuamlHixduuDdu7ddr2lgcCKT3cGOKBCSbqTLaykJfL7pEG7f0RqkcDQ+4w+8v4EGj59LBmfyzPe7eXPVPh44tzeXD8s2iuINTkleXbZHM+UqGkGJRvQ8lwApjtYV1gyOnXibSbNXg97vymoWo2t2rpNrRnfl9RV70RpktP35enWKp0fGUNh6CPQ2xDXb4fx/AtAx3sZtE3pEferEzhM5M/tMCg4X8OqWVyk4XIDb70YUROIscUzJmcIVuVe0e41USwxjyuCY+P7774MMqRUrVrB06VIuvPBCunfvjiRJ7N69m+3bt2OxBE/aiqKwdu1aJkyY0Oo9lu8ub2FItY6iQp3bz8yXVvHpHWPp1o4ebEvXLtQvWtRu1zMwONG5fFg2H/14IKRuSo/X0i+rKIqs2RAyHG6fwh8+3sLT3++if0YCX9yVb/S5MThlaez10zI6oEfoBfQ9lw6LxBmdk47fhzIAAn3ltJxIer4rgDOyW/+uPH6ZBo+M3SLpUjfOSrIzuHMSK4oqNN+P1PfObpG4dXx36Psc1JfB/lXgi9KgMsfAjDegU17U422JIAgMTR/K0PTWa66OJ4YxZdBmysrKKC8vb/rb7XazcOFCpk6dSm5ubtPrvXv3pnfv3ixqYYT4/X5WrVpFfn4+kqT94Hv9CrfMK2hT2oPT6+f2t9bx2Z35Iee2lUCanxGZMjixKat189mmQxyqceP2ySQ7rAzpksTo7sm6PJuqqlLt8qFoyKJD9F5LFTQNqUjPsFdWMIkCL1w1xIhGGZzSfLP1sObregQlGon2ubSZJcb1Sj2mcRtEJsYicemQTP63en+IURXtd+WwSNwyoXvItY/UeXhr1V7mLi+mxuXDLIn4ZIU4m5mrRwWEa6JxQj14fi6XvbBcd62eWRTI7mDn7L4dQRJh1gfw+f2w4a3AAX5tpUAssYGI1BVvQ9YvZwS1F4YxZdBmVq1ahSwf9VaXlJTg9/uDDKlIyLLM7t276d27t+b7X21pe38bRYXCI/VsL62lT/qxNbhUVZWVRZW8uKqedV2n433oS8ySQHKslatGduHSIVkkxESWfTcwOJ6sLKrghR8KWVZYgcDRmiSBgPfQYTVxY34OM4ZnE29r/fe6fHc5T36zk2qXj8uGZvFBQWh0CiJ7LQUhkOffkmhTl8pqPewuq6dnR0O62eDUZW9Fg+bzFUlQIpxDItJzaTOJzB7brc190gz0cf2Ybry3tkQzQhXpuwKQJIH8HilNf7u8Mr/5YCNfbi5FEI7O9Y3/rXH5eHFxES8sLmJi71SemD4Ih9WE16/w9dZSPt1wsEkSvYPDyoV5nfjnrwbywPvRi5+YJYGUWCtv3TDiqCiXZIIpT8GE38LaObDqBZC9IEqBhUD2QuZQGHs39Dg78PopgGFMGbSZQ4cOoTbbJTmdTux2O6IYvfKU3+/nyJEjYY2p538oDOm9oq+/jcorS/YcU5PLBRsP8tjn26h2+nB5ZVRLLPhkXD6odfv551c7+PuX27kwrxMPX9SPuAibVAOD9kZRVP68YAvvrCnB7ZNDMt9VAiItDV6ZJ7/ZwYtLinjnppHkpIbWK67eU8mT3+ygtMbN3Wf3YsrADAQCRs3inUdw+6P3XNpMIl5ZCRmPrmdYUXhl6R7+dmnb00AMDE506tx+TRn01kQK9NRSNUcA7FYTM0d0budPYRCOnNRYzuufzpdbSnVHfyySiEkQ+M0HG3loSl9UFaa/sILi8oamWnEtGg2rhTuOcP7TS5jYO43315WgaPS0W1FYDoLAwKxEVu2pJMYshW14LhCItuWkOnjj+hEkadXdxaXDmQ/CuF9D7QFw14DJCo5UsIfvA3WyYhhTBm2mUe68EbvdjtPpRFGUqA0qRVHYsr8CYXc5CTFmEmLMxMeYibOaKKvzUFhWH3KOnrQHWVH5dMPBNhtT//pmJy8sLmx18muccD7dcJC1e6t47+ZRpBm1HQY/E6qq8psPNrJg46Gwi19zXD4Ft9/D1GeX8entY5tU8dbtq+Jf3+xkT3kDd57Vk2lnZAbJjv9n5hncOm8dKwq1+061JMYsMSg7kfX7q3C1eH70PcOBiJuBwalMfIxZl9CL3lqqRgKGlMTbN44k0W6IT/yc/POygRyqcbOhpDpqgyrGLHL7xJ5cM7orf/18G+c++QN2q4n9la5WDanmePwKeyuczF1erKWDAdCksLpqTyUJMSauH9ON99eVUFHvRRIEVGiKgI3tnsJN43MY0a1D5PRryQRJXaIa58mMYUwZtBmzOTgCk5WVhclkYvv27fTt2zeqa6gI7DziYtH3u6lx+Zr+uXwyMRYpqNdAI3r723hlBY9fxmrSF05+eUkRLy4uinrS88oqJVUupr+wggV35hNrNR4vg+PPvJV7wxpSrdUkNXj8XPHSSp6bOZinv9/FjtI6bp/Yk18NydLsa2Q1Sbx89VBeWlLEf38oxOtXmhbgRho9lgkxZu45uxcNXj/r9lWFXEvvM1zv8Uc+yMDgJKZ7qoMYsxTyTIUTKahd+wmqz6OrlspuFrFbTfzvppH0SDPSZn9uzJLIvBtGcP+7G/hqayk+WSGcPWQzi6gq/HFKPy4fHogg/uWSATxilpizdI+mURSpBjVa/Z9at5+5y4v57M6xVDv9lNa6cPsCfe96pceSFmc4i1ti7PYM2kxqaiqlpaVNqX42m40JEybw+eefI4oi3bt3RxRFioqKKC4uDjG+AKwWM9PPGkC/fsEpCX5ZYcmucm5/a13I4qK3N4MkCnj9ii5jam9FA//8akdIHxxofcKSFZWD1S7+9sU2Hr14QNT3MzBoC4qi8tR3uzQNqUgpQIoaEKq4du5q7pvUm/9eNSTiMyKKAjeP784N+Tks2lHGK0v3UFzRgNunYLdI9O4Yxw35OYzMCXgs31y1F62SDL3PsMVokm1wijOxT8ew4jBaIgWSPQnR5ojaISEIcNfZvbhmdFddSm8G7YtZEvn3FWew63Adc5YV89GPBzCJQsATRWBOj7FI3DC2GzOGdQ5KoVNVle+2HdY0ivSmfLa2j1HVgEF1w2sFfH5XPn0zjq3m/HTAMKYM2szw4cPZtm1bkDT66NGjiY2NZfHixXz44YdYLBYyMjLIz8+nsLBQ8zpa9VImSSQrKUbjaP29GfyKisOi76c+d3mxpoJZNBOWV1b5oOAAD56fG9JkzsCgPflh1xHcGk03o00BktWALO7Vo7rquq8kCpyV25Gzcju2elxanA1JFAnq5Ij+Zzg1zqprfAYGJxsWk8hVI7vw8tI9mhkZLUUKXEUFlL3/p6gdEjaTxAV5nQxD6gShZ8c4/jptAL+/IJctB2upcfkwiQLJsRb6ZyRoGtbr91dzuDZUHU9vymc0+xhZUdlT3sCmkhoGZCW04yc/NTF2egZtJjMzk/j4eCoqgusZ8vLyyMsLLRbPzs4O+luSJIYMGYLJpP0z7JLs0MzH1duboXfHOF1y0O3R70MQAjVUM4YZBb4Gx48XFxeFRG5BX01S4ZF6dpfV0yMtcvNsveT3TAkSqWlEby+cK0ee+jn3BgZXj+rK3OXFeCMfqtsh4ZMVo0bqBMRhNTG8W3SCDF9sPoRbIwtBz3yvZx/j8cu8tKSIp684I6rxnc4YuRMGbUYQBMaPH6+ZvhcNoigyfPjwsO9bTCKzRnTGLIUaQvHDp5E0cTY1K96h5D+zKHn+WurWLSCmZ/Bk4rBI3KrRm6E1luwqR9Qw4vRMWE6vzLyV+3Td18BAL7sO12m+rqcmySSK7C7Tvs6xYjNLzBiWfUzPsApMGZhxXMZnYHAikZ5g44WrhmAzR96aNXdIOHeuQPG5UWU/rsK1VC2cE3L8wOxEo473JOdwrUczxU/PfK/L8FLhyy2luDQcdgbBGE+WwTExYMAA9uzZw+bNm4PS/SJhMpm45JJLSEpqvaP3VaO6MHd5MVqlk9H0ZkCRmdw/PepxQaAJnqwcu/BFeX2YZnUGBu1EuEVOT02SrKjUuY+fwMN1Y7rx1qp9+NrwDFtNItOHZhupSQanDfk9U3nhqqHc8kYBPlnR7EvUiJ6GrzePyzneQzc4zshhfgu65nud+xifrHCgykkPo89fqxiRKYNjQhAEpkyZQl5eXtQRKrPZzNSpU6NS/MtKsnPRoAxiovDUtSRGgqt3f0/lo4+gNDREfZ7XL0fs9xENvihlSw0M2ko4wYjmKUCRkESOa21fdgc7UwfpjyyZRIGspBh+PVm7B52BwanK+F6pfHl3PpcPzybGLLW6/sX2O5NO1zxF53s/IPv2eaRd9jC2rNygYywmkYl90o73sA2OM2lhakd1zfc69zGqCrNfX6OZrm1wFMOYMjhmRFHkwgsvZNq0aXTq1AmTyRRS62QymTCZTPTp04frr7+eAQOiV7p77JIBDMpOjCr1oZEYs8Slw7pw/4t/RPV4KZp6MQ2rV0d1bnyMOaCu0wI9ExZgpFQYHHcykrQlavWkACkqYcVe2oO3V+9j4Y4jXDWyS9TPsNUk0jnZzjs3jzJEXAxOS7okO3j04gGse+gc/jS1P2fnpmkqY0bCZhZ5+oozgnrGGZycnNknDYcl1IGmZ77X2sc0bF3EodfuZt+Tv6Lkmas4/O4fcZdsaXp/b4WL33yw8fh9sFMAY5UyaBcEQSA3N5fc3FzKysooKCigoqICr9dLTEwMWVlZDB48GIfDofvaZknk9dkjuO/dDXy9tRSfrIYNd5slAVEQuHl8Dned1RNBEMj462PULVzIwfsfIG7yuaTdcw9ijPbmsbi8gSU7j2gW9espmjeJAqO6p+j+rAYGepg9thu/n79Z8/cabQpQB4eFvOOg1qSqKs8tKuR/a/bxzs2j6JbiYGyPZB77fDtH6j24fDItnZ12i4Siqlw6OIsHz8/FYTgkDE5zYiwS04dmM31oNp+sP8CvP9gYde9Dm1nk79PyyO+ZepxHafBzMKZ7Cg6r6Zjm+5b7GF9FCTVr5hM36Dysmf1IOvM6TUn19wtKGNKlAzOGZePxy3y5uZRXlxVzqMaFx6/gsJgYmJXADeNyOCM7MXIz31MMobXQ3dChQ9W1a9f+jMMxMGidrQdreWVpEQs2HsJiElFU9af2DIEO3TOHZ3P1qK5kd7CHnOuvquLwo3/BvWULnf76GPYzjirUHKnzcPvb61i/rxpFUfG1kqdev2UhdWs/xlexP2jCap5aYTOJLLgz/7gopB0rgiAUqKo69Jcex7FizE8B5ckhj3yjubhGg90i8eD5fbhyZNd2HZeiqDz62TaW7S7n9dnD6Rh/NIKmqirr9lXz0uJCftxfjdMjYzaJpMZZuXpkFy4+I9Mwok5jToX56XjOTcsLy3ngvY1UOb24vHJIJWJj4+xEu5nHLxvIaMOpd0rx0uIinvhmR9QGdTjqtyykdvV8fGVFCFYHpoQ0VL8Xua5Cs+EvgFUSuHJUF/63pgRUNbTBtABWs0RanJXfndeHyf07HdMYTzRam5sMY8rgpKTW7WPd3ipqXD5EQaCDw8LQrklRNeat/eprSh95hMSLp5Jyxx0caJC55LllVDt9rRb76iEvK4FPbo8sV/tLcCpsVsCYnxr5+xfbeHV5cZsW11iriZUPnhVVSqrHL1Pn9hNjlrBbpLCeR5+s8Ov3N7K/0skr1wwjwd42tU+D05NTYX463nOTqqqs3lPJC4uLWLLrSFOmhiQK5PdI5abxOYzo1uG0iw6cDtR7/Jz/7yUcqHaFzdCJlsZeZYnjrqF2zYchfac8+7cEZdwIgCQJ+OXI97WZRW4e1517zul1TGM8kWhtbjLcfwYnJfE2MxN6t62gNv7cSdiHDqH04T+xYfpMbh1yI5UubdGJtmAzi/y/83MjH2hg0A7cO6k3q4ur2FRSjTeKRa4Rm1lkzrXDWjWkat0+PlhbwotLijhc68YsiciKilkSuXRIJteP6UZO6tHoq8src9ubBQiCwBuzRxCjkd9vYGBwbAiCwIicZEbkJAM09R4yVC9PfWKtJt65eSQXPbOMqgbvMTmAZVctYkwcNcvfjqrvlApRGVIAbp/Ci4uLcFglbhqnrz3NyYhhTBmclpiSk8l8+t88//RnlB/woIihj0LD1kXUrvkIX0VJ2LB3S2xmkUen9m9a5AwMjhWfrPDN1sN8vaWU8novoggd42xMGZjB2B4pmCWROdcMJf8fC1FUOeLiKgqBTdcLVw0J2yxSVlQe+3wb81buRRQEXD9t1jz+QPTLr8j8b/V+3ltbQl5WAs/OHIzFJDL7tbV0Sbbz90vzMBsF7wYGPwuGEXV60Skhhi/uymf23DXsPFyPJ4wCsSiARRJx+7WzFqSYeBRXHahqVH2ntGhtn+TyyTz5zU4m9kmjR9qpLa1uGFMGpy0ev8L8ChN+DYmk2tXzqVn1fkjYu2VRZiMxZhEVeGrGGbr7WhkYaFHt9PLSkiLeWLEXWQnNT/980yHsVhM35nfjcK2HvKwEJvbpyEtLiqhx+ULqKWxmEVWFSX07ctfZvcLW8/lkhevnrmFtcWWT8aSFX1HxKyo/7qtm0lOLSYqxcFZuGg+en4vYFtkxAwMDA4OoSIm18vHtY9lUUsPLS4r4ckspkggqAn5ZQVZUunRwML53Kt9vO8y+KlfINayZfUAUEUyWqPtONSeafZLPr/Dykj08eEEun288xL5KJ7VuH0l2C33S4zmnb0csppPf8WYYUwanLQs2HkJry6d4Gqhe+mZUYW+AZIeFm8fnMH1oNol2y3EcscHpQnF5AzNeWEGV0xs2da/BK9PglfnnVztQVfjy7nH0SIvlujFdWVlUyf/W7ONgdUBpKTHGzNieKRF/o6qqcu8761lTXBl1DZZfUal2+vDJCreMzzEMKQMDA4OfiQFZCfz10gEMyEoIONKcPvyyigrsqWigZJUTRQ3UO7VcSUSrg9h+E6nf+DUNO5YRkzMkrEJxS6LdJ8kqvLt2Px+uK0ESxaYsBwCHVUL8UODKEV24ZnRX0hO0W32cDBjGlMFpy/wfSzRV0DwHtqP6vVGFvS2SwB1n9eDa0d2OxxANTkMO1bi45Lll1Lh8UdXx+WQVSYDZr61hwR1jibOZGdU9mVHd9aeartpTyXfbyzQNqUhpr16fwpPf7OKxadH3kDMwMDAwaDuHalxc/uJKymrduDTmbV+EGqekibOp3/I91QtfpeKzJ8NKqrek+T4p0tqgqAScgnLwfqvBE/j7laV7eH1lMa9cM4yRJ2mJhGFMGZy2VNR7NV+XXbWI9viowt4+WaXG6W/voRmcpqiqynWvrqHW7Q8xpFpbsGQVDtW4uffdDbx0dduF0F78oQiXhoMhqnQORWX+jwf4/YW5RqNdAwMDg+NMWZ2bKf9ZSpXT12ZlP9HqIGnc1dSs+pCUC+8P6p9ZtXBO2OhU4z6pbu0nukoitPDKCl4Zrnt1Na9eN/ykNKiMFc/gtCXc5CPFxKE4a8lUD+EVbNQQixdteWcVUNRj6/dgYNDIj/ur2VfpDPltRmPMeP0Ki3ce4VCNi04J2k2pW6Os1s2ywvKQVBA9aa+CAB/9eICZI7rovr+BgYGBQXSoqsrVr6ymWsOQ0iueFW3D3+ZIMfEozlqqlswj5YJ7oiqJiDQ2l09h9mtr+P6+CUG9CU8GDGPK4LQlyRFcO5JAPdOlRUzvtoBBJoVbi+7j0r4WJBS+VQbzkv8CflR7QLNKK4tJJCHGqJMyaB9eWlwUlFMO+owZFXhjxV5+PbmP7nt/v70MUaMvjZ60V6dX5sN1hjFlYGBgcDxZvaeSfZXOEPVWveJZjcT2O5PYfmdGff+AeIUEUa4N0Y7NL6u8vqKYByZkQvEycFYETo5Jgi6jISYx6jH+nBjGlMFpy/n909l0oAaP18eDpje5UvoWBQG72cufJ1i574s6YiUbk7qbmCisxrN3NSnFEkXj/0ahmgmACIzrlfrLfhCDU4I6t4/vtpfRso+6HmPG61eYt7JtxlSl04tXDk3x05P2ClDl1E6fNTAwMDBoH15cHJqSrVc861gQrQ7svUbj3LYY1+7VQemBWuIV0Y6ts7yPnBUvoq5ZgSCZQPkp80cUQfZB36kw6nbolNeun+dYMYwpg9OWaUOy+McXW3jZ/Dgjxa3YBF/Te/eNtpIeK/LoYg+zPnQRZxEYkiHx27FWzrA8xEzv/2Oj2p1e6XFhJaYNDPRQWuPGLAl4W5Tg6TVmGjwybp+su/eMqhJiyMHRdA5VkaMaQ1tz9w0MDAwMIlNR72Hp7tCUbD2Ot/Ygtv9EnNsWU7P8fxHTAyONTUDhYdNrTJd+wKz6EfwKaJWjb3oftn0KuVNh6jMgnRhmzIkxCgODX4B4m5nXkl+nX/VW7EKoN31WnplZeVq1Um7mWR7jV+rfuXX84OM/UIPTggavrJlmp9eYMUkC9R6/bmMqIcaMxSSGKPlZM/sgmMw4d67A0WdsxOsY7QEMDAwMjh/FFQ1YTGJIH0C96YuTvAAAIABJREFUjjcAh0XSVDVufE8lkL6thTWzD4LZQvzIyyKuDa2PTeU/5v8wUfyRGI29WPChMvicsO0jaDgMs94PpBv+wpz8nbIMDNrKgQKGNCwOMqS6PlVH2j/raPAe9fm8vM7LhLkNQac6cPOo7U3O6dvxZxuuwalNrFVC0QgNNTdmosEnK8Ra9fvJ8numaEamRKuDxLGzqPzmvzh3rkDxuVFlP67CtVQtnBN0rM0scv4Ao2m1gYGBwfGizq2tINzc8RYtf5uWx4DMeEyigN0iYbdImCWBnmmx/Hlqf9Y9dA7p8VbNc/WsDa2N7V7pPSaK6zWd2mHxuWDfSvji19GfcxwxIlMGpy/Ln0WUPSEvyyr8e5WXB/O1JxAASVAZJq9HcB6BOMOgMjh2OiXE4NfoCdJ8wRJEKWJueqLdojsqBdAl2cGAzATW7q0KeS9atSdVhelDs3Xf28DAwMAgOsK1ntCbRWASBaYMymDKoAxq3T5qnD4UVSUxxkKC/WhWzh0Te/Lwp1s0e1ZFuzaEG1ssTm4yfRZUZgEBx7bTB3vuisVhCWRsvLzOy7yNPhZd6wgc5HPCujdg/G8gNi3i5z2eGMaUwemJsxJ2LAANWfMHRlv4xzIPtw2zkGgLTbtqRFFBKngVJvz2eI7U4DTBYTVx/oBOfLL+IHKLEFHUC5ZJ5NrRXds8huHdOmgaUxBZ7UkUYFLfjkaan4GBgcFxYH+lk3kr97J2bxX1GtEpvY63lNijDuN4m5l4m3YLmIvPyOTRz7bh0xAoguiUAMONLW//G/xun5N/TQq9dzSObQQB1r4KE37T6v2PN4YxZXB6smcxiBYgNDI1NENiQlcTjy/38OjE8L0OJMWD68d3iTGMKYN24ob8bny5uTREHh2ilK6VvcxcPwsKKsBsh465AeWjrvmBRScMh2vd/PHjLeworaVPehx7yhtC8vEj4bCa2qQiaGBgYGAQnpVFFTz17U5+3FeNoqqaEaJGonW82cwi146OroWFw2ri8uHZvL16X0hNrR60xna4k4fbx2mvTVE5tv1uWPVfGHf/L1o7ZRhTBqcnrkpQtPOOAf58ppUxcxq4a0TrXvaGmgrKK51kd7ADUO/x89GPJXy1+TCVTi+iIJAaZ2Xa4Ewm9U3HYjLKFA3C0y8jgdxOcWw6UNPqgqmFDS+ThVWk1O04+mLNfiheCrb/z959h7dV3Q0c/56rYdnydpy99yJkkZBFEjZhlhEg7A0FCrTQ9oW+LW0ppS19KaMByghl7713CJkkIQlZZDrLcWI73ta+5/1DSrCta1tSvBL/Ps+jJ/Kd5yrSufd3ZiZMuxNGXVQrqDJNzStLd/CPT37kovE9+dcFIwmamnNmLyCvOLaASgEpThvPXTX+wO9ACCHEwXvq2y3845Mf4wpiYil40xrOP6pnzMe8c8YQfthZFp5OJs6CtvrS5sbDx0nX4lDWNV6xFmwT8EDlHkjvmnC6DpY82Yn2yaqnfQ3DO9o4baCd+75tuENkitPG1f9dysY9FfzmjVWMvecz7v1wPfM2FbEmv5wfdpXx5fq9/OaNVYy55zP+9tF6KryBBo8p2rcnLzuKbLcTu1F/TVJdTvz0Vfnc53iyzhoN/ioo3wUf3QHv33Zg3o4thZVc+MQiXv5uBy9eM55fnTgIl8NGapKdt2+cxDEDc0myGzgaSIfbaaNLpou3b5zEyB5tczJFIYQ4FD27II/7P9lwULVBVpw2xSnDu5Dljr1JtsNm8NxV4xnXJ5uUBPrkWkmnmgANH+tP05N4eImfwqoGPgPDBt6yJklToqRmSrRPKdlgNPz1/+M0F6Mfr+RXE+pvr5uclk2/bDenPDgPjSZUz++9yhcueXnq2618uHo3L197NF0ykhNOvjh8ZbudvHvTZC74zyIKyryWTf5qSsbLMJXHHOc/ojrx1hKohlWvELK7eMx1NU/O28LNxw7gsom9sdUJmJKdNp64dCxbi6p4Zv5WXlu2EwUYhkLr8OTA4/tmc90x/ZjYLwcjjsBPCCFEw9bml3PvR+ssA6mqtV9T/t3bBIp3YjiTcXTsS8bEmbi6D4vp2O4kO/eePTzuNCU7bTxzxTjeW5nPo3M3sz3SeqHm1ILuyGiAPbJT2Li3ssFA0I8dI2q2rNpqFmwPya2v/keDrXX76kowJdqn3sdAqOFap/7ZBucPc/DQEj9HdLT4EduT2NL7Ar5ZWkQwxolK/SGTnSUezpm9gA9vmSKd9YWlTukuPvjFZF79bgf/mbeF0uoAHn/owG3HQJOEn+5qL9fb3+cMYwEDHiyNafSjwOKnKe/Yj3dvuqjRZnl9Orj545nD+Z8ZQ8gv9VDhDZLstNExLUm+u0II0Uwe/2YzgWD0c0X5krcoW/w6OSfeiKvPaJTNjmfrMjwbFx8IphoLtjyBEN9uLOLEYfFPY2EzFGeN6sZZo7qxJr+M91bmU1DmxR8y6ZCaxJQBuRw7ODyy3l1v/cA7K/LrLRAsw41B47VujRZsh/zg7hD3tTQlCaZE++TOgUEnh2fSthjRb7/fT03iuVXWpf2mqbhs9ZFU+aODsoYys5CpKaz0cesrK3jminFNdkni8JLitHP5pD5cNrE3i7fu48v1eymq8GEzFJ1+fIGT/Z8w3MirtU8sox8l4eO3aR+jsq+JOS0uh42+uamJXooQQogYlXkCfLy6IGpUV9NXRem3L5Az41ZSBk08sDyl/3hS+o8HYgu2vAGT+z/9MaFgqqZhXTMY1jWj3vV/PfsIpgzI5eEvN7KtuAp/SBOqUfBs2Jx8ocdyolqCrYEaqkYLtntOBFf96WgJEkyJ9mviL2DjZ+HmTxF5t6bV2qRHhoH3d+kWOyvmdzyfkvxQVPerWDKzQEizcHMxu0o9dMuU5n6ifkopju6bw9F9c8ILdi2DDc+BUR21bSyjHymAvHlQnt+qHXaFEEJEe29lPobF6Ku+XevRQT8pAydY7hdLsLXf9n3VrN5VxvBuzReEKKU4dUQXTh3RhbX55byweBubC6uo9gdJdzkY0T2Dsb3+gO3Nc2s9h1mpt2DbmQqTbmmmK4idBFOi/eo2BgbNgPUfQNAT375JqTyuf0aVv3YGEE9mprXm2QV5/M+MIQlfgmiHlj8XHg7WQsyjHwGsfgMm3twMCRRCCJGovKIqy6ZxIU85Rko6qp4hwBsLtmoKBDVz5m/lnzNHHnR6YzG0azp/+dkR0Su0hvRuULwJatROxVywnZQGfRuZMqQFyGh+ov1SCn72GPQYB45Ya4cUON3sO/s1luyMfqCNJzPzhzQvLdkeZ6JFu1eS12DT1JhGPwr5oFS+e0II0dZU+KynbbElp2NWl6NN6z5IjQVbtbbVmjX55QeVziahFFzwAjjd8e/rSIELXwKj9UOZ1k+BEK3J5oCL34QjZ4E9Cez1leaHgygye8JVn1OQNsxyzqh4MjMIz0sVqG8IQCGsNFKLGuuw/vgbblYhhBCi5WWlOCyXJ3UbjLI7qN6w0HJ9Y8FWXZX1BG0tLncQXPJ2uJaJGEeGdbjhwpeh66hmTVqspJmfEDY7lSf8nbdTr+ClBT9S7PMS1AZu5WGqfQ1X2D+lT/8h4Xa5vSaBUni27cOiSXOtzCyWgMpmKKr9ITKSpVxDxCg5q9FNYhnWH3duEyZKCCFEUxjQMQ2nTeGvM3G7keQmc/JF7PvsMZRhw9VnFMqw481bgXf7KjImnn8g2HIPntzoeZKbaL6oJtHjKLh2Lnx4O+TNDy8L+WpvY3OCMqD7WDjlH9BpaMunsx4STIl2raw6wF8/WsfbK3ZhKEW1PxkIN/kr0hm8FOrMK/oEhpWl8ztjKKMjEVRqkgNtMfFvzZKjWDKzYEiTmiQ/QxGH3lNgy9wGO+w2OvqRMxW6j2nGRAohhIhXhTfA84u2RQVS+6WPOxvDnUXZwlcoev9+lDOZpE79SZ9wfqPBVtb0K2sdq3dOAk3rmlNOP7jkLSjbBd89Fe7X6y0Nr3Olw5AzYNw1kNW7VZNpRZ7iRLu1q9TDzMcWsLfCR6CejCtgAqbJ8u2lzHpiEfefdySnjehK96xkghb7xJuZdcl0RU2YKkSDRs6CL/7U6GYNDeuPYYeBpzRxwoQQQiTKGwhx3mML2VJU1eB2qcOmkzrMetCFhoKtmlKcNi6f1Lupkt60MrrB8b8Pvw4REkyJdqm02s+5jy5gb7kvai6H+ngDJre/tpI0l4OpA3M5dUQX3l6RX2veBIg9M0t22Lhmct8muybRTiRnwZDTYfWboH9qGx/z6Ee2JBh/Hdgk+xdCiLbiN2+sYmtRFf5gdD/qxibiramhYGu/dJeDif1ymizt7Z3cTUW79Ns3f6Co0jqQaijT8gZMbnh+Gd/ddTxXTe7LBz/sjgqmILbMzNSac8Z2b7JrEu3IMb8OD+nfyNwclhwpcFTsE/YKIYRoXoUVPj5aXWAZSMUydyXEHnAlO2zcMK0fyqrjt0iIBFOi3Smq9PHV+r2WTftizbTe/n4XFx3diyO7Z/L99lL8cY7Il+ywcf5RPUh3WY/aI0SDcgfCzOfglYvjmyPN6YbL3oFUGXxCCCHaihcXb7ccxy7WuStjfXYxFBwzsAOXTujVnJfT7kgwJdqdl+uZ2ynWTKvaH+KxuZuZNb4n/7l0LKc9PI+CMm+9/a7qctkNRvbI4HenymS9on6mqfl2UxGPf7OZVTvL8AZC2AxFttvJrHE9uWDcFGxnPY/z9Ytw2RS2UANBlcMNDhdc+i50Ht5yFyGEEKJR/12Yh8+iViqWuStjfXbZ784ZQ6RWqolJMCXanRcWb08409qvuMrPut0VDO2azjs/n8QpD81jT7mv0f3shmJS/w7Mvng0dpsMhy6svbV8J3/9aD1VviBV/p/6RQVCmvxSL498tYmHv9xEerKdmSPf4Y6O38GCR8BfGZ5RXgfDg0xoILUjTL4Njjg3sYkRhRBCNBtvIERptfW8gLHMXRnPs4tNwWtLd3L7SYMSTq+IJsGUaHf2VSWeae1nMxR7KrwMNtP48wfrKKuuZ9Q0C9v3VVHtC5Fkb0NzPIg24+8fr2fO/K14AvU3HfVG1hVW+HljbRUXHHMVPY6+EbYvgJI88FeFJ0DMHQRdR2M5KZoQQohWV+UL4rAZloW8scxdGc+zS8CE9QXlB51mUZsUjYt2J2gxYATEN3u41uALhLj7vTV8vLoAr0UmWN+584qruejJxXgDsc1SLtqP/3yzmTnz8xoMpOraW+Fl5uMLKfUGofdkGHVxeLS+kbOg2xgJpIQQog1zJ9kJ1NPvuubclfWJ59kFoMwTe+GviI3UTIl2x+UwqPJFZzrxTLirFOwu9fLa0p14LIKihkbVCYQ0WworeXzuZm45fmCTXZc4tO3YV80/P91gWTrZ0PfJ1FBU4ePP76/lnzNHtkLKhRBCJMrlsJHqslPuCUats5q7svrH+ZQteJlg2V5syenYO/QCwxbTswsgA181AwmmRLszvGsGi7fui1oez4S7Hn+IT9YW4A1GB1KxjKrjDZrMWZDHjdP7S98pAcCc+VsxLYbqj+X7FDA176/azR/OGCY3SiGEOMRcPL4XT3671XJo9JpzVxa+cx+YJs5Ofel43p9w9RiGZ+syys1QTM8uDptiUOe0qHOIgyPBlGh3rp/aj9W7ymp17N8vlgl3FZDjdrJoS3RAFs+oOoGgyZfr93LisM5Nd3HikOQNhHjlux1RI0LG830yDMWby3Zy+aQ+LZJmIYQQTeOSCeFgqj6pw6aT0n8cO/99GTmn31qrBmr//aByzVcNPrsAGEoxa3zPZruO9kqCKdHuHDMwF5fDZhlMQeMT7iY7bZw5qhvPLYzu2xLPqDpV/vADtARTYuHmYsuhauP5Pnn8IV5askOCKSGEOMR0yUhm2sBcvv5xL/56pllp7H7Q2LMLwOieWXTPSjno9IrapH2RaHdshuLXJw8m2RH/aHpOm8GQLukkO4wDI6rVFM+oOgAF5d640yAOP4WVPkyLgVHi/T4VVzU+PL8QQoi25//OH0m3rOR618d7P6gr2WFw83H9E02eaIAEU6JdOv+oHswa3zOugMppU3TOcDHniqMIhDRWZUfxjqpT3wg+on0JhExMi29UvN+n+kaqFEII0balJtl544ZJGPUMwBrv/aCmZIeNG4/tz8R+HQ4ylcKKBFOi3frdqUP4+bR+uOwG9vpyr4gUp40hXdJ576bJpLscZKY4cNii94llGNOaMpJlsAAR/h7YLL6D8X6fUpOk5bYQQhyqst1OOqQmWa6L936wX7LDxi+O68+N06RWqrnInVe0W0opbj5uADNGdGHO/K28sWwXNkMRNE20BruhCJqakT0yuX5qP44ZmHvggfeo3tnYDYNAqHYJUTwjArrsBtMHdWzRaxZt0+ieWVGDT0B83yebAZP7S6mjEEIcyib2y+G9lbsJ1RndNZ77AYS7JUzol82N0wcwrk92S15CuyPBlGj3+uWmcs9ZR3DnjCHM21hEcaUffzBERoqDsb2y6ZEd3VlzZI9MOqUnkVdcHbUulhEBATRwwTgZVUdA18xkjuqVxfzNxVHrYv0+OQyDqybL4BNCCHEou3pKXz5Zs8dyDstY7wd9OqTw0jUT6Jzhaqlkt2sSTAkRkeK0c1KMI+sppbhhWj/++N5aqi1GBWxsVB1DwbGDO5LtdiacXnF4uW5qP77fUZrQ9wlgQKdUBnSS+UOEEOJQNrxbBt2zk9m4p9JyfWP3gxSnjfvOHiGBVAuSPlNCJOjMkd3ITU2y7OvSGJfDxq9OHNgMqRKHqsn9OzCoc5plX7zGuBwG/3vasGZIlRBCiJZ29+nDcDnif0RPshsc0S1DmvW1MAmmhEiQy2Hj5euOJjPF0egAFjUlO2z855Kx9O8otQjiJ4ah+O+V4+ielYLTHnvW7HIY3PuzI+TmKYQQh4lJ/TvEHVAl2Q16ZKfw1OVHWc5bKJqPBFNCHIQuGcl89Isp9Ongxu1seJh1t9NGusvO81ePY/IAGShAREt3OXjnpkkc2T2DFKeNhu6HyU4byQ4bD184mrNHd2+5RAohhGh2F4zryf/NPBKXw2hwGhdDhQtpx/TK4p0bJ8morq1APnEhDlLHdBef3HoM8zcX8djczSzNK8FpN9AalIJgSNM108UN0/pz2oguuBKYLFi0H+kuB69eN4Hl20t54pvNfPVjYa3vU8jUZKY4uHZKX84Z0500lwyvL4QQh6MZR3RlUv9cXlu6gyfmbaHSG8Qw1IH7gT9ocvyQTlxzTF+O7J4hNVKtRIIpIZqAYSimDMhlyoBcdpd52LS3kgpvkGSHjW5ZyQyUgQFEHJRSjOmVxZhLxlJc6WPd7grKvQGS7Aad0l0M65ouN00hhGgHMpIdXD2lL1dN7sMPu8oorPDhD5qkJzsY2iWdLBnIqtVJMCVEE+uSkUyXjOTWToY4TOSkJjF5gPUkjkIIIdoHpRQjume2djKEBekzJYQQQgghhBAJkGBKCCGEEEIIIRIgwZQQQgghhBBCJECCKSGEEEIIIYRIgARTQgghhBBCCJEACaaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAlQWuv6VypVCGxrueQIIVpAL611bmsn4mBJ/iTEYemQz58kbxLisFRv3tRgMCWEEEIIIYQQwpo08xNCCCGEEEKIBEgwJYQQQgghhBAJkGBKCCGEEEIIIRIgwZQQQgghhBBCJECCKSGEEEIIIYRIgARTQgghhBBCCJEACaaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAmQYEoIIYQQQgghEiDBlBBCCCGEEEIkQIIpIYQQQgghhEiABFNCCCGEEEIIkQAJpg4jSqlpSqmdLb1vS1FK9VRKVSqlbK2dFiFEfCR/EkK0RZI3iYMlwVQDIl++/S9TKeWp8fdFzXjey5VS3zbX8ZuCUkorpfo38znylFLH7/9ba71da52qtQ4153mFOBRI/lQ/yZ+EaD2SN9VP8qbDk721E9CWaa1T979XSuUBV2utP6+7nVLKrrUOtmTahBDtm+RPQoi2SPIm0d5IzVQC9lfrKqV+o5QqAOZYlYjULIFQSiUppe5XSm1XSu1RSj2mlEpO4NxXKKXWKaUqlFJblFLXWWxzp1KqKFI6cVGN5U2SBovz3a2UelUp9WwkXWuUUmNrrP+tUmpzZN1apdTP6ux/TY1rWquUGq2Ueg7oCbwXKc36tVKqd+QztSulzldKLa1znNuUUu8257UK0dZJ/hR1PsmfhGgDJG+KOp/kTYcJCaYS1xnIBnoB18aw/X3AQGAk0B/oBvw+gfPuBU4D0oErgAeUUqPrpKtD5PiXAf9RSg2KNw1KqdlKqdlxpOsM4GUgE3gXeKTGus3AFCAD+CPwvFKqS+Q85wF3A5dGrukMoFhrfQmwHTg9Uj399zrnew8YpJQaUGPZLODFeK9ViMOQ5E+1Sf4kRNsgeVNtkjcdDrTW8orhBeQBx0feTwP8gKvG+suBb+vsowl/GRVQBfSrsW4CsLWec0Udq4F0vQ3cUiNdQcBdY/2rwP82lobIvjvj+Dw00D/y/m7g8xrrhgKeBvZdAZwZef/J/vQ39JlH/u4dOa898vfzwO8j7wcAFUBKvJ+3vOR1qL8kf4o6r+RP8pJXG3hJ3hR1XsmbDsOX9JlKXKHW2hvjtrmEv6jLlFL7lykg7pFVlFKnAH8gXHJgRI77Q41NSrTWVTX+3gZ0bco01KOgxvtqwKUi7aGVUpcCvyT8gwZIJVwCBNCDcOlLIl4E/gn8iXDJytta62qlVEea91qFaOskf6pN8ich2gbJm2qTvOkwIMFU4nSdv6sIfwkBUEp1rrGuCPAAw7TWuxI9oVIqCXiDcLXuO1rrgFLqbcJf9v2ylFLuGplCT2B1U6UhgTT3Ap4AjgMWaq1DSqkVNdK8A+hXz+51P+O6PgNylVIjgQuB2yLLW+VahWhDJH+KLc2SPwnRsiRvii3NkjcdQqTPVNNZCQxTSo1USrkIV98CoLU2Cf8oHohE/iiluimlTmrgeEop5ar5ApxAElAIBCMlLSda7PtHpZRTKTWFcBvh1xJMQ1NwE/5hF0bOeQUwvMb6J4HblVJjVFj/SCYCsAfoW9+BtdYB4DXgH4TbYH8WWd5a1ypEWyX5kzXJn4RoXZI3WZO86RAiwVQT0VpvIFxl+jmwEag718FvgE3AIqVUeWS7QdRvIuESgrqvXxBuy1tCuHr23Tr7FUTW5QMvANdrrdfHmwYVHsHlsYavunFa67WEq5MXEv6BHwHMr7H+NeAvhKudKwi3Y86OrP4r8DulVKlS6vZ6TvEicDzhTK/mEKvxft5CHLYkf7Im+ZMQrUvyJmuSNx1alNaN1QYKIYQQQgghhKhLaqaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAmQYOogKaWeUUrdE3k/RSn1YwudVyul+jfxMQ9cS0vu21KUUncqpZ5s7XQI0VIkfzr4fVuK5E+iPZG86eD3bSmSNzWuXQRTSqk8pZRHKVWplNoT+fKmNvV5tNbztNaNDiGplLpcKVV3+M8mo5T6Wil1dXMd/2A19/VHzjFNKbWz5jKt9b1a6zb7uYj2SfKntkXyJyHCJG9qWyRvarvaRTAVcbrWOhUYDYwFfld3A6WUvcVTJYQQkj8JIdomyZuEaER7CqYA0FrvAj4iMpN0pMr3RqXURsITxqGUOk0ptSIy4dkCpdSI/fsrpUYppZYrpSqUUq8ArhrrakX0SqkeSqk3lVKFSqlipdQjSqkhwGPAhEhpT2lk2ySl1P1Kqe2REqDHlFLJNY51h1Jqt1IqXyl1ZaLXr5R6TSlVoJQqU0p9o5QaVmeTDkqpzyLXN1f9NKM2SqnBkXX7lFI/KqVmJpqOOmnKU0rdrpRaFUnXKyo8azlKqSyl1PuRz7Ak8r57jX2zlVJzIp9LiVLqbaWUm/D/cdfIZ1yplOqqlLpbKfV8ZL+PlFI31UnHSqXU2c15rUI0RPInyZ8i+0n+JNoUyZskb4rsJ3mThXYXTCmlegAzgO9rLD4LGA8MVUqNAp4GrgNygMeBdyM/WCfhWaafIzzT9GvAOfWcxwa8D2wDegPdgJe11uuA64GFWutUrXVmZJf7gIHASKB/ZPvfR451MnA7cAIwgPCs1Yn6KHKMjsBywjN913QR8GegA7Bi//rIj+wzwrNmdwQuAGYrpYbWc/2lSqnJcaRrJnAy0AcYAVweWW4Ac4BeQE/CM5k/UmO/54AUYFgkXQ9orauAU4D8yGecqrXOr3O+l4ALa6R3aOQcH8R7rUI0FcmfJH+KkPxJtCmSN0neFCF5kxWt9WH/AvKASqCU8A90NpAcWaeBY2ts+yjw5zr7/whMBY4B8gFVY90C4J7I+2nAzsj7CUAhYLdIz+XAtzX+VkAV0K/GsgnA1sj7p4H7aqwbGEl3/3qu92vg6hg+l8zIcTIifz9DONPavz4VCAE9gPOBeXX2fxz4Q41974nx/6Pu9ecBF9f4++/AY/XsOxIoibzvAphAlsV2B/4vaiy7G3g+8j4t8pn3ivz9F+DpyPsGr1Ve8mrKl+RP9X4ukj9J/iSvVnxJ3lTv5yJ5k+RNtV7tqZ3rWVrrz+tZt6PG+17AZUqpm2sscwJdCf94dunINyRiWz3H7AFs01oHY0hbLuESgmVKqf3LFGCLvO8KLIvhnA2KlPj8BTgvck4zsqoDUBZ5f+Cz0FpXKqX2Rc7fCxi/v2o9wk64dKMpFNR4Xx05J0qpFOABwiUvWZH1aZFr6QHs01qXxHsyrXWFUuoDwiUnfyNc0nJNZHVzX6sQdUn+JPnTAZI/iTZE8ibJmw6QvMlaewqmGlLzB74D+IvW+i91N1JKTQW6KaVUjUyhJ7DZ4pg7gJ4fkcTfAAAgAElEQVRKKbtFpqDr/F1EuAp2mA63S65rN+Ev/34967+UBs0CziRc1Z0HZAAlhDOf/Q6cR4VH7ckmXKK0A5irtT4hwXMn6lfAIGC81rpAKTWScDMDFUlTtlIqU2tdWme/up+xlZeAPyilviHcfvuryPLWulYhrEj+9BPJnyR/Em2H5E0/kbypHedN7a7PVAyeAK5XSo1XYW6l1KlKqTRgIRAEfqGUckQ63I2r5zhLCP+Q74scw6WUmhRZtwfoHmlHjNbajJz3AaVURwClVDel1EmR7V8FLldKDY2UNvwhhuuwR865/+UgXD3rA4oJl+bca7HfDKXU5Eja/gws0lrvINyGeaBS6pLItTuUUkepcKfQ5pRGOLMsVUplU+Patda7Cbdjnq3CnS0dSqljIqv3ADlKqYwGjv0h4ZKUPwGvRP4foPWuVYjGSP4k+ZPkT6ItkrxJ8qZ2mzdJMFWH1nop4SrLRwiXPGwi0qFPa+0Hzo78vY9w+9A36zlOCDidcIfI7cDOyPYAXwJrgAKlVFFk2W8i51qklCoHPidcqoDW+iPgX5H9NkX+bcyjhH9I+19zgGcJV3PvAtYCiyz2e5Hwj24fMAa4OJKGCuBEwlW7+YSrlv8GJFmdXIVHgZkSQzob8y8gmXAJ1CLg4zrrLwECwHpgL3BrJL3rCZeebFHhDp1d6x5Ya+0j/P93POHr3r88rmsVoqVI/iT5k+RPoi2SvEnypvacN6naTViFEEIIIYQQQsRCaqaEEEIIIYQQIgESTAkhhBBCCCFEAiSYEkIIIYQQQogESDAlhBBCCCGEEAlocJ6pDh066N69e7dQUoQQLWHZsmVFWuvc1k7HwZL8SYjDz+GQP0neJMThp6G8qcFgqnfv3ixdurR5UiWEaBVKqYRmgW9rJH8S4vBzOORPkjcJcfhpKG+SZn5CCCGEEEIIkQAJpoQQQgghhBAiARJMCSGEEEIIIUQCJJgSQgghhBBCiARIMCWEEEIIIYQQCZBgSgghhBBCCCESIMGUEEIIIYQQQiRAgikhhBBCCCGESIAEU0IIIYQQQgiRAHtrJ0CIQ8G24irmzM/jkzUFVHqDAKS57MwY0YXLJvSmR3ZKK6dQiOa1t3ovr6x/hQ+2fkCZrwyNxu1wM6XbFC4deil9M/u2dhKFOGx4AyHeW5nPnPl55Jd58AdNkh02BndO45pj+nLMgFwMQ7V2MoUFrTVLCpbw1qa3KKgqIBAKkJGUwcSuEzmz/5mkOdNaO4miiSmtdb0rx44dq5cuXdqCyRGibfmxoILfvf0Dq3aWYWpNIFT79+KwKQylGNUzk3vOGk7/jm0/k1RKLdNaj23tdBwsyZ9axq7KXdyz6B6WFCwBDX7TX2u9XdmxGTb6Z/bnrvF3cUTuEa2UUnE4OBzyp4PJm3zBEH//+EdeWrIdBVT5Q1HbuJ02Upx2bjthALPG9zrI1IqmEjADvLL+FZ5Z8wwV/gqqg9W11ifbkzG1yYm9TuT6I6+nZ3rPVkqpSERDeZM08xOiHgs2FfGz2fP5Lq8EX9CMCqQAAiGNL2iyeMs+znxkPku27muFlArRPNYVr+O8985jQf4C/CF/VCAFENRBfCEfa4rXcOUnV/Ll9i9bIaVCHPoqvAFmPraQFxZto9ofsgykIBxgFVb6+PP767jzzR9oqFBctIyqQBVXf3I1Dy5/kD3Ve6ICKQBP0IMv5OODLR9w3nvn8V3Bd62QUtEcJJgSwsIPO8u46r9Lqa7nZlaXJnyDu3zOEtYXlDdv4oRoATvKd3DlJ1dS4a/A1GZM+3hDXn79za/lIUGIOPmDJpc9vYR1u8vxBmP7vXkCId76fhd//XB9M6dONMQf8nP1p1ezumg13pC30e1NTKqD1fz885/zQ+EPLZBC0dykz5QQdZim5ppnl+IJRAdSVWu/pvy7twkU78RwJuPo2JeMiTNxdR8GQLU/xLXPLmPuHdNQKvH27FW+IKWeAACZyQ7cSfJTFS3rjm/uoDoQXbpaMq+Eok+K8O/1Y3PZSB+TTqdzO2Fz2wDwhXzc8tUtzJ05F4fN0dLJFuKQ9PT8razdXY7fogVEQ/cdTyDEc4u2cdLwzozpldUKKRf/WvYvNpVsiqq5byyv9Ia8XP/59Xxx3he47K7WSLpoIvKEJkQd8zcXUeENRC0vX/IWZYtfJ+fEG3H1GY2y2fFsXYZn4+IDwRRAUaWPpdtKOKp3dlznDYRMPlu7h8fmbmZtfjkOm3Fg+aDOaVw/tR8nDeuM0y4VyuLgbSrZxHPrnuOr7V9RFag6MKDE9B7TmdZ9GptKN2FSu4S86KMiCj8qpPvV3UkdmkqgJED+c/nk3Z9Hn7v6YES+myEzxBfbv+DkPie3xqUJcUgxTc2T87bgDUTXSMVy3/EGQzzxzWbGXHJIdzU7JHmCHl7f+HpUjVSseWXQDPJJ3iec2f/M1ki+aCLyVCYOSQVlXj5bu4c3l+/k/VX5fJe3j5DZNO3GH5+7JaqtuumrovTbF8g+4QZSBk3EcLpQNjsp/ceTNf3KWtt6/CH+882WuM758eoCxtzzGXe8vpJVO8sImhpPIIQnECJoatbkl/PbN1cx5p7PeH9V/kFfo2i/1hWvY+Z7M7nggwt4Z9M7lPhK8Jt+AmaAUl8p725+l9u+vg1fyFdrv5AnxN6399L14q6kjUhD2RXOXCc9ft4Df5GfsgVlB7atDlbz1OqnWvrShDgkfbOxEI9Fk/JY7ztaw5c/FlJc6Ys6hmheH2/9GEXtViiSV7Y/UjMlDhmmqVmwuZjHv9nMkq37cNoMTK1RKtxnyWW3cdXkPlwwrifZbmdC56jyBVm8tThquW/XenTQT8rACY0eQwNfrd+LLxgiyW5rdPv/Lsjjrx+tsyyVrJ228M329tdWsrvMwzVT+jV6bCFqmr9rPrd+fSveYP3t+kPaup9g9cZqzIBJ+pj0WsttLhtpI9KoXFNJ1jE/NTPaWraVgqoCOrs7N03ihThMvbxkh+VgE/Hcd2wKPl5TwEUyul+Lenbts1GDTcSbVxZUFbChZAMDswa2SJpF05OaKXFIKK32c9a/53Ptc0uZt7EIX9Ckwhekyh+i0heiyheiuMrPQ19uZOJ9X/DeysRqb/ZV+Q80r6sp5CnHSElHGY0HRwB2m6KsOrqpYF0fry6IKZCqyRsw+eenG3h3xa6Y9xHih8IfuPWrhgOphoQqQ9hT7ShbdF9Ae4adYGWw1jKH4aDYE10wIYSoLb/MY7k8nvuOJ2Cytzyx37ZIXEFVQdSyePNKm7Kxq0Lu54cyqZkSbV5JlZ/TH/mWPeVey+HJa9oflNzx+koqvEFmjY9vHgdf0MRq2AhbcjpmdTnaDMV0YzOUwtfIiEyBkMmvX19pGUg1NtCFN2By51urOWl455hqv0T7ZmozXCNlMdJUY52k97Ol2ghWBtEhHfWQECwLYk+Nvp1YDaUuhKjNX8+9It77jieOQjnRNKzyuHjzSlObeILWAbU4NEjNlGjTTFNzydOLYwqkavIGTP70/hrmbyqKeR9tmrj27CIQCEatS+o2GGV3UL1hYUzH8vsC+J96nLL3P8C3aRM6GH3Mz9buIWQxP0j5krfY98UTZBw9k+43PU+3G+aQNnoGno2La6dXaz76IbpUTIi6Fu1eRKW/Mmp50UdFFLxWQOeZnRk6eyh9/7cv/mI/effnYdZ5wEvpn4KyK8qX1R76P+QNUbGqAvdQd63lWmvSHG1/EmshWlt6svWol/Hcd+yGIislsebtInEuW/QofPHmlYYySHWmNms6RfOSminRps3dWMjWwirLQCqW2pt7PljLR7ccY3nsUGUl3lWrqF6xAs/3K/CsWoVKT8c9+gb8qvZNyUhykzn5IvZ99hjKsOHqMwpl2PHmrcC7fVXUIBQ5LhtpKUlUfPophQ8/RHBvIUl9+5I0eBCuQYNJGjyIRxd4DvSD2m9/h+OcGbeSMmjigeUp/ceT0n987ev3h3h07mbOGtUt9g9UtEvPrH4mql3//k7S3a7qRtqIcNCzv5P0hjs2ULagrFa7fluKjY5ndST/+XwMl1FrhCpHtoPMiZm1jq/R9EyPr2ZYiPZoyoAOrNxRGtWaIZ77jtNuMLpnZt1Di2bWJ6MPKwtX1loWb14ZMAP0yejTkskWTUyCKdGmPT53s2XH3FiHKd9aVMW63eUM7pxGYNu2nwKnFSvw79iBa+gQUkaOJOv8mXT9673YO3Tgmq828dAXG6MmTkwfdzaGO4uyha9Q9P79KGcySZ36kz7h/FrbJTtsXH/CIDpOnnFgmVlVhXfDBnw//oh3/XoKPvqUtT3OgzpNN+LpcAywpbCS0mo/mVIiKerhCXosJ9GNt5M0QO6MXGxuGwWvFODf68dINkgfnU6P63pgOH5q6OAwHJw78FycNvleCtGYC8f15JEvN1mui/W+k5XiZFyf+KbjEAfv8mGXc9e3d0UVVsWaVwIMzh5Mj7QeLZls0cQkmBJt1o591Xy/vTRqeTy1N4FAiAf/8SK/XPgsKtlFysiRJI8cRea55+AaNAjljH7Yu2BcT/71xUbLNKUOm07qsOkNplujOWds91rLDLeblFGjSBk1KnwNpR6c/5wbNTFwvANdOGwGpdUBCaZEvcp8ZdgNO8FQ7aamjXWS9myzbsOfPTWb7KkNP7QppZg1ZFbiiRaiHemQmsTUQbl8tnYPFi2/G73vJDtsXD+170FNFC8SM63HNOyG9aN0LHml2+7myuFXNriNaPskmBJt1oodpdhtCl+d7kbx1N6EUKxK60Gft9/C0alTTOfNdju55OhevLh4e1Sw05hkh40rJ/cm3RXdBj4YMlm3u4KSaj9FlT5Mi3mx4u1wLERjAmbA8iErkQElYuGyuTih1wl0S5Xmp0LE6lcnDGLehqK47zmGgoxkB2eP7t74xqLJ2Q071xxxDbNXzMYTim8QCUMZZCRlcEx3664I4tAhwZRos8q9AcuJeOOtvam2OWIOpPa7a8YQ8oqrWLCpKOYRkpIdNqYPzuX2EwfVWr63wssLi7bz3wV5BEwTQym0Bl8o+rg1Oxy7B09u9JyBkElminXnZSEA0p3pBM3oAVBqdpLOGJdxYPn+TtKdzo3vNwPhQGpozlD+OOmPB5VmIdqbQZ3TmH3RaG54YVnMU2XYlCLVZefV6ybgTpLHudZy2bDLWF28mrk75lqOmGrFwCDVkcpTJz1Vb82WOHTIaH6izXLajKiZxaF27U0srOaNaoxhKP5zyVjOHNkNl8PAoiXUAXZD4XIYnDe2O49cOPpALYDWmke+3MiUv33FY3M3U+oJUOULUeENUlm3um3/eWt0OK7esBAz4EWHgng2L6Xkq6ejtu+d45YmfqJB6c50cpJzopbX7CRdsaoCHdT4C/3smL3DspN0Qwxl4LK5mNJ9Ck+c+AQO49AL8L2BEDtLqtmwp4JdpR58wfhqCIQ4WNMHd+Tpy4/CnWQj2dFwYaE7yUaXTBfv3zyZnjkpLZRCYUUpxd+m/I0ZfWeQbE9udHsXNnKSc3jp1JfoniY1iocDCYdFm5WbloTNiI5i4q296ZCaWLBhMxT3nTOCKyb14alvt/DuinzsNgMdadSulCJkas4a1Y2rJvemf8efhoHWWvP7d9bw+rKdjc43VVesHY7dThvXT+2X0LWJ9kMpxeXDLufB5Q9GzWUSayfpJFsSw3OGs7ZoLYZWmKEgKDAMOwEV5Lhex3HZsMsYljOs7unbvPUF5Tw5byvvr8xHKYVhgBn5yZ4zphtXTOpDv1wZtli0jIn9OrDof47jzeU7efybLZRWBzAUB+4jhlIM7pzGDdP6cdyQTgkVFoqmZzNs3D3hbk7qfRJzVs9h+d7loH+ah0qhSLYnk2pP5tK9uzj71FdJk9FODxtKW/V2jBg7dqxeunRpCyZHiJ/4gyaj//yZZS1O+ZI3KVv8Jjkn3djgcLEpTht3zhjMxUf3Puj0VPqCfJe3j7LqAACZKQ6O6p1t2bziyXlb+OenG+Ju/x6PFKeN5f97Aq5GSjDrUkot01qPbaZktRjJn2JX4a9g+qvT8YV8ce/rMpOYUTaFa6vPx+fz8IN9A+W2KkxMUnUKw3z9yeqUS/rUHriG5qAsCkDaoqJKH9c8u5R1u8sJhLRlk2K7obAbijG9s5h90Rgy6pkPSDSdwyF/aqq8SWvNyp1l7CrxMH9TIT/uqeQf546grwT3bV5BVQGf5H3Cnuo9+II+sl3ZjO40mqO7HI36/G7w7IMzHm7tZIo4NJQ3Sc2UaLOcdoOLj+7J099uxV9nnqlYa29MrTlr1E/V6N5AiPdX7ealJdspqvQRMjXpLjvTBnXk0gm96ZwRPQHffqlJdqYP6thouqv9wXoDqcbmxopVssPgz2cOjzuQEu1TmjONq4ZfxZw1c6JqpxqSG8jiH9t+RY6ZiQ4FcGJnjG9o1HaBHZXse/VHkvplkjNrCMrRtkvLd5V6OOvf8ymp8hO0CKL2C5qaoKn5bus+Tn1oHu/cOImc1KQWTKloz5RSjOyRycgemWQkO9hatEkCqUNEZ3dnLht2mfXKybfCw2Ngws2QO7BlEyaahQRTokUVV/rYXebFEwiRmmSnR3YKqQ10nL1kQm/mzNtiua6x4WIdNsWZI7uRmmSn3Bvg/z7dwKtLd6Agau6qTYVVPPXtVsb3yeZ/ZgxhSJd064PG4L2V+ViNUBvr3FiNcTkMbjp2AOeMkbbWInbXH3k92yu288W2L2IadSorkM5DW39DesiNEUP3Wu038W4soeiZ1XS48gjLIdcPVmGFjxcXb+ON5bso9fgxzXDfkWMG5HL1lL4M6pzW6DHKvQHOf3wh+yr9hBpomVGTP6QpKPMy64nFvHPTJCnEEC0u2+1kX5W/tZMhmkJyFky8Gb66B2Y+29qpEU1AginR7ExTM39zEY/P3cKSvH3hgSUUaB0eje7UI7pw9ZS+DO1aO4DRoRDOZ5/g2o2beWLgiXjjaDFnMxSdM1zcdeoQCsq8zHx8IQVlnqgarv38kfbo32ws4rvZC5h98eiYaqGsPPr1ZqrrBGvxzI1VH7fThgb+fOZwCaRE3JRS3Dv5Xh5IfoAX17+IqU0CZsByW4dy8Ncdt5BmpkYFUkt2ruLerx5lQ1EehmEwIKcXfzjuZkZ2GQJBjX97BWUfbSXztL5NlvaCMi+/e/sH5m0sAqjVD7HSF+TN5Tt5b1U+fTuk8qczhzG2d/1zu8z5No/CCp9lINVQzXHQ1GzfV80by3Zy0dG9muzahIhFh1QnxRJMHT7GXReundq1DLqNae3UiIMkwZRoVpsLK7n0qSWUVvsP1Ab56wzI8M6KfD5cvZsR3TN54tKxZCQ7CO7bR/7td6CDQW6ZfT+u1WU8/OXGmIaMddoMOqUn8dp1E9EmnPPoAgrKvZZ9Iqx4AiFueH4Z/71iHOP7Ro+C1pAKb4CdJdGl/vHMjbVfssNAKUUgZNI7x831U/tx6oguUiouEqaU4pdjf8n5g8/nxXUv8sbGN8LLa4yaaWqTG7KuoNfGrtSdQbTCV8UVr/+Wv5z4S04fPB1/KMiSnStJsv00yIsOmFQt3k36Cb0wkg7+u7phTwXnP76Qck+AespCCGkIBUzW7i7n4qcW849zR3D6kdHzXIVMzZwFWy0HhYml5tgTCPHY3M3MGt9TJkgVLSrL7aS02o9paoxDpF+iaIAzBab+Gj7/I1z2bmunRhwkCaZEs1mTX8b5jy+iyh+0nNV9v5DWhAKaFdtLOPWhebx8TDqeO39NxmmnknvLLSi7nRun59Iv181fPlhHcZUfTyAUdcxkhw1Ta04b0YU/nDGMdJeDnz+/jL0V1oFUQ6XQ3oDJ1c8u5bu7jo8reCnzBHDYDIJ1hm2Pd26sFKeNRy8aTa8cN5kpDhn+XDSpbqnduOOoO7hl9C2sKlxFmb8MgAxnBkfkHkHFsxvxBkqi9tuybwcAZw09HoBkw8bUPuOiT6CgesVeUsd3Oah05pd6mPn4QkqrrWvQrHgDJne8vor0ZCdTB+bWWvfl+r0ELOZ3i6fmuLjKz7JtJQ3WfgnR1Bw2gxSnjTJPgCy33A8OC6MuhgUPw+Yvod+xrZ0acRAkmBLNYneZh1lPLK53PiUr/pCmoKSaS17I4/W77iLn+ONqrT95eBdOGtaZZdtK+M83W1i+vYRqfwiHzSDH7eSSo3txztjupLvCI24VVfr4Yv1eAhbF2bGUQpum5oNVu2NqUqe1JlRaSmDNZnQoCHXmx6o5N1asAVXvDm565bhj2laIRDhtTsZ2rj04UajSj3dzqeX2fbN7YCiD2z74C2cMPo5R3YaR6Yrup6T9JhXf7DzoYOrWV1ZQ4bXOQxorDPn588tY+rsTSHb+9Hv7ePVuqnzR7YXjqTn2+EN8tnaPBFOixXVITaK4yi/BVFtnhmDT51C0AXyV4HRDdh8YcBLYa/zf2Rxw7O/CtVN9poHRtgfuEfWTYEo0i4e+2ESlz7o0ucF+CSj2ZHZhbvYgzrbYVynF2N7ZMT3IvLRku+XyWEuhq/whHv16U61gyvR48G/fjn/rVvx5efi35uHPy8OXlxeenKZvPwK9Z1F3BIp458byB025YYpWESz2ouwG2mLS2rQkN29e9AizF7/Irz/+B4VV+5jebzx/P/nX5Lpr/yZDJfEPw17TtuIqVu4otaxVjqUwRAPvrcpn5tgeB/YrrLROUzw1x5rwQBiNCYZMFm4ppqDMiy9okuayM7xbhsxZJRImg1C0cZWFsPRpWPIYBP0QirwMB9iTQBkw9ioYdzVkRJ4rhp4F8x+Ede/AsJ+1bvpFwiSYEk2uyhfk7e93YdGaJqaHoOqgyaNzN3P26IMbZOGFRdst+0bEUwq9s6iSRXf/ja7bf8S3NY/Qvn04enQnqU8fnL17kzLuKDJnnoezTx9sWVkopTj6yUXM31Rc6zhGkpvMyRex77PHUIatwbmxAI7skXmghk2IlqQtam5qGtChNw+ceicAm4q38Yv37+HuLx7m32f8ofZxTM2Pu8rJTk8iK8WBPc7JRZ9ZkIdp0T441sKQan+Ix77efCCYMv1+zKpqy3PFXXPcQJeVveVenl+8jf8u2EbI1JhaY5oam00RCmkGdk7jhqn9OH6oTLgq4pPtdlJcT4GAaGV58+HFmWAGIeitvc4MgD9SuLzo3+Fg69w5MOiUcG3U8X+AD26HwaeFa6vEIUeCKdHk3l6xy3Jo8Hj6Jezc52H1rjKGd8tIOB3FVQdfCm3XIUo69mT4cRNx9umDo0sXlK3h/a6f2o/vt5dGjegX69xYbqeN66f2azRtQjQH5Yz9Ab9/Ti9mDj+Z51dEd6A2gRtf/p6Saj+lngCpSXZy3E6y3E6y3U6yU5xkpzrDyyLvs1Mi69xO3ly+y7KJbjyFIbuKKph/wy/ptHkNwYIC3OMvhpzoebLiqTlWQKc06/no3l2xi1+/sQqtiS7IibRWXLWzjNtfW0luWhIvXzuhwbnthKgpJ9LMT7QxW78JB1KBGObwC/khBLx2BZz9OAw9M9xfKqM7fP88jL2i2ZMrmp4EU6LJfby6ICqQgPgeggIhk3nr9zAkFczq6p9eVTXeV1cdeK8t1gc7nBXV3A7iK4VWSS4c044mdWinmK9/Ur8OpLnslp9BY3NjASTZbRw7OLFh2YU4WPYsF9qiRhfCNVFfbF7IGYOPpUt6R/LL9/DOui8Y3TV6njRnhpPPfxUuJAmZmnJPgOIqP/tqvEqq/eSXelmTX05xlZ+SyPKiSi++oPWoNfEUhjgMCBx3Cj3uuAln9+5csKWUb19cHjXPXDw1xy674sRh0fnBi4u386f318Q04miVP4S3xMOpD83jg19MkYBKxCRHmvm1PaXb4aULYgukagp64K3rIKc/dBoWrp16+WIYcX54pD9xSJFgSjS5kmrrzD6eh6Cgqdn46JNs2vY1RkpK9Mudgqr5d0YG9i5dMFLcB5YlfVxmOTdVXP2XFKS54vuZGIZi9kWjufjJxXhieLCqyeUweOSiUdhk6FvRSmwZSTi7p+HPK49a53amsCJ/HU989yrlvkrSk1I5vt8E7pr+81rbKYeBe9JPQ5PbDEVWpFYqFsGQyYC7PsIqnIqrMMTpxHbkSJL6hkf1O2ZgLi6HLSqYgthrjrMrisn8zY2UXXQR6SeegHI6WbC5KOZAar+QqSn1BLjgPwv59LapOO3S5E80LNvtZFtxVWsnQ9S0cDaEavcP7/2vCqoDsPWWVNzO8L38yeV+nl8V4OvLawwqFfTBN/fDeXPCc011HwtLHofJt7XkFYgmIMGUaHJGPfOvxNsvodNVVzDolL8nnI5BK79l5c6y6PTFUQrtC5oM6BQ9WlljxvTK5pFZo7nxxeUxP2C5HAb/N/NIJvbrEPf5hGgKWmu+3VTEp1UVXKo0ybr2b7lLWi6PnvXHmI6TOjb22ty67DYDu01ZNvOLqzBEU6vvoc1QXDW5Dw99sRGvRe1bYzXHLrvBzRdOIaeiMyUvvMiev91H1nnncW/oiHp/5w0NuBMyNXsrfHyypoDTj+za8LWIdi8n1cny7dFTFohWEvDC98+Gm+7VEdLw4GI/d05Jqn9/bcKPH0L1PkjJhuN+D0+fBGMuh+Ss5ku3aHJSFCaaXMc068yj5kNQY5x2g5zUBjKhGFw/tR/ueiYNTR93NlnHXkXZwlfY+fBF7Hz0ciqWv0/ygJ+aIBoKjh3UkewER9U7bkgnXr52AgM6ppLssGGzCDINFZ4fq28HN89dNZ4ZR8gDlYhDeT6sfReWPwerXoUtX0eVksZqweYiZj6+kD+8s4ax03vjTk9qcKCFetkVKSNyMVIOriP1sK7W/SVrFoZUb1iIGfCiQ0E8m5dS8tXTtbYNmCYDOtUePe+yib3pkpmMPYHaX1/IZGV+OUUjJ9DrmTn0mjOHTft8bMiPLrSB8IA7+754goyjZ9L9piPdFTsAACAASURBVOfpdsMc0kbPwLNx8YFtqv0hHp27Oe60iPYnx50kzfzakrXvUF8mecdEJ/cv8FHqbWCSTQjvv+KF8NsOA2DwqfDtv5o0maL5Sc2UaHKnH9mVeRuLojpgx1MjpIATh3Y+qHQcP7QTdsMg3NszWqOl0A4b1xzT96DSMLJHJp/9ciqrd5XxxLwtfLFuL9X+cE/0ZIeN6YM7cs2UvhzZI/OgziPaEa3DQdP8B2HbArA5wyWcSkVeBhx1LRx1FaQ3Ps/Toi3FPPDZBvaUe/nFcQM448iu2G0GgR7Z7H1kRaOj+9ViU9hzksk8q3/i1xdx/dR+/Oq1FZbzQsXSJM+mTY7buZLqp/JwnHMOjo7hfojuJDsvX3s0P/v3fAorfZa1X/XRGl5bupO3lu/irlOHcMmE/rw74mRC322nbpvEeAbc2VJYyY8FFQzqHH8tuGg/wqP5STDVZuR/D/5Ky1Vju9qY1tvO/Qt83HNsA30igx7YsRi4Ofz31N/CY5Ng/HWQLoWrhwoJpkST2bGvmheXbOfV77YTtJgbBmLvl3Bkj0x65hxcJ0yHzeCW4wbwj09+xBOI44EQcNoUgzqlMbpn0wQ5w7tl8OAFo4BwEygIz5klRFw8pfD8OVC4DvyRvhMhi1ErFzwECx+GU/4WbjJi4bu8fTzw2QZ2lni4+dj+/GxUt1rDlztyU8i9bgSFT/wQDqjq+U3/tIOBo2MKuVcNx3DGNjF1Q44f0hHHQRSGOBx2fn7NDAKfvM2W007HPWECWRecT8rRR9Mp3cWHt0zhxheXszSvhFDIpJ7xLqIETU3Q1Nz74Xqq/SGWbC3GKh6LZ8AdhWLljlIJpkSDclKdMppfW1Jd3ODqP01PYtLTVdwyvpHWLZ4ak6RndINRl8Dcv8PpUkN1qJBgqp0o9wZ4Y+lOvtlYSGl1AIfdoGtGMjPHdmdCv5yEH+xNUzN3YyHPL9zG8u0lnD26O69dP5E3lu/kyXlbLed5auwhKMVp47qDrBHa74pJvflxTwXvrsiPOaBy2BQd0108c8W4Zgl4JIhqX0KmZsOeCkqq/SgUWW4HAzqmxT/IiKcU/jMNyndZttGvfdJIgPXxb8FbBpNuObBq2bYS/vX5BrYWVfGLYwfws9Hd6p3vyNk1lc63jab86x1UL90DgPbX/k0rp4FKspN2TDdSj+6KcjRN63G7zeA3Jw/iT++vi7swxGU3mDqoI8MmjoSJI+l4x+2Uvfsue+79KzoQIPP888k460xeuPpo1u8u52ezFxCM8xyeQIgHPt+Ay2EdOMY34I5JuTex5pmi/chKcVJa7cc0NYYMUtT6khqegHt4RxunDbRz37d+huQ2kC863bX/nnwbPDwGJtwEHQ6+ll80PwmmDnN5RVU89MVGPvhhN4ZSUQ8ln64tID3ZwXVT+nLxhF4xTyK5r8rPa0t38MLi7aQn27n06N48Mms0yZES6Rum9eeDVbvZUeIh1FiJdg0uu8Gk/h2YPqhphgZXSnHf2UeQ4XLw7KI8AkFNyGIi0P1SnDZ657h54erxZBxknw/RvhVV+nhx8XbmzN+KP2geePgxTU2Sw8aVk3tz4VE9Y+sbqDW8cG5sgVRNAQ989VfI6c/3KRN54PONbN5byU3H9uec0d1jGkHOlp5E1hn9yTylD9UrC/Gs34dZFUAZClt6EimjO5LULxPVDA93F47vxeaiKl5YtD3mgCrJbjCwUxoPXjDyp2tITSV71iyyLrwQz/crKHn5JYpmzybt2GP5btxplvPiQcODRwB4A2a9+Vs8A+7YlCKpkf8L09QUlHsp8wRw2Aw6pDrJTEmsP6c4NDntBilOG2WeQMwjY4pmlN0XbEnWrQMi/jjNxejHK/nVhHryeWWHnDrzSqZkw4Qb4at74Lxn0FqzubCS4ko/poaMZAcDOqXKpN9tiARTh7FFW4q56pnv8ARC9bbQqfaHqPaH+Psn6/lw9W7mXDGO1CTrr4XWmhU7Snlu0TY+X7uHE4Z25qELR3Fk94yo2pbUJDsvXzuBcx6dz96K2PolJDtsHNkjg0dmjWrSUjelFHeeOoQzRnblqW+38uEPu7EbCn/IROtwc0BTawZ3TuOGaf04fkinWs2dhIiH1pp/f7WJh7/cBFhM3kp4nqGHv9jEQ19s4pcn/D975x0mRZX14beqOvfkCAzDBILkHIdsFrMuJkRJCkZQ191v1V3TuqsrrnlFRFERXdacMaBIBhUByTCRgYHJuWN1fX80aaa7Z7phEnDf55kHpruq+nZPV9X93XPO73Rl5pjODUcsc1ZC4fY6Qip4+10bhR/cy+3yy9xxdlcm3jQIoy70NDxJr2Ad3A7r4JOrZQyVByf0IMZq4Pnv9+DRtIDXElny9mgb0TmW/0wa6DdiJEkSloEDsAwcgLusjPIPP+aVlbnUmnzTeSs3fEzF+g+IPf8OTGkDkRQdtuxfse1Zf1RMAQFTmkNxHdQdjob7o7TGyXsb8nhjVTY1Tjc6WUbTNJyqh94dIpk1rjPndE8Q16wzhCONe4WYagP0mQjLHm9wky4xMtf20vPCBid9Evyco4oOBt7s+/jw26h4LoMPv1rJa5udlNe60CnHFuQUWeLG4SlMHpFC+0hzU7wbwUkgxNRpyqZ95Uxd+HPQq7k2l4fN+8q5ccF6/jdzRJ0Va5tT5fPNB1i0LpcKm4sbh3firxf3bPRi3i7SxFd3j+GeJZtYlVkMGjhV34mlWa/g0TSuHZLMQxf3aLZJQe+kSJ69tj+PXNaLH3cWUlztwO3RiDTrGZIaQ5eEhkP2AkFjaJrGXz/Zyocb9/sVUcdz5Pnnv9/LoUoHD1/q2/j2KKufB2etz8NB2e8C0VoFP11vwZCW0vibaGNIksTt47pwad8OvL02h/c27ANAQ0PTvFEdp+rhnB5eM5f+yVFBpdLqoqPJv+AqynLXQb3rZCjmERKgyFD/zx2K4Y6mwdhu8XXH4NF4aulO3lyTgwTHWbkfe6Hf9pVz75JNGHQyr9w4iOHpsY2+b8GpTYxFz/4yG2lxVtGPsLUJbwfp42DPt/g40BzH38YaWbQlQBpvQg+IP8vn4SWbS3m4/AmklSXYtMNZMvUO8fqqbF5flc3kESk8cFEPkfrZiggxdRpid6nc/MYGv0KqobQVp6qxs6CSp5bu5K+X9CSrqJrF6/P4aGM+g1Kiue/8bozpGh/SCRtp0fPG1CEcKLfxzrpcFq/Po9LuQidLuD0aHSJN3DI6nasGdazTD6Y5iTTruWJAUuMbCgQhMn9FFh9u3B9SjY/NpfLfDftIjrYwbVSa7wZVByF7Jf5u1vdnGPjXage3DzEQZQp8XupVO6x/GdIyAm7T1kmOsfDgxT354wVnsSG7lNIaJ27VuxgyMCX6hFoYbD9Qiebncw3FPMKjBZ5GBWO4o/O4uUwuQVddCdHe3jIej8Yd725k+a6iRkV5jVOlxqkyZeEGnr+2Pxf0btzBUXBqUVHr4n+/7OP1VdkcrLQz9c0NaECs1cDNGalcP7QTcSfZSkRwgoy6B3JWeFOqD5Mzp66RTHKkjP2hCN999RYYfZ/Pwy//uJeXftiL3SMBgedFR64Ni9flcaDcxkvXDxSCqpUQYuo05PPNB3D5iQAFk7Zid3tYtC6XHQWV7D5UxTWDk/nszlEkx5ycs16HKDN/urA7f7qwOy7Vg82lYjXoxMqa4LSh2uHm2e93+23e2ljtjc2l8vQ3u7h+aKejdYdHObAJdAa/eflB2++iHbbfPfUx6hRGd41vfMMgqLK7cPtJGwzFPALAoldwe7QTMtxR9Hqu9uwn66IJRN80mZibbuaxH7JZvqsoJFFud3mYs2QTiyNMDOwkGn6eDjjcKg9/uo2Pf9uPJHH02nIks7S42slLP3hTii/olchTV/fFYhDTuhYlZQRk3A1rXgSXb/ZAQPRm6HstdL+kzsOfbzrAiz/sCdgE3B82l8qPO4t48uudPHBxj+DHIGgyxFl3GjLvp0xqnSeetuJWPaTEWFg4dcgJ1VY0hl6RReGk4LTj4435yH7Sy4KtvZEk+Gzzfq4d0qnuAewV3j5SAQjafveIlbrgKCa9gnI4Sn48oZhHAMSFG7HoFfYWVYfUt8qsV3hgQncGj5iAc/oNFL34Eisvv5b3hs7Cqfl+lxoX5R4e+Oh3ls4ZE/QYBG2TGoebSQvWsfNgVYPRySPPfbvtEDsLVvP+rBHCmKSlGfcXcDtgw/zgBJXeAr2ugouf4Xj3G9Wj8bfPtgYUUg2d/zaXyptrc5gxJo2E8IYW1gTNgZjRnmbsPFjJgXK7z+Ohpq1s3FfeLEJKIDgd0TSNV3/KCriIEXPebVjOykA2mJAUHZYuw+rUzIDXDOaV5ZlomkZFrYst+eV8vvkAS3eVYm+gCdLx9rsNogh3yvp0iDIfLeo+nuPNI4IhOdrCu7cMp3N8WKOufEcw6WXuOqcLk0ekAmBISSFp7tMsu+n/8Pgxtajc8DGly14jcvg1dLzzHZJuW0j4wAnY9tSNOOaW1LD9QGVQYxC0Tdyqhxlv/cKOgqqgIxQOt4eckhpufH099hBt/gUniSTBeY/CVa9BQk/QmUGqN3+SZK+Iik6DS56Fy1+Cegs1P+ws9FtXDsGd/xLwzrrcpn53giAQkanTjLySWu/koF6hYqhpKwXltsY3EggEAJTVuiis8k3DC2URAyCnpJZ+j36LR4NOMRZSYi1kGKJQGjFUaNR+F8ASF9QYziTGdItDwvezDcU8wmpUuDkjlWirgU/uGMnT3+zivQ15SHjrmeoc97DjYIcoE3+5qAfn9kys87zdpfJhZjVuqa4gCyWzwKlqvL4qi2eu6Y/g1OSLLQVszi/37wTaQHTCpWpkFlbz3vo8pvqrvxQ0Lz0u8f4UbIH1r1KU+Suys5rY6BiI6wrDZkHHwQTqxTDvp0xqHL5CONjz3+H28NaaXO4+u6tw92xhhJg6zfDaoPuuaoaattJY0bNAcLqzJb+ctZkllNY4MSgyceFGzu+V6NeGtrzWiV6RqDd3DnkRw6iTWTxjGL2Tjms34OkPcx+E2sBpeo3a7+rNMHhqUGM4kzDqFCYN68Qbq7N90vOCMY8A0LmcjEuLBLxpg3+9pCf3X3AWX24p4M01ORRW2XG4PYQZdQxIjmLG6HT6JftasQNs3lfuN1U0FFGuejSW7SgM9iMQtEFe8ZOqD8GlDNtcHuavzGLKyFTRIL61aN8XrniZRd/tRpZgzrndGt1F0zR+yyvz+1wo579L9UYouySEN7qtoOkQYuo0I8yo83szDqXnCXhz+QWCMw27y9sGYN5PmRwot+NSPUfraYw6mSe+2sGwtBhmjunMyC6xRycrOln26+gW6iKGLEnEhRvrToJkBYbfDiueBrdvCu8RGrTf1TQYcGOjr38mclNGKm+tycHl5y/YmHmESSdxTc1eci99kYQ//pHwC85HkiRMeoWrB3Xk6kEdA+5b63Tzw85CDlZ4xVaESUdtgPSsUEV5jdMd1HaCtse2AxXklfgunIQSnaywuVibWUJGFxGNbk2q7W46RAVXv2RzqciS5HcxPJTzX5ElKmwB7gOCZkOIqdOMLglhuPxElUJJWwFIj7f6HEMgOJ0prLJz/fx1FFTY/a4KH4nWrtxTzK+5ZZzfM5GnJ/ZDr8hEW/U4/Zx3oS5iuFQPUWY/xeODpnjF1HEEbb8rG6DHpWAWDm/+SIoy84+r+vDAx7+H5KBl1Mn07xTFn6fPwfFzBof++SSl7ywi8S9/wdwrcM+wzKJq3liVzUcb9yPLEk63iurRMCgyHg1cHt8xhCrKG4tIuFQXB2sOUumqxKSYiDPHEWmMbPxNC5qdr38v8JsZEkp0wuZU+XTzASGmWplqh4twU3D9KwMJKQj9/FdkkeLX0ggxdZqREmulR/sIfttX7vNcsGkrVoPCrWM6t9SQBYJWp6TawaUvrqKk2unj7OaPWqfK0m0HqXa4mT95MOEmPX07RrIxr+55F+oixuDUaF9rdABrHFz2Iu5P70KnBo5O+SApEJ4AE55ufNszmKsGdsTmVHn8y+1BCSqzXqZfchSv3zwEnSKjGz6ctI8+pPzDD9k3axZho8cQP2c2+oSEOvu9sSqbf32zE7eq+XzP7A2kVocqysOM/m/t+VX5vLfzPT7Y/QEaGoqkoKHhVJ0MTBjItN7TGN5hOLIkJmOtxcFKB/4uQaFEJzTgUGUI1wlBs1DtcBNmDM74x6RX0MmyXwOKUM5/l9tDjHBzbHGEmDoNuWJABzbnl/u9IDeWtgKgU2TO7ZHQ4DYCwemCpmlMWbiB0iCF1BHsLg+r95bw3Pe7mTg4mXCT/8tpsIsYJr3Mhb3aUWFzEWn2vQF/wSi2azfyR907yA2k+x1FNniF1NSlIioVBJOGp9A5IYynv9nF1v0VeDwarnrfB6tBwaRXmDE6jVtGp9cp8pYUhehrriHioosoefVVsi+7nJgpU4iZcjOyycQLy/bwyvLMkKJfRwhFlOtkiYt6t6uzv0N18OCqB/lx349omobL45sGtP7gen4v/p1wQzgvn/MyZ8WcFfI4BSePGsDNLdTohL/+aYKWpcruDnhf8Mf5vRL56vcCn7lbKOd/u0gTyTG+db2C5kWIqdOIWqebF5btZcnPeYSb9FTZXX4FVUOY9Qp3jO8snGAEZwwb88rILKrxmThDcM12X16eyVtrc7h+aCc251dQXus7UQ1mEcPh8vD0N7v459c7ObdHIreMSaf/YaOCZTsO8chn21h0y8PINRfBNw9A+T5vDVX9HlR6i/exnpfDRU8JIRUCw9Nj+fC2DLKLa3hzdTa/5JZRZXdj1Ml0jDZz04hUxnSLb7DZuBIeTsIf/0jUtddS+PRcsiZczLYp9/KfPfIJCakjBCvKFY/KzX1ij/5uc9uYsnQKmeWZONWG7fNr3bXUumuZ/PVk5p83n/4JwhGwpYkL9+/IGWp0MjZMRCdamyq7m7AQxNQto9NZtqPQb7PuYM5/i0Fh1tjOwnikFRBi6jRA0zS+2XaQx7/YwZDUaL6ZMwaH28MlL64KSVCZ9TLn9EjgltHpzTtggaAN8eqKLL83r2Cb7coS/PH8s5g8IpVRXeKZ8fbPJzRp1jhmpf311gJ+2FlI18Qwbh/bmQc/2crrU4bQo30EcC50ORcO/AZrXoLcNeCs9vaRssTCoGnQ/3ow+3eMEzROWpyVRy/vfVLHMCQn0/GF56nZsIGp7+3BbvT/92hMsB9PMKI8XaqFG6+m+OabiJ48mXvW/pHM8kwcqq91fyBsbhuzvp/FkkuWkBKREvR+gpNnbLcE3l2f52urH6Jd/7k9EusfWtDCVDvchAdIufVHv+QoOkSZyCzy79za2PmvaXBZ/w4hj1Nw8ggx1cqo1U6c+dVoNjcoEkqYAUNKBJKfRpL+yCmu4ZHPt7G/zMbcif0Y0fnYiuQnd4zk2lfXUml3NTq5sxgULu3XgX9c2UesagjOGMpqnCzfVUT9ut9QnLNcqsZba3O9YqprHE9f3Y/7P9x8UlEIj+aNem07UMFt727kP5MGHo1SHaXDAPjD6yf8GoKWYW/7bhSHlYCf70Owgj1YLAaFZ2+/iNRZwyl64UU+mn4uGyc4cdRvPAiUrSyj+JtinIVOFJNCxKAIEv+QiGL1ppDVump55pdneOHsF0J/04ITJqNzLGEmnY+YguCjk7IkcUGvdj77C1qW6hAjUwDPXzeAifPW+l3gawiTDE9P7IvFIKb1rYH41FsBTdNw5lZStSIf++4yJJ3M8a68kiIRltEB67D2KOH+Q/V2l8oryzN5e20OM8d2ZtrINAy6uql5aXFWvr9vLP/7eR/zV2RR43DXuUAbdDISMNSP1bNAcCaQVVyDUSf7OPGF2mw39zgr40v7dyA+wsifPthCcbUDm0v1EWvBonpAkST+tXQX47ol+DenELRpFqzM8uvOFopgDwaLQWHBzYPp3i4CiCDpmbl88/EkbBVbqN+XuPjrYoq+LqLjjI6E9QzDVebiwKID5MzNIe3BNGSdjIbG6v2rKbGVEGuO9fuagqZHliVuGZ3O3G93+V2QaSw6YVAkJg9P8ZkPCFoerwFFaNPsXu3DedL+K3+We2OXgtvXpMD0nUsZXx0NiMhUayDEVAvjcaiUvLUNZ34VmssDGmjuuisQGlC5PJ/K5flEXZZO2ND2dZ7/cWchD3+2jV4dIvjy7tF0iApcbBhh0jNjdDrTRqaxJrOEn3NKKK52YtIrtI80MaFP+wb3FwhOZ6od/vvxhNrXx6VquFXP0VrD4emx/HT/ODbmlfHqiix+2lXUYCPshlK9VE3jYIWdjzbmM2m4SLk61di0z78ZUKiC3R+S5K1zjQsz8p9JA+mddMze/GDNQX6r3ukjpFSbSuEnhSRNTyK8r9de3xBvIPn2ZHbfv5uKNRVEj4k+fHyJD3Z/wMx+M094jILQuW5oJ95ck0NBuR01hJUYSYIIs57po9KacXSCYPB4NGqdbqwhRIo0TePQP5+k34EdfPD3W7n3k+3sK7XhdHv8fg+sBgWrUcc/ruxDRm0i+XfdTdLzz2EdOrQp34ogCISYakE8DpXClzfhLrWBu5EL5OGJV8XnWWg2N+Fjk8kvq+Wxz7ez+1AVj1/Rm7Hd4oN+bVmWGNU1jlFdRd8JgeAIgZpTh9zXQ5J8TAkkSWJQSgzzJ8dgd6kMeOxbbCeY6mVzqcz7KZMbhnUS0eNTDH/pWhCaYNfJEpFmPeW1LvQ6CU0D1aMxvnsCM8ekMygl2ud7sa5gHYqsQL2vXO2eWjwuDxGD6vYkU0wK4X3Dqd5WfVRMOVQH3+V+J8RUCxNm1LFk5gguf2kV5bWuoFxGZQnCTXqWzBxBbJh/EwtBy1F9WEjJDZjV1Kfk1fnUrl9PyjuLUCIi+PaesWzdX8GClVl8s+0QdpcKkvd6MDw9lpljOpPROfbwaySS9Oy/2T/nHjq+9CKWgQOb780JfBBiqoXQNI2St7cFJ6SO38/loeL7PL4tKOfx3QVMG5nGC9cPwBRgEigQCIInKdrcJM1248IMDYqczzcfQPbzfCipXiU1Tn7JLWNIakyj4xG0HYwB0q1CEewGncyfL+zOZf07UGFzoVdkIky6Bl1XKxwVuD2+kVe1WkUXpvNbl6uL1GHLtdV5rNJZ2eDYBM1DUpSZr2aPZsobP5NTUhMwXVgCzAaFxAgTb08bSnKMpcXHKvAl1HqpsiX/o/yDD0h5dzFKxLGFjt5JkTx33QDAu4Di0TT0Ac576/DhdHjqKfLvvIvkea9g7tv35N6EIGiEmGohnPuqcO6r8iukNuRv4R8/vsLu4hxkWaZrbAoPn3MX/dv38G7g8tBtazmfzRlJpzhrC49cIDh9SYoy071dOJvzK+o8HopzlkkvMzmj4fS777Yf8huhCCXVy+ZUWbWnSIipU4z2kSYKKnz7goUi2GXJ2z/GpFeCXkgL1HhXCVNwV7vRVM1HULkr3OjC6k4LpPp5goIWIyHcxJd3j+K3feXMX5HFjzsL0SsykuStm/ZoMKZrHLeO6czw9BgRtW5DhFIvVfnNtxS/9BIp7yzyafR9PIosoTRyPoaNHkX7J/7OvttuJ3n+q5h7hW5kIwgdIaZaiOoV+d4aqXpUOWqY+sH/8cT593Jp9/E4VTcb8jdjVOoaT8TIMnFVbhBZegJBk3LbuM7c9/5mahx1xU6wzlmaBtcP6dTga5TW+O/vE0qqlwYUVTfcJ0jQ9rg5I5VdB38/KatrnSyT0Tk0E4goYxR6We/ToNfSxYKkk6j8tZLIocdqrFS7StWWKhL/UNdSO8okLPZbE0mSGNgpmnk3DqKi1sW+slqqHW5W7ikmr7SGF68X6VxtkWB7TNWsW8fBRx+l04LXMKQ0TU1s+PjxaI88zL6Zs+j0+gJMZ4kG3M2NEFMtgFrjwraztI5j3xGySvcBcEXPcwEwywpj03yLBzWnh6oV+RjTIn2eEwgEJ06fjpEBV9+Dcc46t2diozUKgZq8hlqbpQ8h/17QNriwdzse/Hir3+eCEexGnczNGakhN1IflTQKt+ab5qdYFBKuSODAOweQTXIdNz99jJ6ojGPiyaSYuDT90pBeV9B8RFr0RFq8cwCdLPH4F0WtPCJBIIKJTNm2bmP/vfeR9OyzmHr2bNLXjzjvPHC5yJsxg5SFCzF26dKkxxfURYipFsC1vxpJJ/u49gGkxyQjSzL3fPkEl3U/hwFJvYgyhfs9jjNX5K4LBE2Bpmmsyypl/opMVmeWcCImwjpZIjHCxCV92vPppv3oZJnYMAODUqJ9ctoTIvyLrVBSvXSyREKE6QRGKmhNjDqFycNTWLg6G7uf+rzGBLskwaThDUc+/RFtimZM0hh+yPsBTz0XivgJ8ShWhYNLDuIsdCKbZSIGRpA8MxlZf+y7q6FxeZfLQ35tQfPTNTGcvYXVeDxaSCYHgpahyu4iwqQP+LwjO5t9t82i3aOPYB3WPO57ERMmoLnd5E2fQac3F2JMEy6PzYUQUy2Ax+b2G5UCCDda+WjSS/xn/bv8aenTFNWUMr7zMP514Z+It9atjfAEcIUSCATBY3ep3PnuRtZklmBzqoFOzQbRKxJGncyhSjv3f7AFDe1wbEtClmDyiBQmD0+lXaRX/Fw5IIkfdhb6pBKGluolGnGeqsw5rytrMovZUVCFUw2+mbNJL/PMxH4khJ+YiJ7Sewqr9q/CrvrWbMWMjSFmbOD6O0VSOC/lPMIN/hf3BK1LpFlPmEnH/nKbMJ1oKxzaButegf2/Mq6qnCEePbzbA4beCunjQfYuVLgOFbJvxi3E3323N4LUjERedhmay0XetOmkvPUmhk6hL8wIGkeIqRZA0jW8atQ1LpVnL34AgL0ludz9xd95ZNmLvHzZw3WPI1afBIKTwun2MGnBerYdqPDbELMxLAYFp9uDpkGNwyvEnKpvKtWCldksggXuEQAAIABJREFUWJnNny/szrRRaYztloBJp/iIKQi+NuusduF0SQgLecyC1seoU1g0Yxg3vb6enQergvruGVUXDw1L4OK+J96Es29cX85LOY/vcr/zK6gCISERYYjg3kH3nvBrC5qfbonh7CmsEmKqtdn9LSx7DEr3guoCj5swIAxgdy7krAKDFUbdg9r9OvbNmEHUtdcSPXFiiwwv6uqrvYJqylRSFr2NPimpRV73TEKIqRZADjc0vtFhusSmcE3vC3ln02e+x7EEDhkLBILG+ctHW05ISElAhygzOlmioMKGU204nnWkQe/T3+yirNbJ3ed0pXu7cFZnlvjdvrFUL4tB4bZxnUMas6BtEWHS87+ZGbyyfC9vrM7B7fH4iGu9IiFLEn2SIrkrvoYOLz6Ee9yH6KKjT+g1JUnisZGPUeGsYEPBhqAElSzJhOnDWHjhQuItwfcyFLQ8XRPC2X2omrO7Jza+saB5WDEXVs4Fly3wNs5qcFajff8ojg/+jXXEzcTeMqPlxghEX3cdmtNF7hFB1U5kOTQlQky1AIaO4Uh6Gc3PqvTeklyWZa7lsu5n0z4igQOVh/h0xzIGdqhnZ6mTsA4RX36B4EQ5WGHniy0FR4XO8dRsX07lz5/gKslHNpjRJ6QTmXHN0aa5GlBQYUMvS40KqeOxuVTmr8jig1/z6ZoQRnqclbzS2qCacB7BqJMZmhbD+T3F+X+qY9DJzD63G3eM78L3Owp5Z10uBRU2HG4P4SYdQ9NimZqRSurhFhiFuVs48Mf7SZ7/KpJyYr0FdbKOF89+kWd/fZb3dr6HLMnY3L4TP0WT0OkMdI3uyrPjnqWdVXzf2jrdEsPYkFPa2sM4c1n7cuNC6jgktw2T1Yk5ZQuSpnkLIluQmJsmo7mc5E2ZSqe332rQhl0QGkJMtQCSLBE2KonKZXlQb0XcarCw6cAOXvv5f1Q6qokwhnFu5xE8OP52n+NYh4mbm0Bwoixal+P38coNH1Ox/gNiz78DU9pAJEWHLftXbHvWHxVTAB4NHH6EVGNCzOH2UGV3s3DKEMpsLibOW8uBcptfUVcfs16md1Ik824cJIrMTyN0isyFvdtxYe+Gr+nxs2eTN206xS//h/i77zrh15MlmfsG38fMvjP5LPMzFm5bSGFtIXpZj+pRkZEYu9fAHfctpnOUiICeKnRNDGfx+rzWHsaZyaFtsOxxOG5hIvW5KmpdkD07DKvBe71esNHJO1tcLJ/iXSCRJRVyVsKvC2HI9BYfduz06WhOJ3nTppHy1lvoYuu2XKhxuPlpdxHF1Q5cqkaEScfQtBhSYkWP04YQYqqFsA5p5xVT9WgfHs8rVzza8M4ymLrFoIQFny4oEAiO4VY9LFqb6yNgPI4aylctJnbCHCxnZRx93NJlGJYuwxo9brBCTENjxd5ixp+VwOd3jeL+9zfzw85CAL+iyqxX8GgaEwcl89dLewbseC84vZF0OpKemUv2HyZi7teXsLFjT+p4YYYwbuhxAzf0uAGH6qDKWYVJMWHRWcgcfzYdp8og2kqdMnRNDBOOfq3F2pdB9e37p2rw/HonD4xuoF2GqxZW/RsGT2vx6BRA3G23eWuopk6j01tvoouOZs+hKl5flc0nm/ajyBJuVUPTNHSKjOrR6J0UyayxnTm7e0LAVh9nMkJMtRCKVY/znI5oS3MxhtJRXgLZqif6KtEjQCA4UYqqHX5d1Bz7d6K5nVi6jQj5mKEIsRqHyrzlmYw/K4Ewo45XbhxEYZWdd9fl8fa6XMpqnMiyhEfT6BBpZsaoNK4e3LFBa13BmYEuPp6kfz9D/t2zSV2yBEPHpikeNypGjGbvhE/TNLLHXsKTi39llyGbGocbk16hQ5SZKRmpXNCrHQadEPRtjQiTniiLnvwyG51ihQlFi2GvhK0fguZbunF/hoF/rXZw+xADUaYG5nq2MshdDakNt8RoLuLuussboZo+g0+mPcL89fm4VA/1b5NO1fsef80tY/Z/fyM9zso7M4YRZRGL+8cjxFQLsXlfOdNX7eH5ge1J/b0cLZgCeNkrpBJm9hNRKYHgJKiyu9H5WU1TbZXIloigGubWJ1Qhtjm/vM7vCeEm5pzXjTnndcPj0bC7Vcx6BakVVioFbRvLoEHE3jKD/bNnk/LuYmTjsVXvKruLgxV2qh1uwk062kWaG20WejxfbjnAv5buolDtgd3hQZOOpC25KKiws/NgJQ98/Ds3j0hl9rldRZS0jdE1MZzdh6qEmGpJdnwOAe4ZgzsojEvVMXeNg7+f3UBLA2ct/Px6q4kpSZKIv+8+5la8xWcrMrHLjS/c1TpVdh2q4pIXV/HlXaOJPElTNJfq4dtth1ifXUJxtQOTTqFjtJkrBiSRHn9qOdcKMdUCrM0s4c53N/Lk1X0Z2TMRW79Syj/PxFPhRHN7fHtQ6WRAw9w9hqgrugghJRCcJEadjD/PB8Ucgae2Es2jhiyoQhViDpcnYDqOLEtYDOJyLAhMzM03Y/ttE4f+8U/aPfIwm/aV89qKLJbtLER32AXQo2m4VY3zeyVy6+jO9OkYGfB4mqbx9De7eGN19jF3S8lXKB1xHFywKou1WSW8NW1oSGJN0Lx0SwhjT2E15/YUjn4tRuUBrxgKwGPjjYx8o4bZwxqau2lQltPkQwuF9zbk8TmJ2OXg3W1dqsahSjtTFm7go9szTmjxr6jKwcLV2Sxal4vHo1FzXA9VnSzx6ooserSP4LZxnTm/Z+IpscAorogniMej4TmcT9oQ328/xJ8/3MKLNwwgo3McAOazYjB1i8aVX03Vinwc2RVoThVkCdmixzokEeuQdkJECQRNRGyYEZefND9jUncknZ7a3Wuxdg9thTBUISZLUmukxwtOEyRJov0Tf+eX66cw5fEvyHPpcbhVrzFKvVZnX24p4PvthXRNDOONKUOIC/Ot3/jP8kwWrs4Juk2A3eXh9/0VTF24gXdvGS4iVG2EbonhrMvy33JB0Ew4q/FdBT9G7wSFS7rpeHKVkx7xDZwnQboANgeqR2Put7uxBTj/GzJWcqkauw5V8WtuGYNTAzf+9sfW/RVMWrAem0vF6ade2O3RcHs0Nu0r554lmzivZyJzJ/Zr89cbIaZC4EiB3hdbCqhxeu9eekVmQHIUs8Z2Zmy3+Dqrzh//ls8TX+5k4dQh9O1Yt6pXkiQMyeHETurRou9BIDgTCTPqGNUljuW7iurcAmWjlahRkyj9bh6SrGBKG4Ak67DnbMKet4Xo8dMCHjNUIRZh1p0SK2yCtssBp8ysQTOorHaiyr71GkfwaF5b/u0HKrno+ZV8esdIOkSZjz6/o6CSF3/Y41dINTSJcro9bN1fyesrs5g1TtTxtgW6JobxdgCnUkEzYY4CSQeab8P2Izw6zsTAV6u5b0QDRhSmwJHj5mb5rkIcLv/XkGCMlY60/QhFTO06WMW1r66tE4lqiFqnyrfbDnKnS+WVSW3b0fbUEFMuG2z7GHZ9DbUlIOsgogP0uw7Sxja7G0pmUTX3LNnE7oNVuDx1C/Scbg/rs0vZur8Ck0HhkUt6cWn/Dry1Jod5P2Xy3i3D6JoY3qzjEwgEjXPrmM6szy6ltt6FPGLoVcjWaCrWLqH4i7lIBjPGxC5EjLi2weOFIsT0isTVgzo2y/sSnBlU2Fxc8+paKhwqniBTS90ejdIaJ9fNX8dXs0cfTc9bsDILlx+b/2AnUa+tzObWMZ3b9OTmTKFrYjiZhTXC0a8ladcX9KbDESr/dImRubaXnhc2OOmT4CeqohggeWgzDrJhXl2R5VfUBGuspGkctVD3F/muj92lMmnBOp/7LzS8gGNzeVixu5j5K7OYNbbttm1o22KqsgBWPwu/veP93VlT9/kdn4MxAjLu8vr16xr/g4bKb3llTH59AzVON1oDfTZrnCo1TpX7P9zMkp/z2Fdu438zR5AcI4pCBYK2wPD0GGKtBmxOm0+CRliv8YT1Gh9w3yNTlPr7BSvEZEni5hGpJ/sWBGcwb67OobTG6bf2r6HJiOrRKKy08976PG4Zk06l3cUXWwpQ6x0oFHdKu0vlpz1FjD9LNP1sbcKMOmKsBvaV1YpeQC1F+ngwWBoUUwB/G2tk0RaX/yclCYbe0gyDC47tByr9Ph6KsZJBJ7P7YBVxXRqfe3/1ewE2p+pzDw12AWfe8kxmjEprtLSmtWi7YqpgC7x9GTiqwRPgy+is9v788Bj8/j5M/gjM0U02hKyiaia/voHq+gnpDWB3eVidWcJfLuouhJRA0IaQJIk3pgzh8pdX+10da4gwkw6n2+O3J1RjQkyRJQZ0ihLXA8EJo3o0Fq7J9vv9C2YyYnd7eG1lFtNHpfHdtkN++8SEMomqcaosXpcrxFQboWtiGLsPVQsx1VLIMoy4E378Z52mvTlz6mYhJUfK2B+K8H+M5OEQ1ak5R9kg9gApfqEYK2kaVNoDzM/rMe+nTJ9IWCgLOC6Ph2U7C7mgV8ONzluLtimmivfAmxPAURXc9i4bHNoKCyfALT+A3tz4PkHwx/c3H62Nqk9DK4Ea8Mx3u5k4OJloqzCREAjaCl0Tw1k8Yxg3vbGBGofb7yr/8ciSt5fLu7cMJ7+slrv/+1vQBftHCDfp+Pc1/U9i1IIznR92FuLyI6RC63XmZnVmMQcr7X4nUqG6U+4vb73ieUFduh22Rz9POPq1HAMmw6rn6oipoNGZ4eyHmn5MoQxBkXD7uQGGYqykVVdR+ugj5Jrt6BISDv/Eo09IQJeY6P09Pp7d5U72lfp+TiEt4DhUXluR1WbFVNuLl3nUYxGp40h9roqEp6uocR774y/Y6GTcm4dT/1QnlGbBF/c2yTCyiqrZdqDSb2pf5YaPKV32GpHDr6Hjne+QdNtCwgdOwLZn/dFtJAmW/LKvScYiEAiajgGdovnirlGc3zMRo07G5KcZqVkvY9TJXNynPV/NHk3PDhGc36sdj1/eG5M+uMumLEG0Rc+SW0fUKf4XCELlu+0H/dY3hBpN+uqr9ZT/8hsePze24ydRweAIcVFB0Hx0TQhjz6EgF58FTYMlBm7+DAwhRgN1Zrj4mVatlwKICbDQf7yxUmNoFiu9/jSbuNtmETZ6FEp0FK79B6j85lsK5z5D3s1T2D1kKKtm3I1k97WSD3UBJ7u4pvGNWom2F5na8523u7Qf20lVg+fXO3lgdID8TLcdtn0IF/7jpNP9Fq7O8ckph+BXAu0uD6+vzObW0emiKFQgaGOkxFqZN3kwxdUO/rshjy+2FFBpc4EEUWYDVwzowDWDk326vE8cnExStJmHP91GfpntqDX18Rh1MhowqnMcT1zVm/aRQkgJTo7iaqffx0OdjBTmF5JgcaPHTP3knNDdKU+uYaeg6eiWGM7C1TmtPYwzj3Z9YNq38Pal4LKDK3DvKRQDSApc/jL0ubrlxhiAG4Z14qVle7HXi3iHYqwUF26i7+AeDbrUaprG5pW70b7NBHfdm2Wo7UVsAVIT2wJtT0ytfi5gUd/9GQb+tdrB7UMMRJkC/fFkr2FFxl0nNYzPNx/wGwINbSXQzfaCSnontZ79pUAgCExcmJE7z+7KnWd3DXqfjM5xfHfvWH7Pr2DByixW7S2m1qkiy14hdvWgJG4clkJChKkZRy44kwi0HhfqZCRs0ABGDU/htTc24KoX6QplEmVQJEakx57UexI0HV0Swsgqrkb1aH7r4QTNSLvecPdmtI3v4P70b+jCdUho3mwpWecVUUgweKrXcKIV66SO5/ohnXhh2V6/zwVjrGQxKMwcm95ouw9JkgiPjkBRFHDXLZsJdQHHpA9u0ag1aFtiqmI/HPgt4NODOyiMS9Uxd42Dv58dYKLitsG6eScnpjweqgKYToSyEqjIEqU1/lcUBQLBqU2fjpE8f/2A1h6G4AwgIcKEhG++RiiTEQlICDcyOCWaGKuBWqdvDUOw7pSSJDF5RMrJvSlBk2E16oi1GskrrSUtTphQtDimCKqdPSnJOYfUp+6Gwu1gr/CmAEYmQ9fzmsVt+mSIDTNyXo9Evt1+0G+bhGAcbq/onxTUa6XEWlD9pBaH2ucxpQ2bOLUtMVWe61XxbnvATR4bb2TkGzXMHtaAsUPVAa+gctV6j+Wq9ZpUuGzH/b/2cFi2/mM2NLcTzbMIfyVlIa0EauD2iLxygUAgEJw4l/btwCe/7fdxoQxlMmLSK0zo0x5Jkpg1Np0nvtzpN22msUkUwODUaFEH2MbolhjG7kNVQky1EhWffkrkFVdA6kjvzynAP67sw6Z9ZRysdPgtawmESS8zb/IgrMbgJETP9hG0izD5rXkKdgHHalCYMTo96DG2NG1LTDkbyDc9TO8EhUu66XhylZMe8QEKwTUPlOzxrgroLWCN9/6rM3md/vSWAP96fySdCdPD3/i1Tw5lJVADIkVeuUAgEAhOguHpMUSa9X7vScFORhIjjPRPjgLgygEdeXHZXr81f41h0svcf0H3E34vguahW2I4ew5VtVm3s9MZtbKSmpWraP/ww609lJCItOj54LYMJs5bS2GlHaefCFV9THqZZ6/pz+iu8UG/jiRJ3Da2M498vs3vNSyYBRxFlji/V9t1q2x2MeV0e/h6awGvrcwiu7gGh8uDQSfTPtLEtFFpXNE/6Zi6NYYFdcxHx5kY+Go1940IEDaV9V63lJNgeHosP+4s9EmrCGUlUPVo9Gwv6qUEAoFAcOJIksTMMek8tXQnNj8ueo1NRsx6hVljOx+tb7AadSyZOYLLXlpFdRAtAo5g0sv86+q+R0WZoO3QNTGcFbuLWnsYZySVXy/FmpGBEnXqnRftI818NXs0T329k4827keS8BE8ekVCliT6JUfxwIQeJ3T+X9qvA//4age1hG4iYdYrTB+Vhr6NNuyFZhRTHo/GSz/uZf6KLDQ0ahzHPkC3UyWzqIYnvtzB419s59rByTxwcQ+MMZ3B7Wj02F1iZK7tpeeFDU76JPj5cJugwG/mmHTWZZWc8EqgTpaYOLgjZkPbLZgTCAQCwanBdUM78b9f8tlTWOW3xiEQBkWme7twrhrYsc7jqXFWPr9rFNfNX0elzeXXev0IRklD1ul44foBopdRG6VbYhgLVma19jDOSCo+/ZTYGdNbexgnTIRJzxNX9uHBi3vw2aYDLF65h8L8QoiPJ9ykZ0zXOKZkpNEp9sRrlswGhUXTh3Ht/LV+59WBMOJhSFpcSCZRrUGziCmX6uG2d35l9d6SBq0Mj3ygS37Zx6Z95Sy+ZThhaWPQ9n7vdUNpgL+NNbJoi5/Oy3orZNx9UuMHGJoWQ7TFf5EuNL4SqMgSU0emnfQ4BAKBQCAw6RUWzxjG1fPWsL/MhsNPE9/6GHUyqbFW3pw2FIOffmopsVaW3z+OpVsP8sryTHJLapAlbzNPRZaQJdCjcUXmSm5/+SHiI9tuAfiZTpeEMHJKanCrHnRteAX/dMOZl4czJ4ew0aNbeygnjcWg47qhnbhMLqTw32+Q+uziJj1+n46RLJo+jClvbMDuUnE1EhI362UGFO7lmVjavEtlk4spTdP44/82s2pvMfYgm/rZXR52FFQydeEG5g6eQULmKsxaXRGTMye8zu/JkTL2hyL8HM0DfSee6PCPIkkSf7+iN7ct/jXo93EEs17h4r7tRSGoQCAQCJqMaKuBz+8cxX3vb+bHnYUAfkWV8bBwuqBXO566um+DGRJGncLl/ZO4vH8SOwoq2Xmwkiq7G5NeISnKzPD0WPbd8DbGdavggvOb540JTg5Nw1LwM08a3sDx9mvodBpY4rwucj0vb3NOcqcTFZ98SsTFFyPpT5/6eHdxMbq4uGY59qCUaJbeM4b5P2Xy/q/5QN20QlnyXpOSY8zcPq4L51vSyZ86DWtyEpaBA5tlTE1Bk4upH3cV8t2OQ34FSM325VT+/Amuknxkgxl9QjqRGddg6tgLp6qxMa+cC/fDmogOmGy5SB7/9uQB0Zth8PTQO1IHYHz3BB6Y0IN/fLUjaEFl1isMSonmyav6NMkYBAKBQCA4gtWoY96NgzhUaeeddbksWptLpd2FTgK3qhFpNTAlI40bhnUiPjy0SXSP9hH0aO+7SBkzdSqlCxcSccH5HCi3kV9mo9bpJtykIzXWSmyYmKy3CqoLNi7y9uesKeIy1Yace9xq/66v4It7YNAUb8ZOuEjRbEo0j4eKTz8l6fnnW3soTYq7qBhdXPP1kUuKMvPo5b35y4QefLb5AGv2FlNa68Sk8y7gXD2oY53+rB2eepL82bNJffddDMnJzTauk6HJxdQryzP95kNWbviYivUfEHv+HZjSBiIpOmzZv2Lbsx5Tx16A17AhMsxE1MwvkOaNAls5aEHmVupMkDQEzn20Kd8ON41IJcZq4P73t/gtzDseo07msn4deOLK3iLMLhAIBIJmIzHCxH3nn8V955+F3aVS+vt2Kh9/hO4ff9jkr2U5+2w+mf8h9zz9Ldsr1Dopg063h5Fd4pg5Jp2haTGNNvEUNBH2Snh3IhRs8bZ1wU8zF2e199/182HTu3Dz594ms4ImwbZxI7LFjKlXz9YeSpPiLi5GaabI1PGY9ArXDE7mmsENC6Sw0aOJmzWLfTNnkfrf91Ai/GWltS5NKqZyS2rYkl/h87jHUUP5qsXETpiD5ayMo49bugzD0mVYnW0r7S7WFxsZfutyWDgBaoq9jXgbQm+F9LHwh4WgNH0Z2CV9OzDurAQ++S2feT9lUVLtRKd4bxia5k1tVGSJf17Vh4v7dmjy1xcIBAKBIBAmvUJ8UgI1xU3v5rbnUBU3vr6eqrMuobbEW6dcP7Xwx52FrMsqoVOMhbenDyUh3NTk4xAch9sBb10KhTtAbdy0C48TbKWw8EK4ZTnEdWn2IZ4JlH/yCZGXX37aLSC4S4ox9+3b2sOoQ8ykSTizc9g/Zw7Jr77a5tIqm1R5fPl7AR4/BWWO/TvR3E4s3UY0egybU+WDX/MZPrEf3L4WNr4Da14ARwU4j2v4Jeu8P4m9qB16F9+ogyhaux+n20OEWc/ATtF1woQnS5hRx43DU5k0LIXMohpKqh24VI1Is56uiWG8uz6Pr7ceFGJKIBAIBC2OLjoad1kZmseDJDdNZsTv+RVcd9h9S/PTxP4IGt6sjb2F1Ux4fiWf3TlKNPVtTr7+PyjaVUdIpT5XRa0LsmeHYTV4J/cLNjp5Z4uL5VMOlz44quHty2DO7yALp+GTwWO3U/Xtd8R//llrD6XJUYuK0cUF30eqpUj8vz+z7447OPj3J2j3yMNtSsQ2qZgqKLf7dedQbZXIlgikIE5eDSgoPxyJMobDiNtg+CzI/gmyVkBNISh6CO9AVsI5vPy7zBfvF6CTt+Nwe/BoGgZFBkkiOdrMbeM6M6FPe0z6prlwSJJEl4QwuiTU7Yl19aCOPPf9bg5V2kmMEKtyAoFAIGg5JIMB2WLBU1nZJP1uCipsTFqwrkHL9Pq4PRpltS6um7+OpXNGYzE0eyvLMw97JWx+F9x2n6dUDZ5f7+SB0YFq2DSwV8Ce7+CsC5t3nKc5VcuWYe7TB33i6VeH5i4uRhff/Gl+oSLpdCQ9829yb7iB0jffInbqlDrPa5pGUZXDW0Mqy8SEGYgwtUwEq0mvdM4AVq2KOQJPbSWaRw1KUDnVeseRJEgf5/3B+4E9v2wP897NxKVqqB6N4wPd9sPj2FNYzUOfbOX5ZXv4763DaR/ZfCtlkWY9l/brwPfLvmWS9gXk/+LNV1aMENURht4K3S/xCkGBQCAQCJoYXUwM7tLSJhFTL/2wN2CNcENmUqrHO6H58Nd8Jo9IPelxCOqx+T2Q/EcJ788w8K/VDm4fYiDKFGDV3lntNawQYuqkqPjkUyKvuLy1h9EsNKeb38mihFlJnvcKOdddjyGlE+Fnn02FzcX7v+zjtZVZlNe6vGU4mldL9OsYxayxnRnfPaFZ7dWbVEzFhhn8Pm5M6o6k01O7ey3W7qMaPU60xf9xjvDEVztYvC4vKIe9WqdKfpmNS15YxVezRzdf1GjXUv6W91c8ZTlokop0vHFGRZ63SPSz2TBsJoy5H3QNv0eBQCAQCEJBiY1FLSmB9PSTOk6t081HG/fj9pNpEoyZlM2lMu+nLG4cntKmUnFOC9b956jhRH0Gd1AYl6pj7hoHfz+7gbnOgY1QkQ+RHQNvIwiIq7AQ2+bNdHz+udYeSpOjeTzeBZnY5nPzO1n0HTrQ8aUXyZs5i8/uepoF2yuRJeloX1vHcUbgv+SWMfu/v2HUK7wyaSDD0pvnfTWp5dyw9FisfvpZyEYrUaMmUfrdPGp3r8XjsqOpbmyZv1D24xt1tpUAh1tlS345muZ7If/o13wWr8trsBlwfVSPRrnNxfWvrUNtpEnYCbHyGXh/Csay3Zhx1hVSR3BWe+u+1rwIb17sDdULBAKBQNBE6GJjcZeUnPRxPt10AH8a6IiZVMx5t2E5KwPZYEJSdFi6DCN6/LQ625bVOtmQXXrSYxHUo/JAg08/Nt7IixucFNU0sNisGKEst4kHduZQ+cWXhJ97LrLl9GtirVZUIFssyIa2veBv6tOHFyY+xOu/FeJwexrUBDVOldIaJzcv3MA32w42y3iaNDI1ukscFoPOb451xNCrkK3RVKxdQvEXc5EMZoyJXYgYcW2d7Qw6mc4JYcz+7yYcLpXze7Xj/F6JDE2NQZElnvpmZ8APrbHUg0MVdpbvKuScHk2Y47puHqyY27jj4BHcNijYDIv/AFO+FGl/AoFAIDhpNE2jKiaB4gNlJFbYiLYYTrhW+Pvth/ym+IVqJrU2s6TZVoLPSDQNVGeDm/ROULikm44nVznpEd/AermjqokHd2agaRoVn3xC4oMPtvZQmgW1Daf4Hc9TS3fyQ5mMXQle9NldHub89zfemTGcQSnRTTqeJhVTsiwxY3Qaz363+2jd0vGE9RpPWK/xAfdXZLhqYEf+dkkv/nodEsJQAAAgAElEQVSxxp7Car7ZepB/frWT/LJaeneIpLzW5XffYFIPapwq837KbDoxVbwXvn+4TiFoUI46qsOb9rfqORh7f9OMRSAQCARnHGU1Tpb8nMeCVdlU2Puh26PB3J9wezyMOyuBW8ekMzglOqR0u7Ja/xP2UM2kiqqDsO0WBI8kgWJoVFA9Os7EwFeruW9EA82UTW2vV8+pgGPnTjzV1ViGDG7toTQLbble6gj5ZbUsXJ3j06IBGg6qANhcHv7y0Ra+vWdsk46pya12rhvaiddXZeOsdhBqRp1Fr+OO8Z0Br2tet8RwuiWGc9c5XdlfbuOm19f7/fBC6WO1Jb+CfaW1JMc0QXh23X/A4/Z5uHFHHbwRqnX/gdH3CotSgUAgEISE6tF44svtvLM+D1nicA2xhEuT4HD2xvc7DrF6bzHx4UZeu2kw3RLDgzp2oELtUM2kDKJ5fdMT1QlK9ja4SZcYmWt76Xlhg5M+CX7+Bm4HRKc10wBPbyo++ZSIyy9rsvYDbY1TQUwtWpvrtwwomKAKwL5SG1v3VzRp+6Qm/zZEmvX8b+YIwk16QjHOsBgU3po+lI7R/kVOUpQ5YFQqlNQDgyKzo6AJ6pWcNV5XHT9i6v4MA3PXOCi3N6ImVSfs/ubkxyIQCASCMwa36mHGWz/z3oZ9ON2egGZMmuY1YcorqeXKl1fza25ZUMdvF+nfvOB4M6nG0MsSCaJNSNMz4i7QWxvd7G9jjdQ4A8xBOg2HiPZNPLDTH83tpuLLL4m87LLWHkqz4S5q22LK4VZZvD4Pp1r3ux1KPafT7WHByqwmHVezSOvUOCtf3DWKpCgzVmPDq1dWg0KM1cD7s0YwsFPDOYyBaqVCST2wu1WWbj3Ih7/ms3RrASt2F/FLTinbD1SSW1JDUZWDGofbb/PhOuxeGtCe9HhHnQZxVsPPCxods0AgEAgER/jzh1tYl1UatBGThjfN/eY3NpBdXNPo9lcP7HjSZlKyLDGhT7ugxicIgb4TAV/xnDMnnHPTjyUbJUfK2B+KOFZecASDFUbObuZBnp5Ur1qFISkJY9rpG9VzFxejxLXdOset+/0HQ0IJqqiaxg87C5t0XM3WUS85xsLy+8ezfFch837KZEt+BQadjKZ5035dqoe0OCuzxnbmwt7tMOqCSxmoxffmEWrqQVG1g9V7i6lxuql1qtQ6VWocR/7vpsahYnermPUKFoMOi0HBYlCwGo/9f0LVL1zstAf8AB8bb2TkGzXMHtZIcVxFfqPjFQgEAoEA4Pf8Cr78vcBvNKqxeoEap5tHP9vGm9OGNvgaY7rGYzYoJ2Um1ScpkpTYxiMoghAxWGHwNPjlDXAFaXx1GM0DqtOAkjwSYVgfOhWffkrklVe09jCaFXdxEcauXVt7GAGpsDn9Oo2GElQBAvbQO1GatT25Ikuc0yORc3okkl9WS2ZRDdV2NxajQqcYC53jw0I6XkKEkXKbb6pfKH2sDIrMPed1azQK5vFo2FyqV3A5vP/anCo1TpVah5vkbXrkwsDWo0E76vjpYi4QCAQCgT8WrMzC5T6xegFNg7VZJRyqtDfYc1GWJWaMSuO5ZXv8irbGzKQsBoVZYzufwLsTBMW5j8GBzbD/l+DnEJIM5nAK943BedMUkp57Fn17keoXLGplJTUrV9H+4YdbeyjNilpc0qbT/ABvqL0eoQZVmpoWq6DrGG1hbLd4Lu7bnvFnJYQspACmZKRiOcnUg3CTngHJjXeHl2UJq1FHQriJ1DgrvTpEMjg1hrHd4rmoT3t6pycjN9J499FxJl7b6GR/ZQMpg8JRRyAQCARBUFHrYum2g6jaidcLgLeAuzGmjEwjPS4MvRJaDMOklxnTNZ5zeiSEtJ8gBBQd3PgBpI0Jqn4KnQnCEpBu/ZH2L71J2LnnkD3xGqpXrgzu9TQNnLUhR8JOJyq/Xoo1IwMlqvH546mMu7gYXXzbFVNRFoM/LRVSPSeA1di0saRTyo7kigFJ+DHwALypB9FnT6di7RLyX5xE/itTqNr4Beaux/InzXqFW8ekN01H9g4DA9ZMHeF4Rx2/yHro1Hh+p0AgEAgEa7NK0PtxyAulXsDh9vDFloYbvwKY9AqLZwwjJdaKURfcVMGsVxicEsPz1/dvmvusIDB6M1y/BC5/Cdr1BZ0ZpHqLzYYwsMbD2D/D7esgtjOSLBN3yy10fPbfFDz0Vwqffx5N9ZPy5PFA5g/w9hXweBw82RH+kQSPxcG710HuGgJOyE5DKj79lMgrTu8UP2j7bn59kiL9mtuFElRRZInzejZhv1maOc2vqbEYdFw9qCPv/7LPr0V6Y6kHkgRXD+rYNINJGgjh7aE0s8HN/jbWyKIt/l0INVlBGjaracYjEAgEgtOa8lonqh9zpFDrBSrtvi60/oi2GvjszpE8+PFWvvy9AMnpxCH7ThssBgVNg5tGpPCnC7sHtFYXNDGyDL2v8v4c2g5bP0I9sJPadWsIv/QaSB8PXc71237FMmQIaR9+wP4/3k/e9BkkzX362CR6z/fw2R3exr7OeoYlmuo14MpeAdZYuPJVSMnwOf7phDMvD2dODmGjGy4jOdXRXC7UykqU6KZtaNuU6BWZG4ensGBVNs56OiDYek694u2J25ScUmIK4MEJPfg5p5SsompcavCrIia9zH8mDSTSrG+agUgSjJoDX/8fuI5dbHLm1O3jccRRxx/btVTKSiMZ1XaNUwQCgUDQRgh0x2vOegGLQcez1/bnL0PjmH//XL4ceAmF1U48Hg2dIpEcbWHm2HQu7dcBi+GUm1KcPiT2hMSeqPn5HHpzCuEvP9noLrq4ODq9voCil14i++o/kDT3aSzydvj6T43UYmneeU95DSy6Cq6cB71O36hNxSefEnHxxUj6Jpo/tlHcpWUoUVFIStvufTp5hFdM+aOxoApAelwY3ds1bYnNKXflMxsU/nvLcG5YsJ6somq/ESqfffQyz1zTj3FnNXEOd+8/wI//8Dbg1Rofx/FoOjPVIx/igY9/p2tCGA9c3OOE6shaizJ7Gd/lfkdBdQG17lqiTdH0jO3JyA4jUUQTYoFAIGhyoi16v1GfUEyYACJMod/6leXfM62blYceOBfw9rrSiaa8bQ5Jr0dz+c+G8bu9opAwezaWgQMpe2IG5v77kQh+f9w2+HgWWOMg9dSO3PyeX8E763PJKqrG5lIJN+np3zGSUV//wMC5f2/t4TU77uKiNp3id4T2kWZuH9uZV1dkBd0e4ghmvcJTV/dt8jGdcmIKvKkHH9+ewYs/7GHR2lxUTaPGUfcDNSgySDAkNZo/X9idvh2boWjQYIGpX8H88eCoDF5Q6c1IFz7FsEGX8t1olbfW5PCHV9Zw5YCOzD6nK5GWtrv6sbV4K29sfYOf9v2ELMnYVe/qlYSEWWfGqBi5seeNTOw2kWhT2w0VCwQCwanGiPQ4XKrvfeb4egFJVjClDUCSddhzNmHP21LHhMKok7m4b8MubkVVDvJKa6l1urEadXSKsVD15ZfE3zPn6DZCSLVNQhVTRwgbMQzr0BIkZ919U5+rotYF2bPDsBq8Qn7BRifvbHEd62HltsGH0+GeHd7Uw1MITdP4bPMBXvxhL/vLbDjcKsdn0v6aXcKCPlMYuqaKe/+fvfsOj6pKHzj+vfdOT09IgYSSQg29Ezp2VFRA1EVWwYKiq7jW1f2t64q67qKsHSyoiF3BtmJdkA5Kb1ICISSQ3pPp9/7+GEHCTJKZkJB2Ps/D7uPUMxom9z3nLcFFDOoc2XSLbWTuwhbQye83d5/fldxyG59tO+53QGXSy7xy/UD6JIQ1+HpaZDAFnuLY+y/qwdzzu/H93lze25RJTpkNh0sl1KxjVEo7ZozoQny4uXEXEpkEt66CtyaCrcwziLcmitHTtGLSS9BnCgBGncKtY5KZPDCBBd8fYMIzq/jThBSmD+/ss9C4qWiaxsKdC1m8azEOtwP1jKGBGhpVriqqXFW8uvNVluxZwusXvU6PyB5NtGJBEITWJcyi56LUOL7aeZwzS6f8rRcAmDG8i9dtqqqxPr2QhT+l83NGEYbTmk44nG76RI7mnnYppGmaaC7RjEk6Xb2CKfZ9gVRDIqlbg+c2OXh4tLHm59sr4PBKSDkv8PduIi63yoOf7uTrXTk1XpA7VEDWsS69kC2vb+Jvl6Xyh2Gdzu1CzxFXfvNuPnE6SZJ48qo+dIkK4j8/HESSap4dFWRQCDHreWX6QAbUMRapvlpsMHWSXpGZ2Kc9E/s04byEyET40zbY+zmsXQAlRwEJVKen8FPWeYKoIbfAkJsgtIPXS7QLNvLEVX2YMaIzT/x3H+9sPMpfL+vF+IZOTaynBVsW8MH+D06dRNXG7rZjd9u5YcUNLLlkCd0ju5+DFQqCILR+t45J4vu9uT4v/vxpwjQ8KYq4sOozpo4VVTHjjU3kl9tPDeo9M4X+l8gkbnl3G3GhJpbePIwOjb1RKdRLfU+mWPefGjeD708z8K91duYMMRBuqiGQdlTAuudaTDClaRoPfLqTFbtOYPUxS80Xm1Pl8a/2YNTJDdfMrBlp7m3RzyRJErPHJjNjRGc+25bNop8Ok1VUiV7SQKfD5dYYmRLF7DHJjEiOatRNoBYfTDUbehP0u8bz58ROyNnpOanSmyE0HpLHg1J3+l6PuFCWzBrKyv15PP7lXt5cl8FfL+1Jt9iQOp8LcKLUyrsbM9mZXUK5zYXFoCMlOojrh3emq5+vcabvMr7j/V/f9wqkitcUU/BtAY48B4pJIXRQKLFTY1GCPDVTVa4qbv7uZr6d8i0WvaVe7y0IgiD8rnd8GJf0iQvoIvCkIIOORy/vVe229PwKJr+8jnKby+u0qzqJKoebo4VVTHx+DZ/NGUmXdn7MOBLOKUmvR3P5163xlMpCyN9f492DOyiM66Jj/no78ybUPOyZjLXgckAdMzibg8+2Z7NiV47Pv0OVe1dR9vNnOAuzkA1m9DFJhKVNw5SQitWp8shnuxjQKZykFlTn7g9XQQH6eO/N/ubOYtDxh2GduW5oJ9L/7zHs3XoRfeUkQs16TPpzU8MvgqnG0L6v5089SZLEhB6xjO4azTsbjnLdqxu5pE8c95zfjahg38fsW44W89wPB9h0pAgNqrWM3Hi4gA9/PkbX2GD+NKErF6bG+b0WTdN4fuvzXoFUwYoC8lfkk3BzAsG9gnEWOzn+znEy5meQ+Egi8m8pIna3na8Of8W07tMC/xchCIIgePnXlL4UVjjYfKTIr3oBCU/zprdnDal2AVhU6eCaRRsos7n8Hhnk1jTKrE6ueXUD380d26xrfNskRQFVRXO7/e/KVlUIigHcNczEBP4x3sjIxZXcPayWQEnRg60UgqMDXPS5pWkaz/1w0OffnbLNyynd9AlRF96BKXEgkqLDemQL1oObMCWkAuByayxed4R5V/Y510tvFFaHmxKrg9yCUjr06dfUy6k3SZIw5ucQO2EMIaG1BP2NoPkU5Qhe9IrMrFGJ/HjvWHSyzAULVvP6msNevfWXbjzK9Nc3svpgAXaX6nW/WwWbS2VXdhl3f7Cdv3+xG7X2LchTdhfsJs+aV/31rG7yPsujw/UdCOkbgqSTMEQb6DinI44CB6XrS0891uqysnj3YrQ2NNxPEAShMekUmcU3DuGaIR0x6mRMNQzVlfDMgEqItLB8zkiv4vk31x2pMZCq3LuKE2/PJfPZqWS9OIPcjx7FlrUHAFWDkion72zMaOBPJpwtSZI8dVOBnE5pdQfkvWMULuum459raw64AFADPBVrAjuySskts3vdrtorKVn7LpEX3I6lexqywYSk6LCkDKvWxMWlany6JZsqR/P/rDVxuFS+3HGcS59fQ+9Hv2XC/FVco09jxAaY8sp6vt+b63OmXXPnys9HF33ug3kRTLUA4RYDf5+UykezR7A+vZALF/zEt3ty0DSND3/OZN5/92LzM93D6nTz4c9Z/OOrPX49fsneJdhd1b90qg5WoTpVQgdV79OvmBRC+oZQsad63nWxrZgd+Tv8ej9BEAShboos8fdJqWz4y3ncfX5XooIMGHQyZlyYZA29IjGhZwxvzRzK6vvH0T2uepq3063y9voMr8038OzOF/34GmHDp5Fw51Lib3+TkIETsR7cdOoxdpfK4nUZLfKCq7ULuG7KFO6p8a7DY+NMvLbVQXZZDf/N3U4wN0Ln5Aa2ZEMGdpd3AGnP/hXN5cDSbUSdryFJ8O2enEZYXeP7fFs2g+Z9z0PLdrLneBluTcPqVLHLOtyaJ9Np7ofbGDzve1buz6v7BZuRpgqmRJpfC5ISE8ziG4ew+kA+8/67l5dXHmJfTrnPX4a15/x6AqrhSe24uHftKX+/Fv3q1bnPXeFGF6xDUryL+XRhOqxHrdVuUzWVQyWH6B/Tvx6fWhAEQahJZJCB28elMHtMMkVVDg499W9Ce3YncdpkzIaa07x+2JuL28eR1Mnd+aiJc7F0Tzt1uyVlGJaUYdUea3e6WflrHuf3im24DySctYCDqZA4sLSDsuxaH5YSKXNNqp7nNzvoE+NjLz6mh6dOvJk7nF/psz7QbS1DtoT6Nfja6nCTVWSt83HNzSurDvHcjwfr3ICvtLupxM3tS7fwjyt6M21wx3O0wvrT3G5cxcXooqLO+XuLk6kWaEy3aL6+azSyLNV7V9HqdPPC/w7W+V5Vriqv25RgBVeFC83t/W3kKnWhC64eo7s0F5XOSn8+miAIglAPsizRLthIZ72L9oqr1kAKYOX+PK/5jBDY7nylw81PB/LrvWahkRgCDKYkCdLuAj8aRf1trJFKh49IxBAMI+8JYJFNp6b0PMUcilpVhqbWnfaoAeX2lpXmt2xLll+B1OlsTpW/fb67RZxQuYuKUEJDkfTnvo5TnEy1UFanm30nyrxuD2RXMT2/gv055dXTP1wOT2v3wkNQeAiTrdzrPSwpFiSdRNmWMsKG/j78zG1zU76znNip1XcpdZIOs67571YJgiC0dJKiQ/ORwnSmggrftS+B7M57Xse79kRoWpJOD4G2R+9/Hfzwd6+bM+ZWTw/tGCZj+2uo1+OQJOg1KbD3bCKhJt8X28b4Hkg6PVUHNhDUY1Str6FIEGZuOc1XrA43f/18t89AqrZMJvAEVPd9tIPNj5yPIjffGXOu/Hx0MU0zTkgEUy3U17tOIPvomR/IrqLTpbL0ixU83mnbqeCJ0mzPHKyoFGjXlQRzNJmVmdWep1gUYq6M4fjS48gmuVo3P32knvC06jnTiqzQIbjltdsUBEFoaSSdguaue8dc7yNNG6rvzvsTUDWn4fKCR71mTZnCYNxD8NPT4PTOSKmN6pZRhz2ATlfLUN9mpE9CGDuOleA8I9dPNgYRPmo6Rd8vRJIVTIkDkGQdtozt2DJ3VmtCYdIrdK/nuJkauZ2ekTqyDMYwz/83kC93Hvd5uz/dCwFsTjc/HchjQo/mm9LrzMtrknopEMFUi5VRUOVz2nMgu4puDQ4V2CC1PXQZ7QmgIrpUmxExPWs123+63yvdL3piNEqQQs6HOTjyHMhmmdCBoXSc3RFZX/0LQCfrGN5+eP0+qCAIguA/RQd+dHKLCzUhSXh18gtkd16WIDa0ZVxAtyX1mjUFMPJuKD0G29/zP6DSm7GFjSH7sWUkxJ6HOTW17uc0sT8M7cQ7G476vC906GTkoAhKN3xIwVfzkQxmjLEphI64ptrj5Ipyur4xn8opV2EZPhypvoGP2wn7vvQMTT6x09OiHg1Ut2cActpd0GWU5+TvLCxcle51zRhIJlOlw83Cnw4362CqqZpPgAimWqxym+9dp0B3FStDk2Bkzb8wR3YYiUln8lk7FTk2ksixkT6e9TujYmR6j+noZPGjJgiC0Ng8bbHrTvOb1D+ej7dkeV1gBbI7b9QpXNE/vsE/g3B26nUyBZ4L9onzITTec0KFBK4amizoLaCpcOETWIbcRGzv7zl2y620f/IJQsaNO5vlNxqb080HmzN5dfVhLAaFMpvvgDM4dTzBqeNrfB2jTmbmqO6ElDjInT8fd0kJ4VdeRdjkqzAkJPi/oO3vwYqHPK3pHb91QXafljZ78Hs4ug7MkTB1MXQc6v9rnyavzEZWifd/x0AymQB+ySjC4VIx1DCKoamJYEoIWITF9+C8QHYVAUJqyB0+SZEVZqbO5KUdL2Fz2Wp9rC8SEld3vzrg5wmCIAiB8zfNb2CncKKDjRwt8t4o83d3Pj7CTO/4MK/nC02r3sEUeAKq0X+GQTfC1iWw4SVwVMLJzVnVBeYIz4lJ/+s86YFA6AUXoI+J4didd+KaM4eI665rmA/TAMptTt7ZeJTFazMY2Cmcl68fRJXdxay3fw6oGcNJekVmxtjuRIb0JfL66dj27aNk2XIyrp6GsXt3widfRcgFFyCba6kVX/VPz2mUs7aOgJrn372jEpZMgqlvQfeLA15vcZUTgyJxZt+NQOsj9YpMmc1Ju+DmeRrtys/HmJzSJO8tgqkWqmtcCEEGhcqz2FXUKxJ9/PhF+MfUP7I5ZzObczZjd/tfbGxSTMwfO5925nb+fzBBEAQhYKqqsSu7lMOOYFxulYT0QnrHh9a4YSZJErPHJvGPz3Zh07xTiOranbcYFG4bm9xg6xcajqTT1T+YOskSCaPmeoKmwkNgLfYEWuZIiEr2mXZm7tePLu++y7FbZ+M4lkXMfffWP/2tARRVOnhz3RGWbjzK2G7RvHvzsGoNt+69oDvPfn8Aq7Puk9yTTHqFt2YOITrk94DC1LMncY/0JOb++6j43/8o+XQZOU88SejFFxM+ZTKmPn2QTv/39cubfgRSZ3Ba4ZMb4Y9fQsch/j/vFO//XoFmMoF3WnBz4srPJ2iEf6dsDU0EUy3URamx/GWZ7xxav3N+JYnrh3eq871kSWbB+AXct+o+Np3YhNVd+xeAhIRRMTJv1DzGdhzr/4cSBEEQAlJc6eCDnzN5Y+0RrA43kjMaDQ35nV9wuFQu79eBm0cn0iOuegc2d3k5ae8uIEXqy4HgOBw+Rl3UxKiT6ZcQzpX9RWOh5shzMtVAbbtlGaK7+f1wQ6dOdPngfY7deSfZf76XDv98Ctlkapi1+Cmn1MZraw7zyZYsJvZpz2d3jKRzVJDX424Zk4RRJ/Pkin043VqtA6iNOhm9IvPmzCEM7uK7vEE2GAi9+GJCL74YZ04OpZ99TvZ99yMZ9IRPnkLYpMvRBevhm4fgjEyfLv8pp8oJR+4OJsjgubZ7fauDpTudrLrxt7U7rbD8VvjT1jprqDSnE/uhQ1h378a++wAOa284o9wi0Ewmp1tt1h0MRZqfEDCjTuG6oR15e30GTh+/BOvaVQQY2CmChIi650qAp/bpuQnPsezgMt7Y9QaFtkJsLhsaWrXHaJpGWoc05vSfQ8+onoF9KEEQBMFvH/9yjL9+thtJ4rR0pd9OAn6rB1m+LZuvdh5nXPcYnru2P0adgm3fPrLmziUoLY13517D9CVbOZRbgc3H3MIzmfQyPeJCeP2GwehEJ79m6azS/BqAEh5Opzfe4MRfHiZz5iwSXn4JXUSE1+Mq7S6+2nmcg7kVlFidhJn1dI0J5rJ+HQg2Bn55erSwkoU/HebrXSeYOiiBb+eOIS6s9kDuj2ldGJoUyWurD/PVzhMosnSqjlACLEYFnSwzY3hn/jiiMzGh/gWG+rg42t02m6jZt2LdsoWST5eRPvFSYsYEEx6q+Tgn8jQFe26Tg4dH15JGV54DWb9UO506GTjZ9uzBumcPtt17sB86hD6+A+bUVOJ7pRKTaSSrsv6ZTAB9E8Kbbb0UgCsvH120aI0uBGjmyETe25SJ0+3/EfVJJr3M3PO7BvQcWZKZ2m0qU7pOYUf+Dj7a/xHZFdnY3DZCDaEMih3E1G5TRVqfIAhCI1u0Op0F3x/AXkcA5FY9O+6rfs3j2kUbWdjuOKX/WUDsww8TdvllAHxyWxqPfrGHz7ZlI0tg9VFHYtYrqJrG5AEJ/H1SarO+qGrrmjqYApCNRjrM/zf5/3mOjGuvpdOiRRi6dAE8My5fW32Yz7ZnI0tStSYoFoPC37/cw6R+Hbh1TBIpMXW3H9+fU84rqw7x04F8rh/emZX3jSMyyHdduS894kJ5Zlp/Hp2Uytc7T3Cs2Eq5zUmkxUCP9qGc1zOm3iMAJEnCMngwlsGDcZdXID3XG8nlu1zi/jQD/1pnZ84QA+Em3ydPmsuG+u0TlEfM8ARPu/dgP3gQfYcOmFJ7Ye7dm7DLLsPUowdy0O+ncbdvPMoTX+/zajjjbyZTUDNP69U0DXdBAbroprn+FMFUCxYfbmbRjEHcsuSXgIoodbLEAxf1YFhSVL3eV5Ik+sf0p39M/3o9XxAEQai/b3efYMH3BwL63re5VPZmFnL/7hO8tvQdjMm/XxiZ9ApPT+nLw5f05OMtx1i89gi5ZXY0NCQk4sJM3DwqkSmDE2oceCo0H5K+AWqmGmIdskzMn+9BnxBPxvUzSHj+eX5UYnjgk1043G7cPn58T17sf7oliy93HOefk/tyxQDfHSN3HCvhxZWH2JZZwk2jEnn8yt51NtWqTahJz7VD6y59qC/FfgKkmuvOB3dQGNdFx/z1duZN8H0CJmkqUuYqKvd3w9w7ldCJEzH26IkS7J3GeLorB8Qz77/7fN7nTyaTXobzezbNqY8/3CUlSBYLsrFpmmOIYKqFG901mldnDOa2pVtwulWfKX8nKZKETpEINioY9WJXURAEoaXRNI1Hv9hbYyBVuXcVZT9/hrMwC9lgRh+TRFjaNEwJqdglhc1RXckIiqG7j+eGWfTcPDqJm0cnoWkaTreGXpGqF88LzZ5nzlTTB1MnRUybhr59e9557GUWpF6JzY89ALfmOSF9cNlO3JrG5IGeluOaprHhcCEvr0znSEElt45J4vlrB2A2+NdAoUlVFXrVLZ3pH+ONjFxcyeZVFd0AACAASURBVN3Daj5ZkxWJ+KfmVZsJWpcgo45HJvbgia9/DajhBoARlbu2Lce+KRRdWlrdTzhHDuSW886GoxzILaeirBL9kJkM+e9eZgzvQqco/0pYGooIplqBMd2i+e6eMby1LoP3N2cCVOvydzI9Y2Kf9twyOgmLQWHqwg10CDMzvkfz3WkQBEEQqtuQXkhZDXMGyzYvp3TTJ0RdeAemxIFIig7rkS1YD27ClOAZpup0q7yx9jD/mtqv1veRJAmDTgRRLVFzSPM7U1bXfizofRU2Hxu+tW0A2JwqjyzfTfe4EHJKbby08hAlVU5uG5fMlf3jW1a6qVZ3FNk7RuGybjr+udZBz+iaPpvk12ud6foRXThRZmPx2gy/AyqTXubBi3sx7Yp2HH/wIcKmTCb6jjuQdE0XPny3J4fnfjxIen4FTpfKqR+p4AT2rM9gyYaj9E0I457zu5GWcm7S/kQw1UokRFj462W9uO+i7ny7J4f9OeUUVzkIMenpHGXhsr4dqnVhWTRjILcs2cKSWUPFnBBBEIQW4tXVh73qHgBUeyUla98lauJcLN1/3z22pAzDkjLs1D+7Nfhix3H+dnlqvYr8heYrr9zG//blkaF0Rs7UiNuQwciUdiRHBzf10nhlZTq+DlP92QCwudxcu2gjCZEW7hifzCW926PILTDQN0d4BvTW4bFxJgYuquDeETWkrEky6OvXIfH+i3rQPszME//dhyTh87sEPDVSSPD0lL5c1rcDkEjisk85/sADZN44kw7PzEcfG1uvNdSXqmo88fU+3tuUWWMw6MnO0vg5o5hZb//M3PO7ctvYxp89Jb5JWxmT3r+J9IM6RzLvyt7c/PYvfDonjfjwWobLCYIgCM3CzxlFPm+3Z/+K5nJg6Vb3nBW9IrP3eBlDE323eBZaDk3T2HykiEWrD7P2UAGKJGFXO6EdB8N/9yEB3eNCuH1cMuf3jG2SDoylVU6+2ZOD+4whRf5uAGiaJ6B696ahRDbTgbF+adcNdEbPEN5apETKXJOq5/nNDvrE+Pjv1aXuNua1uX54ZyYPjOeL7cdZ+FM6x0ts6BVPcOpwqyRHB3Pb2GQu6ROHUfd7+qSuXTs6vv46ha++ypEpU+nwxDyCx5678TdPrag9kDqTzany3A+H0MkyN49OatS1iWCqDZvYpz3ZxVZmvfkzH98+QhQWC4IgNHM1XUi4rWXIllD/hm9qUGptXmlgQuBcbpX7P9nJN7tzsDndpw0q8VwYn+z0uCOrlD9/tINuscG8PWvYOZ8VtHxbFrKPurtANgAUWWLZtuxGvyhuVLICw26HNc94zZk609/GGnlnp4+/o4YgGHn3WS/FYtBx7dBOXDOkI4WVDkqqnMgShFsMtXZClGSZdrfdhmXwYLLvu5/QSycSM3cukr72n6nsEisnSqxYnW5CTHqSooMCuuZc+WseSzf6DqRqSxO1Ot3M/24/QxMj6ZsQ7vf7BUoEU23czaMTySyq4valW3jzxqGn8o+PFlby1voM1h8qpNzuRK/IxIWamD68MxenxrWsPGVBEIRWQpElVB91J4o5FLWqDE111x1QSZzaiRZaJlXVmL10C+sOFfjV1bHK4Wbv8TKufGkdX9w58qy63gVq9/EynxfBgWwA2Jwqu7NLG2N559agmbBmvtfNGXOrt4DvGCZj+2uo1+MwhkJiw50GSZJEu2Aj7QI88bMMHkzi8mWceOgvHL1+Bh2eeQZDQvWsKKdb5fu9uSxclc7+3PJq140Ol8rFqXHcMibJr1KT53886PNnyJ80UYdLZeFP6bw8fVBAnzEQ4oq4jZMkiUcv74VJp/Dw8l1sySji6oXruXDBapZuPMr+3HKOl9g4WljFpiNF/GXZTgbN+56nV/yKLcCOMIIgCMLZqWk31xjfA0mnp+rAhjpfw61qAV88Cc3LM9/tZ/2hwoDa4zvcGtklVm5buqURV+attMr3KejpGwB+vU5rOE0NjoaR94C+Ht3mdGa4/HmQm8eluy4igoRXXibkoovImDaN8h9+OHXfjmMlDH3iB+7/eAc7s0uxu1TKba5Tf+wulS93HufqhRuYtmhDjU11AA7nV7D3RJnX7SfTRCMvuB1L9zRkgwlJ0WFJGVZt4LCqwQ/78iiqdDTsv4DTNI//IkKT0ikyL/xhABvTC7nm1Y38nFGM3eW7zXql3U25zcXidUe48qV1FDfiD6cgCIJQ3VUD432eKsnGIMJHTafo+4VUHdiA6rShuV1Y03+heOXiao8NMeno1d7HrrfQIlQ5XCxeV3NHtsq9qzjx9lwyn51K1oszyP3oUWxZewDPLv2Wo8X8muN9cdpYQky+k6AC2QDwvE4rKUUY9xD0nhxYQKUzw8VPQbcLG29d9SDJMlGzZtLxlZfJfeqf5DzxJOv353DtqxsprnJW6yx9JlXzpC1vzyzh8hfW1hh0f7D5GG7V+3o0kDRRWYIvtmf7/8ECJIIpAYC1BwvIr7Dj8vED64vdpZKeX8G1r27EWstfFkEQBKHh3DCii8/6E4DQoZOJmHATpRs+JOuF6WS9ciPlW7/C3PX3iw2zXuaW0UnILbEbmgDAZ9uyqWn0V9nm5RT9+Bphw6eRcOdS4m9/k5CBE7Ee3HTqMU6Xyhtrjpyj1UJKbDBGH6UBgWwAGHUyKTFN35WwQUgSTHoRRs71NKTQ19IAzBDs+TPldRg889ytMUDmfv1IXPYp6bll3PTGhoBmWTncKsdLrFz/xiZcPiY5H8wr93ltGmia6JGC2ht/nA1RMyVQXOng7g+2nypWPV1thX1Ot0ZGYSVPrdjHP67o3QQrFwRBaFs6RloY2CmCzUeKvLqjAQSnjic4dXyNz9c0uHpQx8ZcotDIFjVAe/wvdx7n75NSCToH7fGnDkzgPz8c9Hlf6NDJyEERlG74kIKv5iMZzBhjUwgdcY3XY6cNbkU/t5IE4x6EobfAtqWw4UWwl/0+1NflgKgkT0pgryvq3Qr9XFLCwnhr8FRse3N83l/X9WR6fgU/7Mvl4t7tqz+vhg37gOpEodZUwrMlgimBD3/ORMP7l7I/hX12l8rHv2Tx0CU9sBjEj5MgCEJj+/fVfbn0+bUB15CY9DL/nNyHMEsrSZdqg1RVI7Ooyud9gaQ96WSZzKIqep6DdM+YUBOjktuxcn+ejyuNujcAAIYnRREX1vwDioBZImHkXTDiTig9BrYST0BliYKQuKZeXUAKKuysPpCPhvexqT/Xk1UON6+sSq8WTLkKCgiq9N145PQ00aAedbeLj7A0Xp2oSPNr41RV4/W1R7yKWP0t7APPBsuXO46fy2ULgiC0WQkRFt6/ZThhZj3+ZuuZ9DIPXtyDKwckNO7ihEZV5XSj1JDjF0jakyRBuc3V0Mur0R0TkjHq63fJadYr3DG+8QevNilZhojO0L4fxKa2uEAK4L1NmT5vD+R68tfjpfzy/Gscu+NODo4dR/qll5F8eCdGvDOnAkkTtRgUerYP8XqNhiKCqTbul6PFPnNbA9nhqnK4eXNdRiOsThAEQfClV4dQ/nvXKMLMehRZwuCjKYWE5yIiPtzMi9cNZObIxHO/UKFBmXSyz/ROCLA7nuYJUs6VQZ0jmXt+N8wBjlUx6xXunJAiBky3AD/sy/VZLhLI9SQuF5tLJcIuv4zOS9+h28YN3PrUXaDznfnkT50oeNKbL+vboV6fyx8iL6uNO1Fq9Xl7QAMggbxye0MuSxAEQajD0cIqQkx6Pr09jQ9/PsayrdmU2Zy4VQ2zQWFol0hmj01mSJcIpJo6Fggtik6RCTHqKPNxqhRI2pPdrRIbdm7b498yPIHC997nnch+2LS6gyqzXmHOuGTmjEs+B6sTzlZNaceBXE86FR3aqHGEnnYSGRNiYnTXdvy4r35pojpZYuqgBMyGxts8EMFUG2dzuvG1yRVoYZ/Dx26EIAiC0Djcqsa8/+7joUt6kBQdzF8m9uQvE3s29bKEc+C6oZ1YvO6I1/iS09OeJFnBlDgASdZhy9iOLXNntZSqfglhxIScuxokTdPImzePP+qLGT9rNi+sTGdrZjGqplX7HHpFQpYk+ncM567zujIypd05W6NwdnQ15BwHcj0pSxKKj9e567yurPVzQPWZ9Iqng2ljEsFUGxdi0vtssxtoYZ+lESN+QRAEobpPt2YRZFC4pHfLq60Qzs6MEZ15a30G+Nin96c7XpBR4bax5/a0p/i997Bu307n9z+gY3AQaV2jOVZUxfubM9l3ooxym4tgk44ecaFMH9aJjpH1GGorNKmoICPp+d7txwO5njToZCKDDF63900I5/ErevN/n+8OKKAy6WVenj6QTlGN+/Mkgqk2LrVDKE4fff0D2eGSgD4JYedw1YIgCG1Xpd3FM9/tZ+H1g0T6XhuUEGFhWGIkG9ILcfqYv1NX2lOQQce47jGNucRqKjdtpuDlV+jywfsowUGnbu8YaeGBi3ucs3UIjWvakI7sPl7q1bY/kOtJt6pxfs9Yn69/9eCOKLLEI8t34XCr+Lh0PUWvSOgVTyB1Ln7WRTDVBpXZnBRVOHCpKqEmPakdQtmaWeL1OH/nP5gNCrPHiJxmQRCEc+HV1YcZmhjFgE4RTb0UoYksuKY/E59bQ36FHR/xVI0sBoW3Zg71mUrVGBxZWWTfey/x8/+NoWMrmhMleLmsb3v+9vlun/f5cz0pSzChR4zPk6mTJg9MoF/HcN5Yc4Tlvw2vPj14CzIoIHlSYW9M60JCxLk54RTBVBvhVjVWH8hn4U+ePGW9IiNJ4HJrmPQKelmq1w4XQITFwJAu4pe6IAhCY8sts/H2hgy+vLPu9Guh9YoKNvLpnDSuWbSBvHK7V/3UmWTA5LLz1vRB9OrQ+LOlANTKSrLuuJN2t95K0Ag/OrkJLZpJrzBtcEfe3XTU589jXdeTRp3CzX7UNiVHB/Pk5D48cmlPvt51gqOFVZRanYRb9HSNDeGi1FiMunNbeiKCqTZgy9FibntnC1UO16lJ0k7375G83aX6GLHmH5NO4qFLeohUE0EQhHNg/rf7uXaIqCkRPOl+X981hmd/OMDHvxwD8EqxMullNA0u6BXLzMzVRD//DdqihUhy407G0VSV4395GFNqKhEzrm/U9xKaj3su6MYP+3I5UWKrsYW/L2a9wlUD4hnU2f+N+SCjjqsHN4/TThFMtXIrf81jzrtbsNZRsHfqR17TPNP8/GCSNCZnbOACcysfpicIgtAM7Dleysr9+fzvvrFNvRShmQiz6HlsUip/uaQHX+08wSdbjlFY4cCtaoRZ9FzUK5ZrhnQiIsiA5uzD0ZkzKVy0iHa3396o6ypYuBBXbi6d5v9bbLa2IWFmPR/OHsGUV9ZTWFH3iSl4AqnxPaJ5/Mre52CFjUMEU63YzqwS5ry7tc5Aqho/vvR0soROlph7QVeuLYXMWTfR8ZWXMffpcxarFQRBEGqiaRpPfr2Pu89LIdSkb+rlCM2MSa8wdVACUwcl1PgYSa8n/plnyZg6FXP//o2Welf+44+UfPQxXT76ENlQc/2L0DrFh5tZcddo7v14B2sPFQC+x+cE/dYF+raxydw5IaVFB90imGrFHlm+C6vT9yT0yr2rKPv5M5yFWcgGM/qYJMLSpmFKSAU8Hfr0Os+8B5dbQ5Yk9DoJVYWrBsQza1QiKTHBQApKcBDHZt9G/IIFBA0beu4+oCAIQhuxcn8eOaU2rhvaqamXIrRg+tgYOvzraY4/8CBdPvkYfazvzmn1ZTtwgBP/9zc6LlqIPubcdQwUmpeIIAOLbxxCTqmNpRuP8t7mTEqqHIBnllRSdBC3jU1mYp/2mPQtf7SOCKZaqQO55RzMq/B5X9nm5ZRu+oSoC+/AlDgQSdFhPbIF68FNp4IpnSxxcWocw5PaUWZzYlBkYkKNTOgRg8VQ/ccmZMJ45GefIfuee2j/5BOEjBtX7f5ym5NPt2Tx9a4TFFc5kSWJyCADVw2I5/J+HRp1KrUgCEJLoTnduMudaA43kkmHEqxH0sm43CpPfv0rD0/siU5p3FoXofULGjGCiD9cR/af76XzW28i6RvmpNNdUkLWnX8i9sEHRKaKAEBcmIn7LurOfRd1R/ttQLNB1/q+w0Qw1Uq9sfaIz/lRqr2SkrXvEjVxLpbuaadut6QMw5Iy7NQ/O1WNH/bl8a+p/fzaNQgaPpyOr7zMsTl3oD78F8IuvZQTpVae+e4AX+44jixJXqdkO7JKePSLPUwdlMA9F3SrtR2mIAhCa+XIKqd8TTbWPQVIsuRJt1Y1kMAyOJYfLRAT4tnMEoSGEDV7NlXbtpH3n/8Qe//9Z/16mstF9p//TMh55xF2xRUNsEKhtZEkCYOu5aby1UYEU63UD3tzfQ40s2f/iuZyYOlWd660IklsyyxhRHKUX+9p7tePTosXc+yWW9iXZ+W2zFDK7c4aB6ud7Dr0weZMvtuTw4ezR9ClXZDvBwuCILQy7lI7BW/vwZVvRXOpoIFG9YLtyo0nGKhqDOoc6jmxMopf28LZk2SZDk8/zZEpU7AMGEDI+eef1evl/fvfIMnE3HdvA61QEFqO1nfWJgBQYXf5vN1tLUO2hCLJdZ82aUCp1RHQ+5q6d0N74TVm7ZEpsTpqnVB9klPVyK+wM+WV9eSV2QJ6P0EQhJbIVWAl97mtOHOq0Jwq1NT0SgUDEvrsCvJe2I5a5Tyn6xRaL11EBAkLFnDib4/iOHas3q9Tsmw5Fat+Iv7ZZ5AUkbYvtD0imGpjFHMoalUZmuq7McXZ0jSNW7/LxqY3whnTqyr3ruLE23PJfHYqWS/OIPejR7Fl7QE8GS2lVie3v7u1UdYlCILQXLgrneQt2olqdXm+/Pzh0nAV28hfvNtziiUIDcDcrx/tbr+drLvvRrXbA36+dft28ubPJ+Hll1DCwhphhYLQ/Il8gVYq2KjD7vI+VTLG90DS6ak6sIGgHqPqfJ1wS2B1TJuPFJFXbve6PvCn6YVL1diTXcqhvHJSYkICel9BEISWonzVMc8J0xnfk5uzdvLkylc4UJCBLMt0jerMo+f9if7te3oe4NZw5VZRtTOfoIEN24VNaLsirp9O1ZYt5D7xJO3/8Rh2l5sVu3JYtPowGQWV2F1u9L81oZqZlsjUwQmEmvQ4c3PJunsu7Z+YhzE5uak/hiA0GRFMtVIX947jw5+P4TojqpGNQYSPmk7R9wuRZAVT4gAkWYctYzu2zJ1EjJ916rGaptG/Y3hA7/vq6sNYz5jA7m/TC/AEVIvXZvDkZNEJSBCE1kdzqVRuzoEzhlmW2yuZ+clDPHHhn7m8x3gcbhebs3ZgVKpvaGlOlfKfskQwJTQYSZJoP+9xDk+9mqdf/IJ38o1omkblab/L7S6VY0VW/v3tfp7+5leu6BvHrE/+Rbs//IGQ8eObcPWC0PREMNVKzRqVyCdbsryCKYDQoZORgyIo3fAhBV/NRzKYMcamEDrimlOP0SsS1w3tFFD//3KbkzUHC7xS/wNpeuFSNZZty2Lelb2R5dbZ9UUQhLbLursANO/v5cNFnpqVK3t5GgGYZYWxib7n9rmLbDiyKzDEBzfeQoU2RTNb+OekB1l7pBi74rvmGjjVlffzLcfY0uUKPp0x+VwtURCaLRFMtVLJ0cH0bB/K9mMlPu8PTh1PcGrNu0myJHFDWpeA3jO/3I5ekTjjYCqgphcAblWj3O4izNwwsy8EQRCai6qdBWgO75qnpMiOyJLMPf99gkk9zmNAfCrhJt/pzppLxba/SARTQoPQNI37P9nJuhNW7Ip/qf12SSFTF8qMNzbzye0jMOpE4wmh7RLBVCv21OQ+TH55HVZnYMXKZr3CjWmd6RhpCeh5VqcbSfI+TTq96YU/AZUiS9icbhFMCYLQ6qiVvrvxhRiDWDb9RV7e9B4PfPNv8iuLGJ88jH9d/ADRQZHVH6yBu/z3mthSeynLDi7jm4xvKLWXIiERbgxnYtJErky5khCDqEEVarbmYAHf7snB5uNaoXLvKsp+/gxnYRaywYw+JomwtGmYElJxujUO5pXz5toj3DYupQlWLgjNgwimWqkqh4ttmcWEmvRYnf536DHrFSb2ieOBi3sE/J6hJj2qj/SVQJteuNwaISbxoykIQtvStV0XFlz6MACHCo9y11fz+PuPL/DSpEe9HyxBdkU2C7YsYNWxVUhI2Ny/j5bIqsgivSSd57Y+x4WdL+TugXcTGyTqrARvC39KPzX38XT+NI6yOVVeX3uEW8cki9R8oc0SV6yt0Hd7cpj74XYAn1+QvlgMCqqmMWdcMndOSPF5wlSXmFCjz9sDaXoBEG7RYw6gVksQBKGlkIP9O3FPierMtN4Xs3T7Fz5eBA4ajzL3y/+j0lGJiu/sA6vbCsDXR75mTfYaFl+0mK4RXeu9dqH1ySquYsvRYq/bA2kcZXW4WX0wn3HdYxp9vYLQHIk5U63MR78c464PtlHlcPsdSBkUiQcv7s4vf72AP53XtV6BFIBRp3DN4I7oFe/nhw6dTMSEmyjd8CFZL0wn65UbKd/6Feau1ZtSmPQyN49KrPcaBEEQmjNLv2gkg/ev3kOFR1m0+QNOlOUBcLwsl8/3/cjADqlej80y5XFX/sOUO8prDKRO59bclNhLuPGbG8muyD77DyG0Gt/szvE5LzqQxlGVDjcf/5LV8IsThBZCnEy1IusOFfC3z3f7zHuujSxL/JxRzA1piWe9hhtHduG9zZl4DVCh7qYX4Glydc2QTme9DkEQhObInBpF8TLvzaIgg4Xtx/fx2s8fUWavINQYzPnJI3hk/Jxqj9PQeDT+Jap+O3U6XfGaYgq+LcCR50AxKYQOCiV2aixKkOekv8JZwd3/u5tPJn3SOB9OaHFyy2w4fAyBDrRxVG6Zre4HCUIrJYKpVuT/Pqs5kKqtiNTmVPlhXy6/5pTRIy70rNbQOSqIMd2iWX0gH7uPL+jamPUyVw1IICIosEHBgiAILYGmaXy5O4cjqoMrJAXdaXtO7UOieeXKx+p8jb3BBymWC9HO2LAqWFFA/op8Em5OILhXMM5iJ8ffOU7G/AwSH0lE1smomsrRsqPsK9xHz6ieDf3xhBbI4fb9ezrQxlE1vY4gtAUiza+V2HGshBOlvneGyjYvp+jH1wgbPo2EO5cSf/ubhAyciPXgplOPcbhU3lhzpEHW8vy1A+gSFYRB5/+Pl0kn0yc+nMeu8E5pEQRBaOn2nSjj2lc3snBVOqOmp2IMM0Kg2cyKxKcdfsImVw+k3FY3eZ/l0eH6DoT0DUHSSRiiDXSc0xFHgYPS9aWnHutQHSzZu6QBPpHQGkQFGfHVN+L0xlH+CLeITVCh7RLBVCvx2prD2F3eNVIni0gjL7gdS/c0ZIMJSdFhSRlWrfGDqsEXO45TbvPdtjcQZoPCp3PS6JcQhsVQ946WxaAwqms7ltw0FL0ifiQFQWg9SqucPPr5bma8sYnL+3Xgyz+NYlD3aKJn90UONvj/W1gn44qR+UW/2+tUqupgFapTJXRQ9cwCxaQQ0jeEij0Vp25TNZXvMr7DqZ79d73Q8g1LjMTko+HT6Y2jqg5sQHXa0NwurOm/ULxycbXHmvUKE7pHn6slC0KzI9L8Wom1hwpQfVSRBlJEalBkdmWVkpbS7qzXE2zU8cGtI/hhXy4LV6Wz90QZwKncbIPqQtPpGJwYxa1jkhjbLVo0nRAEofk4sQPWvwgHvwNHJUgSGIKh1yQYfgdEd6v16W5V46NfjvHMdwe4uHcs398ztloKsy7CROzdAyhcug9nVgWaqnJmL4lsfR6fR/2P7UEHsBrsyGY9bqv3ppm7wo0uWIfko/mPLkyH9egZ9VUSlNnLiDJH+f/vQ2iVhiZGEmExUOXwrsELHToZOSiC0g0fUvDVfCSDGWNsCqEjrqn2OFXTmDIo4VwtWRCaHRFMtRJVdt+d+wIpItWAEmvD7VYqssRFqXFclBrH4fwKVh/Ip6TKiSxL6DasYZTFTupNE3GrmgikBEFoHo5thi/ugpIMcDlAO+271VoE25bCjg8gpidMehHienu9xNbMYh79fA9GncxbM4fQOz7M51spwQZibuuHM6+KinXZVG3NQ9M09pkP81rUp6Qbj6FKLlzSbztl3te7v72OgqvChebWvAIqV6kLXXD1X/WKpGB3+z9/UGi9JEli9pgknlrxK1an93VEXY2jZAku79eBEJN/Lf8FoTUSwVQrIcuAj3gqkCJSCdA10tC9pOhgkqKDAU/twMI9HXg+y4794a+RJE/gNahzBLPHJjO2a7QY/icIwrm393NYPhucNUQtAKrL8+f4NnjjQrjufUgaC0BeuY2nV+xn3aECHrqkB1f07+DXRpE+xkLEVV0Jn5TCl/u/4PGtL1QbwFsXS4oFSSdRtqWMsKG/B25um5vyneXETq0+rNelugg2BPv9+kLrNmVQAq+uOcyJEhtuzVej9JpZDDruPk/MLhPaNhFMtRKhJj02p/dO4+lFpEE9RtX6GhoQFdx4RaQHcsuZ+8F2DhdU4HSpuGXPj5+mgerW2Hi4iJ1ZpVgMCo9NSuXSvh0abS2CIAjVHFkNy2aDq5ZA6kzOSnj/Wpw3rOCtw6G8vOoQ04Z05Id7xxJsDPzX66rsVTy+dV5AgRSAYlGIuTKG40uPI5vkat389JF6wtPCqz0+1BBKiD4k4PUJrVOQUceHs0cw6YW1lFiduH3VDPhgMSi8PWsIHSMtjbxCQWjeRDDVSlw1IJ7F647gdFf/Ejy9iFSSFUyJA5BkHbaM7dgyd1ZrQqGTJfolhJ/50g3i54wibly8mSqH2+eAwJNODhu+9+MdZBVbmT02uVHWIwiCcIrbBR/d4BVIdflPOVVOOHJ3MEEGzwnT61sdLN3pZNWNQQBoziry3riGtQlL+OT2NJKj63fiU2ov5YHVD/gMpOqaHwUQPTEaJUgh58McHHkOZLNM6MBQOs7uiKz/vcuFLeJCAAAAIABJREFUSTHxx15/FKnVQjXx4Wa+vns001/fxIkSK5UO36UDAEEGBaNeYcmsoTWmsApCWyKCqVZixojOvLU+A1/Dcv0pIjXqZG5I64KuEbrpHcgt54bfAil/2ZwqC344QGSQgasHd2zwNQmCIJxyYAW4HT7vcmvw3CYHD482+rxfAmLlMt66QEOqZyAFsPzgcp+3+zM/6qTIsZFEjo2s9X1UVCZ3nVzvdQqtV2yoie/mjmF9eiGLVqez9lABJp3iSf+XwOlWSY4O5raxyVyUGhfQ+BNBaM1EMNVKJERYGNQ5go2HC3129auriBRgeHIkz/1wgJwyOy5VJTrYyOiu0QxPijyrXcx7PtyOtYZAqq5hwv/32W4uTI0jzCyKWwVBaCRrF4Cjwudd96cZ+Nc6O3OGGAg3+f4e1LmtsP556DS8Xm+vaipL9i7xOpU6OT8q/qZ4Qvp60vJOzo86cP8BSteXEjEmwu/3MSkmrki5gnBT42QgCC2fLEuM6tqOEclR9P37tzx7TX8kPLVRCRFmurQLauolCkKzI4KpVuTpKX259Pk1lNlcAT1Pr0hEWAzMfPNnTy3TacHYW+szCDPrmT0miasHdyQowDqA/TnlpOdX+EztK9u8nNJNnxB14R2YEgciKTqsR7ZgPbgJU4JneK8kSXy6JYtZoxIDel9BEAS/VBVBzq4a7x7cQWFcFx3z19uZN8FUw6M0OPCtJ11QCfzX6q6CXVQ6K72X5sf8KH+DKZNiom90Xx4a+lDA6xPanvT8CqJDjFyUGtfUSxGEZk+c0bYiHSMtvHfLcEJMOp8TzX2RNBVN08gps2FzVg+kwFPDdKLUxtPf/MrE59ZwojSA4mzgjbWHcbpUr9v9HSZsdbp5dfVhtAA7DAmCIPilMh+U2hvv/GO8kRc2O8iv9P4uO0VSwF5WryXkVub6PP2va36Uq6LujTMJCbPOzOiE0Sw8fyE6WeyhCnXblllM/47iBFMQ/CGCqVamd3wY//3TaPp3DMeok2tsdW7Re+5TZAkfsY4Xq1Mlq8TK5S+sJb/c//kkX+/K8QrQILBhwmU2JwdyfafgCIIgnBW30zOQtxa9YxQu66bjn2t911UBnteooe6qLna33eeG0enzo87ka34UgE7SoZN1mHVmDLKB0Qmjeem8l3hm7DPoFZEuLfhn+7ESBnTyP4VUENoysUXVCnWKsrBszkjS8yt4c90RPt9+nAq7ZwdTL8v0SQgjxKhj45FCbE7vi4ja6phKqpzc+OZmvvrTqDrrqFRVo9Lhe+c0kGHCiixRWGkHRCtfQRAamCnMk55Xh8fGmRi4qIJ7R/huRIHbCfWsRQoxhPj8Pg10flSQLojb+t2GJEmEGcMY2WEk0Zboeq1JaNu2ZZZw7ZBOTb0MQWgRRDDViiVHBzPvyj7Mu7IPqqrh1jT0ikyZzcmQeT9g93EkVVcdk0vVOFJQydbMEgZ1rn3XSjv1P94CGSYM+D33QhCEtiWvzMa7mzL5cudxyqxOwDN3b2KfOK4f3oW4sJrqnH4TGg+GoDrnS6VEylyTquf5zQ76xPhI6ohMAn0d71WDHpE9cKpOr9sDmR8lITEgdgA39r6xXmsQhJMq7C6OFlbRs31o3Q8WBEEEU22FLEvIeHY+P/0lC9nHLujJOqaoiXOxdE87dbslZRiWlGGn/tnqdPP6msMM6jzo1G2a04nj2DHshw7hSE/Hnn4Ye3o6+pQ/4vCRWhLQMGEN0c1PEIRqjhZW8tiXe1l7qAAJqm0OFVQ4eG3NEV5dc4QRSVE8enkvkmpqWy7LFPe7haCNz2LQak9h/ttYI+/s9A56MATBqHvq/VniguLoH92fzTmbve7zd36UWWdmVu9ZXs8XhEDtzCqhV4dQ0fpcEPwkgqk26PW1h7E6vVuV+1vHpGnw494cDj33MqYjB3GkH8KReQxdbCzG5GSMKckEjRpJ5A03MHRDGWsPF3u9RiDDhDWge5xI8RMEwWP7sRJmvLGJSrvL5ygI+D24+ulAPuc9+xORQQaCDTpiQk1cP7wTF/eOI6/MzksrD7FpdzLf+zhGz5hb/XunY5iM7a817Nb3PrvZTbN6z2J3wW6qXFVe9/kzPyrUEMrg2MFntQZBAM/fL9F8QhD8J4KpNkbTNI6X2nzeF0gdk97tJNsh0f+88zDOvhVDYiKyyTvF5TZTAVuzfvE5sNefYcJ6ReK6oR0x6upekyAIrd+hvAqmv76RSrv/Q8A1DQorHBTi4GhRFbuzS7n3ox0oMtyYlsjy+6+g6oct6Lcuxoz/DXYA0Ftg/COgNwf4Saob0WEE8cHxHCk9gksLbLyFSTExd9Dcs5oHKAgnbcss4Yr+HZp6GYLQYohgqo2xOt3IkoTbV+eoAOqYZLMZ/eSrCUtuV+vjRqZEEWrS+wymoO5hwrIkcWNal1rfQxCEtkHTNG5d4ntzprbGOWc6eTIvyzIr9+eRltKOB3ZO4MOYA3Qp3gBO79Mhn/QW6HcdDJ9zVp8LQJZkXrvwNaZ9NY0ia5HfAZVJMXF9z+u5NOnSs16DIGiaxrbMEh69vFdTL0UQWgyRENvGmHQKag0zm06vY6qTBsF+DPCVJInHr+yNSR/4j5pZr3D1oAQSIiwBP1cQhNZna2YJJ8psnPkVVrZ5OUU/vkbY8Gkk3LmU+NvfJGTgRKwHN9X6eg6XSnpeBTPf3Mz/Xd6bLrd9AgNngM4Eci11mooRdEYYORcufabO1ur+ijJH8eFlH9I5tDMWXe3fezpJh1ExMqf/HO4edHeDvL8gZJdYkSSIDz+7k1ZBaEvEyVQbI8sSUUFGCiq8U1kCqWNyuFU6+Plle0GvWB68uAdPf/MrNqcfQ60As16jR3ImKV1zWbhjHUH6IJLCkhjefjiKH2mIgiC0Pq+tPoztjHpPfxrn1HZq5dY86cTr0wu5rF8HuORfMOw22LQItr0DkgycDJY0QIYhN8OQmyAsvsE/YztzOz6e9DGrjq1i8a7FHCw5iCzJuFQXsiSjk3SoqExKnsT1Pa+nS1iXBl+D0HZty/TUS4mUUUHwnwim2qA/jujMSysP+WyN7k8dE8DQxEjaBdcwb8WHmSMTiQoy8MCnO5Elqca0P52+FCV8A6Z2mzmhk3h+mwOX6kIn6zDIBow6IzN6zWBK1ylEmMRAQUFoTSrtLr7ZnUNWsZVyu5MIi4HusSGM6x6NqsGPv+Z6nUrV1TinrnEPAE63xrJtWTxyaU+CjDpPm/NLnobz/w7Ht4G12BNUmSMhfiA08vBbvazngs4XcEHnCzhccpgd+Tsod5RjUAxEmaMYFT8Ks06cHAgNzzOsVzSfEIRAiGCqDfrDsE68uPJQjffXVccUZFCYPSY54Ped1D+eCT1j+WxbFgtXHaaw0oFO8ex+aZoGlr3o4t7Drbpx4cJ1WsmAU3XiVJ1UuipZtGMRb+x6g1fOf4X+Mf0DXocgCM3LobxyXl9zhM+2ZyNLElaH23MGJIHZoKBXZK4elPDbSIfq0VRtjXP8HfcAnvrMz7ZlM314599v1JuhcxpNKSk8iaTwpCZdg9B2bMss5v6LejT1MgShRRHBVBvULtjIhb1i+X5vrud0SnKhC96LbCgA2QaqCdXRDldFL9Cq/4jIEkQFG0lLjqrXewcbdVw/vAvTh3UmPb+SokoHLlVlX+laXtn7Hna3vc5KPpvbBm645btbWHjBQgbFDqr9CYIgNFvvb87ksS/34HSruM84LFc1fuva5+bt9Rk43IE1zvF33ANAlcPNWxsyqgdTgtCGOFwq+06U0zchrKmXIggtigim2qinp/Rld+4X5PIjSthvRdqyA0nS0DQJVAPwKY7iYTiL09Bc4Uh4gqF3bhqKLJ9dPrUkSaTEeIZo7i/az8INT3gCqQDY3Dbu+PEOPr/ic2KDYs9qPYIgnHtL1mfw1Ar/ail9BVJQ+wDwQMY9AOSXBdgWXRBakX0nyugcZfGkugqC4DfxN6aN2pK/nsrop9C5nCBVr1+SJA0Uz0WFIXIdhsgNOE9MJ8jdlw9uHU7nqKAGXcuinYtwuB1etxevKabg2wIceQ4Uk0LooFBip8aiBP1+YeRwO3h337v8efCfG3RNgiA0rk2HC3lyxT6/m9LUpLbGOdaDG/0e9wCexjqC0FZtyywW9VKCUA8imGqDVmau5IHVD3jS5eo4YJJkT4qNOf49/ja8D91iQxp0LcW2Yn7K+gmV6hcxBSsKyF+RT8LNCQT3CsZZ7OT4O8fJmJ9B4iOJyDpPLqBTdfLRgY+4c8CdGBRDg65NEITG8+z3B2oMpAKZGQU1N84JGTwJ6+FffJ5a+WIxiE6hQtu1/VgJaXXMjhQEwZsIptqYw6WHfw+kAuDGwRM/P0LfuBSSwhquGHr5weVIZ0R0bqubvM/yiL8pnpC+nuDNEG2g45yOHLj/AKXrS4kY83snP03T+DHzRy5JvKTB1iUIQuM5VlTF9mMlPu/zp/ueLzU1zvF33IME9IkXtSJC27XtWAlzxqc09TIEocURwVQbs3jXYpyq0+t2f1LqnKqTN3e9yeOjHm+w9WzL2+ZVK1V1sArVqRI6KLTa7YpJIaRvCBV7KqoFU1WuKvYV7hPBlCC0EEs2ZPgcHh5I9z1/+TvuwWxQuLUeXUoFoTUoqnRQVOEgJTq4qZciCC2OCKbakApHBd9kfINbq14j5W9KnVtzsyJjBQ8OfZBgQ8N84ZY5yrxuc1e40QXrkBTvHERdmA7rUavX7UW2ogZZjyAIjW/7sRKcPhpKBNJ9LxD/z959x7dVnY8f/5yrYXk7ju1sZzmbhCwSEhJCgLbMUvZIWCGMFgq00PbbFmhpgVLKj10IlB0IDbvsGQJkb0L2Tpxlx4m3LWud3x9XTmRLsiXZTmzreb9eekXWXecq0tF9zjn3OY1N9wCQnmjjxD6ZzXpcIdqKVfnFHN8jo8nJpYSIRxJMxZGPtn2EoermHY92SJ2hDD7a9hGXDbyswWM53V4+X7ufLYUVlFa7SXfY6JmVzFlDO5NkP/KxCzXxpCXFgqfCg/bqoIDKU+rBmhL8sU22NW9SDCFE7FweH1+tL2Dj/jKKq9ykOmzkZiZy5tAupDlslDs9IbeLNvue1VD4tMYXOtFfxBw2gz+cMRCl5EJSxIeKys0cLJpDTU0RoNmbrxnXQ6YZESIWEkzFkR+LfqTaU7dXJ9ohddWean4s+pHLCB1M5R+q4sX525m9NB8FVLqO9IIl2S3c/f4azh/ZjekTetMnO4XctFwW7VtUp7csKS8JZVWULS8jfcyRexi8Ti/lq8vpdFHdNOh2i53uqd2jfj+EEM1rf6mTVxfuYOainfi09s8RZUqyW7jnf2s5Z1iXsNs3NGdUfQq4ZHR3Kmu8fLZ2PzaLgcfnizo7YKLNwjXje/KLEd2i2k6Itsbn83Cg6At27nyWysotaO1Fa3PYf45hoYvlDRYtfopePW8kJ+csDMN2jEssRNsgwVQcKatpniF1pTWlIfc/Z0MBN7++Eo/PF3IIT5U/sHpzaT7vrtjNQxcO48J+F/Le5vfweo9cdFmSLOT8Ioe9r+3FcBh1hh7aMm1kjK+XulUj90sJcYwt3HqQ6a8sxe3TuDzBAU3t9//9VXvRYbqSGpozqr5Eu4WRPTO5aFR3SqvcLNhaRHGVm3Knmxfmb6e82k11A4GVxQCbxeDW0/K4aZLcKyXaN4+nkh9WX09Z2Y/4fFVByy3KzNxbWbmRDRvuYlf+y4wY/jI2myRlEaIxEkzFkSRbUtBrzTWk7uv1Bdw8a0VErcIen8bj0/z+ndU8cP5Quqf0YEvp5jrrZJ+VjSXZwv7Z+3EVujASDdJGptHjxh4YtiNDFRWKk7qdRFaipHMV4lhZtO0g015e0mDwUsvbwJi86q1LMRwpFP3vQQ59loy9Sz/STrwEPK6g7Hs+rTnzuM4ApCfZOHPokR6va0/qzedr9/PMt1vZdqAChcLt9WEYCpvFwOfTnD+iG9Mm9CIvp3mnexCitfF6a1ix8nIqKjajdfCcjkHr+6qoqNjAsuUXccLo97FaZRi9EA2RYCqO9Errhc2w1cnmF+2QOquy0SM1t85r24squWXWypCBVEPzxTjdPv7wzmqS0iZj6bwLL3Wz+mVOyiRzUsM3hCdYEph23LQG1xFCtJyCMifXvbI0ZCAVzXxRgSnRvc4Kyld+jDN/Dc4960nsMbRO9j1DwS+GdyM5IfRPmN1qcO7xXTn3+K5s3F/Oil3FlFW7sVsNslMTOHVgTp17N4Voz9Zv+COVlVvR2sWUK3ZRXOzFMMBqVQweksDtt2eTk1P3+6C1i+rqPaxZexvDj3/+GJVciLZBfk3iyHl55/H8mrqVYrRD6jw+zYufZrNn+4+cPbQLJ/bJ5LnvtuH2Bl9IRTJfjNurGZQxmry+BXy+4/Oo5r9yWBxMO24aw3OGx/BuCCGawysLdoQc1hvNfFGhUqKnHv/TsMfUwPSJvSMq34DOqQzoLL1PIj7V1BRyoPBTfAE9Un+/rxOjRiXhcvl4/PEinnqyiL/9vXPQth6Pk+LiBVRV7SApqddRLLUQbYsEU3Gka0pXRmSPYPH+xXVej3RIHcDYLqO459xz+GTNPh76fAO7D1VR6vQEDd2JZr6YlfklPH75n/BpH1/t+iooSUYoDouDKYOmcNPxN0X7Ngghmonb6+O1RTuD7pGKdr6oaFOiG0rRMTmhaYUXIg7s2TMLlDJbIOqx2w1OPjmFp/99EICH/lmIPUFRWOBh9Won9/6tE6NH28nf/QoD+v/lKJdciLZDgqk4M33YdH448ENQD1AkQ+ocFgfXD7ue3I5J3DSpLzdN6sszc7fwyJeb8NZbN5qLIwW8u2Iv90+6nxO2nMCzq5/lkPMQTo8THfALYCgDu2GnR2oPbh5xM6flnhbpaQshWsDX6wvwhph8N9rgKNqU6A6bwcHKGjok26MqrxDxRGtN/u6Z+Hw1IZc7nT7mflPBoMFHGibmzKnggQe6cN/9CXg8Gq3d7Nv7Nv3y/ohhyPdNiFAkmIozJ3Y5kamDp/L6+tcj6gGqlWhNZMqgKYztUrdVeV+pM+QQn2gujpweH6t3l6KU4vx+5/OLvF+w6sAqXl/3OtvLtlPlriLZlszAzIFMHTyVgZkDIy63EKLlbNxfTlVN/aaU6IOjaFKig5l4xuVp4uRSQrRzHk85Xm9l0Ot/uacAi0XhdPrIyLDwjwePJG8ZPz6Z445zAGC3m0mpNBqXqwiHo+vRKbgQbYwEU3Ho1hG34vF5+O+G/0Z0j5LD4uDSAZdy64hbg5YVV7lDbBH9xVFp9ZH9KKUYkTOCETkjGt1OCHHsFFe5Q40eivr7H01KdDAzAqYlys+XEA3xeitQynp4Lqla9/7NvGfK69UsWFDFHb/dywsvmnM15mQHf6+UsuDxlB+VMgvRFhmNryLaG6UUd4y+g3+e/E8GdBiAw+LAqPdRMDBwWBwM6DCAB09+kDtG34FSwXNRpYbJphV4cRSJlDD7EUK0XqmO5vn+GwnJZEyYwqEvZ1C1aSE+txPt9VC9dRnF37wYtL7DZtAlPbFJZReivbNYEtE6uOf4yHLFxInJGAas+dHfsBr8M4/WPiyW4KlVhBAmuYKNY6fmnsqpuaeyqXgTb6x/g80lm6l0V5JsS6ZfRj8uH3Q5/Tv0b3AffbKTSbAa1NS7AT3w4kgZFhy9R6AMK84dq4Lmi7FZFH2zU1rkHIUQLSc3M4kku+XwhLy1ovn+10obcwFGcgdKF86m6KOHUfZEEjrl1UmJDpBgNbj2pN5YjBBXfUKIw6zWNJQyCHFbI2DeU7VgQRXl5T5ye9pZtCh4Ml9zRQ92u8zlKEQ4EkwJ+nfoz1/Gx5ap5xcjuvHQ5xtDLov04shQisvG9Ijp+EKIY+fMoV24+39rQi6L9PsfKGXIZFKGTG70uFeMzW10HSHinVIWunS5gD173gQ8h1+/+64CDMNM8tepk5U//CGbXr3CJZcwyMr+CRaL9AQLEY4EU6JJslISOKV/Nl+uLwjZ+hXJxdGw7un07CgzrAvR1qQkWDlveDfeXpZPiDw0EQdHkUq0Wbj0hB5kpUhadCEi0aP7tezb9y4+nxlMvT4rfEPE7/+QE/SaYSTQM3d6i5VPiPZA7pkSTfbLU/qSYI3to5Ros3DLqf2auURCiKPl+om9scX4/Y9Gos1gXN+O3H3O4BY/lhDtRXJyH9JSh6JU9G3nSllITMwlLW1YC5RMiPZDginRZCNyO/Cb0/uTaIssDXKtRJuFa8b3ZFL/7BYqmRCipeXlpHLvz4fgsEX3c+KwKgZ1TiXBamCzhL//yWKYCSfOG96N/1w1Wu6VEiJKQ4c+hc3Wgegu+RQWSyrDj3++pYolRLshw/xEs7jh5D54fJon52zG6fY1un6izcKV43ry+zNkzigh2rpLT8jF7fFx3yfrI/7+nze8K/efP5Q9xdW8vGA7s5fm+zOGmuMFFQqPT3P+iG5Mm9CLvJzUFj4LIdonuz2L0aPeYcWKy6hxFaG1q8H1lbJhs2UwcsQsmVtKiAhIMCWahVKKmyfnMbxHBo9+uYkf95Ti07rOhL4WA2wWg345qdx2Wj9OH9zpGJZYCNGcpo7rxcAuaTz65SaW7SxGa42rzvdfYbcocjsm8+tT8zh7aBeUUuR2TOKec4fw+zMGsmJnMcVVbnxak5FkY2RuB5Jl2gQhmiwxsRtjxnzMzp0z2L1nFuALmtDXYjHvXe7a5RJ69foldnvHY1BSIdoe+ZUSzeqkvCxOystiR1ElMxftZP2+MsqcblISbPTvlMKVJ/akXydpYRaiPRrdK5PXrz+RPSXVvL5oJz/kl1Dm9JBkt9AnO4WpJ+YypGt6yG0dNgvj8yT9shAtxWZLIy/v9/TpczsHDnzB/oIPcbmKAI3NlkmnnLPJyTkLi0USvAgRDQmmRIvolZUsN4oLEae6ZSTKEF4hWinDsNOp0zl06nTOsS6KEO2CJKAQQgghhBBCiBhIMCWEEEIIIYQQMZBgSgghhBBCCCFiIMGUEEIIIYQQQsRAgikhhBBCCCGEiIEEU0IIIYQQQggRAwmmhBBCCCGEECIGEkwJIYQQQgghRAwkmBJCCCGEEEKIGEgwJYQQQgghhBAxUFrr8AuVOgDsPHrFEUIcBT211tnHuhBNJfWTEO1Sm6+fpG4Sol0KWzc1GEwJIYQQQgghhAhNhvkJIYQQQgghRAwkmBJCCCGEEEKIGEgwJYQQQgghhBAxkGBKCCGEEEIIIWIgwZQQQgghhBBCxECCKSGEEEIIIYSIgQRTQgghhBBCCBEDCaaEEEIIIYQQIgYSTAkhhBBCCCFEDCSYEkIIIYQQQogYSDAlhBBCCCGEEDGQYEoIIYQQQgghYiDBlBBCCCGEEELEQIKpNk4pdYpSavfR3jbG481VSk0/WscTQrQebamuioVSKlcpVaGUshzrsgghwmtLdZFcN7UNEkzV4/8xrH34lFLVAX9PacHjXqOUmtdS+28qpVQvpZRWSlnrvf6yUuq+Y1UuIeKV1FXh+euqvBY+xg6l1Om1f2utd2mtU7TW3pY8rhCtjdRFocl1U/ywNr5KfNFap9Q+V0rtAKZrrb+qv55Syqq19hzNsgkhRC2pq4QQrYHURSLeSc9UhGq7dpVSf1BK7QdeCtUqEtgiqpRKUEo9rJTapZQqUErNUEolxnDsa5VS65VS5UqpbUqpG0Os8yelVJG/tXRKwOvNUoYIy3mNUmqe/3jFSqntSqkzw6zbRSm1Win1O//fc5VSf1dKzfef5xdKqayA9X+ulFqrlCrxrzvI//q1SqkPA9bbrJR6K+DvfKXUcP9zrZS6yb9OiVLq30op1RLvhRDHitRVQcf7q1LqTaXUq/5yrVVKjQ5Y/n9Kqa3+ZeuUUufX2/76gHNap5QaqZSaCeQCH/pb338f2AqtlLpUKbWs3n5+o5T6oCXPVYjWROqiiMop103tgART0ekMZAI9gRsiWP9BoD8wHMgDugH3xHDcQuAcIA24FnhUKTWyXrmy/Pu/GnhOKTUg2jIopZ5WSj0dQ/kCjQU2+svzEPBC/S+eUqo38C3wlNb6XwGLrsA8vxzADtzpX78/8AZwO5ANfIJ5EWP372eiUspQSnX1bzfOv10fIAVYHXCMc4ATgGHAJcDPmni+QrRGUlfV9XPgv0AG8AHwVMCyrcBEIB24F3hNKdXFf5yLgb8CV/nP6efAQa31lcAu4Fz/0L6H6h3vQ2CAUqpfwGtXALOiPVch2jipixon101tndZaHmEewA7gdP/zUwAX4AhYfg0wr942GvPLp4BKoG/AsnHA9jDHCtpXA+V6H7gtoFweIDlg+ZvA3Y2Vwb/t7giP2ct/btZ6r78M3BdwDlsCliX5t+ns/3su8Ij/fb283n7mAncF/P0r4DP/87uBNwOWGcAe4BT/3/nASOAy4DlgCTAQs4L5oN7/zYR679P/HevPmTzk0dSH1FVBx9VAnv/5X4GvApYNBqob2HYVcJ7/+ee15W/oPff/XaeOBF4D7vE/7weU++vEqN5vecijLT2kLqpzzDp1QsDrLyPXTe3qIfdMReeA1toZ4brZmF+K5QENDAqIOtOTv8v3L5gtJYZ/vz8GrFKsta4M+Hsn0LU5y4BZ8QDYAp7X/u0O+Ht/7ROtdZX/uCkBy6cAW4C3Qxxjf8DzqoDtumKeU+1+fUqpfMzWIjBbWU7BrIy/BUqASZgV4LcRHkOI9iSe66pQ6n/vHcp//4ZS6irgt5gXPmDWCbVDZXpg9lzFYhbw/4C/YbYev++vE3No2XMVojWJ57pIrpvihAzzi46u93cl5pcOAKVU54BlRUA1MERrneF/pOuAGzUjoZRKAN4BHgY6aa0zMLtrA7uAOyilkgP+zgX2NlcZ/PZhfvl71Xu9NwFf2Aj81V+uWSryFMJ7MYcIAODv/u6B2coCRyqFif7n32LozJh3AAAgAElEQVRWCpMIrhSEiAfxXFdFU+aewH+AW4CO/jKvCShzPtA3zOb13+P6vgSy/fceXM6RIX7H5FyFOEbiuS6S66Y4IcFU0/wADFFKDVdKOTA/8IDZCoD5I/2ovyUSpVQ3pVRDY02VUsoR+MAcy5oAHAA8/taWn4bY9l6llF0pNRFzfOtbMZYhJG2m+30HuF8p1VEpZVNKXY45ZObTKHblBi4GkoFXlVKRfAbfBM5WSp2mlLIBdwA1wAL/8m+ByUCi1no38D1wBtARWBlF2YRor+KmropSMubF3gH/Ma8FjgtY/jxwp1JqlDLl+QMwgAKgT7gda63dwFvAvzDvGfnS//qxOlchWoO4qYvkuil+SDDVBFrrTZhDOL4CNgP15zv4A2bX7CKlVJl/vQGENx6zRaT+41bML0Yx5nCRD+ptt9+/bC/wOnCT1npDtGVQZsaaGQ2U71fAIcwbEwsxW3PP1loXNLBNEK21C7gA6AS82FjFoLXeCEwFnsRsnTkX88Zvl3/5JqACszJAa10GbAPma5nzRYh4rKsiorVehzkUbyFmcDQUmB+w/C3gfsxepXLM+y4y/Yv/AdylzAxXd4Y5xCzgdMyLtMBhPtG+30K0C3FYF8l1UxxQWjc2UkEIIYQQQgghRH3SMyWEEEIIIYQQMZBgSgghhBBCCCFiIMGUEEIIIYQQQsRAgikhhBBCCCGEiIEEU02klHpZKXWf//lEpdTGo3RcrZTKa+Z9Hj6Xo7nt0aKU+pNS6vljXQ4hjhapn5q+7dEi9ZOIJ1I3NX3bo0XqpsbFRTCllNqhlKpWSlUopQr8H95mnyBRa/291rrR9LZKqWuUUvXTgTYbpdRcpdT0ltp/U7X0+fuPcYpSanfga1rrB7TWrfZ9EfFJ6qfWReonIUxSN7UuUje1XnERTPmd65/BeiQwGrir/gpKKetRL5UQQkj9JIRonaRuEqIR8RRMAaC13oM58/RxcLjL92al1GbMCeRQSp2jlFrln4xxgVJqWO32SqkRSqkVSqlypdRswBGwrE5Er5TqoZR6Vyl1QCl1UCn1lFJqEDADGOdv7Snxr5uglHpYKbXL3wI0QymVGLCv3yml9iml9iqlpsV6/kqpt5RS+5VSpUqp75RSQ+qtkqWU+tJ/ft8qpXoGbDvQv+yQUmqjUuqSWMtRr0w7lFJ3KqVW+8s1W5mzmKOU6qCU+sj/Hhb7n3cP2DZTKfWS/30pVkq9r5RKxvw/7up/jyuUUl2VUn9VSr3m3+5TpdQt9crxg1LqgpY8VyEaIvWT1E/+7aR+Eq2K1E1SN/m3k7ophLgLppRSPYCzgJUBL/8CGAsMVkqNAF4EbgQ6As8CH/i/sHbgfWAmkAm8BVwY5jgW4CNgJ9AL6Ab8V2u9HrgJWKi1TtFaZ/g3eRDoDwwH8vzr3+Pf1xnAncBPgH7A6U14Cz717yMHWIE583egKcDfgSxgVe1y/5fsS2CWf9vLgKeVUoPDnH+JUmpCFOW6BDgD6A0MA67xv24ALwE9gVzMmc2fCthuJpAEDPGX61GtdSVwJrDX/x6naK331jveG8DlAeUd7D/Gx9GeqxDNReonqZ/8pH4SrYrUTVI3+UndFIrWut0/gB1ABVCC+QV9Gkj0L9PAqQHrPgP8vd72G4FJwMnAXkAFLFsA3Od/fgqw2/98HHAAsIYozzXAvIC/FVAJ9A14bRyw3f/8ReDBgGX9/eXOC3O+c4HpEbwvGf79pPv/fhmz0qpdngJ4gR7ApcD39bZ/FvhLwLb3Rfj/Uf/8dwBTA/5+CJgRZtvhQLH/eRfAB3QIsd7h/4uA1/4KvOZ/nup/z3v6/74feNH/vMFzlYc8mvMh9VPY90XqJ6mf5HEMH1I3hX1fpG6SuqnOI57Guf5Ca/1VmGX5Ac97AlcrpX4d8Jod6Ir55dmj/Z8Qv51h9tkD2Km19kRQtmzMFoLlSqna1xRg8T/vCiyP4JgN8rf43A9c7D+mz78oCyj1Pz/8XmitK5RSh/zH7wmMre1a97Nitm40h/0Bz6v8x0QplQQ8itny0sG/PNV/Lj2AQ1rr4mgPprUuV0p9jNly8k/Mlpbr/Ytb+lyFqE/qJ6mfDpP6SbQiUjdJ3XSY1E2hxVMw1ZDAL3g+cL/W+v76KymlJgHdlFIqoFLIBbaG2Gc+kKuUsoaoFHS9v4swu2CHaHNccn37MD/8tXLDn0qDrgDOw+zq3gGkA8WYlU+tw8dRZtaeTMwWpXzgW631T2I8dqzuAAYAY7XW+5VSwzGHGSh/mTKVUhla65J629V/j0N5A/iLUuo7zPHb3/hfP1bnKkQoUj8dIfWT1E+i9ZC66Qipm+K4boq7e6Yi8B/gJqXUWGVKVkqdrZRKBRYCHuBWpZTNf8PdmDD7WYL5RX7Qvw+HUuok/7ICoLt/HDFaa5//uI8qpXIAlFLdlFI/86//JnCNUmqwv7XhLxGch9V/zNqHDbN7tgY4iNma80CI7c5SSk3wl+3vwCKtdT7mGOb+Sqkr/eduU0qdoMybQltSKmZlWaKUyiTg3LXW+zDHMT+tzJstbUqpk/2LC4COSqn0Bvb9CWZLyt+A2f7/Bzh25ypEY6R+kvpJ6ifRGkndJHVT3NZNEkzVo7Vehtll+RRmy8MW/Df0aa1dwAX+vw9hjg99N8x+vMC5mDdE7gJ2+9cHmAOsBfYrpYr8r/3Bf6xFSqky4CvMVgW01p8Cj/m32+L/tzHPYH6Rah8vAa9idnPvAdYBi0JsNwvzS3cIGAVM9ZehHPgpZtfuXsyu5X8CCaEOrswsMBMjKGdjHgMSMVugFgGf1Vt+JeAGNgCFwO3+8m7AbD3ZpswbOrvW37HWugbz/+90zPOufT2qcxXiaJH6SeonqZ9EayR1k9RN8Vw3qbpDWIUQQgghhBBCREJ6poQQQgghhBAiBhJMCSGEEEIIIUQMJJgSQgghhBBCiBhIMCWEEEIIIYQQMZBgSgghhBBCCCFi0OCkvVlZWbpXr15HqShCiKNh+fLlRVrr7GNdjqaS+kmI9qc91E9SNwnR/jRUNzUYTPXq1Ytly5a1TKmEEMeEUmrnsS5Dc5D6SYj2pz3UT1I3CdH+NFQ3yTA/IYQQQgghhIiBBFNCCCGEEEIIEQMJpoQQQgghhBAiBhJMCSGEEEIIIUQMJJgSQgghhBBCiBhIMCWEEEIIIYQQMZBgSgghhBBCCCFiIMGUEEIIIYQQQsSgwUl7hRDRyS/L5/s931NSUwJARkIGE7tPpEdqj2NcMiFEXKo6BBs+gooCcNdAYgbkngjdRoFSx7p0QrSIbQcq+G7TAYqrXBjKoEOyjVMH5tC9Q9KxLppohySYEqKJfNrH97u/58U1L7L24FoAXF4XAHaLnUeWP8KQjkOYdtw0JnafiKGkQ1gI0cL2rIAFT8LGT0AZ4HGC9oHFDoYVUrvASbfB0IvBLheYou3z+jRfrS9gxtytrNtXBoDL4wMgwWZw/8frGZnbgRsn9eHkftkYhjQmiOYhwZQQTVDlruK2b27jhwM/UO2pDlpe460BYEXhCjZ8t4Hjs4/n8VMfJ9GaeLSLKoSIB1rDV3+FJc+Cp8YMoAJ5Xebj0Fb47I/w3cNw7SeQIb3nou2qqPFw7UtLWLu3jCqXN2i5021+DxZuO8jCbQcxFCTZLGSlJjD1xJ5cPKoH6Um2o11s0U5IE7kQMXJ5XUz7fBorClaEDKTqq/JUsbxgOdd+du3hnishhGhWH98BS54Dd3VwIFWfuxLK9sBzk6B0z9EpnxDNrNrl5cJnFvBDfknIQCoUn4YKl5cdB6v4f19sZMwDX3H7f1dRWu1u4dKK9kiCKSFidM/8e9hSsgWXL/LAyOVzsaVkC/fMv6cFSyaEiEvLX4Yf3gB3VeTbaC9Ul8Ar54AvsgtRIVqT22avZEdRJS6vjmn7arePGo+Pj3/cy9lPfM++0sYbR4UIJMP8hIhBQWUBX+78MmQgVfx9MUWfF+EqdGFxWEgblUanizphSbYA5tC/L3d+yW+rfktOUs7RLroQoj3y+WDO/UGBVK/Hyqlyw/bbUki2m/eIPL/CxWur3cy9JtlcSXuhohA2fwEDzjzaJRftjNPj5MeiHymtKUWhSE9IZ1j2MOwWe7Mfa9fBKr7deIAaT3AvbOW6uZQtfR/3wd0Y9kRsOX1IH38Jju5DQu7L7dXsK3VyyYyFfHzbRNIcMuxPREaCKSFiMHvj7JCvF31axIFPD9B9endSBqfgLnazd+Zedjy8g95/7o1hPdIZ/ObGN7llxC1Hq8hCiPZs65ywPVJeDY8vdvGniQnht3dVwLzHJJgSMdtVtovX17/Oe1veO5xoSaHQmD1GF/W7iCsGXUHXlK7NdsyXF2zHp4N7pMqWvEfp4rfp+NObcfQeibJYqd6+nOrNi/GWHQgbZHl9mv1lTh74eD0PXjis2cop2jcZ5idElLw+L//d+N+gXilvtZfC9wvpOrUrqcNSUVaFPdtOj1/1wFXkonRB6eF1XT4Xb2x4A68MqxFCNIcFT5gBUQi/G2/n4QU1lDgbGQa1bxUc2t4ChRPtmU/7+Mfif3DBBxfw5sY3qfZUU+mupNJdSYW74vDzWRtmce775/L48sfRIQKgaLk8PmYvzcddb3ifr6aSknmvk/mTX5I0YDyG3YGyWEnKG4sluQOHvv4P6SdeQvdbXqPbL18ideRZVG9efHh7t1fz/qo9VNZ4mlxGER8kmBIiSsU1xSETSFRtrsLn9pE2Kq3O6xaHhdRhqVSsrXuh4/Q6KXWVIoQQTVawNuyi0V0tnNLLysMLahreh8UOheubuWCiPfNpH3d+eyfvbnmXGm8NHh0+AHH73Li8Ll5f/zr3zL+nyQFVQZmTUHuo2bMB7XGR1H9c3bI2EGR1mDytzrqGUry/UpKyiMhIMCVElCpcFViUJeh1b4UXa4oVZQmeu8KabsVTUfdHxqqsVIRpSRZCiKi4Kxtc/LfJCTy5xMWBygYy/Gkv1JQ1c8FEe/bEiieYt2ceTo8z4m2qvdV8tvMznv/x+SYdu9zpwQgx8bS3ugwjKQ1l1P2dDhdkhVLl8vLa4p1NKp+IH3LPlBBRclgd+EKkHLakWPBUeNBeHRRQeUo9WFPqft282ovD6mjRsgoh4oQlwUyHHsZxORbO6W/lwXkuBmWHaUdVBthkDjwRmRJnCTPXzYwpEZPT4+S51c8xZdAUkmyxTRqdaLeE7N2yJKbhqypD+7x1AqpwQVY4BytkChMRGemZEiJKHRwdQgZTSXlJKKuibHndll2v00v56nKSByfXeV1rTUZCRouWVQgRJ9K6NLrKvac4+M8KF3vKwgyv0j5I797MBRPt1btb3kWF6Bkq+rSI/W/tp/MlnRn89GD63N0H10EXOx7egS8g655Sio+3fRzz8bNTE4LulwJI6DYQZbVRtWlhndcDg6xIuL2NzNMmhJ8EU0JEKcGSwGk9TzucraiWJclCzi9y2PvaXspXl6M9GtcBF/lP52PLtJEx/kjgZCiD03ue3iKpYoUQcWjMTWBLbnCVvEyDS4fYeGJJmBb3xEzoOrIFCifaG5/28eraV6nx1r0PL5pETNWeal5c82LM906lJFiZNCCb+vGckZBMxoQpHPpyBlWbFuJzO9FeDz5XNSgVFGQ1tH8hIiGfFCFicPWQq/lm1zc4vXXHiWeflY0l2cL+2ftxFbowEg3SRqbR48YeGLYjwZfdsHP1kKuPdrGFEO3VsIvh8z82uto9kxKYudodvMCWBCfdRtCVqRD1ON1e/rt8LQery6DexyWSREwdTu5w+PX9lfup8lSR3EhDQDg3nNyH+VuKqHLV7W1KG3MBRnIHShfOpuijh1H2RBI65ZE6/AwOfTkDZVhw9B6BMqw4d6zCuWt1nSQUFgXj+naMqUwi/kgwJUQMhnQcQs+0nmwu2Rw05C9zUiaZkzLDbmsog55pPRnccXBLF1MIES/syTBiKqx4FQKSAey4PbXOaj3SDZx3pdXf2rxfatilLV1K0Yb5fJpHv9rEC/O2o2wFqG4G9XMxNZaIqXpn3fv6rBYrZTVlMQdTo3t2oHOag+0HK6nfwZUyZDIpQyYHbWPvOjAoyEobV/ezb7MaXDehT0xlEvFHgikhYvTEqU9w8YcXU+4qPzwpYWMUihRbCk+e+mQLl04IEXd+8nfYuRAObIQQSQHCsiXCZbPAESLIEgJzTqcbZi5j8baDVLt9KA3JIX73ok3EpLXGZrHFXC6lFE9NGcnZj38f8TbhgqxAfbNTGNA5tcF1hKgl90wJEaOuKV159cxXybCnYYlgzLdVWeng6MDMM2fSJaXxm8WFECIqNgdc8yF0GgwRZArVYAZSF70EfSa1ePFE26S15jezV7HIH0gBaG8KqOBEDtEmYvL6vKTZmxbEL9t+CLs19OVs5bq57HvldnY9chG7n7qSgjf/gnN3+DnZABw2gz+fPahJZRLxRXqmhGiCvul9eNeVxtNpnfmwaidKKao9dYcxJFoT0Vpzbt9z+dXwX5GVmHWMSiuEaPcSO8B1X8D8x2HRM/gqyzCMehOpWh1on4/KwgSS7nwPo9cJx6asok34en0h32wsxOkOGNLuc+CtzqVi9UozBfo+F4bDwJHrIOOkDPa+thfDYZAyOAV3sZu9M/cGJWJSKE7qdlKTEjFprXn2u23UeIIz75UteY/SxW/T8ac34+g9EmWxUr19OdWbF+PoPiTk/hw2g3t/PoTxfeV3WkROgikhmmLFq2Q5K7ln6gfc6XPx8baP+WT7J5TUlACQkZDBWb3P4uw+Z8c8l4YQQkTFmgCTfo/3+Onsu2I83S4bjKoqBI8LEtOhz2TU6OsouedBnJ+vJOtGCaZEeM98uzUowQNA4YdQOn8/Xa/uQurQVJRFUf5jOVUbq+h0YadGEzElWhO5Zsg1TSrb0h3FFFcFD2n11VRSMu91Op51O0kDxh9+PSlvLEl5Y4PWtyjzPqlHLjmes4Z2bVKZRPyRYEqIWJXtha/vhas+AIuVJIuViwdczMUDLj7WJRNCCCrmLUT3OAV11YyQy7NvvZWdV0yhw6WXYMmQOe9EsB1FlazZUxr0uq+mkuKvv6Tbdb1JH51w+PW0EWmkjTCH7TWUiAnMxsZRnUY1qXw/7inFE2I+qJo9G9AeF0n9x0W0n5w0B5/cOpEOyTJdiYie3DMlRCy0ho9+CydMh87HHevSCCFEkIo5X5Ny2qlhlyf07k3q6adz8IUXjmKpRFvyyZp9+ELcE1wbrNiyb0H7ok8g4bA6eHTyoyEn/Y1GudONK8TEvd7qMoykNJRhCbFViPLYLBJIiZhJMCVELNa8A8U7YOKdx7okQggRRLtcVMybT+rkhrOWZd38K0refAt3QeFRKploS/aWVONuIFjxufpSveeKqAKqRGsiT576ZLNMD5JgtWAJcSVrSUzDV1WG9gUPTwy9H7kcFrGTT48Q0ao4AJ/9H5z3b7BKS5YQovWpXLKUhN69sWZnN7ierXNn0i+8kKJnnj5KJRNtSZ2kEwECgxVvxSCqdt6At7or2mdD6xC9TdogwZLA0KyhzDxzJid2ObFZytcl3YHDGtz7lNBtIMpqo2rTwoj2071DYrOUR8QnCaaEiNanvzcnt+zetLHeQgjRUswhfqdFtG7H66dT/tnnuHbubOFSibbC69M88sUm/rdqT8jl9YMVn7MHVTtupWrHzbhLR+HzJKN9VrTPis+TTCc1ibfOfYtZZ89iQOaAZivn6YM74Q0xDNFISCZjwhQOfTmDqk0L8bmdaK+H6q3LKP7mxTrrJtstTD2xZ7OVScQfSUAhRDQ2fAz7Vpm9UkII0QpprSmf8w25Lzwf0frWDh3IvPoqDjzxJN3+38MtXDrR2jndXqa/sozlOw+FHOIHdYMVZVhw9B6BMqxUrtuNc1cZHSbffXhdu9Vg+s+H0Ds9t9nLmpJg5bzh3Xh7eT7181CkjbkAI7kDpQtnU/TRwyh7Igmd8kgbd2md9ZLsVk7u13APrhANkWBKiEhVl8DHd8KFz4Nd0pwLIVon59p1GAkJ2Pv0iXibzKuuYssZZ+Bcvx7HIJmwNF75fJqbZ61g2Y5DOEPM3RSoNlgpnvM8nrJCMzGTxYo9py/O3WsPz+Xk8vg4bVCnFivz9Am9+d/KPXh9weVNGTKZlCHh7xt02AyuP7k3htG0RBgivkkwJUSkvvgzDDgTep10rEsihBBhVcyZQ8ppp0WVKc1ITibrhhspfOwxcp99tgVLJ1qz//2whwVbD4YMpCrXzaVs6fu4D+7GsCdiy+mDLaMzPreT7PP+GHZiXIsBn63Zx5XjerVImft1SuV3Zwzg4c83Ue2OLOEEmEknRvTowLSTerdIuUT8kHumhPCr8XgpLHeyv9RJlctTd+HWObDtWzj9r8eiaEIIEbHyOXNIbSAlejgZl16Ca8tWqpYta4FSibbg6blbqQ4xQW/Zkvc49PV/SD/xErrf8hrdfvkSKUNPo/yHz8j8yS9JGjAew+5AWawk5Y2lw+Rph7f1+uDZb7ehQ9zb1Fyum9CHmyf3xWGL7LLWYTMYmZvBC9eMxhoqHaAQUZCeKRHXfD7NvC1FPPvdVhZtO4TNUKDA7dX0yUrmpkl9OXtAGo4Pb4NzHgNH2rEushBChOXavQdPYSGJw4dHva1ht5P161sofORRer7+GgqgaBNUHgCvGxIzIHsg2CTzWXu0Zk8puw9VB73uq6mkZN7rdDzrdpIGjD/8upGQDFpHNDHuwUoXWwor6NcptVnLDFBa5eat5fl8sa6AlAQrHq8br08TKnRLTrBgsxhMn9Cbmyb1lUBKNAsJpkTcWrL9ELfMWkFljYdKf0uc13ek+t1cWME9/1vD3Z4a/pR7BVP7nR7zscqdbuZvOcihShderclItDG2TyY5qY4mn4cQQtSqmDOHlFNOQVkim6y0vvRzz6X4pWepmfUHHIWfQlURGLWXChq0D0ZMhbG/hEwZHtWevLtiNzWe4F6p2gl66wdN0UyMa7Uoiipc9GvGW6f2lVbz4Kcb+GzNfgylQg7xU5gJMLKS7fTKSubq8b04dWCOBFGiWUkwJeLSZ2v2cfvsVWHn0KhlBllW7t8znPxP1vPHs6K7MXvD/jJe+H47H67ei9VQZmuZBouhcPs0E/KyuPHkPozpndnkmeCFEKJ8zhwyr5wa8/Zq/Qf0GrUcvWERWMLUj0tfhOWvwHEXwrlPgEUuJdqDPSXV+EJ054QLmgLnmookoHLXT7fXBOv2lnHFfxZR7nQTJuEgANp/3AqXlz+eNYjjuqU3WxmEqCU1oIg7y3YciiiQClTt9vHqwp10SnMwbULjrbE+n+a+j9cxa/Eu3D5dp8cr0DcbClm07SBjemcyY+ooHLbYWpOFEMJbWopzzRqSx49vfOVQlj4Pn9+F8tWgGqqKfG7zsfZdKN0NU98Biy22Y4pWoyaCCXoDg6bAuaaSB05ocN9aQ3pi83xGdh2s4rLnFlLm9DS+MuDTUFrt5vLnFvHBryfQOyu5WcrRVmmtWbGrhE0F5VQ4PSTaLXTvkMiEvCzpsYuRBFMirmitueOtH8IGUqGyFaWPvwRH9yFUu73887MNXDCyGxlJ9gaP8ds3V/H52oJGU8tqoMrlZdHWg1zy7ELevHGcBFRCiJhUfPcdSWPGYCTGcE/Tpi/g87vAE3zPTFjuati9BD74NZw/I/pjilalQ3Lo37VwQVO4uaacO1bh3LW6ThIKn9b0b6b7pa5/dRkVNcGBVEO/3wAVLg/XvbyUr++YFJcjQSpqPLy/cjczvt3GoUoXWoPH58NiKCyGwmYYXDO+F1ecmCu3IERJgikRV1bsKuFAeU3IZWVL3qN08dt0/OnNYVO8KgWzl+Zz46S+YY/x9NytfL62IKoUrU6Pj037y7njzR/495SR0Z2UEEIA5V/HlsUPreHj3wYFUr0eK6fKDdtvSyHZbl58Pr/CxWur3cy9xt+6766Gte/BxDsgq19TT0EcQxPysvhi7f7D9xDXaiho8laW0OHU6xqcGNdqKC4a1Z1Ee9MbCn/IL2HXoaqg4YiR/H5rDftKnazYVcKonh2aXJa2ZMP+Mi5/bhE1bh9V9a5NjkzM7GXGt1t59rttPHn5CE4f3HJzg7U3EkyJuPKf77aFDHLCZStKyhtLUt7Yw3873T6en7ed6yf2CTnJn9Pt5d/fbAkbSDXUcub0+PhqfQE7D1bSs2N8D0MQQkTH53JROX8+ne/6c/Qb71wA1YdCLvJqeHyxiz9NTGjg4F5Y9Ayc80j0xxatxtnDunD3/9aEXFY7QW+ooMnRfVCDE+N6fBqXx4fWusk9Qs9/vy0oSUakv98ATo+X/3y3lVFXjm5SOdqStXtLuXjGQqpCpLyvr3Y0zS1vrODhi4/nnGFdW7p47YIEUyKuzN1YSKipLsJlKwqlssbDlgMVIYcsfPDD3rDbRdJy5tOal+bv4K8/HxL5SQkh4l7V4sUk5OVhzcqKfuMFT4KrKuSi342389D8Gn51gp0MR5gLYZ8bfngDfnof2JOiP75oFRw2C5ee0IPXFu0M6K04ImXI5AaDpob8b9VelFI8cP5xMQdUNR4vn68tCOqViub3W2v4ekMhTrc3LobUF1e6mPL84ogCqUBOt4/fvfUDvTomS9KOCEgwJeKGx+ujJkw2oWhSvFoMRXGlK+SyGXO3hqy0Im05c3s1by7L5//OHBgXFb0QIko+L2z+Ala8CmV7weuCxAx8OyHtlJ/Fts8d30PIWXlgdFcLp/Sy8vCCGu47tYH7KAwL7F8NuSfGVgbRKtx4cl/eWb4btzey5A6RqnZ7eX/lHgZ2TuXq8b1i2kdJlRvDAOr9xEbz+w1gNQwOVrroltH+50ubtWQXzjCBVGP3mDndPh79ahMvXH3C0SxymyTBlIgbDc29Hk2K1wqnh9+/vZpuHRLJSLKRkWQnI9FGeqKVbUWVITzkSkUAACAASURBVLeJpuVMKcg/VNUikxuKJirbCwVrwVlqTlya1g26HG/+pwnRktxOWPgULPo3eFzgqqizOMWrUAXLYPYymHwX5AyMYt+he6Vq/W1yAie9WMltY8Mn3gEF1SWRH1O0Sp3THcy8biyX/2dR1L0ZjV2cV7u9PPrlJqaMzY0pa5zT7cUIUddGm6LdUFAd5bm1RV6f5sV520MmworoHjNg3uYiCsudkpCiERJMibhhsxjYLAauEBVLNCleHTYLv/vZADKS7JRUuyipclNa7WZviROlCDmMMJqWM0MpypzuiM9LtDCfD7Z/C/MfN+8tsSaYE5cqZf6b2BFOug2OvxQSJAAWLaDqEMw8Hw5sAI8z5CqGRYPPBRs+hi1fw6WvQd5pke1fGeZnOYzjciyc09/Kg/NcDMpu4CJY5ptqF47vkcHbN41n6guLKa12EWpAR/3ASSUk460sJuvM28JenIM559PXGwr52ZDOUZcr1WHDE2KakWh+v8G8hystsf1/Vj9fuz/ktUQ095gBzFq8i9tP79+iZW3r2v+nSYgAE/pm8c3GwqBeqmhSvCbaLZw5tAuWegkonG4vMxftDNkDFk3LmdaQYJUhfq1CxQGYeR4U7wCXv9fRWy8bpKsSvrwbvrrHvIDtG0M2NSHCcVfDK+dC0SZzSF9jtM/saZo9Baa+Bz0b7w0nIS1sAopa957iYOSzFdwxLkwiCu2FpBju1xKt0uCuaXz06wmc/NA31B/XUb9XQ3tq2P30NST2HtHoxXmly8uMuVtjCqYyEm2kJFg55Kn7PYjm9xvMBtGOyQ0kVGnjdhdX8fhXm3lv5Z6QwWc0I2VqPD6W7mi4bhAgs3OJuHLDpD5h07OmjbngcIrX3U9OYfcz11C+4iMS+x2pcBKsBtNO6hUUSIFZQYcLggJbzhrj9vrISW2/FX2bUV4Az06AA5uOBFLhuKvMdd64AtZ/dHTKJ+LD53+Cg1vqBFK9Hisn51/lVLqOXCg9v8LFKS8HfE7d1TDrksY/uwDHXQhGwxOq5mUaXDrExhNLwgR0tmToPKzxY4k2492Vu7HW+62r7dXI/MkvSRowHsPuwLV/C/i8ZJ8fWSbJjQXlMZXHMBTTTupFgjX40jWS328wf8OvHt8z5G94e/BDfglnPv49767YHTKQgujvMSutlpEyjZGeKRFXxvbOJCPJFnYseCTZii4bkxt22fkjuvHmsvygSiyalrNBXdLISZPxyceUx2X2BlQeNDOVRbxdNbw7Ha79DLoOb7nyifhQUw6r3gg5tC/ilOVr3oGRVzV8nBN/CStnNvpZv2dSAjNXh1jHmgjjbsbMDiDaA601L83bEXS/TahejWgvzp1RzMFY32VjcnlyzpaQyyL5/dbA1LE9Yz5+a7Zxf3lE97pFe4+ZQ0bKNEpqPhFXlFL866Ljcdii/+gn2gxunpxHVkr4i5dpE3oHteTViqTlLDnBwk0NTAgsjpL1H0DZ7joXlxH1BoDZI/DVX45WSUV7tvrNsMlNfjfezsMLaihxNpBax10J8x4NfSNnoI59zUQq1D3WjttTOb3PkTbXHukGzrvSjkzYe5iv8YBNtClOt4+SED0SoQKnwIvzSMSSfKJWVkoCF47sTmIMv+EOm8HPj+/aLhsrvT7NVS8uDplYo3LdXPa9cju7HrmI3U9dSdmS98CwRDRSRinokSnTHTRGeqZE3DkpL4sLRnRj1pL8iLdJtFk4f2Q3fn1qXoPr5eWkMKhrGqvzS/GGuIBprOXMbjE4fVBOxOUSLWT+YyGHR0XUGwCwcyGU5ENGjxYqoIgLi2eEzbQXccry8v1QsAY6D234WL94Bp6bZPaGRcOWCGf8E5Iyo9tOtGrlTjc2Q+GtN8oiVK9GtAkgOiY3lBWycfeeN4RtBypYtbsEpzt84pRADqvBkK7pPHB+I9+DNmruxkIqnJ6ge7bDZe3TPk9EI2UcNgtTxoYfjSNM0jMl4s53mw7wxboC7j//ONITbSQnhO/CTrRbSLAa/Pq0PO7/RWSTDT51xUhSHNG3UzhsBi9cc0KTWu1EMyhYB0Whh5FE1BsAgIYl/2n+son4UhZ+EnAwU5Y/ucTFgcoGLigNqxnYN6ZjX7jyfX9GygjvJ7ElwsTfwairI1tftBlJCdZGM+fVChzGXrVpIT63E+31UL11GcXfvFhne4fVYOqJTRtmZ7MYvHLdGCb1zybJbmn005pkt3BSvyxenz4We4j7rdqDGd9upbJer1So+9uUxUpS3lg6X/EgSQPGU/Thw+Q/chG7/t/5FH38CJaMTnX2YVEwqmeHo3kqbZL0TIm4smZPKb+ZvYoZV47ihF6ZXDK6B1+tK+CZb7eyfl8ZdouBUubkuR1T7NwwsQ8XjupOqqPhm7MDdctI5K2bxnHZs4vMtLKNXXdjVvYzpo5iZK5UWsdc/qKw15IR9wZ4XbDtm5Ypn4gfjWTviyhleW12v0h0Hw3Xz4V3r4fCdeDzmI/67ClmIHXmQ3DcBZHtW7QpyXYLdquBp94Ferj7f20dc7F3zqN04WyKPnoYZU8koVMeaeMurbO9Bi5v4L7jSCVYzd/MxdsP8czcrXy76QCJNgsen9mwYLMYeH2aE/t05MaT+zCub8eIGkPbooIyJ6t3lwa93lDWvrIl71G1cT5Z59xRp8eqJn9tnfWqXV7KnB7SEyO/BopHEkyJuLHzYCXTXl7K/ecP5YRe5pAUm8XgzKFdOHNoF0qqXByqdOH1adITbWSnJsRc+fbvlMont03k7x+t46v1BShF0HAEq6Hw+DTj+nTknnMHM6hLWpPPUTQDZ6mZgCKMyCYw9e/nGPH6NIu3H2RviZNqt5c0h5VBXdLoLxNBty22xEYDqkZTlivDTH0eqaw8uOEbOLARFj0Da95G11SA1ihrAnQ/ASbcDn1Pk4QT7ZhSiktG9+D1RTtx1+uhShtzAUZyh6DAKX3cpTi6Dwq7T5tFcfqgTmQ2cZif1poql5cql5fhPTK4aVJfCsqcTJvQm9Iq8z6v9CQbJ/fLpnN6+7s/qr49JdXYrQY19ZKFhEsMEs08UzaLwTvLdzNtQu+WO4F2QIIpEReKKmq4+sUl/Pq0fpxxXOj5LTKS7GQkNa2SD9Q53cG/p4ykuNLF7GX5vLN8NyXVbnz+CQN/Orgz3206wHUTeksg1ZoYNjAs4A19M3XkE5ge/Za8gxU1vLF0Fy/N20GNx4tPg09rLErh1Zo+WSncdEpfzhjSud0Od2lXcgbDroZvEg9MWT40J8T/qdcFOQOjP3b2ADj3MTj3MWrWrmXvXX+mz3vvR78f0WZde1Iv3liyC0IM94skc159bq9mS2EFH/ywN6Y6KP9QFa8s2MEbS3fhdPvMBkmvxmZRDM/N4NSBOQ0miGqvQiWdgPBZ+6KZZ8rp8fH8vG0STDVCginR7lXWeLju5aWcM6wrVzZxrHYsOiTbuWlS35BZ+vp3SuWVhTs4fXCn4A3FsZHSCSz2BnsEGu0NqN3PUfTlugJufWMlWuugdMa11u0r44/vrObBT9cz+4ZxkqWptRt/K+z/EVwVDa4WNmU5mEP3Mpo4rMpqgzCfKdF+9eyYzKieHVi64xDuSMarR2BjQTl/fHc1f35/OVedXsL8A++wt3IvLq8Lu8VO1+SuXD3kas7ofQaJ1kQASqpc3PrGKhZvP4hP68NlqU2O4fVoVuws5qQH53DOsC48cMHQuJr4PiXBWn9eZSB8YpBoU9nvK3WitW63wySbgwRT7dz+yv28t/k9tpVuo9JdSZo9jcEdB3Ne3nmkJ6Qf6+K1OLfXx82zVjCgcyp3/LT/sS5OkLOHdeGBT9az7UAFfbJTjnVxBED/n4a+TyRAo70B9hQYdW2jh6rttVyw9SBl1W7sVoPczCQuO6EHo3p2iPjH68Mf9vC7t1dHlNmq0uWl2u3lnCfn8eEtE8jtKAFVq9X/Z2ZgX8+O2+sO16xNWR7EngLjb2tyMZTVgg7TUyvat6euGMmZj39HUbkrZIba6Hlxp32CPXMhr2xSKONIo1W1p5qtpVt5cMmD/GPJP7h0wKVcnncjFz2zmAMVNQ0GdC6vBjQfr97HpoIK/nvDiSQnxMclbq+Oybi8wXV/uPvbPCUF+CpLIp5nylCKareXJHt8vJ+xkHemnVq6fynPrX6OFQUr0GjcAfPlfL3ra55Y+QSTe0xm+tDpDMgccAxL2nK01vzx3R9RwP3nD22VrSoOm4WLR/fgtUW7uOfcwce6OALAkQ5Dzjfn+NHhLyAb6g3wuT3Q92dh06VuLijn8a838+W64Pvplu44xCc/7qNjsp1fntKXS0/IxRJm7jKAVfklEQdSh8unzdTHlzy3kDl3TJIfydbKsMCkP8DX90aeROLwtlZI7Qx5pzW5GMpiQXsbbmAQ7VNmsp33fnUSlzy7kMLyGlxN6aFUbhJ7vIQlMR9lhJ8guspjftb/u+G/vLZyHmUV1+D1RlZHOT0+NhWUM/2VZbw2fWyDdWd7kZ5k4/RBnfh0zb6gEZmh7m+zZ/cGS+Sp7H1ay8S9jZBf0HZGa82zq5/lhR9fwOl1hlyn9vUvdn7B3Py53Dv+Xs7qc9bRLOZR8fAXG9lcWMEb14/F1orTjU8Zm8u5T83jzp/1l4va1mLcLbD2ffBUH34p0t4AbdipKOtN4c8vIPv220k7+yxUwI3632wo5Fevrzh8T1PQ9hr/zdXV/P2j9XyxroAZU0fhsIX+MXvosw1hA6nKdXMpW/o+7oO7MeyJ2HL6kD7+Ehzdh+DTUFbt5v2Ve7hi7NEf/ioiNPZG2L8a1r4XeUClLGajwNUfmgFZU1mt4JGeqXjVNSORT26byIy5W3lt0U68Ph2UhjucOnVQoiYx1072z7NI7l9/4udgTq8Tbd2BveurVOdfQ+BsPg3VbTUeHz/sLuGTH/dx7vFdYzvpNub6k/vw1fqCoCQUEPr+trIl70Y0zxRAVnICRhwEpU0hV27tzNOrnubltS+HDaQC+bQPp9fJXxb8BUMZnNH7jKNQwqNj5sIdfPLjft6+aVyrD1B6ZCYxumcm76/cyxUyOV7r0Pk4GH8LLPx3dD0ChhXVsTdp//c1llVrKXzoXxx65RVyfv87kseMYf6WIn75+vKIe5Gq3V4Wbj3I9FeW8cq0MUGtrHtKqlm2szjktuEma6zevBhH9yGAGbQ9++02Lh+T2yp7bgWgFPz8KXP+pxWvgruakDdI1LIlQXI2XPMxpDXPhaTZMyXBVDxLc9j4/RkD+c1P+vPF2gLeX7mHoooafFpTWF7D/jIn9UcBBtZBKcdbSOz2IRVrD1G+ovxwMFX8fTFFnxfhKnRhcVhIG5VGp4s6YUk2GwGU4cGStB1r2g94ykYE7behum3Gt1vjIpjaUljBvz7fEFWvYbiMjPVT2TusBlePl8a2xrTuq0wRlfl75ocMpBqrrJxeJ3fPv5uBmQPpld7rGJS8eX22Zh9PfbOFt24cT8c2ktnnqnE9eeCT9Vw+podc1LYWk/9spjdf+VpkAZUlATJ6wNUfgz2Z5DFj6PXmbMo++ZR9//dHXAOGcEPWGTjdwRfCjbWyLt95iKe/2cKvT+tXZ7uZC3cSdAVDdKlvD1TUsGJXiUzM2JoZBpz5TxhyASx4ArZ8afY++QMrnxeUIwmV0gkm/AaGXgz2ZrwXTob5CT+bxeDsYV04e1gXwEzwNPq+r4Kqofp1kKPLoxh2D2kj0kgbYfboF31axIFPD9B9endSBqfgLnazd+Zedjy8g//P3nnHR1Gnf/w9M9uy6YEUAgmk0HuXJsXesTcQKTbuLHfqeXfe6RXP83fnnWdv2PFU7HoqRQUE6UWQFiAECIT0tsn2mfn9sQSS7GyyG7IpMO/Xy5cvZ2dnvrtuvvP9fJ/n+TwZD2cgHnf7E0QPpi4r8FYPD2luyy2pYW+R7bRuCbEhr5zZb2zA7pab2mLRJBhHRgW4oRX6gp3udNzcJ52QeXHbi35CqvSbUgo/LCTluhQGvDCAzD9m4i5zc/DJgyj1djG8ipe3d73d1kNudTbklfPwpzt4bdboTlVYPzG7Ky6vEjDKoNMOCAJc/E+48AmITEIVA/QrMVrBYIHB18DtKyAq8eQlRJHYSy8h85uvWZo1Hq/T5ff26g2fUv7dq8SedR09frmQ7ne9QfSIi3HsW3/iHIdHYcHqPDyNiozX5pYeL7xuSCjWt15Z5af8ymbP0+kApI+FG96F+3bCBX/z1VNN+BXl+el4zn0B7tkKI2e1rpACBD3NTycAP+VXYtBIAas/B4mWo4im8gavyw6Z4s+KSZ2RSvSQaASDgCnRRNr8NNylbqrWNOzTJ5rKEc1HQ5rbPLLC0p2Fp/YBOzC7Cqq59Y0N1LZASAWD2SBywcDkM9JuPlT0yNRpwqHqQ+wp39PgWN1k1X1ud6KH+HZm6iarvQ/upWpNFfFn+3ajvaqXL3O/5IFRD2A1dh4RUp+9RTbmv7uZ/9wwjEHdO5dToSgKzDyrJ2+vPcSQHrFsy6+iwu5zOYq3mhjSIzZgzYxOmBk5C4bPpPS+K4nvWYzBU+CLCEhGiEqCMXfAsBt9NSqBMJr4r6srLqmhmApll9WrKHy3u4gLB3U7cazaqR0tCMX61i0rVDsCF4PrdECiEmHUyboGx8dFmG3RmMIU1dbT/HQCUWn3oGos5evPQYaYn0BoOFfZ99lRPAoxIxvWnUoWiegh0dTsrDmxPgFA8GKI/QnZYQ16bpMVKLb5b2CdDqiqyry3N2LXqF1rKtMhWIySQI/4CP7v6iGtOezTFl1MnSYsylmErDT8owp1shIFkaWHljI9e3qbjLk1Kah0cOvrG/jDJQOY1Dux+Td0QMZldeGJb3bz3e4iREFAwFcZIeBz07lhTDq3ju+l9wZqB+TaWsp/PErC4ysgKnQL+02HKqh1+QufUHZZa10yry7bxYSSPXiOFuA5ehSxLB1E/xSWQM0atZBEMBv1JIXOjCkrE1fuAaLPDdMNJAN49TQ/HX98QSl/EV9/DhINlQhCQ8El18gYogwIkv97DbEGHIccDY4JgopoqESKSAl6bgNQtFx+TgPW5pZRafffBAumnqw5LEaRrMQoFs4d2+FrzjsK+rd0mrC3Yi9eteHDLtTJyu61c7DqYDiHGRaq7B5ufWMDt07oxfTh3dt7OCGjqir/XJLDa6vz8Coq7gA7wG+vPcjCdYeYOa4nv7+ov+6u04bUrFiBdcwYpBYIKYD8crtmGkaozRMP5ZdSsfs7jN27Y0xNpbsQTW65/3mBmjVqYTZIJEcHSGHU6RSYM7Owb1jf/IktRO8zpROIhEj/PmjQcA6KSPP/7UhREt4aL6qs+q1RvFVeDFEay1PRG9LcJgqQGH16pqi9/EOuX1SquUyHYCJWSdFmbj87kxln9dSzYUJAF1OnCbWeWr9jLZmsqtxVfsc6Mk6PzG1vb2JidiK3Tcps7+GEjKqq/Oaj7fxv+zFNS9P6eI43JXx33WGKq108fcMw3ayijahesoToC85v8fvtbq/mDmkoESQAb1w86f9+5cR/z9pdxOb3tvrZFAdq1qhlfauoKucPTG7xZ9Npf0yZGVS8917Yrq+n+ekEYnh6PFqPofpzkKnLCOLGKgiSQM2uGmp315J4WSKCQaB6czWxY06mSMtOGdt2G8nX+M9JqjcqpLnNbJCY3DcpLJ+7Pal1eVmTW+Z3vKlMh2AiViZJ5M6zM5nTCddS7Y0upk4TIo3+PRus2daQJ6s4U1xYx9mayIrKfe//RFKMmT9c0r9TCovnl+/nf9uP4fAEv1BxeGSW7Sri38v2cv/5p2fD5Y6EUluLfd16Uv/2txZfI9pi1GweGcouK4DV1FBwTembhNkoafZ8Ccb6VhIFLh2SSrTF2IJPpdNRMGdl4T5wAFVVwzMPShJ4veG7vk6nxWQQmTG2J6+tPuBnhlM3B5Utfp9jC48hWUQsvSwkXZaEZJVImp5EwcICRIvYwM3PmGAkbnzDtYgqm5Breze4bnO23imxFob26Fz108FQXuvGKIl4Gm1wBMp0CLY21y0rvLI6j9kTM/S/8xDRxdRpQt/4vmwu2oxHOZlDG+pkZTVYO4Q1+u5j1fywt4TyWjeiINAlysR5A5Lp2eWkYFRVlT9/uZMqh4c354zulClvNS4vzy3fr9lzqLlwvMMj88oPB5g3KZPYCH0hHE5qfviBiOHDkWJb/lDOTorSbNAbyi4rQN/kRvWPosDciRk8+/0+zd9Rc9a3Rsn3fp3OjRQTgxBpxVtUhDElpdWvL4iiz55dUXzCSkenHjPH9eT1H/PQ6n/mm4MmEdXnMQSpodtw4sWJSJEShR8U4i52I0aIxIyIIe2ONES/Ok4Rr21go+sGntsijBJ3Tc46LUWByytrVKkFznQIpTa3yu7hcLm9wXpLp3l0MXWacH3f63k/532/46FMVoqqcH6vlqcynQoeWeHrn4/x4opcDpbVIivq8bQ2387XP5fkMLh7LHdOzmJavyReXJnLhrxyFt05DrOhcz7cP9t6BFFjog+2gFQUBD7alM9cPSQfVqqXLCXmFFL8AAZ1jyU1zkJuiX86brC7rFaTxG1n+wuf2yZl8t3uInYcrdK0SQ9EhFHil9Oy6d8tpvmTdTo85swsXLm5YRFTcDLVT9DFlE4jUuMi+OXUbF5YkRsgy8KAu+IsTAmrEcSGtd0JkxNImJzQ5PVVxfd+CO63ZxQF0rtYuXzY6dmwNybCiFdjdy5QpkMotbkGSaDC7qFnl1Yd8mmPLqY0UFWVbSXb2F6ynWp3NWbJTNeIrkxLn0asuWOGjNNi0hjQZQBbi7f6vRbMZGUQDFyRfQURhohwDTEglXY3M1/bQG5JjabNZ11X702HKrjn/a30TLBS7fTwyfwJxHTS9CRVVXlp5YGQC0jr4/DIvLLqAHP0kHzYUBwOalevJuXRR075WndOzuLRL3Zq/saDaZ4YG2FkXKb/E85kEHlrzhhmvraePYU2zQhVYyySwK0TejF/SlbwH0CnQ2PKzMCdewAmTAjPDQzHHf1M2oYDOmc2v5yWTWmti0Ubj2gKKk/5RIxxG0DwatZY1adybSWlS0pxH3MjWkTMaZFED5+OOQhtZJIEkmIs/Hfe2NPWQKFrpJlIswGX193geKBMB29lEUptZdC1uTqho4upetg9dr468BWv73idMmcZXsWLR/EgImI2mPnb+r9xTvo5zBo4iwFdBrT3cP24a+hd3PP9PX6Ne4PBIBq4ZcAtYRhV09icHq58YQ1HKuwnIlFNYXfL7C60MSg1JqCLUGeg2umlqNr//1Mo4XiACruHslq33lQvTNSsWkXEkMEY4uObP7kZLhuayv8t3oOjBQ0WI4wSvz6vT0DRHG0xsuiO8Ty1bC/vrDuEqqp+dVSC4LtOPB7mHv6B2eddoIvw0whzZhauA7lhu75uQqHTFIIg8OfLB9GrSyT/WrrXbw5S5Sgch2/H2vNFVNHtZ5VeR+niUkq+KiF1VipRg2IQBDNl34/HnrMdc+qogPeXRDBKIsPS4nh55qjTOv1dFAXmTOjFs9/v9zOu0sp0MCVmgBRcba5XVom3nr7fXbjQxdRxjtiOMHvJbKpcVTi8DS3DFZQTxxYfXMz3h79n1sBZ/GLYLzrUYmRc6jjmDZ7Hgh0LcHqDF1QWycLjkx4nPSY9jKPTZv67Wzha6QhKSNVnf0kNj36+k8evGhymkYWXaocnpALSQBhFgSqHRxdTYcK2ZCnR51/QKteyGCXeu+0spr/wI7Wu4BelEUaJq0Z059pRaU2eZzKIPHRRP+47rzeLdxTy5o8HKahy4PIqRJoMDOwew+2TMhmRHkf+LYuoeO99EmbcfKofS6eDYM7KxLZsWdiuL0gSqt5rSqcZZk/I4Kax6SzeUciLK3LJKbJhEAVfuZ03lfERf2Wr5wkc3lqERk3MZbtM8afFdJ/bnZjhSaiKGfvh27D0SsTSq+F9InDxqvFJCsQU3lQuZsiwMcyZmEGfZP++e6cjN4xJ59nv92u+ppXpUL3hk6Bqc2OtRtL1XpYho4spoKCmgBv+dwM2jw1FbTpFRlEVnLKTt3a9hd1r5zejf9NGowyO24fcjiRKvLztZVyyS7MzeR2SIGEUjTw28THO63leG47Sx/5iGxsPlp9I46tPcwYMTo/Cx1uO8JsL+xJn7XwRKqMkoqqnbpWt4rMz1Wl9FJeLmlWrSP7db1vtmr2To/n4rvHc+Mo6HG4ZZxN2+AI+ATbzrJ789qJ+Qd/DbJC4Ylh3rhgWuOdayp//xKGbZxB93rkYk3Vb9NMBU2YmrgMHwncDgwH0yJROENSfgxRFxebyYpJELEYRQRC45qUofipdi6nLCkTLMVTVgIBK7d5qFI9CZP8BOAum4a3pR6A6KS8ie9SezBWXcK3xR4TSfqA+B3TODdZQ6Rpl5vpRaXy4OR9HEKndwdTmRhgl7piU2aGCBJ2FM15MeRQPc5bMCUpI1cfpdfJhzof0S+jH5VmXh3GEoSEIAvMGz2Nk8kgW/LyAdQXrEAQBl3xyB8gomBFFOL/n+cwZNIfs+Ox2Getrqw/ikf2/82ANGAQB3t+Yz52TO1/dR5zVqBmNC9Uq2y0rxHfidMeOTO2Pa7D07Yuha9dWvW6/lBi+v38K7204zGur83B65QaRKrNBRAUmZXfljslZjMlout6xJZizsoi/6UaK/vY4PZ55utWvr9P2GJKTUR0O5KqqU3KeDISe5qfTEkRR8Eu5q7IreGsG4q0ZiGAsRzSVIYguHPk7ECOKcRy5u9nrepCoVCMRVBm8Mhz7CV47H258DzKnhOfDdDAevXwgB8tr2ZhXHpSgaq42V1FVrh7VozWHeMZwxoup5YeXU+Gs0BRSFasqfEWQxW4ki0TMyBiSr0lGivTtlDhlJ89seYbLMi/rcEp+eNJwnj/neUrsJXye+zl5VXnUeGo4WioQKQ7H7gAAIABJREFULfbkmcvmEW1qv3C40yPz2dajNNZSoRgwOD0Kr63O65RiymKUmJDdhZV7SxscD9Uqe1TPeKLMZ/yfcViwLVlC9AWtk+LXmPhIE/OnZnPH5Cx+2FvCT/kVlNa4cX20iP63XMdl43uTFG0Jy73r6HLHHRy4/HJs3y8nelrT5hc6HR9BEE5Ep6zDh7f+DQy+XlM6OqeKQTq5XlI9Ccge34aRILpQ7LagMjMEwCQ0+j167PDejTD7a0gNw99AB0MSBRbcMorb397Mqn2lyBrZLsESYZR46MK+ndbUq70541dhr+94HbvX7ne89JtSSr4poce8Hg36Mx188iAZD2cgGnypVTa3jY2FGxnTbUxbDz0oEq2JzBs878R/L9tVxLvrD7WrkAI4WulAqzVUqAYMZTUuXF65U9qj33F2FhsPVvi5uwVrlR1pkrijEwrJzoDqdmNbsYLEX/86rPeRRIGp/ZKY2i8JgANv/pHU9JuwhFlIAYhmM93+9CcKHn6YyLFjECP1viKdHXNmJu4wiSlBMuiRKZ1WISXGwu5jNr/jzWVmHHlxDqrXRfc7XsNqEuhCFQu2uFm43cOKW4/PXx47LJoF926jWdvATs6X2wr497K9FFY5UU5RSN06vhe3TtB7DraUM1pM5VXlsb/Sv4BPdsgUf+Yrgowe4hMdpkQTafPT2PvgXqrWVBF/ts/dy+F18MbONzqsmGpMv5Ro9mhMYm1NjdOr2WMpZAMGScTm9GKO6nxialxWF2IjjC22yraaDEzunRiu4Z3R1K5bhzkrC2NyUpve19C1K96SEujfv03uFzluHJGjR1Py3PMkP9Sx6j91QseUmYkrNzx1Uz4DCl1M6Zw6149OY0NeuZ/jaHOZGQAoCrbNX2AYdw3ni5v4XOsGtaVweC30HK/1aqdHVVX+/OVOPghgQx8skSYJFfj9Rf2YMa5Xq43vTOSMFlP7KvZhEA0N6okA7PvsKB6FmJENm1lKFonoIdHU7Kw5IaZUVHLKc9pszKdK97gIalxeKu3udjVusJokFA1zjFANGGRFxWrqfEIKfGk5L80YyQ2vrAt5QrQYRV6aOQJRK7ynEzyKArnfwe4voabYlzsSlYJzQwXR553b5sPxianS5k9sRZIeeogDl11O7GWXYhnQ8Vo+6ASPOSuTyg8/Cs/FDRLIepqfzqlzbv9kDJII+D/36jIzKr5fgLe6GFQVJAOmpCxUr5uYsVdRvf5jYoZfwFFTIqAxX3rs8OMzp62Y+vvXe1okpCRRwGwQ8cgKaQlW7pqcxWVDU0/bflxtyRktpmxubdMJuUbGEGVAkPwXqoZYA45DDa3TG1upd2REUaBPchR7Cm2cpdEAtK1IirHg8Z66AYPJIBLRiSeCoWlxvDxzJHPf2hi0PXyEUeL5m4czsmfrGxOcMbhssGEBrHsePA5w15x4SQUSVAGhZC2stcPI2WBqG6tYQ1Ii3tK2FVOGhASSfv0rjj36J3q9/x6C1Hn/ns50TJlZuPLywnJtPc1Pp7UwSCJXDEvl7bWHNF9XaitRPE4Sr/hdAxOq0i//hSmlN+b0wVRt+IxXp1zMcHZrXEGF/ctA9oB0etUArdlfyjvrDmkKqeZckA2SwB8u7s8lQ1NP6z5c7cEZ7akcYYhAwF8wSVES3hovqsbi1lvlxRDVUIOapM7lptavWww5he2b6hcbYWRidle/b79+mN++dy2Kx4kqe3HkbqJi+esNzjVKAtePTutw5h+hEhthxGqU6JsSjcUgouV0Lgm+aNSAbjG8f/tZTOun21m3mKqj8NIkWPkE1JY0EFLgC06JBhXBXgTf/RVenQY1JW0yNEPXrm0upgBir7oK0Wym4r332/zeOq2HKa0H3qIiFJer+ZNDRE/z02lNFEV787DOhCrhvLuw9h2PaLIgSAas2WMRLVEAxE28merNX/FNTTa1aoAei6IEzqpwDb/deH7Ffk0hVb3hU8q/e5XYs66jxy8X0v2uN4gecTGOfetPnOPyKHywKV8XUmHgjI5MdYvqpnncmm1FMAhUb64mdsxJi1nZKWPbbiP5moYL2WRr51rY9k+JZtex6nYdQ3mtG6tJ1OyCFawBgygI3Dq+V5uMN1zUuLzc+/5WHr9qCJcM6ca+Ihuvrc5j6a4iapy+lJpIs8S5/ZOZOymDfikxzVxRp0lqS2HBcXGkBrEw9DqgbD8sOAfuXAWW1recro/UtSven34K6z20EARB7z11GiAYjRh79MB98CCWvn1b9+IGg57mp9NqrNirvUEVjAmVKbEXEdmjqV73MXlJ3YB9GmeJILtbZ7AdhIJKBxsPVvgdD8UFOafQxv7iGrKTosI+3jOJM1pMDU0cSqQx0s/NT7JKJE1PomBhAaJFbODmZ0wwEjc+7sS5VoOVGQNmtPXQT4m+KTF8vOVou9zb5vSwYFUeb689yCWDu5EWH8HRSgeNN6maM2AwigIje8bTs0vndiB75LMdjMvqwiVDfMK+d3I0T1w9hCeubueBna68dyPUljUQUr3+Y8Pugbx7o4g0+aKcDRyiFA/YjsFHc2FGmOpRjmPomugzoGgH9N5TpwfmzEzcubmtLqb0PlM6rYnNqS3MgzWhipt4M4Vv3svRsWnaJ8husMRpv9ZJ+XTrUbR2oENxQfYqKos25vP7S9rG5OhM4YxO8xMFkVsG3IJF8rchTrw4keSrkyn8oJBdd+0i96+5GBOMZPwmA9HY8Gu7oFd4etGEi74p0ewrsgUMs4cDh1vm5ZW5TH1yBUcqHHzxy4k8duVg3pk7lsgQ+yRJgkBCpIlnb+zcfSQ+2XKE7UereOTSgc2frHPqFO6Awp994qgRsgpPr29iF1N2w8EfoOJg+MYHGBK7IrexAUV9utxxB86cPdi+X95uY9A5NUxZmbgOtH7dlC/NT49M6bQOgbLz65tQNYUxPpWYfuNZvCnAbz2+V5vVurYVB8tqcTduzkloLsheReVQeW04hhd2VFXF7vZSVO2kyuEJnCqqqKinYBXfEs7oyBTAlb2v5IVtL2i+ljA5gYTJgYv8zZKZa/tci1kKkLPbQYmNMBIbYSS/wh72yI7b68vRfe77fYxIj+e9286id/LJHle9ukby4Z3juPGVdVQ7vcjNCDyTIJMUF8UHd4yjS1Tn+t7rc7C0lse+2s3CuWOJ6KRuhJ2OdS9oCimAB8eb+MePLuaPNhFnCfCUV1XY8Apc8HjYhtheNVN16L2nOj/mzExqVqxo/QsbDKBHpnRaidgII5V2DymUMcPwLRPFn4nBTnkvgUkG6LN/IQf6zEAm8PMxacLV5O38Hr+4gDESJv4qvB+gHbC7tP/+QnVBDnSdjkql3c2iTfksWJVHWa0boyQgKyqiIDB9eHdmT+hFUbWLl1fmsvlQBW6vgiBAtMXI5UNTmT2hF5mJ4U1rPOPFVKw5ln9P+Te/Wv4rnLIz6PcZRSPZcdncM+KeMI4ufPTrFsOeQlvYxJSsqHz+01Ge+nYvGV2jePWWUQzpoR1y75cSw+L7zubpb/fx6dajCAJ+vZciTRKCADdLK5g//XJi4yLCMu62wO1VuPf9rdwzLZsBqXoNVJvgccKOj0HR3lkflSoxpZeBJ9e4eGxagIa5shs2vwXnPQZieIL6YnQ0qteLYrcjWttnV1Wz95SjEgq3+wq6JTNEJUG3oad9U8zOiCkzC9drrzd/YojoBhQ6rcnt2TbSbf9ktOBz47MIvo2uTCs8NsXI/y39iKeM33Kk5yW8pl5O1cFdRPabSESvYSeuIcYkU/lwV6xCY8MVFQadfrny8ZHaxhGhuiDHWQMbULhkF0sPLmVRziKK7cW4FTdRxihGJY9ixoAZZMVltXj8oeKVFf76v128vzEfQQCnxxeVO7np7ktZ/GBjPqJAg3IRVYUqh4f3Nhxm0aZ8BqbG8NT1w8K25j3jxRTAxO4TeWLSE/x21W+DElRmBPqKkbx0/qudzsmvjr7Hm/deMDClVa+rqipLdhbyr6V7ibMaefKaoYwNwoI9OcbC41cN5g+X9ufznwr4evsxyu1uJEGgS5SJK0f04MKBKZhyHPDdQ5D9A0id8+f7r2U5dI0yM6uTm2d0KmqKQGhaAP1lqpkJr9dy79gm/qZlNzgrwRoeW3pBEHzRqbIyTO0kpuBk76m4ib0xF37l68MlmfAl7Au+mjNzLIy/G4bdBBGnV21CZ8acmYH70CFUWW5dm3uDhKobUOi0Bru/5Kadt6MKDrRaJd4/3kxKlMhTq6rY/elCIkzv404ejHXcTSfOMeDlSmmVv5AyWmHK7067FD+A4WnxfLLlqN9mc3PNjuOnzjlxboRRYmTPeL9rO71Onv/peT7c+6Evna6el0Cpo5R8Wz5fHviSrLgs7h95P2O6jQnfB8W36Tz7jQ1sOVyBy+uf2lhHnX4KlNTkVVS8ispP+ZVc+uxq/jvvLAb3aH0jqc65Gg0D5/Q8h7ej3uaZrc+woXADqOBWGtZQWA1WzJKZmb2v4dZVr2LMXQH9L2ufAZ8i/VKiWbKzsNWup6oqq/aV8uTSHGRF5feX9GdKn8SQbcutJgM3jknnxjHp2icMmA6b3oCNr8JZd7XCyFuXQ2W15BTasDm9WE0S3eMjGNw99sT38MPeEj7fWsDX907q9JbunQp3LQhNLywHJUlc2sfAE6vd9E8MILxEg89KPUxiCk427jWlBSisbgMMURYyrgRpyRxUSUBQZfA22mhy18L3f4Xv/gJXvgQDp7fPYHUaIFqtSAnxeAoKWvU3JEh6mp9OK7D/O/j4NgSvo8nA9s1DjNw8xBdBcakSuaqHK91Z1EknAzJzpMUN32S0wpDrfZs8pyGXDOnGHz/foflasC7IKipXjezR4Fils5I5S+dwuOowLkW7rYKsysiyzK6yXcz/bj4PjX6Ia/te2zofrBGqqvKrD35i8+GKE9GoU0VRfaYnNy1Yx//untjqESpdTNWjf5f+vHjuixTVFrFo7yLWH1tPtbsak2giJTKF6/pex4TUCUiiBMkT4L/XQfIgSMho76GHTL+UaP61JIddBT6L9DirkW6xlhYt8DcdLOefS3IoqXFx/3l9uWhQCqLWdlNrIAhw8T/hjYtg4FUQ3f4Wzl5Z4bs9xby0Ipddx6oxSiKqqiIIAoqqEm81cefkTM7uk8iDH23jqeuGkRDZOSOanRZzVFBW6H+eYmHEyzXcPy5APZ7iBXO09muthJTYtd0c/QCfSHrtfAzO/QiioukedQLP8d3LT+8ERwWMmt0mQ9RpGnNGJq7c3FYWU6Ke5qdzajirYNFMX8uJ4zTrpgqYBZkMjvF7w7s86p1NBE5mSsvIFgt8F5FMvrXBhHth8kOnbfqxxShx3ag03l13CI9GKKY5F2RJFLhkcDdiLCfT/BxeB3OWzCGvOg9vgDT4xrhkF//Y+A8ijZFcnHlx6B+kGdYdKGd5TrGmkGquMXGzr7u8PPjRdhbd0bzzYSjoYkqD5Mhk7h5+N3cPb2J3o8dIOPsB+PBWmLsUDJ3DDMHu9vLFTwW8uDKXwxUOrnt5LQAeWSEx2sydkzOZPrwHUUE47O04WsW/luawt6iG+87tzZXDu2PQ6jjb2iT2hWE3w7eP+nbE25GjlQ5ufGUdZTUuao+H3huHpO1uB3//eg9/+nIXlw3pxvjsru0x1DObyCRfEnUzZCeIXD/QyDMb3AxO0uqebPalt4URnwlFO4kpVYUPZkLZfoTGkaim8Dpg8e8gLh2yzwnf+HSCwpSViTv3AEyZ0noXlfQ+UzqnyE/vac7DdW6qv58UeB0VIXi4VlrJ096rONuwi99ZvwIhGhBh9FwYPQ9iu4dx8B2DeZMy+HBTPh536BsbZoPIXVOyGxx7esvTHLYd9hNSFasqKF1SirvYjWSRiBkZQ/I1yUiRvgwPp+zk0TWPMqbbGLpGtHxNs6ugmrzSWmpdXqxmiV5dInl5ZS4Ojc9XveFTqtZ/RJfzf4ElYwSCZMCRtxnHvvVYegxs9nXwRai25VdyuMxOepfWSwXVxdSpMPZOOPQjLP2DL1rSwfloUz6PfLETOGnwUOM6+Qd0pMLB41/v4bGvdvPwxf2ZOa6X5nVyS2r497K9bMwr5xdTs3lp5kjMhjZ2pJv8G3huDBxeB+lnte29j5Nfbuey51Zjc3iRm1mo2493LF+8s5Dle4qZ2i+pLYaoU4fRAoOvOf4wb3pB+MhkM+9s13D9k0y+yEuYzCfqMCQmtp+j39HNcHhtg5S+YHaOAZ+g+uYhuHtTW49apxHmzCwcP29v1WvqfaZ0TglVhTVPn4xm1yMoN1VAReA6aQWFhp7kT3mK9B49oPtIkAIbKpxu9Ii3smDWaOa8uRGHJ/i/R4tR5Lmbhjdo1uvwOvhk3ye45IapfaXflFLyTQk95vVo0Gf14JMHyXg4A9HgewaqqHyY8yF3DQut5MLpkfnf9mO8tGI/RyudSCLICkgieGUVp0aNVHONiUNpXKyoKm+syePRy1qvLc0Z3WfqlBEEuPw52LcUdn7a3qNpkueX7+cPn+/A7pb9ihfrY3fLOD0Kj3+9m38s3tPgtSMVdh78cBvXvrSWQamxrHhwCrPG92p7IQW+VKvz/wpfPdAuu6V2t5frX15LtcPTrJCqj9OjMP/dLeQU2sI4Oh1Nxv1C07Tk4H3RnJt58nharIjzDzENhQL4/t7H3BbuUWLomojcXmJqzTP+tVEE0YerjuqjPkGm066YMjNwt3KvKcGgu/npnAJHN6M6qzRfqu+m2hSRgovp0hq2R01g+nexFMUNO6OEVB3jsrrw1pwxRJkNWAxNL+PNBhGrSeKVmaOY1q9hWcTivMUINBSvskOm+LNiUmekEj0kGsEgYEo0kTY/DXepm6o1J/8fumQX7+5+N+j0QIDdx6oZ/8T3PPr5DvaX1OLwyNS45BP/1hJS0Hxj4lAaF3tklc+3Hg16zMGgi6lTJSIOrn3Tt6gvy23v0WjyyZYjPPv9vpAK+RwehTd+PMjCtQcptjn50xc7ufTZ1aTEWlj+wBTumpKF1dTOgc1BV/u+/02tbwPcHJ9sOUKF3aPpIFO7awXH3rqPw/++hiPPzaRo0aM4j+w88brTK/PvZTltOFodAJL6Q/cRx13pQsRghsxpvjS2MGNI9BlQtDm1ZbB3Maj+88SD4008ucZFpbOZjQOvE9Y8F6YB6gSLOSsL14EDrdu4Uk/z0zkVqvJxyoHF0l+mmnl2g5uS2qbXKclCBQdKa7E5PUx44nu+213U4PWiaifLdhXx8eYjfLGtgHUHyvBqNLrt7IzJSOCH30zl3nN7E2U2IIkCZoOIKIDJIBJploizGpk/JYsVD07h7D6JDd5fWOXklZ/eaeDaB2DfZ0fxKMSMbNi2RbJIRA+JpmZnTYPjHsXjM20Lgp+PVHH1i2sor3WfKIsIluYaE4fSuBjA5mrduUxP82sNUofDlN/Ch7Ng7re+lKIOgssr88jnO1tUyOfwyPzpy138Y0kO14xM49tfT6ZrR2qUW2dG8eYlMPBKiEps/j2tgKqqvLTygGaIPZicXVWFFTkllNa4Otb3eSZw/bvw0kSfVXqwu2miEWJ6wNWvhndsx/G5+bVDzdTRzb6aMK//gieoPlzgE2IHV4VxkDrBICUkIAByWRmGrq1To6n3mdI5FY5VHiBG0f79LNnvYXGuTK0bev2nhtHdJQK5YZvwpWB7ZN9Gwe1vb+aJqwfRPd7KKysPsOZAGSZJRFFVBAEEwCCJzB7fi5vG9iQx+vR55iZEmrhrSjY/H62iW6yF9IRIbE4PUWYDPbtEMql31wZ17Kqqsia3jJdX5rI+rxxTRqGfCpBrZAxRBgTJP93SEGvAccjR4JjT6+RYzbFmx1psczLjtXVNZkY1RXONiUNtXCwH8lJvIbqYai1Gz4ODq2Hxb+Gy/7T3aE6weEeh5u5kMIv+OuZPzebOyW3XqC0kkvrD0Bvh2z/B9Ofb5JabDlVQXuuf8hRKzq4gwHsbDnP3tN5hH69OPawJcNv38OalUHWkgauUFqoxAiEhG2Z9EXYXvzp8BhTtEJlyVmlGpeoIqg8X+KzjddoVQRAwZWXhyj3QamJK7zOlcyp8V7SJKzSOV7tUFmz18N7VVnrGCoxZUMtRm0phjYJbVjE1WtTbaSiGZFXlwY9+xmwQT5g/uf2UmMwLK3J5YUUuf5s+iGtGtV/bidbGKyus3lfKsl9PJjkm8EZXld3DrDfWs7eo5oSgMSD7padJURLeGi+qrPoJKm+VF0NUQ9mgoLAoZxFX9b6qSTfo11cfDCikmtvYh+YbE4fauDiilTOrdDHVWggCXP4svDIZfv4IBl+DqqqszS3jrbUHOVRmx+GWiTQbGJoWx+wJveiTHP7F2Usrc/3CqaEs+r2KynsbDnPH2Zkdty/S5Ifg+TGQvwHSwttIDmDzoQqNyTq0nF2nR2HVvlJdTLUH0Slwx0rY/LavRshV5bMDP4GAarTirfEi952N5ZpH2zTaLHXtire8HFVREMJsdtHwxgYg8N94UH24wNeLS6fdMWVm4M47QOTYls+JFbVuFm3K55MtRylzDkXdJRH3rxVcMDCFmeN60i02ohVHrHO6UuOu4cOqXVzT6Hi1S6XSCb+fZOLCbN+8ccNAIx/v9uD0wsLtHuYMN/Hfnz38e62LPaUKZrMDe+Kjfovtphq71n/9D5/voNLhYd6kzBZ/nmp3NV/s/4JNRZuoclVhNphJi0rj6j5X0y+hX4uv2xK2HK4kLcHarJC69LlVFFU5ccv1NtcVE9Awzc+abUUwCFRvriZ2TOxJV78iN6pXxdrHilwrn3D1A9hfuZ8V+SuYmq5ty+72Kry7/tCJaGJ9gt3Y12pMbM/5kao17+OtKkaKiEGK7krZ4ueabVwsAGMzWrdXpP7Ua00sMXDtWyhvX8nCwnRe2Gij2unxU+M5RdV8uvUI2YlRPHBBX6b0DY+zm8Mts7fIf5c4lEU/wLFKJ2W17o6bkmaJgfP+Al8/ALctB1FCVVWKbS6qHB5EQSAh0tRqvZ0q7W68GiHiUHN2qx0ajnE6bYMpEsbdBWfdCXk/QM7XvtQ/BIhOQeh/OfatRVR++hk9b2zbtF3RZEKyWpGrqjDE+3eqDxuRSU1pKSCIPlwAEW04Zh0/VFVl06EKXrYM58B2Ffeh74k0GeiXEs3siRkMS4tr9hqFVU4e+2oXy3YVIQgcTxM3ggxlJbUsWJ3Ha6vzGJuRwB8vHUDvNtgY1Om87K/cT7HZym6TkeGuk1kda/JlVGBS+slnZp2bapcIgWUHvFQ6VZ5Y7ealSy1MzDZyn/c3fJvr0syiCQanR+HJpTlkJkYyrZfFZx5Wug+clRCR4Mt2GTAdTP622XlVebyy/RWWHVqGiIhDPpnZIAoin+3/jB7RPZg3eB4XZ1zcJhvQ3+8pZmoTa0hVVZn1xnp/IQXIznQEYxWCcPK4ZJVImp5EwcICanbUUL21mpTrUqjcUImn3INoFP1c/dyKm9d3vB5QTC3ZWYiisWYKZWMfGjYmLvn8CVAUTMmZJF37FyxpA3Hkbca29ZtmGxdHmCTuOLvlYloLXUy1Mq7Egdwd8Q9WfX8Uh6q9eJcVkBWFHQXV3LlwM/ec05v5jbz/W4NKhxuTJOJolKcc6qLfKAlU2j0dV0wBDL4WNr9J1do3+VidxqurDlBe68YoCagquGWFvsnR3Dkli/MHpGBqxgGnKSwB3AtDzdk1tUVPLp2mEQTInOz7pxExqR6Kn3kGx7ZtRAwd2qbDqmvc26ZiKm1ss1GlZvtwGSwwbEaYBqjTFKqq8tHmIzz93T7Ka9043BZfv+UK34JvX7GNpbuK6BZr4dfn9+HSIama18kptHHDK3VOpdr3qovMr9pXyhXP/8iCWaMYn6X3z9PRptpdDcDrsTH8vaSMqOOlB6V2heRIgQuyTzry1bmp/vZbJ+uPyDyy3MUbV0RwVX8jxZLI2trBWLMlzcV2Hc2ljaV7D+H+6GVU1iIIYkO7dmMkfHW/r3xg3C+gi6/EYfXR1fx6xa9xyS4UjXRoRVVwyk72V+7nz2v/zPL85fx94t8xhtltcEVOMX+7cnDA19fklrG3qMZPSAGUrzBj27wfd6EL0SJiSbeQeFkiiRcnIhgFCv9bCCIUflhIzIgY0u9MBwn2PriXqjVVxJ998vm0u3w3h6sPkx7jb9K0Pq9M03Ai1I198DUmtmaP4cjzs+hy2X0NUvoCibDGJFhNjNEjUx0XRVG5572t/FBixdlE7UF9nB6FZ7/bT4RRYvaEjNYdkOrz029MqIt+ALGDZvjVoQIvdH2YZ74sRjTuwXHccKO+YcuOgmoe+ng7v//kZ56+YXiLez11iTI3yM+uI9Sc3ZS4jmNUouOPYDTSZfYcyhYsoMezz7bpvU/Yo/fp03Y3lQww5nb48T+aJhR1BOzDBT53lVGzwzRAnUDIispvPtrG1z8XBuw9o6g+U6EDpbU8+OF2Nh+q4JFLBzTYPT9SYee6l9dSFWTUXMXXTmPum5v44I6zGNKj+aiXzpmH6biD6g/WCEolCbPXixHoahUptat4FRVDo0XGsRoVFXB64cr+BhyCwDNx8ai1Ta9Zmksbu0ZawV8Nb2L0eBEEjXWa53jK95a3Ydt7cPUC1sd25VfLf4VTDq6RucPrYEX+Ch5Y+QBPTX0KUTj1jVOPrPDtriLWHiijrMaN2SASE2GgoNLRZLT5pZW5mrVKvu/pE7rNzCR2hK8+yvazDdsWG5F9IjEnm0GAga8M9KudqnP1qy+mPB4jf/p6NZXlPam0exAFSIgyc/WI7hRXaz9PQt3Yr6MlIqwOi1HkL9MHtnrUUBdTrcgHm/L5YW9pyM55Do/M/y3ew7isLvRLidG4ctOoqkpJjYu9hTXsLbKd/KfQpplHHOqi3y0rxFtbJ0UuHKiqyu8//ZnPttpwYYImLOBrXb5J5a5aHdQpAAAgAElEQVR3N/PXKwZxbZCFqA63zKp9JSzbVcS3u4s0a6a0cnoD5exGmiRuHB1+m22dUyPu6qsoffFFXLm5mLPazoSl3Rz9Rs2BH59ucOjgfQ1TuOp2jv0QDJB9DkTpDanbkrr57+ufj53YRGoOh0fm/Q35WAwiD13U/8TxX7y7hRqntslEc8+wOW9uZP3vz0Xq6DtvOm1OUkQSsiqjCAJzuyXx4dFCYhSFcT0kzAb4ZLeX6waejODUuFW+2efl8r4G9pQquCWRT6Ij+dTa9NzSXNrY9dJyHjW8RYQQRM88xQOKh6pP5nFPWipOxf89J+qJit1IFomYkTEkX5OMFCnhkl2sLVjLu7vfZeaAmcF/WY0orXHx1pqDvLXmILKiNojw1P2pXfnCj9w1OYsLB6U0EAnHqhxsyCv3/2j1vidzjwQE0zsIgkrM8Bhihvvm9mBd/RR3PK6SC/DaBrJSFJHlypMnltSy42gVrgAbPI039oMxooCWizCLUeR3F/X367fVGuhiqpVQVZUXlu9vsV22R1ZZsCqPJ69tOp2ootZNTpGNfUU2cops7C3yCSiAPsnR9EmOYmBqDFcO706f5Gjmvb2JzYcqGlwjlEU/QHZSFPGtVG8UDp79fj+fbS0IqRu406Pwx893kBJrYVJvbUv10hoX3+8uZumuItYfKGNwj1jOG5DMPef05m9f7WLJriIaB/7q5/Q2lbNrMUp+fR90Oh6i1Ur8zTdR9trrpD7+tza7b7s5+kWnwAWPw9KHwdO002FDBF/Pt0ufCtvQdLRZsrOIL7YVaAqp5gTQm2sOMalPIuOzurK3yEZOoU2zCXkwzzCHW2ZFTjHn9G/9hYpO5yYjNoOuEV3Jt+VTbDBwXfcU3jhWREKEwqOTzdz9jZMYs8A5GRJHbSrzv3LSI0bg0r4G3vzJwxuRUbwY1wVPxagm79NUxGKIkMujhrew1hNSvf5jw+6BvHujiDT5RMOCLW4WbvecaNr+sUVC8TqhkRlQ6TellHxTQo95PYgaEIWnwkPBOwUN6okcsoPXfn6Nm/vf3KLo1O5j1dz0qs9OXGtjvK4MafuRKu59/yeykiK5eWw6vZOiGZORwLb8KoySfxZN/e9JrpFQPHFIpobrxGBc/WRHD+yH54JiBkRkjSVYU1bo9Tf25eqSoB2mQy6pON5/65/XDOGyod2bPb8l6GKqldh0qIKyU7DLlhWVL7cV8OhlA4i2GKl2eth3XCzlFNrYV2wjp7AGp0emT3LUceEUzYUDu9EnJYrEKLNm2PLOyVnc9/5Wv3zVYBf9kWap49qi4xM8zy/frznRNLfL4fQoPPTRdn787bQT392BkhqW7Spi2a4icopsnN07kUuHdONf1w4l1npy5+z2yVms3FuqKeCiBk4laqB2ISb4dkfmTszQd3A7CfE33UTuhRfhKbwbY0pKm9zTkJjYPpEpgNFzwVEBPzzZrHU84ItIRcTB7G98YkynTXlh+X4cAdN4mhFAHpmXVhxgfFZXXludh0fRWLAF+Qyrdcu8tDJXF1M6fgiCwJxBc/jHxn/g8DooMhi4sns3Lqy1M2dKNdFWkfuXOjlQoRBjFrikn5HXro5gpTUCjE7+76BITBcRT0XTaV1NRSzuNnyKBf81mqzC0+vd/H6Sf024ArwdG4OzkZCSHTLFnxXTfW53oof4IvemRBNp89P86okcXgdrCtYwsXvzWUD12V9cwzUvraU2yOaybllh9zEbj3y+E7NBJDbCxOhe8cgaf9ONvyfF2d1PTDV29TvxXqeMbbuNxOm9sR++7biQahl1G/tlS19EddvpctHdRGSPabCxr1UDFUp2lQDcclZP7j23N9GW8NWv6WKqlVi47pDmwjqU3E5ZUbnmxbXYnB4q7B56J0fROymavilRTO6bSN/kaLrFWkLK9ZzWLwmTQdIs/mtu0Q8gIHDhoI67QPrv+sOax4O126x0eHh77SGOVTlZtqsQm9PLuQOS+cW0bMZndcEcwGxiRHo8M85KZ+G6wyFFxEySSL+U6FOyZdVpWwzx8cRNn075m2+R/NuH2uaeiV1x7t7dJvfS5OwHICETljysYR3vQ5XMCALQaxJc8bwupNqB/cU29hbb/I6H4pK1Pq+Mgko7n289iqyRJRjKM2zbkSrKalx06chmRTrtwsUZF/PPjf888d8uUeTz6Cg+j45iYKKLGee4iVEU3IJAqSRytdWKTRJJvFKk4J1jKO5UxGgLgugNmEUTKGKRSCWTxJ81a78fHG/iHz+6mD/aRJyl4QmbLGacGm+y77OjeBRiRjZMd5Yskl89kd1r551d74Qkptxe5XhEyl9INbdJ7KuNVHB4nCzeWahpR974e5Jr+mKI3IsgnRSb9V39RIvYIPpmjDdiTn0IFH9xEmyqXh0xY67CW1uJbcMnlC1+jorlr2tu7NcnlOyqbrEWHr6kf9idFXUx1UocLrP7pXxBaLmdXkWlb0oUD5zfjx7xEYitELmQRIGnrh/KnQs3a9ZyNYXFKPLktUMCCor2RlZU3vgxzy8qFcpCwu6W+cfiPdw6oRf/um4YQ7rHBv29/+6i/tS4ZD7bejQoQWUxiPROjuatOWNPyU1Qp+1JmH0rB66YTtc770CKC3+Rfbul+dVn0FUw8ErIW+mrozqyyVecLRrwOgXktEswX/sYxGi7wumEnw825uPRUEChCCAV+GjTEc12DxDaM8wkiZToYuq0YWfpTtYUrKHUUYpBNNA1oitT06bSK7ZXyNeyGq08OflJfrXiV7jkhoYEO81mdpq1fzNdLuyKFB1JyZfFeMpuDphFA4EjFtdJywOOa1SqxJReBp5c4+KxaQ1NofINBrSe7MHWE9VxuFp70zcQS3YWUuv2+q0pg90krkNLSIH/9+SpHoo55Qu/8xIvTkSKlCj8oBB3sRsxQiRmRAzdZkzGXR4NasN1TKjjq8OUlIEYGUfaLxcG8e34CCa7KsIkcXsb9UjVxVQrEWgxHWpuZ3KMhfQu/v0NToUpfZP46xWD+OPnO4IWVHWFehcO6taqY2lN9hfXnHLzXPBNOA9eEHqjPVEUePzKQQxMjeGpb/fidMuaEUCrSUJV4YbRafz24n4dVpzqBMaYkkL0OedQ/t//kjh/ftjvJ7WXAUVjBAEyp/j+qUfF88+jlNtJ1oVUu3KwtFYzmhSKAHJ7FXZ9sxzB0A0E//NDeYYJEPKmnU7Hwi27+Trva17/+XWO1R7Do3iQVd9zzSgaef6n5+kb35c5g+cwNW1qSLVAk3pM4q/j/8ojax4JyhlPVUVU2Yqx2/10m9V8jXHAiMWhTTxy2MY/ztN20P3LVDMTXq/l3rENa8PtoojWrzmYeqL6OIJJl67HSytzT5hl1RFqT6am0Pqe3GXDcRV8j31PDSnXn8wySJicQMLkhjbi9kPnQqPWP6cyvpY4TEPz2VUCcPXIHkFf71TQxVQrEW3R/ipDye2UBIgLk2vetaPSSIqxcP+in3AEWPSDz2XOZBB54uohXDCwY6ftVNrdmlGkUJ1ePIqC26u0KFokCAIzzurJTWPSWbmvhFd+OEBOoQ2HW8ZkEEmJtTB3QgaXDU0lwqSLqM5Ml7lzODTzFrrMno0YERHWexkSE9s/MtUE1hEjKHmmbe3idfwJVNwd6uJETU3DWxLcLnZTKKpKTIBnoU7Hp9xZzrwl8zhSc0RTAHgUn2X+9tLt/G7V7xiZPJKnpjyFxRB8m4+LMi8iJSqFJ9Y/QW5VLl7Fi8cZj7t8ArKjJ6piBkFGNFYgWgp8phNK8E2htSIW33SDJ88O/LsclCRxaR8DT6x20z/x5DogUlHQ+utprp4o+ZqGdYNWY/Ab5LklNeQW1/gdb26TuCXpdQ2/JwvWTEi6vEuT41O8VmRHBtBwvXQqduWhOkwHg8Uo8vLMkWGtk6qPPuu1EqN6+pxT3I22CUPJ7bSYJAakhm6NHiyT+ySy4ffnsnJfCS+vzGVDXjmq6ksFVFSV4enx3Dk5i2n9kjqFOUKg0G1LdjlONQosigJT+yY12Ylcp3NjzsoiYsRwKj/+hIQZN4f1XlJsLIrdjuJ2I5o6npNmxJAhOPfsQXG5EAOk5+iEn/qmOPUJuedddjppcgmHy+1+r4Xq/to9PrwbDTrhocpVxY3/u5FiezFetXnTA4fXwcbCjdy29DZev/B1jGLwi9bhScP54LIP+Oin7TyxeDdVVRZUVYB60kV2JyHX9m3JR/GLWPzZ+CTjpS1NvufPUyyMeLmG+8ednM88rmQcqgA07LvWZD1RgpG48SdTwQUEsuOygx57XkktRknE2SjrpqlN4pam1zX+ngRzARE9X0JV3QHXRKo3BgQvqA3lQ0vtyiH0OaY5LEaRp28YHtCpORzoYqqVuPmsdF7/MU/ztaDtsg0SZ4f5f379Rb9XVrj2pbXcNTWLaX2TMEidq44n3mpE1sjzD3UhYZJEjJ3ss+u0D13nzePor+8n/vrrEIzh2/ESRBGpSxfshcVEpnVvk5zvUBAjIzFnZuLcsQPryJHtPZwzltE9E1i+p8QvzTzUnndD0+LokxLN377arRntCuYZZpQEbhyTrqcxd1LuW34fJY4SPyHVXC+lPeV7eGL9E/xx3B9Dut/baw7y+DdHcXqiWjTeSLOE57jgcAeoDapjr9qDyep2TEJgkZidIHL9QCPPbHAzOElkrdyfPzkfREz6D6JU6Xd+oHqitDvSEI0n1xMRhghuGXBL0J+r1u1F0SjAD7RJ3Jrpf6orlZp9D2Pu9jGGyBxEg0azXdWIr9IyuPEFS7DrZACzwd/u3WQQEYCxGQn85sJ+DOoe6/e+cKKLqVaiR7yVkT3jWZNbpvl6s3bZhra3yzZIIoqqkhhl7nRCCiArMYpIs8Hv4R/KQkIU0KNJOkETMWwYxh49qP76K2L7R8Cuz6GmyPdiVDL0vxx6nwcteJgAVDk8fLz5CK+tPsCxsb+GF7cB20iKtjB7Qi+uH50WtlTgUIkYMQL7li26mGpHrhrZgycW79F8LZTFycWDuyErKo/9L7CDZHPPMFEQmDW+V8ifQaf9ySnPYUfpjhNpfHUE00vJKTv5bP9nWGsvJa9Yxeb0EhNhoF9KNNeNSicx2j9y/eGmfP7+ze4W1dcZROiTHMOdU7K4cGAKb689yJNLc5q81vvyVOZK3zR77Ucmm3lnuwcnJuZ5HsCBBWPZZMzJXyOIHr/zteqJGhNrjmVkcvBzpNVk0CxfCLRJfCrpdZqoZlwFN4Klmv5DP+NobR4G8aRUcHgUXIK/MUdrpOoF4zAdZzXyt+mDeH9jPkXVTtxehZgIIxOzu3LLuF6kxAafctqa6GKqFbn//D5sWbC+RROE2Shx45j0MIyqaVwtrBXqCIiiwG2TMnhq2V6/hpWhNM+97WzdplwnSFw2Us6NwbBhHupOE0Jjy/Bdn4PBDGfNh7F3gDm4XH+3V+HPX+7ko81HEAXBF2kQxBMbgIXVTv7z7V7+vWwvlw1N5bHpg7AY2zcCYB0xnKovvmzXMZzpxEYYuWhQCl9uO6bZbLe5xYlRErhudNqJ39JtZ2fw6g95IbV7AF9azXn9k0lLaF3zJJ224Z1d7/gJqVB6Kbm8Kgt++gBH6aQT71+6s4hnvtvPpN5d+cXUbIan+87NL7cHNMNqru5HAOZPyebX559M/5s3KZPMxEj+8uUuiqpduLwyjRNWSg3d2K5mMUZouPFw8L6G83NarIjzDzH8yn0XXyi+TStP9QjMiUvhuJiqXFvpi9QdcyNaRCzpFhIvSySyT6Tf57EYLMwfNr/JzAJVVVmfV86KnBJKbE5qXF7NvnGBNokdh7aBwdiiaFBgBAQ5niTbb/jPFYkU1B7F7rETaYykm7UnF/1rOw655ZvYLSXCKHHPtN5cMiSVS4Z0LPMjXUy1IiN7JvDnywfy6Bc7QxJUVpPEu/PGEh/Z9jvOblnp1GkZ141K419L92q+FswuR2K0mRHp4be61jkNqD4Gb16MqboAwegBt/9OJe4a3z8//BO2LoRbv4LYpjuu291ebn51PbsLqzWbT9dRt2Hw5bYCdhZU8cEd44hpo+JaLSJGjKDwz39BVRQEsXNuyJwO/HJaNkt2FoUsgMCXGjN3YsaJ//7VuX3YX1zD8j3FfhtUgbAYRPomR/PkdUNDvr9O+2P32Fl8cPEJx74Tx0PopSSIHsS4VVBPTNXNZd/tKWb1/lJ+d1E/Zo3P4I0f8zTT84Op+1GBhesPc++5fRpk8Uzrl8zUvklsO1LFKz/ksvFgBXaXF4MkkhBp4uax6Qzs8hjuD2dgUpt2EaxSrXytjEWuq99SzNjz52Ht+TJlSwso+aqE1FmpRA+ORpAEbD/bsG2x+YkpixTBZZmXckXWFdrfu9vLoo35vLLqAJV2T0AzmfpobRIbohPB425Rel1T4tXpUVi1r5TPN8Zw77m+9MH8cjsfbjwSsJ4q+JIWAUn01YVp/Ra0iDBKXDw4hdkTeoX0GdsKXUy1MtePTsdikHjwo+1+ZhSNiTCKmAw+IdXW+Z11uDwK5k4amQKf++HvLurH41/vafb7bozFKPLv64Z2uHoUnQ6IvRwWTIOaYgQliI70XidU5fvec+caiNR2SPLKCvPe2sSuY00Lqfq4vAq5xbXMen0DH9w+rt0iy8bkZESrFXdeHuasrHYZgw5kJ0Xz3E3D+cV/t4S0iWcxirw2azQ94k9GkwRB4NkbR/DoFzv4aPMRPF5VM+IFvihBhEliTEYCL80Y2ak35c5kjtUewyAa/Ho/hdpLSTDYABka+d+pqs8u/4lvcpAV9XhvtIa/qVDqfspr3Qx8ZDFXjejBnIkZZCf5aq4EQWBYWhwv3BwopS4Tx9i7cKx7gQg0aoGO87H3bBoboivO7tTkzKD40wdInZNO7KiTdV4xw2OIGX5ScKoqoBoZ3+0S/nDWHzTXF0XVTm54ZR2FVY6gNy3qaLxJrLhqOfL8LSGn1wUjXh0emVd+OEDPLpF8vOUIPx+t4oqhqfzn+mHc895WP5MMrfFpoSLw2S8mMO/tTZTYXE0KSUHw+QncMDqNP146oMOu13QxFQYuHtKNfy3LoV9KDD/mliIg4PLKeGUVo0HEKApEWQzcNimTa0elERvRfrvLvshU5xVTABmJUYiCdlFiIOrcXkb2bDrfWUcHgEW3QE0J1BNSvf5jw+6BvHujiDT5JvgFW9ws3O5hxa2RvnNry+GDm2HOYs3LfrGtgK2HKzV/t03tGrplhT3Hqvlg42FmjusVlo8cDBEjR2LfskUXU+3MOf2TeXnmKO5auBlZUZucBy1GEYMo8sbs0Yzu5T//SaLAY9MHc8PodF5bncfXPx/DIAnIioridGGMsOBRVCZmd+WOszMZk5HQYRc4Os1T46lBwP//X6i9lFBFand/R/WGrzXnLIdH5olvcjTrwkOt+3F6FRZtyueTLUfo1y2G/2fvvOOjqNM//p6Z7dn0SgkkhN57l3YqJxY8+4lwVuzlbOdZzu5P7/TEXlAs2Lg7RbFgQQHpLUivCSGBAOlt+87M748lIclOkt2QEJB5v168XmR3duY7m+x3v8/3eZ7PZ9blA0lLCC6zq4918qNsK6wkI2sulgYCql/V/rgI7rtxZVeieMGSNgVVXgsoCJL32O0rEiAgO9PxFo8nveM5mh5cJQ4vU19dTlGVt0Gj7HBoTnldOMGrwyvz6i97uG1SN2bPGFpTEjwqI56VWcUhr7mqsRolrhzRiW7Jkfxw1zgWbMrnjSVZHC5341cUfLKKQKAFQ1FVJvZI5IZxGQzpHBv+m3MC0YOpVuCj1fvpHB/BW9OH4JUVftlRwMEyF06vjN0caMwclRF/UnwBeXzyKdszBbB8TxF3z/uNj28Ywa7DlTz93U5Q1UZ9tCItRl65cpDmQkJHJ4iiPXBgHSjBZX2yCi+t8fLgGQ3IgyteyN8IhbsgMVjm942lWZrlWaHtGiq89Ws2V43s3GZziW3wIFyZG4m99NI2ub7OMcZ3T2TJfRP4ZHUu76/MweeX8Xq8yJIBoyRikAQsRonrx6Zz+bBOxDVRVt63QzQvXj6Qxy7ow4q9RZQ4vBx+7p9k3HsnY/p3IimybRq9dVoWq8GKqqHOFq6XUtGPRyhZ9EGjc5ZXVhCP02S6Gr+i4ldUNh8o4/xXlvPxDSPo37GJkn1BQJ7wCLfsiuc28Qt6CzmIKJiEY3NwqaqtLlg9Rl/JFHwlkzFEbkO05iJKDlTFhOKLw18xENUfeK+KHNrB2vUfrKPYERxIhesTVZtwxGYg/ODVKIlcOKhuufpr0wZzwasryC1x4g1jE3tYeiwPTul19GeJy4amctnQVH7LK2N9TgmlTh9mg0iC3czkPsnE208N6w09mGphSh1eXv1lL5/OHIkgCJgNEuf0a9fWw2qQU7lnamVWEXd+tpE3rhrCkM5xDOkcx8VDOrJwy2HeWLKXPQVVmAwiqhqYeEdnxHPT+AxGnySBrM4pwuo36mSkanPfaBP/XOHhlmEmYiwN/E3Jflj9Opz/Up2Htx4s50BJsDFmuCUv6/eXttnGgHXwYEre/6BNrq0TTFKkhbvO6h7oo/rfz+xesRnjhRdhNxvokmhnbNcETaWwxoi2Gply9Dtsb8UuOnW2YdIDqd8NybZkfHLwRlE4XkqyU6ZwfgFxf7y/yTlLKxlzPLLaigqVHj/T3lnDN7ePpXN8wxmqlVlFXPf+elz+AfzCANKEQ1wp/Ux34QCRgosK1Uaxqu31WX+M/sr+UNm/wWtZNQSCth4sZ8ehyqAyx+b6RNUmlPK6asINXvcVO4Ies5kMfHHLaK6es5adhysbL9UjUBL8h55J/PvygZrZyYGpMQxMPXX71/VgqoWZtWg3U/q1o3ty6I7dbYWqqqesmt+a7GJu/2Qjr145mOHpxxaSZoPEhYM6cOGgDrh9MuUuH5IoEGUxnpL3qdPGyD7Y9GmDwdTQ9hIT0gw8v9LDU5MaWGCqftg8D875FxiOZQO+2ZyPxx/8BRTOrqHLJ/PlxoNtFkyZu3bFX1KCv7gYQ7x2X5jOiccgiYx0HmRIskzKpG4tdl7Rbkepqmqx8+m0PdHmaIa3G86KgyuCMlSheik59rhRfGqz5blDkdVuKnPj8Pi5/3+bmXej9hh2H6nk+g/W16kEyFHb8Yz/qhYbYzUmSSA5Kvj74J1l2UFZnJb0iQqVcINXj09BVdWgTegoi5H/3jSaRTuO8ObSLLbnVyAI1Nyj2SAhq4GN7JnjujCqy+93I1sPplqQPUcq+XrzIRbdPb6thxISfkVFFIQT6m3VEqzPKeGWjzN5+c+DGJXR8ALOYpTaXD5a5xTHWQJq4yUMT0w0M2aOgztHNFw2pSoKni1rUU2xIMuofpmD2cWau7Th7BqqaqCZua0QJAnrgAE4MzOJOuusNhuHTjC+vFxMaelNHxgGejD1+0NWVAZGXsgKdR0IwaVpoXgpyVUyojW8Mr3aNNX3I0XENpm5UVT4La+MvBKnpkT/o19t05Qch9BK7MLpTRIEgXP7161Icnj8LNx6OEjQpcV9okIgXE8oo0FsMAiSRIHJfVKY3CeF7MIqVuwtotzlQxAE4iJMTOqZpBlY/t7Qg6kW5Klvd3DrxK5N1qKfLHj8p574xIb9pdw4dwMvXj6QMV0T2no4Or93vFVNGvD2TZI4r7uBZ5d76ZWo/XlSvX4Kn3kMnxyFIBkQJAlH0liwpQUdG+6uYf2SkRONbchgXJkb9WDqJMObm4d9fMtu7In2CGQ9mDop8csBKevcEicOr59Is4GMJHuj2YA12cXc+kkmLq8CqXYEoxdBCG8+URUJwZiK4jzQrDK9ahrq+4kcegGF858JKXOjqCrvr8zhkfN613k8r8RJZm6pRmdYeCV2ofYmDU+Po120tc5jh8pdGEQhSPaiOf1ioRBhkrCaJcocvqD+rHBFK5I0jJe16JJop0uids/Z7x09mGohFu8sIK/UyYxRndt6KCHjPcVK/H7LK2Pmh+t5/rIBjOue2NbD0TkdMEeC0rT/x+MTLAx+q4p7Rml/6YhmI6lzPoLIlJrHOn+9DVbkBF8yzF3DeHvbbt5YBw2m8MUX23QMOsF4c3MxprasEbwUYUepCu6f0Gk7CirdfLR6Px+u3I9PUfDLKn5FwSiKiKJAZLVy8JBUom3HlIN/2HqIO+f9ViOnL+ReS0T6K6iiu0Efofqoiojqj0QwX49gWF9nzmqOoIJW348re0PImRufrPLlbweDgqm5q/ajakj8N6fErqneJKtJ4sZxweqmVR4ZUeONPZ5+MeCokrFEgt2EyycTYTKQkRTBpJ7J7C2oZO7qXM3XhRoYWo0SV49OC3tcpxt6MNUC+GSFJ7/dzsPn9sIonTrBiccvnzKZqS0Hyrn+g3X885L+TOyR1NbD0TldsMaB2PQ02TVO5PI+Rl5e66VfksZnShDBVrckdVy3RP6zLi9IeTKcXcMIk8Sknm37ebD274d7924UtxvR8vsv5zgVUH0+/IcOYezYuGF0uOhlficQnxu2fwVr3gx41vk9YLJBcl8YfTukj2fxrkJu/SRTUw5fVgI/O70yL/y4m5d/3sPc60YwIDWGDftL6wRSAKovHuf+m7B2ehtEN4KW7F4tVMWA4ovBlTsT0RRVZ87yFR+gfN18Igeeg7lDH2InXhO2oELNfYSZual0B/e3fr/tMF6NDH5zSuzMeBkp7iCOCiRBoVyNYL3SnVKisBolLhzYgTFdg9sPIkwSikZu7Hj7xRQ1EFDd/8eeDOoUw1e/5TN/40He/jWbCwd14KEpPXnhx92aCsehiFYoqsqlQ1NDfHdOX/RgqgWYu2o/HWNtp9wi/1TJTG09WM4176/lmT/14w+9kpt+gY5OSyEZYPAMWDtbUxq9Nv8Yb2buZo1jRCMMmg5SXT+5cd0TsZgkzS+5UHcNq+vV2xLRZsOckYF761ZsQ2RSQ5IAACAASURBVIe26Vh0AvgOHcKQmIhoatmspWi3ozj0YKpV8Xth8dOw7p3Az95a77enAioPQ+4q3JKdHxx/wukd1+QpXT4Zlw+ueHs1H98wgr99vlnT4FnxpFDw9TAqN8zDV1iCaBGxdLKQeH4iEd0DKnmqbAZEvCWj8JaMByWQja+es8pWfIqvIBvBHIHvSBZRoy5HkAzNFlQIN3OjaDSiVrg9SLYsRGMJiB5UxYziTQorUOskHOFq6QculxYjIyKiIqCiIGLEzy/KIN6Vz+XzTIXvtx5idEYCM8d1YcBRhbrkaAs+f/DYWqJfzOGV+dvnmzEbRM7qnczkPsmgBtSl/bIScqaxPmaDyPkD2repF+qpgh5MHSelDi+vLd7LZ0el0E8lvP6TXxZ9e34FV7+3jien9uXsNl406pymDJ8J6+cEBVM5d9VV7EyNFnE/rCGrK0ow4saghyVR4Pqx6cxatEfT+LCpXUOTJDB9VOeTIhtuHTwI54ZMPZg6SfDm5mHs1LIlfqD3TLU6nkqYexEc3gz+RoRlvA4sOPiHMIeehiwe818NGsa79XH5ZK56Z41mwAH1+oe69MUYuwNP3iLK1x7G3LETqj8Sf8Ug5MqeqBrLR3ufiUjWKAr+9zipd3zSIn1A4ZY920zHrlnuKeeLPV8gd5iNVfCCoAAqqELAc6rEQ/F35ag4EQgWrQigcof0BTcbFiChYhK0lV0ni+uYIG5ipdKbW513snCrj192FtAu2sKTF/ZlTNcEJvZM5MftR6hfcdjQ5pl9yAUUfRlav5jXrzC2awILNuUjCkIdqXJLMzbNjZJA53gbT07tG/ZrT0f0YOo4eXHRbs7r345up4AUen08fgXTCV6IVXmr+Drra349+Ctl7jIMooEkWxJTu05lTPsxSLUm312HK/nLe2t57ILeJ7VXl87vnLh0SB8H2UtB1jZibBDJDJ1HQ3xwDT3AtJGd+XDVfo5UuDWV/RpCUBXsqsq1Y1pWra252AYPoXz+/LYehs5RfHm5mFJbvjRHstvx5h1o8fPqELBh+PhSOLQp5HnGJni5VPoVB1b+5b8CaLpXye2TNecarf4huWoEhtgRRA4B936FGdJPjBO/5Va64W5g+RhuWZ5RFIi3myh3+TUNzMMpexYFGJ0REKZae2gtt/9yO4qqgMGtGWpG9FQQDAJK+d8xJt2P4gkui33c8AGXSkuxCo1XJkiCig0PY8StzDM9wWXeR3H5ILvIwXUfrOOJqX2ZOa4LS3cXamYFj7dfzK+o/LQjOFADcIdoqluNxSDSJdHOx9ePwGo6uTfcTxb0YOo42H2kkm9PISn0+nj8CmbjiQmmDlYd5M3f3mRhzkJEQcTlr2tWuvzgciwGC9N7T2d67+nkFXuZ/u4aHj63F+f1b39Cxqij0yAXvwtvj4eyvCbL/WoQjRDVDi59v8FDoixG5s0cxdTXllPh9iOHEFGJAtjNRl7c+R88/7cG9bHHEIxtW4ZhHTyIQ48+iqooCGLbZ8pOd7z7czF1boXMVITeM9VqrHkTDv1WJ5BKm1WJ0wf77rQTYQqEA+9kevlos48lVwfK7iIED9dI3/OTPJRf12wJSUJci8b6h0z4eN04i9HidmyCh4fVuTztn46LYMGdcMvyzEaJVX//Axv2l/LMdzvYmFcWcuamftmzqsLewiqu++gnMj3/xm9wNnptySaR/KckDn2SQ/sZzyFG3YXq7VQTqN1xZkculZZi05CMbwir4KMHebxofJ1bfXcC4PYp/OOrrVw2tCM+f6DsTivoqU+4gWko56zGbBCDKiIizBIWg8S1Y9O4bmwX3VomDPRgqpmoqsqT32zn1oldiT1FpNDr4/HLJyQztalwEzf9dBMuvwtZ1VZGc/qdOP1O3tr0Fgv2LuTwrhk8cM4Qpg5s2QZqHZ1mYYmC63+GDy6AkizwNf4ljdEWyGjN+Bos0Y0e2inexsI7xzH93TUcLHPh8sqaEr4QaGJOiDTz0XUj6GA+g4N330PejTfR4aVZSJFtlx03JiUhRUbizc7G3LVrm41DJ4A3Lw/roEEtfl5dgKKVUBRY+Qr4XEFPySq8tMbLg2c0LE9txseV/i9ZsHx1s81fG1q4Cyi8bpzFGHFrTXbmKsMviKg84Z+BFwMKx14TbllehFlCEASGpsXx3MX9Of/V5SFnbuqjAnsLqthbIINwI6L5MObEHzHY9zT4moRzEjBEGyj8Jh9P/l0IxkjMyd2IG3UJ9xpmBQVSjQW4M4eY+PcqDzuLFOzmRfiSDmEcNR1Lxz64fQpzV+fy9lVDuO/zzVS4fE1WIxyP0l9TGUoFFx0T/RQ4KpFVD4KhHCluM91SBbp3uQZJSgP0YCpU9GCqmSzeVcDBMhfTTyEp9Pp4/QrmVt552FWyixt+vCEoE9UQbtnNvvIsEtPe5o/9/tuqY9PRCQtbHFy/CDZ9CitmQVXh0aCq+htRCARREQkw5i4YeCUYQ1O3S4m28ONfx7FmXwlvL81mRVZRHXEYj19hRHocN47LYHRGPOJRo+2Or73KkWeeYf+0q0h9602M7dquHNY6eBDOzEw9mDoJ8OW2UmbKHoFcVdni5z3tyV4MXm3J+ftGm/jnCg+3DDMRY2nAOFVQseaHXhKmeY4GFu7TpEWMFrfXKXMLBBQL+Ob2LD4Qp7JIGUL5pp8o27aMlCufJWbsNEp/egOTqBKV3henGKlZlieJ1BHu6pYcyeQ+Kfyw7bBmQBXG3YAqobhTcR2YjilhEc4dX1H0QxHeAi+SRSJqSBTJlyQjRUjEjI4hZnQMqmLAW3wG3qLJTBFXI6E9Bq0A90CFwl3fu3nzPAuTMwwgSjy8O5I5tbKCFoPEhtxSrh6dxltLs3H7Gt44g/AD02oa9c7qnIal3f8wROyhUhQwJxz7varAtmJ4ZPkjiKLIrQNu5cpeV55yegBtgR5MNQOfrPDUNzt45LzeJ0Xzd3Np7Z4pn+Ljxp9u1AykSpeVNjixIciU+w7z9JqneXrs0602Ph2dsDFaYOg1MORqOLAOdnwDVYcDz9mToee5kDqC5sgnCYLAyC7xjOwST3GVhwOlLqo8fiLMBjrEWEnUME4UDAaSH3mEkvfeJ+fPV5L6xutYevU6zptsHrbBg3FtyCT2ssva5Po6AVRVxZuXh7Fj6/RM6T5TrcCG9+uq9tViaHuJCWkGnl/p4alJDW/OFLogwmZptuiD9sJd5WbD15plbrIKi9bt4bUzXqFUtXOX2J7FQhmTxbXEjIylNLIna1f/m53fKPhMkZpleUZR5Poz6vZ9Pn/pAI5UuNmUV4bruAKq6lswUfStg4p1JXScmYK9tx1fqY/8ufnkPJ9D+kPpiEc3rgTRjyluFd6iM7nJ8DV2QVsEpH6A6/Kp7CtT+exiKxf1qi65Vni61x4WdrmX6rO4fDJzlu9j6sD23DYpg9VZxSzbW9zg0MM114XGvbMienbHlvYSguRAEOUGQkVw+AOf8VmZs9hbtpd/jPqHHlA1gR5MNYMPV+0nNc7GxDb2dzleAmp+rRdMLc5djFtDkahoYRGFCwvpeH3HBic2r+Ll+5zvuX/Y/USbGy+T0tE54QgCpA4P/GsF4u1m4u2huc4LgkD8tddgbN+e3Guvo/1zz2If17RccktjHTyY4jnvnfDr6tTFX1CIGBGBZI9o8XPrZX6tRHnjoh5PTDQzZo6DO0c03FKQYpVxOV3NN3/VWLiPNOxhzZ4SVu9388+z6gZytQOKWEsVZ0hb2S/4eMs0K3DAIGCQFZdqYojnTZwEB4IWo0Sp04eqqjWLdaMkMve6ETzw+Wa+2XwIRVHxhdBL2lBZmykxnbLl84ifciu27psRDIcwJZpIvSWV3fftpnxlObHjYmudScUQuZ3e3v0NXqt+gJtVqqKo8KdedZfUMhIZwiG2qWk1j0miyI3ju5KRGMHcVdqGuvXvRYqIoXTp+8iN9ItV02Dvm+jElvYmgqESQQitucotu/lm3zdEmiK5e+jdIb3mdEUPpsKk5KgU+ryZI9t6KMeNp5WDqTlb59TscFQju2QKviygw3UdiOwf6PFoaGITEZm/Zz5X97261caoo/N7IeqPkzEkJXHgzjtIvO12Yi8/sRkic9euyGVl+IuKMCQknNBr6xzDl5eLqRVk0QFEe6QeTLUG/sYFDvomSZzX3cCzy730StT+zh6TKmEwiGGXhNWmvtDDTyaVqnYKD48LDuJCzZjJiJwtrudLJXhM5S4ff5mzlvgIEzdNyODPwzohigJGSeSFywZy+6RuvL8yh/+sz0MUBBxev6bIQmNlbarXfTS4OANPUTy2jp8AIFkkIvtHUrWtqk4wJUgeIi17EbyNBxy1A9wqr4pRBINYN3ujIhAlOKhdy2eQBIqqPBypcFPhDhYzauhePHnbNDNR9Wmo982c9D2CVBUUSDVaJQS4/W4+3fkp52WcR/fY7k1e/3Tl1K1RayNe/Gk3Fwxof0pKodenNU17D1QeYG/Z3qDHnXucKD6FqCF1/XhqT2zVuGU3n+z8pFXGp6Pze8Q2eBBpH31EyZw5FLzwAqrSSJlM+YGA3PvObyF7CZRp75KGiiCKWAcOwJmZeVzn0Tk+vPtzMXVq+RI/AMkegezQy/xaHGvT1RePT7AwO9PLwQrtRb7NLHHmuGGU/PQmzt2rUHxuVNmPK2s9pYvn1Dm2sYIte5+JtPvLLDrd/Tm/3t2T76bZGJ2qve/+xEQzr6z1UuhoeJ6x4KW9oF3KpgJOr0xeqYunvtnBtR+sw11LHj0tIYLHLuhD5iNn8fgFfTTbEqrL2uLOuhlbj9GIJkuNSXDsxGtrBRcG5Mp+7LpnDztu34HiUTBEG/BX+SlZWkL2/2XXnLOHcTtiE9mb2gGu3STgUwLy5PXv0KMGK636ZIW3lmbV8YIK5V7qc+CNa8l7ZRqK91gFkOfQHhRHGapS69yCF2P0RgSx7vWKFhZx+L+HSbkshd6v96bLI13wFnvJeT4HpZbSn0/xMXfb3Ebfj9MdPZjSwOtXKKryUFzlwScf+4PadbiS77Yc4s4/dGvD0bUcHr/capmp/Kp8jGLwJCJXyRjsBgQpeDqvnthqU+QqapXx6ej8XjF17kznzz7FuSGT/HvvRfHU2vVWZNj1Pcz5I7wyBOZdBfNvhHnT4dVh8O7ZsPO7wHHNwDZ4MK7MjS10JzrNwZuXizG1dTJTgs2G6najys37+9BpgPQJYGhcrKZrnMjlfYy8vNar+bwXI/ZhFxM76TrKV83jwCvTOPDG1VRmfoO127GSLwEY0jkGMYQWGAva16qmdkDREAZBwdpA71FtXD6Z1VnF3PDh+iCLCItRYs2+4jrrsWpql7U5ti/h0Ad3kfvvSzjw6nSO/OdRZEdpjbAGgCpbQYGiH4vwl/sx2IMDxb7kN2mBnDarku/3+Jmd6cVqCFhW3P6dmwnvH9tsMOKngJg6r1PVgCrriqzgALMxefoGURQqNyyo+dEQkwKCgHP3qmPjiN4YMCquRXWVUPur2hPZPxLBINRUCXmLvJSvLD92rCqzMGchVQ309enoZX41+GWFRTsKeHNpFpsPlNUIS/hkheHpcdw4rgvvLt/H7ZNOXSn0+rRmZqoh9T7JLuGv8qPKalBApTWx+ZS69dQ6OjpNY4iNpdN7c8h/4AFyr7mWjq+9ikEthQ+mgqvkWLN7/Z7GvDXwxQ0BOfcZCyAhPGU+66DBFLzwQgvdhU5z8OXmYp/YuIR0cxEEATEiAsXhQIqKavoFOqEx9BpY1vTn5h/jzczdrO1zV6baWaP2xN6nV6MS4hajRJcEO5sPlOOVg7Mvtft1Rpn9DEtReegMM2M7aS8XH59gYfBbVdwzSrvH06tKVKjH+vcak+x2+xXW55TyxpK93Dap7qb1viKHppR4deapcv0CzfI4d85vAWGNXSuJ6HUGqmIk/pwEir4rRBAFki9NrnM+k6LQXg5s6qo0nsVDEOgRL/B2po+0GIGPt/joFC3i9AXK/j7KsrEl65s6WSVZUUmKtGAQhaCgMRRfKREZMz6MyAhA5IiLqFjzOZGDpiBa7IhGM1JUcp3eN0PkGqq2FePY4SDl8hQgtCqh2uWPkiixPH85f0z7Y2PvyGmLHkwBC7cc4u9fbMGnKDg8gd2L2mZmq7NL2Jhbhl9WuHZsekOnOeUICFC0jjR6hDECQWMasnW1IRgEKjZUED38WGmD7Jap3FxJ8iV1JzazZNYDKR2dZiCazXR44QUK//1v8q+ZSurwbASfA9QmFLK8VQGZ5tkT4NofILlPyNe09u+HZ88eFJcL0Wo9vhvQaRbe3DxMqa1T5gfHRCj0YKoFsSdB1z/AroXUbrDJuatuO0FqtIj74eD33amaeEs+jyaW/liNElePTuPD1fs1A6n6/Tp/N88jed9CvtrpbTCYqp0x65cUvDnrxcgOtbPm+bVMhV0+mXeW7eOm8RkYapX11S+Jq6Za0r102UcknPtXTY8tKTKekkVvIZptqKqKZEtFNAQyQzGjYyhbWVbrjAJTqpxH/9c494028exyD24/9EuSmJgmMm+bj8R/VWI3idhSkrCOPJZlMogClwzpiMkgagq+NiRPn0AZPgxE4GaguJcbDN8SqTrphYoppRvmTv0oXzuf2HHTA9eJSsA+YHJN75toVbCmmUk6/5hoWlNVQq79dTfEZVmmxFXSxDty+nLaB1PvLt/Hv37Y2aSnQXVwdcvHmTx9YT8uHtLxRAyvVXH75FbzmUqLTsMjBzfVSjaJpAuTyP8oH9Ei1lHzM8YZiRldNyWeGtl6iwIdnd87giiSdMu1KK5Z4KlsenVQgxo4/v3z4NY1gcVeCIhWK+Zu3XBt2ULE8NZROtRpHG9eHsZWEqCAo31TVVUEF3HrHBcTHwr0LjZlCF4PvypQgY0v5DMaPc5qlLhqZCd6pNjRUnHQktT+lD/yY49fuKhn4xUsjWXMHFhYofRpVLK7vqmwT1H4eWcBk/sEsigHqw5S5S8ADUVAc4eeIErQSHlc9IhLEC1RlC5+F6WyiCP/qcLe10bl1koUV621n6oy1O0mpbqMVRCPvlfaAe7Q9hKT0g30ThR5apKFdzK97C1RWHJ1BJWqlaGef6FyrJJJEgWuGZNGlNWI1x+85tSSpxdQWGu+VbOHy3jUVDdm7DQOf3Q/UUMvqHmutslxRNdnEI0VdV4bbpWQgoJfqduGoXOM07pn6utNB0MKpGrj9ik89OUWlu4ubMWRtQ4+WeG7LYeY+upyejy8kNeWZDFr0W6GPrWIF37cxZGKpuuam8Lllfnqt4Pc+1k23qoumsckTkkk+eJkDs87zPabt5P1ZBbGOCPp96cjGo/9SdoMNq7pe81xj0lH57RmzVuIeOvshKbNqiTpX5U4ailWvZPprVPvDwSyVKteD+tytkGD9L6pNkIuKwNZRoqNbfrgZiJG6F5TrUJKX7hkDhhCz+iqgohitHOb6WlUU7AoligEgqguCRE8f+kAHjq3N28v24dDI8uj1a+TqyazWQn+Hs+5K5IzuxxbbFdnzJZcXVeO36mamO0/FxUxrH4ghyfgx7Ti4Aqu/f5apn45lSPKahCCAzbRHIGteyA4c+1d26DwRuSAs2l/3etIUUkkXXY1nW7rRNTAKAq/PbaWs6oq15XXCjqMNpAab+vQEuFwqSZu9P0VT61AyoKXyX1S6JJox2KU6Jak8fuqJU9fW0Rk4R4v9/8UvD6LEpyY8WJKTMPadRjlq/+rOUZVDg5Ca1cJ1aa6Siiid93fpUE0EGXWs9ENcdpmptw+mQc+36IZSDVW0xt4rcLd835j7UNnIoXSxdnGqKrK+ytzePGn3ciqWlPKCKCoUFTl4e1fs3nr12zGdk3gX5f0D9njBkBRVNbsK+GLzAP8sO0wA1JjuGhwB65NuJv7fr0Lpz94py1ufBxx4+OaPPfktMkhj0NHR6cesh/WvgUaWWJZhZfWeHnwjEY+67IX1r8LEx8EQ2i9otbBgyn74vPmjljnOAhkpVJbtTRatNtRHHojeqvQ4xz486cwb1ogI9JYlspkR7DGYPrLN/w3No1VWcW8tyKH7KIqXF6ZCLOBXu2iuG5sOgNSj1V87C/SDoQb6td53n8ZH4jPYRUaF6Ooj6KCBxP/kcc3ev6GyDyQzy2LnkE5KoIhxayE4jGax9r7TsK541fKV35GUQheTNV5hKQ/JZH1aBYJf0xAAi6prGK4u9ZcqSow8ErY0LB/Xn3ZegWBu3y3sFLpW3OMBQ99pTyev3RqzWM3T8jgoflbggLb+vL0gsnKK+38/GNccC7YgpcJ4kbW0J+YsdM49P6dRA37U9Bx/qqeiKbiOmp+4VYJyarMkKQhDb4PpzunbTD13ZZDmo+HUtMLgWDs192FJ71xr6qqPDR/K/M3HsTla1iBqbqMcdmeQqa8tIz/3Tya1Dhbo+feW1DF/I0H+HJjPpEWAxcN7sC9k8eTHGU5eu0OtLe3J6c8B78aXnrYarAyvfd0zFLoQZ2Ojk49dn8Psnb5TW3TzRhLI4tvVYEdC6DfJSFd0jZ4EIceeQRVURDE07r44YTjzc3F1EpKftXoxr2tTMZE+Ot2+O1jWPlKoNwWAp9DQQLFDwndYOxfoed5YDAhAKO7JjC6a9P+bh6N8jJouF9nrdqLJ/1X8bDhI2whBlSKCi4sTPM+SAX2Rs/fEH5FrgmkAERjBZJtH7KjK/WLqswdeiIYTUSNvDQkjy1BDPQDmZPNRA+PpvinYlLbGbi3pKzugVEdIG0sbJ7XaGD78HgrI96uYtrIZLYrEeQqgRJnCT9GZCaJG5kVvwCT4c6a15zTL4WHvtyqeb7aJXoA9xifZ6SobTlxpfQLXaRUPo2dREXPM6jc8DXGxM51jvGVjsIUtzLotYlTEpEiJA7PO4y3wItoFYkaHEXqjal1qoQA+sb3JTVKb7toiNM2mHpjSVbQjkA4Nb0Or8ybS7NO+mDq+R92NRlI1cYnqxRWebjsrVUsvPMMYmx1d6OLqzx8vSmf+RsPkl/u5sKB7Zk9Yyi92wenfwVBYPbZs7l4wcWUe8qR1dDGYJEsDE8Zzi0DbwnpeB0dnQbIXnxMua8eoZpu4q2CrF9CDqYMiYlIMTF49u7F0l03eTyR+HJbz7C3GvFoz5ROK2KNgVG3wshbAgqbZXngc4A5CpJ6Q1LPZp/aZBDxa5T5afXrVPOJfCZu1cTTxjlIKPxvq4t/r/Kws0gh0iwwMEWsUfxzqiY8mLjS+1CN8ERT59dErJtNL11WStHCf+At8CKYbdi6jSJ2/F8CCna1yuOqFewE0YA75zfcuZvrejQJKlLEfgRVxaKq9Dk3jjUrS+nsA5Fjm7dOLOR1u54e7QY2KNqjqFCpWtkWPZTYXsXMXrsVc2Ic7XDgxcj54iquM3xHL+kgpF5Y9/02SDw0pSdPfrOjyfXZ2/7zGG3aRgTBFQaioPIP40eMEHfy3NgJLNm2OKg11uC3o7o6I9iygvpmQ6kSshlsXNu3acPg05nTMpgqqvKwvzh4lyFcjf/1OaVHvZpaR8TheNl1uJJ3V+wLu5SxuvTv/xbu5LmL++P2yfy8o4D5Gw+wJruESb2SuPvsHozJiK+jtqNFgjWBeefN45rvr6HIVYRbbrwvy2qwMil1Ek+OfRJR0He1dXSOi6rGezufmGhmzBwHd45oooTPEbrf2/b8Cl4ZdBlrPtyFS8hGFCDSauTCge2ZPiqNDjG6yl9r4c3Nwzp4UKteQ4qwo1TqwdQJQRCg08jAvxaioc6EpgKSLyZeyzpvD9qvf5lfVm5m1rmRTMkQsUky3+yF/+6UyUiN5W3/eXwuj6OCiLDOX9eUVkE0H6seKlpYROHCQjpe3xFT+3E4946h+MfZHJn3CClX/RNBMmqWx2mV+qXe8jo90h+lm9PNjIpKhro9CA8FbwYbjUZu25xGx0MlvB7fC+uR4D5QUYBIwcWF0gos4ztx0VYnGUI+bxhn0U/MJkoIZMAUyYY4+rag1185ojN5JU7eX7m/0YBqndqDEjWKCKHufF5bDGOytJ7JievZ/WBX5sqT2KNspwoLUTjpY8hjzPjbuP+3Bxu0rWkIo2ikS0wXxnYIIQA+jTktg6lShxejQaD+5ky4Nb0GSaDC5Scx8uQMpt5dno1PI6UfSimjT1aZn3kQWQn4b/VuF8VFgzsy64pB2M3h/dmkRKTw+QWf83XW17y37T1K3CV4/B4UAmMziSYQYEDiAK7rex2j24/W5dB1dFqCJpqn69f7N0gI/VLrckp45Mut5BQ78AnJyH4BCEyyDq/MnOU5zFmRw/C0OJ75Uz86xTdeRqwTPt68XKKnTm36wONAL/M7ddlbUIm7AYlxx/YlOHb8iuJxUPjVs6CqiGYb5vY9awKS/W47K5buImXKPazoamGfUIxV9VCZYaOgSyqjPL1pTDI01IAHwYcpfjlwzFy2w3UdiOwfCWxEMKiItrs5+MbNOLYtxt7/bCC4PK4+RnxcL/zI3w4eaPyNMtowzviCb9sN5cNVOfz95z/wlLoDey3z4epA5pMtvqNZus3EWwU6mCpRD24iqpaUfIEQT0qHQL9RlcfPlxsPMHvZPvLLXPhkteYdkzS8p46+ITyszuRN9V9N9q51Fw/ypPh+nXth+Ezocyb/ijJy79J7m9zUrsYkmkiJSOGts95CCnFdfLpyWgZTASO24A98uDW9AoGepJORKo+fBZvyqW8lEZY8qaxQ7vLx3R1n0P44d5NtRhuX97ycy3pcRmZBJusPr6fEXYJJMpFgTeDMzmfSwd7huK6ho6NTj+gOgT6LRkpsmzLdVAURIarxz+bXm/K573+bamXBg+dXrxx4bmVWEee9soyPrx9Jv47RQcfpNB/f/lxMnVu/Z8p/WLvnWOfkZs7yHLRWLA1tsHryttXJGFVX7xi7n8F3SvMW100FPACC5ESy7gO0zWVNMb9h5gGAGgAAIABJREFUsOUS0asTrpwN2AeMA9UISICfQE9V3c0hAYUYqrjB8F0jFxYDaopXfAypwzEB149N54GcsRTumYsZL0bh2Ab1v1d5eHa5lzfPszA5w4BJgu/3+vlqp7/Gl8ulmnjcfTmPVbh5Y0kWn63LRRSEOr5Z1b+T6lnTIApYjRIIAYEvk0FkyJg/4benwE/3Q6jZJaMt0Fd35mMAjE8dz+tnvs4dv9yBoiqa4mAAoiBilsz0ie/DK5NewW6yh3a905jTMpiKsRrxyaFp/DeGV1aIsp6cbhtLdxUiiSJQ9z7DKWVUgcPlnuMOpGojCAJDkocwJFlXhdHRaXX6XRJQ8/M1/OXblOmm6ofyvQZsBw5g6hjsr/fr7sJ6gVTjKCpUuP1c+c5qvr5tLGkJEU2/SKdJFKcTuaICQ3Jy0wcfB3rP1KmJ0+tn/saDx7XBGm71TrMQvJgSf6yxcmjIXFY0lWBOOYwrJwdL+y9QvLGoiglB9CE705EdPY4dK4BJ9fNJ5CvECSrUT+4YrIG+qB7nBPy+ErpBznJY8RLq3p95UgFVEDCgoKiB85W7Vf6x2MN7U61c1OvYOvD8HkbO7xH42amaeNn/J76XB7PptRWUOL2NzpP+o1kpUYSkKDMPn9ebdtEWuiVFHlWO7gaxKfD5dYHxNtAPW3M/I2+BSQ9T2xdjWMowFl+2mB/3/8i7W94lvyofg3gsFPApPiamTuQvff5C34S+WmfX0eC0DKYSI80kRZrJK627wAivphf6to/G0kqmt8dLscODrBEwhjsZFjuCGx51dHROEVL6QWwXKNjW6GGNmW4S1QFXiYmCSy/D1KUL0eefR+TkyRhiY3H7ZG75OLNZFhMOj5/bP93I17frtfgtgTfvAMYOHVpdQVGy6z5TpyK7Dldi0GiYCmeDNdzqnXARRC+G6HWYYo71JzVpLhspYYzaXPOYqkh4ZBuyowfCUZ+tDjFWpo3szA3Lnubbc9xEbPsEKvNB9lFlSiQv5Swqu5yLLSqWlKKtJHx0EbhKwOtEQMVY67KBiiRYdUDG7Yc/9QpeRntUAwoCT/un8bH8BxAd5JfLQGjvmdevklfq4o3FWXx8w4i6Fjzdz4b7sgIKqytmQeGuQDm3IASUHk2RAfGSwTPApi0sYTFYuCDjAi7IuIB95fsocBbgkT1EmiLpEt2FaLNeMRAup2UwJQgCN43P4Olvd+D0Na3xr1XTG2GSuGlCxokcdlj4ZRWt/Y+w5Unrb2Pp6OicWoz9K3x9Rx1p39qNy3DMdDMIow3xzAdoN+gqUv7+d6qWr6Dim68peP4FbCNGsGzYFFQ1ePEeSl+mosKeI5XsOVJJt+RgA0ud8PDltb6SH+g9U6cqFW6/ZjtTOBus4VbvGCUBm0nC41NwNyDJDoE4wGKQiEjYgjv66zrP1TaXjR5+bJFfbS6bfEndTKyAiEm0IhhEJvZIZOa4LgzuFIsgCOQUObgj08nsGR+z83Al7yzL5tsthzDmigjrdqHKPrw+P0OF6dxo+IYzxC2agh2CAMVOlQSbgCwYkFUBGREZCQGVT+WJfKiewZGYbMyGBXgKpqAVSDW24eT1K2w5WM4XmQe4fFi9z7XBFKg66HcJVBUEBIIUf0AFMqpjILUVIunR6aRHp4d8vI42p2UwBXDhoA48uSA0jX8tJFHgrN6tW05xPERbjRhEISibHe5kGGU5bf9EdHR+H/S9OOCTkrMM/KE1HgNgsEDHYdD/CgAEk4nISROJnDQRuaqKyh9/YvbKMhyGuruYYfVlKgrvLt/Hsxf3b/796QABJT9jp9b3gREj9GDqVMTYgIxfOBus4VbvZCTa+WzmSD5Zk8u7y/fh9ss4PXJNj5DFIKIC47onctP4Lry39xuW1tOGCNdc1mI08Lcp4zm7y1iiLHXbMB6c0otL31rJuH8tpqjKg09WkRW1lveWABhZqfZlky+DaBx8aHqWrmJ+0HsRbxMocqrICnyvjmCD2oMjaiy/Kv2RbfuxdpyDSVBx5twKarCATygbTi6fzBtLsrhsaCNG3PakwD+dNuW0XCkrbjeV/3qeG7KKmN31LNyh2R/VYDGKPP2nvhibkAVvS4anx2mqwoQzGRol4aT30dLR0WkCUYTL58JHF0N+ZqP9UzUYrZDSH/78KUjBXxOS3U7lhMnkZy4Ff/P7MmUFFmzK14OpFsCbux9zRtdWv44UaUd26MHUqUZipFmz0iTsDdYwqndunpBBjM3ELRO7cuP4DH7dXcimvDKKHV4izBLtoq2c278dCfaA+I3PdBXrDq8LEkYIx1w2wmjjol7jNa1VHB4/hZUe8sua3lRyYGXXG7fS0+9mxe2dGGHJBeCdTC8fbfbx1RU2zAb4bpebKb3W85DnOlxYkCJ2Y+04F9UXjbvwLBRP8KZ7OBtOBZUefssrY1Cn2CbHrNN2nHbBlHvHDg7edx+W7t256+3H8a3M54MmNP5rY1b83Giv4vwBJ7fyXGqcjYGpMazZVxL0XKiToSgIzBiVdoJGrKOj02oYrTBjASx6FNbPAYSACWjQcbZAQ8Dgv8DZT4HUsMBOYZUHk0GstasbINy+TJdPxicrJ/Xm1KmALzePyEmTWv06ot4zdUrSNclOgt103L3iEKIinyDwx74pNT9LYmBztrEN2hEpI7Cb7Joqc6GYy1okCzN6z9AMpLx+hWnvrKGwMrw+cFVRmbqqDxsmOOggFNc8Hm0ReGKCmVu/c/OSYGRK5yV8bhiI9/BnFH8fT8y4W0GV0KqtDGfDye2T+XLjQT2YOsk5bYIpVVEoef8DimfPJvmBvxF1wQUIgsAD50TTPtrKMwt3BMlV1sZmCiwMHp+UzqCn7qQsw0zMJZecyFsIi0PlgQlTAE0p1FAmw0GdYkiN071gdHR+F0gGmPw0THwQtn4OK1+BsrxA6Z/BEpBRH3Ub9LsUzE1L4Wp52EH4fZmSIOjBVAvgzcvDmHoCyvz0nqlTkppe8e92BK1zQvZ/ChGLUWTW5QMxG8ITqRAEgZn9ZvL8+udD9kKqjSiIXNTtIs3nvttyKOCBp5Gda6h3CSBqxEUUrvmcZ4dczStRH9Z53T2jzaTYRZ5fVsX2+a/hNdswJfcgauTko1LtAaq2LKJi7Xz8ZYcRzFaM8Z0QrZEhzY+KCofKw38vdE4sp0Uw5TtyhPwHHkB1e0j773+C5H1njE7joiEd+XLjAd5cmk1hpefoF7uKV1ZJjbVy0/gMzh/QHotRwtP5bfbPmIEUF3dCdgLDwe2Tmf1rNu+u2Mf0EZ1w+2R2HKqs8XgJFatR4uFze7fSKHV0dNoMU0RA6WnwjOM6TZTViKLhsxdu2ZCiqgFPFZ1mo/p8+A8fxtSh9SsmRJsNxelEVZRWVw7UaVkuHNSBp77doflcKBusoWAxijx+QV/ObGZPebUX5S+5v4QVUFkkC6/+4VViLDGaz7+xNEtzs7yx3iUAU0o3zJ368cnqQzxzlpX6uurT+huZ1t/IL/IArvP/FbVef1TF2i8oX/MFCef+FUvnAchVxRR8+X8oznIUnxvRaGny3rSsfHROLn73wVTFDz9y+IkniJ12JQkzZyIYtG/ZbjZw1cg0po3ozJEKD2UuLwICsTYjSVF1/9jN6emkvv46eTNvRHrtVWyDB5+IW2kUVVX5YdsRnvp2O33bR/P1bWNJjbNxw3gfU19dwcEyF95G1HRqYzGKvPLnQfTtoMtj6ujoaNMlUdsfKtyyod7tohpurtYJCV9+PoakJARTcKN7SyNIEqLFguJ0Idl1j7BTiQizgVlXDOTOzzaG7AtXjSiAURLxyQGvpdoIgNUkEWMz8uxF/RnXPbHZYxQEgafHPs1jKx/jx/0/4mrCoFYSJEySiVkTZzEsZZjmMdvyy9lfHFya2lTvkmPncgBixk7jyEf388GIiVj4Oeg8btXI7b47UKn7+VM8Tkp/nYvBHk/hl/93LOs16jKK5j9D6S/vEj/51ibfk+qeMp2Tl1MimJIVlSqPH7NBDNnXSXE4OPzMMzjXrSf1tVexDhwY0usEQSAl2kJKdOO7BdZ+/Wj/z39y4PY76Pz+e5i7dQvp/K3B7iOVPP71NgorPTx3cX/GdE2oeS7KYmTBbWOY+eEGfssrw+OXgybCaiJMEqIoMHvGUEZ2iT9Bo9fR0TkVMRskrhjWiQ9X5QSVzoTXpN76ogm/SxQZSrLBVYp/y2asXWID/W4nIDAV7XYUR5UeTJ2CTO6TwuMX9OXRBVtDCqgkQcBuMfDpDSPxKwqzl2WzaHsBbp8MApgkkTEZCcwc34UR6XEtsjFiEA08OeZJJqRO4J0t77C3bC9+xY+sHsssWQ1WVFVlSpcpXNf3OjpFNWwLsHJvsaYgV6i9S6bENCxdh/HqinLuTQl+/mt5lGY7Renid0H2ETPhaqxdhtZkvTx52zCldKVqyyKs6YMb3XCKMEmM79F0cOqTFQyioG9MtREnbTDl9Pr5cuNB3vo1m9wSJwZRQFZULEaJS4Z05Jox6aQnaE/krk2bOHjf/diGDSX9iy9abcK3nzGW5Af+Ru7MG0n75GOM7dq1ynUaotzp48VFu1mwKZ87JnXlqpGdMWj0HURajHw6cyRbDpQze1k2P2w7jMlw7DiHx09SpIX7Jvfg3P7tTlojYh0dnZOLq0en8dHq/Wh1ZoZSNiSKAmf3OXktJk5Kqgoh8wNY/Tr43CBKWHxe2rX3w0v9Ycxd0P8yMLeed5dot6NUVkKy/rs7Fbl8WCodYqw88tVWDpe7NTdZTZKIIASUgf/von50jA30T7/y50AljqyoqKqqueZoCQRB4MzOZ3Jm5zPJKsti/p75HKg6gMvvIsYcw+CkwZyXcR4RxqbXdyUOr2avVDhiOTFjp7H3/Ts4aAs+z5vy+Tix1nlM8Tio2voLgtlWp9y5OutVuuR9VJ+3aSEwUWByn+AIzu2T+W7LId5cmkV2oQNZVRGAGKuRK4Z3YvqozrSLtga9Tqd1OOmCKUVRmbVoN7OX7UMQqKlxrf4gOL0yn6zJZd66PPp3jObVKweTfLQMT/X7KXr7bUo//oSUf/yDqMlnt/p4o88/H39RMbnX30Daxx8hxWjX67YksqIyb10e//5pN2f3Seanv44jPoQ0cL+O0bz850GUO33sLaykwuXHbBRZtruQCrefi4d0bPIcOjo6OtWkxtk4t387vttyKOyyIatR4v7JPXThiVBRVVjyf7B8ViD7VMszTIRArVVZLvz4MPz4EJz7bxh4ZasMRRehOPUZ2y2BxfdOYFNeGW//ms3KrCKcXhlJFIiyGLhocMdGF+SSKKDpAtwKZMRkcO+we5v9ekMLeGwZY9vTvtcgXl67mn5Jx+asbCWFg2pC0PGegztBkVE9fs3zy44SjElpJF5wf6PXndw7uc4cqaoqry3eyxtLsgBw1OoDU4ESp493lu/j3eX7GJ0RzwuXDSQuovXLf093TqpgSlZUbv04k6W7CxuVKvcrKn5FJTO3jHNmLeN/N4+io6eM/Pv/hmA2kf75/zCmaORiW4n4a67GX1RI3k030+m9OYjW1tsNWJdTwmMLtmEzSbx/zbBm9TVF24wM6XxMYjTKYuSOzza25DB1dHROE567uD8HS11sOlAWckBlNUr8eXgq03XrhdBQVfjqNtj2BchNSDv7jspKf3s3VBXA2LtafDiSPQJZl0f/XTAgNYbXprV933drEm83YdawcQhXLGf42DF8s3VVnceOqLEY8eOm7oa27KpAtEai+tw4d60kotcZNc8pXheu7A3EjGtYBEgAIswSv+wq4NmFO7nrzG4YJZE7P9vIzzsKGl0jV/fHL99bxDkv/crnN4+uySzqtA4n1ZbgQ/O3sHR3438ktZEVlVKXl0tfXsLGK68h8swz6fTuuyc0kKom6Z57MHXuxMG/3o3q99d5TlFU1u4r4YvMA3y0ej9f/XaQnYcrwjr/oXIXd3y6kTs+3ciN4zP4z42jWkwgone7KIqrvByp0OU3dXR0wsMoicy9bgRn9UrGYhSP7lhrU106NCI9jkfO09VCQ2bpc4FAyhfsv9MgPhcseRa2fN7iwxEj9MyUzqlDQ8qCtcVynLtXofjcqLIfV9Z6ShfPoePNc7CmBfrtI3Dxl9jNuB+OYsnVEciqgKqCByNaGTrJGoXiqiRq1OWULHoLV/YGVNmPv/wIhV89iyEyAXsfbTVoUQC7xcD8W8bww13j2V/s4NyXl3HLxxv4eceRkNfIPlmlsNLD5W+totzpC+3N0mkWJ01malNeGV/9lo9LY2ezIQ8AS8c+qCqUeRQ+m/4wL147vg1GHkAQRdo99RR5t97KoUcfpd1TT1Hh8vOf9XnMXpaNwxMIsGRVRRIFFAU6xdm4aUIXzunbcJ9SHanzkZ159uJ+2Ewt+2sTRYER6XGszi5m6sCT24xYR0fn5MNkEHnlysFsz6/g3eXZfLP5ECZJrNNJJQhw1YjOTOqZxI0fbSCrsIquSa3X1/O7wVkCy1+sU9YHkDarEqcP9t1pJ8IUWMy9k+nlo80+llx9tI/E74Lv7oXeUwM+Yy1EtQCFjs6pQMdYG4M7xbIquzjouVDFcgQUInAxzfN31qs98GIEVKx48QmmoLbR6qyXMbY9MeNmULr43aM+U2YiByTT7opIpMhHAwcrZnyVffCVjMWqdiDGZuKTG0bQOT7wOX7jqiG8vngv//xhl+b9NbZGVlQoqPTw/I+7ePLCvhRWesgrdeL0yESYJTrF2UJqE9FpnJMmmJq9LBuPPzwPAEvHPgDIosTCXCdPuH1EWoxB5zhRCEYjHWfNYv/V1/Dtc29zv7MzqqpqBogAu45U8vD8rTy7cCef3jCSLonHjDJrS533aR9VI3XeWozKiNeDKR0dneOid/soXrhsII9e0IeNuWWUu3wYRIFYm4khnWNrhG/uPqs79/xnE5/fPLrVGth/N2R+SEO9KbIKL63x8uAZjSyGZC/s+QF6nnvcQ9l8oIwPV+1nhzAQz2YDMQUr6dshmhmjOtf5/tLROdm4cXwXfjtQhkvDa6opsRwJHwoSt/vuwIGZY0VdAk4sWvo7dbJe8ZNvpcPMu7B0XIBzdy6OXWUYalvuSC6MUZmYozfRIaITz098piaQqmZbfgWiQJBQSChrZJ8c6LPfW1hF5v7SOgJkHr/CmIx4Zo7LYGSXllFjPB05KYKpUoeXn7YfCfojacoDoDaiIPBF5kH+MjrtBIy4YUSbjd33PcM9/92CR2w6Fevwyjh9MlNfW8H8W0bTNSmyUanz1mJkl3g+WJnT6tfR0dH5/RNlMTK+Ea+ZaSM68cO2w7y5NIvbJrWdrcRJj6LAqlcDGSYN7htt4p8rPNwyzESMpYFFkLcqIFrRzGBKVVUWbMrn5Z/3kF9WrfwWAR5gfymb8sr4dG0uvdtHcfdZ3TmjW/M9hnR0Wovx3ROZ0D2RxTsLcIfouQkgIKMg4ST8jfrqrFfFmvcp+vYQkkXEkmYh6fyk4OuICioKB5xZXPPD1cyaOIvR7QPr3lKHl0U7jm+N7JUVVmUFMnP1e8eW7Cpkzb4S2kdbmHv9CF0FsBmcFFuCS3cXaqqthOoBAAGVv/9tONAawwuLnYcruGvBHjxi6B88VYUqj5/L3lrNg19s4Yq3V3Nmr2S+veOMExJIAfRIjqTC7edQeeMGeTo6OjrHiyAIPHdxf95bkcO2/PK2Hs7JS+k+8DYs9DC0vcSENAPPr2xClOLgukBgFiayovK3zzfzwOdbyCp04PIFS2j7FBWPX2FjbhkzP1zPG0v2hn0dHZ3WRhAEXrpiEEPT4rAaQ1v6CsioiKjHsVSOHtaJrk8l0eft3vR8uSdpd6dh69Z4lZHL7+LOX+5kW9E2AH7ZWaDZixrOGrkxVAJr6H3FTqa8tIzc4jB6M3WAkyQzVdwCHgAApU5vSw8tbF78aTdujXJFaLr3q8ThZfuhipClzluS6r6pVVnFXDRYl0jX0dFpXdrHWPn7lF7c859NLLhtbJ3SE52juMtAbPxr+omJZsbMcXDniEbkj0UjeCrAGrp1h6qqPDR/C19v0u5l1sLlU3j5572YDRLXjk0P+Vo6OicCk0Hkg2uH88x3O/h4zX4EhAbFHEQUFES0SmwbW8vVRcbS4WMEMVj8oXRZKUU/FOEt8CJZJKKGRJF8STJSRGC965bd3L3kbr6/+HuKHR58cvBnMNw1clPIikq5y8flb6/ih7+OI6oN22ZONU6KYEpVVVSNotNwPAAAFFXLg/rEUVTlYcmuQrSGEUpdK0B+mYtYW9t4AozKiNeDKR0dnRPGxYM78P3Ww7z88x7undyjrYdz8iE0HWD2TZI4r7uBZ5d76ZXYwPGqCmEuuBZuPcxXDQRSjS0mXT6Zf/6wk+HpcS2mOKuj01KIAkzpl8LBUicb88rwygqyoiIIgBp4PjbCRJnTq5nMDXUtByBFZCEI/qBzFC0sonBhIR2v74i9tx1fqY/8ufnkPJ9D+kPpiEc3lko9pfw/e/cdJmV1Nn78ez/P1O27sEtdegfp0ixgN2rUmGA3ghprbNHExNcaY15fY37mjQZ7SfSNBWMDu0awoBJRQBYUpMPSWbbvTju/P2YWZ3dmd2cWdncW7s91zcXMPOWcZ9g585x6L9q2iEAoJ+79bbL3yHWaW7CipNLH/32+niumDUj4nAe7lGgKzPY64wZujI4BkIj2rkU/v3AD8ebu1Y1rzTvuCtIGT8FyeRDbQdqAieQedVG9fStrAyxYHbviTFuY3K8Tn69tn7SVUgcfEeGPZ4zg+f9sYPHGPe2dndST1jm8gEQz7pzm4bGvfGwua6xB0YAruQUiHvhgVdzJ+mULX2H3B4+RPelMev7yWXpc8RSZY0+ietUXe/fxBUI89tGapNJTqjUZY3jpy41M/dM8LnhiIe8s38a2slqCkXGrArgs6GJ2UlFZSbxpVcncy4VPWovY9b+/weog21/dTvfzu5M5MhNxCK58F4VXFuLb6aN0wQ/DnmsCNTxV9BQ5XhcuO7aylOw9MiT2/a0JhHj8k7WEGo7pVY1KicrU5P6d9v5BR2suBkA0t8PihOFtH18q2oLVu+IGrUxmXGu1P8iSTe1zUzGgIINqX5BNJTpeVinVNgoyPdxx6nBueHExNQnGTzloZPcMP5oxIM/irOFO/rowXsVLYMBxxG3pa8SKLWWsizNvItGbyZCBt4u2amwblRKCIcONs5dw62tFbNhdRZUvGDOCyGNquFaeo4QMakxsw3zl8nls+fv1GH8Nu99/hG0v3k7NpqKm063qH/Ne1aoqQv4QWeOy6r1ve2wyR2ZSUfRDyAGD4dPNn3Jo3+y4PVPJ3CNDcpXBGl+Q+St3NHl96gcpUZnqmZvGuN65cbdlTTiD3KMvpvSzF9j0wHlsemgG5V/NxTuwfsXEAOdN6tUGuW1caXX8H45kxrWGDOyqaJ+5XyLCxH6d9q74opRSbeGUkd0Z0i2L+xqJo3LQEoHDrgNnerO73jbVTaUvTkuyMw0OuyapZP/vi/VxQ5Uk0zBoiTBnaXFS6Sq1vxlj+N3LS3nzm62Nzo9y4+Ml1x18EBpLNZ6Y7XW9OZ5+45G0bHrG6c2JK+SJqbQFK4I4MhyIHdu44ch2EKioPyzQtmw6ZQUZ1i0rZn9I/B4Zkvv+VvqCfKSVqYSlxJwpgMum9mfxxj1UtSAGgAgcMaAzBZmxX4K25G5kAnWy41q9rv0zmbAlJvfrxGdrdjF9fGG75UEpdfC567QRnPiXjzh+eFcm9M1r7+ykjhE/hbd+E/P2uuvqBzwuzLaouSXODVd6J+iV3Gpfa3ZUxqzaB8k1DFb7g6zXVcFUO3t72VbmLN0StyJVN3fI2rWWaW6ozH+WzCln15v7FL38uOXyUvHVXBCJu/x4LIk8fvgy2Rk2gYoAJmhiKlSB0gCOjPq35RYW/qCfy6f15/oXFrfoHrnuWks+egZMiM2zZjSxaMYPdlQ0s0qo2isleqYgXBka1TOn0QpJU9KcNr87aWgr5Co5PXLjr82fzLhWr9Oia1b7RaOe3L8Tn6/ehWnnxTyUUgeXvHQXd//kEG6cvYTK2thJ2wctVxqc9jdwtCD2i8MLP30iqSF+QNwbNqjfMJiI8hod5qfa1wP//r7JuX/9Jh/Pphuy+d1V55A99kd7e5s2PXQRGx84j+r1S/f25vh2bwZMEnOUQlRvOp+q9RdTvelcancdjrdvHuIQyhaV1dszWBOkfGk56cPq90L7Q36y3FkcM6SAgQUZOOP0aDWn7lrTh00DhO6XPZ5Qz1pL0jpYpUxlyrKExy8cT9/O6UlVqLxOmydmHMqAgvaPvn7Oob1Ii9OrlMy41qCBkw7p1lZZjtG3Uxq1zlXc9el93PLJLdy+4HZmLZ7FypKV7ZYnpdTB4bhhXTi0Tx73vPVte2cltYw4A469I7kKldML05+GwglJJ5fliT9oJdkJ73np7bMyrVIQjvu5ZmdFzPvRc4duGLaLTBdstbrgGjCl/tyhUIiq5fP39sZaTjd2VpeE5yiBRbBiOMGqgQTKR+LbcTzVm28j96jjKX5mK+VLyzEBg2+Hj42zNuLMc5IzpX74gp6ZPfE6vDhsi39cPJEeOd6kwkhEX2v2xDMQp4vq1f9pfNGMupwLdM3S4L2JSplhfgDpbgevXnUYVz/3NfOXbyGEEIizxj9AusvG67J5euaElFl+dXL/TmR7nXFb9eoiYZd+9gI7596HuLy4uwwga/JZe/exBI4eXNDmMaYAqvxVvPL9Kzy97GmC+SXMXv1D966FxVPLnqJ3Vm8uPuRiju99PPZ+imuglFLRbvvxMH70l484fngXjhiY397ZSR2TLg8vRjH3OvBXgy/2JhEkPEcqowDOeAwKD21RUqN75fDF2t3UNljSLLphUCwbT98xiOWgZt1iajYsrXdzlu62GdrIPA+l2sJzCzfij7MsX931niulAAAgAElEQVTcoexBh3K2/SQuCVBFbMU/a+IZlC54ARPw7e2NdWR1JmPUCU3ey/2gwf2rCaeROfpKsPqz5bln8e/cgOW1yBqbReFlhVhRAYXTHGlcPOLiva+zvU7mXH04l/5jEYs37sEXCBFsZhRR9DwpseyEv79uh80po9qvYb+jSanKFIDHaTPrmG58/ORtfHDF73l52U4sKzyZ1ZjwkqujCrO5fGp/pg0uiBsVur2ICJcd2Y//efvbuLE5mhvX6nbYXDq1X2tmMa5tlduY+c5MdlTtoCZYE7M9RIiaYA3flXzH7Qtu5+VVL/O/R/0vac6mo3grpVSysr1O7vnpSG56aSlva+DI+oaeAoNPgtUfEJh7B/aeIsS2w3GkxIJBx8OUa8O9UUkO7Yt23sTePDI//tLmiTQMAgjS7ivsqoPbup2VBJuY+9fN3oNE5jN1lrKY/VxdB+LudQg1a7+q1xubyBylptlkjjyJzFHH4u3xTxyZ8XviDYYT+55Y771Mj5PnLp3Ess2lPPHJWt78Zgsu26I2GMIXp+LYcJ5jot/fwjwvw7unRkdFR5BylSmAnY8+xohTj+Pos8Zz6xnhpbpLq/14nDYFmR7yM9tvTlFzzp/Um3eXb2PR+pKYVr2meJ02M6b0Zmyv+KsatpZd1bs4541z2F2zm6Bpfhx8daCar7Z9xSXvXsLfT/w7TltvdJRS+9eRg/I5akgBd81Zzp+mj2rv7KQWy4KBx7HH8z2hjAoKrr0yXJFy7r8hOV2yPEzu14l5jazm1dzNpMsWzpvYK6nhSErtb1W++HMv6+b+ZYQqCEVmuxxqfceLwWlUUv97lDv1QrasW8zud2eRNuRwTCiECQbi9ubE01SAXIyL6s3nkt7/z1jO0nrHeR1erh59Nd5GhvaO6JHN/WeN5s7ThrN2RyWPzF/Nm8u2Nnqt0QugNff99TptrpgWu6y7alzKlXT+zZspf/tt8mbOAMI9VQMKMhnXO4/h3bNTuiIF4LAtHr9wPCMyDO5gYpNvvU6b6eN78psTh7Ry7uozxnDlB1dSUlMSU5Eq+biEVbesoujSIr695luK/15MsDK8jy/kY1XJKu7+4u42za9S6uBx80lD+XztLj5Ysa29s5KSqouK8AwfDq70/VqRqnPdcYPwOFt2i+B0WMw4rM/+zZBSSSit9uNsZORS3dy/jSuX7+2ZOs5ahE1sA7grvw9pg6fg6jaQ6tVf4ite0eTy49ESCZCLEXwlk+od53V4Oa3/aVww/IJmrzPL42RUYQ5Du2XFnRST7DxHAcb2yuHUUT0S2l+FpVxlaucjj5Jz9tk4ctu2h2Z/cmzbwh9e/yMXjcwj0+Mg3R07v0iANJdNt2wPd502nN+fNgLZh2EZLbFs5zLWlq4lYOq33ux8aydbZ2+l65ldGTZrGP1u7Ydvl491960jFOltqwnWMGf1HEprS+OdWiml9km628GffjaKm1/5hpLK9om9l8pqipaHK1OtZHRhDneeOiLpCpXXafHEhYfSLVsnr6u2FQoZPlm1k/Mf/4Lxf3iPL9eXxN2vbu7ft+8+z1vfVlHlNxAKMH7dk5R9+HjM/jmHn0ftpuVkHHIs7p7DKfzlsxRMvwNPz8ZXkU48QK4Tf8kkjLFxWS5clouLD7mYmyfenNS19+6cvs8LoEF40ZjHLhyfUlNoOoKUGubn37yZ8nffpd9bb7Z3Vlos5POx+fpf0eWyX3DTBUdyfSDEu8u38uQna9lUUk1tIESay2Zo1ywuObIvk/t1avNKVJ2ni56mNlA/jkCwOsj2V7fT4+IeZI4MxzFx5bsovLKQlb9eSemCUnKPDFd0LbF4ZdUrzBgxo62zrpQ6CEzq14lTRnbntteLeOCcMe2dnZQR2L2bUEUFzl6tG6j+rEMLcdnC7175hkDQEIgXfCrC7bBw2MKTFx7KxH6dWjVfSjX07dYyLnrqP+yp9je6tP8PQmRN/BFWeg63fPwoM14pJ9MlHNJtPnkTf0egQQ+VM7c76UOOoHzRHJz5vRPKTzIBcgULV80YZk4cx5mDz6QgrSChNKIdP6wLv/tX/G2JzpOyLeHpmRNIc6VU1aBDSKlPbOfDj5Bz9lkduldq+71/wtG1C7k//zkALofFKSO7c8rI7u2cs/rKfeXM2ziPUINCo2pVFSF/iKxx9Vdhsj02mSMzqSiq2FuZqgnW8MyKZ7QypZRqNb8+YTAn/fVj3li6hZNH6upSEOmVGjasTRrifjK2J2N65fL4R9/z0mdrsD0eqgMhQia8Aq3XZeOwLH4+uTcXTO5NQaan1fOkVLRF60u44IkvEqhERUgAy72TLj/5kPEnDGDWzo2kE56WsSz0L870DYs5JPuwc6go+rDZU9fNkfLtWA8Ytr/0+2aD41q4+cXQm7lizMDE8h+Hx2lz1qGFPPP5evxxVt1IZNGMwV0yOaSnLjrREilTmfJt2kz5e+/R/+232jsrLVb2zrtUzJtH33+91G69TYnaXLEZl+3CF6o/fCZYEcSR4YiJzA3gyHZQvb663ns7qnYQMiEsSbkRo0qpA4DHafPn6aP4xT8WMaFvXsrPm20LNUVFeIbF3vC1lj6d07kpawcXbHudFdf/gY27qyivDZDtdTKgIIOjhxTgtPU3QLW9dTsrufDJhYlXpACMi1BtV2o2n8eink+xa4/BEwAbGGGt42XXHZx7xV+oxUll5BBHVj69b3ylydOWLXyF0i9eotPxV4EIO175IxmjT6R61RdNVqaCIdhTue+Bymce1je8HHwwic8iwuu0ufbYllfmDnYpU5na9cjD5JxzNnZOTvM7pyDfhg1svfNOCh95GDs79Wv2lf7KuO/bGTaBigAmaGIqVIHSAI6M+n8ytmVT5a8iw9X+QZOVUgemMb1yOXN8T25+5RsevWAcGKhdvYeqr7YTKKuFoMFKd+Idkod3VD5WnLkDB5KaoiIyTzihTdMsnTOXglN+xOAxOjFdpY6731zR6Kp9Ta+k5yRY1Zdg9QCu6hLgueKtZERiNg2xNrLAfTVzgpN5OPhjik0nbEIEsQhi4cNJwxhSdXOkOp10HWmDpxCqrUScLggFm13xD+DJT9cyqX8njhqc/BC/OoV5adx/1miue+FrauKE52mM12lz7sReGspgH6REU5Jv0ybK33ufThde2N5ZaZGQz8fm666n8+WX4z3kkPbOTkLSHPFjRKUNSEMcQtmi+jEXgjVBypeWkz4svf77oWCjS3cqpdT+cu2xA9myq4oFzxex5Z6F7HpmOVVfb8e3uhTfujJqinaxZ85qttz1OSWvfk+gtLb5k3ZQNUVFeIa3Xc9UsKKCyk8/Jev449ssTaWas7Oilo9W7iDeVL7EVtJz4tt1JOtcTi7q1oVSS6irlnnEz3THR3zg/jWzXXfyR+cT3Op4liwqiQnGS+wcqWQXfvAHDVc8u4i5S4v36TM5cURX/jx9FB6nRSJrSHidNj+f3JtbTm58MQ3VvJTomdr1yCPknntOh+2V2n7P/+Ds0YPcC85v76wkrFt6N3zB2BWy7DSbgtMLKH62GMtjkTEsA3+Jn+JninHmOcmZUv//KMedg20d2K3ASqn25/SFeNCkEViym1DcRYDB+MKtsZULt1K1ZAf5F4/A1TOzLbPZ6gIlJQRLS3H1Tmwi/P5Q/t77pE2Y0GF/o9WB6Z9fbIj7fsNeojppAyYSqi6n+ImrCOzZiri9pA2aSI8ZaazIhp/16MY1u/dwXFU1TtuFHVmga4S1nhGsB2c6d5THb/RvGBwXEl/4oU6NP8SNs5fQI8fLmH2IOXryyO4MKMhk1off83bRViyB6qieKqctWCKMLszhqqMGcOSg/BanpcLavTJV1yvVUedKlb39NhUff9wh5klFy/HkMKHbBD7d/CmG+s06+SflY6fbbH1hK77tPiyvRdbYLAovK8SKWibXbbs5Z8g5bZ11pdRBxviDbH90KY5dNTgaqUjVEzKY6gA7HvuGgitH4eyS3vwxHUTN8uV4hg5FrLYbWFI2dy45Pz2jzdJTKhFvfrOF2kDscLbGVtIrW/gypV+8TOeTr8fTexTBil3sevdB1t+3gX63d2Grw8HNBZ3572CIa6zOnB10Q9VuEIG0TphDzsb3L2fcvMQLjguxCz9ULp9HyQePxB96SLhC9cc3VzD78ikxaSRjcNdM/vecMZRW+Zm9aCNLN5Wyp8pHuttB//x0zjq0F4V58UcoqeS1emUqFDLMX7WDxz5aw6rtFdT4gridFoV5aVx8eF9GvtBxe6V869ez9fd3Ufjoo9hZWc0fkGIuGnERX237iqpAVcy2vKl55E3Na/J4YwzTB09vrewppRRAeNjezhposErVwk1L+eOHD7Fy5zosy2Jgp97cfszVjO4WHrJiaoPseGIZ3W6aEHdRnY6oteNLNRTYuZPqpUvp+eADbZamUokoq/bHfT9eL1Gotoo9n/yTTj+6Fm+/cQA4sruQf9pvKH50Zr2wL+W2xZ+sCqZNf4Gu6T/MIxLA9dpb+OJU4KKD46YPOTx+fqMWqPD0HYvYDqrXLopZoGLpplLW76qkd6d9bwTKTnNyyRH99vk8qmmtVpkyxvDMZ+v5679XUe0LUhm10kp5Leys8HHT7CVQM5qL+w3k2pDpUEHCQrW1bLruejpfdSXeEW33w7Y/je8ynvy0fDaWbyRkEp+sCOFeqak9p9LZ27mVcqeUUhCq8lO1ZAcE6lekymsrmfnSb7n7+F/x4yFH4QsGWLhpCW7bVW8/Uxuk5tvdeIcfGLGPaoqKyDz2mDZLr+ytt8k4ahqWV+fGqtTS2GigeL1EtZtXhHurBtfv8bFcHtKH9qKiaNveyhSE42jO2ziPs4ecXW//Hjle1u6MXcAreo6UWDaevmMQy0HNusXUbFhK9pSzGh16mDZgYr1zhYzh6QXruP3HHfPe8mDUKuMEgiHDDbOX8N9vfcvOCl+9ilS0Sn+ISqeHx77YzM+f+IIaf/LLObaXbffcg6t3b3LPPbe9s9JiIsLDxz5MhjMDSWToTITTctItvRt3HXZXK+ZOKaWgctG2uDdNa3ZvBOD0YcdiWzZep5upfScwtKB/vf1MbZDy+RvbJK9tIbz4RNvdZJXNnUv2Kae0WXpKlVb5+X57BcuLy9i8p5pgI8Gic9PiD7mL7iWqE6+3CgAJ4sh2EqiovyJgbbCWkpqSmHNfcnhf0hpZLTRrwhnkHn0xpZ+9wKYHzmPTQzMo/2ou3oGTkwri6w8aPl61s9n9VOpolZ6p219bxlvfbKk34a0p1f4QX64v4bJnFvHkjENTvoeq7M03qfx0QYebJxVPz8yePPOjZ5j5zkzKfeX4Q/G7zet4bA+9snrx+PGPk+bU8bZKqdZV8WkxJs5vSb+8QiyxuP6Nuzl1yDGM6TGcHE/8xSZ8xZUE9tTgyOnYAWWDpaUEd+9us8UnfBs34tu4kfTJzd8AKrUvgiHD/JXbeXj+Gr7eUILLYSEIwVAIt9Nm5pQ+nDuxd704c6eP6cHqHZVUN2iIj9dLZLnTCVWWsvvfT5B39MVRe9sEq7bGhH0xGIImtoH/9DE9uOuN5Y1eR2PBcSuKPoxfmWtEeU3T92Iqtez3ytQnq3byr682x61INbXmf20gxMK1u5n95UbOntBrf2drv/GtW8fWu/5A4eOPYWceGKtE9cvpx8unvszTRU8ze+VsjDEx86jSHGmkOdO4cNiFnD3kbDyOjn1TopTqGILlsauOAmS603n5vAeZ9cU/+c3bf2JH5W6O6j+Re0/8Dfnp9ed7ikMI7u74lama5ctxDx2K2G2zgmrZG2+QdeIJiDN+D4BS+8NXG0q49B9f1psSEh14ttofYta81fxt3mrOHl/I7acOx7aE6eML+dM738U9Z8xKek4PiGC5oxuBQ1jOJVQs20OXn3Wpd7zLcpHjjp3Ln+52cM6EXjy/cEPCHQbQ+AIVjdEg2B3Lfq9MPTT/+5hWAkhs4l21P8hD81Zz1qGFKdnjE6qpYdN115N/zdV423CYRVvo5O3EDeNv4OoxV/P++vf598Z/U1JTgi02+Wn5nNz3ZCZ1n4Ql+gVXSrUNEzIxi05EG9i5D/effDMA3+9azzVz/8AdHzzA3069PWbfkC+5eaGpqC3jSxljKJ0zl253/b5N0lMHp/krd3DZM182G2S2btW+2Ys2sX53FU9cOB4MdM3ysH537CJaENtLVPrFS5QtfBV3t0GR1fy2sufjv8UN+yIiTOo2Ke55bz5pKN9sKuWbzaVxVxOMJ5EFKqJ1znA3u49KHfu1MlW8p5ov18WOMW1qzf+GE++2V9Ty9cY9jN2HNfZbotoX5PM1u9hd6SNkDDlpLib0ySM7akzutv++B3e/vuScfXYTZ+rYXLaLk/qdxEn9TmrvrCilDnJiCdjSZIWqzoBOvTlzxIk8u/j1uNstd8ePh1ddVETG1Kltklbtd99hqqvxjh7dJumpg09RcSmXP7Oo2YpUtGp/kIVrd3P+E1+wdmclRwzsTFmNn5Kq5ofFZU/8GZYni5IPnyCwZyuW20nWeBddr+xbL+wLQL/sfgzIHRD3PE7b4pmLJ/KLf3zJovUlcTsQGmpugYrcoy7au2+ay+a8iak7QkvF2q+VqblLizFxfvOSmXhX4w/ywn82tlllavWOCp76dC3/WrQZ2xKMCUddskXwBUMcP7wLlx7Rn16LP6Hq88/pcwDMk1JKqY7CkechsKM65v3vd63ng9WfceqQo+mWVUBx2TZeW/EBY7vHjhowgRCOzh1/Nbqa5cvJv/LKNkmrbO5csk4+uU3jWamDy62vLmu0ItLUtJBqf5Av1uzmvukj+em4QlZuK2f6w59RURMgGO8mNErmqOPJHD0NO2013sJnEImtyKU50rjokIviHP0Dr8vm7xdN4LXFm3lo3mo2lVRTGwjSyFoZQHJBfH88qnuT6avUsl8rU5tKqvEFY/8wG11FJQ5jYHNJ7A/n/maM4b/f+pZ/LFhHIGQINPINeGPpFt4v2sqE4uU88Of/h52R0ep5U0opFZZ5ZE/2zFmNaTBML92VxuLiFTz2nxcpq60gy53Bsf0n819HxVY23P1ysDNdMe93JMHycgI7duLq1/oxY0woROkbb1L48MOtnpY6OK3dWUlRcVncbYlMC7Es4fM1u/npuEIGdcnkjWsO54pnv2LV9nL8QRN/BUCpBSycuZ/hLngbkdh9nJaTPtl9OKZX8+EHbEs4Y2xPzhjbk2WbS3lt8Wa2lNZQUunj8zW74naoN7ZARR2XLUwf1xOPs+P3pB9M9mtlqrGlzZOdeNfaS6QbY7hx9hLe/GYrNc2Mdw0ZqA4YvugylEs/LeWZoSGdGKiUUm3EOyqfPa+vjnm/W2Y+D51+Z7PHi8si88ierZG1NlVTtBzP4MFtsvhE9VdfYWdk4Bk8qNXTUgenpz5dG7fCk+i0kGDIMGdpMbf9eBiZHic9c9OYc/XhfLe1nCc+WcOri4sJhUIEjB+MjTjKcOZ9hCtnEWLXxs2Ty3LRNb0rjx73KE4ruUVXRvTIZkSP7L2vH/hgFbPmrU5oCGAd2xJ65Hr59YlDkkpbtb/9WivolB5/wly8Nf+b0ljsgP1l1rzVvPnN1qT+yGuMsGTTHn738jetmDOllFLRLJdN+pRuiDP5n6sgYGW7cffPbnbfVFdTVIRnWNssPlE6dy5ZGltKtZA/GGJLaTWrtpWzqaQqbgP5u0Xb4o4ISmZaiMOyYubpD+6ayb0/G8W3vz+RRbcez4tXDaXfmAcpGPpX3J0WxK1I2WLjsT2MKRjDC6e8QLZ738uLXx49gPMm9cKbYA+Ty2HRI8fL85dOJsPdKlGLVCvar/9j43rnku62qaxtfs3/xibeeZ02hw3ovD+zVU+1L8iD/46/4iA0N043xJwlxfzquEF0z+n44++VUqojyD6hL/7NldSuL4NEJ6sL1Fhwt9fHn2oCZHs79vLeNcuXk37YYa2ejvH5KH/7Hfq89FKrp6U6jppADevL1lPuK8dlu8j35tMto1u9fTburuLvn63j+YUbCYZCWJYQCkHQGE4Y1oVfHNmPkT3Dq+ZV1AbipJLktBAMe6rjh06wLCHb62RCz8G8f+ZbfFb8GU8ue5LF2xfjsl17jw+EApzU9yQuGHYBA3MHJvGJNE1EuOXkYQzuksm973xHVW1g77Lv0bxOm5AxnHRIN+48bThZno5dTh2s9mtl6qghBbhsi0pi/2ASnXgXMoYzxrXekIw5S4ppbP2IRMbpGgPPfLaem36k3bBKKdUWxBI6zxjOrue+pXZVScz8qRhOCzvdSZ9fHEL3T9cw/eEFPDVzAj06cCNYTVERnS69tNXTqfj0U1z9+uHq2aPV01Kpb33Zev5v+f/x6upXsSQcSBfAH/LTO6s3F4+4mMO6HcVNLxXx4Xc7CBmDP85koTe+2cL7K7bTp1Maj884lMaW8UpqWogBK4EFwSyxOKzHYRzW4zB21+xme9V2qgPVZDoz6Z7RnTRnWrPnaKnp4wv56diefPL9Th6Zv5qi4jKq/EGcttA5w80Fk3ozfVxhvZWjVcezXytTtiXMPKwvf/vw+7hr7zc38c62hFNGdmvVmvnD81dTFad1INFxur5giGe/WM/1xw3C5dC5U0op1RbEYdHpvKHUrNhN+fyN+Iorw5Nao4YKictCPA4yj+xJ+qFdsNwObv/xMJ74ZC0/nbWAJ2aMZ3j3jjfkL1hRgX/bNtz9W3/xibK5b5B1ysmtno5Kbb6gj5s/vpl5m+YRDAUJmNiepJUlK7nzszup9t2Br/jn1Ab6Nnq+kAkva75yWwUn/e/HeFwW5XGmLiUTj8kSISctuYVl8jx55Hnymt9xP7Is4chB+Rw5KL9N01VtZ78PzLxgUm+e/nQdtYH4Xa9NcTssfnn0/utmbSgQDLF2V2XcbcmM0w2GDJv3VNO3c/r+zqJSSqlGiCV4h3fCO7wT/h1VVC/bRbC0FhMMYWe4cPfPxt0/p174ChHhkiP60S3by8+fWMj9Z43ucDc1NcuX4x40EHG07lyKUGUlFfPn0+W/bm7VdFRqqw3WMvPtmawqWUVtMP5iDXWqAlVggaP7k9ibzyFY0fS8vqAxlNX48TgsHBY0bHdPZlpIIGSY0KdtK0ZKxbPfS+bcdBf/vHQiP31oAVW1QZoPtRjmsYXHfj6+VSsoZTUBnLaFL06vWTLjdG1LKKtuPkCcUkqp1uHMT8N5VOLDc04e2Y2CLDdXPPsVvzlxMGeOL2zF3O1fNcuX4x0eGz9rfyv/94d4x47Bkac3qAcrYww3zLuBlSUrm61IRRPLj7fHc1Stv5RQTWGT88+NAV/QECeSDpDYtBCHJfxsXE+8Ll1CXLW/VmnmGtI1i9euOoyzH/2cKl8w7rC6OmkuGysY4O4VLzGpoOku3X3ldliEGoknlcw4XWPQIX5KKdXBHNonjxcum8SMpxayuaSa644d2CGCsNcULSd90sTmd9xHZXPnkq2r+B3UluxYwsKtC+NWpEo+LmHnOzvxbfdhe2yyxmXR5WddsNPD90xi+fF0eZ2ts7s3O/88GDIINNrg3ty0EIclzDyszz5erVL7R6vVCAYUZPLpb4/mjz85hEFdMvA6LTLcNl6nTbrbJs1l0zPHy80nDeHz205k4uh+FN/0W0wowZWaWiDNZTc6WTGZ5dt9gRCdM+IvA6+UUip19c/P4F9XTOHf327nNy8txd9Y83gKqSkqwtPKPVOBkhKqFi0i4+jmg5WqA9ffi/5OTaAm5v2db+1k6+ytdD2zK8NmDaPfrf3w7fKx7r51hKJG+5jQJvZ88ix5x11B2uApWC4PYjtIGzCx3hA9AI/Twt2Chmmv0+b8yb3pl5+R/AUq1QpadQC222Fz+pgenD6mB99uLeP77RVU1ARIczvolZfGqJ7Ze1sF0397E+svnMGuRx6h8xVXtEp+RMILXLy2uJigqd8eksw43eHds8jP1MqUUkp1RAWZHp6/dBJXP/c1Fz39H2adN5bMFF2SOFhRib+4GHf//q2aTvk775Bx5BHYGToX+GC1u2Y3H236CNOgvyhYHWT7q9vpcXEPMkdmAuDKd1F4ZSErf72S0gWl5B6ZC0DVmgpMMLH559X+EKMLs1m1rYIqX2LTQrxOmxOGd+HmHw1N+vqUai1tFhlsSNcshnTNanS7uFz0+MtfWDd9Op7hw8k48sh62ytqA5RU+ggZQ7bXSbbX2aLhGedml/NG0E/Qir30RMbpprttLp/Wuj9qSimlWle628GjF4zjtteLOPORz3l65qF0yfLE3dcEDYFd1YSqA4glWOlO7Fx3qwwRrKgN8MbSYlbvqKS0yk96+S7yx55ILyPstwWcAz6o2gm+KnBnQlonSufOpdNFFzV/rDpgfb39a5y2E1+o/gJiVauqCPlDZI2rfw9ne2wyR2ZSUVSxtzIVqvBjZzgSmn8OEAgaXrnqMC5/dhFb9tRQGwgSbzZGmsvGGLhian+uPmZAhxieqw4eKRVm2dmlgB7/789suvY6+jz/HHaPnvz72+08PH81Szbu2TtPyR8M0S3by2VT+3H66B6kJxAtuvqbb9hx//3kbN5Mn8OvYVUVcb+wzY3TddkWxwwpaPE1KqWUSg0O2+Lu00cwa95qzpi1gKdmHsqgLpl7twfLaqn4fAsVC4rDPxiWgAlXrhw5bjKn9sQ7Kh9rP0yC/357OY99vJbXFm/GQqiKCizvKZjMX+96nzPG9uCSI/q1fKGmbcvhs7/BspcAAcuGUAAjNlm2m4wRd+zzdaiOq6y2jJCJs0BXRRBHhgOxYyswjmwH1eur9762M2yCFf7E4kQBCAzqkskHv5rKkk2lPPbRGt5dvhURwRLwBwzdczxcPq1/wvd7SrW1lPurTBs/ns6XXcbrN/03fxxyOrWB0N6o0YGohSw27K7i7jdW8Ie5K7jxhMFcfHj8+AIlvuYAACAASURBVAa1a9aw4y//S/WSJXS+8kpyzvgJj5T6OOWBTxqNwN0Yr9Pi8QvH47B18QmllDoQiAhXHTWAHjleznn0cx44dwyT+3ai9I01VHyxJbxTILblLbCzmj1zVrNnzmryzhyMd0TnFufhla828btXvsEfCBEn3ik12OAP8sJ/NvLyV5u5/6xRnDiiW+IJVGyH58+Frcsg5IdQ/d8+AXL7VCOPHgb9j4YzHgO3zkc52NiNVH7sDJtARQATNDEVqkBpAEfGD7eSaQPSEIdQs+41vP3OaDbNTunhKRMiwujCHP523lgCwRDlNQF8wRBZHqeu2KdSXkrWCj4ZeTT/VXgiu6v8eytS8VT5glT7g9z3zrfc8XoRJmoelH/LFopvuYX1552Pd+Qh9H/7LXLPOhNxOunTOZ3nfjGJLI8DK8GeYq/TZtb54xjXW5eMVUqpA83pY3rwwLljuPafX7Ns1ldULtwarkTFqUjVMb4Qxhdi9wvfUfH5lhal+/KicEWqxh+/IhUtEDJU+4Nc98Ji3vymOLEEStbDQ4dB8dcQqI6pSNURCUGgBlb/Gx47Cqr3JHklqqPLdediSextYV0FqWxRWb33gzVBypeWkz7sh55SO82my0+6sOudf1C74UVC/hpMMED16i8p+fDJesenu2xOG909Jj2HbZGb7qJLlkcrUqpDSLmeqc/X7OI3/1pKrST+Bar2h3jhPxvpkuXmF6M6sevRxyh9+WVyzjqL/m+/hZ0dG/H+kJ7ZvHHNEdz62jIWrN4FEBN/ymkJliWM6JHNnacOZ0SP2PMopZQ6MEzp35kXBvYktHgH4YWbE2P8IUrfWIOd68Y7OPEGt2WbS/mvV8MVqYaaitNT4w9xw4tLGdQlkwEFmXHOHFFdAk+fFJ4fFWf4VlyBGihZB8/8BC56BxyuhK9HdWzju46PO8zPTrMpOL2A4meLsTwWGcMy8Jf4KX6mGGeek5wpOfX27/yjzjiyHex8Zza1xc8jrrSY+ed1TjokiR5WpVJUSlWmjDH8evaSuD8s0PSPS7U/yP1vr2D8rfdTeNw0+s55HWdB03ObCvPSeHrmBLaV1fDs5+t5bfFmyqoDhIwhw+PguKFdmHlYX/q0YiBhpZRSqSFQUoNr2S6IU5FauGkpf/zwIVbuXIdlWQzs1Jvbj7ma0d3Cq4oZf4g9L3+P57eHJjw5/m8ffk9tnCDyZQtfaTZOjy8Y4uH5a7hv+qjGE1jwIFTsqFeR6vOXcqr8sPbaDNJd4Xw+/pWPZ5f6mTcj8lsX9MGOb6HoZRh1dkLXojo+r8PLaf1P46WVLxEw9Xsw80/Kx0632frCVnzbfVheC3c3NyF/iBVXrYiJO5UzJYecKTkEawuoWvOrmLSctnDWoYV4nNrzpDq+lKpMfbm+hF2VvrjbEvlxEWP4/Ff3cuhPxieVbpcsDzccP5gbjh+8z9eglFKqY6r4rDhuFNHy2kpmvvRb7j7+V/x4yFH4ggEWblqC267faxOq9lO7phRP/5zYkzSwu9LHB99uj1kIKVRbyZ5P/o9OJ11H2uApe99PGzCRtAE/BO4NhgxzlhRz+4+HxV/WPeiH/zwGcYKvBg387xc+bj6iiRAf/ir45H6tTB1kzht6Hq98/wqBYOxw0LypeeRNDfe87nxrJzve2kHPS3rW66lad986+v5XX6zIgmGWswTLs4lQTc+957EE8tJd/PLogW1zUUq1spSaM/XoR2uo9sfOkar7cWkuCFyt2Dy9rIRgvGX6lFJKqUaYQIjKL7YSb+LSmt0bATh92LHYlo3X6WZq3wkMLagfJsP4QlR8tCmh9GZ/uTHuD3Dt5m8xgcTi9FiW8NrizfE3fvsGhOLPOf71FBf3LahlT00zv5V7NkDx4mbzoQ4cfbL7cOHwC/E6vI3uUxd3qvv53ckcmYk4ZG/cKd9OH6ULSn/YWQK48j7e+9JhCTlpLl64dDJ56TqEVB0YUqoy9dHKHZg4ZXsyPy7V/iDfb69ohdwppZQ6UPm3Vja6rV9eIZZYXP/G3Xy4+nP21JQ3um/tmtJGt0X7ZnMpNXGG+AWry7DSshJaVrraF6SouJG8FL0Mvvi/heO720zr4+C+BbG9VvUEamDl283mQx1Yfjn6l5zS7xTcdvyey0TiTtURMTgyVmALeJwWI3pk8da1R+j0CXVASZlhfv5gCF8w/lypZH5cbEvYUxV/qKBSSikVT6g6EG+qFACZ7nRePu9BZn3xT37z9p/YUbmbo/pP5N4Tf0N+ev0FJ0L+EDOf/IKACQ/FC4YMIRP+N2ggFDIEQob1u+JX3mxvFqGqsoTj9Gz+8BPWv3wvOB2I7UAcDsTpID/7c5oKyfP7o9wc9mQl105sonfAhKBiW7N5UAcWEeHWSbcCMHvl7JjtycSdAsDyR2Kk9Wdw1yYWTFGqg0qZylRT03WT/XFRSimlktLMmhEDO/fh/pNvBuD7Xeu5Zu4fuOODB/jbqbfH7HvB5N7Yto0tgmWBLYIdWR227vk9b33LJ9/vjDnW3WMI4nBStfIz0occ3my2Ow8fTOcx4yAYwAR+eNhLFzZZERpRYHPKIAf3fOJjaH5KDVJRKUBEOKbXMby55k0qA/Ur/snEnYLw3/+9PxuV8MIsSnU0KVOZctgWLtuKu7JRMj8uwZAhJ03H4SqllEqcle4iZjWIRgzo1JszR5zIs4tfjz2P2+booV2bPce43rl8sXYX/gZztCx3OjmHn8fu9x5GLBtP3zGI5aBm3WJqNiytN0/Y7bAYMrQ36RP7Nzw97HkWln/bZB7unOZh7CMV3DC5kYUoxIKMLs1eizow5bhz4jYyRMedyp7wQ8iYurhTXX5W/2/G4/BoRUod0FKqOerIQfnE+75F/7hUrfysySBwXqfNgAKN3K6UUipxzi5piDv+yIfvd63nkYXPs6VsOwDFZdt4bcUHjO0+vP6OluAd0Smh9KaP74nVyA1m1oQzyD36Yko/e4FND5zHpodmUP7VXLwDY+cN/2RszzhnAIafAa6mfwsH5FmcNdzJXxc2MjTe4YZBJzZ5DnXgGpw3GDtOzM/ouFPlS8sxAYNvh4+NszbGxJ2yxGJq4dS2zLZSbS5leqYALj2yH59+v5MqX+wKRFkTzsBKz6X0sxfYOfc+xOWNCQLncVhcfHhfbEtbQJRSSiVOLCHjiJ6Uv7ce0yDWYborjcXFK3jsPy9SVltBljuDY/tP5r+OurL+OWwh4/BGKjcN9MxNY1zv3L1B4xvKGH4UGcOPajy/Em6AzM9spFdpyMkw55pm83HbVDfPLPXH35jTG7qPbvYc6sDksBycM+Qcnlr2FL5Q/Qp3vLhTWWOzKLysEMv5Qzu9y3Jx4fAL2zrrSrWplKpMje+dS+cMNxt2V8Xd3tyPiwHOntCrlXKnlFLqQJYxvgtl766Peb9bZj4PnX5ns8c7OntxdUt8lbJfHj2ArzfsiRsSpDkeh80V0+IM76tjO+HQX8CCB+rFmlp3Xf0FAAqzLWpuyWp4NDjT4PDrk86XOrCcOfhMnlr2VNxt0XGnGtM9ozvDOw1vch+lOrqUGuYnItz7s5F4nMlny+u0uf64gRq3QCmlVItYaU5yTumHtOA3SFwWuWcmF/h9Sv/OXHZkP7zO5BZWqvu9G9srt5kEroaMgvDcp2TYLigYGh4qqA5qBWkFXHLIJXgcnqSP9dge7pzSfCOEUh1dSlWmACb168R900clVaHyOm3OnlDIZUc20UqnlFJKNSNjUjcyp/VMqkIlLotOFw5PqleqzrXHDuTSJCpUXqfNNccM4NJEfu+8OTDjDUjrDJYzsQw5PJDbB85/GRzaOKng8lGXc0q/U5oM5NuQx/bwP0f8D6MLdJioOvCl1DC/OqeM7E5emotfPvc1tYEglbXxh0CkuWyMgRtPGMzFh/dt41wqpZQ6EGUd0xs710Pp3DWYoMHE+w0SEIeFnesm7+whuLq3bOEjEeH64wYxrncuf/1gFd9sLiUUMvijVhZ02oIlwpheOVxzzECm9O+ceAK5veGKT+H582DrNxD0gwnE7me7whOx+h8NP30cXBpUVYWJCLdNuo2eGT15eMnDiAjVgeq4+6Y50khzpHHftPsY12VcG+dUqfYhxjS+FOz48ePNl19+2YbZqS8YMsz7bjsPz1/N1xv24LQtRMIBfrvneLlian9OHd2dNFdK1gmVSkkissgYM76987Gv2rt8Ugc+EzTUfLubsnkb8W8qD1c2jIms2teZzCN64Oq5f4OQrttZybNfrGfl1nLKawNkehwM7ZrF+ZN6U5iXtm8n374CPvsbfBMJxGrZEAqGe63GzwjPscrtvc/XsC8OhPLpQC6bKv2VvLHmDZ5c9iTbq7bjsBwYYwiEAozpMoaLRlzElO5TsJIdWqpUimuqbErpylS0ytoAe6r9BIOGbK+TLK9D4xYo1QIHws0KpFb5pA58xhiML4RY0qI5VSkl4IPq3eCrBHcWeHPBTo1GyQOhfDoYyiZjDGW+Mspqy3BYDrLd2aQ597Gyr1QKa6psSo3SMwHpbgfp7g6TXaWUUgcQEWk0DlWH43BBZvOBhZVqjIiQ7c4m253d/M5KHeA6ePOaUkoppZRSSrUPrUwppZRSSimlVAtoZUoppZRSSimlWkArU0oppZRSSinVAlqZUkoppZRSSqkW0MqUUkoppZRSSrWAVqaUUkoppZRSqgW0MqWUUkoppZRSLaCVKaWUUkoppZRqATHGNL5RZAewvu2yo5RqA72NMfntnYl9peWTUgekDl8+admk1AGp0bKpycqUUkoppZRSSqn4dJifUkoppZRSSrWAVqaUUkoppZRSqgW0MqWUUkoppZRSLaCVKaWUUkoppZRqAa1MKaWUUkoppVQLaGVKKaWUUkoppVpAK1NKKaWUUkop1QJamVJKKaWUUkqpFtDKlFJKKaWUUkq1gFamlFJKKaWUUqoFtDKllFJKKaWUUi2glSmllFJKKaWUagGtTCmllFJKKaVUC2hlqoMTkWkisqmtj21hevNE5JK2Sk8p1X60bFJKdQQdqaxqCRHpJSIVImK3d14OVFqZaiDyB1f3CIlIddTr81ox3Rki8klrnX9fiUgfETEi4mjw/tMi8of2ypdSBwstm+LTskmp1KJlVeMiZdWAVk5jnYgcW/faGLPBGJNhjAm2ZroHM0fzuxxcjDEZdc9FZB1wiTHm/Yb7iYjDGBNoy7wppQ5eWjYppToCLavUwUZ7phJU15UrIjeJyFbgqXitINGtDiLiFpH7RGSDiGwTkYdFxNuCtGeKyAoRKReRNSJyWZx9bhaRnZEWifOi3t8veUgwnzNE5JNIeiUislZEftTIvt1EZKmI/Dryep6I3CUin0au810R6Ry1/6kiUiQieyL7Do28P1NE5kTtt0pEZke93igioyPPjYhcHtlnj4j8TUSkNT4LpdqKlk0J5VPLJqXamZZVMendISIvisg/IvkqEpHxUdt/KyKrI9uWi8hPGhz/i6hrWi4iY0XkGaAXMEfCPYG/kajeexE5S0S+bHCe60Xk9da81gOdVqaS0xXIA3oDlyaw/z3AIGA0MADoAdzWgnS3A6cAWcBM4H4RGdsgX50j578QeFREBiebBxGZJSKzWpC/aBOB7yL5uRd4ouFNgYj0BeYDDxpj/hS16VzC11cAuIAbI/sPAp4DrgPygTcJFxSuyHmOEBFLRLpHjpscOa4fkAEsjUrjFOBQYCRwJnDCPl6vUqlAy6bmadmkVPvTsqq+U4HngRzgdeDBqG2rgSOAbOBO4FkR6RZJZzpwB/DzyDWdCuwyxlwAbAB+HBnad2+D9OYAg0VkYNR75wL/TPZaVRRjjD4aeQDrgGMjz6cBPsATtX0G8EmDYwzhP0ABKoH+UdsmA2sbSSvmXE3k61Xg2qh8BYD0qO0vArc2l4fIsZsSTLNP5NocDd5/GvhD1DV8H7UtLXJM18jrecD/i3yu5zQ4zzzglqjXVwJvR57fCrwYtc0CNgPTIq83AmOBs4FHgYXAEMIF5usN/m8Ob/A5/ba9/870oY9kH1o21UtTyyZ96CNFH1pWxaRrgAGR53cA70dtGwZUN3HsYuC0yPN36vLf1GceeV2vjASeBW6LPB8IlEfKxKQ+b3388NA5U8nZYYypSXDffMJ/nIuiGj8FSHo1lchwlNsJtxZYkfN+E7VLiTGmMur1eqD7/swD4YIGwBn1vO61P+r11ronxpiqSLoZUdvPA74HXoqTxtao51VRx3UnfE115w2JyEbCLSYQbgGeRrjwnQ/sAaYSLgTmJ5iGUh2Zlk1aNinVERzMZVU8Db/3HonMJRORnwO/IlwZgnCZUDfEuJBwz1VL/BP4M/B7wr1Sr0bKxAJa91oPWDrMLzmmwetKwn94AIhI16htO4FqYLgxJifyyDZREzMTISJu4F/AfUAXY0wO4aEk0cNTckUkPep1L6B4f+UhYgvhG5M+Dd7vS9TNRALuiOTrn5L4Mp3FhIcEABAZmlNIuAUYfrhhOSLyfD7hG5apxN6wKHUg0rJJyyalOoKDuaxKJs+9gceAXwKdInleFpXnjUD/Rg5v+Bk39B6QL+E5m+fwwxC/drnWA4FWpvbNEmC4iIwWEQ/hH2Mg3EJJ+Itwf6S2j4j0EJGmxsGLiHiiH4TH2buBHUAg0rpyfJxj7xQRl4gcQXhc8OwW5iEuE15S81/A3SLSSUScInIO4W7pt5I4lR+YDqQD/xCRRP4GXwROFpFjRMQJ3ADUAgsi2+cDRwFeY8wm4GPgRKAT8HUSeVPqQKFlk5ZNSnUEB01ZlaR0wpWiHZE0ZwIjorY/DtwoIuMkbECkAgawDejX2ImNMX5gNvAnwvPX3ou8317X2uFpZWofGGNWEu4mfR9YBTSMb3AT4WEjn4tIWWS/wTRuCuFWgYaPawj/aJcQ7pJ9vcFxWyPbioH/Ay43xnybbB4kvGrLw03k70pgN+FJ09sJt5icbIzZ1sQxMYwxPuAMoAvwZHM3LcaY74DzgQcIt5z8mPDkSl9k+0qggvCNCsaYMmAN8KnRuArqIKRlk5ZNSnUEB2FZlRBjzHLCQ/E+I1w5OgT4NGr7bOBuwr1K5YTngOVFNv83cIuEVwa9sZEk/gkcS7jCGD08OtnPWwFiTHO9gUoppZRSSimlGtKeKaWUUkoppZRqAa1MKaWUUkoppVQLaGVKKaWUUkoppVpAK1NKKaWUUkop1QJamdpHIvK0iPwh8vwIEfmujdI1IjJgP59z77W05bFtRURuFpHH2zsfSrUnLbP2/dh9ISK9RKQiiVhWSh1wtBza92Pbit47Ne+gqEyJyDoRqY78gG2L/PHu9yBkxpiPjTHNLiEpIjNEpOHyn/uNiMwTkUta6/z7qrWvP5LGNBHZFP2eMeaPxpiU/VyUqqNlVmqKlCtGRG5K4ph1InJs3WtjzAZjTIYuja5SnZZDqUXvnVLXQVGZivhxJIrzWGA8cEvDHUTE0ea5Ukqp+LTMSj0XEo5n9fP2zohSbUTLIaWacTBVpgAwxmwG3iISSTrSyniViKwiHDAOETlFRBZHAp4tEJGRdceLyBgR+UpEykXkBcATta1ejV5ECkXkZRHZISK7RORBERkKPAxMjrT27Ins6xaR+0RkQ6QF6GER8Uad69ciskVEikXkopZev4jMFpGtIlIqIh+JyPAGu3QWkfci1zdffoiojYgMiWzbLSLficiZLc1HgzytE5EbRWRpJF8vSDhqOSKSKyJzI59hSeR5z6hj80TkqcjnUiIir4pIOuH/4+6Rz7hCRLqLyB0i8mzkuLdE5JcN8rFERM5ozWtVKllaZqVGmRUpV34GXAUMFJHxDbb/QkRWRPKxXETGisgzQC9gTuSz+42I9In8HzpE5CwR+bLBea4Xkdcjz5v8jJVqK1oOpUY51CBPeu+UIg66ypSIFAInAV9HvX06MBEYJiJjgCeBy4BOwCPA65EvrItwlOlnCEeang38tJF0bGAusB7oA/QAnjfGrAAuBz6LDPXIiRxyDzAIGA0MiOx/W+RcJwI3AscBAwlHrW6ptyLnKAC+IhzpO9p5wF1AZ2Bx3fbIl+w9wlGzC4CzgVkiMqyR698jIocnka8zgROBvsBIYEbkfQt4CuhN+KakGngw6rhngDRgeCRf9xtjKoEfAcWRzzjDGFPcIL3ngHOi8jssksYbyV6rUq1Jy6yUKbPOACoIf4bvEO6lqjt2OnAH4R6rLOBUYJcx5gJgA5HWfWPMvQ3OOQcYLCIDo947N5JnaOIzVqotaTmUMuVQQ3rvlAqMMQf8A1hH+EdwD+Ev6CzAG9lmgKOj9n0IuKvB8d8BU4EjgWJAorYtAP4QeT4N2BR5PhnYATji5GcG8EnUawEqgf5R700G1kaePwncE7VtUCTfAxq53nnAJQl8LjmR82RHXj9NuNCq254BBIFC4Czg4wbHPwLcHnXsHxL8/2h4/euA86Ne3ws83Mixo4GSyPNuQAjIjbPf3v+LqPfuAJ6NPM+MfOa9I6/vBp6MPG/yWvWhj9Z+aJnV6OfSLmVWZP/3gb9Enp8T+ayckdfvANc28X95bNTrPpFrcERePwvcFnk+ECgnfJPT5GesD3209kPLoUY/F7130nuneo+DaZzr6caY9xvZtjHqeW/gQhG5Ouo9F9Cd8Jdns4n8hUSsb+SchcB6Y0wggbzlE/7xXCQide8JULfaU3dgUQJpNinS4nM3MD2SZiiyqTNQGnm+97MwxlSIyO5I+r2BiXVd6xEOwq0b+8PWqOdVkTQRkTTgfsItL7mR7ZmRaykEdhtjSpJNzBhTLiJvEG45+R/CN0e/iGxu7WtVKhFaZqVImRVplT8K+F3krdeAR4GTCbe4FwKrkz1vxD+BPwO/J9wr9aoxpkpECmj6M1aqLWg5lCLlUCP03ikFHEyVqaZEf8E3AncbY+5uuJOITAV6iIhEFQq9iP8juhHoJSKOOIWCafB6J+Eu2OEmPC65oS2E//jr9Gr8Upp0LnAa4a7udUA2UEK48KmzNx0Jr9qTR7hFaSMw3xhzXAvTbqkbgMHARGPMVhEZTXiYgUTylCciOcaYPQ2Oa/gZx/MccLuIfER4/PaHkffb61qVSpSWWT9oizLrAsLDZuZE3bR5CA/1ezWSVv9Gjm2uLHoPyI+UbecA10feb+4zVqq9aTn0A713OojvnQ66OVMJeAy4XEQmSli6iJwsIpnAZ0AAuEZEnJEJdxMaOc9Cwl/keyLn8IjIYZFt24CekXHEGGNCkXTvj7RGIiI9ROSEyP4vAjNEZFikteH2BK7DEUmz7uEk3D1bC+wi3JrzxzjHnSQih0fydhfwuTFmI+ExzINE5ILItTtF5FAJTwptTZmEC8s9IpJH1LUbY7YQHsc8S8KTLZ0icmRk8zagk4hkN3HuNwm3pPweeCHy/wDtd61KtYSWWa1fZl0I3El4qEzd46eRtDsBjwM3isi4yP/BAPlhAvo2oF9jJzbG+AnPIfkT4Ruw9yLvN/cZK5VKtBzSe6eD9t5JK1MNGGO+JNxl+SDhlofviUzoM8b4CE9CnkF4edyz+P/s3Xl8VNX5+PHPvXf2bCSBBEgCSUBAQGSTTRaxSl0rat0X3Kjrt1qt3dvvr63VWv1W27rgirt1ReuCS1X2HZFVCJAESEII2ZfZ772/PwYCyUySmZAACc/79eKlzNy59wxwb85zznOeA++3cB4duJDQgsjdQNGB4wG+BjYDpYqilB947ZcHrrVCUZRaQvn5gw+caz7w+IHP7Tjw37Y8TehGOvhrLvAKoWnuYmALsCLC594gdNNVAmOAaw+0oQ6YQWhqt4TQ1PLDgD3SxZVQFZgpUbSzLY8DTkIjUCuAz5q9fx0QALYCZcA9B9q7ldDoSb4SWtDZt/mJTdP0Efr7O4tDC75j/q5CHEvyzOrcZ5aiKBMIdRyeNE2z9LBf/znw3a4yTfMdQmlAbxBa8/QBocAI4CHgdweeQz9v4bu/Qeg59E6z0fgW/4yFOJ7Ic0j6Tidy30lpmsIqhBBCCCGEECIaMjMlhBBCCCGEEO0gwZQQQgghhBBCtIMEU0IIIYQQQgjRDhJMCSGEEEIIIUQ7SDAlhBBCCCGEEO3Q6qa9PXv2NLOzs49SU4QQR8PatWvLTdPsdazbcaTk+SRE99Mdnk/ybBKi+2nt2dRqMJWdnc2aNWs6p1VCiGNCUZRdx7oNHUGeT0J0P93h+STPJiG6n9aeTZLmJ4QQQgghhBDtIMGUEEIIIYQQQrSDBFNCCCGEEEII0Q4STAkhhBBCCCFEO0gwJYQQQgghhBDtIMGUEEIIIYQQQrSDBFNCCCGEEEII0Q4STAkhhBBCCCFEO7S6aa9omT9oUOsNYJqQ6LRgt2jHuklCCNFlGYZJnTeIJ6AT77AQZ9NQFOVYN0sI0Y0FdIMaTwDDNElyWqUvJ9ql2wZTm0tqeP/bIoqqPPiCBikuG5MG9uSCEX1wWNt3s+iGycK8MuYszGdtYRUWLfSDPqAbjMjswW3TBnDWyWlYNJnwE0KIaOwoq+fFJQW8v64I3TDRVIWgbpLksnLz6TlcOa4fKXG2Y91MIUQ3oRsmi/L2M2fhTtY068udkpHE7WcM4KyT06UvJ6KmmKbZ4ptjx44116xZcxSbc2QMw+Q/60t4asEOdle68QcNjMO+XpxNwwQuH5vFrdNy6ZPkjPrcC7aVcd/b6/EGdBr8esRj4uwaFlXloUuGc94pfY/w2wjRORRFWWua5thj3Y4j1dWeT6Kpsjovd7z2LZuKawgaJkEj/GeRw6JiAhePyuDPM4djlc5Nt9cdnk/ybDp+Lczbz71vf4fX33Zf7sGLh3P+COnLiZDWnk3dZmbKF9S56411LNlRjqeFG+TgjfP6yl28920Rr9w0jlH9kts89ztr9vD7DzfhDRitHtfg0wGde99eT2mNj5sm58T8PYQQorvbXeHmF0uU0wAAIABJREFU4qeWUuMJRAyiDvIGQ8/cD74rZkdZPa/dMr7dmQVCiBPbu2uL+N0HG6Puy933znr21ni5ZUru0Wmg6LK6xTCfbpjMfmUNi/P2txhIHS6gh3Lzr3l+JZtLalo9dsG2sqgCqcN5AwZ/+3wrH60vifozQghxIqhq8HPZM8uocvtbDaQO5w0YbCqu4fbX12JE+RkhRPdmGCYev05rGVYHLdhWFlUgdThvwODRL7bx0friI2mmOAF0i5mpf329ndUFVY2jmIdr2LKA2tUfEKgoQrU5sablkjTpchyZw3D7da59fiXLf/2DiKOdumFy79vrI958rZ0XQjfhr97fwIxh6bKgUQghDvjn19upbPATKSZq7bnqDRqszK9kYd5+pg9JO/oNF0Icc/vrfLy5ajevLt9FeYMPFQUDk75JTm6ZnMOlYzNJdFibfEY3zAPLNNrbl9vI2UN7y6y4aFGXD6b8QYMXlhTgCYTPSNWumkfNyndJnXEnjpzRKJoFT8FaPNtXNt4o/qDBJxv2cumYzLDPf7O1DF+wfecFwIRPN+7l4lHh5xZCiBONN6Dz1uo9BPTwSCqa56rbrzNn4U6mD0ljT90e3s17l53VO6kP1JNgTWBo6lAuHXQpaS4JtoToThp8QX753ga+2LIPBfAdGDzXCT1Liqs9/O2Lbfz1s61cMTaL3184tHGN5cK8Mrzt7CMCjX25S0ZLX05E1uWDqc83l0ZM+zB8DVQveZ3U8+7BNXhS4+uugeNxDRzf+PuGAz+cIwVTcxbtPJA7G/t5D5776QU7JZgSQgjg4w17iVTsPJbn6vqKVVz10VPk1WzBMA2CRrDxvWUly3h+4/OM6zOOW0fcysi0kZ31VYQ4bnn8OlVuP7phkui0kuiwdOltBiob/Px4zjKKqzz4I2QgHXRwmcc7a/ewpbSW124OrbF8esHOsGIT7enLSTAlWtLlg6kXlhRErMjiK96KGfTjGjSxzXMUVXnYWlrLkN6Jhz4f1Pl2V/URnRegsNxNRb2P1Hh7VMcLIUR39cG64iN4XpvYen6JJXUxmyoDEY/wG34AlhQvYU3pGu4dey9XDbmqI5ouxHHNMEyW7CjnmYU7WVlQiVVTUZRQue8+SU5unZbLzJEZxNm7VrfPG9C55vkV7Kl0R5zRjsQTMNhYVMNtr65lzrVjOqQvt6vCTXm9j57SlxMRdK27KoJdFQ0RX9c9taiuRBS17RxXDYNt326lf18bZjCIGQyyv86PVTHxNbt3YzkvgNWiUOUOSDAlhDjhldf7Ir4ezXPV1vNLbKmLUdTIgVRzXt3L39f8HU3RuHzw5e1qrxBdwdpdldzx+rfUe4ONgxVB49Cgxe5KN3/55Hse+Ph77jnrJH4yNbfLzFS9tKyQgvKGiIFUa+udfEGDVYWVfPBdMRZNQQ82/XysfTm/bnDmows475Q+3Dw5h5PSEzrk+4nuocsHUy1VZtGciRjuWkxDb/Nm0b0+St5+l1L3LhSLBcViodIWD6k/BKXpZ2M570HRVJoRQogTVVvPVS1ue8RAqmpxFeWfl+Mv86M5NBLHJJL+43S0uNA5vLqXR1Y/woheIxiSMuSofBchjqavvt/HnW9822aVOveBIOvx/25nV6Wbv8wcHnVAFTACLNizgJc2vURhbSE+3Ydds9Mnrg+zhs1iRvYM7FrHDxgbhskLiwsifrdo11i+tKwwYmpxe/pytd4g76zdwwfrijkpPYGHLx3B0L6JbX9QdHtdPphyWNWIxSfsGUNQLFbcecuJGzK51XNoLhc5P7+XnOG9G19LD+gE/vAZNIuDYjkvhMqwJ7msbR4nhBDdXbLLFvH1tp6rttSvwwKp8vnl7J+/n8xbMokfGk+gKkDJqyUUPlpIzm9zUC2hxed+w8/cTXN5eOrDHf+FhDiGvttTzV1vrIup3LcnoDPv22LSE+zcfdagVo81TIO5m+bywqYXMAyDhuChTCCf7qPWX8sDKx7ggRUPcOWQK7lr1F1Y1dj7Ow2+IB98V8zivHIq3X5smkpGspOBveJp8IXPRMey3il/Xx1B04RmIVWsfbmDdAN0w2BjcQ0/nrOM564fy+kDe8b2hUW30+WDqeyecVTtDs+HVe1x9Jh8DZVfzkFRNRw5o1BUC97C7/Du3kDy9Jsaj9UNk9xecU0+77BqDO2TyKaS2nafF6BPooNekuInhBCcP6I364uqG0fJD2r1uVq8ivjZTffs0z06ZR+UkXFzBgkjQuk2tl42su7IIu/+PGqW1ZA8NbQhu2Ea/HfXf6nx1ZBkTzo6X1SIo+AX766POJgMrafAeQI6Ty7YyZXj+pGe6Ij4+YAR4P6F97OsZBmeoKfFNriDbgDe+P4N1u9fz9NnPY3T4mx83zRNlu+s4IUlBWwvq8cT0HFaNQb2iudHp/ZlRUEFH3xXjKooTZ4LigKqoqBHKDAWy3on3TBI1XT2G02DvFj7chG/u19n9itrePvWiQzPkGfLiazLB1Ozp+Ry/7vrw6ruASSOuwQ1Lpma5W9R/vGjKDYn9vSBJE68oslx/VJdDIqQ/3rbGQP45bsbwhZMR3tel03j1mldJzdZCCE608xRmTzwyfcR32vpuZp6Xj+apwi4t7sxAgaJY5qm2GgOjYQRCdRvrm8MpgBURWV+wXyuHHJlh38nIY6FTcU17KmMHOREW/L7tRW7uG/G4LDPB/QAv1z8S5YUL8GnR17n2JxX97KpfBN3f303T5/1NKqi8u9Ve/jH19up9QRCm+sedvzuSjdfbytr8XymCXoLSyRiWe+kKypDctNw76kK6ydG05drax8qt1/n9tfWsugX06WvdwLr8sHU2UPT0Vr5Bxw/bDrxw6a3+H6cTeP2aQMivjdjaG9+rW5s13kh9DCYOSqj1WOEEOJEEW+3cNHIDN5bW0QwwohzpOeqI/MlFLVpJ0iv17HEW1C08Ge/JcmCZ1fTTqZX91JQU9AB30CI48Nzi/MjlgmPNgXOHzR4ZfkufvqDkxr3Y8qryuPVza/yccHHTbYcOKitNYo+3cd3Zd/xXt48Vm4YyKcbS1ucOTsSsa53SnZZUVpYut5aXy7aoLSiwc+aXVWclp3Sru8jur4uH0xZNZWfTM3lyW92xnzTKoDTpnHuKb0jvm+zqDx48XDuf3dDTDnJAE6rxm/PPxmX7dAfsWGYuA9McWuqjGAIIU489549iC82l1Lljq4qn6KGj4xr8RrB+iCmboYFVMGaIJb48B9tNf6a9jVYiOPQV9+XRZy5iSUFLmgYbCmppXeKj3u+uYcd1TsIGAF0M7wvFe0aRY/u4f9WzaF2+714ArEV32prFuigWNc7fbKptNVB90hiWZfl8es8s3CnBFMnsC4dTNV5A5TX+znz5DRW5FeyZldlTEGP06bxxuwJ2C0tj2xceGoGJdVeHvtvXtTntut+rh/Vh2sn9KfWG+C9NUU8tySfvTVetAM5wD1cVq6d0J9rJ/RvMWdZCCG6m/REB2/MnsDlzyynwRckwgRVE6YR/nx0DXShWBRq19aSNO7QWgXdq1O3oY70H6eHfSbZnhz2mhBdkWmauP3hM0cQWwqcqihsq9jJXUvvoc5fFzGICp0z+jWKAA16FT5tFwT6NTlPa8FStLNAEPt6J90w0ZtXE2tDLEGpCSzYtp+gbmA5MMsnTixdLpgyTZNlOyt4ZuFOludXYNVUVEUhqBsEDROLqkRMHzmcTVNx2jRev2V8xLVSzd06bQBpCXZ++8EmgLDF0we5bBqGafKzTIMzXnqA36t/4u31+1AVpXHWLHhgJKnKHeDZRfk8syifMwb14tHLTyXRIVX/hBDd38l9Evn4fyZz49zVlNZ68QR0Ii2PsGoKii8DJWE7Joc6j5pLI21mGiWvlaA61CYj5dYUKz0m9WhyHqfFyUnJJ3X21xLimIspBU6t4/Etf6E+UIPZSrAR6xpFlADWHmvwlR4KploLlmy9sqOeBToo2rXr7RXrPlSaqlDrDZISF7liqejeulQwtbW0lpteWk21O9AY0AT0poGNBRPFBFVVQAmVsTwozq6hKgrXTejPDadnk5YQ/YzQxaMzOfeUPny0voQ5C3eyp9KD7cC0dkA3SEuwc9u0AcwclYGmKlxZEuT71XvwqS3/EfsO5DsvyNvPBf9cwnu3T6JXglT+E0J0f/1T4/jqvmms2VXF0wt2sjhvPwYmCkooiFIULh+bxY/G3M3sbxbRfAyr13m90OI0St8qxV/mR3WqJI5OJOvWLFRr09Fh0zQ5J/uco/jthOg8iqLgslmo94XPTsWSAmf2+AJ3sC4skGq+NsqeYUeL06Jeo6goJqr1UJXltlLmPPlro54FOlw0a9fbK9Z1WYoS6guKE1OXCabWFFZy/YurWpwVOihIKIiyaAopcXZOy04hoBukxtmYMCCVGUN7NwZBsXJYNS4bm8VlY7Moq/NS7Q5gmCbJLhtpCXYURcEwTG55ZQ3fq4n41OhuLH/QoKTaw1XPreA/d53eZJ2VEEJ0V3n76nlnzR6W7SzHZlHxH+iMBA2TibkpzBiWzqi+qYzrPY4lxUvCPp8yLYWUaa2vU9AUjQtyL8BldXXKdxDiWJg+pBefbNgbliYbdQqc4oeEtehm04As0tqo3U/tRq/T0X06mr1pYNHSGkWUQ+dtK2Uu1lmgWEW7FutwMe8pGjRJckp20YmqS/Ta8/fXc8Pc1W0GUofzBU2q3H78QYOnrx3d4SUr0xIcEWe2vvx+HyvyKxpnnQ7X2g0dNEz2VLp5blEBd58l6ShCiO6r1hvgtlfX8u3uKgJBAz1ChtGi7eWs2VVFz3g7v754FqtLV0ddpvlwVtXKrGGzOqDVQhw/fjJlAP/dUhax8FY0KXD2HuvD+kUtrY3q99N+5P0sj9I3S8m44VCF4tbWKJq667Dzth4sxToLFItY1mIdLtZ1Wf1TXTisnRMMiuNflwimHvj4expaWGzZWoDiDRgs2r6fb3dXMab/0amyMmfBzohBXzQ3tC9o8NKyAu46c6BU+xNCdEtVDX4uenIppbXeiKWdD+f26xRVubnvVR+OlPOxpHxC0Iw+oHJoDh6a8hDZSdlH2Gohji+nZCaRkexkR1l9xPfbSoGzJK4Dpem91NLaKFuyDUd/B9VLqkkcndjmGkVTt6E3HBoUbitYinUWCNMM5dW1IZaKfKEyEk3PGcueoredEXmLHXFiOO6DqbJaL0t2lkdcnBxNgOIJ6DyzKJ9nr+v8YCp/fz1b9taGvR7LDe3XDb7ZWsZZQ8NHeoQQoivzBw2ueX4le2s8BCJNR0VgmAeK/lRO5N4zBvHMpsfxG34Ms+VATFM0bJqNhyY/xA/6/6Cjmi/EceXhS0dw5bPLo76XDme3e2i+OUFr+7fFD4vHCBhRrVFEMQnUjDp0rTaCpVhngeI0CBgGVkMHQ8erWtAjrE+PpSJf80Dq0PeObl3WhSP6RnEN0V0d98HUayt3RfwnHm2AYpqwcNt+yut99Izv3OIOX2zZhxGhkmAsN3SDT+f9dUUSTAkholLvr+ejnR/xWeFnVPtCi75THCmcn3s+5+Wcd1ytFfpkYwmFFQ0RO39trWvwBw3y80fw6nmvMnfTXL7a/RUKCl7d23gOp8WJaZpckHsBs4bNkhkp0a31cFmJONJMdPeT0mySqK3925xZTrJuz2q1TaapEqgeDeahqnbRBEvRzgLF2TT+9uNTGZ+bQnGVhwZ/kIc+3szGveEzdDFX5FOImHLcGqdV42dnn4TTJil+J7LjPpj6dENpxPVHsQQoFk1h6Y5yLhqZ0eaxR6Ks1ksgQjAV6w1dVhv7ugAhxIllX8M+nlj3BPML56Oi4tEPVdTKr8lnc8VmHl71MBcOuJA7R95JqjP1GLY25OkjSIMOGibz1hXzuwvO4uGpD1Pjq+HTgk8prCmkxl9DD3sPBiUP4pzsc46rAFKIzvLC4gIizc9Gcz/pQSeWZuPL7dm/LYxhw19xRtjL0QRL0cwCqarCjGHpWDW1cYA8zmkHwoOpWNdi9enhpKLeH3EdWiROq8alozO5ZXJuVMeL7uu4D6Zqvc0nokNiCVB03aTGE/k8Haml/a1ivaEj7WouhBAHbavcxs2f30x9oL7FjTY9wVBwNW/7PL7Z/Q0vnvMiOUk5R7OZTWwqrmF3pTvs9VjSoBUFPlhXzNXj+5NkT+KqIVd1eruFOB65/UHmrSumeTXuaO+nYO2paM5iFPVQ3yjW/duac1qcjE/4FZ+ZDgIR9q060lLmTqvGrVNzsTbbGLdXQuS9nWJZi6UAM0dm0OAP8sbK3Rim2WL6pM2iogC3nzGA/zlzYIcXOBNdz3EfTLUk5j0AjkKbesXbURXCSpXGurgy2SWbvgkhIttTu4cbPruB+kDkhefNBc0gFd4Krp9/Pe9c+A6943p3cgsjW7KjnGCEzkksWQZuv878TaVcPb5/ZzRRiC7j042lEWswRHs/BWpHY0//NOz1WPZvO8hlcWFVrTwz4xl6aLl8vX5h2B6gR8phVZlyUk/uOGNg2HvnDe/D11vLaPA1vWYsa7EcVo2zhqYzMqsHN52ew9ylBby1eg+KooT6j0ooo1JV4LqJ/bluQja9k6Lfq1R0b8d9MNXDZaWsLjztLZYAxaKp9DgKAcrpJ/XkqYU78fjbf0O7bBrnDD82nR0hxPHNNE3u/OpO3IHwGZ7mG20mjkkk/cfpaHEaJiZ1/jp+9s3PePOCN49By6Gi3hdx9j7WNOgqt7+jmyZEl7OroiFiymzU95NhJ1B7Krakb0FpOr0Vzf5tADbVRmZCJjefcjMz+s/AYQkFF3NvGMesF1dFnS7XFpdNY8bQdB657FTUCJWOzxqajtbC7FC0a7F6Jzk4NTOU2piV4uIPFw7jF+cMYWNxDdXuAKoCPVw2RmQmhc2MCXHcB1M/GtmXf321I2zdVCwBSkA3mHJSz05v66isHqQl2NlVEd7RifaGNk2pCiOEiGz9/vWUuksxmq2UiLTRZsmrJRQ+WkjOb3NQLSq6qbOjegfbq7ZzUvLR38vOokbugMSaZaBKSo0QLS6BiOV+8u+fgSPxewylIaZra4pGTmIOf53yVwanDg57f1xOCq/PHs+sF1cR1M1WgyqLCroBdquKN3DouWZVFVRV4dSsHtw2LZfpg9NaTKezairXT8zmucX5EdfYt5Ve6LRq3D5tQNj5HVaN07KPzrY6oms77oOpq07rxz+/2hHxvWgCFFWBGUPTj8rMlKIo3DZtAH/6aEvEh0eb+z6oCpeNzZSqMEKIiF7a/BLeoLfJay1ttJl1RxZ59+dRs6yG5KnJAASMAK9ueZU/nf6no9721HgbNosatrdUrGnQnV2VVYiuoIfTGvH1WO4nM5hITvDnlMX9nYZAQ9ggTcTza3ayE7N55dxXWi30MrpfMkt/dSbvrSni2cX51HoCGCbohoGmqqgKxDsszJ6Sy6WjM/muqJpV+ZWU1/uwW1T69HBw4YgM+qVGV0xm9tRc3v+2iNJab9hSi9ZYNYWcnnFcNEoGsUX7HffBVGq8nTMH9wqVHY9wg7QVoNgtGrdMOXqVVi4elcELSwooLG9osSBFSxIcFu6cHp4PLIQQ7oCbRUWLMJst7G5po03NoZEwIoH6zfWNwZRu6nxa8Cl/mPgHLBH2ZelMZw9N55HPt4W9HkuWQZxd45LRnVuVVYiu4KT0BOLs2hGtE7JqCmP7DmXW1Le486s7KW0oxRv0hj1jAKyqFUVRmJo5lQcnP9iY0teaRIeVGyfncMPp2awurKKgvJ46b5B4u4XsnnGMz0lpnA2aPjiN6YPT2v3nkeS08tatE5n51FJq3IGo+l82i0qfJAev3zIeu0UGsUX7HffBFMDvLhjKsvwKaj3BmD7ntGqcP6IPp2a1XoGmIzmsGm/MHs+PnlhKRb0vqs30VAXibBbemD2B9ERZ0CiECFfhrcCqWgkYTdN7Wtto05JkwbPL0+Q10wytn0p2JHdqe5vrnxrHKRlJrNlVFfZetGnQqqIwY6isKRXi7KHp/Oq9I1snpCoK103oT1ZCHB9e9CHr96/npc0vsahoERbVgqqo6IaOpmpcNugyrhpyFX3jY5/BURSFcTkpjMvp3JS5rBQX8++ewuyX17BtXx0B3USPEFQdTCGckJvKE1ePIsEReZZPiGh1iWAqM9nFG7dM4KpnV9DgD0Y1hevUYOKAVP56ySmd38Bm0hIczP/pFG56eTVb99bhC+ottjnOppESZ+OVm8eT0zPu6DZUCNFleIPeiGsG2tpo0xLf9DGvqVpYquDRctu0Afz03+siLpxvK8vAZlG5bkJ/bBZZ/C2E3aJxzfh+vLi0IOKgbTRlyEdkJtE/NdTvUBSFkWkjeTztcdwBN5XeSrxBL/G2eFKdqVjVrhFwpCU4+PCuyWwpqeWFJfl8vGEvqqKgqqAbJqqicNnYLG6clE229LlEB+kSwRTA8IwkPvqfydz62lp2V7hbDFCcVg1D17mg9Dse+t19WI5R1ZXkOBvz7jidjUU1PLc4n883l2LVVMyGehRXHH7DZHxOCrdNG8CkAamyT4EQolUJtoSIe0rFutFm0AgSb4vv9PZGcuaQNCYOSGXp9nK8ERaKt0RTFXonOrj9jAGd2DohupZZk7J5dcWudpUhd1hV7jlrUMT3XFZXl9/4emjfRP7v8pH85eJTqGzw0+ALEu+wkBpnlwEZ0eG6TDAFkN0zjs/vmcrGohqeX5zP/E2hfRZURSGgG/SMt/OTqblcOjqDul9+RuXTT5N278+OaZtPyUzin1eNotYboKjSw9bZt5P70J/Jys0gVRZSCyGilOpMRVPC8/pj3Wgz3hpPvPXYBFOqqvDk1aO5/oVVbCyuxhNoO6Cyagqp8XbeunWCpOMIcZi+PZzMuXYMP3l1TZNKeG1xWjXuPuskTh/Y+VWOjzWHVaNvD+exbobo5rpUMHXQKZlJ/OOqUfyfblDrDeIN6CQ6rcTZtMYZnrg//IH8mReTcPbZOE8ZDoRKiW4rraPGHcBuVUlLcDAoPf6ozAolOqwM7WvFWl9Mdi8nVgmkhBAxsKpWLht0Ga9//3rYuqloN9q0a3auG3rdMZ0Jd1g1Xp89nnv+/R2fbNyLw6JGnKWyagrqgbUW/7xyFMlxspm5EM1NHdSLZ64by80vrUY3IpWOOERVQumyP//hIG6efPQKcwnR3XXJYOogi6aS0sIPWEuvXqT/6pfs/c1vaHjiRV5Yvof5m0qxaSoc6EcEDZPUOBu3Tctl5qhM4u2d88dR1eDnnbV7+HZXNXuHXkHKhzsYkFnFleOyGJiW0CnXFEJ0P1cNuYo3tr4R8b1oNto0Mbl00KWd0bSYWFSFao+f35w7BFVVeG5xPmV1PqyqStAwcNksXDUui+snZpOV0rXTjYToKIZhEjTMsDS1Hk4rCQ4rPxyWzn/WlwA0WZfosKqYZmhz29umDuCUzCSEEB2nSwdTbbHMOIdfLyhn/ZzlBNDQTTNsQ7civ4cHP93KXz7ZyhNXj+IHJ6e3cLbYbS2t5Ymvd/Dlln0oCqFp+KR+UFjL4t11vLZyF4PSE7hz+kBmDE2XdVNCiFb1je/LtMxpLCpahE/3xfRZh+bg/NzzSXEc+00oF+TtZ2+Nlxsn52DVVG6ZkktAN2jwBXFYNRxWKVMsBED+/npeXFrAB+tKaPAFQQkNRpySkcRt0wYweWBPfvneBv7fj4Zy0cgM/t+PhvHpxr1s2VtLVYOfBIeVnJ5x/OjUvjK7K0QnUUyz5UnhsWPHmmvWrDmKzek4bn+Qi59aRuH+enxRlCeH0OjNgxefwiWjM4/4+vM37uXet9e3WsnvIKdN4+KRGfx55nA0VQIq0bkURVlrmubYY92OI9WVn09Hwhv0ct2n15Ffm49f90f1GbtmZ2jqUF744QvHvCpXUDc49x+L+cU5Qzh7aMcNXonuoTs8nzri2bS7ws09b61jS0ktwQMzUs3F2TV0wyQ7xcX8e6bKgKwQnai1Z1O3LGlimiazX15DYXlD1IEUhGaOfjNvIyvzK47o+l9u2cfP3v4OT6DtQArA49eZt66Y38zbSGvBrRBCOCwOXj73ZU7tdSouS9spcE7Fwvg+4/nX9DmsKajl0417+WTDXpbuKMftj23vvo7wztoiUuJsnHVy+zfoFKI721RcwwX/Wsx3e6rxBo0WN6Bt8Ol4AwYFFe6IG2ILIY6Obpnmt7qwinV7qsNS+gAatiygdvUHBCqKUG1OrGm5JE26HEfmMCAUUP3xoy18eveUdl27pNrDT99cF7GyTmvX9gR0/vNdCRNzU5k5KqNd1xZCnBhcVhfPnf0ci4sX8+KmF9lSsQWgcabKptkwMRmZPJgLtm5lY/wsJj246LCRaxMFhaBhcsnoDG6anMOAXp1f4a/eF+SxL/N4ftZYGUUXIoI9lW6ufm4Ftd7oBzp8QYO5SwtJibNxyxQpLCHE0dYtg6lnF+3EEwjfd6F21TxqVr5L6ow7ceSMRtEseArW4tm+sjGYAsgvr2draS1DeifGfO1Xlu9CN8IDqWiu7QnoPP7fPC4a2Vc6GkKIVmmqxhlZZ3BG1hnsrt3N0pKl1PhqAEi2JzM5YzJvLqvnVxXbMMr3ETAjJyK8tXoP764t4pLRmTzQyanGzy7cyaQBqYzI7NH2wUKcgH79/kbqfZEDqbYGZB/5fBsXntqX9ETHUW61ECe2bhdM7a/zsWh7Oc2z5QxfA9VLXif1vHtwDZ7U+Lpr4HhcA8c3OTagmzy/uIBHLzs1pmv7gwavr9yFv1lqYSzX3lfrY31RDSOzpLMhhIhOv8R+9Evs1/h70zT51Xsb+c/6EnxG68UcDq7H+GBdMfvrvDx73VjUTgioSmu8vLJiF5/8tH2z/kJ0dyXVHlYXVkZcHhDtYPBrK3Zx34zBR7HVQohuF0yt212FTVPxN0vx8xVvxQz6cQ2a2OZTwpJGAAAgAElEQVQ5dMNkyfbymK/99dYyjAhPwViu7QvqzF1awD+uHBXz9YUQAmDOwp38Z31JxBn6lngCOkt3VPCXT7/n9xcMjfmadd4A760tYt66Yiob/JhAktPKucN7c9W4fvzfF9u4alw/MmQDTSEienX5rrCBYIh+QNYXNHhl+S5++oOTsGrdckm8EMelbhdM1XgCGBGeRrqnFtWViKJGV3K3rs5N+XPPoVisKBYLitWCYrGAxYJitYZeP/Dawdd37PDhi9B5ieXahgk7yuqjaqMQQjTn9gf551c7IgZSba0Z9QR0Xluxi9vPGEDPKDcWL6vz8rfPtvHxhhIURcFz2P42RVUedu6v5x9fbUdB4f07JrVyJiFObPPWFeHXw5cJxDoYvG53NeNyjv0WCEKcKLpdMGXVVCIlqGjORAx3LaahRxXUWDDRq6shGMQMBDGDB38FMAOBiK+XxQ0l2GMENFvvFOu1G45BhS0hRPfw0fqS5o8gIPo0IYB/r9rNXWee1Oa1dpTVc8Uzy6nxBFqsOHawGI+CyZXPruDlm8Yxpn9y7F9MiG6upaITsQ4GVzbEtgedEOLIdLtgKjXeFrF4gz1jCIrFijtvOXFDJrd5nuSUBNLvvz+ma2ctysf6+VYCzdZMxXrteHu3+2sRQhwlTy/YidvfdFYqlnWbvqDBi0sLuf2Mga0Wo9hb4+GyOcuodgeIZkMHk1A1v+teWMm8O05ncO+EWL6WEN2e3sKARGwDspH3pBJCdJ5ul1Tb0tS2ao+jx+RrqPxyDu685RgBL6YexLNzDVXfvNjkWIdV5YrTsmK+9oC0OGwR8pRjubamKgztE3sVQSGE8AZ0dle6w16PJU0IwOMPUlrrbfWYO17/llpvMCyQatiygL0v38Puv/+YoieuY9/b/4u3aHPj+26/zg1zV0VcXyrEiSyuhYHUwwdk26KgkOQ8thtzC3Gi6XZTIHaLxtXj+vHSsoKwqnqJ4y5BjUumZvlblH/8KIrNiT19IIkTr2hynGnCFaf1I1ZTT+qF1aKCP3ytQrTXtmoKN0zKifnaQghR6wlgs6hh+9zFmiakqSo17kCLxSJ2lNXx/d7asJH0aFMJaz0Blu4sZ8pJvWL8hkJ0X5MH9uTjDSVh1fwOH5BVVA1HzigU1YK38Du8uzeQPP2mxmP9usGpUg1YiKOq2wVTANdP6s/LywshQvJJ/LDpxA+b3uJnLarCD05OIyXOFvN1LZrKjZOyeWrBzogbBrd1bYDs1DiG9pWZKSFE7KyaGrGscqzrNn1Bna+37qPBH6R/qote8fYm6dMvLCkk0GyhfCyphA1+nTkLd0owJcRhZk/J5cst+yIWj4lmQFZTFS4Y0YdEh8xMCXE0dctgKjPZxX0zBvHYl9tjKg2sAD1cVv74o+HtvvbV4/vz7OJ8Wthzr1VOq8q9Zw9q97WFECe2RKc1YvpcrOs2dcPk+721fL21jF0VbjwBnf6pcWSnuuiX4uKdNXtoXnQs1lTC1QVVVLv99HDFPnAlRHd0SmYSGcnOFiv6tjUga9NUbp6c21nNE0K0oFsGUwA/mTqAaneAuUsLowqoNBV6OG28fetEeiVEVxI4kl4Jdl6YdRo3zF0VlmrTGocR4FL/Xs4efHa7ry2EOLFpqsJZJ6fx+ZZ9TfariSVNCGBUv2SevGZM4+9rvQF2V7gprGggb18degdsP2G1KJTV+SSYEuIwf75oODe+FFv/AUJrvc88OU0yW4Q4BrpdAYrD/eKcITwwczipcVacwcilQm0WFbtFZfLAnsy/ewq5veKP+LoTclN5cdZpuGwaVq3lalgQmg1zWjV+Mv0kZu9fQ/F9P8f0+4+4DUKIE9PsqQNwWMIDmsRxl5B85s3ULH+Lon9dQ9HTN1D37cc4T2o6kxRn07ht2oCmn3VYGZ6RxAUj+nLdhOyIhXYOTyWMhoqCN4bMASFOBBMHpPLgzFNwWKPvnjmsKqdkJPHY5SM7sWVCiJZ025mpgy4dk8m0Hcv4eslm3jvpHDYW1+AJ6FjUUMWby8dmcd3E/vRJirzQur0mDezJFz+bygtLCnhr9R4UQusEDrJbVExgysCe3HbGAE7LTsE48ymK7/kZRXffQ8bjj6Ha2z9DJoQ4MY3u14P0RDuFFeFV/aJZt2m3apw5JK3F9xMclrD1UhB7KqFhmiTI2g4hwlwyJpMkl5X/eXMdQNhWBwdZVQVVVfjh0N48ctmp2CzdenxciONW9wqm9CBU7wJvNWg2iE/HjOtF7Ztvct5993H5lNAIrGmaEfei6miZyS7+98Jh/PKcIXy6cS+bimuobPATZ7fQP9XFzFEZpCU4Go9X7XYy//kPiu//BUW330Hmk0+gOjs2yBNCdG+KovD0tWO49OllLXbCWuKwqjxz3ZhW95dyWDV6JzkoqW5aOj3WVEKAvj0cYa8JIeAHJ6ez9ndn89H6Ep5auIPSGi/egEGCw4JphgYjLhubxQ2TssnpGXesmyvECa17BFO1e2HNC7Dq2VBApWqACUE/RkIOTqdO3ITTGg8/GoHU4RxWjUtGZ3LJ6Mw2j1WsVjIefYS9v/0de2b/hMw5c9Di5UEphIjeyX0SefSyU7nz9W+j2lAXQunGT1w9itOyI+/Vd7ifTMnl4c+2ha1HjWULiCtOy8IeIR1RCBHitGlcfloWl43NZPH2/dz/7gYeu3wkiU4rA9PicVjl/hHieNC1gylDh0/vh+9eC1VB18PXRWnVW0kfYkf5+2C4/FXImXL02xkjxWKhz0MPUvrHP7H75pvo99xzaInhi0o9fp2tpbXUeAJYVJWeCTYGpycc9WBRCHF88QZ0nlqwg1sm57BuTzUbi2swTJNAs733NDVUTn1gr3j+cvEpUe9Pc8mYTB6avzXie9GkEqqK7KcnRLQURcFhtZCZ7GLSwJ7HujlCiGa6bjClB+GNy2H3cmihuMRBqukDjw9evwwufQ5OvvAoNbL9FFWl9//7X/Y99BC7briBfi+8gCU5GYCC8gbmLi3gnTVFaKrCwdhJN0ySnFZ+MiWXS8dmyl4TQnRHVYVQtQv8DWCPh9SBkNi38W3TNPnVexvI7RnPb84/GUVR2FXRwNylhXyycS/13iAmJnE2C2cNTefmyTkMSk+IqQmJDivXT+zPayt24Ym16phF5cwhafRLdcX0OSFOKHoQ8j6DlXOgqpBTvW6e1m3w9niYeBdkjgUZOBXiuKCYEUrcHjR27FhzzZo1R7E5MfjgTtj8HgQ8sX3O4oRZH0HWaW0fexwwTZP9jz1O/Tdf0+e55/ndor18smEvQcMkGGl3TkLpOiYmD18ygotGZRzlFovjnaIoa03THHus23GkjuvnU0fTg7DtU1j6OOzbHFoT2vieD7ImwOk/hdwzeW5JIR98V8y7t03Caeu8NCDDMLn55dWsyK+IOqCyW1QGpsXz3u2TJEVJRNQdnk9H9GzSg7D0MVj+JOgB8Dfbc0pRweKAhN7wg/+FYTOPvMFCiDa19mzqmjNTZd/DpvcgeCiQyn68DncACu6OJ84WGq15/ls/r20IsOCGw9YcBT3wyb1w2+Ijb0dtCax+EQoWgrcKNDskZcKYG+CkGQfWbh0ZRVFIu/dnGA4HV/15HluT++MNtt5xObiO4Zfvb6Ciwc9NkyWdRoguq+x7eGVmqFN1sGMVbFr8gYKFULwWt70nH3p+zbN3XdSpgRSAqio8e/1YfvHuBj7bVIo3qNPK2Bwum8aorB48N2usBFJCROJ3wxtXQPHqlgeKTQMCbqjMhw9ug5J1cNb/k1kqIY6hrhlMHRyxaUY34R8r/fxmShslxcu3w74tkD60fdffux7++ycoPBCQHb5Wa98mKFwCFjtMuAMm/RQsR74p5d9SJ7I1sajNQOpw3oDB3z7fSlaKi7OHph9xG4QQR1nJd/DS+aGUvrZKSfjrsfndzLP9EqsxDsjt9OZZNZW/X34q14zvx3OL8/lm234sqkJANwjoJg6rimnC2P7J3DptAJMH9kRtpVKgECcsQ4d/Xw1FK9tcutAo4AkV3rLFw7T7O7d9QogWHbVgqtrtZ1eFm3pfEJdNIyPZ2aQseNR8dbDxXTCDYW/dP8nG35b6uOM0Gz0crfzA1v2w4km46MnYr//9R/De7CazYmEOjiAveiSUmnPt++CMbmF3JDv31/PJxr34zPDv1LBlAbWrPyBQUYRqc2JNyyVp0uU4MocBoYDqDx9u4qyT06QwhRBdSW0JvPKj8DSfVlgwIFALc8+HO1eAI6kTGxiiKApjs1MYm51Ceb2Phdv2U+X289SCndw4KZtLxmSS0UO2eBCiVaufhz1NA6moMm4Cbljyf3DS2dBXNu0V4ljo1GDKNE3W7qri2UX5LMjbj/2wDeX8QYMxB0Yrp8QyWlm0GjRrxGBmbF+NM7ItPLrMxwNnthKomTrkfR7r14Ed/207kDpc0AulG+HVmXDT56HZqnaYu7QAPcL6qNpV86hZ+S6pM+7EkTMaRbPgKViLZ/vKxmAKoMYTYEV+JRMHpLbr+sdC0AiyqGgRG/dvpMpXhdPipG98X87NOZeeTqlmJE4Ai/8eSvs5TFSdK9MATxWsfTm0juoo6hlv59IxoS0gPtm4l4kDUiWQEqItphlaDxkI32g7qoyboA+WPwGXPt+JjRRCtKTTgqn9dT5mvbiKgooGvIFQLr2/WYrasp0VrN9TTQ+XjdduGR/dxnOeKlpLzP/TdDunv9jA3ePbSK2LYbS38bpvXx/7Oi3dH1rz8NUf4YcPxnZNQuXP31tbHFZswvA1UL3kdVLPuwfX4EmNr7sGjsc1cHzYOZ5ZuLNLBFOV3kr+vfXfvPH9GwSNIA3Bhsb3bJqNx9c+zoS+E7hp+E2MSR9zDFsqRCfyu+G718FoZzpz0BPqXE28C1S15eM6UZLTSo0nvP1CiGYKF4O3JuJbUWXcmEYoa8ZTBc7kTmyoECKSTvkpu7fGw7n/WETevjo8/tYXJTf4dUpqPFz4ryV8v7e27ZMrKrQyiTU8TeOCQRb+usTf6mlMwyRYXk5r1QybWPd6xCDuYMemVUEvrH0p9sqDwJa9tVgizNr5irdiBv24Bk1s8xwmsKqwMuZrH21bKrZw4bwLeXHTi9T4a5oEUgB+3Y/f8LO4aDG3fXkbf1/z9+j//oToSja/3+KC8vsn2Xh0mY9qb9trqMj/uhMaFx0JpoSI0poXD6yLDHd4xk2rFBW2/KcTGieEaEuHB1Nuf5ArnllBlTvQYunu5kwT6n1Brn5uBWW13laP9Vh7ENRbP+8fz3Dw3Ld+imtbPk73muSffwF548ZTcMUVlPzyV5TPeYbaz7/Am5eH4TvswWUYsPxfEafgo+7YAGye1/YxzdS20BnRPbWorkSUKCsGegL6cR14bKvcxo2f3UitvxZfhM2XD2di4tW9vLn1TR5a9dBRaqEQR1HeF0feufI3QMGSTmhcdCSYEiJKVbtafftP0+38a5Wf/Q2tFKAKuEPrLIUQR12Hp/m9vXoP++u8Edf4tFUsod4X5OmFO/nfCw+t9zEMk80ltSzavp9FefvZVlzPck1vteEDU1SuGGbln6v8nJIWIV7UbFgm3sKgvzxIsKoKf0FB46+aDz/EX1BAoLgYS1oattwc4vtb6RGsihh5Rr1Oy98Q2nxv5NWYuo5eXY1eWUmwqgq9sgq96rD/r6xEr64iWFlFqZ6IMehCsDZdd6A5EzHctZiGHlVApSrKcVuAwh1wc8sXt+AOhgerVYurKP+8HH+ZH82hkTgmkfQfp6PFaXh1L/O2z2NEzxFcMOCCY9ByITqJp6LVt6NOZ24o68BGxUaCKSGi1Eb1vsMzbk7u1coYeCDyAIwQonN1aDBlmibPLs6PuIFjNMUSArrJW6v3cMOkbFYXVrEobz9LdpST7LIydVAvbps2gPG5p+FccAusfCa0HqkFf5hm59UNLfwgVxQYdysAluRkLMnJuEaPbvpdAgH8RUX4CwpDe1pVGC2mF0bbsQnu3kL++Ano9fVoCQloKSloKaHra8mh/7dlZaKdOgKtRzJaSjJDdQfmu9vBrzc5lz1jCIrFijtvOXFDJrd6XYAEx/FbBf/j/I8jzkaVzy9n//z9ZN6SSfzQeAJVAUpeLaHw0UJyfpuDalHx6l6e+O4Jzs89/7gNFoWImWpt9e2oO1da+4redIREh5W9Na1nGgghiKrq5h/PcDD6mXrum9jCPa1o4Dr+10UL0R11aA97ZUEl1e7wACamYgkBnXP/sZjpg9OYclJPfnnukPBqUON+Aquea/JS4T0JTX6flaTi/V1ihFYqkDEWkvu3+l0UqxV7Tg72nBzosQfmvweByMFZtB0bzWUjd/6naElJKFp06XmnGCbxHxXgbhZMqfY4eky+hsov56CoGo6cUSiqBW/hd3h3byB5+k2HrqvABSP6RHW9o800TeZumounWYVE3aNT9kEZGTdnkDAi9Hdr62Uj644s8u7Po2ZZDclTQwttK72VrCtbx+j00WHnF6JLSswgNHrTcmpum50r1QKJfY+oGaZpUu8L4vHrxNktuGxa1IMWSU4rW0vrjuj6QpwQsidD8dqme1Y202bGjdUBfaQ0uhDHQocGU0u3l4d1+iHGYgkmTBnYkyevaaVj3KMfnH43LIu8jqlV9ni48J8xfiYR2kila7NjAyj2BCwpKTFdWlUVZk/J5e9fbgub8UscdwlqXDI1y9+i/ONHUWxO7OkDSZx4RZPjTOCTDaWkuGxcO6E/aYnt2N+rk2wo30CFNzylyb3djREwSBzTNCDWHBoJIxKo31zfGEx5g15e2fKKBFPiuKQbJgvzynhmYT5bSmrxBHSsmkrPBBvXT8jm8rFZJLmazUSdeiVsmdfiuimIonOlWmDYzHa1eU+lm5eXFfLm6t34AgYWVSFomMQ7LNw4KZurx/enV0Lrs16JkuYnRHTG3gTL/tHmYa1m3NgTIWdaBzdMCBGNDg2m9tdHHlWJtVhCpbuN6ngAZ/wa6stgw1tRBlQK2OJCG+j2HBhVOxqlDwe99U5Bmx0bgN7DY7vuAZePzeLRL7ZFfC9+2HTih01v8bOqAqP6JfPwpSN4eVkhZz+2iOmDe3Hj6TmcmtX+jYQ7Sn51fsTX9XodS7wFRQsfBbckWfDsOjSTZWKSV5XXaW0Uor3eWLmLR7/IwxfUafAdGmgKGjp7Kj38/cttPPrFNs4d3oc/zxxGguNAUJU9OVTiuJVgCtroXKUPh54nxdTeGneA/3lzHSsLKjBMk8CBYj8HiwlVuwM8tWAnTy3YybnDe/PXS0fgsEZ+ric5rS0W0BFCHCaxD2RPgR1fcfhsdNQZN1bXMd0GQYgTXYfeeVoLG+8eXiwhGhYtimYpClzwGEz/bShIsrWwR5Wihoo39BoMt3wFWeOiakMTPQdC2tA2D/vDNDsN/hbScmxxMKl9G2gmuaw8MHM4Dmvsf11xNguPXnYqA9Pi+fPM4Sz6xXSG9U3izje+5ZKnlvKf9SUE9FYqBHWy+kA9eoR/F1q8RrA+iBmhcmOwJoglvuk4gDvWGUohOpFpmvzhw038+ePvqWzwNwmkDucJGPiCBp9uLOGCfy2hrC60xmhfnY83LTPx0nT2p/CeBM7KPfRv/2DnqnFfu4OscTD5ZzG1uazWy3n/XMTy/HJ8QaMxkGrOFwy1+bNNpVzy1DLqfcGIx0kBCiFicObvQql6MVNCwdSoazu8SUKI6HTozFTvJAeaqoRV8ou1WEJ6QpQPFEWBSXfBaTfD5g9g6WOwPw80W2gTOwU4+aLQMX1HteMbHWbyPfDBHU02+41+nRZgT4Kcqe2+/GVjs6h0+3nsyzy8EQp8NKcqoUCq+WbISU4rs6fmctPkHL7cso+5Swt48JPvuW5if64a14+UuDaqg7WiuNrD/I172VfrxR80SI23Mz4nhXE5KS2us3BZXKhKeJDoGuhCsSjUrq0ladyhxbm6V6duQx3pP05vcrzDcvykLgrxf1/k8c6aIjyB6AaQ/LpJUZWbK55ZwXUT+/HE1zu57rSrsTlWQel3rRbbCWNxwsAzYcj5UX+kwRfkimdXsK/WF/WWFt6gwc799dw4dxVvzp4QNgiW5JJgSoio9R0FFz0V6mcEo92TUgktXbjhE3A2zTSp9lbz/vb3eW/7e1T7qtFNHZfFxdjeY5k1bBbDUoe1cE4hRKw6NJg6d3gfnvxmR1gwFUuxhDibxiWjM2K7sNUJI68K/dKDoZ3ELTawxbe48WXMBp8XqpQT8IAZXQepSft+8PsjbsutUwfQL9nF//5nMw2+IA0R1qdZVAVNVRjaN5G/Xz6ySSB1OE1VOGd4b84Z3pvNJTW8tLSQMx75hnOH9+HGydkM6d1CUNiMaZos3l7OnIU7WburChPwB0PBnqqAw6qR7LJx69RcLh2TSZy96T+5zITMiMGU5tJIm5lGyWslqA61STU/a4qVHpOa/uDISsiKqr1CdLatpbU8vyQ/4qBHa9tD6AYUljfw9Df5vDF7fOge9L4LL18I+7eFNv9ug09xYO8/ES59IabnzYtLCyip9kQMpFprsy9osLmklo837GXmqNBz2zRNVhdW8eF3xZTVebnt1bX0SrBz1tB0pgzsidpCBoMQJ7zhl4T6C+/eFBoQbu2etx7IyLnhE+g1qPHlSm8lD658kAV7FqCg4NUPnaMh0MDnhZ/zze5v6BPfh1+N+xWT+k6KdHYhRAyU1jZyHTt2rLlmzZqYTvijJ5awoagm4nv1m7+hbs2HBCr2NCmW4Mg8ufGYXgl2Vv76B8fnD9zq3fDM1FCwZkaZGmd1wejr4dyHO6wZhmGybGcFcxbuZM2uSnwBA0WBBIeVH53alxtPzya3V3zM562o9/Hmqt28umIXuT3jufH0bH5wcnqL6ZtB3eD+dzfw+ebSiIVHDue0aqTG23jr1olNqjPqhs6Z75xJpbcy4ucqF1ZS8UUF/jI/qlMlcXQivS/rjRZ3aJ2Gy+LikWmPMDWz/TN/JxJFUdaapjn2WLfjSLXn+XQ0/Pyd9cxbVxw2qNTS9hC+PZvDBpTW/v7sQ2uRAl747Few/s1Q2nKklFZbHKZp8qpxDv0ve4hpQ3pH3V7dMBn7wJdURajEGm2bh/RO4IM7T+edNXuYszCfKrcfj19vUoswzqbhslm4ZUoO10zoT7z9+N2uQRw73eH5dMTPpvoyWPsSrHwagv5Qf8MIYqpWvIEgWlIGtjPug2GXgM3V+LE9dXuYNX8WVd4qgmbk9NvD2TU7vzjtF1w++PL2t1WIE0Rrz6YOD6bmb9zLfe+sb7NzHYnDqvKzswZx67QBMX/2qKkqhJfOB3dl64UvFBUsdpj4PzD9Nx03QxaBYZgoCh22z5I/aDB/017mLi2kosHHrInZXH5aFomOQxXHDMPkttfWsmj7/qjSDgE0RSHJZeXTn06hd5IjVLpx59c8v/A3PKO58baz+cmOZL657Bu0KAucnOi6Q2cFjs9gqs4bYOwD/8UXbHpPGL4Gip6cRep597SZ6hxn03jg4uFcPCqz6Rve2lBAteJpqNsb2ujT4ght8zDppzD8EpbuauDn76zns7unkuCwsGxnBS8vL2RXRQMev068w8LIrB7ceHoOg9JDacr/3bKPu/+9LmymO5Y22y0KGcku9lZ720xtdFhU0pMc/PsnE+iT5Gz1WHHi6Q7Ppw57Nhk6FCyE6j2h/oY9kX9sstJjwDhmTcpucmiFp4LLPrqMCk8FBtGvg3ZoDv58+p85J+ecI2+vEN1Ya8+mDh8a/OGw3ry9Zg/LdlaEdShaY9UUcnvGhz0gjjvJ2XDnKtj4Lix5DOr3gREMrWlQtFDnxtTh5AtD1XX6dv6+Dx09i2ezqFw0MoOLRmawbncVc5cW8q+vd3DRyL7MmpTNgF7xPPnNDhZvL485lanGE+Ca51fw5cUq6oIHoaGcSyf/lOe2PBFDnvghDouDG4fdKIGUOC58vbUMi6rQvK5pLNtDNPh13li5OzyYciTC+FtDv1pw+kAnZ5+cxnUvrGRfnZc6bzBsYCuvtI5564oZ0Cue+384mP+sL4mYMhxLm31Bk8LyBqJZbuUNGhRVebjoiaXMv3sKqfHHbmNhIY5rqgYDzmzyUq5awoffFYf1lf684s9UeavCAqmqxVWUf16Ov8yP5tBIHJNI+o/TG7M7vLqX3y/9PRP7TiTJ3vbmwUKIcB0eTKmqwtPXjuH6F1axobg6qlkLu0UlM9nJ67eMb7HM7nHFFgdjZoXS94pWh355qkMzUQm9Qwu/ncnHupUdYlS/ZEb1S6a0xstrK3ZxxTPLGdonkVWFlRH/bltKC/JsX3lgXYhJaUUVS999jSkzZsEpl5GsWXgsbSB3f3N3k/zuttg1O+N6j2PWsFkd+ZWFaLfyen/E6pixbg+xv67lzTtb4wvq7K70sLG4psXtfnUT9EBordNtr60lrYX9omJtc/NAqvX1YSZVDX5ufnkNH9x5egzfUIgT2+kDe/Kb9zfiDxrYLKH1xhWeChYXLQ5L7SufX87++fvJvCWzybrjwkcLyfltDuqBz5uYvL/9fW4cfuNR/z5CdAedkrTusGq8Pns8f/nke/69ejcKSsTUD4dFxSQ0m/XXS0/BZetiOfSKEiq13p5y611M7yQHP//hYO46cyB//Ggzvh3hHUbD10D1ktdJPe8eXIMPLWp1DRyPa+D4xt83GFaeSf45U0YeGvGelDGJR6c9yv2L7scX9LWZpuDUnEzsO5FHpj0SsYCFEMdCUDcizs4cvj1ENMFJS2XJW2MYJne+/i0r8itaDKSa8wZCs0SRxNrmw7U1qAIQMEy2ldayuaSGYX1lRFyIaCQ5raTG27j2+RX4dZOgYeCL+wK9WfdJ9+iUfVBGxs0ZJIwIpfTaetnIuiOLvPvzqFlW07jxvU/38cqWV5g1bJb8PBWiHToterFqKv/vR8O4b8Yg5iz5gs8AACAASURBVH1bzHOL8ymuDv3QNoGe8XZm/X/2zjs8jurqw+/MbFevbpKr3HvvlWYDoRhTDYRiwEAoCSEhQAgEkvARklAC2KaZjgPBOKZXd1lu2MY2uEhykWVbvW2f8v2xbvKOpF0h2bK47xM/edidnb272rlzf/ec8ztjOnHl8I4izeM0wmFV2FRQiVmpXeRpQRJr91RwqMpHm/hjluYTMyfy9rlvM2fzHFMnIhkZu8VOuiudWf1ncUG3C8TEL2hRxDutWBUZ9YTeadG2h4hzRD81v7N2Lyt3leIzSa+uL0pUV2petGM+QqSbKhCyhH95eT7/vLz506EFgtMZX1DjpeV5vLJyN9W+ILtLj9Vsx3RbgqzUbp/g2elBD+rED63tzKs4FOIGxFGzteaomIJQr8YtJVsYkDageT+IQNAKafZQUJzDyrVjOnPtmM4YhoFf1bFb5CYzSxCcfPKL3aaPR5MWZLPI5BbX1BJTAFlJWTw58clQj4xdH7D+4HoqA5U4FAcZcRnM6DGDfqn9muRzCARNzbBOSZiZ+kTTHsKmSEzskRbV+xqGwfPf5ppmAEQSJTIjmjEfTzS1Vppu8PH3B3js4n6nX2aCQHCSKHcHmPlSDnnFNaabJZIl3AxLq9GwxFqQlPC1liXBgndP7Yi0LMmUeEuabtACwc+Ik3r3kiTp9KiJEtSLXzV364omLcgwoMZXt3VroiORG/rdwA39zBdsAkFLpHubOLq3ieP7/eHtIeJHTEeOSaIyewElHz1Zqz3E8UiSxLVRGvGsyS+j3BPe2DeaKJEZkY75eKKttbLIEkVVfjqnCjElEJyIJ6By2dxsdpe660n/DRdYSqyCWqNiaEaYoFIrVSyxta83wzAI6FE0BxcIBEcRdy9B1FgV2dSpMZq0IElC7EQLWiWzJ3bj3vfN20PE9p1MbN/J9b5+WKekWr3YIuHNnD14f6IjX11EMubjibbWSpIkavwN98QRCH6OPLRoK3vLPKZC6mj6bnkeikPG0dFB2i/SiOkRgyvLhWSRqFpfRcKIYzWJmk+jenM1bWa0qXUuSZKIt8af+BYCgSACRMGJIGpOTM07wvFpQZ4d2ehBH4am4s1dR/m3r9Q6NqjpdEgSPWYErY+z+7YhM8mFpREtCxxWmfum9W74wBPYW+oxNZ2INkrUFBy/qRIJhmEQIxr4CgRhVHqCLN5UaLp5WbVmIWVfv0jCqMvI+ssv6PFkT5KnJFO9oRoAxaWQflE6hW8WUr25GkM1CBQH2Pf8PqzJVhLHJNY6X0AP0Cul10n5XAJBa0PcwQRRc/2Yzjzx+XbT+oxI04K6psbSJTXmZA1ZIDhpWBWZt24ayfnPrqC0xh+xM5/DKvP0FYPpnxG9s11djXKjjRLZLTKqbqBF0DDKKoOqEybioq21UnWDNvHChEggOJH31u9DNqkvPzF9V3O3wZa0j/jBEvGDj0WX0s5NQ4lROLjgIIGiALJTJn5IPJm3ZCJbj+2ly8hM6DCBZEfySflcAkFrQ4gpQdRcMiyDxz/7sc7nG0oLirErzJ7UrTmGJhC0CFJj7Xxy53iueTmH/BI33oBWp125y6YgAfOuHcbYrNRGvV+8w2r6eDSptxZZ4oZxXfj2xyL2lHrwq5q5zbsUEoyDOibSu108b63eS+CE3lqRbqooksS5/duJlF+BwITXVu023Sg5MX1X83TF0FxISnjNU/LEZJIn1i+S7BY71/W7rknGLBD8HBF3MEHUxDusXDS4Ax9+t980/aAhbIrM1L5tm2FkAkHLITnGxkd3jGN1XhnzluWyMrcUu0XGF9SQJAlFlkiNsTF7UjcuGtThJ6W6De2UxKaCirAoWDRRIrtFZnTXFH53Tk827qvgxeV5fPVDEVZZQpIkDMNAMwwuHNSBG8d1oUebOA5Uenk7Z6/pmCKptbJZZGaN79Lozy0QtGYO1dG8Ozx9V8JfdB6O9v9BkoNRvYdNttEvpR8DUoUlukDQWISYEjSKRy7oy8Z9FeQV10TVYFSRJF69fvjRzu0CQWtGkiRGd0thdLcUiqp8fL+/kjez92BIMGtcV0Z1Tcai/PRr4ZrRnZi/ajfhSXeRR4lcNgvjslKRJInBHZN4fuZQavwqh6p8ePwasQ4LbeMdOG3H0gXbJTiZ0CONZTuKo95YscgSPdvGiYa9AkEdBOu4pszSd9Xq/vhLSrCnfhOxoLLJNjrEdeDZM54V7WoEgp+AEFOCRuGwKvzn5tFc83IOO4qq8QXrX0jJUiidqXt6HC8uy+fpKxKaZBEpELR0DMPgu30VvLgsj69/KMI4LHhW55ZitypcO7oT14zqRHodxi6RkJHkYkjHJLLzSk2fbyhK5LDK3Di+C/IJphmxdguxabH1vve/Lh/EL55dQUG5ueOYGYoskRRj46VfDovoeIHg54jdKpveW+tK3w2WTgbNhb3NYgAk2dwlU0LCYXHQJ6UPz53xHDFWUb8sEPwUxGpW0GgSXFbeu3U0d5/RndRYGzG28AJ3p1XGbpGZPiSDT+6cwDs3j6Lar/Lr/2xC1aJPERQITid2l7g5859LufqlHD7fepCAphPUDIKagU/VqfQGmbcsj/FPfMs9/9lIoBFps0f47Tk9cFgbN6XbLQpXDM9s1Gtj7RY+uHUMWelxOCN4f4dFpkOik0W3jyU1VhhPCAR10SXVfCOjPufcooVbcef+jkDpBAzNgaHZcVlcuCwuYqwx2GQbEzIm8PwZz/PqOa8KISUQNAGSYdS9kzhs2DBj3bp1J3E4gtMVXTdYtrOY99cXcLDSR1DTSXTZmNwzjUuGZhB3XIG8L6gx67V1pMXZefLSgSiNsJAWNB5JktYbhnHahwRa+vy0tbCSK+atxu1XTY0cTsRhlenTLp63bxrV6ObmC9bu5U//29pgpPh4XDaF92aP/snpdr6gxrtr9jJveR4VnmCY6UaMTcFpU5g1rgszR3WqNScIBEdoDfNTU81NH363nwcWfo/bpIccQM3Wb6let4hg6b5a6buOjCPtFTSG9KjgpslpBLQA8bZ4+qf1J9XZOKMbgeDnTH1zkxBTglOCN6Bxw/y1tE908vcZA8LSiwTNR2tYrEDLnp8KK7xMe3o5ld7oisEdFpnRWSm88svhja5h+PC7Au774Hs03ag37c5pVbBbZd6aNbJJ65YMw2BNfhkff3+AvOIaNuyt4KLBHTi7TxsmdE8T17qgXlrD/NRUc5Nf1Rjy6Fe4G9nUOsam8OK1wxjTSJdQgUBwjPrmJlEzJTglOG0KL183jOteXcsfPviev03vLxZZglbDE5/9SLXPXEi5ty2hau2HBEsLkG1OrOldSRhzGY6MvvhUnZy8MlbuKmVc98YtgC4anMHQTsm8lr2bd9aEnPYCqo6qGbWMX+49pyeXDM0gwdm0ESJJkhjZNYWRXVPYV+bhyhdX89eL+zfpewgEPwfsFoVbJnTlhSW5dfaSqwuLLNEhycnobinNNDqBQHAEIaYEpwyXzcKr1w3nl6+s4YEPt/CXi/oJQSU47an0Bvl0y0HT1L6qNQupzHmflLNvx9FlCJJiwZu/Hu/OHBwZfQHwBDTmLstttJgCyEx28eB5fbj3nJ5880MR+yu8eAIacQ4Lmcku7n7nO2YMy6izP1VTEbKCF7WRAkFj+dXkLLYWVrJsR0nEgkqRJRJdVt6cNVK49AkEJwFhQCE4pcTYLcy/YQTbD1bx0P+2UF/aqUBwOvDeun3IJgsY3e+mYsVbJJ91K66eY5BtDiTFgitrZK1+TwBr8ss4UOn9yWOxWxSm9W/HrPFdufOM7lw/tgtn9m7DqG6pfLn10E8+fyTv71ej21EXCATHkGWJ564awrn92+IyMXk6EadVoUOik8V3jCM9rvEOoQKBIHKEmBKccmIPC6rv91fxyOJt5oJK9cPuFbBtEWz9EPKXQcBz8gcrEDTA4k2FpjvI/v0/YqgBXD1GN3gOWZJYtqO4OYYHwC8GtuOjzYXNdv4j2K1yoxp7CwSCY1gUmScvHciL1w5jfPdU7JaQS+4RZAkkoGOyiz+e35vP755AuwTnqRuwQPAzQ6T5CVoE8Q4rr98wgmtezuHRj37gj+f3DqUnlO+GNfNg/WtQa7dfAl2FgVfCqNsgNetUDV0gqEWFx7xWSvNWIbvijzbZrI+gptd5nqbgjN5teHDhFsrdAZJibM32PnaLTEDV0XVDpPAKBD8BSZIYm5XK2KxUCiu8fL71IMXVfvyqTnKMlReX5fPGDSPolCqszgWCk40QU4IWQ4LTyhs3jOSql1bzt09+4A/295BWPweGAVrA/EUbXoeNb8Hgq2HaExDBQlUgaFbq0AyKMx7dU4WhaxEJquYk1m5hXPdUPt96kAsGtafcE0TXDeKdVuIdliars5AkCZtFJqDpOMS1KRA0Ce0TnVw/tkutx3YVuVmRWyLElEBwChBiStCiSHBZefOGEax5ZiZBdTk23V//C/Rg6N/Gt6H6IFz2Bsgie1Vw6kh02dhTGp6Cau/QC8lixbMjm5he4+o9h9Uik+hqPnMITTfomOziL5/8wIMfbsGqyEhSKCKWkeTi1ond+MXA9jgjqNFoCIdFxh/UG907SyAQNMy4rFS++uEQM0d2OtVDEQh+dohVp6DFkfTd85ytr8Sm+yJ/UdADud/AVw8138AEggi4aGB7nCbCQbbHkDhuJmVfzsGzIxs96MPQVLy56yj/9pVax+q6wcQe6c0yvuU7ixn+l694c/Ueqn0qqm7gDWp4AhpBzSC/xM3Di7cy9LEveX3V7p/8fnargk+YUAgEzcq47qmsyi1Fi6RDuEAgaFKEmBK0LAJuWPZ/SGrtnf3OT1WT/vdq3IFjN4qXNgSYNN997KCgB3Lmgbv0ZI1WIAjjkmEZ6HW4UsaPmE7SlBupzF5AwbMzKXjhOqo3fISze21TilFdU2ib0PROXIu+289Nr6+jzB3AHahb4HgCIXH1t09/5K+f/PCT3tN+ODIlEAiajzbxDtrE29myv/JUD0Ug+Nkh0vwELYvN71GXxtcMeDonwP3j7XW/XpJhw2sw/jfNMz6BoAHiHVbO69+ORRv3o5loqti+k4ntO7nO17tsoUadTc2q3BJ+/8HmqPo+eYMab2TvoU28nRvHNW5MDmvj7dH3lXmYv2o3n3x/gGqfioFBnN3KOX3bcP3YLnQW9SECAQAVngAJTivXvrKGoKYT1HScVoW+HRK4ZUJXJnRPEyYwAkEzIcSUoGWx8ikIuk2funeMjSdW+rltuI1ERx03BdUL2c/B2LtF7ZTglPG7qb34dnsR5Y1w5BuYkcDobilNOh7DMPjtfzbVKaTc25ZQtfZDgqUFyDYn1vSuJIy5DEdGX7xBjSc+286MIZkkNKKOqzGNe3cVVfPAwi1s3FeBbhgEj1Olbr/G2zl7eXftPvp1SOCxi/rRu1181OMSCFoDVb4gDy7cwudbD2IYBoHjrpWgppKdW8rmfRU4bQq/m9qLy4ZlnsLRCgStEyGmBC2HgBsq99b59LD2CpM6W3hylZ/HptSTAhWogZqDEN++GQYpEDRM2wQH79w8isvmZFPjV4mkjMFplWmX6GTHoWo27C1naKdk0+P2lXnILa7B7ddw2RQyk11kpcfWe+6c/DIqvObCrmrNQipz3ifl7NtxdBmCpFjw5q/HuzMHR0ZfINTH5r31+5g1PvroVLSRqZy8Um6YvxZPQKOury2oG6AbrN9TziUvrGLeNcMY1z016rEJBKczh6p8zHhhFYeq/AS0ujcs3AENd0DjT4u2sv1gNQ+e17vJHDsFgpZCYU0h6w+tpzpQjUW2kOJMYXS70bisrmZ/byGmBC0HXyUoNtC9dR7y58l2xr7i5q6R9fTGkS3grRBiSnBK6dU2no/vHM+s19ext9SDX9VMRdWR5psXDurAoxf1Y+WuEm56fT3/uHQgk3uFTCg03eCbH4uYszSXLfsrsVlkDMNAkiSCmk7nlBhundSNqf3aYreEm1/MW5qH16RGSve7qVjxFinn3o2r55ijj7uyRuLKGnn0v71BnXnL8rhhbJeoU4Xslsgb9/5woIrrDwupSPEENG56fR0LbhnFgIzEqMYmEJyuVPuCXD43m8JKX8SmE95gKKqb4LRy5xndm3mEAkHzoxs6K/ev5NUtr7K5ZDOKpKAZGhISiqyg6RoXdLuAq/tcTZeELg2fsJEIMSVoOchWMOpfdPVLVzi/h4XHVwTonVZXGp8REmUCwSkmM9nF53dPYHNBBS8uz+PzLYeQJZBlCVUziLErXD+2C1eN7EhqbKgWcFLPdF68dhi3vLGOB8/rw9BOSVz54mrKjzONOFGc/Hiwmvs/+J6H/7eVN2eNpG/7hFrPL9tZbBrl8e//EUMN4Oox2uTZ2lT7VPJL3XRLqz8KdiKhNL+GxZFhGNzyxnpTIVVfGiKEFok3vb6O7PvOEHUhgp8FT36+ncIKr6mQaiht9/kluzi3f1uy0uNOwcgFgroxDIM9pR5K3X5UzSDBZaVbWixWJXy9VxWo4tYvb2VXxS48ang7Eg7fSj7Y+QGLchdxY78bmT1wdrNEZYWYErQcnImgN7zoemSSgyFza7hndB1GFFoQYpq25kQg+CkMyEjk2SuHEFB1KrwBfAGdOIeFBKfVdPE/tFMSb980ipkvrqbarxJQ9QZTBY+k8lw6J5s3Z41kSMckAHxBrU53Qc1bheyKj6iJsEWRqIi0BkzXIe9b2LyA+4q2k/aVApvaQc9zod90sDrDXrJuTzklNeF95SJJQwSo8amszC1hfPe0yMYoEJymeAMa760vqFUfdYRIrpegpvPyit38bXr/kz10gcAUt1/lw437mbs0j+JqPxYldF/UDQNFkrhmdCeuGdX5qMttdaCaqz66ikJ3IUG9/vuSaqiomsorW1+hKlDF70f8vsnHL8SUoOWgWKH7ObD9E6izWgKykmUu72vlmTUB+qebRKc6DAVnUvONUyBoJDaLTHpcZJbnbeId6AZRmzd4Ahq/fHkNn949noykUK64hITZNaU449E9VRi6FpGgapCgD9bMg+x/h2ogAzX0BPACpYQE1qf3wqCrYdyvIb7d0ZfOWZoblooYaRoihMTknCW5QkwJWj2LNxeaPh7p9aLpsPC7Ah48rzcxdrEMFJxaFm8q5Hfvb0aSOJaZcII+enF5Pi8tz+eqkR158Nze3PHNHRxwH2hQSB2PT/Xx/o736ZbYjRk9ZjThJxBiStDSGHsn5C2p09HvCA9NtPPGZpOLyBYbcvITCE5z3snZg9uvmj7XUNqbJ6jy/Le5/HV6fxxWBclcS2Hv0AvJYsWzI5uYXuPqHY+qGSTW5+bnKYPXL4CSXSFXTTMCh6/rda/A9/+BXy6Gtv3xqxpLt4enIkaThgiwZncZ1b4gcY7oXQcFgtOFd3L2mqbDRnO9KLLE0h3FnNu/XYPHCgTNxfxV+Tz+6Y8NbhoGDqe2v7tmHz8cOkiuYxsBPRB2XPnycko+LyFQFEBxKMQPjafNjDYoMaHNQp/m45kNz3Bx1sUoTbGBeBghpgQti8yREJsO5fm1Ht59d+3c7swEGd+DJnbIFgd0P6s5RygQNDu6bvDSinx8JsYNkaTxaDos3FDAHUYu2vKlDKrpxPqUbhhS7UiubI8hcdxMyr6cgyQrOLoMRpIt+HZvxLd3M0mTbzh6bKzDQpeUOvo6BdzwylQoy4NIdgr1IHjL4dVpcNMSKm0ZWBQJ9YRcxmjSEAGsiky5W4gpQeum2CQdFqK7XlTNME2rFQhOFl9tOxSRkDoeb1BjbV4AJeFM7G0X13qu5NMSij8tJmNWBrF9YgmWByl8o5DdT+6mywNdkA+bPfk1P8v3L2dS5qQm+yyiEY+gZSFJcPkb0BgrS6sTrngLmnC3QSA4FSzbWVyv+17yWbfi6jkG2eZAUiy4skbWEj4Ahs/LByt2EDthAnfd9gucdnOBET9iOklTbqQyewEFz86k4IXrqN7wEc7ux3a3HVaZm8bX4+S3+C6o2F1LSHV+qpr0v1fjDhwTSC9tCDBp/nFRZ38NvHEhgaB2OBWxNsenIUaCJEFAa1yDYIHgdEE16wZOdNeLbhhHd/sFgpONrhvcv/D7ensfHnjtbvb+cwYF/76GQ//5E76CrQBoukKgYgR68JjRkubVKPqwiPZXtyduQBySRcKWZiPztkwCJQEqV1UePdajenhlyytN+nlEZErQ8mjbH2a+B29d1mC63xF0ixN5xnzoOKp5xyYQnARy8sqOOvcdTzRpPD6Lnc0Dp3D79GGMMwwSPt9dp+V4bN/JxPadXOe59KDKjD51mLq4S2Hb/0AL3+XWDHg6J8D94+swi8EAbwXxB1ei6uE31WjSECG0yIwXUSlBKyfOYeFgVfjj0VwvFlkmwSmuFcGpYVVuKTV1pLFHZDpkSATKRuNo8xkAnp0e9KBO/NDaGUuKQyFuQBw1W2tImnCsln572fYm/TwiMiVomXQeB7O+gozhodQ92WTSlxSwOCmL782vnX8hmHX2yR+nQNAMlLh/ehoPQJk7lFMuSRJPXjoQhzX6Kd9plbmF3ZRdeRmedevCD9gwPxQSMuHeMTaeXOWnwlePFWGghvLP/oaZ4eDxaYieHdnoQR+GpuLNXUf5t+E7i7EOy1GLeYGgtTK+eypWJfyai+Z60Q2DoZ2EUZPg1DB3Wa7p5l7k2RcWghWjMIzQvVCr0bDEWpBMrgtLggW1prZw86m+Jv08QkwJWi5t+oQE1a2rYOh1EJMOFmdIXLlSYdBVcPO3JP06m6rk/vz7m12nesQCQZNgN+mpAdGnvR3fwHdsVip/u7g/Dkvk077TqjBzZCfuefwO2vzhPvb/5h4O/e1xdN9xN6I186COG9Ow9gqTOlt4clX9tRkdarZw/bBknNZwkRhJGiKAwyIza1z0TYUFgtON68Z0Qa5jAyPS66VPu3i6RtkzTiBoKtbkl5k+Hq3pkO5rA4ASq6DWqBgmKbBqpYoltnYinkVu2sQ8keYnaPmkdIPzngz9M0ECHr9kAOc+vZyz+rShX4cE0+MEgtOF9olOrLJE8ARDhmjSeCSgXUJtG/aLh2SQHGvnrne/I6jpuP3mosxlUzAM+P3Unlw3NtQ1Pm7KFJyDB3Po0UfJv3g67R//G84BA6CmqN5x/HmynbGvuLlrZN2NtC02B3cMgNc3mEewGkpDhJBZ4eXDO9Z7jEDQGuiY4mJgZmKdC9KGrpcYm8LsSd2aa3gCQb2oml5nvV502RcGhh7qWejKciFZJKrWV5Ew4rhaKp9G9eZq2sxoU+uVifbERo/fDBGZErQK2sQ7ePD83tzzn034VVGALji9OW9AO9MISzRpPE6bwqXDMsPOMbFHGuseOJN/XjaIwZmJWGQJl03BZVOwKhKdUlw8dH4f1v/xzKNC6giWpCQ6/POfpN15B/tu/xVF/3yyno5wIfqlK5zfw8LjK8JtbI8nXvJy84SuptGphnBaFa4e1YnkmLoFm0DQmrj/3N6NStu1KhKZyS7O6JXeDKMSCBpGliRM/IaA6LMvICTKFJdC+kXpFL5ZSPXmagzVIFAcYN/z+7AmW0kcc0w82RW76DMlENTFRYM68On3B3n6q538bmqvUz0cgaDRdEqJYUBGAmt3l4c9Fz9iOnJMEpXZCyj56EkkmxN7myziR19e67gkl43hnc1rIiyKzDl923JO37ZU+4JUeIJoukGC00qiy4pURwrR0TFMm4Zr+HAO/OlPkKrXeWM8wiOTHAyZW8M9o+upZ7LH85uzepBf4ubrH4rwBiO7mTqtMmOzUnjg3N4RHS8QtAYGZSbyj0sHcs97myK2lrYqEmlxdt6+aRSWOlKJBYLmRpYlXFbF1GQpKtMhQ0ayHDMpSzs3DSVG4eCCgwSKAshOmfgh8WTekol83MaDYRhc2vPSJvs8IMSUoBUhSRJ/ubg/0w6n+w3uKIprBacvsyd2Y2vhd6ZFug2l8TitCrMndm1QFAHEOayN6stkSU0l49//Rv/bFyiB+lP9spJlLu9r5Zk1Afqnhy/itKAPJakzkiTxzBWDefSjbby9Zi+aboT1njqCIof6Sl00OIPHLuonaqUEPzvOG9CeWIeVW99cj2FQ7waE06qQlR7L6zeMIElEcAXNjKYb5BXXUOENIkuhzb0uqTFH70lT+7Xlw42FaCfM79H0PpQUH7Kt9r0neWIyyROT6xyXRbYwrsM4Up2pTfhphZgStDLS4uw8fEEffvveJj6+czyORqQMCQQtgSm90jmzdxu+2HYwqqaGVl2lb7KDK0c0f/2QJEkoZ92H8fkDSKq3zuPe/j5Izn6VMq/Bqn0a095y88B4O+M6WtAMiW/VAfRTY2hLaNfyTxf05erRnXh1ZT7/Xb8fiyyhY4ARShEJ6joXDerAjeO60L1NXJ3vKxC0dib2SGPNA2eycEMBc5bmUe4JoMgShmEgSRIBVadDopOMJBev3TA8og0WgaCxlNT4eWfNXl5ZkY9f1VEOb3KpmkFSjJWbx3flkqEZ3DiuKx9/fyBMTEFk2ReKbKAkrajLSNYUWZJJcaTwyJhHfvLnPBHJMPOjPcywYcOMdWZWuAJBC+f2tzfQPsHBA+f1OdVDaXFIkrTeMIxhp3ocP5Wfw/wU1HRufXM9K3eVRpT25rDKdHfoPPrFP+n53DM4+/U1Pc4TUClzB46m9iU4G07tqxN/Ncbfu9cppv6Z7efxFQHmnO/gnG4WbAp8tktl2R6Nv5/twG3YmaXdz5DxU7n3nPD0XE9AZe3ucio8oZqrBKeV4Z2TibGLvcDWSGuYn07V3GQYBlsLqzhY6cOnasQ7rPRqG4fNIjPx70v4/O4JtD3BlEYgaAoMw+Cpr3YyZ2kuAP46DCacNgXDMHjkgr7MX7WbHw5UN/o9lditODu8iyQHGz5WspDmSmX+1Pl0iO3QqPerb24SdyNBq+TRC/sx9all5BXPKAAAIABJREFUnN23LcM71x3yFQhaMlZFZt41w3ju213MW56HrhumeeZH3PeuGJHJH6b1xj/Szr5bbiHz+edwDhwIhDrOL91ZzNyluazbXY5VkZGkkGBrG+/glolduWhwBrHRihR7HIVZV5D8w9s4pdoW6JU+g4e+9fPqhU6m9z6WSviLnlZ+0dNKwFDYY7QhW81ia/Ye7j6zB9YTajlcNgsTe6RFNyaB4GeIJEn065Bg6mh78eAOzF+1m/umiXpiQdNiGAa/e38zH20+UKeIOoL38P3r4f9tZWSX5J8kprSavnj23Iy9zUcojv2AjiTXfn9DtwEGwepBPDbxkUYLqYYQYkrQKkmOsfHoRf24971NfHLXeGRJYvGmQt7K2UtRtQ9VM4i1WxjdLYUbxnWhm+i3IWihyLLEHWd055aJ3fh860HmLcsjv8SNL6hht8i0T3Qya3wXfjGwPS5baEq3nXkmWCzsu/U2Mv79LFuSOnP7WxvwBNSjYkw9zi1pX7mXv37yI499/AN3n9mdWyZ0iypS9YjvCq7XVzNIzsUpHXPtyy7Q8Klwce/wW41qyFQSyy8D9wESmmGwdHsxZ/ZpE3asoHkorPBSUO7FE1CJc1jonBJDimh63Cq5cVwXLvj3Cn41JSv6DROBoB6e/monH20+ELFpEIA3qLNkRwkzhnTgky0HTWuDI0H3ZeLdcyvu7YupXvdfAsVlKA4Fe2YMqVOzsKacQ7BqCOh2/vrRHhbdHu5w2xSIK0rQajmnb1s+2lTIJc+vYk+ZB6DWBVtU7WdvmYf31xfQq20cD5zXhxFdRBRL0DKxWWR+MbA9vxjYPqLj4yZNQnriCRb86Rn+b8Cl+Bq4Vx25Np7+ahd7y7z85aJ+EQuqvFIf1wV/z7+tzzBG3krM4QhVqccg1SVhOcEcwmPYKDYSuTzwR4oJWdaqmkFBuSei9xM0HlXT+ebHIuYszWVrYRW245o4B1Sdcd1TuWVCN4Z3ThL1Na2IzGQXY7JSeXfNXmaN73qqhyNoJZS7A7ywNNc0IuXetoSqtR8SLC1AtjmxpnclYcxlODKOpZ9/se0QC24exT3vbWJfmZeAqqOZlB/F2BVUzSCg6mHtOKrWLKQy531Szr4dR5chSIoFb/56ypdtJWnysea/Px6oZldRDVnpTb95LsSUoNVSXO3n+/2V7C6te4GmHnYL21RQybWv5PDYhf2YYdKbRyA4Hdme2SciIXU83qDGwg37aRNn564ze0T4Gh0/Nm4O/oaz5XXMtiyml7SXBKdOiSd0jcmShAcHlUYMc7Tz+a82EQ/H6jeCmo4nip1NQfTsOFTNNS/lUBNQjzZsPnER9M0PRWTnltIpxcXrN4wkLU5EqloLt0zoyq1vbuCXYzqH0mnVAPirwOIAWwxRVfMLBMCCtXtNfzZ1CRzvzpxaYkozDArKvXzx64ls2V/Ji8vz+HzLQfyajkyoH9WQjoncMLYLv16wMUxI6X43FSveIuXcu3H1HHP0cVfWSFxZI2sdq+kGr67M5y8X92/KrwAQYkrQSqnxq1w2N5uC8rodxk7EF9R5cNEW4pxWzunbthlHJxCcHH73/qY6hVR9u4beoMbzS3K5ckRH0uMbLliPsYdcMw1kPtdH8HlgBF2lQia2XYlieZN7tvaia+8+ZOt9Wa33xqwxlVWRiRPpR83G5oIKrpy3Gk9Aq7fRskEoSrnzUA3Tnl7G4jvG0S7BebKGKWhGBmQk0j1BZ9uifzBwz+tQVQCKFXQNZAsMvAJG3QZpPU/1UAWnAbpu8PKK3WFus9EIHLdf44WluUzr345+HRJ4+orBQChKrhvGUUfmVbtKQr3RTtj88e//EUMN4OoxmoZQdYPPtx4UYkogiJRHF29jf4XXtEdNfYtIX1Dnrne/I/u+M0QvDsFpzaZ9FRRW+Eyfi3TX8M2cPfzmrIYXVn3bJ5Bb5K6VnpFntCfPeikx4xT+/dkHpBhn4+jSFUnWTHuGWGSJrHRhc14fBeUeVuWWUuUNIkkSKTE2JvZIa3CuKqzwcvVLOabmJXWh6gblniBXzFvNp3eNP1qPJzhN0XX4+mFeLplL4JABHDaLUQ//v67Chjdh07vQtj9c+hokNE+xvqB1sL/CS41fDXs8GoED8P3+SlRNr9VI+vj0Y4ByTxBMtoE0bxWyKx5JjqwNjtl4mwIxOwpaHdW+IIs27SdgksMb6SJywbp9zJ7Y7WQOWyBoUl5anodfDV88R7pr6Fd1Xlu1hzundK91kzPjxnFd+GzLQdMC5Eh6hgDEWmBU1+hrFlVNZ8n2YvJKaqjxqcTYLXRKiWFKr/SwG/LpiK4bLNtZzNyleWzYW44iSwQ1HYlQNC+oG5zVuw03TejKoMxE03M8+82uOgu869tc0nSDoiofH2wo4OpRnZvvQwqaF02FBVdD/lIUzUedcUZDBVWF/Rtgzli44XMRpRLUSYUnGFYPC9ELHKssU+1TI9jADn8vxRmP7qnC0LWI3k8yOUdTIMSUoNXxwYb9phdMpItIX1Dn5eX53Dy+K7LJRCEQnA58s70Yk8BsVLuGmm6w7UAVAzLMF+lH6NchgYwkJzuLakyfj+07mdi+k+t8vUPSuXjLVxTcvpjUW2fj7N9wGkZRlY83Vu/h9ew9qLqOP6ij6gYWWcJulZEliatHduLaMZ1O2zQ1t1/lxtfWsrmg0lQMBbTQY59uOcA3PxZx4aD2/OXi/kcbZR45x8LvCkyj9JFsLnmDOnOW5jFzZCcOVfl5b/0+cotqqPGrJDit9GufwPQhGSS4rGHnF7QQFt8J+UsgGGHau6GBtwJePRduXQVxwmFTEI4sm8WKohc4AU1n8j+WIEsSCU4L5w9ozzWjOtVKMU9yWU1lkL1DLySLFc+ObGJ6jWvwvZrLyVKIKUGr443Vu013yKNZRHoCKhv2ljNM9KgSnIYYhoE3YJ7OEM2uoSSFdh8j4Z6ze3L72xtMO9o3hN1hZ/bcP6N//D8K7rgTe1YWqbfOxjV0qOnxK3eVcPPr61B1I8xAQdUN1MPmCi+vyGf+qt08N3MwU3qdXgtCb0Bj+gur2F3ibrB3i26EjEMWbSykwhPk+ZlDjm4EfbhxP7JJhXg0dQ0l1X4uem4lPxwM9YQ5Pur/yfcHefyzH5nary23TupGr7bxjf7MgmZg31rYurCWkOr8VDWeIOTfFUuMLfTbeGlDgDc3B1lyXczhowzwVcCXf4Lpc07BwAUtnZQYO0EtfG6KVuDAsftMmTvAvGV5zF2Wx9huKTx8QV86pcQwpFOSqcufbI8hcdxMyr6cgyQrOLoMRpItdaaST+3XPPXwp38OhEBwAsXVAdPHo11EHqryN3icQHC6cfyuYVMyNisFvRFCymlVePPGkSQmxZF89Uy6ffE5cWefReHv72PPNdfizs7GOO4mumxHMTe+thZ3QGtQZAQ0HW9Q47a3NvDF1oNRj+1Ucsc7GyISUsfjDWos3VHMU1/tOPrY1z8UmUa1otlc8qk6mwoqCah6WPq0Nxj6O3y06QAXP7eSxZv2RzxewUlg1TOghtdOagY8nWN+rzyKrsK2heCrbKbBCU5n2iY4yEwKj/ofL3A8O7LRgz4MTcWbu47yb19p8Lz+w/PM0h3FnP/sCr7bW47DqnDZsEzTtML4EdNJmnIjldkLKHh2JgUvXEf1ho9wdq89tymyxPVjOzf689aHiEwJWh1mOyUQXej5yE6vQHA6IkkSTpsFt0mxbTS7hoZhkBhB+lZQ07n4+ZVRj9NhlVlwyyj6ZyQcfUy22Ui67DISp0+n6uOPOfjnR1Hi40m5dTZl/Ycz+831Ye5RDREyltnI/341lu5tWr7JRW5xDct3ljSqd4s3qPHi8nxmT+qGy2ah3P3TN5ciQTMMvEGDe9/fjCRJnD8gsn5ogmbEXQo7Pwcj/Hd07xgbT6z0c9twG4mOetLZJRk2vgOjZjfjQAWnK7MnZfGnRVvCzG0irZWtD92Aap/K1S/nsODmUQRU3TRdGRpOJQfo3S6ermlN32MKhJgStEJcNsV0JzaaRWQod1fUAAhOXyb3TOOT7w+E1U1FkxZBTQ3x//gzlVOmEDthPEpCAmb8/v3N5BW767XcNsMwDEpqzCPAksVCwoUXEn/++VR/8QXF//wXT7Ubhz+pj+nxDYmMgKrz/JJc/nX5oChHGTk7D1Xz8op8lu4opsavosgSSS4blwztwJXDO5ISG1nPpldX5pumS0ZqoCNJsGhjIVeO6Firfup4oq1rOJ6GHFHvfW8zfZpx4SKIkF1fhSzPCb/GhrVXmNTZwpOr/Dw2pZ72B0FPyOFPiCmBCecPaMefFm0xfS4SgQMNz90ev8ZFz63i3P5tuWpEJgu/K4x6s9tlU/jb9Ka3RD+CEFOCVsfAjES++bEobGEXzSIyqOn0aS9y/wWnLzeN78rXPxQ12mHPbpH55ahuJKpQ9emnHHz4YRz9+hF3xhRiJ0/Glhlqbp1bXMNHJqINGr5J+lWDP364lcm/T0eqo2GopCjET5uGMvkMvvjzl5gFniMRGZph8Mn3B3j4gr5NvlGyfk8ZD/9vKzsP1RDU9VpjrPAE+ffXu3jm611M6ZnGIxf2o009vbt8QY3/rt8ftgMbTY2TJ6Dx3KINjHv5r8SoWZDWO+x9GlPXAJF910FN5+UVzdMcUxAFnlLQ6q55/PNkO2NfcXPXyAZc1LxlTTwwQWvBYVV48tKB/Po/G6POGIDI5hODIyl6XRiYkUiNX+PLbYciFlROq8KL1w6jd7vmW9MJMSVoddw8oSvZeaWm0alIQ89DOyXRIfH0dAATCAAGZibSPtFBbrHb9PlIdg1/OaU3ifGDSZwxA93rxZ2dTfU331Aydx6W5GRip0zhxbhBBH9CG4JyT4A1+WWM7JpS71g+3nII2aLACdd1NCJDliQ+WF/A9eO61Pte0fDR5kJ++96mehcSvsPfz5fbDrFmdzkLbh5VZ7rh3jIPskk1c7S9W/ZrFuIumc7VMZms+3xPWBpOVBHKw0T6Xau6wQcb9vPAeb1Ff6pTiUl63/H0S1c4v4eFx1cE6J1WTwl9A+cRtHLK8iFvCXjLQ2mfrhTocQ7EpgMwrX87NhdU8MLSvKhOG83cHdR0XlyWx/NXD+XpKwbxr692MHdpHrIk1SmqXDaFRJeVuVcPq5VK3hyIWU7Q6hjRJZkklw1PwNwGtqFFZIxN4RbRY0rQCnhixgBmvpQT9Y6h06pwy4SutaxpZaeTuClTiJsyBUPT8G7eTMnXS1i4owJDrh3pieYm6Q1ozF2a16CY2rSv4icbKXiDGhv2VXB9g0dGxtIdxQ0KqePRDCh3B7hsbjaf3DXe1LK92hc0dd+LtsbJoshI4yczxW7B8U2BacPeaOsaovmuJQk+3nyAS4dlRjReQTPgTALZClrdRhOPTHIwZG4N94yuJwXV0bwLUUELRNdCaaIr/gWF34VElBoItXpS7PDxPdBtCoy9k6LEwSzaWMhtk7rx/voC3H41ogbh0cwnugFf/nCIgKpjs8j85qyezBrflQ/WFzBveR7F1X6sioxhGAQ1g3FZqdw8sSuju6bUmfXQlAgxJWh1SJLEH87tFdUi5wiyBBnJLsZnpTbT6ASCk8fQTsk8ffkg7loQeQqG06pw0eD23HVm9zqPkRQF1+DBFCd3QnoxB06ITEVzkzSAnPxS1PJydLcH3ePG8HrRPZ7j/nk5uFMBwtPjohUZFZ4GHMwixBfUuO0tczOM+tIbDaDKp/LrBRt59+bw78duUTBxAI66xknTDRxWGUWWuHFcF575eufRCNnxRFrXANF9156ARm6xed8xwUmiy4SQI189ZCXLXN7XyjNrAvRPN4lOWRzQ58JmGqCgReKvhrcvhwMbIXBCZoPBsd/Ujs8w8peyTR7JlUP/xh1n9+K3Z/dkxa4S5i7NJSe/DN0wMMB0Tot27pYliUpvkLS4kPCPd1i5bmwXfjmmM5XeIFVeFaslVKfqsDaNsU6kCDElaJWcP6A9Ow7W8OLyvIjzai2yhAFM6J7KSdjIEAhOCuf0a8drLhu3v70Bb0Crc8fQZVPQDYM7zsji1ondItrNq/AECerhC/Rob5Ief5DcaeeiuFxILieyKwbZ5Qr9czqRXS5cdMdMTEUrMn48WM3D/9tKx2QXmckuMpOdZCa5iImymePiTYWmC4SI6rd0g+/2VrCvzENmsgsAtawMz/r1yGs24vf2OGwccIxoa5ycVgW7JfR93DCuC4s3H2BXUTVBLXr7+iNE+11H2qNM0EwkZkKnMZD3bb2HPTTRzhub6/hbGQYMva7pxyZomQQ88PLZUJoLWkPtYQykoIfR0iomFv0R9AXIssKEHmlM6JEGhDad8opruHRuNm5/7XtPtPOJLEn41fD7lyRJJLpsJLoaqP1rRoSYErRafnN2D+IcFv7xxXY03SBYTw+cGJtChyQnz145hLve/Y6/fPwDD5zX+6SEhwWC5mZk1xRy7j+TZTuKmbM0l/V7yrEqMrIU6sXUJs7B7ElduWhwRlQd4svcAUy0VNQ3SUVR6Lk6u95jen+7i8+/3hnW5ygakWGRJYZ2TCQjycmeUjcrdpWwt8xDQbmHGJuFjGQXmUnOY0IryUXHZBftEh1Yldq79nOW5oalHUaT3qjrOvPe/IbbStfhWbcO9dAhnIMHkz5sGP1tTjaU1l7cRlPjZJUlZgzNOPrfDqvCW7NGctncbPaVeaLqXXU80Qq6SGz1Bc3M2Dth3xoIHosw7L67dr1eZoKM78Hw4nwDCSnrjKO1MYKfAe/fAGV5EQipY9gNP+xZCV/9Cc5+rNZzDqtCcozd1J002vlE1XXiW6jLshBTglbNTRO6clafNsxftZv/rNuHLEm4/SoWRcKqyGi6Qe928cye2I0ze6djUWTevXkU1726lvsXfs9jF/Wv01oYONpMVIguQUtHkSUm90pncq903H6Vck8AVTNIcFpJdFkb9Rtek1+GIkthznPR3iRjHQ3fii4e0oGnv94Z9ng0IkORJe6b1pvOqTG1zmEYBsU1fvaVedlX5mFfmYfv9pazaON+9pV5Ka72kxZnJzM5JLRibRb2lnnCxhJNemNQhw/3B/l1ny4kXnYpjp49kSyh7+G2bYe4693vGt27RT7sfHU8yTE2/versdz/wRY+3XIAScI0RVGRJXTDgND/Gv1du2wK3YQ1+qmnyyTIHA57V5s2760PAyuc+Qji7vYzoWRXKIp5wu+k81PVeIKQf1csMbbQr+GlDQHe3BxkyXWH59KgB9a8BBN+B46QMC+t8bNiVwlLthfhN5lrojXBSY6xERdlBsHJomWOSiBoQjqnxvDwBX25b1ovlu4o5sEPtzBjSAbd0mMZ3DEx7Iaf6LLx5qyRzHptLb/5z0aevHRgrV3prYWVvLw8ny+2HcITUDEIpdSM757KzRO6MaRjohBXghZNjN0SdVrbiQQ1nW+3FxHnsFDu+WlRlAsHNtzgtV2Ck5Fdklm2syTsuUhFRr8OCWFCCkKbIelxDtLjHAztlGT6WQsrvCGxVe5hdV5pk9QAuGUbSddeGzZfTO6VjtOmmKZkNlTjZJElBmaYf06XzcJTVwziT+4+zM/ezcvL86k5obGz2Q7y8UT6XRsGnDegXb3nEpwEZBmueAdenQrF2yMWVIbVRVH+YNTHnqf9E/+HbI+sR5rgNCbnhZDxhAmaAU/nBLh/fN2/A0OS2PPNK7yvTGPZzmLyi92M7JrMxB5pJLlsvJWzNywqHul84rQq3Dy+a4tdWwkxJfjZ4LAqnNO3LX/8cAvXjulk6qR1hFi7hfnXj+DWN9dz21sbePbKwewqquGe9zaxp9RNUDXQjltNeQIaX2w7xPKdJaTF2Xl8+gBGd6vfnUwgOF04VOUjv8SN26/islnomOJiy/5KOqXEcHafNvzrqx1hUY6fEkWpi1snZbF2d7lpHWRDIsNpVbh9cuNcOq2KTKeUGDqlhARKh0Qn3/xYRLWvthCJNr3RMELixaLUXiAossT860dw6ZzsqJtTxjksPHvVkHqPqfIFeXfNXoJmTbsiIBJBN31IB2GL3lKwueCGLwi+fxPSzs8xdA0rdRhT2GLBHoc0833Sk7pTeN997L3xRjKfe860affBSh+LNxVSUO7BE9RIibExuGMSZ/QKZXoIThOCXtj4NujmtXP3jrHxxEo/tw23keiooydg0INz3b/RRkzlD9N6M7RTEjZL6DdQXO3nrZy9pq+LxARHNwxmtGBnUDHTCX52VPtU4h0N5906rApzrxnGrxds5JIXVpFbXFOvI5phhETVnlIP189fw+PTB3DR4A5NOXSB4KSh68ZRV6Z1e8qP3hQBAqqO3SIzfUgHZgzN4J9f7jA9R0M3SVmqO1pkxuhuKdwwtjOvrNwdlchwWhUuH57BlF5tIn5NfcQ7reE5cESf3mhV5DoXnP06JPDq9cP55StroqpxsspSvSJpf4WXC59bSZU3aNpouSmwKjI3NmEvL8FPY3+Fl2e+3smibVfRUTqTy/RPuUIJmVLoSMiATQpyKHEQbab+HluPM0GWkYEO//gHRf/3BLtnzqTjvHlY24eiyNm5pcxZuovVeWUYhkHgOGOTGLuCVZb55ZjOXDO6E6mxIqrV4qnYF7I/r4Nh7RUmdbbw5Co/j02pu+l4G72E35/VDZTaa6y0ODuXDs3gvxsK8DaiVcc1ozs1ebP1pkSIKcHPgk37Kli2o5iDVT68QY3Xs3dzVp+2ZKXXn9Nvs8jMGt+FGXOyG0x/OR5fUOe+DzaTFGNj4mFXG4HgdGF/hZerX1pNUZX/aKrZiQt6v6rz3roCPttykF9NzuL5JblRR1FibBb+cenAqF7z23N6EtQM3li9J6L3c1oVLhnSgYfO79vgsZHSPT3W1MUw2hqAnvYgank5lqTw1EIIbfwYZqqtHkrcAWbMyebTO8eTFFPb3cowDK5+KYdqrxompOqzc48Gh1Xm75cOoKuol2oRbNxXwbUv5+AOqGg67CCVx7iGJ9QraC+VEI8HHzaKjER8JQl0/szB25kqyYd/O5Is0+YP92F5dT67r5pJhxee58ldOu+s2Vfn9RdybdOYszSX+at28/ZNI+nbXvSqatH4q+oVUwB/nmxn7Ctu7hpZj2ueYg1Zq7uSw556+IK+7C71sG5PWRStOmTGZaVy39ReER1/qhBiStBq8QU1/repkDlLcjlQ6cOvakcXEP/4YgdPf7WTHm3juHViN87u29bUaMIwDO545ztTIdXQ4sMX1Lnj7Q2se/CsWrv6AkFLZk+pmwufW0m1V62VymqGO6DhCWrMXZbH5cMzeHftvohukpIUElJvzBoZcVTq2Gsl7j+vN0M6JfGvr3awt9RNQDNqXaOyFOrX1D7RwV1n9uCCCGqyoiHGbuHCQR14f/0+TgwCRZre6FLgisot5J71KM6BA4mfNpXYM844KqyqfEHufOc7Amp0c49uhAq/71/4PS9cPbTW61blllJU5Qv7u0Zi594QiiRhs0g8MWMA5w9o2u9b0Di2H6zmqhdXmza7DmBlt3FCTZuqk1tUw4w5q1j8q3G16ipTrr8OS3oav/3b+3ybOQRvBPsmflXHr+pcOiebD24bQ6+24Y6BghaC1YlpuP04+qUrnN/DwuMrAvROq2NNo2uHzxWORZF59frh/Pa9TXyx9VCtNdmJKLKEVZG4cFAH/nJxf+R6jMBaAkJMCVolpTV+rnoxh71lHtPdM1U3UHWDzQWV3PPeJoau2cu8a4bhtNWuc1i7u5wyd3iTz0gXH5pu8MW2g2JxITgtqPQGuWxudlQpYIYBNX6VDzcW8tD5ffjXlzvxBFRT8wSLLKHIEr3bxfOvywfRJUohdTxT+7Vlar+2bC2sZP7K3fx4sJoav0qMXaF7ehzXjenMwMzERp+/IW4c14VFG/ejmUSoIqkBsNqsXP7n+5D9d1OzbDlVn33Gocf/D+egQcRPm8oHCb1NXxfJ3BPUDL7+sYjSGj8px6VYzV2aG/Z3icbOHUJ/w+PdG53WUH+yqf3actukLHq2jQt7jeDko2o617ycg9fkOqxPjAd1g/3lXu5f+D1PXzG41us+S+/Pt5mYCqn6zukJaFz1Yg4rfz8l7B4raCHEtgW1YTv0RyY5GDK3hntG15G6abHXKaYglAL89BWD2bSvgpeW5/HFtkNHnZUBFDk0f53Xvx03ju9y2kQ0hZgStDoqvUEufG4lhyp99faWOoInoLEmv4yrXlzNgltG14oizV2WG3Yzimbx4Q5ozFmSK8SU4LTg7Zw9VHrMhVRDkVi3XyW32E3O/WewfFcJc5bksmFvOQFVD0Wi7BYuGNie68d2aTC9Nhr6tk/g71GmCjYFPdrEMb57Gst3FOOLsm+T06rwu3N6huqlXC7ip55D/NRz0D0eapYupfLTz3hB9uN11F5IRDP3yMA7a/fyq8ndgVABeE5+WdhYorFzh1BPvsm90qnxqyQ4rfRrn8D0oRktup7h58jXPxbhPuw2ezyRiHG/qvPZloOUuQNH0/0Mw+Cpr3aaCqlIzhnKFNnP5cM7NufHFjSWmBTIHAm7l9d7WFayzOV9rTyzJkD/9BOiU7IVBl0V0dsNzEzk2auGUOEJsDqv9GiD70SXjTFZKRHVtbckhJgStDpufmMdRVXhQqq+xaBf1fnhYBUPLdrC45cMAEJRpSXbi8NuRtEuPnYcqqGkxi+KcAUtGl03eGlFvqkwiDQa8u6avdx7Tk8m9kg7Wiuo6Qay1Dp7sT175WBmvLCKnUU1EZtEOK0KV47IZOaoTmHPyS4X8dOmsaf/GNzzsuGEjZxo5h6fqvNOzr6jYmpvmRubRQ4bZ9R27gGNp06IWAhaHnOW5B6uXTpGNGJckmDB2r3cOikLgOy8Uiq94U5vkZ7TE9B4YUkulw3LbJVzQatg3N2oBeuxqOE99I7noYl23ths4vqYLHdDAAAgAElEQVQnKzDqtqjeMtFlY2q/07+FghBTglbFjwer2LSvopazEES6c6az8Lv9/H5qL5JibFR5gyiShHaCnIp28WG1SJS5A0JMCVo0S3cW4zNJCYo2DezjzQe4ZGjG0f+ur+n16Y7DqvDe7DHMfnM9a3eX4Q1qpv2nAKyKhCxJ3DapG7+aklXveQ9V+Uy/t2jnntIKNwf//Gd0n589fhe6pS9ItXd8o7VzV3UDVdOF7XULprDCy7YDVWGPRyXGgzrzV+0+KqZeWp5vmjIYzTmLqv1sLqhs1vRbQeM4VOXj0dWJPKQ5ScOLdNy6Z/fdtVN3MxNkfA+eUP8mW6D9EEhpXPuJ0x0hpgStipeX5xM8QUhFvRu3bh+zJ3ZD1Q3MNtCiXXxI1G9VLBC0BFbtKjGtc4pmseQOaCzZUVRLTLV2nDaF+dcPZ8PecuYuy2Pp9mKsiox+eP6QJAnDMLhiREeuG9OZzGRXg+cMaLqpKIta+CBh69IVyWEnOeBA2ibDCX/iaO3cLbIkhFQLZ3+FN4oopIYl7gcssduQLDUYhoShxqNWDaSk5pjo336w2tSeIBqBLwG7imqEmGpBaLrB2zl7+NdXO5k5siMJkxchvXYOBNxRnEUCRyJc+mqzjbOlI8SUoNXgCags3lQY5rwX7W7cyyvymT2xG/FOi6kIinbxoeq6qCcQtHhKasyLj6ONhpTVhBu2tHYkSWJop2TmXZNMSY2f9XvKqfQGscgSyTE2RnVNwWGNvPA+zmEx3ciJdu5xOawkX3M1AL2rfAR/+BaoPadFa+feJl5E2Fs6br95Q95aYtzix5a8AltyNqAhKceuW8MAa/wmDN3Ja1tKuaL35aZRqbBzNjBHqLpBTR1jE5x8thZWcv/CLdgUiQU3j6J7m8MRqGsWwZvTIVADRgMbwbIVXClw/ScQ17b5B91CEWJK0GoorGii1Jga/+GmpCFXsO2Hqms9H+3iI9ZuoX1C3e42AkFLoK5oQ7TREOvPPGqRGmvnnL4/bVHRu108AZMarGjnngEZxyIAbeIdDOmYRHZeadh5I7Vzd1oVbhgrmvG2dOLqKN4/Isa9eV+QNvU7JMWDJIeLG0kClACSEuDfG5/lo7zF2GzXgCf8+o9G4CuyhEu4+f1kDM3A90Mp7g2H0KoDoIPssuDonULMkHRkR/1Le09A5amvdvLf9QX8bmpPLh2aWdt6PHM43LIUvnoYtn8W6j+lemufxOoKqe4Bl8EZD0FMatN/0NMIIaYErYZqX9C0F0FjFoM1fpVki41bJ3XjgYXfh6U/Rbr4cFhlZo37f/buPD6q6nz8+OfcO1sme0IgLCEhIYDs+yaL4oai1hVcUXAB22ppq21/3dS21tZq+61aEFFEoVhwF1dUBEEUEARkJywhLElIyL7Men9/zABJZpLMJGELz/v1yotk5t65Z5Lh3Pvcc87zdDnrayQI0SE2Al0joG5SOBdLCkiOtZ26Rp4n2kRZGdstic+25wVM9wu174m06Ewbm17rsWlj09l8sDjodM5Q0rl7DYObBqc07U2J06ZLm8igCVE0ayRxY27k2NLZ2Dp0ILp3FOiK8m3lVGyvIHlS4E2Aak81WSVZaO2ehZIfg2ENfM0QA3wFdA5hmqsIzuv0ULbiIBWrD2N4DYw6CUYc2aWUfLQPe/8kYi7pjCk+sC/+Ynsef3xvK0O7JPDpz8fUv5Y7IR0mvgYVhfD9fNj2HlQX+wIreyL0uxX63AxWKc4NEkyJViTSagq6ziDsaXke48Tdsyv7JPO7d7cE3S6Uiw/DQFLBinPC1X3bM3vFnoC6SeFcLEVYdG4YeP6slzqV7h+TzqqsgqAFV0PpeyKtJi7MqH23eExmEvGRFqpcVSHXETvOZtK4bkBHmbJ8DkiItDAmsw1fbM+vs87JIPnmPOzp7Tm6JJ+Ds3PQbTq2NBsRnSPY/fvdOPOd6DadmEExtLupHXqkjtvrxqQXEdXpTcpzbg84XsiFqjWDwanxTXpPhmGwdt8xNuYUU1LlwmbWSY6xcUWvZGLtrf8z6SlzcnTOZtzHqiFIIW8AnL6+u3J9HlVbCki6pw+WFN/UvdySah5fspXtR0r5+419GZUZ4khSZCKMmuH7EvWSYEq0Gu1ibC0yNSbCop9Y32A16fy/8d154qMdQYv/NiTCrDN1VBrx/jodQpzNMttFk9kumh8OlQQ8F+rFUoLdwpC0pl0sidoGpcbTs30Mmw+W4AwzgU2EWeM3V/YIGBHXNMV/7x3G1c+torw6sAZRfSy6omu7KB67tlfjG4uzwv1jMli9p7BWMK5FHECzFBF/YSzxF56sYVbwcQFHPz5Kp3s7EdUzCleRi8PzD7P/6f10+V0XNJOG23Ci7NtQplIMd0zA8RoL8G3K4Ias5Ry4bR6JU6YSfdmlKL3xmSLlDjdvrT/Ii1/tpajSidPtxe0vt2Az6/zhvS1c0SuZ+8ek07vjuVHgNVzeajf5L2zCU1Rdd8ljcAYY1R6OztlM4gP9WLSvgH9/sZs7hnXmX5P6h7V+U4RGginRasRGmBmV2YYvd9S9Gxf6xaBFV9w+rPZI0h0j0sg+VsWCb7NDDqgizDqX9WzHw5d3b85bEuK0mj42g0fe3NSk0ZAIs29amdSQaRlKKeZOGcK1z63icHFVQLmH+kSYdaZcmFbvCGFqYiRvPzCSSS9+S3m1u9FALcKs06tDDK9MGSIXYeeQIWnxpCVGsiuvDLd/GNKS8BWo2vWBPFUe8t/Np+M9HYnu6xvFsCRZSPlxCrse2UXJ6hLix/hukOiaIiJhLZX5l4bdHt1i4oEXHkf7ZiXHXp5L/j//ScLddxF3/fVoEcHXFO8rqOCW2d9QWu0OOPd6DU70Ux9sPszSbblMH5PBzy7NbHV9UNFbu/GUOAICqbUHN/PXL2exq2A/mqaRmZjKo5c8SP/2FwDgdXrJeu57Puqss3jacLq2jQ7y6qIlSDAlWpVpY9L5dm9hk6fGKKW4c0RgMc3fTbiApGgLzyzdBVBvgU6LrqEU3Dkilf93ZY9W16mL1u3K3sn858usoDVqGmLRFT07xHDLUJnS2pJibGbe++ko7pq7ll15ZUH7teOO17H6+WWZ3D+m4Vovme2iWfrzMcxdtY/532bj9Rq11lEp5QuiEiMtTB+bwcQhKed9YpFzjVKKV6cO5apnV3KswomHSkxRO1GqdlBeubsSr8tLzKDao026TSe6bzTlW8tPBFMe3FgSvsF77PKgxb3rYzNrvDR5CHFRNrjsMmIuu4zKDRsonDuXguf/Q/wttxB/x+2YEhJO7JNdWMGPnl9FucPd6JRUr+HLxDv7q71UOj38dsIFIbftbOcpc1K1vTBgal+Zo4Ipb/6GJy7/Bdf0uBinx83ag5uw6idnwiggSlO8MqYbdgmkTikJpkSrMrRLAu1ibGQXVoS9JsCsK4anJ9IpPvgC2fvHZHDDwE78b+0B5n69H4fbg6YU1U4PmqawmDTuHJ7KHcNT6RAn2fvEuWfe6v0UVTjonxLLztwyqlyNXzDZTBpd20Uxb8oQzLqG0+PErJnlRkILiY0w8/YDI1mVVcDsFXtYl12EWVd4vAYaCk0Lv44V+JJc/Gp8D2Zc2o1Pt+by5c58CsudWHRFu1gb1w/oyMDO8fJ3PIclRVtZ8tNR3PLiN+RV54GhA7Wz93nKPZiiTCg98O9sijVRlV07i5tXVXBhu0pW59sb7R90pbCaNWbdMYgRGYm1nrMPHIh94EAce/dxbN489oy/kpgrryRxyt0YnVJ8I6chBFI1Vbk8zP82m14dYvjRgI6h73gWq1hzJOjje4/lAHBdT98oYYSmM7bL0IDtTG6D8hWHsPc8uUaqyl3FJ/s+YcexHRQ5iogyR5EWk8aE9AkkRiQGvIZonARTolVRSvHa1KFMeHYlZWGsCdA1RdtoK8/eMqDB7dpEWfnpuEweuKgr24+UUlTpZNG6HGIjzDx6TS8sJrl7K849hmHwz8928eHmI7zxwEiSY2w89elO5n+TjVIEHRGJMOt43Q6u7VzJoDEV3PbxDeSU5eD11yWJt8UzsdtEJnafSJI96XS/pVZF0xRjuiUxplsSh4qr2OCvY2UxaSRFWRmREV4dq5osJo1r+nXgmn4dWrjV4myQHGvj45+N4bnVDubvC5LtNkrHXe7G8BgBAZW7xI0pqvZlotkw+Ef5DL7QhzHTdAu5bl/mwJpBT4RZx2sYjO+dzIPjMunatv6Mb9b0LrT/0+Mk/ewhiv77X/bfdjurhk6gNKp/0ECqYttySte9i6vwIJolAnPbdGJHTsTWybeer8rl4R+f7uTa/h1a/EZAhcPNhgNFFFe60JQi3m5mYGr8KZ3+Wv7tkaAJJ9ITUtCUxs8/fIJre1zCgI69iLMFH31yHirDU+LgsMrn1W2v8l7We2hKo9JdeWIbq27l3xv+zaiOo5jaZyr9kvqdsvfUGikjWPozv8GDBxvffffdaWyOEC1jd14Zk178lrIqF65Gbm1ZTRod4mz87/4RtIsJP63zvK/3sedoBX++rndTm3taKaXWG4Yx+Ey3o7mkf2oZHq/BH97bwg8HS5g3ZQiJNVLlHi+E/eJXezlYVIXD7cVi0kiOsTFlZGcKnS+x+MASlNlOZd06JPhO0IZhMKrjKP486s/EWAIXrgtRU2von87Gvml30W7u/PhOKlwVtR73VHrYMWMHne7tROzQkwkcPNUedj2yi3Y3tSNh7Mnpd8owWL8/BzOAbmET3fhf6p/IcUZR7fIQZzczPD2RmwenNCnzo7eykiv+/hm7HYH3+kvXvkPJmjdJvPwn2LoMROkmqvatx5GztVYiKbtFZ96UoQztkhDwGk2RlV/Gy6v28873BzFrGsaJ27QKr2EwaUgKU0Z2oXNiy6Z9NwyDQ79dRX13hXcX7GfmmoWs2r+eoxXHuDhjGE+N/xVJkbXft7Lp7PhRKf9v2x9xeVy4jfoLJysUVt3K9H7Tmdp7qoxM19BQ3yQjU6JVymwXzaczxjBzeRaL1uWgIKC2SqRVx6xp3DUyjfvGpBNlbdp/h3YxtqCFMIU42zncHn6+aCNFFS4W3jcsoNin3WJi0pDOJ9L7G4aBUgqX18WML2ewNnct1UoFFnQ8/voeBwArD61k4pKJzL9yvoxSCXEGtI9sj9sbeBGt23XaXteWwwsOo9m0Wtn8zAlm4kbG1do+3uPlRC/hcdKPLfTLuQPu/gA6Dmt2O3eXesjxWqibbcHrqKB41X9JvGoG9u4jTzxu7zoMe9fax61yenjxqz3NDqY8XoM/vLuFtzccxO01cHsNqoOk01vwbTYL1xxg6qgu/OqK7i0XgDSSdCazTRr/mvBbALIKs3nog7/w2BfP8Z9rH6213XcRW/jzlhdxeB2NHtLAoNpTzezNs3Ebbqb1ndb09p9HJJgSrVZStJVHr+nFr8f34MPNR/hoyxGKKpxoSrH9SCkPX9adO0ekYmrmwuq2MVbyShvvpIQ4m1Q43Eybv55Iqx5ypjalfOtz/vD1H1h7ZC3VnuqQjuXyusityGXKJ1NYdM0iIs2RzW2+ECIMUZYoLk65mKXZS09MxT0u6aok9Eid3EW5OPOdaBEaMQNjSJmWwu7f7Mbr9NL96e5EmOGO0lJe2uBkwWYXy+/2/z92VcL862HGD2BrXnryHbml6EGK3DsO7cBwO7F3G9HoaxgQdhKdujxeg/vnf8fqrMJGk224PAZgMO/r/RSUOXjqpr4tElDlVzow8CWSaEzXxFQm9h7Pgo3v134N0zH+0nY2Dq8zYJ+ilUUUfFoQtLZYlbuKlza/RJ82fRjZYWTAvqI2CaZEq2cz69w4qBM3DjqZKvinCzcQE2FudiBV7fKwI7eMfQUVzPlqL9E2E307xdGzg0xnEmevYxVOpsxbR4920Txxfe+w/h+szV3LsgPLggZSDZ2cPYaHIxVHmLN5DjMGSQFIIU63u3vdzYqDK6gKMpKcMDah1nS+WrxQsLSAThOSuLGsgreDbeNxwcbXYfj0ZrWxrNqNJ8jUfE9VKZo9BqWFtj6p0hFeXci6nvhwG6uzCsOqL1nl8vDB5iOktYnkJxd3Det45Q43PxwsYdPBYjblFLP5YAll1S5eNdlJcgVun1WYzRd7vuHaHuNoH9OWw6V5vLf9CwZ2qF0L7v345bgJfA+h1Bar9lQzc+NMCaZCIMGUOC/17hjLlsMltQKscOQcq2Te6v38b+0BlIJyh4enPt2BSfNdlKYkRPDARRlc2bu91GYRZ5XDxVXc+fIaLuuZzK/Hhz8l5ZUtrwS9GAvl5Oz0Olm8czE/GfATzFr46ymEEE3Xq00vusR2YdexXQ2um6mrzZVtKPj4KLcNtZLgrWeUxlUJq5+FYdN8ufWbKMKsowXZX4+IwVtZiuH1hBRQmZ1VlH72GdaMDCwpKShz6P1NQbmDBWsO4AwyIhVKAoznl2Ux5cI07Jbgl9hOt5eduWUnAqdNB4vJOVbFBe2j6dspjit6JfOr8T1IS7RTsTaXkg/3YjhrtyXSYmfj4e3MWbeYUkc5MdYoLs0Ywe8u/vHJ4ygXH8WvxK3VDqbCqS2249gOskuzSY0JLBkjTpJgSpyXeneIZdmO/Cbt+/raAzz+/lY8huEf3vdxeQxcHl+ntSuvnN+9s4WnP93FomnD6023LkRzVbs8FFU6qXR6iLaaSIi01DvSlJVfzl1z13LXyNRGaxEFk1eRx7rcdQGPh3Ny9hgelh1YxhVpV4R9fCFE88y8ZCY3LbmJouoiPEZooy7RqVb0TDvWD/JgXANJmqpL4MC3kNr4VDyv16DM4UbXFJEW/cRNnY7xEUFjMWvHHiiTmcpd3xDZY1Sjr99OOSl56xMce/fizs3FnJKCNT0dS0a6L8BKT8fapQuaPfDc/PraA0Gn1tWXAKNq95oTwRT4Ysn3Nx7mlqGdMQyD/YWVbMopZqM/cNpxpIyUhAj6dYqjX0ocd41Mo1u76KDZgO0D2lLywd6Ax9tHJzHruseDvvcCUzHFplLWRW7FqwJH+cKqLeb18PqO1/nN0N8EPZbwkWBKnJeirDrfHyjiV29uwun20ibKyrD0RMb1aBt0vvZxL63cyzNLd4ZUsLDS6cHhqubq51bxwYOjJKASLWrLoRJeWrmXj7fkoimFpnzz/E26xu3DOjN5ZBoda9Q723ywmHte/Y5HrujOxMEpTTrmlzlfBh3JCufkXOmu5N2sdyWYEuIMSIxI5PUJrzPlkykUVBU0uu5RAWkuN0+MgMvmOvnlMEv9GxteKNhVbzDl9nj5fHses5bvYfOhEkya4nhC6XE92jJtbDpD/KnGK+pM09OskcSNup1jn72A0nRsXQagNBPV+zdSfWBzrWx+kVad+2+4kJR+NwPgdThw7t+Pc88eHHv2UrZsGc45L+HMzsaUmIglI+NEoGVKT2fuykIc7qYnwKh0enjy4x18+MMRNuUUE20z0y8lln6d4hjfuwe9O8aGnPBKs+hEjmxPxeojGA3U9XLh5uuYjbyRuJQcSy5mw4RLuXCpwIA5nNpibsPNjmM7Qmrr+UyCKXHeMAyDpdt8Hfn2I6W4PAaLvzt44vnX1x7AYtK4e2QXJo9IJT6y9klj+c58nl66k+oQCpke5zEMyqrcTJr9LcseHovVJFP+RPMcKani3le/Y+/RCpxuL5665S3cXl75eh/zVu/nkh5t+eek/qzPLuKh17/nyRv6cHmv5CYf+1j1sRMZ+moKt/BnQVVBk9sghGie5Mhk3rr2Ld7c9Savbn2Vclc5la6KE9PzlGFgMwyS3R7yvF5+WVTM0HQTV3cz8bdVTi5IqmeNpdcNjrKgT72z4SCPLdmG2+M9kVm35syOz7bnsSqrgKQoKxN6t2fx+pyAc23M0BvQIuMp+WYRBR88jbJEYG3XlZgRk2ptp1CMr9HPaVYrtu7dsXXvXms7w+3GdfAgjr17cezZQ9WG7znwzidUdroG9Nrn/3ASYACUVrm4ZUgKz0zsR9vo8Euu1BR7RRdcRypx7CuBINcf39t38ESnOXjxUqX7+mcX9U/jDLe2WLmzvFntPx9IMCXOC063l58v2siXO/ODFiAFX+r0CqeHmcuzeO2b/SyaNpyubU8WwXvy4x31BlINzaP2GAbFlU4+2ZLLj/q3jqrs4szYe7ScG2etprSeRdrHOf3Zpb7Ykc8lz6ygyulm5h2DGJ7evOr2wVIrQ/gnZ4+3eYvDhRDNYzfbmdxrMnf2vJM1uWv46v37KHRXoBkGbT0eLq+oopfTSZrXe+JC8fGLbAycXc4vR1iDv6hmAmtggd7nl+3m+S+zGrwRaRi+EZ3sY5Us/i4Hbz01UKN6XUxUr4vrfR2bWeOuEalBp8zVpUwmLGlpWNLSiB43DoCSI6WYX/gGh6N2XxduAgyrWWNwWkKzAykApSnaTO7JscU7qd5xrNb6qZXRG3i6w6s4tSBZKuph72pHmRSl60sDaouVbS6j3U3tam0v2VcbJ8GUaPW8XoPpC9azek9BSKNKDrcXp9vJ9TNXs+Sno0hrE8nWwyUcKKwIun0o86grnB5mLd8jwZRossJyB5Nmf0txlYsGaq3X4nB7OVRcRc8OMQxOjW92G+KscZg1My5v7RN3uCfnGKtkuxTibKCUYnj74QyP7A5ZnzW4bdcEjUm9zDy71kmftoHBiqE0VHyXWo/9b+2BRgOpuqrdXkweFxaTCacRejILi0mjV4dYZlzWLeR96jLrNYvynhRuAgyv4XutlqJMGgm39sCxq4jSFQdxHihjmy2LZ8IMpCC82mK60smIC3997flGginR6j23bDff7CkM2pk3NKJU4XBz65xvWfmri5m7ap//bn9t4cyj3l9YwfYjpVzQXi4kRfieW5ZFcaUzaCDVWIap/QUVLN2Wx1V92jerDUPbD0VXOi5qn7zDOTlH6BGMSxnXrHYIIVrY8AfgwGpwBr9peNwfx1qZvzn4xbuntIr82R8S+yMz9mHDKHN5eez9rUHXGDfWZ7l1M22jLJRWu3G4vEHCm9oizBq9O8byypShzQpi2kRZgmbxCzcBhtdrEGNr2UtspRS27gnYuifgLqrmwU/+jsMZ+LdoqETFcQ3VFtPMJ39/Zs3MrT1ubdH30RpJMCVaNafby5yV+4LWimhsRMlrQGm1i8+357M+uyjotKpw5lFrSrHlUIkEUyJs1S4Pi7/LwRXkMxjKyGil08MLy/c0O5jqkdCDTtGdyCrOCngu1JOzFy/XZV7XrHYIIVpY+sVgiQoIpvbPiK71c0qsRvXvg5zDTBGocQ9hy+1E/tPP4C4s5ONL7kQR2OeEmhWvrNrDX6/vw/sbD7N6byFAQKATadGxW0zcO7oLU0d1afZoUJzdQt9OcazPLqr1eDgJMDQFl17Qrtl1LBuyx8jmoOdIwOOhlKg4rsHaYn5psWlkxme2ePtbGwmmRKv2ydZcjCC38kMdUapweHhhxZ4TC2brCmcetdtjUFodem0PIY77YPORoKl6wxkZ3ZVfRlZ+Wa11gE0xtfdU/vztn8Mv/AloaFza+VJiLHJDQYiziqbB6Ifh80d9NaPC3l9HHzWdBHsCCZMnU7VzJwte3U5VnRtA4fRZDreHL3fkM2/qUHJLqlnwbTbf7i2kpMqF1azRMS6C24alMrprG7QGsvCGa/rYDGb87/uA836oCTBsZp37xqS3WHuCmb9tfsB063BKVITCptuY3q95RZjPFxJMidbD44bqYt+JwBoDtljmfLU3aCAUzojS9iOlxNmDF/wLZx61rvkWx4rzT1m1i8+355FX6sDp9hJtM9E/JY7+KXEhFc1dsulwsz/HHq/Bsh35zQ6mrki7gle2vMK+kn1hFf4E36L3nwz4SbOOL4Q4RYbe55vqt/MTCHKzpF6mCLhtEdhP3kg5FN+RMn0f1Ek2E06f5TXgs215ACTH2nj4iu6N7NEyxvVo60vPHqTPbSwBBkDbaCsDO8c1uE1zrTy0MqBOWDglKhpjM9m4KfMmLul8SYu1uTWTYEqc+45shm9mwta3fUUxlA5eF1hj2F/6f0BgkBPOiJLFpBFjM5NXGpgSOpx51LqmkRzT/Mw+4tyxM7eMl1buZcnmw+hK4XB78XgNLCYNXVO0jbbywEUZXNuvIxGW+j+LhRWBnz0I73Ps8hgcq3A2+b0cZ9EtvHTFS9zywS0UVBUE3B2tT4QpglmXziIlumk1roQQp5hScMMceO/HsP2DxkeolA4mG0yaD2m1z39FlU5MuqLO8sqws+I5PF5cHm+LJnNojK4pZt85iDteXhNW4gwAu0XnhTsHhXSTrDkqXIFr28ItUVEfm8nGpG6T+MXgXzS7necLuU0uzl3FOTB7DLx8OfywGDwOcDt8JwCPCyoL6y2uW3NEqTEOl5fCCkfQaVY151FX7voGr6saw+Omas93FH05t9a2SsHozKSmvFNxjjEMg/98mcWP/rOKt78/RLXLV1vF7fXliXK4vVQ6PewvrOTxJdsY98xyco6FP7UmnM9xS0qwJfDGNW/QM7EnEaYINFX/qcRuspNoS2T+lfPp37b/aWylECJsuhmufxGumwnJfX2jTqpO4GO2+x7vdytMXwldA0cv6ivdEG6fpaDeNOmn0uC0BGbePpAIc2hBn8K3fmvelKH0SD7105iD9bk1S1TUFaxERU0mZcKqW+mf1J9nxj7Dw0MebrBfF7XJyJQ4Nx3dCXOvgOpSMOrvlK24cBE4RS+cESUDg7tHpvHiV/sodwROawplHrVF17hjWGi1L8S57+mlO5m7an9IdzUrnR6qXR6ueW4VHzw0ik7x9oBtEiOD13UJ53Ns1lW9r9MUsdZYFly1gC0FW5i3dR5fHvgSS41Cl06vkwsSLmBq76mM6TQGkyanGyHOCUpBr+t9X3lb4fv5UJQNzkrfVL7OI6DfLWCrP2iIs5vxBgmows2Kp2vqjBW7HxInYgwAACAASURBVNejHS91d/Cn74rIjmqHy+vFU6dLN+sKTSn6dIrlyev7kNmuedOoQxVljgpYtxpuiQpd6WTGZ9I+sj1dYrpwY7cb6RzT+bS0v7WRs5s495TlwStXQVUxNJIwtaMqYKcR2DmEk5nHpGlc1acD1S4vc1ftCzra1dg8arfXyw0DpcbU+eCDzYd5edW+sKaHHM8cOWn2t3z58EW1gm7X4cNcUrKbtV47VZql1n7hfI51pRh3Qdvmv8E6erfpzdNjn6bEUcLBsoOUucqw6TaSI5NJjkxu8eMJIU6jdr1g/N/C3q1rUpQ/m13tm53h9FlArULj3koXFevzqFyfj6fSBYaBZjNhuyCBqJEdMMWFP42+pNLFhpwiSqtcaEqRGGlhUFo8VpOOp7iYpNlP8+7MWRxsm8rcr/fz+bY8yh1ulIJom4mr+rTn7pFppCae3sK2V3a5ktd3vF5rmnU4JSrAN/V64VULMevB14SL0EkwJc49y/7kSzRRI5BK+78yKl2w72dRRFp8E/Je2uDk4KbfY791NpVEBLxMqJl50hLtdG0bxc8uzWTV7gK255biCjKMXh+bWaN/Shz3vvYdz94ygH4pDS9M9XgNduWVUVTpBMOXqrVbu6hTmmZVtAzDMPj7xzvqDaQaqq3iNaC40snSbbmM72Ch9JNPKf3oI5x79zLm0sv5m2UoBMn3EOrnuHtyNBlJUafibQO+kapYa2zjGwohWj2TrnHXyFRmr9iLo84NyFD7rEiLzrQxGXjKnBR/uJeqLQUopTBq9K/eMhflxw5Tvvow1tQYYq/OwNK+8cBmy6ES5qzcyydbcn2Feg0DlG+6noHitqEpXPntO3S47HIi+vQmE3jyhj48eUOflvj1NNttF9zGop2LAh4PtUSFRbMwsdtECaRaiART4tziKIMf3gJv4FWlx4B/r3Hy29EnpzIlqlIKgq528mlsRCnSojP9Il/1b6tJZ/69w7jz5TXsyi2rdz1WTTazxm+vvIDJI9P4cPMRps5bx/SxGdwzqktAKtfCcgevrz3A3K/343B5Tjzv9RqYTRpTRqZx27BUkqJbbqqWaFnrs4sorCfJQyi1VSqcHp59dRmZXz1L1MUXkXj/fUSNHImyWJi0ZCsLvs0OGsg39jm2W3Smj5Uq9kKI0+eOYam8sGJv0OdCyYoXZdEZEhNB3r834K10gdc37T6Av0907Cnh6MyNJN7ZE1u34FnrHG4PM/63keU7j+J0e/EYRkCwBzDv63286u7NT0dm8KBhnPKEEuHqGNWRvm36sj5vPV5qtz+U+lEouKXHLaewhecXudUtzi2bFvnmcwfxyEgLT692UFx9srPVMLhD/wwbwbOhNcZq1hnf++RUpdgIM29MH8G9o9OJsZmIDJKBTdfAZtLo0zGGl+8awuSRaQBM6Nued39yIR9tOcLUV9dRUO5rk2EYvLAiixF/W8bzy7I4VuGkwumhrNpNWbWbCqeH4koXM5fv4cK/L+PZL3YHrZ0lWphhwMH1sP5VWP0crHsJdi31JTepx5yVe4MWiD5eWyXhsgewdx+JZrGhdBP2rsMCprVkW+PhrY/o+NRTRF90Ecrim9r34LhM4u2W+j7+9bKaNPp2iuXyXjLlTghx+rSNsXHvqC4hJ3GoyYqXBze9T8F/NuAt9wVSoTBcXgrnb8NxoDTgOafby+1z1vDlznyqXB48DZxHXV5waiZmrT7I40u2hd3+0+HRkY8SYQ6cddOYCFME9/e5n/ZRzSviLk6SkSlxbvlhcb3pWgd30LkozcTTqx38ZdzJudOPmBazwZvJD0Y6DixB9w3GbtFZcM+wgMWvVpPOw1d052eXZrJ0ax6vfbOfIyXVON1eomwmBqfGc8+oLkEXoqYk2Fk8bQT/+mwXE55dyT8n9ufz7Xn8b21OQGX3uo7fPZu1fA95pdX85breZ93dslbBWQlb3oRV/4KyXN9jXrcvo5Vm8hW3HHo/DL4HYmqfjH44VEKw83M4tVVMZp0dxxx0r7PULyHSwqJpI7hh5teUllfjCSG1sM2kkZ4Uxct3DUFvwaKWQggRikeu6M7h4io+3ZoX9EZTMDazxu+u6sUlK2JwHwus4bj24Gb++uUsdhXsR9M0MhNTefSSB+nf/gLAH1DN20r73w1D1Zge/8ibm9hyuCSs9axVLg+L1uWQkRTJnSPSQt7vdEiNSeXFy17k/qX3U+muDD5qV4dNt3Fd1+u4v+/9p6GF5w8JpsS5pbKwwaf/dLGVC+dW8LNhJ4Mms/LwmuXv3Of8BRuMTKpoeJGqSVN4DYOpF6bRs0P92YrMusaEvu2Z0De8uztmXeNX43swMqMN0+Z/R7XLi7ueNLLBVLk8vL3hECkJEUwf2zWsY4tGHN0J864GVwU4A+t4nPD1s/DN83D9bOj5oxMPVziCXyyEU1vF4zUoqw5eDLdLm0jmOb/lN1o39pljcXmMoCmILSYNBVzWsx3/uLkftibcGRZCiOZSSvGvSf15eulO5qzchwb1TpE/PtPj6Zv7cVn7OHI/PITSam9b5qhgypu/4YnLf8E1PS7G6XGz9uAmrHrtG6WGx6Bq2zHsfdoAkF1YwSdbcoNO6WtoLSv4zrlPfbqTSUM6n3UZefsm9WXh1Qt5ZMUjHCg9gNPrxGsEvke7yY5SiocGPMRtF9x2BlraukkwJU45j9dg2Y58Xlq5l70FFVQ7PdgsOmmJdu4dnc4lPdq2WHKF3m11ru5m4m+rnFyQdPI17crBa5a/8aZnLLM815KvtaXKq9UaRbD7O/KbBnVifO9kHlz4PZf2TKZ/IwkjmmpQajwerxE0kAqlc/+/z3dzx/A0oqzy37hF5O+Aly8FRzmNZYnE4/AlqXpnGriqfGmC8U2pC6ZmbZXGAipNqXqnxZR+8imRa7/mg7d/yc5SDy+v2ssHm4+gAE1TuL0GNpPGHcNTuXNEKu1jw58CIoQQLUkpxSNX9GDqhV14fd0BXlm1n0qXB5NSGIDL4yUlPoLpF3Xl6r7tsZl1it7P8qU5rWPvsRwArut5KQARms7YLkMDtjMcHspW5JwIpuat3h+0XlUoa1nBV+tq6bZcru7boSV+JS0qPTadt659ix3HdvDa1tf4dP+neAwPmtJwe91kxmcytfdULku9rFb5CtFy5CpMnDKGYTBn5V5mLd+D0+0rWnpcmcPN0TIH2w5vxKxrTBubzrQxGQFJGQJEBF9UWtPjF9kYOLucX46onahBVwaTTMuZaFvLxiF/543KARwprsLp9hIfaWFMtySu6duBCH9Q9eQNfXhgwXre++mFtI0OP+VqY5ZsOhx0ml6onbumFO9sOHjWTT04J1WXwLwJoQVSNbmqYMkMSOwKnQbTPtZGflng+rxwaqt4vAY7c0tZujWXkV3bnAiWXUeOkPvnP5MyayZ6VCQ9o+CZif158oa+FFc5qXJ6iLKaiLNbZEqfEOKskxhl5acXZ/LA2K7kHKukpMqFrikSIi10iKt946dyfX7QYCo9IQVNafz8wye4tsclDOjYizhb8NpOrtwK3CUO3HYTi9blBCTvOb6WNfGqGdi7jzzxuL3rMOxdh9XatsLhYdbyPWdlMHVcj4Qe/HX0X3li1BNUe6pxeV1EmaOk+O5pIMGUOCVcHi8PLvyeFbuONjhP2hdgeXj2iyzWZxcx8/ZBDQ6jl2ZcS8ThHzB7q+vdpmuCxqReZp5d66RP28DXUngYMHQsA+JTG3wPl/dKZsvhUn7y3w38997hJ9q15VAJi9blcOBYJdUuD/F2C8PTE7hhUCdibKGnGZ21Yg+Vztq/m3A690qnhxdW7OWO4amydqq5vl/oX4vXeLr9BZtdLL+7RupddxXGsj9TPeivXH90M7vcyVSZagfy4dRWcbi9vPZNNgvX5uD2erm2Xwemjkwl4te/JmHyZCL69q312haTdkqCfSGEOBV0TZHWpv705YbHwHAGv26Itkby9u3PM3PNQn71yT84WnGMizOG8dT4X5EUWTuDndI1vKVO9lU6gibuCWctK8D2I6UYZ2Fmv7qUUkSYIogIUhJGnBoSTIkWZxgGD7+xieW78kNe6Fnl8rAqq4CfL9rI87cNqNVZOd1elu3IY9G6HHZnp/Clavw1/zjWyvzN9WRd6zQEGgmkjptxSSZbD5Xwlw+2MiA1nv98uYdDRVU43J5aN81W7DrKkx/v4Oq+7fnJxV1Jb6SeT1m1i5xjgYk0wu3cj5Y5OFbhJDFK0qU3mWHA6meDJjYJlm4/6EtkfUXe69O57MpJ/MMb4UsFVUeotVXAv6bAP7f/rfUHeW/9Aa6OHcBTU6cGbCuEEK2J4fH6svbWk20vs00a/5rwWwCyCrN56IO/8NgXz/Gfax+ttZ23rIz9t09miwkYcAeYa990CmctK/iClGqX98TsFSGOk2BKtLhPt+bx2ba8oIFUQ2uBql1evtyZz4c/HOHqvh3Iyi9j0boc3vn+EOltopg4JIWrbh+I+YPr4Yc3wTh552r/jNrD/CmxGtW/D5I8whwJF84I+b1omuLvN/VlzFNf8nqQaQLHHR99e3fjYT7eksvsOwcxOjOp3tctrnRh1jXc3tp338Lt3M26oqTKJcFUc+xfCY6SoE89MtLCU187+PEQC3G2+u9GKl0n9ecXoy7/KXd8tJ1XV+8Pusg6lNoqdXkM8KDxUVwPnG9s5tlbB5z1d0aFEKKplFmrN5Cqq2tiKhN7j2fBxvcDntOio0n77yuUV1Wh/XcT1EkQFM5aVvCtmzLr0veKQDKRUrS4WcuzAqavgW8t0LEv5hA7fCKdfrqAjg+8QvTAq6javebENpVOD3/9cDs3zPyaW+esQdc0Fk8bweLpI7hpUCfsFhNc8kew1p9lr14mG6QMhYxLQt7F6zX49VubcXuNegOpmjxeg0qnh/te+45v99afedCk++qs11Wzcw+FAZg0+W/cLIc3gjt4HbKa6fYbogwXKucbAH5xeTe6J0djaaGkKsdVub18vj2fZ7/Y3aKvK4QQZxOlFKYke9Dnsgqzmb32fxwpzQfgcGke723/goEdegVu7DUwtYumbVIsziDn75prWUMRYdZbLFmWaF1kZEq0qKz8cnbmlgU8Hs5aoNzSau4fk8HtwztjDtZxxXaCu973JQxwlkOQNKABTDZI6gG3LPTVCQrRS6v2sjqrMGgNqMZG2e599Tu+/vU4Yu2B66ji7ZagwVk4iQrANwUyLjL0dVoiiKpiXx2pegRLtx9UtW90y2rSmX/vMCa/vIYduWVh1TQJJYvjCyv2ct+YdN+NBSGEaIWix3ai+L0sDGft/jPSYmfj4e3MWbeYUkc5MdYoLs0Ywe8u/nHtF9DAPrAtmkWnk8VOaqKdXXnltTcJYy2rSVP8qH/HU/Z+xblNzsaiRb3z/aGgqb7DWQukgMMlVcEDqePa94X7l8PCiVB6JCB5wAm6xTf3usfVcN1MMIU+Hc7jNXhhxd6gCTRCybjn8Rq8uT6He0anB+xvM+sMSUvgmzqjV+F07gB9O8WGlfRCBGGJwDdIHzzoqS/dfgDTycW+MTYzi6eN5IUVe5j79T5cHm+9NaiOCzWLo1Lw7veHuG1YaOv+hBDiXGPv24bi97ICHm8fncSs6x5vdH+laUSNOhn8PHBRBr9/Z0utrMIQ+lpWk6a4Z1Ra096MaPUkmBIt6mBRZdBgKqyipYbvdRqVmAE//Q5y1voSCOxeCpoZlAaG2/f9kHtgyL2+0awwLd+ZjyNIIBXqKFuVy8OLK/cydVSXoGtcpo1NZ/PB4iZ37pFWneljM8J+X6KO6A5gjvAV6q1Hfen2a4nrXOtHi0njoUsy+fFFGSzbkc/CNQc4VFxF1tHygOUA4WZxnL1irwRTQohWS5l1oi/uTNmyAxhhjO4DYNKwdovDXGOq4JW92/OHd7cG3byxtayagu7J0XRtGzwFuxASTIkWVV1PGvRwF3qGPDVKKeg8DDr/F5wVUHHUV/vHGgNR7UBv+kf85VX7AgIdCG+Urazazbr9RQztkhDw3JjMJOxWU9BjhJKowKJrjOvRttE2iPqVO9x8WN6X61wuGhqzbCzdPpYoX+AehEnXuLxXMpf3Sia7sIIr/70yYE1huFkcc4oqcbq9DZYREEKIc1n0RZ1w5VVQvbUw9IDKpDAnRZB4a49aD9vMOrPuGMh9r30X1tRrgEirieduHRjWPuL8Imdi0aISIoNfkoa70DMhsglVui2REJ8GbS+A2I7NCqQA9hcGH6kIN+PegSAp0MGXKfD5Wwdga8IFsc2s8fxtA2UxbBPllVbz67c2M/gvn/H4Z4dY4h6G22j4d/nHsVYqnMGTkBR5bNzwsYlrnlvF5JfXMO/rfZRWB6bmL61yowcZpQw/i6MW9PWFEKK1UEqRMLE7kUOTfRn+GjndKYuGNS2WpAf6ocyBfenozCSeubkfNnPo5027RWfBPUPpnBg8IYYQICNTooUNT0/g/Y2HAkZbwlkLFGnRGZ6eeLqbHqC+u1fhjLJ5vAYVjuDJDXKOVfLG+oO4veHdJbOZNf5xYz8u7NomrP2Ez47cUm598VvKqt0npqS+pCYwQV+DCeeJ7UJNt19lWPh39VVsyDmZXn1ddtGJumMzLu1GSoLvRGwxacFW9jUpRa9VRqWEEK2c0hRx12RgH5xM+aqDVG4qQOkKw2uA4Z/V7zWwpscRPaYT1ozYBktHTOjbgeRYG398byt7jpbjcntpKFGvgcFdc9cxeWQq08ZkEGmVy2YRSD4VokWN753Mb9/ZEvS5UNcCGcDVfdufhtY2LCLInS0IL+OerimignS+67OLuGvuWiqdboIsMau3PUnRVp66qe9ZEWyeiw4UVnLzC99QVl07wN1hdOZf7huZYXobu2o4DXpN1YaZ77zdeM1zea3Hq/w3E975/hCfbs3j1alDGJSaQFK0NWhmyHCzOCqCf66EEKI1srSPJOHm7sRdk0H17iK8FW7wGmgRJqzpseixoSeXGpSawIcPjeaDzYf5+aKNeBqIpqqcXqrwMnvFXj7YfIT/3TectjG2ercX5yc5G4sWZTXp3DY0hXmr9wdN/d3YWiCzrpg4OAVbPYHM6dStXTSHiqsCHg83415G26haP289XMKdL68JWourPgpIjLLw8c9GEWmV7H1NYRgGU+etDTpSWLFtOY+v+4LHC4uIt3oYkKzxu9FWRnWuv4usNCxs8GZyn+uXeOuZf+I1fOuy7nx5LW9MH0GvDrH0S4ll3f6iWtuF85nSleLqvu2lcK8Q4ryj2UzY+yQ1+3X2HC3nN2/9EFL9SACH20t2YQU3zFrNhw+NJjZCzsPiJAmmRIu7d3Q6i9bl4PLUX7unPjaTzv1jAlOJnwn3je7C2n2FQRNEhDrK5nB5WLr1CHERZtLaROLyeJn88tqggVRDNYYM4GiZgz+8t5V/Tux/qt5yq7bhQDGHS6oDRgLrpiS/zLyJIdmv8OaOQoanKEyq9g7lhg0nJl5yT2C252o8+AL/hv5+lU4Pd768luUT2nDjlk/YYutPVZ00/aF+pswmxT2ju7T8L0gIIc4DXq/B5JfXUuEMfmOtvn7c44X8UgcPv7GJOZMHn4GWi7OVBFOixbWLsTH/nmHcOufbsEZfIsw686YOpUNcROMbnwYjMhKJtpmDBlPQ+Cib3aJz3+gulDs83DhrNV3bRtEjOTpoxsNQagw53F4+3HyER6/uFbQQsGjYnK8Ca4YFS0n+BcP4Im0Yvbvs5QPPx/TW9hKlqnBgJcdowzz3eJZ5B9QajQrl71dVVsHiJ97iuhvH8vSeaKoqnNQVSoreLm0i6dUhtiV+JUIIcd5ZlVVAcaUzoERFKP240+Nlxa6j5JVW006m+wk/CabEKdEvJY43po/gjpfW4HB7Gwyq7BYdi64x/55h9Ol09lwkKqWYcWkmjy/ZFrRwb2OsJo1pYzOwW0z8enwPPt+ex6/e3BwQnIVTY0gpWPxdDvedJaN354pKp5tlO/IDTp4NpSTfYqQzw/2TRl875Lpjmpl3x97GvbeN4eWcYm558duwP1eRVhMv3DEorH2EEEKc9MKKPc07DwPzv8nm4Su6n47minOApIMSp0yvDrGs/s0lPHZtL7q0iSTCohNlNRFh9v1rt+ikJdp59JqerP5/486qQOq4SUNSuLZ/h3qTUdTHbtFZcO8w7Bbf/QqLSaNPx9igmfvCqTFU7fLyyup9YbVFQEGZE5Pe/JTkwYTz98vKryC7sIJ+KXHMmTwYuyW042oKYiJMvH7fcFITI5vcViGEOJ8VlDv4Lrso4PFw+nGH28uCNdmnonniHCUjU+KUirDoTBycwsTBKWw5VMK+ggrKHW4irSa6JEbSu2PMWb2QXinFk9f3wWbWWLzuYKMjCRaThs2sseCeYQFTsXKKKjHrWkDK9XAv6AvKAqeHiYZVuTxoQT5m4aYkDyacv5/ZpHGoqIrUxEhGZbbh7R+P5PH3t7LhQDFewwhYDH08/fmozDY8fm0vOsVLrRMhhGiqI8XVWHUtIKtquOfhkioXHq+BHuzEIs47EkyJ06Z3x1h6dzz7Rp8ao2mKx6/tzWUXJDNrRRbf7S8KuPCNtOroSjF5RBp3jUwjKTowTWtVPVMdw72gd3q8GIZxVgehZ5tomwlPkHJe4aYkDyasv59BrSmvPZJjeP3+EeQcq+S1b/bz6dY8yqpdKKWIjTBzXf+O3Dasc9DPkxBCiPBUOt2+eXp1hHse1pWi2uWRulMCkGBKiJCNymzDqMw2HC6u4r2NhzhYVEWl00NilIVBneO5tGc7zHr9M2frqwsU7gW91aRJIBWmpGhr0DuI4aa5Dyacv59SEGUL/BykJNj53YSe/G5Cz9DflBBCiLBE2UwBa2ch/POwxzDCnv4vWi8JpoQIU4e4CB64qGvY+6UnRQUt2BruBX2arJkJm1nXuGVoCq8GqX8Wakpyha+gdF3h/P0cbi/pSfL3E0KIM6Fzgh1XkGkK4Z6HO8RGoMkUP+EnwZQQp0lStJWRGYks33k04KI81Av6SIvOtLGSya8p7hqRxvxvsgkWEjWWkhxA1xQWXVHpCjwRh/r3G5GeSNtoSacrhBBnQrTNzPheySzZfDig5mCo/fjxsidCHCfBlBCn0f1jMliz71jQVPGhXNADXNWn/aloWquXkmBnXI+2LNuRjyPICGFDbGaNCX3a88mW3Hq3aezvF2nRuV8CYSGEOKPuG5PO0m15QRNKhXIe9ngNbhzU6VQ1T5yDJDW6EKfR8PQEUuLtmJowPSDCrHPv6HRsMk+7yf41qT9d2kSeyJIXCqtJo2f7GJ68oS/3jU5v0jx5k6boGB/BiPTEsPcVQgjRcnp3jOWC9tGYg5TLaEyEWefWoZ2JtplPQcvEuUqCKSFOI6UU8+8dSpzdHFZKVZtZY2TXRH52SeYpbF3rZzPrvPnASPqnxIVU48lu0RmRkcjC+4ZjMWk8dEkmozLbhBVQ6Zoizm5mwT3DJHGIEEKcBebePYSkKGtYNzZtJo3+KXH8fsIFp7Bl4lwkwZQQp1nbaBtLHhxFx7iIkC7K7RadK3ol88Idg2TBawuIsppYeN9w/m9SfwZ2jsNq0k6MVCk48fPw9AT+c9tA5t415MRooKYpZt4+kPG9k0MKxiLMOh3jIljy4CjaxshaKSGEOBvE2S28/+AourSJDPnG2uhuScybOgRTA1l7xflJ1kwJcQa0j43g0xljePv7g7ywYg+F5U6qXZ4TC2ItukIpxeDUeKaNzWB0ZhsZ1WhBuqa4vFcyl/dKZl9BBct35lNc6UShiI+0MK5HW1ISghfINesa/5zYj1VZHXlh+R7WZReh4MQ6LE35RsASIy1MH5vB9QM7YrdIVyuEEGeTNlFWljw4iiWbDjNrxR6OFFfjdHs4nvDVrCs0pejbKZZpYzIY16Ot3NAUQckZXogzJMKic/uwVG4b2pn12UWs2XeMwnIHFpNGUrSN8b2T6RgXcaab2ep1aRNJlzbhZWZSSjE6M4nRmUkcKq7iky25FJQ5cLg9JEZZGdolgcGp8RIACyHEWcxm1rl5cAo3D05hU04xq7IKOFbhxKQpkqKtXNazHalSjkQ0QoIpIc4wpRSD0xIYnJZwppsimqBjXAT3jJI0uUIIcS7rlxJHv5S4M90McQ6SiZ9CCCGEEEII0QQSTAkhhBBCCCFEE0gwJYQQQgghhBBNIMGUEEIIIYQQQjSBBFNCCCGEEEII0QQSTAkhhBBCCCFEE0gwJYQQQgghhBBNIMGUEEIIIYQQQjSBBFNCCCGEEEII0QTKMIz6n1TqKJB9+pojhDgNUg3DSDrTjWgu6Z+EaJXO+f5J+iYhWqV6+6YGgykhhBBCCCGEEMHJND8hhBBCCCGEaAIJpoQQQgghhBCiCSSYEkIIIYQQQogmkGBKCCGEEEIIIZpAgikhhBBCCCGEaAIJpoQQQgghhBCiCSSYEkIIIYQQQogmkGBKCCGEEEIIIZpAgikhhBBCCCGEaAIJpoQQQgghhBCiCSSYEkIIIYQQQogmkGBKCCGEEEIIIZpAgikhhBBCCCGEaAIJpoQQQgghhBCiCSSYakWUUhcppQ6e7n1PF6VUZ6VUuVJKP9NtEUI037naZymlPlZK3XUmji2EaJ5zqd9RSi1XSt17uo4nmkaCqQb4L9yPf3mVUlU1fr79FB73bqXUqlP1+i1BKWUopbqe4mPsV0pdevxnwzAOGIYRZRiG51QeV4hzlfRZDVM+e5VS28LY5zGl1IKajxmGcaVhGK+2fAuFOPdIvxOcUirNf61kqvP4PKXUX85Uu0TLMzW+yfnLMIyo498rpfYD9xqG8Xnd7ZRSJsMw3KezbUIIUZf0WY0aA7QFTEqpIYZhrDvTDRLiXCf9jjjfychUExwf5lVK/VoplQu8EuwOSc3RG6WUVSn1tFLqgFIqTyn1glIqognHnqKU2q6UKvPfYZ0WfU1T6gAAIABJREFUZJvfKqUK/CM7t9d4vEXaEOR4jymlFiulXvO3a6tSanCN53+jlNrjf26bUur6OvvfV+M9bVNKDVRKzQc6A0v8d7d+VfMuj1JqklLquzqv83Ol1Pun8r0KcS6SPuuEu4D3gI/839dsQy+l1GdKqWP+Y/1WKTUe+C0wyd8PbfJvu1wpda+/fcVKqd41XifJf2e+rf/nq5VSG/3brVZK9W1G+4U4Z0i/E1I771ZKrfIfr0gptU8pdWU927ZXSm1WSj3i/3m5UurPSqmv/e9zqVKqTY3tr/VfjxX7t73A//gUpdSSGtvtVkq9UePnHKVUf//3hlJqun+bYqXUf5RS6lT8Ls5lEkw1XTKQAKQC94ew/d+AbkB/oCvQEfhjE46bD1wNxABTgH8ppQbWaVcb/+vfBbyolOoebhuUUjOVUjPDaNe1wP+AOOB94Pkaz+0BRgOxwOPAAqVUe/9xbgYeAyb739O1QKFhGHcCB4Br/FP7nqpzvCVAd6VUZo3HbgMWhvtehThPnNd9llLKDtwE/Nf/dYtSyuJ/Lhr4HPgE6OA/1heGYXwC/BVY5O+H+tV8TcMwHMDbwK01Hp4IrDAMI18pNQCYC0wDEoHZwPtKKWt97RSilTmv+50QDQN2+tvzFPBy3YBFKdUFWAE8bxjGP2o8dRu+99cWsAAP+7fvBrwOzACS8N1AWuLv81YAo5VSmlKqg3+/Ef790oEoYHONY1wNDAH64uvfrmjm+219DMOQrxC+gP3Apf7vLwKcgK3G83cDq+rsY+D7j6iACiCjxnMjgH31HCvgtRpo17vAz2q0yw1E1nh+MfCHxtrg3/dgGL8PA+jq//4x4PMaz/UEqhrYdyPwI//3nx5vf0O/c//Paf7jmvw/LwD+6P8+EygD7OH+vuVLvlrjl/RZAce9AziKb3q7DSgBrvc/dyvwfT37PQYsqPPYcnxTmQAuBfbUeO5rYLL/+1nAn+vsuxMYe6Y/H/IlX6fiS/qdWsdMo8Y1S43H5wF/qfEesmo8Z/fvk+z/eTnwT//v9dY6r7Mc+H2Nn38MfOL//g/A4hrPacAh4CL/zznAQOAW4EVgLdADX2D2fp2/zag6v6ffnOnP2dn2JWummu6oYRjVIW6bhO8/yPoaNxsUEHZWOv/w76P47ppo/tf9ocYmRYZhVNT4ORvfndYWa0M9cmt8XwnYlH9+tFJqMvALfB0L+O56HB+KTsE3ctUUC4FngD/huzvzrmEYlf7pNafyvQpxLjrf+6y78F1cuAG3Uuot/2Pv0Lx+6EvArpQaBuThu5v9jv+5VOAupdSDNba34Ht/QpwPzud+5/j6MHON74//7Krx84nrJ/81DPiuk467HcgC3gxyjLrXXsf364DvPR1/Xa9SKgffKBv4RqcuwhfErgCKgbH4AscVIR5D+Mk0v6Yz6vxcge8/IABKqeQazxUAVUAvwzDi/F+xRo1Fm6HwTw15C3gaaGcYRhy+oduaw8HxSqnIGj93Bg63VBvCpZRKBeYAPwUS/W3eUqPNOUBGPbvX/R3X9RmQ5J/beysnp/idkfcqxFnuvO2zlFKdgHHAHUqpXP/6jZuAq/xrDHKA9Hp2b7AfMnzZRRfj64NuBT4wDKPM/3QO8ESN9scZhmE3DOP1cN+DEOeo87bfAY7gC5rS6jzehRqBTgge87droQq9NMxhfDdzAF8mU3w3jQ75HzoeTI32f78CXzA1lsBgSjRCgqmWswnopZTqr5Sy4fvwA747AvgCin+pk4uSOyqlGpp3qpRStppf+O5oWvFNVXH777xcHmTfx5VSFqXUaHxzXd9oYhtaQiS+zvSo/5hTgN41nn8JeFgpNUj5dPUHYOC7y1vfBQ6GYbiAN4B/4JuT/Zn/8TP1XoU4l5xPfdadwC6gO76Ro/747lgfxB8AAe2VUjOUb/F5tH+kCXz9UJpSqqHz5UJgEr47yAtrPD4HmK6UGubv3yKVUhP8a7SEOB+dN/2O/0bLW8ATSqlEpZRZKXUrvqUQH4fxUi7gZnzXU6810hcdtxiYoJS6RCllBn4JOIDV/udXABcDEYZhHARWAuPxre38Poy2CSSYajGGYezCN93sc2A3ULf2wa/xDdN+q5Qq9W/XnfqNxHd3pO7XQ/j+kxThm9r2fp39cv3PHca3yHq6YRg7wm2D8mWveaHhd904wzC24ZuK9w2+i5I++NYUHH/+DeAJfBcgZfjmNSf4n34S+L3yZZB5uJ5DLMS3ZuENo3bK1XB/30KcV86zPusuYKZhGLk1v4AXgLv8I0mXAdf427Mb34UG+G7YABQqpTYEe3HDMNbgu+PegRoXSYZhfAfchy8hT5H/vdxdTxuFaPXOs34HfOuYjuFL6JCPb5bOBMMw8hrYJ4BhGE7gBqAdMLexgMowjJ341ok+h29U6xp8Cb2c/ud3AeX4gigMwygF9gJfG1LLM2z/n737jo+qShs4/jt3SiY9oYTeIy2ICihNBFx7WXtBZcG6ll3L7rp91y2u7666q+u6CoqiggUromtHQZAmICC9hgCBFNIzfe59/7hDSDIzmZmQhIQ8389nIJnbzp2Z3DnPPec8RxlGtJ5UQgghhBBCCCHqk5YpIYQQQgghhGgECaaEEEIIIYQQohEkmBJCCCGEEEKIRpBgSgghhBBCCCEaQYKpY6SUekkp9XDw5wlKqW0tdFxDKZXdxPusOZeW3LalKKV+q5SadbzLIcTxJNesY9/2WCileiulquKYL0aIE45ch45925Yidafo2kUwpZTKVUq5gl9gBcEPb5NP4GoYxhLDMKKm31ZKTVdK1U8H2mSUUouUUrc11/6PVXOff/AYk5RS+2s/ZxjGI4ZhtNrXRYgj5JrVOgWvK4ZS6ldxbJOrlDrnyO+GYeQZhpEi6YdFayfXodZF6k6tV7sIpoIuDc5gPQIYBfy+/gpKKWuLl0oIIcKTa1brMw1zzpgfHe+CCNFC5DokRBTtKZgCwDCMA5iTKg6Dmibfe5RSOzAnkEMpdYlSal1wsthlSqnhR7ZXSp2mlFqrlKpUSs0DHLWW1YnolVK9lFLvKqWKlFKHlVJPK6WGYE4UOTZ4t6csuG6CUupxpVRe8A7QDKVUYq19PaiUOqiUyldK3dLY81dKvaWUOqSUKldKfa2Uyqm3Siel1OfB81uslOpTa9vBwWUlSqltSqlrG1uOemXKVUr9Qim1IViuecqcxRylVKZS6sPga1ga/LlnrW07KKVmB1+XUqXUfKVUMuZ73D34Glcppborpf6klJob3O5jpdRP6pVjvVLqyuY8VyHiJdes1nHNCl5XrgbuAU5SSo2qt/x2pdSWYDk2K6VGKKXmAL2BD4Kv3S+VUn2D76FVKXWdUmp1vf08oJRaEPy5wddYiJYi16HWcR2qVyapO7US7S6YUkr1Ai4Cvqv19OXAaGCoUuo04EXgx0BHYCawIPgHawfmA3OADsBbwFURjmMBPgT2An2BHsAbhmFsAe4Elge7emQEN/k7MBA4FcgOrv/H4L4uAH4BnAucBJxD430c3EcWsBZz5u/abgT+CnQC1h1ZHvwj+xx4Lbjt9cAzSqmhEc6/TCl1Zhzluha4AOgHDAemB5/XgNlAH8xKiQt4utZ2c4AkICdYricMw6gGLgTyg69ximEY+fWO9zowpVZ5hwaP8b94z1WI5iTXrFZzzboSqMJ8DT/FbKU6su01wJ8wW6zSgB8Chw3DmArkEby7bxjGo/X2+QEwSCl1Uq3nbgiWGRp4jYVoSXIdajXXofqk7tQaGIZxwj+AXMwvwTLMP9BngMTgMgM4u9a6zwJ/rbf9NmAicBaQD6hay5YBDwd/ngTsD/48FigCrGHKMx1YWut3BVQDA2o9NxbYE/z5ReDvtZYNDJY7O8L5LgJui+F1yQjuJz34+0uYF60jy1OAANALuA5YUm/7mcBDtbZ9OMb3o/755wI31fr9UWBGhG1PBUqDP3cDdCAzzHo170Wt5/4EzA3+nBp8zfsEf/8b8GLw5wbPVR7yaO6HXLMivi7H5ZoVXP8L4Mngz1OCr5Ut+PunwH0NvJfn1Pq9b/AcrMHf5wJ/DP58ElCJWclp8DWWhzya+yHXoYivi9SdpO5U59Ge+rlebhjGFxGW7av1cx9gmlLqp7WeswPdMf94DhjBT0jQ3gj77AXsNQzDH0PZOmN+ea5RSh15TgFHsj11B9bEcMwGBe/4/A24JnhMPbioE1Ae/LnmtTAMo0opVRI8fh9g9JGm9SAr5t2NpnCo1s/O4DFRSiUBT2DeeckMLk8NnksvoMQwjNJ4D2YYRqVS6n+Yd07+gVk5uj24uLnPVYhYyDWrlVyzgnflJwO/CT71PvAccDHmHfdewK549xv0GvBP4C+YrVLzDcNwKqWyaPg1FqIlyHWolVyHIpC6UyvQnoKphtT+A98H/M0wjL/VX0kpNRHooZRStS4KvQn/JboP6K2Usoa5KBj1fi/GbILNMcx+yfUdxPzwH9E78qk06AbgMsym7lwgHSjFvPgcUXMcZWbt6YB5R2kfsNgwjHMbeezG+jkwCBhtGMYhpdSpmN0MVLBMHZRSGYZhlNXbrv5rHM7rwENKqa8x+29/FXz+eJ2rELGSa9ZRLXHNmorZbeaDWpU2B2ZXv/nBYw2IsG20a9HnQOfgtW0K8EDw+WivsRDHm1yHjpK6UzuuO7W7MVMxeB64Uyk1WpmSlVIXK6VSgeWAH7hXKWULDrg7I8J+VmH+If89uA+HUmp8cFkB0DPYjxjDMPTgcZ8I3o1EKdVDKXV+cP03gelKqaHBuw0PxXAe1uAxjzxsmM2zHuAw5t2cR8Jsd5FS6sxg2f4KrDAMYx9mH+aBSqmpwXO3KaVOV+ag0OaUinmxLFNKdaDWuRuGcRCzH/MzyhxsaVNKnRVcXAB0VEqlN7DvjzDvpPwFmBd8H+D4nasQjSHXrOa/Zk0D/ozZVebI46rgsTsCs4BfKKVGBt+DbHV0AHoB0D/Sjg3D8GGOIXkMswL2efD5aK+xEK2JXIek7tRu604STNVjGMZqzCbLpzHvPOwkOKDPMAwv5iDk6Zjpca8D3o2wnwBwKeaAyDxgf3B9gC+BTcAhpVRx8LlfBY+1QilVgdk/f1BwXx8DTwa32xn8P5pnMf+QjjxmA69gNnMfADYDK8Js9xrmH10JMBK4KViGSuA8zKbdfMym5X8ACeEOrswsMBNiKGc0TwKJmHegVgCf1Fs+FfABW4FC4P5gebdi3j3ZrcwBnd3r79gwDA/m+3cORwd8x32uQhxPcs1q3muWUmoMZsXhv4ZhHKr1WBA8tymGYbyF2Q3oNcwxT/MxAyOA/wN+H7wO/SLCub+GeR16q97d+IivsRCtiVyHpO7UnutOqm4XViGEEEIIIYQQsZCWKSGEEEIIIYRoBAmmhBBCCCGEEKIRJJgSQgghhBBCiEaQYEoIIYQQQgghGqHBeaY6depk9O3bt4WKIoRoCWvWrCk2DKPz8S7HsZLrkxAnnhPh+iTXJiFOPA1dmxoMpvr27cvq1aubp1RCiONCKdWoWeBbG7k+CXHiORGuT3JtEuLE09C1Sbr5CSGEEEIIIUQjSDAlhBBCCCGEEI0gwZQQQgghhBBCNIIEU0IIIYQQQgjRCBJMCSGEEEIIIUQjSDAlhBBCCCGEEI0gwZQQQgghhBBCNIIEU0IIIYQQQgjRCBJMCSGEEEIIIUQjWI93AYRojfZW7OXt7W+zq2wX1b5q0uxpDO04lKsHXk3npM7Hu3hCCAFAlcfP/O/2s2RHMaVOHwlWjR4ZiVwzqhcjemeglDreRRRtgdcJG9+GHZ+DswQsNkjvAafeCL3HgnyOhIhIgikhalmyfwkzN8xka8lWAnoAv+GvWbYsfxmzvp/F6G6jufOUOxneefhxLKkQoj3bV+Lk6a928v66A2hK4fQGapZpChasz6dzagJ3TRzANaN6YdGkMizCKN8PS/4F6183AyZvda2FCja+B0kdYfx9MHKaGWQJIeqQYEoIwDAMnljzBK9vfR13wB12Ha/uBWDpgaV8e+hbfnX6r7h60NUtWUwhhGDVnhJufmkVbm+AgBG6XDfA6Q2w97CTP3+wmY82HmTmTaNItFtavrCi9dq/BuZcDj4n6P4wKxjgq4byavj8D7DxHbjxTUhIbfGiCtGayZgpISBqIFWbgYE74OYf3/6D93a81wKlE0II07p9ZUx7cRXVnvCBVH0uX4CVu83gyx/Qm7+Aom049D28fCl4KiIEUvX4nHBgDbxyGfi9zV8+IdoQaZkS7d7X+78OG0iVLiml+NNivIVeLA4LaSPT6HJ1FyzJ5t1dd8DNIysf4eROJ5OdmX08ii6EaEdc3gDTXlyFyxcIWVa9eREV387Hd3g/mj0RW1Z/0sddi6NnDh6/zvp9ZTz5xQ5+cf6g41By0ar4vfDK5WarUy19n6zE6YM996WQbDe7hc5a62XuBh+LpidDwAMFm+Hz38OFjx6PkgvRKknLlGj3ZqyfERJIFX9czKG3DtH12q4MfWYo/f/QH+9hL7mP56L7j97d9ek+Xt70cksXWQjRDn2wPh9fmNalilXvUbLwedLHXEvPn8ylx12zSR1xEa4dK2vWcfl0XlqWi8cfGoiJdmbrB+AP3wsjYMC/VzbQ8uR3wdpX6o2tEqJ9k2BKtGt7yvewvXR7necCrgCF8wvpflN3UoenoqwKe2c7ve7uhbfYS/my8qPrGgE+zv2YKm9VSxddCNGOGIbBM4t21kk0AaB7qilb+iodzr2LpEHj0OwOlMVKUvZoMiffErKPTzYeaslii9Zo6RMQ4TvrwXF2Hl/moczdUB9SDTa81TxlE6INkmBKtGtvbnuTgF63cuLc4UT36aSNTKvzvMVhIXV4KlWb6n4JaWh8kvtJs5dVCNF+bcqvoLDSE/K858BWDL+XpIFjo+6j2htg1pI9zVE80VYU7zQfEYzqbmFSXyuPLwv9rNXwVcOK/zZD4YRom2TMlGjXdpXtqpP+HCBQFcCaYkVZQlMJW9OtuPa66jznCrjIL9kJa16GDfPAeRgwILED5FwBp0wBR1rIvoQQIlb7SpxoYeb6Cbgq0JLSUFpsmfr2lzqbumiiLSnNNdOb+10RV/nL5ATGv1jNfaPtkfdTcaDpyyZEGyXBlGjXqn2h/b4tKRb8VX6MgBESUPnL/VhTjv7ZZAYC3FNazhV7HwNLQsiAXg6uh88fgmFXwuTfmZMgCiFEnKq9AXQjtOuVJTEN3VmBoQdiCqg8fsno1655q4CG00AOy7JwyUArf1/qZUjnCB2Y/A20XAnRzkg3P3Hi0HVz9va5V8FTp8E/B8N/RsHbt5jzaYSRYk8JeS4pOwllVVSsqajzfMAdoHJDJclDkwHo4/PxzoGDXFFZhV33hwZSYKaT9btg/RswY7yZjlYIIeKUkmAJ2zKV0GMwymrDuX15TPtx2GSuqXYtIQWIPoHznyc5eH6tlwMVEQIvq6NpyyVEGyYtU6Lt03VY8Sx88wT4XKEDa0t2wbaPIa07nP0HyLm8ZtGQDkP49tC3+HRfzXOWJAtZl2eRPzcfzaGRMjQFX6mP/Dn52DrYyBiXQZbfz5z8AtJ0nZiqJkYAXKUw+yK4YxF0HNAUZy6EaCeys1Lw66GtSlpCMhln3kjJ5zNQmgVHv9NQmhV37jrceRtCklAM6JzcUkUWrVGngRCIPk9UdgeN63JsPLXKy8lZYe67d+jfDIUTom2SYEq0bX4PzJsKuUvMVqBwDN1cdngnzL8T8r+Dc/4ESnHtoGuZs2VOyCadL+qMJdnCoXmH8BZ60RI10kak0evHvdBsGv/KLySlXiAVdY4OMAO9166Fn6yGMHeZhRAinOysVPp3SmHzwYqQZWlnXImWnEn58nkUf/g4yp5IQpds0sZeV2e9ZLuF2ydIJbhdy+gN3U6FfSuirvrHiQnM2eALXWBPgXH3NkPhhGibJJgSbZeuw1vTYc/iiHNmhPC5YNVzYEuCSb8iY2s+gwusbOjoDen50GFiBzpM7BCyi2yvl0FeH7Ywuz8yR8dvJySEP76hQ8VB2LcSeo+JrcxCCAHcOWkAv3l3A9We0LmiUnImk5IzucHtbVaNswdnNVfxRFtx5v3wzm0hvThy70+t83uvdA337yMkTxr6w+YqnRBtjoyZEm3Xd3Ng96I6gVTfJyvJeqySau/Rft6z1nqZ9FKt8Uw+J8aSf1Fw73Xk//JX3NHvJhxx9P++qbwSa5iB4BDjHB0+Jyx7KubjCSEEwAU5XUm2WxvVqJ1os3DPpGysFvnab/dOOg+SOoKK/7PgUortQy4Aa4Qbhu1ZyW74+FfmeO1HusPfusHjA+GDB6B4x/EunWhGclUVbZNhwNJ/he3aF3UGdwCfm4weh+j/8UdMvPKn3DviXhyW6AGVZhhcXF0dsUk3pjk6MMxEGR6Z6FcIETu7VeP1O8aYAVUc2yXazBap2yb0a7ayiTZEs8C0BZCQRizJKGpYE/H3GcsDej4Pr3gYbwxjr9qFwq3w4gXwzFj49gWoPAjearN+UlUA370CM86E5882M/yKE44EU6Jt2rcKqorCLoqldUhpkODdghYwW6ymDp3KAyMfwGFxoDVwty5V11ENZ5XlL5MT+M8qL0XVDaQgttigurDhHQkhRD0DOqfw7t3jyEy2k2CN/hWeaLdwySnd+ff1p6JknKY4IrMv3LYQUrJiy8xnS4ZBF5B60/u8cembHHYdZurHU9lfub/Zi9qq5X4Ds86GvBVmLxk9zBgz3W8uO7DGDLp2fNHy5RTNSoIp0TatmR0x4URsrUOYEdWmd2t+vWHIDbx84cuc1+c87Jo9pKUq0ZpIhpaA0hoealh7jo4Gj+2LPGmiEEJEMrBLKgt/NpF7JmeTmWTDUS+osmqKBKvGmP4deObGETx61XDp3idCdcqGe1bBpN9AcmczsURtms3sztdrNFw1C66eDVY7qfZU/jXpX1za/1Ju/OhGvsr7KuIhdhRU8ut3NjDmkYUMe+hThv/pUyb840v++dk2CitiHOvcWh36Hl69xmyFijJ3Vw2fE968CfavbtaiiZYlCShE21SaS0MXr5hmcPe7oKzuXbWhHYfy2MTHKPeU89Gej8gtz6XSW0lGQgYDOwzk/G7jsP9zaNTi/XmSgxEzq/j52Aj9yvVAsIuFEELEr9rrp9rjxxuchNeqgV83AyndMDijTwfuP2cgp/fNlBYpEVlihpmQYty9sGsh5C6FqkIziErvCTlXhJ3KQynFTUNv4uTOJ/Pg4gdZW7iWe0fci00zUzOtzSvlofc3saOwEl9AJ1Cro0aF289zX+9m5te7GT+gI3+9fBg9M5Na6oybhmHAvJtC5peMKauvzwVv3AA/2wqa3OQ4EUgwJdqmKK06Mc3gDqFzUgWlJ6QzZfCU0AWGAY50cBY3ePyoc3QoDVK6NLgPIYQI57mvd/HPz7ajGwa+QN2bSn7d/H35rsOs27eKU3tlMGvaKJLs8nUvGqBpcNK55iMOp3Q+hTcveZPfLP0Nt356K4+e9Sjr9hjcP28dbl/kru6e4E2Ar7cXc/FTS3nt9tHkdE8/plNoUXkroDr8UIOoWX3BbM3atTDu11u0ThISi7bJkRF1lagzuKMguVN8x1UKxtwN1sSoq/5xYkKdrII1LHYYOR2sDbSaCSFEGI9+spUnPt+Bx6+HBFK1GYDTG2DN3lKueGYZTq+/5Qop2pUMRwb//cF/mdBjApe9/kvufWNtg4FUbQHDoNzlY8pzK9hXEmGuyNZo2VPgDV/emLL6eqvgmyebqXCipUkwJdqmPuOiDpqt3ToUlj3ZnLwwCsMwMGqnQh85DQj9osi9P5Vz+h+9+3tkjo6apv0jlAZn3BH1uEIIUds7a/Yz+5tcXL7QeaYi8fh1couruWvu2mYsmWjvNKXxo6G3UJl3PeHi9urNizj48v3k/etq9j89lYI3H8K9f1PN8iqPn5++/l0LlvgY+L2w4zMiDTWIedz2vpXgLm/68okWJ+3+om0aOR2W/DPqahFncAezT3iYJnbDMFi1p4Tnvt7N8t2HcfkCaChSHFYuPrkbt07ox4BTpsCGefEnkbA6zDk+MvvEt50Qol3TdYNHPtoSMZCq3ryIim/n4zu8H82eiC2rP+njrsXRMwePX2fVnhI2HihnWI821JVKtCmfbDyEQgPqfkYrVr1H+cq36XjePTj6jUBZrLj2rMG1YyWOnjkA6AZsPVjBzsJKsrNSw+y9FXGXgWY1s/RFENO4bYsdqovNoQOiTZNgSrRNqV1gwNmw/RNq3x2KeQZ3q8PsrqdZ6jz91bZCfv/eRkqdXlzeQM2eA5hdEd5cvY931u5ncNereSx1P9lFC9EsMWbxsSRAx5PgyueOPmcYwXFbymwpk4HiQogwFu8owu0PH0jFUln1+nVeWLqHJ66L3hovRGM8u2gX1d66n1HdU03Z0lfpeNH9JA0aV/N8UvZokrJH11nXrxu8uDSXR648uUXK22h+T9QJj2Mbt62BzNV1QpBgSrQqm/LL+WJzAQUVHnTDoHNqApMHZ3Far4zQjFSTf4OxcyFKb8TFyJYII2+u89SrK/by1/9tbrCvt1838OsGG/ZXcFngJmZn2hhjfB2S0SeEPRm6nwZT5pl3o7Z/ZvaXzlsOykJNQDjgBzD+XugzXgIrIUSNmYt3Ue0JDaZirawGDIOPvj/In36YQ3qirUXKLNqPA2Uu9hSHfg96DmzF8HtJGjg26j78usH8dQdafzDlSI8pCIqe1dcX0/hv0fpJMNUWGIY5N4HfY6bTtpxYb5svoPPhhnyeXbSLvBInXr9OMCEVCnhh6R6yUhP48cQBXHFaDxw2C4ZhcHjBSnwbe9D15HxUIErf5NrsyfCj9yG5Y81TH32fHzWQqs0AXBY7t7qnM/+SWzhpy3/M/s9o5uR8YHYjBOh2Coy/Hwaeb/azXvAxEW3ZAAAgAElEQVRTs3vgkUyCRq1j7vjMTE2b1MFsweozDiGE2LA//NiKeCqrdqvG9oJKTu/boamLJ9q5wgo3dqtWk6XviICrAi0pDVWvF0gkLl8AX0DH1prnRUtIhdTuUJ7X4GpRs/ompEpW3xPEiVUrP9EUboUVz8CGN807GMpi3g3J7GtWzodfYwYGbVil28f02d+y5WAFTm/oXdcjGalyDzv5ywebeWlZLq/ePArf43/HtXEjvWZ8hipbA2/fYs7d1FBQZUsyu/dNWwBdj975cnr9/PzNDWEDqYbGIRAs230rkvno3g+gbB/s+BScpWbJEzPNrohH5uhY9Tx89gdzfquIDLOVq7wa5lwJV86EoZfF8EoKIU5kkcZKxVVZNaDCFWEMqRDHwOsPfyPSkpiG7qzA0AMxfUYtSrX+YEopGH8ffP4H80Z3AyKO264ZatCKz1PETIKp1qhsH7w1DQo2QcAPxpFBjsH/S/fAp7+FT35tdgmb9Js22SXM7QtwzYzl7C6ujnghrs3lC7C7sIpLH/6QFyqKGTRnDpaUZOhyIdy7Dla/CKtmQsBnDgzVA2YrntIgsaM5MeHwa827QbXM/+5A2JcvlnEIBrC7qIqthyoY3LUXnH5b+MJvfj+GQKoevwve/bE5M720UAnRrlk1FTYVelyVVWW2TgnR1NISbRhhhg8n9BiMstpwbl9O8uAzo+5HNwwSbbG1Yh1Xp1wHn/0+5OmYx20bBoyY1lylEy1MgqnWpmAzzL4IPBVgNJD+9sgYnWX/gaJtcPWLIckUWrsH315PbphAqqHWIJ9uUGxJ5OGRU3ktpVarXGoXmPwbOOtB2L0Iyvaak+I50qDzEOh1RtiA0zAMZi7eHdIqFs+gWa9f54Ule3jsmlPCn6jPDfPvDgmkYpop3e+Cd26FBza3yYBZCNE00hNtFFeFjtOIp7Lq181xqEI0tX6dktHDRFNaQjIZZ95IyeczUJoFR7/TUJoVd+463HkbyJx8S531B3dNCx0f3RolpJo3aL/5d9TWqRC2JBh1c52hBqJtk2CqNanIh5cuBndp7Nv4nOY4m49/CRdHTxXeWhRUuPl0U0FIIBVLa5APjbV5ZeFTqFqscNI5MZcj97CTwsrQroHxjEPQDfjf9wcjB1Ob50fcNqaZ0t0VsOdr6D8xalmEECema0f1YtbSPSHXzHgqqx2T7Qzq0srTTos2yWGzcO2oXry6cm9IC2raGVeiJWdSvnwexR8+jrInktAlm7Sx19VZL9lu4c5JA1qy2Mdm4q+gZDds+SD2gMqWBP0nwbkPN2fJRAuTYKo1+fR3IRO4xdR64XPCuldhxI/MZAdtwNwVe6l/7yme1iCfrjdJCtXDVR6sFgX1ujQ3ZtBsQDewaGHuqC194miyiXoeHGfn0W883H26nQxHhLtx3irz7pcEU0K0W1PH9uGFpXvCLoulsppkt3DnxP5t466/aJNuHt+X11flEW4y25ScyaTkTG5we6UUF+R0babSNQOl4PIZkJJljonWA+b49rDrWs2bvcOvh4sfl7FSJxgJploLZwls/V/Yrn0xtV74vbD8v3XnMArHMMwkFhb7ces2ZhgGryzfG5L1J57WoIAO7313gD9eOhRHmP7Vht9PoKKCQHk5enk5gZpHRc3PekU5+VVW9KRR5utRS7yDZg0D/jB/I51TE+iQbK95ZFFKdsnukMDxiNozpT98tiPyAXYvMsfPnWCZHIUQsemWnsiY/h1ZurOIQJghptEqq4YBl5/Ws4HlBh6/ToJVk4BLNEqfjsmcn9OVzzYdwh3DOOjaEm0Wfn7ewLY3pk/T4LyHzalWVs6EdXPxGwqvP0CSzWrWs3SfGUSNuQs6DzreJRbNQGpmrcXaOREngYup9cIImEkOLnwUEuvNW+Aqhe9ehRXPQmX+0efTesCYe+DUKaHbNKMqj59qT+jM4fG2BuHzsfEv/6BrZRGB8jIzQCorJ1BRge5yYUlNRUtPw5KegSU9HUtamvl/Rjq2Ht2xDB1Kdy0VvvWAv+6dtHgHzVqUYmCXFEqcPnYUVlJS7aWk2ktaxXb+FbCQ0sC2Mc2UrlnMcXRJktJYiPbqsWuGc+GTSyhxesMO9o/EYdP4z5TTSEmo+5VfXOXhjVV5vLx8L8VVnpqbPj0zk7jjrH5cflrPkG3KnF7yy9y4fH5SEmz0yEwMWUe0X49fcwpTylxsOlAeV0B1fk4Xbh7frxlL1sw6DoCLHoVz/8yKLz9k5948po/tY2b17T2mzWdeFg2TK2BrsWFexExvMbdeaDbY/RXkXGH+7nPDRw/C92+agVr9Pr3l++DLv8DCP8EpU8xAzNpAhb6JVHsCWC0Kv163NhBva5CmwN8vm7Rup6GlBYOmDDNo0lJSUDE0ow/366i1n4O/bnAX76DZ03pnMD3cF8GhFJhthQYytsc0U7pSZhcCIUS7lZXq4K07x3LtzOWUOX0h19BwHDaNR644mXOGHp3Pptrj59fvbOCzzQUANb0Ejuwtr8TJIx9t5eH/beHG0X349QWD+HZvKc8t3s2yXYexWzWUCk6BGNC5cFhXbpvQn2E90pv8nEXbYrdqvHb7aO59/Tu+3l6Mxx8g0sfUZlFoSjGmfwe+P1BOudNHelIbn1DalsjmxBEUds+Bk4ce79KIFiLBVGvhKmlwcUytF7ofnIfNn90VZjKL4h1HJ5EN50iAtf4NOLQBfrQAEhpqRzk2Tq+fnYWV+PyhV9d4W4MMm42uPzybtE6Nv+Njt2rcOLo3L36z55gGzd4VadCsI8PsnhdF1JnSA74WbT0UQrRO/Tun8PF9Z/HQgo18saUQTREyR54WTIHev1MKD106lNH9j2YNK6n2cvWMZRwodYV0ta7tSIbTV1fsZe6KvTXHMQBvvX6GH6w/yKebChjaPZUXpp1ORlLz35QTrVeC1cLMqaNYv6+MWUt289nmAmwWDSPYnKqUwjAMrju9F9PH9aN3xyT++uFm7pizmlduPYMEa9vKTFxfYYVHsma2MxJMtRZGw83hMbVeYATHRPng1auhaKs5PioWfhcc2givXwdT3z/msTmGYXCgzMWWg5VsOVjB1kMVbDlYycFyFwM6J4MiZIxqvK1BAd0gK+3YL1hTx/bhpWW5oQUitkGzCTYLkwZlhV+Y3tNs5q9seH6pqDOldzsFLG38jp0Qokl0Tk3gmRtHUlLtZd63eby5ej+lTi8ubwBNKS4e3o3bJvRjcNe689u4fQGmPL+CfYed+GJo1QJi6qoVMAxcvgAb9pdz4b+XsOAnZ0plUnBKrwz+c8MIypxeVueWUu7yYbUoMpPsnNGvQ53xzr+7aAh3v7qWX769gSevO7VNj9srqvIwtHuYuaXECUuCqdbCkQ5VBQ2uErX1QrOaLSHfvQqHvq8TSMWUFTDggQNrzW6Bp94Qc9GdXj/bDlWy9VAwcDpYyZZDFSTZLQzplsbgrmmcn9OV+88ZSP9OyVgtGn/+YBNzVzQ+haqm4LycLiTZj/0j3DMziatH9uTdtftx+eIbNOuwafzxkqHhs/iB2T1v3E/hy782fqZ0ewqMvz+ucgkhTnwdku3cNSmbuyZlA/D55gJeX5XH4xGmaXhx6R72FleHDaQamt8vVr6AQVGlhxueX8EHPz0zbHIg0f5kJNnrdDMNR9MUT15/KlOeX8Hjn23jwfMHt1Dpml5RpYes1AaGZIgTjgRTrcXgi2H5XjOgiSBa64XhcVK+voT04qdRYSruMWUF9Dlh6ZNhg6narU1bD1aw5ZAZOOWXuxjQOSUYOKVyQU5XBndLo0Ny5K4e08f15bWVjU+hmmC1cPuE/g2uE4+/XDaM/aUuVu05HHNA5bBp3DMpm8tP69HwiqfeYI5LqyfmmdI1i/n5EEKIBnRNc3CwPHy37oBu8MLSPWFbmmKZ36+2hgIvv26wv9TF/O8OcP0ZvZv8HMWJy2GzMOtHo7jy2WX0zExiShv9/BRVSje/9kaCqdZA12HY1WZq8ygitV4YKPyZp+Jd/zWGYx8qzDsbU1ZAgPI8XHtXs03LDrY0mV30jrQ2De6axpBuR1ub+nVKxmaJL51pn47JjOqbyao9JSGtU9FYNEW/TskM79l0Y4gsmuKFaaP4/fyNvPfdAXTdiNgNxmHTzFTolwzlxtF9ou88MQMm/QYWPxr/TOnWRLjwMeniJ4SIqmu6g4KK8MHU19uLcPtCk9jEM78fxBZ4uXwBnl28i+tO79Wmu2uJltcxJYHZ00/n2pkr6JbuiNyFvhUrlGCq3ZFg6niqKoK1L8OKZ8BTZSaQqCfW1gtlS8J2+cNkbZ6PsSr84WLNChjwuXn1hSd5r9OPg4FTKufldGVw11Q6pjTdBeK/N4zgwn8vobDSQyDG/vuagnSHjdk3n95k5TjCatH4+1XDuWvSAGZ/k8ubq/ehKVUzHZeuGyTZrdw6oR/XjepFZgMtbyHG3w9leWaij5hnSk+ECb+AU66Lvq4Qov0yDNi3kk4rZvC6fzX6kxqaPQW65MCYO6HHSN79bj/V3tBgKp75/eIJvIoqPXy3r4wRvTOP/fxEu9K/cwozp47gjlfW8MqtZ5DTve1kifT4Azi9fjIS5QZoeyLB1PEQ8MNHv4D1rwGq4Wx7sbA6oM8487Hiv6gwXeeOiCUroAWDW4YncNvVE46tXFFkJNmZf894rn9uOfll7gYzSwHYLYrMZDtv/ngsXdKarz9yn47J/OmHOfz6wsFsyq+gwuXDoik6JNsZ2i0NLdL4qIYoBRf/C9J7weJ/YL7vEZJS2JLMytGFj8GIqcd0LkKIE5hhwLrXzGtKdTHK52SQMqAsuLxoC2z9ENJ70s9/NTAsZBfxzO8XT+Dl9gVYsr1IginRKCP7dOAvlw3j1pdW8+7d4+iekXi8ixST4iovnVISGldPEG2WBFMtze+BOVdC/hrz52NlTYSsIXDdHLPCHiUNd2xZAUGLNQvgMeqS5uCDn07ghSV7eGnZHrx+PeTuaXKCBU0ppo7pw+0T+sfXInQMHDYLI/s0YUVAKZjwMxg53ZykecXTZoukFvwz1H2Q3NlsxRp+LSSkNrg7IUQ7pgdgwb2w6d3Ird2Gbi4r3s7d/JN062T+6p8KHK3oxTO/XzyBl26YWc2EaKyLh3fjQJmTm2d/y1t3jSXN0fpbe2S8VPskwVRLMgx45zY4sCZyq0SsLHazcj7oArhiJliDf7wpnaNuGjUrIEBKw5l3mlJKgpX7zjmJn5ydzZdbC/lgfT5FVR4M3aBjip0LT+7GeUO7YrfGNy6r1UrqAGfeZ2b5K90DrlLzvUzsAJl9QcYYCCEaYhjw4QMNB1L1OPAwxfIVbuw85r++5vl45veLd2L1iFlOhYjR7RP6s6/Exd1z1zL75tPjHp/d0gor3HRuwuEQom2QYKol7V0GOxfWCaRiSll+hGY1u/QBjJgGo+8wK9+1+DufgabeQjManxUQewpk/6BRp3gsLJri3KFdODdKCtUThqZBxwiT/QohRCRb/wffv1UnkIrluyRJebjZ8glL9OGs0IcC8c3vF0/gZdWUpIcWx0wpxUOXDuXOuWv4zbvf89jVw1t1UpOiKmmZao9ad4h/oln2VNi7iEdSlkelNPjR+/DL3XDBIzWBlO50Ur5gAXm33MKunz0DMWT2/uPEBKq9EcZWWRPgpPOi70QIIUTLW/LPRn+XOPByp2VBnefSzriSzLNvpXz5PPb/50b2PzudyrUfknhS3bFRtQMv5/bl6D43RsCPa9dqSr96sc66Vk1xfk7XRp6gEEdZLRpPTTmN7QWV/HvhjuNdnAaZc0xJMNXeSMtUM9lRUEl+uRuXN0Caw8rAFCeddn9FuHmVYk5Zrlkgbzn0HIWh6zhXfUv5++9TuXAhiaedSsbVV9Pz7GfQvnncDNxqzVkV85xGlgQYc7d5LCGEEK1L0TYo3Bx2USzfJZqCMdoWulBCAR1qno9lfj+IfWL1gV1Tyc5KiePEhIgsyW5l1rRRXPmMOQfV1SN7Hu8ihVVU6WFwVxnv3BK8fp1PNh3i040HOVzlRdMUnVMTuPzUHpw1sHOLdjOWYKoJuX0BFqzPZ8aiXRwsd2O1KAzDHAJzUeAr/mzVCNfpIdaU5fhc6KvnUrwqQPmCBVhS00i//HKyfvYA1s61xkqNuctMuV5dRLjgLSKlmeN5Tr819m2EEEK0nO/mhJ1GA2L/LlEYXGFZyozADxtVhGiBV5Ldwp0TpQuzaFpZqQ5euvl0rn/OnINqfHYnAAzDYP3+cl5buZfcw07zJnaijZF9MrhxdJ9mzf5bX2GlhwkndWqx47VHJdVeZi7exasr8zAwqPbUTVr2xZYCHFYLt57Zj+nj+5Jkb/5QR4KpJrJy92Fue2U1Ad3AeSQbXa25dVMsFSjdVzuJUh2xpCwHCBzYhZHpodczz+AYPDj8Sskd4eaPYNY54KkwMzpFoyzgSDO3S5RUtkII0Sod3h0xmILYvksSlJ8+WgGETjt1zBKsGqP6dpAufqJZZGel8vQNI7jn1bW8ettoNuVX8PRXOzlU7sbjD1B7ysrVuSXMWLybcQM68sA5AzmlV0azl0+y+TWvXUVVXD9zBeUuH95A+LpttSdAtSfAvxfu4N3vDvD67WOa/T2RMVNN4KtthUybvYpKt/9oIBVW5Fai2inLG2Lt2pUuv/5V5EDqiE4nwY+/NsdV2ZMbWFGZyzv0hzuXmv8LIYRonbzVDS6O9bukk81LckLk7txKQaLNgs2isMXYXcZh08jpnsbMm0ZKJj/RbMb078jvLh7C5c98w+/mf8+e4mpcvrqBFIDHr+P16yzeVsR1zy3n3bX7m71s5pgpSbzSHA6UubjymWUUV3siBlK1efw6ucXVXPXsMircvqjrHwtpmTpGWw9VcPfctbh9Db+xZUYqfqwkNHArMJaU5SqxQ8RlITL7wE/Xwp7FsPRJyFtmjok6wu+BfhNg/H3Qd4Kk5BZCiNYuKXrPgVi+S84ZOZSns0cwc/EuvssrqzP1hNevM3lQZ24/awDZWcn84q0NfL29CCO4rL5Em4ZuwFUjevKnH+a0+vTVom0zDIMvthTgDxj460dQ4dYH3D6d3773PQlWjYuHd2+2chVVeegkqdGbnGEYTHtxFVVuP0a9t7x68yIqvp2P7/B+NHsitqz+pI+7FkfPHPy6waFyNw+8sY4Xpp/ebOWTYOoY/ePjrbh94QOk2m/wf+wJbOzq5A9n2Tizd/iXPWrKcmsiDLsqvgIqBf0nmY+qQijLM7v+JaRBRp+Y5qUSQgjRSvQcDds/bXB+qejTXySjeo5i8qAsJg/KorDCzf4yFy5vgJQEK306JpGRdLSb4PM/GsWhcjdzVuQyd0UeFW4fVk3h1w2yUhO4fUJ/rhnZi/Sk1j+pqmj7XluZx1dbi8IGUg1VrN0+nZ+/tZ5Te2fSIyPxmMvh9et8uukQM7/exa7C6ppuhuc9sZhp4/rK30QTWrO3lPwyF4F6kVTFqvcoX/k2Hc+7B0e/ESiLFdeeNbh2rMTRMwcAb0Bnyc5iDpa76JZ+7O97OBJMHYOCCjfLdh0O23kv3Bs8Ie+3zN+6M2IwBWbK8jkbIjRHGjqMmNr4AqdkmQ8hhBBt06lTYOFDUVdr8LsEYOjlNT9mpTnIijJIv2u6gwfPH8yD5w/G69dx+QIk2y1YpRVKtCDDMPjvVztxhbmJHUvFWtcNXlmWy28uGnJMZXhm0S5mLN6FrhtU1xvesa/UxT8/285jn27jkuHd+Ovlw1okCcKJbObXu0Pec91TTdnSV+l40f0kDRpX83xS9miSskfX3YEBc5bv5ZcXRBki00jy7h6DuSv2hn0+0hu8ccB0Xhj8GNCIlOVKg8EXm9n2hBBCtE+JGTDkMtj4DhhHKxcxf5doNjhtKtgaP67DbtXqdAsUoqWs2F1CmSv0JkGsFWtvwODVlXn87LyBJFjjnwImoBvc+8Z3fLmlMGxAd8SRZR9uOMj6/eW8+eOxdEhuOMGYCK/C7WPxtqKQ7n2eA1sx/F6SBo4Nv2Et3oDO3BXNF0y1yNWw3OXj5WV7+OXb6/nxnNX8/M11zFi8i6JKT/SNW7FF2wrxhOk/HukNXqEPYak+DJfRiD+ohDQ498+NLaoQQogTxVm/MCdXbwxrAoy9p2nLI0QLeXlZLq4wib7iqViDwVdbC+M+tmEY/O6971m4paDBQKpOuYJJEG54fkXEISGiYYfK3WFv3gRcFWhJaagY50Wt8vib7T1o1pap7QWVPLtoFx99fxBNqTofvgSrxhOfb2fCSZ24e3I2I3q3vXTcFa7w6Wkjv8GKn/ruZa79EYaxhyTVcLalI9uQkAI/mg8ZvY+5zEIIIdq4zoPg6hfhrZvB74p9O2si3DBPvktEm5V7uDrs0Ip4KtYev87+0jj+boJW7inh/XX5YROORUuCsKe4mpmLd3HfOQPjPm57V+3xh82PZklMQ3dWYOiBmN53m0Wj2uPHYYu/RTKaZmuZ+mD9AS57eikL1uXjCfavrs3j1/H4dRZuKeTG51fy3Ne7jvmYHn+AokoP5U4fegwZXo6V1RI++13tN7g+LzZu8P6OBYGxuA0bbiPS4EQFtiToOADuWAzdT2vCkgshhGjTBl0I180BW3L0ViprIthTYOq70PfMlimfEM0g0vQzDdW76vMHDJyeyHO1RTJz8a6wLRsVq96jZOHzpI+5lp4/mUuPu2aTOuIiXDtW1qzj8evMXpZLoAXqpieaVIcVvX4fPyChx2CU1YZz+/KY9uML6KQ4mqcNqVn2+uGGfB58e0PUdOFgpqx0+QI88fkOdMPgzonZcR3L6fWzYF0+MxbvIq/Eic2iYRgGBnB+Tldun9C/WSZqMwyDlITwgVDtNzh5cOgXlx8rv/b/mH/6r+NH9q+4w/4ZCYYHNCu624myWPBmn4flzHux9h4tKcuFEEKEOulcuHctrH4RVs7E7fVixY9V94HFDhYbWB0w5m4YOQ2SOx3vEgtxTFISwldbo9W7arPhI2XJX2DrdrOVtubRy8xynNEbkjvXqXsVVrj5JkzCsXiSIPgCOl9uLeTcoV3iO+l2rlt6Iv5AaDClJSSTceaNlHw+A6VZcPQ7DaVZceeuw523gczJt9RZPzPJ3qhxcrFo8mBqV1EVv3hrfdzNoC5fgCe/2MEpPTMZO6Bj1OMcyaby9Jc7Uero3YraY5g++v4gC7cU0iMzkRk3jSA7KzXS7mLi8gZYvruYRduK+GpbIRUuf0162NpifYOLyOBp/SquvfvfrNy8nXeWbWJXhZ9KLQ33lkQCm0qYPHgNd5zVn1F9MlESVAkhhKgttStM/i2c9UsefvxJ7jlVo5sjAAmp0Ckb+k+GGMcUCNHandwjnW2HKqhft46nYm1LcDDwukegY5U5XUxZHpTthfy1R3/3OiG9Z02g9WnZcDQjC6hbD4tnrFa1J8C8b/MkmIpTcoKVi0/uxvvr8kNSo6edcSVacibly+dR/OHjKHsiCV2ySRt7XZ31HFaNW87s12xlbPJgataSPfjCzEwcS8pKt0/nqS93RA2mDMPgwbc38L8NBxscBKgbZqvXrsIqLnv6G165dTQj+8Q3Niu3uJqvthXy1bYi1uSWkNMjncmDsnj+R6Po2zGJkQ8vxB+muTiWN1hTkN05hclPLAHA6Q2WTQd08zX8YksB3+wspmOynWdvGsmwHulxlV8IIcSJzxmAt6ty+OMPzgfJtCdOUNPH9+X99QcIhLlhH2vFOsluZdzgXmYlrMvQ8AfyVEH5PijbB2V7KTzgwx2muhlvEoSCiradeO14uXVCPz7aeJCAL7SFKiVnMik5kxvc3gCmnNF8Y0WbNJiq9viZ/90B6sdS8TSDrt1byv5SJz0zkyIe59FPtkYNpGozgGpvgGkvrmLBT8bTv3NKxHXdvgArdh9m0bYiFm0rpNobYPKgzlx/ei/+M+U00hPrdu274YxevLJ8b9isftHeYAXsLKoKu21N2Q2z1c3pdXHNjOXMmjaK8dnSVUMIIcRRWw5WclJWqqQsFye0Id3S6NcpmS0HK8Muj1bvSrRp3D6hH5oWpadPQgpkDTEfgO/wFsjdHbJavEkQwjU2iOhyuqczuGsa6/eVhU1A0hCHVeOSU7o3a2r6Jg2m/rfhYNjhPfE0g+qGwWsr8yLmgs8trubFb3LDBiANdSMEqPb6+d17G3n9jjF1tttX4uSrbYUs2lbEqj0lDO6ayuTBWTx9wwhyuqc12L3u/nMGsnBrIXsPV4cEkQ2xoINu4InjU+HyBbj9ldW8dedYcrpLC5UQQgjTpvxyhvUIM6+UECeYX5w3iHteWxvTuPz6bBaNa0f1inu7zCR72GEd8YzVAshIjJR0TDRk7+FqypxekuwWPH495H2IJMGqcVKXVP52xbBmLV+TBlPbCyrDZlqJpxnUFzDYfLAi4vLZy3LDZuqLpRuhYcDavFJ2F1VxoMxVa+yTj4kDs7jitB7869pTyEiKPXpNTrDyxh1juGbGcg6WufCGGSRXn8Oq4Q2YvfnqixYQOr0BfvX2Bj68d0LMZRRCCHFi23igvFmSLQnR2vxgSBfunDiAmYt3x9xDCSDRZmHOraPjquMdcUa/DtgsGv562QLjGavlsGmcPSQr7mO3d6tzS7jr1bX89OxsLhjWlSnPrSC/zB31vU+yWxjeM50Xpp3ebIknjmjSYCrcrNQQfzNomTP8/EtuX4C3Vu/DVy+Yiqcbodevc/4TXzOspzn26d/XnUZO97ToTb4NyEp18OFPz+R3723k002HUIqwd0yS7RZsFo2Te6SxKrc0pHUtloAQYGdhFTsKKjmpy7El1BBCCHFi2HigghtH9znexRCiRdz3g5NItFl48osdePwBGmqoSLBqJFg1Xrl1dKNvOJzaK4MuaQnkHnaGLIt1rJZh0KhWsRNFabWXN77NY963+yh1+gjoBskJFsYN6H7xen4AACAASURBVMRtE/qF7XE1/7sD/PXDzfzz2lOYNMgMRD/86QTeXbufGYt3cbjai8sbqNP1z2HT6NcpmTsnDuDik7thtTR/1+cmDabqjyc6It5m0A37y8n54yd0SXfQNc18ZKU5qPb4wuaaj6cboQF0SEngvbvHR103HqkOG09NOY3Sai/zVu9jzvK9HK724PMbOGwaA7uk8uOJA5g4sBOjH1kYEkjFlV5T15m1dA//uGp4k56DEEKItsUf0PHrOruLqxjUVW6wifZBKcWPJw5gfHYnZn69i882FaApVae1IjnBgk3TmD6uLzeO6UPn1CjzsUU53l2TBvCnBZvDtohEG6ulKThvaJdGtYq1dcVVHh56fxNfbCkIaWyo8pjTG32y8RC9OiTyp0tzGJfdCcMwePKLHbyzdj+v3T6mzrUt0W7hxjF9uGF0b1bvLWXxtiKKqzxUuf0s2VnErGmnk9M9jUSbpcWyYDdpMJWdlUKS3RLS1S+ulJUWxS3j+3HP2dkUVrg5VO7hUIWbggo3Gw+U4wvTjS7ebCrVjZisLVaZyXbunDiAOycOCLt8bV5p2Dso8QSEAR0+3XhIgikhhGhn3L4AH244yMzFu9hTXE1AN1DKrOw9v2Q3U87oTaeUxlcahWhLhvVI5z9TRlDm9PLhhoPkl7lYuqMYh93CrWf24weDs5qsZeKyU3vw7KJd7Ct1xT35bqLNwgPnDmyScrQlucXVXDNjOaVOb8RxTgHDwOULsL2gilte/pbfXzyUVXtKyCtx8t7d4yMGwUopTu/bgVN6ZvDppkM8s2gn5S4/Nzy/At0wA9jzc7pyx1n9Gd6zebtAN2kwdcnwbvz5g01hl8XaDKopxY2j+5DmsJHmsNWZG6pHRiLf5ZVSXS9Yi7cboRGmdaullDm9YZN0xB0QehsOCEvdpXy05yPyKvKo9FaS6chkcIfBnNvnXBxWR2OKLoQQ4jjRdYN/L9zB80t2o6DO96BhmP88/eVOnv5yJz8Y0oV/XHUyqQ4Z7C7ah4wkOzeNMbu5ZqXuYXdxNefndG3SYzhsFt64YyyX/GcJZU5fzEkQEm0WZk07vcFM0ieiokoPV89YxuFqL7FWu90+nT+8v5ERvTN5444xOGwN14lfW7mX//t4K7phUO0xr4lHGl0C1J5v1sHTN4xgcNfmSdLTpMFUqsPGpcO78+53B8JG7bHkgu/VIZFuGeEr++lJtrBjm+LtRpjiaPLptWKm6xAur2P8AWH45zcVb+KFjS+weN9ilFJ4AkfnNEiyJvHwioe54qQrmDp0Kj1SejTyLIQQQrQUf0Dn7lfXsmRHcYODro90H/9iSwEXP1XO23eNJStVbp6J9iUz2U5pXlmz7LtruoOP7pvA1Fmr2FfqDJt07QibxayvvnHHmHaZHOZX72ygzOkLW19tKNmaYcDm/Aq8Ab3BYOr/PtrCy8v34o5hvtmdhdVc+cwyZk8/ndH9G57LtjGafFTWHWf1r/kANcb+UhcjH/6cj78/GLJsVJ/MsDn6a3cjdG5fju5zYwT8uHatpvSrF+usa9UU5ww5frNPZyTZwubIrx0QxiLJHvoBe2njS0z/ZDoL9y7Eq3vrBFIATr8Tp9/JvG3zuPL9K1l2YFljTkEIIUQLMQyDX76zIWogVZvXr3OgzMn1z61o1m7tQrRG6Ym2iInMmkJWqoNP7p/ArB+NYuLAztitGqkOKykJVlIdVhw2jcFdU3nkipPplu6gtBnL0loVVrhZurM4bOtdxar3KFn4POljrqXnT+bS467ZpI64CNeOlXXWe3fN/oj7f/7r3bwSJZCqz+kNcMtL37K9IPwcZceiyZtoTuqSyv9dcTK/ee/7Rs0B4PbpuH06D7y5joIKN9PH96tZlmy3ckrPDFbuKQnZLtZuhBZNccuZ/UK2bynDeqSH7WYYz7gyi1I1WU2OmPX9LGaun4k74I5aBr/ux6/7ue+r+3jq7KcY2z36OC0hhBAtb/nuw3yy8VDYQKqhu7sB3bw5+d+vdkact1GIE1Fmkp0yZ/js0k1FKcW47E6My+5EYYWbnUVVVLr9JNkt9MhIrOnS1yk1gYcWbOLT+ztG7bJ2Ipm7ci/hmlViTbbm8gWY+fVupo3rG5JEorDSzeOfbWvUfLNOb4AH31rP+z+J3ostHs3S3+2KET3RDfjd/O/x+vUGU1ZG4vbp/P2TrXRJc3B+Tlc+2niQf3+xA00p7BYVdj6nWLoRDumWxoDj2G/VYbNw/Rm9eGX53pBkGrEGhHarxh1n9a/5fXn+8rCBVOmSUoo/LcZb6MXisJA2Mo0uV3fBkmz+QbsDbu7/6n7ev/x9uiY3bd9iIYQQx+65xbvDdiWKZSoNr19n7oq9PHDuQGwtkB5YiNYgI8lGmavlWoOyghmnw5k8KIs3u+3j2UW72lUCije/3Rc22Ikn2Vq5y8em/AqG9aibMv31lXlh149pvllg26FKdhZWkZ3VdLFAsw0eumpkTwZ1TeXpL3fyyaZDYdeJFkG6fTo/e3M9vTK3k2i38NuLhzBpYGd+8dZ6Pvr+IK44W74SbRb+elnzzoIci2lj+zF3RR7hBk/FEhD2zEys8+H677r/hgRSxR8XU/RxET1v60nK0BR8pT7y5+ST+3gu/X7XD81qfrH6dB9vbH2D+0fef+wnJoQQoskUVLhZvvtwyPPxTKUR0A0+31zARSd3a/byCtEaZCTZKatu3papePzhkqFc/NQSLj+tB/06JR/v4rSI0ggtg/EkW7NoiqLKusNV/AGd2ctyj2l6Ib9u8OLSPTxy5cmxnk5UzXqraliPdC4a3o0ke+hhYu0z6fEH+MHQLObfM57Jg7JQSvGPq4Zzet8OJNpiL36iTeOZG0dwcs/QScFaWu+OSdw4ujeJjWjyddg0/q/WB2BvxV62lmyts07AFaBwfiHdb+pO6vBUlFVh72yn19298BZ7KV9WXrOuT/cxb9s8fIHWc+ERQggBnwUnga8vnru71d4Ab67e1wylE6J1SnNYcfkCYcfYHw/dMxK5a9IAHlqw6bhmk25JkTId1k62Fo3LZ167Xly6h/fXHWDJjiIWrM/He4wtXn7d4KMweRmORbO3+89YtAunN3wE2eHcu0gaNA7N7kBZrCRlj64zNgjMTBxLdhTX6TNptWjMvvkMrhrRE3twZutIkhMsdEi2M/e20UwenBVxvZb2+4uHMnlw57gCKotS/POaUxjVt0PNc69vfZ1AvQ+lc4cT3aeTNrJuCkiLw0Lq8FSqNlXVeV43dBbtXxT/SQghhGg2xVWesGOP451Ko/7dXSFOZEop0hJtlLtaz03im8f341C5i483hu+pdaJJilC3jSfZmkUpku1W9h6u5vPNBcxYvIt/fb4dV5huz/FeE6uaODFPs+YIr3D7wmbNiCeCBNh6sJJKt6/OnBkWTfHwFSdz7w9OYu7Kvby8bC9ev45FUxgY+PwGJ/dM586JAzh7cBaWMCnVjydNUzw9ZQSPfbqNF77Zg1ZvVujakhMsWDUzW0xJvabTLYe34DfqfigCVQGsKVZUmKyK1nQrrr2uOs+5/C52l+2GPsd4UkIIIZpMuEnqIf6pNPwR9iPEiSojyczo11omsLZZNB6+/GTue+M7zhrYmZSE4zdFT0sY2CWVNXmlIc/Hk2wN4PeXDCEjyV7z+8ItBdw/bx2V7rr13rinF2rkeUXSrO9mWbUPu1XDXy+KjDeCtFk0ypy+sBMQZqU5+Nm5g7j37JM4WO6m3OUjwarRMSWBDsn2MHtrPTRN8asLB3P7Wf2Z920es5buodrjx6ppGIaBN6CT090MCM8ZksX+UhdXz1jG4K6pnB5snaryVYXs15JiwV/lxwgYIQGVv9yPNaXu225gUO4pRwghROuRmWTHZlEhQVW8cyumJ8nkvaJ9yUi0NXtGv3id0a8D4wZ04t9fbOd3Fw9t0WMXVLj5+PuDFFZ68Pp1MpPtjO7XgZF9MkOy5TVWabWXDzfk887aA+wu+n/27ju8qfNs/Pj3nCPJkrxtbPACG7PN3hBmdtIMSAhp9iABMprRvmnfpvmlSdOkfds0zWpCErJJAmQvaDZ7771sDAaD99Y+5/z+EDYYSbZkvMDP57p6tZWOjh4ZnaPnfsZ912CQJb/L/YJJtiZLcHG/zvUCKThRXshPJBTqPTHQzFlTtWgwpQeI/UKNIIP5dzYoMmlxVtJCbWQ7EBdu4p5JPZg1IZPjld6A0KjIxIebiD0lIEzvFM4/rxvE/R9u4sv7xtEl2ozVYPU5n7WHFckgUbmxkuiRJ/eIqQ6Vqm1VdJ5Wv86WhESkKbLlPqAgCIIQsjGZ8RhkGbdaf0AylNFdi1HmojasrSgIbSHWagqYBKEt/fHyPlzy72VcOyyVPl2iGn/BGdB1ndU5Jby+NKcukU1t4gZFgjCjQny4idkTM5k6NAWrKfSQwOlR+Xl3IZ9tPsqanBIm907kwQt7cl73eMb94xcKAywxbizZWphB4e5TslbXykqORjvD8kKyBON7dQr5szakRYOpGKvJ7wbAUCNIl0cjpgOMrMmyRHKMheQYS8BjJvdO5JbR3bjng43MnzmazJhMthVvQ9NP/p0Vq0LilETy5+Ujm+V62fyMcUZixtavxG0xWEiNTG2xzyUIgiCErn9KNKmxFvYX+q5ACLaUhqbD9OFn4zCjIDRdtLVlC/c2VaeIMB6+qBePfb6DhbPGILfQFhRV0/nT59v5cku+3xp1qu6tuWRz2fnrt7t5dWk2C2aNIaWB/mctXdfZcKiMzzYdZfGOY/TtEsXUoSk8N31QvRVkf7q8L3/4bFvINWfDDDJDu8UyMDXG5zmzUWH68DQ+WNv08kJhBoW7x/sGameiRYOpaIuRjE7h7Cuo/0MQ6prJnokRfpf4dVT3TurB9qMVPPn1Lm6d+Gu+zfnWJzV6wuUJKOEKxxccx1XoQrbIRA2NIm1WGvJpWRB1dC7sdmFrfgRBEAQhCPdMyuSxL3b4rTXV2OiuLMGlWV3EMj+hw2mNwr1NdcPIrny8IY9PNx3huhYY6NB1nQfmb+bn3QVBlRCyu1WOlTu48qUVLHpgPF2i/dfMOlhcw+ebj/LF5qOEGWSmDk1h0QPjA04AXD0khZziGl5fluM3oPMnzCDTLd7K67cMC3jMHeel89G6ppcX6hwdxuA030DtTLT4Drh7JmXy2Oc7qDnthyDYCDLcpDB7UmZLN/OsIssS/5o+mCn/WUmUORHVFQdKvs9xcRPjiJsY5+cMJxkkA1dlXoXF0PhohCAIgtC6fjUwiTlLs8kpqgmYbjgQq8nAby/uOIVCBaFWrNVIWTucmYITCdSmDOCOd9ZzkZ99QWdqztJsft5dGFItVlXXqbC7uWnuGn54eGLdjFntPqjPNh8lr9TGlYOS+c+NQ+mfEhXUXquHL+pFrNXI3xbvQdfBFSBdvQRYTAqDUqOZe9sIwhtI0NEtPpzrhqXy6aajQQdptcxGmb9NHdhs+8RqtXgwdVn/JB77Yoff54KJICVJ4tL+XVqiaWe1iDADN45M4y/f7MYQORlz8kIkOfRRGINs4JZ+t7RACwVBEIQzFWZQ+PDu0Vz50gqKq50BM/ydzmJUeOeOEXSL7xhFQgXhVNFWE/kVlW3djIAGpEZz+YAu/OO7vTwztfmKx7pVjVeWZAcMMmp2LaFy/Re4S44gmywYE7sTPXY65tQsVE3neIWDX/YW4lY1Ptt0lNU5JUzqncgD5/dkfM9OGJTQKyrdfl4Gl/TvwvurD/H+mkPounf2TAdkScKtaozv2YlZEzMZHmRCjCev7s/xSgcrD5QEHVB5A6kBjMmMD/kzNKbFgymzUeGpq/vz6OfbQ143aTbKPD2lP2GG5s26cS5Yvr+If3y3FwBP1QDc5TkYYzaEFFCZFTNPjn2SblEiJ7ogCEJ71SkijEUPjOf2t9exv7Aah1sl0CRVeJiC2aDw3oyRZCW3fZF6QWgLse10z9Spfndxby58binXDUtlSNfYZjnn9zsL0ALcHCrXfU7F2k+Iv/g+zBlDkRQD9oMbse9fizk1C/AW+Z71/kaGp8dyzdBU/nXaPqimSoq28PtL+/DQhb1Yn1tKcbU3q2C0xcjgrjEkRvpfWhiIIku8dstwnvpmFx+uO4zEyeQapws3KUiSxEs3DmFy75apN9sqie6vGZrKsQoHL/28P+iAymyUeejCXlw9JKWFW3f2cbhV7pm3qd7f0llwJSBRs+c7Sr4vwHXMhWyWMXc1k3BlAuG9Thmd1CXMhjAeG/0Yl3e/vPU/gCAIghCS2HATX9x3HpvzynljWQ4/7CpAAswm72Cjy6PRp0sk90zK5IK+nTE2YQRZEM4VMZb2u2eqVrTFyKOX9+GxL3bw1f3jmqUe6mvLsn221QBozhrKV3xA/OUPYe09tu5xa49RWHuMqnesLEn8+/rBJEU3//YPk0HmvB7Nk0lPkSWeuCqLeydl1tWbdasn6s3q3lm69E7h3Dspk0v7d2nRiZlWqxp23+QeJEWZeezLHUjg9x8bvBEkwF+n9GfqUJFhzp/FO46h+6SGlCn6WqViXTFJN/UhaoiKZIDqHRVUbaoivFc4uub959ZtvXjq4t9xaY+Rrd94QRAEoUkkSWJo11hevXkY/7NwC9FWIxN6JWI1KSRFm0mN9S2VIQgdUYzV2C5To59uyuAU5q/LY96aQ9w2Nv2Mz5ftJ/MngPPoHnSPC2uvMY2eI8woc6CwukWCqZZwar3Zw6W2k+WFIkyt9hlatQTzNcNSuXxgEt9uO8acpdnkltRgOjF65lI1MjqFM3tiJpcPSMLczAW1ziWvLvEdeTh11MEQOw5HfiGGmLVYehRj7e3EXWVBs6fhLh+BrEeybLuFS3u20QcQBEEQzkhOcQ2/v7QPo7s3//p/QTjbxYab2nyZX22dp425ZZTWuAgzynSJMnP5gCQSo7zL2iRJ4q9T+nP962u4bECXkJe7nc4RYKmbaq9EtkYFVdtV16Ha4TmjdrQFgyLTPSGibd67td/QbFS4dlgq1w5LpazGRbndO3IQYzHWK1Ar+JdXauNwqc3n8dNHHTRXIq7CK/2eQwW+2HKUZ65pvk2PgiAIQuvQNJ29x6vo28JFPwXhbBVjMbbZMr9qp4eP1+fx+vIcKuzuenscwwwyzyzew3mZ8cyamMno7vH07BzJ9OFpPPPtbp7/9ZAzem+TImPXfFd+KZYoNFsluqY2GlBJkjeznhC8Vg+mThUbbhIBVIiKqp0YFdln71koow7grSvgUbUmZWYRBEEQ2k5emY1oi1HUjxKEAKwmBVXTcbjVVl3pdKikhumvrabS7vabmrw2ScKSvUWsySll2rBUnrwqiwcu6MFFzy1jVXYxYzND31OUV2pjVXYxgbZdhaX0QTIYse1bTXifcQ2ey6PqpMaeHUv82os2DaaE0LkCTOGGMuoAtekodUSiREEQhLPL7mNV9EkSs1KCEIgkSURbjVTY3a0WTOWV2rjq5ZVUOdwBs23W0vEOan+y8Qg1Lg//um4Qj1/Zj//3xQ4WPzgBk6Hhge6SaiersktYlV3MygMl2FwexmR24vy+ifywq8BnwF0OCydm3E2U/jAHSVYwZwxBkg04crfgOLyN2Ml31h2bFmehR2JkU/8MHZIIps4yUWYjPrknCG3UAQDdmzFREARBOLvsPlZJ3yTR2RGEhtQW7u0cdWb7kILhUTVueGON30CqodpOdrfK4u3HGZASze1j01mwPo+5K3K4d1KPeueodnpYd7CElQdKWHmgmKNldkZmxDG2RyduG5tO786RSJJElcPN9zsL/LYxauQ1yOGxVKxeQPE3zyKZLIR17kHUmOvrjgk3KdwzKbPZ/z7nOhFMnWW6J4T7yeQX2qgDQN+kyGavAC0IgiC0vN3HKrlqcHJbN0MQ2rUYa+ulR/9pTyFlNpdPIBVMbSe7W+XFn/Zz65h0nrgyi6v/s4JLs7pQWOVk1YFiVmaXsPtYJQNTozkvsxNPTx3AoNRov9s0Is1Gpg5J4fONh3Fqvn28iKzJRGRNDvg5jIrMZf2TzuyP0QGJYOosYzYqTB+Rxrw1h3Cr9a/aYEYdoHbkof6ohyAIgtA+qZrOweJqymxuZEli+9EKHrmkd1s3SxDatRiLkbIaZ6u815wl2dQ4A2dZbqy2k8uj8c7Kg7g1nSiLkYv+vYz+yVEMT49jUGoM3TuFU+30cLjUhnKwhLRYS11GwPpvqvGE9RM2y93IoQtBlnYFwGJUePfOkSKbdhOIYOosdPvYdD5cexjvqtv6Ght1AJBliYuzOrdQ6wRBEITmUFzt5KN1h3lrxUGcnpPFKKudHm57ax0zJ3Tn2mGpRJpFIgpBALC5PHyx+Shzlx8kt6SG73cVYFK2khRj5q5xGUwdmkpEWPN2fQ+V1LD7WKXP46HUdqpxqfzz+71cPzyN31/Sm6e/3U2k2cgHaw8hSxK2U8rhhBlknv9xP+dlxnPv5B6MSI/zPuGqgc9mYraXseDh+7j5w30cKKz22T91OlnyBlJzbxvBoLSY0D68AIhg6qzULT6cS/t34budxxu9SE5nMSr8/pLeGEUWP0EQhHZJ13X+/cN+5izLRuJkBrBT5Vc4+L/v9vL3xXt48uosrh/RtfUbKgjthEfV+PviPXyw9jCSRL3gw6VqHCqx8bfFe3h60W6uG5bG/7uiX6NJHoK1r6Dam2XZc2ZZls1GhSev7s8PuwoornGRX1Hs97jTMwLeMymT3wwPR5r/a0jMgmlvE2Mw8cnsTry54iBvrTiIw6361Ce1GGU0HS7N6sJDF/Uio1N4Ez69ACKYOmv9c9ogjpbZ2ZFfEXRAJUswfXgqt4xJb9nGCYIgCE2i6zq//2Qb32w7FjB7ay37ic7RE1/tpKjKyf3ni0rsQsfjcKvc+uY6th8t95uOvFZtgPXxxjy2Hy3ng7tGE94Ms1TVTjean5VCoWZZdrhVftlTyG8+2tTotQ8nMwK++st+1JXf8/Dkq2Hcb72FovAGZ/dN7sHsiZks21fEgvV5HK904FY1oi1GJvdOZPrwNFFioRmIYOosZTLIfHD3KB6av4Ule4tweTRUf2n+AJMiIUkSSdEtn9FGEARBaLoXftrPN9uOYXf7Ft4MxO7WePmXA6TEWJg6NLUFWycI7Yuq6cx8fyNbj5T7ncH1x+HW2H2sihnvrmfejFFnVG9T03SqHR5U1bf/FWqWZaMsc+8Hm/wOkDeYEdCj87p+IYMTRzLZT2IxRZaY3CeRyX0Sm/YhhUaJYOosFmZQePXmYew4WsHc5Tks3nEckyKj6TpIIOG9qG4a1ZVbxnTzZnn5z0o+WneYG0aKJSGCIAjtSVmNi1eXZPvtFDbUmQJvB/GJr3dxxaBksYxb6DC+2ZbPhtzSkK8Zp0dja14Fn246EvQSWbeqsb+gmp35FezMr2RXfiW7j1ViNsq4Vd/3DzXLskGRcHh8B1GCygioSjz/wz4m9xYBU1sQwdQ5oH9KNM//eghP2txsyiuj0u5GkSXirCaGpccSdkpl3rm3DWf6a6vJ6BTO6O7xbdhqQRAE4VQLNuThr2JFMJ0p8O4b+Wl3AZeK1MZCB/Hqkux6+6NqBZuS/NUl2UwfnuZTKqbG6WHP8Up25ley82glO49VcKCwmpQYC1nJ0WQlR3Fh3870S44i1mpk0rNLOFRi82lHsFmWrSaFGpfqk6U5lIyAe45XcaCwmh6JEaH9EYUzJoKpc0i01djoqET3hAj+ff1g7v9wM5/fO5a0OGsrtU4QBEEIRNN03lx+0GeJTyidqRqXyitLskUwJXQIu/Ir/QYwoVwzhVVOlu0vRpbwBk75lezMryC/3E7PxEiykqPonxLF9BFp9E2KxGry322ePTGTp77Z5TewCybLskfVMcjgOu3xUDICqprO+6tzefLq/o0eKzQvEUx1QON7JnD/5EzuencDn947tn6aUHs5HFnv/W9ZBmsn6DoaDGFt12BBEIRz3JEyO9VOj8/joXSmALYfrUDVdBRZFGUXzm1fb83H6WdZXCjXjM2lctc76xnSLZas5Cgm9krg3kmZ9EiMCGm57NWDk3lm0e6Q2l/LbJTp2yWSLXkVPs+FkhHQo+nsPl7VpDYIZ0YEUx3UbWPT2VtQxUPzN/PaLcNRCrbBqpdh91egmEDXOLHxymvYHTByJsSktWWzBUEQzkkVdjcGPwFQqOmVjbJMlcNNjNXU3E0UhHblWIUdzU/erVCvmYm9E5h724gzaovVZOCdO0Zw89y1DWYUPJ1JkejdOZKBqTFs9hNMhZoRsMrhDqndQvMQu1Q7KEmSePKq/tTY7ex/+Vp46xLY8Sl4HOCsBFc1uKrAeeI/a+fAy8Ng2bMQIGugIAiC0DSy7K8Me/3OVDB0dGQxKyV0AKfvL6oV6jXjCXCeUA3rFscbt47AalII5hI0G2X6p0Qz765RxFpN+HvJqRkBg9HcBYmF4IhgqgMz4eE95WnSS5eD2w56Azce1QUeJyz/F/z3j63XSEEQhA4gPjzMb0awUDtTmgYRAfZ1CMK5pFOE/9nXUK+ZuADnaYpxPTvx9W/G8asBSYQZZMxG3252uEkhMTKM313UmwWzxhBpNpKREI7V5DvzdGpGQNu+1WhuB7rqwZ69gbJf3qp3rCJD7y6RzfZZhOCJO25H9tX9GAu2YvTZ8tgAtw02vQOdesCIu1qsaYIgCB1Jl2gzabEWDhTV1Hs8lPTKEjCpd4KYmRI6hHE9E/hk4xFqTkv6EMo1Ex6mMKmZ04lnJkTw0o1DKbe5WLghj2X7iim3uwhTFJJizFw/Io3zMjvVu04vyerCHz/b7vd8wWYENCoyt45Jb9bPIgRHBFMdVWkO7PrSu6zvFOnPV2Fzw8EHIwg3eS/0uZtczNvmZsnt4d6D3Hb48S8w9DZQROVsQRCE5jB7Ug/+/OUOn85hsJ0pi0lh5oTurdlkuhL6AAAAIABJREFUQWgz5/dJxGSQfa4XCP6akSWJS7O6tEj7YqwmZk7IZOaEzEaPNRsVpg9P44M1h3D72QgWTEbAHgkR9OosZqbaggimOqq1r0GA9cSqDi+sdfHo+AYy+Okq7PkGsqa2UAMFQRA6lisGJvHnL3f4fS6YzlSs1cTIjLiWaJogtDuKLHH72AxeWXLAb9Hexq4ZoyJx86humAztY8fLnedlsGB9Hu4g93qdymJUeOCCni3QKiEY7eMbJLQutwM2vw+a/6wvj4w18ewqJ+WOBjZluqphxfMt1EBBEISOx2xUePa6QX73WTT+WpkXbxjiU3xUEM5lt43tRpTF6LfYdUMkIDLMyJ3jMlqkXU3RNd7Kc9NDv/4tRoVbRnfj4haaYRMaJ4KpjqhkP/jNG+M1PFlhUrqBZ1c5Gz7P8e0is58gCEIzumxAEn+6vG9IHSqLUeE/Nw5lWLfYFmyZILQ/MVYT82eOJjLM0ECvpj5JgvAwAx/NHE1CZPuqoXnZgCT+OW0QFqMcVEZAi1Hh9rHd+OPlfVq+cUJAIpjqiBwVIDX8T/+XyWG8tM5FUU0D9RIkvPunBEEQhGZzy5h0Xr5hKAmRYYSH+a8tI0lgNSl0i7fywd2juKBv51ZupSC0D5kJETx5dRbyiWuiIeEmhS5RZr7+zbh2m/nuykHJfH7feXUZAS2nDazIEoQZZEZlxPH6rcP4w2V9xYx0GxN7pjoipfGRmP6JClf0MvD3FS76JgQIvDTNW+BXEARBaFYX9uvM+X0SWXGgmDlLs1mVXVI3Um2QZc7vk8DMiZkMSYsRHSmhQ6uwufnnf/cy9/YRSMBrS7PZdLi83l4ol0djYGo0sydmMql3Iko7z3jZp0tUXUbATzYeYduRCsrtLjyqzr6CKj6/9zzS4qxt3UzhBBFMdUQRid66UY14cpKZoa9V87sxAYIvkxUU8RUSBEFoCbIsMaFXAoNSYxj795/Y8NiFyLJEmKHh0XdB6Ch0XedPX2zn4qwuTD6R4nxS70SOVzg4VFJDlcNDhNlA1zgryTGWNm5t6GKsJu4afzJDp92lMuSp7+kcZW7DVgmnEz3hjii2G8RmQNHuBg/rESdzfZaRF9e5GJB42uyUbIAB01uwkYIgCAJAdnE1mYkRWEQxXkGo58st+ew5XsU3vxlU7/Eu0Wa6RJ97AYfFpJAaayW7qJq+SVFt3RzhBLFnqqMa9zCYIho97PGJYdS4/CSZkA0w+t4WaJggCIJwquzCajITGr9fC0JHcqTMxlPf7OL56wdjNnac2dqs5Ch25le2dTOEU4hhro6q39Ww6H98Hs59qP6GzLRoGcdjp41+SDJ0zoKEXi3ZQkEQBAHILqohMyG8rZshCO2Gqun8buFWZozPoH9KdFs3p1X1S4piV34lDGvrlgi1xMxUR2U0w3XvgCH0aXC3Eg7Xvtn8bRIEQRB8ZBdV013MTAlCnTeW56DrMGtCZls3pdX1S45i17GKtm6GcAoRTHVkPS6AKXPAEOSmTElGNUVzO4/z5WGRxU8QBKE1ZBeJZX6CUGtnfgWvL8vhX9MHtfusfC2hdmZKF3U+2w2xzK+j6z8VopK8S/6KD3iz/Olq/WNqU6mnn4dyxfM87ozj5jfXAnD14JRWbrAgCELH4VY1jpTZ6RYv0iALgsOt8tD8LTz2q74dNjV4fEQYVpOBI2X2Dvs3aG9EMCVA19EwewUU7IRVL0P2T+CsAlkBcwwMvB5GzICoZAB6A/NmjOIWEVAJgiA0uwOFVSzfX0y5zUW53U1EmEJJjYuUszC1syA0p//77x56dYlk6pCO3e/wLvWrFMFUOyGCKeGkzlkw9dWgDu3dJZJ5d43i5rkioBIEQThTHlXj+10FzFmazb6CKnQdnB4NAFmC859dwvD0WGZNyGR8z06iUK/Q4SzfX8R/dxxn8YPjO/z3v1+SN6PfJVld2ropAiKYEs5Ar84nAypdhymnjhQ5q2DrAlj3GlQdB9UJRit07g/nPQiZF4AstuwJgiBUOtzc9tY69h6vwuZSfZ7XTgRWKw+UsPlwORN6JvDiDUMwGcQ9VOgYympcPPLxNp69bhAxVrFnOys5ik83HW3rZggniGBKOCOnBlQAU7Ji4bs/egMpSQK37eTBHifkLof8zWC0wOQ/wfA72qjlgiAIba/G6WHqf1aSV2rHpWqNHm9zqSzZV8jtb6/jvTtHYlBEQCWc23Rd59HPt/OrgUmM69mprZvTLvRLjuKv3+5u62YIJ4i7sHDGagOq/3y7jsqXxsOWj8Bjrx9IncpVDTVF8N2j8M1vQWSkEQShg7r3g00cKQsukKrlcGtsPlwmOlNCh/DppqPkFNXwyCW927op7UZarJVKu5tym6utmyIgZqaEZtIrzsDX0f+HXJILeIJ7kdsGWz8CcxRc+ETLNU4QBKEdOlBYxdqckrq9Uaeq2bWEyvVf4C45gmyyYEzsTvTY6ZhTswCwuzU+WneYhy/qRbTF2NpNF4RWkVdq45lFu5k3YxRmo9LWzWk3ZFmi74kU6WN7iNm6tiZmpoTmsfxfmCsOYjolkEp/vorEf1ZR4zo58zR3k4tJ79ScfJ3bBmvneJf+CYIgdCBvrcjFo/nOzFeu+5zSn94gevR0Uu+fR8o9bxM59HLs+9fWO06WJD7deKS1misIrUrVdB5esIXZE7vTLzmqrZvT7tRm9BPangimhDOnumHd6+Bx+D6lwwtrG5mG9ji9KdkFQRA6CIdb5fPNR32CKc1ZQ/mKD4i76B6svccim8xIigFrj1HETr6z3rF2t8rry3Nas9mC0GrmLM3GqMjcNa57WzelXaot3iu0PRFMCWdu7yLQfDNQATwy1sSzq5yUOxrYF6VrsOcbsJe1UAMFQRDal6PldmQ/2Z2dR/ege1xYe40J6jyFlQ7cIey3EoSzwfYjFby14iD/mj4I2d+FItAv2ZseXWh7IpgSztzGd71JJfwYnqwwKd3As6ucDZ9DUmDPohZonCAIQvtTaXf77SSq9kpkaxSSHNz+EKMiU+UIcp+qIJwF7C6VBxds5vEr+5EsClUH1LNzBIdKa3C4/Q9mC61HBFPCmavMb/Dpv0wO46V1LopqGhg99TihuqCZGyYIgtA+WU0Gv4lMFUsUmq0SPcBs/+k8qo5FbMwXziF/W7ybASnRXD04pfGDO7Awg0J6fDj7C/wPZgutRwRTwpnT3A0+3T9R4YpeBv6+ooG9U7rq3XslCILQAXSOCsPlJ4tfWEofJIMR277VQZ0nzChjNoqfcuHc8MveQn7aXchfru7f1k05K3iX+lW0dTM6PHEHFs5cWHSjhzw5ycwbm1wcrQywd0oxgbnx8wiCIJwLYqwmRnWP83lcDgsnZtxNlP4wB9u+1WhuB7rqwZ69gbJf3qp3rFGRmD48DUkSe0qE9k/VdHYcrWDpviJ+2VvI5sNlOD0nZ2BLqp3876fb+Od1A0W6/yBlJUeLjH7tgKgzJZy5HudDwQ5QA++L6hEnc32WkRfXuRiQ6CeGl2VIG9mCjRQEQWhfZk3IZNOhMmpc9Zf0RY28Bjk8lorVCyj+5lkkk4Wwzj2IGnN9veNkSeKO89JbscWCELqSaicfrT/MWytycXpU5BPBv66DDtwwMo3bxnTjqW92c/XgFMZmirpJweqXFMXi7cfauhkdngimhDM3fAasfLHRwx6fGMb72wIs5YtOg5ShzdwwQRCE9mtsZjyx4SZsbrvP/qmIrMlEZE0O+FqDLDGkawzd4sNbuJVCR6Wf+FKeyczn2ysP8vfFe5AkcLj975t+d1Uu76zMJcJs4IVfD27ye3VE/ZKi2HO8Ck3TRdbDNiSCKeHMRSVBxng48BPecSav3Ici6x2WFi3jeMxP4T1jOIz7bQs3UhCEjkzVdHKKqim3u5EliLWayOgU3qZL5GRZ4p07RnL1f1ZQ4ww+I5csQWy4iZdvFANQQvM6UFjFmyty+XZbPtVODzpgMSqMzohn5sTujMqIC/qa+cd/9/D2ylycfvYGnsqtevsNdpfKwwu28spNQ0VgEKRoq5EYq5FDpTYyOomBlbYigimheVz4BBxaCW576K8Nj4esqc3dIkEQBIqqnHy49jDvrDqIy6PVddI8qk5suJGZ47tz7bBUIs1ts0ejR2IE8+8ew01vrqHG4UFtoCQfePdJdYoIY+GsMXSKCGudRgrnvAOFVfx24Vb2Hq/Co2mcWrrM5lL5ZW8haw6WEGM18vTUAUzundjg+eavO8zbK3Oxh5C22+nRWLqviGcW7eaxK/o19aN0OLXFe0Uw1XZEMCU0jy4DYNrb8MkdwQdUsgEMZvC4oDQHOoubpyB0VMdrjrOlcAuVrkoMsoF4czyjkkZhNpibdD5d13nuh328viwHwO/ouL1c5R/f7eXv/93Dk1f15/oRaWf0GZpqQGo0ix+cwHPf7+WLLflIgEerH1WFmxR04PrhaTx4YU9irKY2aatw7tmQW8ptb63D5lIJFMvreIMqm0vlnnkbeexX/bh5dDe/x7o8Gk8v2h0wkKrZtYTK9V/gLjmCbLJgTOxO9NjpmFOzsLtV3ltziJkTu5MY2bRrv6PplxzFrmMV/GpgUls3pcMSwZTQfHpfBjcsgPk3gq6B2xb4WFMERCXDbV9D7gp47yqY/h50G9t67RUEoU1pusaaY2t4e8fbbCrYhFE24tE9yJKMLMlousY1Pa/hpj43kRYVfKCj6zr/8/FWFm0/3ugSI9uJ5A9PfLWD4mon903ucUafqalSYiz837UDWbqviOtHpLH+YOmJJYkS8eEmrhmayq8GJmEWNaWEZrS/oIrb3lrnkwSlIQ63xl+/3UVcuJHLByT7PP/9ruNomv+wrHLd51Ss/YT4i+/DnDEUSTFgP7gR+/61mFOzAJCAj9Ye5sELezXpM3U0WcnRfLD2UFs3o0MTwZTQvLpPhN/tgW0LYeXzYCsBSfam7ZFkb8a/1JFw3kOQeb43i9+AaWCNgwW3wFUvQp9ftfWnEAShhVW5qrjnx3vYX7Yfm8c78OLSfGvRLdizgE/2fcLsQbOZ0X9GUPs1/v3DfhZtPx7SEiO7W+Oln/eTEmNhypC2KRa6/EAxqbFWHrmkT5u8v9DxPDh/S92AwukamkFyuDV+t3Abk3onYjXV70q+uiTbb3CmOWsoX/EB8Zc/hLX3yYFTa49RWHuMqvv/To/G26tyuW9yDwyKqODTmH7J3mV+QtsRwZTQ/MIiYcQMGH4nHNsKlUe9S//CoiCxL8T4GWHOPB9u/gQ+/DXUFMGw21u92YIgtI5qVzU3fHsDx6qP+Q2gTuXRPXhUD69te41KZyW/Hd5wsprSGhdzlmX7LYjbUOcQvCPuf/5qJ1cMTGqTTtwnG48wbVhqq7+v0DHtPlbJweIav0v7gppBkuDLLfncMLJr3es0TQ9Y98h5dA+6x4W115hG2+byaBwutdE9IaJJn60jSY424/RoFFU5iQ83ieQdbUAEU0LLkSRIHuz9TzCSh8Adi2DeNVBVABN/7z1HA3RdFwUrBaGN1Tg9fLHlKF9vzae0xhscxYWbuHpwClcPTq43cq3rOvf/fH9QgdSpHB4HH+35iO7R3ZnSc0rA4+avO4x0ep5xguscAng0jR93F3Jp/y5Bt605VNjcLNtXxDNTBrTq+wod15srDuJSfQcdgp1BsrlU5izJ5tcjThaOrnJ4MMhSXYa+U6n2SmRrFJLc+FJVRZaosAcopSLUKatxMX/9YewuldF/+wlV0zEpMt3ircyamMkVYmlwqxDBlNC+xGfCnd/DB9dCdQFc/k845cZbYXOzcEMeb686SFGVE7eqYzLIdIsTNw5BaG2FVQ7+/cN+vth8FEnCZ7nQtiMV/OXrXUwdksLDF/UiITKM7cXb2VWyy28gVba8jOLvinEVulDMClHDoug8rTNKuPeadqgOnt/0PFf1uApZ8p05yi2u4bkf9vkkbwi2cwhQ41SZszS71YOpr7flM6FXAtHWtskqKHQ83247hupnb1MoM0gFFTZ2fraIpMoiPAUFVBYWoxkneJf1n0axRKHZKtE1tdGAStfBKJb4BVTj9PCnz7ezeMdxJIl6QbFL1dhfWM2fv9zB41/u4K5xGTx0YS8xY9WCxDdVaH8iO8Pti6BkP3x8O7gd2F0qv1u4hZHP/Mi/fthLfrmjbuTL5Tl54xj61A/86/u9ATe/CoLQPPYXVHHp88v5eEMedrfqd9+FzaVid6ss3JDHZS8s40BhNe/sfAenx+lzbPHiYo5/fJwu07vQ75V+dP9/3XGVuMh9NhftlCV7do+dVfmrfF6//UgFV7y03CeQgtA6hwBbj5S3+j3kk41HmDZULPETWodb1XB4/O+VCmkGye0ib8U63MfyMSR0ImHSRGTZf9cyLKUPksGIbd/qoNoXHyEyVvpTUu3kypdXsHiHN8FOoGLINSeyL76x/CAz39+A288spNA8RDAltE/mKLjpE5AVyt+9katfWso3244FdeOYu/wgM95dL24cgtBC8kptTJuzmrIal9/g5XQeTaek2sW1r67k55zNaNS/NlW7SuEXhSTfnEzkwEgkg4QpwUTavWm4il1UrKqoO9bmsfHWjrfqvf5QSQ03zl1DdYDCt6F0DgEMskSVwxPUsc3hQGE1R8vtjO/ZqdXeU+jYVE0n0DzFqTNIjZGtVmJ/8wBdHn2U+BkziLnqCi7un4S/SRA5LJyYcTdR+sMcbPtWo7kd6KoHe/YGyn6pf02nx4eTFG1pwic7t9ldKjfOXUteia3RTKV1r3GrrDhQzP98vBXdzxJo4cyJZX5C+2UIw3HVG9z0j085WFOFO8ivq92tsjqnhN8u2MKLNwwJak+VR9X4aU8hc5fnkF1Ug8OtEmZQSIuzMGNcBpf1T8JkEGMPggBw93sbqHZ4fDauN5TgQce7n0LOuwlL+gv1Xmfbb0Nza0QNi6r3uGJWiBwYSfXOamInxNY9vqtkV73jHvhoMzVO/8FPza4lVKz8CK2mnLyXb8F0WtIJf3QdFKX1lsR8uukI1wxJEZnLhFZjNirIkoTmp3N96gxSeJ9xDZ5H03WiLfWXpt49oTs/7yn0m00zauQ1yOGxVKxeQPE3zyKZLIR17kHUmOvrjgk3KdwzKbOJn+zc9vqyHA4V1+D2M4jVWPbFH3YVsOJAMeN7JrRBy89tIpgS2rW3Vx8i2xWNm9Aycznc3uBoyd4iJvcJXKld13XeWJ7DK79k41a1eulcbS6VMpuLRz/fzp8+38Ed56Xz0IW9UMS6Y6ED25pXzqESG+ppnbBgEjxoOmiOeFRHEor5WN1r1WoVQ4QByU8AY4g2YD9UvxC4w+Oo+98HCqvYW1CFvwmy2jbFnj+Dkv++TOz5dyGbzD5JJ06n4+3QtRi33VuLz2hF1eHzTUd5b8bIlns/QfBjcFoMGw6V+Tx+6gySJCuYM4YgyQYcuVtwHN5G7OQ7646VkMg8LePeoNRokqLN5BTX+H3fiKzJRGRNDtguSZK4bEDr7lk8G6iazturDuLwMyMVzP23NmGICKaanwimhHZL03Tmrjjod1lf0DeOpdkBgymPqvGb+ZtZsqeowXo0NSeWDs1dnsOmQ2W8efsIkeRC6LDmLs/Bedpei1ASPKAruErGYUn5uO4hJULBU+1BV3WfgMpT4cEQUf+nSkIhr9RGSoyFt1bk+l3Se2qbwvuMQ6spp+ynN4i/5D6ix92Irnr8dw4lOL9PYvNmCdV1OLwGVr0IB34ATfNWJkWiMmEEl5gvo1dC4M6lILSE2RMzeXD+Zr81oYKZQTIpEjeP7uqzakOSJF6+cSjXvroqpFpvAGajzIs3DCbMIH5jT/fznkLcfgKpUO6/6w+VcbTcTkqMWELZnEQwJbRbS/cX4TiDwn8AW/LKySu1kRZnrfe4ruv84dNt/LKnMOAerNPZ3RobDpUxe95G3rpthMiMI3Q4DrfKdzsLfGaBQkvwoOCpGoiuf4Ykea9vaw8rkkGicmMl0SOj645UHSpV26roPK1zvTNIWjjXzVlNlcON3a36nZU6vU3BdA4BLEaFmRO6B/E5gnR4LXw+E6qLwG2D2sWRJ/4rumA1jylb4bmX4aqXodclzffegtCAyX0SMRlkv8EUND6DhCRxy5h0v0/1S47izduHc9e7GwIWBT6d2Sjz1ykDOL9P58YP7oA+2Zjn998qlPuvBHy34zh3jstogRZ2XCKYEtqtTzceOeMbh6br/HfHce4+rXP0y95CFu847jeQamj5oNOjsTanlE82HWH6cD/FhwXhHFZa4/Iucz3tsgw1wQPo6B4rkrEKAMWqkDglkfx5+chmmYh+EbjL3OS/n48xzkjM2Ji6V4YpYdw14EZm334BpTUuhj71g9938NemRjuHQHy4ieHdYhs8Jmi7v4HP7vIu6wtABmTVBtU2WHgbXPK0t+i5ILQwRZZ4/NKe/O+n23FKoc0EWYwK1w5NaXCGY2xmJz69Zyz/8/FWDuSX40FCPS3thSR59291ijDxt6kDGSeSsARUUOmbBRVCu//WFvcVmpcIpoR263iFw+/jodw43KpOQaXveV5dku13tCyY5YN2t3fdsQimhI7G5lL9ZukKpX4MgCTpoNfftJ5weQJKuMLxBcdxFbqQLTJRQ6NIm5WGbDy5jEjXdab1mgZAeJiCLOF3ZirUNkHtEqPGk9Zomk6Vw0OV043FqBBtMfomjzi0utFAyofHDt/9CcIToN9Vwb9OEJrAU1bGkOce5dau5zHPlIk9yFUaFqPM6O5xPHl1/0aP7ZsUxdezRvLD5dP58dY/8N3BKmqcKjo6FqPCeT07MWtCd4Z2jW3epbXnII/m/98n1Hudv0LNwpkRwZTQbgVKuRzqjaPc7kLX9bob9aGSGrYdqfA5LpTlg8cqHGzNK2dQWszppxGEc1aUxeC3yGco2b8AdF0B2XeQI25iHHET4wK+ziAZGJcyjk4W7+h1mKF5MpIBmAwyr9w0lCFdA89KFVU5+XDtId5elUuN04NBltF0HUmCKYNTmDEug56dI717pD6d4RNIpT9fhc0NBx+MINzkvR/N3eRi3jY3S24P9x7kscMX93qX+xnCGm23IDSFKy+PvLvuJvLii3n04XvJWH+EJ7/eiSQRcOm7UZGQJYmpQ1J5akr/oJMxVX33HX3TE7jktrH8E+rSc4vgKTSxFv91t0K51ykyxIWL+l3NTQRTQrsVazX6fTzUTtK3246zaPtxusZZ6RZvpbTahcfPyEwoywedHm8hUhFMCR1JfHgYRoPsk00qlOxfAOEmAxaTjjOEAVIZmThLHE+MfaLe4/1TotmSV+57fEgZyeD9O0cyqnu83/d2elQe/Ww7X287hgR19V3c6snZ7Y835vHF5qP0S45izgQXiQ7fNgGoOryw1sWj4xsKlHTY+QUMur6BYwShaezbt3Pk3vuIv2c2cTfeCMCNo7pyWf8uLNiQx5vLD2Jzeer2Beu6NwC6fkQat4/NoGu8taHT+yibv4C422+r+/8iiGqaC/omsuFwGfbTVtWEcq8LMyiMDnCfE5pOBFNCu3VB30TWHiz1WY4Xyo3DalJ4+/YR9EuO4lCJjUMlNuYszUb1M+kVyvJBTYej5YGX76iaTqXdjVvViLIYRfY/4ZygyBK3jO7G3OUHfZaKBJvgIcwgM2Ncd/pkPsGfV/0Zh+p/Oe+pDLKBeHM871z6DrHm+jNHsydm8ruPt9Rl3Qy1TYqmMtVzmP5FCegZcT4dPZvLww2vr2FvQRWuBopkqhqomsa2IxVc9pGNz5Rwusm+qaEfGWviHyud3DvCRIw5QKfSVQ0r/y2CKaFBmqazKruE15dls+d4FXaXSphRJiXGwp3jMri0fxefrHhVS5Zw7I+PkvTXp4i84IJ6z8WGm5g9MZOZ47uz61glZTZvUe5oi5F+SVFN+h1z7t+POy+PyMkiW+WZumZYKn9bvMfvc8HefxMjwxjaVQwCNzcRTAnt1pQhqfz1291+nwv2xhFnNTHyRAepf0o0/VOi+WHXcbYf9V3mF+rywdOrj+u6zua8ct5YlsMPuwqQJQlZ8q5P7hpnZfbETK4anIzVJC474ex1y5huzF1x0O9zwSR40IGbR3UjMao3SRFJ/GP9P9hfth9VU/Ho9QvvWgwWNF3jkvRLeGT4I8SYfTsBF/ZNxCjL+GTFCLJNBpOR6/sm8uELH3DMshh1wBA6ZfWiV3IME3slMPO9jew5XuVzvQfi0XTKMDFdfZz/hv0vsVJ1veeHJytMSjfw7Confz3fHPhEJTlQdRwiRb0dwdeC9Yd57od9VDs89RI1VTmhuNrFo5956yPeOqYbD1/UC6MiU7ZgIUUvv0Taq69gGTw44Lll2ft72RzKFiwketq1SEb/K02E4EWZjVwxMInPNx/1u0+0sXudxagwe2KmmBlsAaJXJ7RbEWEGrh6cwicb8/C3XzKYG8esid19bhzxEf6X14S6fPDUZYh7j1cxe95GCiodOOpSNZ+82+WW2PjLN7t48utd3H9+D+6dJG5owtkpKdrCJVmd+WFXQdBlBWqZjTKX9U8iMcobRAxJHMJHv/qInIoc5u2ax+r81VS7qzHIBmLCYpjWaxpXZV5FpCky4DkNiswfL+/LE1/tDLmmjckgkRJj4baDElLPy7A5VfTjIB/bj1kG3WDArWr4i6MayvqpoVBKJC95pvC4cZ7Pa/8yOYzz3qrhwVEN7F1QTGArEcGUUI+u6zz2xQ4+23S04fqIJwKst1YeZN3BUp51bMCz+FvS338fU3p6q7RVs9up/PprMj77tFXeryN46MJefLezgGqnp/GDT6HIEp2jwpgyJKWFWtaxiWBKaNd+c34PFm0/RpUj9BtHp0gT1w5L9XluZEYc89cf9lkWFOrywXE9vJvg1+eWcttb67C7VPynzPCqXa748s8HyC2u4R/TBoqASjgr/XPaIKbNWcWIllMVAAAgAElEQVT+guqgZ2zCDDK9O0fy92sH+DzXPbo7j495vMntuX5EGrnFNbyzKjekgMrl0TlcavNJdqNJMjYdCBAsBpP1042RBepkfm9YgFly13t9/0SFK3oZ+PsKF30TZH9v4d3IpYUWHArnvr8t2tNoIHUqh1tja24xDzrCWfDBB5gSWi/1eOWixZgHD8KYIjrwzSUtzsq7d47kljfXBl2/yyBLxIabWDBrjNhy0EIC3MUFoX1IjbXy3p0jsZqCvwFIQIzFyIKZY/wuqbugT+2yIF9RI68h9vwZVKxewJGXbuLIq7dTtekbLD3rJ6VwqxpXDEziQGE1t7+1DlsjgdSp7G6Vb7Yd49nv9gb9mQShPTEbFRbOGsPwbrFBXZtWk+IdxJg5xmcPR3P5w2V9+O1FvQgzyJgMwf+0BcoaGkht1s+4i+7B2nssssmMpBiw9hjlk2gDYLE20u95npxk5o1NLo5WBnh/1QMWsbdBOGnjoVLeX3PIbyBVs2sJx959iMPPTePIy7dQsPDPOI7sBMAtKeyNSuajvVWt2t6yhQuIvV7s+2tuw7rF8snssXSKMBHewP1Xwnvv7dk5gkUPjKdzVAPLioUzImamhHZvSNdYPr1nLLe+uQ6b2+N3ozl4bxwWk4Ku60wfkUZygGKCBkXmtrHpzFma7XdUvbHlg4rkDdam/GcVSAQcHWpoGZDdrTJ3xUF+PbIraXGhZUYS2taRMhu/7CmktMY72xBjNTKxVwLpncLbuGWty2oy8N6MUfyyp5A5S7Pr9iE6PRoS1AU0g9JimDWhO5N7J9ZlB2spd0/ozlWDk3ngo82sPVjaIu8RStbPGix87hnHVGWlz3M94mSuzzLy4joXAxL9BH+WGIjynVkXOq45S3NweJpWH9Hh1nhtWQ63n5feKisiHHv24CkoJGLChBZ/r46oX3IUq/94AT/uKuDVpdnsPV5Vd8+1uW3ImJjQK5FZEzIZkS5qeLU0EUwJZ4W+SVGs+uP5/LS7gFeXeDMXmQyyd1uS5O3AnZcZz8wJmXSLt3DVy6u4oE8iw9P916y5dUw33ludG/QSpVOZTQqf3TuWDbll/PbjrX5npIL5cdN0nbdXHuTxK7NCboPQujRNZ9n+Il5bmsOmw2VIEjhPLAEzGWSeWbSb/inRzJ6Yyfl9EoOuv3K2U2SJC/t15sJ+ncktrmHpviLKalyANzPY5N6JIadRPlNrc0rYesR/WvLGNDQAUiuUrJ8AJUQFfO7xiWG8v83t+4TRCmPuhwAz6ELHU1ztZNm+Ik4vqRZKfcRKh5vV2SWM7dHyS/3KFiwgZto0JIPoZrYUoyJz2YAkLhuQRF6pjfxyOza3ynu7X+PinkO4IWtEWzexwxDfcuGsYVRkLu2fxKX9vTeOo+V27C6VCLOB9PhwEiJPJpb4+zUDeHD+Fr59YBwxVt9N3vERYXw0czTXvrrKu+k8yDZYjArv3DGStLhw3lqZi0GWcJ+WZz3YHze3qrNgfR6/v7SPWMfcjjncKvd+sIk1OSV+ZyFrA/KNh8p4cP5mBqZG8+ZtIwgP61i31/RO4W0+O6dpOn/5ZlfAxBgNBUvBDIBA6Fk/T5X7UP1EGmnRMo7H/ARbugZDbg7p3MK5bfH2Y/ibXAhlptTmUvlw3eEWD6a0mhoqFy2m+5dftOj7CCelxVnrVrnsssdR7PKfcVVoGWLYSzgrpcVZGd09nsl9EhmRHlcvkAK4sF9nLsnqwu8/2VZXbf10fbpE8eV959EpMqzBdcfgXXccbTGyYNZoRpyY7Vq0/ZhPIAWh/bhJksSmQ2WNHie0DbeqcfPctaw6UBzUZl+bS2Xz4XKmzVmNI8TMcsKZW7a/yKegZa3KdZ9T+tMbRI+eTur980i5520ih16Off/akPZBnZr1MxixhLhPxWiFCY+I/VJCPccqHH4HCUKdKc1voD5ic6n49lusw4Zh7CIyUbaFjKgMDlaIYKo1iWBKOGf94bLeHKtw8O6q3IDH9EiMZMUfJvPMNQPomxSJ2SgTEWbAalLq/jujUzhPXJXFmj9ewMDUkx2cQBkGQ/tx0ymz+VnmI7QaTdM5WFzD5sNlbD5cRm5xTV0A/v++2MGO/AocISwHdXo0coqq+e2CLS3VZCGA15bm1Ku5U6uxYCmUAZBTs37a9q1GczvQVQ/27A2U/fJWvWOtisqVYRsbPacbWG4x80lMLB/0HMNXyT05UHYg6M8tnPsCDc6cOlMajKYsbQ9V+YKFxF4/vcXfR/AvIzqD3Mrctm5Gh9Kx1qEIHUqYQeHlG4cw9ZVVDE+Po39KNLvyK3lzRQ4/7ymk2ulBkiQiwgxc1r8LL984FE3TyS6qpsrhISLMQNd4K1nJ/osXBpjwCnkZkBboREKLKre5WLghjzeWH6Ta4cEgSyCBR9WJsRq5aVRXPtt0BJef2cfG9tY4PRo/7SnkSJmN1FiRYKS1bMnzv1eqsWAp1NH9YIuGa7KRuF5jcO5bisloQHLb6j1foCgsiIpgflQkGqAqRjTnIQxrn0HTNbpFdWPGgBlc2PVCjIooetqRxYWbkCV8irWGWh8x2tKy3yP7jp2oZWWEj2u8LULL6BbVjcOVh1E1FSXEpchC04hgSjindYv3zird9e56oixG8kptuFQdte4XSafU42LB+jw+3XiE3kmR/N+1A+nTJfCm8VoRYQa/KWpD+XGTkIixik5Sa9J1ndeX5fDcD/uQJbD7WTpjr1B54af9fpdxBru3RtN13lt1iEd/1bdFP4/gpet6wNH7xoKlpuyDaizrpyyBqun8Zlc/Pp+xmb6F38KqF8FWCrKBry1Gnoy2okvgqt0Mo6ugq7g0bxKPvWV7eWLVE7yw6QXevuRtkiKSgmqbcO4Z0jUWi1HxmXkNpT6i2SBzXgvvlypfsICY6dchKaIT31asRiux5ljya/JJi0xr6+Z0CGKZn3DOCzPIFFW72FdQjd2tnRJIneTRdBweja15FVzzyipWZRc3et4L+iai+LmCQlkG5NF0hnaNbfJnE0L35Ne7eP7H/Tg9mt9Aqpa/QCqUvTVuVefDdYdxqy2/rEbwkgOk/21sKVSo+6AaIwExVhO9OkfSu0skt364hw+5DPWB7TB7BQsn/Ya/xMfilKWTgVQANo+NYzXHmP7NdI5VH2uW9glnn7GZ8USY/Y9/B1sfUQduGNm1xdqoVldT+d13RF9zTYu9hxCcjKgMcity27oZHYaYmRLOaWtySnhw/ma/AVQgNpfKXe9u4OPZYwIu8QOYMS6DLzYfRdV8O8vBLANSZJgyJKXDZX1rS3OX57BgfZ7fGcVghLK3BrwzE0VVTp+aZ6qm4/JomI2yqP/RTCRJIspi8LsHsbHZ4lBG9xsjSxBpNnDDyDTW5JSycNYYdh+r5KlvdvHe6lymjXPwWvbHOE7MPgVD0zUqXZXc8d0dfDXlK0yKb4ZS4dwmSRJ3j+/Ov77f5/f+FcxM6fl9EokLb7nvTuXXXxM+ahTGxMQWew8hOBnR3iQU41PHt3VTOgTRixPOWR5V4555G/1mQGpsz4vNpXLPvE0sfWRSwM5uz86R9OwcwfajlX6fb+zHzajIzBiXHvoHE5qk2unh2e/3Nun7UCvUvTWKLNUlKskrtfHuqlwWbsijyuFBliV0Xad7QgSzJ2ZyxcAkkSL/DE0dksL7aw75zCoGEywFuw9KAr+lFGTJu08zLc7C01MGMGveRj6ePQZFluifEs38maP5bmcB/7vmTlSjw+f1ZcvLKP6uGFehC8WsEDUsis7TOqOEe78Tmq5R5ijj+0Pfc0X3K5rrTyacRaaPSOO1Zd7CvaFutQ0zKDx8Ua+WaRjeZbZlCxaS+Mj/tNh7CMFLj05nX9m+tm5GhyGCKeGc9dOeQlx+llgFu+eluNrJxkNlAQv/Ajw9dQDTX1sTchpsi1Hm6sEp9EiMbPxgoVl8vumI32VgwX4foGnJRWwuDze+sYaNh8rQdL2uo187W3qgsJo/f7mDx7/cwb2TMrlvcg8xW9VEt4/NYN6aQ36fCyZYamwAJMwgc3FWZ37eXYime4NlTdfxaDoX9e3M3RO6MzAlipvfXMesCd3JTIioe60kSfRKtWEwF6KedrsoXlxM0eIiUu9KJaJfBO4yN/nv55P7bC4Zf8pANnjXE9s8Nt7a/pYIpjqoKLOR+TNHM+U/K6lxenySUQRiNsq8ctNQenVuud8bx7ZtaDU1hI8JbtZeaFkZ0Rl8n/t9WzejwxDBlHDOmrMkmxpn/V5LKNXi7W6V15flNBhMDUyN4ZWbhnDvB5sCFgr1p2uclaenDgj6eOHM6LrOa0tzfGpFhfJ9gNAzZ7lVjRnvrKfS4cHTQM+ndlP5f37JZu/xKl749RBkWQRUwfKoGj/uLuDVJdm4G/g7NxYsNcSiqDwwIY17Lh6Iy6NRVO2kyuHGajTQKdKE1eT9Of1g7SFqXCp3je/uc473d7+PR6tfUkG1qxR+UUjKjBQiB3o7u6YEE2n3prHvkX1UrKogdsLJfZV5VXnsLd1L77jeTfocwtktMyGCr+8fx6/fWEOV3e23FEAti1FBluGNW4a3eKHesvknEk/IYit+eyBqTbUu8a0XzkllNS525Ff4PB7Knhddh5/3FOJpJIHA+X0686sBoWXZOlxq46Wf9of0GqHpiqqdFFY7fR4PdQ9UKMlFJLzZGsvs7gYDqVPZ3So/7i7kia93BnW8ANuOlDPymZ/43cdb2XqkIuTlT8Gw4GS2/CWzN14JRzdhMsikxFjo0yWKrvHWukDqaLmdf32/j2enDUTxEwxvKtiEqtfv/Nr229DcGlHD6mcQVcwKkQMjqd5ZXe9xSZLYUbyjmT+hcDZJ7xTOskcm8/drB9Lb6CIMjYgwBYtRITxMIdykkBRt5g+X9mb1Hy9o8UBKrayk6scfiRGJJ9qNRGsido+dSpf/bQhC8xIzU8I5qaTGickg4z5tPU2oe15kWaLS4Wlw0+4Haw6xaPvxkNpnd2u8tiyb5FgL04eL1KUtrdLuxqhIuE6rsxzq9wGCrzEk6xqq6n9/TUN7tOxulYUb8rhuWBoDUgMnQBFgdXYJd76zvskJRU7SkdDRTxlfNODBgEovKY/fGj5hkrINnMA7V8AdiyB5cP0z6Dr/++k2ZozLoGeA5VTV7mqfx9RqFUOEAUnxDb4M0Qbsh+z1HnNrbqpcVfUeK7e5yC93YHN5CA8zkBprIdIsSi6cy0wGmSsHJdP/Lx9QOethjnTpTrXTjcVkICXGwtCuMS2yXLjc5qKwyonTrRFpNpAcY6H6y6+IGD8OQ3x8s7+f0DSSJJEenU5uRS4DEwa2dXPOeSKYEs5JblVHwveHJNQ9L7JEg6mt7S6Vpxft9tuZayypgd2t8eRXO7lqULJIPNDCpP/f3n3HR1Vmjx//3DI1PZCEkkAIoTfpRaooKhYUsWBXFLC767q74q7fxZ/rqrvuomvDgg2xi110QUWkIwgoIBBqqIGE9On398dQEubOZCYktJz368XrtZuZuXPj3Mx9zvM85xxFMY1qatNfCKKrnBVAMQ2kosnR8voMXp63iafGdo/6nBqajXvLuOX12gdSCuDARTpF3KTP4pdAK5599hV8Pi/nT5xEV/tubtC+4fuft/C3VV6+vzEOAMNbju/Vi/hlzDw65GQd/tt9b9l2iio8jB8cur3vEKsaOimjxWv4ynwYfiMkoPIV+9Djq9+mNUXDptswDIPFmwt5ce4mfszbh1VTUZTgirrXH2BExwzGD24tAflpzLt7N978fLqe1Zduev0N5wIBg3kb9zF1bh7LthRh0RVUFPyGgQKcv3Mbt44dU2/vL2rnUEU/CabqnwRT4rSU6LDgMylZHmvOi8fjw/3KVEo6dsDesQOWzMxqs32frdpp+rpYihp89csuLu2eGeNvKGKR6rSaFiOJ9XqIluVgAzK3r/p7Rpuj5TcMZv26m+IKL0nS1NnU47PWUREmkIqmOmN363b+youcoeRx6E/6faWUUgP6/jyZSYNsAHx/1LEVwOt18/WMJ7nWuIAre2cxsktTHp/1G2/d0vfwZ28mIy6D/LL8aj9z5jpRdIWSn0pI6nMk8PG7/JSuKiVjTEa15+uqjuptzPAn57K7xEWlx48BeI661r5YvYvZa/fSrkkCr9zQi0bxtrDnJU5Npd9+S/zQISj1GEit3VXCza8tpaRKftbRaVofp3bis++KOXvPcp68vJtMDp4kshOzJW/qOJGcKXFaappoJ84aeoOJJecFIMup4tAUij/+mK3XXc/6Pn3Zeu117H70UQ58NJPnv1kbtqhBNI1dyz1+nv8+r25/eREiJc5qWskq1uuhJpoKKU4LLVKdIYEUxJajpasK36/fG/M5NAQFpW5+WF9gmh9VsmQmhXNeIqnfFWTeOZ3mt71KQo+RVG5YXO15is9Fd/VIIHXI/QOs/GuBmwOu8MlXTtzcyGeUe3y8uWgrV0xdSFaqg3Y1VEu7qv1VOHVntZ9pTo30S9LZOX0npatKMXwGngIP25/bjiXVQvKA5GrP91am8vcPA2zZX07FwUDKTMAI5uD9uqOYkU/NY09JaDl2cWormz2H+OHD6+34y7YUctnzC9hV7IpY6MKn6rh9Aeas3cMVLyykMsJzxfFzaGVK1D9ZmRKnJVVVuGlgNs/M2YjrqEFttDkvTqvGHSM7kdb7gsM/8xUV4V67Ftfateyav5jt2gA4antYrEUNNhWUU+Lykig5DvVq4pDW/PmjVSEVHiNdD3FUcqb6C6lKKbpiUEo8K9SObPNUHzQ7LBoGBhd3a8bvz2nHmBcWmJ5DLDla3oBBYXn0jV0bkhmLzcufx1Kd8ZdASzYHmtBKrZ7v2KuZxtBsnX8tcPPIWfaw5xBPJf3VNSz0B1e71u8u5fYZy3nu6h5hKzEOzxrOZGVyyM/TRqahxWnsfnc3nr0eVIdKYo9EsiZkoVqqzHn6EqnYNh6X2xc2iDqaN2Cwr9zDVS8u4ou7Bx4uliFObf6SEipXriTzv0/Xy/E37yvnxleXhkwWRuLyBvhtTykTpi/j9Zv6SIuHE6xVUiu2lGw50afRIMi3qjhtje3dgv/O2Wj6WDQlkg0DLurWrNrP9JQU9AEDiBswgJJ95dienofPc2xFLqy6SnGFBFP17dxOTZg0c7XpY0dfD22UfMZpXzJKexwfGhr+YFNWq5WAz8NyRw9maKP4RetIcpyVi7s149IemcTbgl+p/jDV+2LJ0TIMI+xxGrr/rd1zzCt/KgEWBjqGBFMADw+zcea0cu7pG77wjI6fNko+CzmSAzn3twIem7WOSSM7mL7GolkY234sb6x5A7e/enXJ1CGppA4J34YBwL1/EB63+W070tZGf8BgV3El7y/L54YB2RHfQ5wayn6Yh7NXL9S4uHo5/qNfrqX86Io9B0W61ty+AMu2FLEgbz9n1nMVQRFZy8SW5Jfm4w14sagyvqhPEkyJ01ajeBs3DMjmzYVbY05Sd1g07j2nDQ5r+AGval7TIOaiBoaB9BQ6Dqy6yn+uOIM7ZywPWa08RCHAA/rbXKf9Dx0fFuWo53k8aEBv92J6W1ZDswFw5ZtgcVR7WqLdwq7i0G1VseRoWTSVJIfcAM2UVJoP8mKZyPCgU4z5QLRzusaFbXUe+9FDh7QjK0MzVnv590I36/YFSLApJKbPxNU/s0pRGT+vL9jC+ME5NA6TozSx20Tm75zPhqINeAPeGs/zEJsSh6d0IJgU1okmR9PlDfDiD5u4vn9LWTE4DZR9O4f4s+tni19N22hrutYqPH6mzs2TYOoEsxVtY2K5l4rPf0+SNR7i0iB3ODTtdqJP7bQjOVPitPbn89ozpG0ajhgSYh0WjUu6N2O8SdPNqlLirCFJ31B9wBwNrz9Asgyaj4uzO2bwt4s7YbeYffUZ/FOfyrXabByKJzSQOuq5eMthyzx4dST4qq8ynNMxA5se+h6x5Gj5Awb9cqTUsJlwcw9VJzJqohBcXQpn8lA7Ly33sKMkOKL890I3985yMWmQjT1/SGDDPcn07Nk1NBdLgXeWbAt7XKtm5cVzXqRNchtsWnRFIeyancEpd6MrofOfseRoFlV4WLK5MKr3FCevgMdD2Y/zSRhWuwbUNZmx2Pz6jeVaW7y5kF3FlabHEfUo4Ic1n8JLZ8ELA7lxz3aSlr8Bi56D7/4O086FZ/vBynfAJ9vI64oEU+K0pqoKz13Tg8t7ZWLTVSwRVoAsmoJNV5kwOIdHL+1S4+xtot1C+6bHXtSgS/Mk4myySHy8XNWnBc9f25P0BBtxVVYeJ2qfMlJbglMJbe4bls8Fe9fAR+Or/fi6/i3DviSxz2hSzhpH8cJ3yf/vNeQ/fyOlyz/H0ab61rQeLVLISnWGOUrDFq4yXSwTGVa8pCilYR/PTVW5spOFp5d48AXgoe/cPDvSzugOFuKsCmgW4nN7hwwiXd4A0+ZvibhFM8mWxBsj3+CaDtcQZ4kLKUoBoCoqds1O6+TWPDn0Sdwl7UyLAMSytbHC4+eH9QU1Pk+c3CoWL8HWujV64/pZ+fnql13HvI1WUxXmb9xfH6cnwvFUwPTLYOZE2PET+FxYjCqfY8AH3kooWAuf/x5ePhsqZHKlLsgITpz2VFXh4VGdufnMVry6YAvvL9uOejBQUgB/ZSWq1cJ1Z+ZwXf+WNEt2RD5gFbcNyeWPH6wMGeREW+QizqoxcWjrY/4dRWyGtUtn0QPDmZ+3jxfm5rEibzd36Z+EBFLZU0qp8MLme+KDA2jg5eUeplfpO4TPBetnwf48aBT8LDMS7fTPacTc9QWmW0FrytmLs2pMGBJ5ZbQhu7xnJmt3lYQkx1edyFBUDXur7iiqjmvLz7i2raoW+PjQOUv9OeL7PDTExpurvJS4DVw+uLTDkVtmAIXvAuZ9wCo9fnYeqIwYDNs0G7/r+TvuOOMOZm+dzYx1M9hbsRe33028JZ4ujbtwfafr6dioIwDPlS8yPU6sOZp7y2KYLBAnpdJv55Aw/Kx6O35Jpfn201iuNZ8/wIEKWfk4bnweeP0i2PNL8J5UE295MKh6aThMmAv2xPo/x9OYBFOiwchuHMfkizvxwPnt+Xn7AQ5UeIPbhb78hE4WF83OGxnzMUd0yuCBj8xXsKIpcmHRVIa3T4/5fcWxU1WFQW3SGNQmDePnGfCFBiZjCL8BTy32HO47ZCrgh8VTYeQTh3806YIOLNlSGFM1LACr4adTkyQGt0k7/DOX189nK3fyxsKt7Clx4fUHiLPpdM9K5tbBOXTNTI5wxNPPqDOaM/mzNaaPRTORoSpwtmUNqUetTG25t/pKc1aSiusviby1yst937jQD65suw2d6f6z8Ya5hWqqQnGll6wofherZmVkzkhG5kT+/tGi2NoYzSBXV2VDyqnMCAQom/MtLV5/rd7eI9yujNiuNeXwpKU4Dr74Pez5tVogVeNkoN8DJTvgvevh+o9P1JmfFiSYEg2O3aIdzkXJKyjjBWc2922toGLSVwQwcBx8fMLgHPq0So243c+iqTx5xRnc9fZyXN5IOTahbAT49xVd0as0+dxUUMa0+ZuZ9ctuytzBJPt4m845HZswbmA2uemR+9iI2lHmTwnO1Jm4f4CVJ+a7ub23lWR7mGsh4IUVb8I5k8HiwO3zs3ZXCY3irVQURp83YNNVmnnLmbzwFYzrz8Blc/DErHW8s3Q7QLXArKjCy84Dlcxeu5dmyXYeOL8DZ3fMCHfo04rDqnFZz+a8u3Q7Xn/o2l9NExk2XePWAc3gJyd4K2p8v0ZOhX0VBr6Aga4q+FF50zci4msiNe+tjfRE8zLtsRQ1URXISJDmvaeSEpeXvSUuKj0B4u06Kfl5qPHx2Fq1qrf3THFa2HEg9HsrlmtN1xRS4iQX+Lgo3w+r3wvJ3YUoJgP9bti2EAp+g7R29Xyipy8JpkSDtHFvGfe9/zO/7SrFFwjgU60cKl1U4fHz3bq9LNq0nxSnlUcu7cywduFXj87pmMFDF3bk4c/XRB1Q2XWVu3cvpP0bCzEe+X+s21vOAx+tZu2uEvyB4KDtEJfXw/vLtjNzRT5tMxJ49NIudG6edGz/AcQRgQDsWx/24Wj7DqFoULiZL/cm88cPVmNghPS0CkdTg4PvfjmNeHbscEoeW8Pqmyfyp77j2FLkMs1fgCONWfMKyrnz7eXcM7wNtw3Njeo9T3V/GNGOOWv3sqfERSwV5IMFZprTfcQIOPA/2Di7xm0x/TM1bDp8vM7HyA5x/N57G7sIXxzE4wuQGhe+rHptXNK9Od/8ujtkS3EsWxutusp5XZrU6XmJumcYBku3FDH1hzzmrd+HRVdQUILtErxeLugzljsKyshJi6+X9x/dI5O8gt9CquDGcq35AwZD2squi+Ni+WuEK4EQ3WSgDxY9DxdNqbdTPN1JMCUanJ+2FnL9K0uo8PjDNr40CAZVFZ5Kbpv+E3+5oCPX9gtfVODqvi1pkmTn/vdX4fL6w3aLj7NqOKwa/7y8G0OyhrD9jjv49P6/80BcHyoilG/3HQywVuUXc/kLC3n+2h4MjRDgiRh4y0HRwQhfpjqavkMoCq8t3cNji7fGtEqpqTCmZxbjBraibUZw5VGd9Beun/wJm/eU4osyF8blDfD0nI0k2HSu7Z8d9fufrLbuL+frX3ezt8SNL2CQlmCjf+tGdM9KRlEUkp1W3rm2K6OnfEexxYnPpGT40RwWjWHt03jkks7BZZoxr8KH44IBVYQVqiS7wuShNm7/0kXfwHmsbNkNRfWZDiIB2jVJIK2OV4AG5TbGadVNv1uizdHMaRxP+yaSG3Ey27o/2Cx3T4mLyoP3qOofucon7hQ+f2oe/Vs34tmre9R5AaPLemby+Kx1po9Fu412ePv0Op9QECYMIxgI+cx3QEQ1GRjwwap34Lx/hLT5ENGRYEo0KBv3lnL9K0vCBjtmXN4Aj3yxhi8skGsAABxBSURBVNQ4CyO7NAv7vLPaZ7DkwbP5YX0BL8zNY+mWwoP5CcFAqG+rVCYMac3gNmmH+0odeOgJ/vTCAlwx9MGq9PqZOP0n3r61H91bpET9OhGGZgMj8n//cH2HqprlPYPHFlXgMm+BFJZFU+mWmXw4kAL4z+wNbFOc+NTQoCxSw8xKr59HvljLkHbpp2QlQMMwmLN2Ly/MzWP1jmIChnF4G5+qBLfnpSfauG1Ia0Z1yUCZPInpLVrzWOYwFm0qxDAMPCbb/pwHqzZOHNyau4bnHtm6q1vhijdgxXSY9ySU7cHwVqJUmWZxGTqg0LNvNzLt2XwzfyHeT6+JXFRmSN0XlVFVhXGDWjHlf+tN+6TVtLXRKUVNTnrrdpdw+QsLKXf7Iq62+gzw+QIszNvPRc/8yMzbz6zTnnRJDgsjuzTl05U78JvMC0WzjfaWwXKtHRfeyhor8kU3GahCyc7DRZREbCSYEg3K3e/8HLYgQKRBqssb4L73VjG0XTpOa/g/G01VGNY+nWHt0/EHDEpdwdWOBLsF7aiy7IZhcPu7q3GZ9I6JdC4QDPAmvPkTix4YLg1/j5VuBd0RNmfqkMlD7fSYWsZ9/UNXHHyGyp8rrsVss1g0n+X/+3wNl3RvhtOq4/b5eWvxVtOtfdE0zPQbBq8t2MJfL+wY+3+LE8jt83PXjBX8uHGf6d/ooS2NW/dXMPmzNTw/cylTFBtnTPoDr+s6Ow9U8sbCrbyzZBvFlV4OxUs5jeOYODSXC7s2xW7Wb05RoMd10P1ayF+Gf86/+XXjb2iqn2LDyU+Bdkz3n81eUqAjNO14ZegxqtBUhRGd6id37cYB2cxcsYNNe8vwxrC30aardM9K5sKu4SeDxIm1p8TF2BcXURrDbIzbFyC/sIIbpi3m/YkD6jRP74Hz2zN3fQFF5Z6wOzjMOCwaI7s0oYdM9B0f7lLQLOALPyEYzWQgigbukno6ydOfBFOiwVi3u4RNBWWmN4ZoBqmKAh+v2MHVfcNv96tKU4NbkcJZtKmQIpPSsdGcC0C528e8jfsY0jYt5BgiRt2uhOVvBgtJhFG171CX9Oo3pDmBHqaV3aL9LKteW1+t3o3ZRXqoYWajkffibDfg8M+duX1x5vY9/P+9foN3lmzj/nPbmQcPJyGfP8BNry5l+dYi01WXo1V6/eQHLEzMHsVX7gCpOjRLdvDn89vz5/PbEwgYePwBbLpaY7+4Qwyfj32fLaXoza1UTJjEzXnxITkjNbFbVJ67pmedF584cnyNGbf0ZcwLC9l5oDJsLt3R59QuI4GXbugVMqEjTh5TZm8IG0hFmpDx+A3W7ylj1i+7uahb3QXL6Yl23rmlD5dN+ZZyRScQRVtSh0VjQG4jHr+sa52dh6iB1RmsJluDSJOBABgBsNZPDl5DIDVSRYPxyrzNppW/ou3qXuHxM3XuJgwjlnm68F78IY/Ko2bgY+kwX+7xM3VuXp2cS4PX73aIIjfpoSE2yj2hn/8L/lGUU30/eiyfZdVr682FW4+5OauiKKdUc9bHvlrHim3mgVT5mu/Z9fq9bPv3GPKfuY497/0frvxf8asahRVebnp1SfCJe9fCt4/AzNtQP7oF++xJKOu/jmqg4Vqzhs2XX0Hlzz/T6qMP6X/jGF65odfh7YHRsFtU/jmmGwPb1E8j1UMaxdv47K6BDGuXhlVXsenmt3G7JfjYqDOa897E/hFX1MWJVe72MXNFfrXCQ4eULJlJ4ZyXSOp3BZl3Tqf5ba+S0GMklRsWH35OhcfPC/VwL0h5/3Ve3vUVXZonY7eoYcvzOw5ea9f3b8lL1/WqVqFW1DNrPGg156ZVnQw05fdAfMOoBlsf5NtVNBhfrN6F3+RmFcsgdU+piy37K2jVOO6YzsXl9TNvw76QBYhYzgVg6ZZCytw+4us4AbnBadwGmnQNdo2vkj8Vru9QVR5DY2UgtExxrJ/l7hIXe0vd7Ck1rywXU8PMQIA9padGc9Yyt4/pYYp21LSy5/P7ydn7Pyr++yecxXnBAUHV/LcV00G3Qd/boPc4cKZWO37A42Hfc89x4L33Sf/j/SSNGnV4JWtAbmM+vG0Af/xgFRv2luL1B0LyRw7lcTVLtvPopV3omxO+wl9dirfpvHBdL/aUuJi+aCtvLtpKcaUXXVXw+YPFOm4e2Iore2WRIkUATnof/7zDtCdTtKvREGzz8dvuUto1qZv2GSWzZnHg44/p+f77fNKoEev3lDLtx83MXLEDf8BAUxW8/gBNkuyMH5TDZT0zSbBLKfTj7tA25aWvRNxZAUeakJscBNqcI417j4GMwESD4PMHwm7ZiWWQatFU9pe5jzmYOlDhRdeUkJnIWM7l0PkUlXskmKoLY6bB1EFQWRTTy4r1xlh8KkdXQY/1s9TcLlZdfQMVbUaDLXS7RSwNMwMBA3eMW9QA9pe5+Wj5DjbuLaPE5SXZaaVj0wRGdW9OYjQDpZ0rYP3XULo7eJOPbwJtz4VmZ4R9yczl+bUaSNrw8IzlaQYov+LcHyZw9JQF/837Fyx9CW78EhoHS8dXrlzJzkkPYm2VTauPZ2JJD62O2aFpIp/dNZANe0qZNn8LX63eRbnHh2EEe10NbZvGLYNy6JZ1YpomZyTauW9EO+4b0Q6fP0C5x0+8TZftfKeYL1fvMs0TjGVCxh8wmLehoE6CKdfateye/DAtXnkZvVFwgqBtRgKPXdaVf4zuQrnHT6XHT6JDx6afGluJT2t9J8BPr4UEU9FMBgJgccKAu+vxBE9/MgITDYLfMFAwTUWJsas7plsFY+X1B0wHkLGeiwJ4zMotidglZ8FNX8GrI8FVXGOFPyB4E7roWXjPB1T/HGL+LO02mv79EZJmbaOwKHR1KraGmSqJMVT3Wrn9AM99v5HvfitAgWq5OA6LxiNfrOWCLk2ZOLR1taqDAPg88MsH8ON/oHh7sHGkcfD1igrz/wPJLeHMe6HzZcGCH1W8NG9zzANJDT+vWP5JT3U9DiXybGzwHF1Q5oaXzyJwwzcUvPEpxZ99RpMHJ5Fw3nk15lW1yUjgH6O78I/RXWp+rxNE11SSHLK96lRUVG6+9SqWCRmv32B/WZgtXDHwFRaSf+ddNPnrX7B3DC1ioygK8TZdJvBOJqk5kNUXti6EQIzXgKJCUvPg60WtyTevaBBsumYavED1QWpNAoZRJyVoEx0WvCZBUCznAuAN1M35iIPSO8Bt86HDRcGS6bpJzw3NCrodsgfCzV+T1HE4PpMAO+bP0lBIb5VF39w0zFIOqjbMrFi/kIDXheH3UZm3jKLvplV7rs9v0DXKxs7TftzMlS8u5Js1e/D4AiFFDSq9fty+AJ/8vJNRz/zIZyt3VnmwCKaNgC/uCzY+9lYeCaQg+L+9lVCwDr74Pbx6HlQeOPKwYZBfZN7fKdJA8g/6e/RQN1QLpLKnlJL+z9JqOW0vL/cw9LVDVRoNDFcp/qcG49u7m5xPPyHx/POjLlAhRH0Jdw1WnZCJxrFWdjU8HnbcfQ+JF15I4siRx3QscZyNeRXiGgWr8kVNAVsiXPM+yPfgMZFgSjQY3VuYb8WJZZCqoNA6/di2+AEk2nXSE0Ib6MVyLgDJDgupESoGilpIbAZXvA73rYMhf4LG7SAuHZyNIaUV9LkV7lgCN34BTbti1VV6tAy9tmL9LDOTHaQn2hk3sBUW1fyrObHPaFLOGkfxwnfJ/+815D9/I6XLP8fRpvrqjaLAuNeX8eQ3v7GpoCzsr/ryvE388+vfcHkD1FRXxW8YVHoD3P/BymBA5S6Dl8+BPb9GbHh7mLcCdq+GaeeCJxjguH0BlDDNdsMNJO24uV77BqcSOgPrN+CpxeFnZhUC6HEKzSech56aGvZ5QhxPjePNK6zFMiFj1VUaHWN+3O5HH0VNSCDtHtnydcqJawS3zIakzOBEYE1UCzgbwc1fQ0p2vZ/e6U7WaUWDMXFIa9bsXGFaKS2aru4WTeGavi3qZI+4oiiMH5zD47PWhWxxiuZcIFhB6dZBOdJnqr44U2HQ74L/ahDu2or2s3RaNSYODTZLbJORQG56PL/sNO/5UVPDzDirxmOXdSEnLZ6Plu/giqmLyEp1MLpHJhd1bXq4XP+SzYU8+c1601zCmnqu/fGDVXRo9SG5B7YFiz4clD2llAovbL4nnjhr8Lp8ebmH6au8fH9jXPC5RVvgo/Fw1VtYNZVAmC424bY1XqSFH1jeP8DKE/Pd3N7bSrLd/O9C8btg/hRoOyLscYQ4ni7t3pxlWwpDvj+qTsgoqoa9VXcUVce15Wdc21ZVqwqqAMM7hOb9Ravo7bepWLaM7HfeQQkzmSNOckmZMHEe/DglWJDC8AdzRquyHpwMPuNaGHQfJEgFv7ogwZRoMIa2S8dm0UyDKah5kKoqCtf1D/aYKih1M2PxNhZt3k9JpRebrpGZ4mBsnxb0y0mNauvQ6B7N+cdXa2t1LhBsYnpFr6wa30fUv0jXVjSfJcBFVRqqThrZgZtfWxpVz6WqNFUhLcHGuZ2aYtVVOjVL4oHz2zNvwz4+XJ7PE7PWMTC3MaN7ZPLq/M2mgVQ0vbE8Pj8vbUrhcT208MOh1aFJg8LMjvpcsH4WrHgLpXkvUqwKhe7QgCrcQLLttneZvLWYJ84JXdnt1UxjaLbOvxa4eeSs0McP2/ETFOcHBx9CnGDnd2nCgx//YvpYtBMyXTOTaNkodNdEcYWX1TuKKXEFqz02TrDRLTO5WpGS8iVLKHjmWbJnvIUWL72GTmn2JDj7/2DYJFj3Oax6D8oLIBAIrl51vAQ6jwaLyRZ2UWsSTIkGQ1MV/u/Cjvzpo1WmZZgjsaswukcmJZU+bv1s2eEePlXzS1ZsK2L22j0kOSxMHNKaa/u1jFhVK8FuYeLgHJ77flPMRSQcFo2bzswmySn5UicDTVV48vJu3PbWT7FfWxaV/zeqM44qPY0G5DbmLxd25JEv1kR9PE1VSHZYeGd8f6xVeg/pmsqw9ukMa59OcaWXL1fv4r/fbmBVfnHIMaItxew34BNfP/6qvU68Ur1YRjSrQ0bABx/fiREw+NyfxhPqFXwR6BPS+NhsIPlh03IeGhz+un94mI0zp5VzT98IW540K+zPk2BKnBRsusbYPlm8sWALHpP8y5omZJy6woTBrav97Jcdxbz0wyZm/bobq6ZiEFy9CmBg1zXGDWzFVX1aEF9UwI7f30ezJx7H2jK6hvTiFKBZoNOlwX+i3kkwJRqUUd2bs62wgue+30hllINUu67QpSCPHoU+Rj+/A7fPb5pfYhBsnljh8fPYV+uYvXYPL17Xq9oguapSl5f5efvxx9gE2GHRGNY+jfvPbRfT60T9GtY+nb9d1Im/ffZr1AGQ3aJy11ltuKxn6KD+2n4tSbDp/OmjVRgGIYUhqoqzaaQn2HlnfD8yEsOvyCQ5LIzt04KdBypZu6skpDJlLKWYVQw+9/fjKv37aj+PZnVIAVACKBo00/byiPEyk3iTazwPssGo/t/i6IHkl7ZrUJXwfzOd0zUubKvz2I8eOqSF265kgLu0xt9RiOPlrmFt+Gr1LnYVuzBphxiWTTHovGcDvTYb0PECXF4/t7+1nIV5+3H7/ARMvjvK3X6enrOBp+Zs4J5t33LVrbcQf+aZdfwbCdFwyMZY0eDcNbwNf72wEzZdxW4J/ydgURVsusol3TMZd9VgHvzNwOU1D6SOVun1s2RzIeNeX4rPZNWpxOXl4mfmszK/2LSRcDiaAlf2zuSZsT2kCtlJ6Ko+LXjumh6kOC3E2cLn1sXZNOJtOo9e2oU7huWGfd6o7s354f5hTBiSQ5LDQrxNx2nVsFtU4mwadl2lW2YST17ejf/9bnDEQKqqNSaBFMRWirkCO+sN822mDw+z8d8lHgrKowsq4xUXjSlmpvUhOimbIz7XTc2rsZOH2nlpuYcdJeH+thTTXl5CnChJzuCqcuN4G3qUebB2i0rnFqm8eNfZ7PvPFLY++jhjnpvP/I37qPT6IwZlroOVO59qNogPWkVutSCEiExWpkSDdHXfFpzfuQnvLtvOK/M2U+7xHd6SZxjBEuhX9MripjOzSbBbGPj4t7i10EFcpER9ty/Aim1FPPf9Ru4e3vbwawIBgxunLWFHUWXM2/t0TeW8zk2l6MRJ7Kz2GSx98Gy+XbeX5+fmsSq/GKumAAoef4C2GfHcNjSX8zo1qbYdL5z0RDu/P6cdd5/VhgV5+9lT4sLlC5Bo1+ncPInWabEHBaUun+nPY+2NVWSYV7aMbnWoOlUBp+HiLeujjHA/wV5STJ+300iltbI74rFyU1Wu7GTh6SUeuqSbvL/fIxWsxEknK9XJl/cM4o4Zy/l52wECAQOvSURkt6gEjGDhiodHdcaiqSR88D7XPvwB6w8U4VGjH9q50fj3/9bTslEc53ZqUpe/jhANhgRTosFKibMycUhrxg/KYc2uEgrLPfgDBklOCx2bJmK3BAeTz3+/kYDJclQ0ifqV3gCv/LiF24bmYjnYPGhB3n7W7S41DaQiBWcQ3K7x8Odr+PLuQfX1n0XUAV1TGdGpCSM6NaHS46e40nu4R1lcLZtd6prK4LZpdXJ+4RpuxtIYGCBZKQ/72OShdnpMLeO+/sFCFDVW+uNIQHWL/gWP+q41Pe4r/pE8qLxFnBJa/KKqh4bYeHNVmIa+TbpIMCVOSo3jbbw7vj+b95Xz6vzNvL8sH68/gK4qeP0GKXEWbhnYiit7tyClSin0teUKK+Ka4zHZYlzTfcXlDfDwZ2sY0TFDdjwIUQsSTIkGT1UVOodpcBoIGLz84+aQHJhoE/UBfP4Ac9bu4bzOTQGY+kNeSDl0iC44A9hUUMZvu0tp1ySh1r+zOH4cVi1s3tyJ0r5JAvM2FIRs9YulFLMDF62VnUcf+jCz1aEaK/0BVsXP1dq3POcfwwGj+rZFp0Xla4bwN/Utjq6ovuXe6n8PWUkqrr8kmrxBPJx5b9j3F+Jk0KpxHA+P6szkiztR7vFT4fGRaLccnuQ72kvzNuH11W7SD6CowsOSzYX0zWlUb7+TEKcryZkSIoLFmwtxmZSPjiVRv9zjZ9r8LQDsKq5kyebCkOccCs5Sz7kNZ7sBqFY7iqbjzO1bbQAL4PUFeHneptr9QkIAY/u0QA0zAx1tY+AAKhdrCyK+z0NDbJR7jgzw7h9g5V8L3BxwRc4TdPg9TLV9ydlNLfTISqJbZhJD26bx4IUdmffXC7H2vgH02pT2PZgr1fa8WrxWiONPURTibcEm7+ECqeJKL7N+2R1SzCiW+0qlx8/UuXJfEaI2ZGVKiAjyiyoImKQ1xZKoD7C9sAKA5VsPYNHUkOpKsQRnfiO4VVCI2spKdXJGVjKLTQJ7qLkUs6bABdpSEpXKaj+PtDqUPaU06j5Qmh6gbzedvpeGaaw74u+wYznsXg3+yNv9qrHGwfWfgia3PnH6+HVHMVb92O4rBrBsq/n3gRAiMlmZEiKCYEWk0Fn0qon60ago2M/6QYPY+MhjeCsqQh6PNTgrc5sXEBAiWvcMb4MjzEx3Tay6xoQmG0CNvc9Z1JX+yveFf0y3wvWfQFZvsDhrflPVAo4UuOkrSJOWAuL0UuLymlaZjfW+Yrb9XAhRMwmmhIgg3qabNt6tmqgf1XHSUmn1wYdkjB+HZgvNF4k1OLNokiQsjs2A3MbcNrR1zAGV3aLy8KhOtLvhGXAkgxLbbaRqpb+ItPB5VUBwu951n8C5/wgWk7A4OdjB6ghrXDBHqs94uH0RNO0a07kKcSrQVRWzXbux3ld0ua8IUSuy10GICNpmJJjO+MWSqK8AHZsnY8lIp8kBULVd4Ku+shRrFbXUKlWchKitu87KRVXg2e/yqDTJDaxKUcCmq0y+uBOX9zrYX+qW2fDaBVBWENN2u6Mr/YW+mQpJoY2MQ2g69LoRet4A+Uth5dtQsjNY+tzZGHLPhk6XgF5DYCbEKSwtwWa6gyLW+0qSI/aVZiGEBFNCRNS5eRLNku3kFYSWgE7sMxo1LoXihe+y7/N/oVgd2DJySex/ZbXnOawatw7KAaB/60ZHFyEDYqyiZtEY27tFnf6eomFSFIU7z2pDr+xUnvl2I0u3FGIYVCvbb7eoGAYMaZfGncNy6ZqZfOQAKdkw8UdY8AwseQkMP3jKanzfGvtA6XbodlUsvwhk9Qn+E6KB6dI8CadFp9xdfUIklvuKVVcZ0zOKCQwhRAgJpoSowcQhrfnbp79SbrKfvKZEfYBUp5Xe2cEGpDZd4+o+Wby2YEtIWepog7OAYXBZL7npibrTL6cR/XIasfNAJe8u3c6GvaWUVvpIdFro3CyRy3tl0Tg+zOqOIwWG/xWG/hnWfQG/fgQbZoM3fA8qqKEPVFImNO9xjL+VEA2DqiqMG9SKKbPXh7TxiPa+ogDX9cs+fictxGlEgikhanBRt2Y8Mes3Krx+0y1/kTgsGn84t221Rog3DMjmzYVb8ZqsUdUUnFlUhQu6NCXRLtsxRN1rluzgd+e0rd2LNUtwS12nS+DXmfDxHdUCqqj7QFmcMPB3tTsHIRqoK3tlMWX2etPHoqnO2S+nEU2SwlfYFEKEJwUohKiB3aLx9vh+xFn1o9PbI3JYNK7u24JLuldfRcpMcXL/ee1iTvxXFWicYOOhizrG9DohjrsOo6D1sNh7Qel2aNEful5Z83OFEIelxFl59NIu2C2xD+uSHFaeGCPFWYSoLQmmhIhCbno8H90+gNQ4a403K1UJBlK3DGrFXy7oYPqccQNzuGVQq6hvfBZNIS3BxnsT+pPslOIT4iSnqjBmGrTsD5YoAyqLA5r3hCunQ5SlnIUQR4zukckfz20f9X1FUyDFaeHdCf3ISJRVKSFqS4IpIaLUNiOBb/8wlPvOaUdago04q0bVqukOi4pNVzmvUxNm3NqX+0a0q7a972j3jWjH46O7kpFow2k1HzzadBWrrjK8QwZf3TOYrNQoeuoIcTLQbXDNB9D/zmB5cmu8+fOsccF/vccHG+pa5RoXorZuHtiKZ6/uQWaKA6dVMy2Zbj14XxnUJo2v7hlMm4yE0CcJIaKmGBGSQHr16mUsW7bsOJ6OEKeGQMBgft4+Vm4/QGG5B6dVp0mSnZFdmsZcttwwDBbk7Wfq3Dx+zj9ApcePRVNJcVoZ2zeLsb1b0Chc8n8tKIryk2EYversgCeIfD+dQnxuWPMJLHgairaA1wUWOyRlwYC7oNOl0a9gidPa6fD9dDJ8NxmGwU9bi5j6wyaWbC6kwuNDV1WSHDpjemZxXf+WsholRAwifTdJAQohakFVFQa1SWNQm7RjPpaiKJyZ25gzcxvXwZkJcRLSbdD1iuA/IUS9UxSFXtmp9MpOPdGnIsRpT7b5CSGEEEIIIUQtSDAlhBBCCCGEELUgwZQQQgghhBBC1IIEU0IIIYQQQghRCxJMCSGEEEIIIUQtSDAlhBBCCCGEELUgwZQQQgghhBBC1IIEU0IIIYQQQghRCxJMCSGEEEIIIUQtSDAlhBBCCCGEELWgGIYR/kFFKQC2Hr/TEUIcBy0Nw0g70SdxrOT7SYjT0in//STfTUKclsJ+N0UMpoQQQgghhBBCmJNtfkIIIYQQQghRCxJMCSGEEEIIIUQtSDAlhBBCCCGEELUgwZQQQgghhBBC1IIEU0IIIYQQQghRC/8f3N5qxqo3NRwAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Take a batch of graphs from the test set, and compute predictions.\n", "graphs = next(datasets['test'].take(1).as_numpy_iterator())\n", "labels = graphs.globals\n", "logits = train.get_predicted_logits(eval_state, graphs, rngs=None)\n", "graphs = jraph.unbatch(graphs)\n", "\n", "# Create plot to visualize individual molecules.\n", "ds, ds_info = tfds.load('ogbg_molpcba', split='test', with_info=True)\n", "fig = tfds.visualization.show_examples(ds, ds_info,\n", " node_color_fn=node_color_fn,\n", " node_label_fn=node_label_fn,\n", " edge_color_fn=edge_color_fn,\n", " rows=10)\n", "\n", "# Update plot titles with true and predicted labels.\n", "for graph, graph_labels, graph_logits, ax in zip(graphs, labels, logits, fig.axes):\n", " label_for_task = get_formatted_label_for_task(graph_labels, task)\n", " predicted_label_for_task = get_formatted_prediction_for_task(graph_logits, task)\n", " ax.set_title(f'True Label: {label_for_task}\\n'\n", " f'Predicted Label: {predicted_label_for_task}')" ] } ], "metadata": { "accelerator": "GPU", "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/ogbg_molpcba/ogbg_molpcba_benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark for the ogbg_molpcba example.""" import time from absl import flags from absl.testing import absltest from flax.testing import Benchmark import jax import numpy as np import main from configs import default from configs import test # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() FLAGS = flags.FLAGS class OgbgMolpcbaBenchmark(Benchmark): """Benchmarks for the ogbg_molpcba Flax example.""" def test_1x_v100(self): """Run training with default config for ogbg_molpcba on a v100 GPU.""" workdir = self.get_tmp_model_dir() config = default.get_config() FLAGS.workdir = workdir FLAGS.config = config start_time = time.time() main.main([]) benchmark_time = time.time() - start_time summaries = self.read_summaries(workdir) # Summaries contain all the information necessary for # the regression metrics. wall_time, _, test_accuracy = zip(*summaries['test_accuracy']) wall_time = np.array(wall_time) sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) end_test_accuracy = test_accuracy[-1] _, _, test_aps = zip(*summaries['test_mean_average_precision']) end_test_mean_average_precision = test_aps[-1] _, _, validation_accuracy = zip(*summaries['validation_accuracy']) end_validation_accuracy = validation_accuracy[-1] _, _, validation_aps = zip(*summaries['validation_mean_average_precision']) end_validation_mean_average_precision = validation_aps[-1] # Assertions are deferred until the test finishes, so the metrics are # always reported and benchmark success is determined based on *all* # assertions. self.assertGreaterEqual(end_test_mean_average_precision, 0.24) self.assertGreaterEqual(end_validation_mean_average_precision, 0.25) # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics({ 'sec_per_epoch': sec_per_epoch, 'test_accuracy': end_test_accuracy, 'test_mean_average_precision': end_test_mean_average_precision, 'validation_accuracy': end_validation_accuracy, 'validation_mean_average_precision': ( end_validation_mean_average_precision ), }) self.report_extras({ 'model_name': 'Graph Convolutional Network', 'description': 'GPU (1x V100) test for ogbg_molpcba.', 'implementation': 'linen', }) def test_cpu(self): """Run training with test config for ogbg_molpcba on CPU.""" workdir = self.get_tmp_model_dir() config = test.get_config() FLAGS.workdir = workdir FLAGS.config = config start_time = time.time() main.main([]) benchmark_time = time.time() - start_time summaries = self.read_summaries(workdir) # Summaries contain all the information necessary for # the regression metrics. wall_time, _, test_accuracy = zip(*summaries['test_accuracy']) wall_time = np.array(wall_time) sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1]) end_test_accuracy = test_accuracy[-1] _, _, test_aps = zip(*summaries['test_mean_average_precision']) end_test_mean_average_precision = test_aps[-1] _, _, validation_accuracy = zip(*summaries['validation_accuracy']) end_validation_accuracy = validation_accuracy[-1] _, _, validation_aps = zip(*summaries['validation_mean_average_precision']) end_validation_mean_average_precision = validation_aps[-1] # Use the reporting API to report single or multiple metrics/extras. self.report_wall_time(benchmark_time) self.report_metrics({ 'sec_per_epoch': sec_per_epoch, 'test_accuracy': end_test_accuracy, 'test_mean_average_precision': end_test_mean_average_precision, 'validation_accuracy': end_validation_accuracy, 'validation_mean_average_precision': ( end_validation_mean_average_precision ), }) self.report_extras({ 'model_name': 'Graph Convolutional Network', 'description': 'CPU test for ogbg_molpcba.', 'implementation': 'linen', }) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/ogbg_molpcba/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.4.1 jax==0.3.4 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==0.3.2+cuda11.cudnn82 # Make sure CUDA version matches the base image. jraph==0.0.2.dev0 ml-collections==0.1.0 numpy==1.22.0 optax==0.1.0 sklearn==0.0 tensorflow==2.11.1 tensorflow-datasets==4.4.0 ================================================ FILE: examples/ogbg_molpcba/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library file for executing training and evaluation on ogbg-molpcba.""" import os from typing import Any, Dict, Tuple, Optional from collections.abc import Iterable from absl import logging from clu import checkpoint from clu import metric_writers from clu import metrics from clu import parameter_overview from clu import periodic_actions import flax import flax.core import flax.linen as nn from flax.training import train_state import jax import jax.numpy as jnp import jraph import ml_collections import numpy as np import optax import sklearn.metrics import tensorflow as tf import input_pipeline import models def create_model( config: ml_collections.ConfigDict, deterministic: bool ) -> nn.Module: """Creates a Flax model, as specified by the config.""" if config.model == 'GraphNet': return models.GraphNet( latent_size=config.latent_size, num_mlp_layers=config.num_mlp_layers, message_passing_steps=config.message_passing_steps, output_globals_size=config.num_classes, dropout_rate=config.dropout_rate, skip_connections=config.skip_connections, layer_norm=config.layer_norm, use_edge_model=config.use_edge_model, deterministic=deterministic, ) if config.model == 'GraphConvNet': return models.GraphConvNet( latent_size=config.latent_size, num_mlp_layers=config.num_mlp_layers, message_passing_steps=config.message_passing_steps, output_globals_size=config.num_classes, dropout_rate=config.dropout_rate, skip_connections=config.skip_connections, layer_norm=config.layer_norm, deterministic=deterministic, ) raise ValueError(f'Unsupported model: {config.model}.') def create_optimizer( config: ml_collections.ConfigDict, ) -> optax.GradientTransformation: """Creates an optimizer, as specified by the config.""" if config.optimizer == 'adam': return optax.adam(learning_rate=config.learning_rate) if config.optimizer == 'sgd': return optax.sgd( learning_rate=config.learning_rate, momentum=config.momentum ) raise ValueError(f'Unsupported optimizer: {config.optimizer}.') def binary_cross_entropy_with_mask( *, logits: jnp.ndarray, labels: jnp.ndarray, mask: jnp.ndarray ): """Binary cross entropy loss for unnormalized logits, with masked elements.""" assert logits.shape == labels.shape == mask.shape assert len(logits.shape) == 2 # To prevent propagation of NaNs during grad(). # We mask over the loss for invalid targets later. labels = jnp.where(mask, labels, -1) # Numerically stable implementation of BCE loss. # This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits(). positive_logits = logits >= 0 relu_logits = jnp.where(positive_logits, logits, 0) abs_logits = jnp.where(positive_logits, logits, -logits) return relu_logits - (logits * labels) + (jnp.log(1 + jnp.exp(-abs_logits))) def predictions_match_labels( *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs ) -> jnp.ndarray: """Returns a binary array indicating where predictions match the labels.""" del kwargs # Unused. preds = logits > 0 return (preds == labels).astype(jnp.float32) def add_prefix_to_keys(result: dict[str, Any], prefix: str) -> dict[str, Any]: """Adds a prefix to the keys of a dict, returning a new dict.""" return {f'{prefix}_{key}': val for key, val in result.items()} @flax.struct.dataclass class MeanAveragePrecision( metrics.CollectingMetric.from_outputs(('labels', 'logits', 'mask')) ): """Computes the mean average precision (mAP) over different tasks.""" def compute(self): # Matches the official OGB evaluation scheme for mean average precision. values = super().compute() labels = values['labels'] logits = values['logits'] mask = values['mask'] assert logits.shape == labels.shape == mask.shape assert len(logits.shape) == 2 probs = jax.nn.sigmoid(logits) num_tasks = labels.shape[1] average_precisions = np.full(num_tasks, np.nan) for task in range(num_tasks): # AP is only defined when there is at least one negative data # and at least one positive data. is_labeled = mask[:, task] if len(np.unique(labels[is_labeled, task])) >= 2: average_precisions[task] = sklearn.metrics.average_precision_score( labels[is_labeled, task], probs[is_labeled, task] ) # When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning. if np.isnan(average_precisions).all(): return np.nan return np.nanmean(average_precisions) @flax.struct.dataclass class EvalMetrics(metrics.Collection): accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') mean_average_precision: MeanAveragePrecision @flax.struct.dataclass class TrainMetrics(metrics.Collection): accuracy: metrics.Average.from_fun(predictions_match_labels) loss: metrics.Average.from_output('loss') def replace_globals(graphs: jraph.GraphsTuple) -> jraph.GraphsTuple: """Replaces the globals attribute with a constant feature for each graph.""" return graphs._replace(globals=jnp.ones([graphs.n_node.shape[0], 1])) def get_predicted_logits( state: train_state.TrainState, graphs: jraph.GraphsTuple, rngs: dict[str, jnp.ndarray] | None, ) -> jnp.ndarray: """Get predicted logits from the network for input graphs.""" pred_graphs = state.apply_fn(state.params, graphs, rngs=rngs) logits = pred_graphs.globals return logits def get_valid_mask( labels: jnp.ndarray, graphs: jraph.GraphsTuple ) -> jnp.ndarray: """Gets the binary mask indicating only valid labels and graphs.""" # We have to ignore all NaN values - which indicate labels for which # the current graphs have no label. labels_mask = ~jnp.isnan(labels) # Since we have extra 'dummy' graphs in our batch due to padding, we want # to mask out any loss associated with the dummy graphs. # Since we padded with `pad_with_graphs` we can recover the mask by using # get_graph_padding_mask. graph_mask = jraph.get_graph_padding_mask(graphs) # Combine the mask over labels with the mask over graphs. return labels_mask & graph_mask[:, None] @jax.jit def train_step( state: train_state.TrainState, graphs: jraph.GraphsTuple, rngs: dict[str, jnp.ndarray], ) -> tuple[train_state.TrainState, metrics.Collection]: """Performs one update step over the current batch of graphs.""" def loss_fn(params, graphs): curr_state = state.replace(params=params) # Extract labels. labels = graphs.globals # Replace the global feature for graph classification. graphs = replace_globals(graphs) # Compute logits and resulting loss. logits = get_predicted_logits(curr_state, graphs, rngs) mask = get_valid_mask(labels, graphs) loss = binary_cross_entropy_with_mask( logits=logits, labels=labels, mask=mask ) mean_loss = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask) return mean_loss, (loss, logits, labels, mask) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, (loss, logits, labels, mask)), grads = grad_fn(state.params, graphs) state = state.apply_gradients(grads=grads) metrics_update = TrainMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=mask ) return state, metrics_update @jax.jit def evaluate_step( state: train_state.TrainState, graphs: jraph.GraphsTuple, ) -> metrics.Collection: """Computes metrics over a set of graphs.""" # The target labels our model has to predict. labels = graphs.globals # Replace the global feature for graph classification. graphs = replace_globals(graphs) # Get predicted logits, and corresponding probabilities. logits = get_predicted_logits(state, graphs, rngs=None) # Get the mask for valid labels and graphs. mask = get_valid_mask(labels, graphs) # Compute the various metrics. loss = binary_cross_entropy_with_mask(logits=logits, labels=labels, mask=mask) return EvalMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=mask ) def evaluate_model( state: train_state.TrainState, datasets: dict[str, tf.data.Dataset], splits: Iterable[str], ) -> dict[str, metrics.Collection]: """Evaluates the model on metrics over the specified splits.""" # Loop over each split independently. eval_metrics = {} for split in splits: split_metrics = None # Loop over graphs. for graphs in datasets[split].as_numpy_iterator(): split_metrics_update = evaluate_step(state, graphs) # Update metrics. if split_metrics is None: split_metrics = split_metrics_update else: split_metrics = split_metrics.merge(split_metrics_update) eval_metrics[split] = split_metrics return eval_metrics # pytype: disable=bad-return-type def train_and_evaluate( config: ml_collections.ConfigDict, workdir: str ) -> train_state.TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the TensorBoard summaries are written to. Returns: The train state (which includes the `.params`). """ # We only support single-host training. assert jax.process_count() == 1 # Create writer for logs. writer = metric_writers.create_default_writer(workdir) writer.write_hparams(config.to_dict()) # Get datasets, organized by split. logging.info('Obtaining datasets.') datasets = input_pipeline.get_datasets( config.batch_size, add_virtual_node=config.add_virtual_node, add_undirected_edges=config.add_undirected_edges, add_self_loops=config.add_self_loops, ) train_iter = iter(datasets['train']) # Create and initialize the network. logging.info('Initializing network.') rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) init_graphs = next(datasets['train'].as_numpy_iterator()) init_graphs = replace_globals(init_graphs) init_net = create_model(config, deterministic=True) params = jax.jit(init_net.init)(init_rng, init_graphs) parameter_overview.log_parameter_overview(params) # Create the optimizer. tx = create_optimizer(config) # Create the training state. net = create_model(config, deterministic=False) state = train_state.TrainState.create( apply_fn=net.apply, params=params, tx=tx ) # Set up checkpointing of the model. # The input pipeline cannot be checkpointed in its current form, # due to the use of stateful operations. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Create the evaluation state, corresponding to a deterministic model. eval_net = create_model(config, deterministic=True) eval_state = state.replace(apply_fn=eval_net.apply) # Hooks called periodically during training. report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer ) profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir) hooks = [report_progress, profiler] # Begin training loop. logging.info('Starting training.') train_metrics = None for step in range(initial_step, config.num_train_steps + 1): # Split PRNG key, to ensure different 'randomness' for every step. rng, dropout_rng = jax.random.split(rng) # Perform one step of training. with jax.profiler.StepTraceAnnotation('train', step_num=step): graphs = jax.tree_util.tree_map(np.asarray, next(train_iter)) state, metrics_update = train_step( state, graphs, rngs={'dropout': dropout_rng} ) # Update metrics. if train_metrics is None: train_metrics = metrics_update else: train_metrics = train_metrics.merge(metrics_update) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 10, step) for hook in hooks: hook(step) # Log, if required. is_last_step = step == config.num_train_steps - 1 if step % config.log_every_steps == 0 or is_last_step: writer.write_scalars( step, add_prefix_to_keys(train_metrics.compute(), 'train') ) train_metrics = None # Evaluate on validation and test splits, if required. if step % config.eval_every_steps == 0 or is_last_step: eval_state = eval_state.replace(params=state.params) splits = ['validation', 'test'] with report_progress.timed('eval'): eval_metrics = evaluate_model(eval_state, datasets, splits=splits) for split in splits: writer.write_scalars( step, add_prefix_to_keys(eval_metrics[split].compute(), split) ) # Checkpoint model, if required. if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(state) return state ================================================ FILE: examples/ogbg_molpcba/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.ogbg_molpcba.train.""" import functools import pathlib import tempfile from typing import Dict, Optional import warnings from absl.testing import absltest from absl.testing import parameterized import flax from flax.training import train_state import jax from jax import numpy as jnp import jraph import tensorflow as tf import tensorflow_datasets as tfds import numpy as np from configs import default from configs import test import input_pipeline import train def average_with_mask(arr: jnp.ndarray, mask: jnp.ndarray): """Returns the average over elements where mask is True.""" arr = jnp.where(mask, arr, 0) return jnp.sum(arr) / jnp.sum(mask) def get_dummy_raw_datasets(dataset_length) -> dict[str, tf.data.Dataset]: """Returns dummy datasets, mocking tfds.DatasetBuilder.as_dataset().""" # The dummy graph. num_nodes = 3 num_edges = 4 dummy_graph = { 'edge_feat': tf.zeros((num_edges, 3), dtype=tf.float32), 'edge_index': tf.zeros((num_edges, 2), dtype=tf.int64), 'labels': tf.ones((128,), dtype=tf.float32), 'node_feat': tf.zeros((num_nodes, 9), dtype=tf.float32), 'num_edges': tf.expand_dims(num_edges, axis=0), 'num_nodes': tf.expand_dims(num_nodes, axis=0), } dummy_graph_spec = { 'edge_feat': tf.TensorSpec(shape=(None, 3), dtype=tf.float32), 'edge_index': tf.TensorSpec(shape=(None, 2), dtype=tf.int64), 'labels': tf.TensorSpec(shape=(128,), dtype=tf.float32), 'node_feat': tf.TensorSpec(shape=(None, 9), dtype=tf.float32), 'num_edges': tf.TensorSpec(shape=(None,), dtype=tf.int64), 'num_nodes': tf.TensorSpec(shape=(None,), dtype=tf.int64), } def get_dummy_graphs(): for _ in range(dataset_length): yield dummy_graph datasets = {} for split in ['train', 'validation', 'test']: datasets[split] = tf.data.Dataset.from_generator( get_dummy_graphs, output_signature=dummy_graph_spec ) return datasets def get_dummy_datasets( dataset_length: int, batch_size: int | None = None ) -> dict[str, tf.data.Dataset]: """Returns dummy datasets, mocking input_pipeline.get_datasets().""" datasets = get_dummy_raw_datasets(dataset_length) # Construct the GraphsTuple converter function. convert_to_graphs_tuple_fn = functools.partial( input_pipeline.convert_to_graphs_tuple, add_virtual_node=True, add_undirected_edges=True, add_self_loops=True, ) # Process each split separately. for split_name in datasets: # Convert to GraphsTuple. datasets[split_name] = datasets[split_name].map( convert_to_graphs_tuple_fn, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True, ) # If batch size is None, do not batch. if batch_size is not None: budget = input_pipeline.estimate_padding_budget_for_batch_size( datasets['train'], batch_size, num_estimation_graphs=1 ) # Pad an example graph to see what the output shapes will be. # We will use this shape information when creating the tf.data.Dataset. example_graph = next(datasets['train'].as_numpy_iterator()) example_padded_graph = jraph.pad_with_graphs(example_graph, *budget) padded_graphs_spec = input_pipeline.specs_from_graphs_tuple( example_padded_graph ) # Batch and pad each split separately. for split, dataset_split in datasets.items(): batching_fn = functools.partial( jraph.dynamically_batch, graphs_tuple_iterator=iter(dataset_split), n_node=budget.n_node, n_edge=budget.n_edge, n_graph=budget.n_graph, ) datasets[split] = tf.data.Dataset.from_generator( batching_fn, output_signature=padded_graphs_spec ) return datasets class OgbgMolpcbaTrainTest(parameterized.TestCase): def setUp(self): super().setUp() # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') # Print the current platform (the default device). platform = jax.local_devices()[0].platform print('Running on platform:', platform.upper()) # Create PRNG keys. self.rng = jax.random.key(0) # Create dummy datasets. self.datasets = get_dummy_datasets(dataset_length=20, batch_size=10) self.raw_datasets = get_dummy_raw_datasets(dataset_length=20) @parameterized.product( probs=[[[0.8, 0.9, 0.3, 0.5]]], labels=[ [[1, 0, 1, 1]], [[1, 0, 1, jnp.nan]], [[1, 0, jnp.nan, jnp.nan]], [[1, jnp.nan, jnp.nan, jnp.nan]], ], ) def test_binary_cross_entropy_loss(self, probs, labels): probs = jnp.asarray(probs) labels = jnp.asarray(labels) logits = jnp.log(probs / (1 - probs)) mask = ~jnp.isnan(labels) loss_array = train.binary_cross_entropy_with_mask( logits=logits, labels=labels, mask=mask ) loss = average_with_mask(loss_array, mask) expected_loss_array = -(jnp.log(probs) * labels) - ( jnp.log(1 - probs) * (1 - labels) ) expected_loss = average_with_mask(expected_loss_array, mask) self.assertAlmostEqual(loss, expected_loss, places=5) @parameterized.named_parameters( dict( testcase_name='no_valid_tasks', logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[jnp.nan, jnp.nan], [jnp.nan, jnp.nan], [jnp.nan, jnp.nan]], expected_result=jnp.nan, ), dict( testcase_name='1_valid_task', logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[0, jnp.nan], [1, jnp.nan], [1, jnp.nan]], expected_result=1.0, ), dict( testcase_name='2_valid_tasks', logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, -1.0]], labels=[[0, jnp.nan], [1, 0], [1, 1]], expected_result=0.75, ), ) def test_mean_average_precision(self, logits, labels, expected_result): logits = jnp.asarray(logits) labels = jnp.asarray(labels) mask = ~jnp.isnan(labels) mean_average_precision = train.MeanAveragePrecision.from_model_output( logits=logits, labels=labels, mask=mask ).compute() if jnp.isnan(expected_result): self.assertTrue(jnp.isnan(mean_average_precision)) else: self.assertAlmostEqual(expected_result, mean_average_precision) @parameterized.parameters( dict( loss=[[0.5, 1.0], [1.5, 1.3], [2.0, 1.2]], logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, 0.0]], labels=[[0, jnp.nan], [1, 0], [0, 1]], mask=[[True, False], [True, True], [False, False]], expected_results={ 'loss': 1.1, 'accuracy': 2 / 3, 'mean_average_precision': 1.0, }, ), ) def test_eval_metrics(self, loss, logits, labels, mask, expected_results): loss = jnp.asarray(loss) logits = jnp.asarray(logits) labels = jnp.asarray(labels) mask = jnp.asarray(mask) # Ignore RuntimeWarning caused by MeanAveragePrecision calculation. with warnings.catch_warnings(): warnings.simplefilter('ignore', category=RuntimeWarning) eval_metrics = train.EvalMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=mask ).compute() for metric in expected_results: self.assertAlmostEqual(expected_results[metric], eval_metrics[metric]) @parameterized.parameters( dict( loss=[[0.5, 1.0], [1.5, 1.3], [2.0, 1.2]], logits=[[-1.0, 1.0], [1.0, 1.0], [2.0, 0.0]], labels=[[0, jnp.nan], [1, 0], [0, 1]], mask=[[True, False], [True, True], [False, False]], expected_results={'loss': 1.1, 'accuracy': 2 / 3}, ), ) def test_train_metrics(self, loss, logits, labels, mask, expected_results): loss = jnp.asarray(loss) logits = jnp.asarray(logits) labels = jnp.asarray(labels) mask = jnp.asarray(mask) train_metrics = train.TrainMetrics.single_from_model_output( loss=loss, logits=logits, labels=labels, mask=mask ).compute() for metric in expected_results: self.assertAlmostEqual(expected_results[metric], train_metrics[metric]) def test_train_step(self): # Get the default configuration. config = default.get_config() # Initialize the network with a dummy graph. rng, init_rng = jax.random.split(self.rng) init_graphs = next(self.datasets['train'].as_numpy_iterator()) init_graphs_preprocessed = train.replace_globals(init_graphs) init_net = train.create_model(config, deterministic=True) params = jax.jit(init_net.init)(init_rng, init_graphs_preprocessed) # Create the optimizer. optimizer = train.create_optimizer(config) # Create the training state. net = train.create_model(config, deterministic=False) state = train_state.TrainState.create( apply_fn=net.apply, params=params, tx=optimizer ) # Perform one step of updates. # We use the same batch of graphs that we used for initialization. state, train_metrics = train.train_step( state, init_graphs, rngs={'dropout': rng} ) # Check that none of the parameters are NaNs! params = flax.core.unfreeze(state.params) flat_params = { '/'.join(k): v for k, v in flax.traverse_util.flatten_dict(params).items() } for array in flat_params.values(): self.assertTrue(jnp.all(~jnp.isnan(array))) # Check that the metrics are well defined. train_metrics_vals = train_metrics.compute() self.assertGreaterEqual(train_metrics_vals['loss'], 0) self.assertGreaterEqual(train_metrics_vals['accuracy'], 0) self.assertLessEqual(train_metrics_vals['accuracy'], 1) def test_evaluate_step(self): # Get the default configuration. config = default.get_config() # Initialize the network with a dummy graph. _, init_rng = jax.random.split(self.rng) init_graphs = next(self.datasets['train'].as_numpy_iterator()) init_graphs_preprocessed = init_graphs._replace( globals=jnp.zeros([init_graphs.n_node.shape[0], 1]) ) init_net = train.create_model(config, deterministic=True) params = jax.jit(init_net.init)(init_rng, init_graphs_preprocessed) # Create the optimizer. optimizer = train.create_optimizer(config) # Create the evaluation state. eval_net = train.create_model(config, deterministic=True) eval_state = train_state.TrainState.create( apply_fn=eval_net.apply, params=params, tx=optimizer ) # Perform one step of evaluation. # We use the same batch of graphs that we used for initialization. evaluate_metrics = train.evaluate_step(eval_state, init_graphs) # Check that the metrics are well defined. evaluate_metrics_vals = evaluate_metrics.compute() self.assertGreaterEqual(evaluate_metrics_vals['loss'], 0) self.assertGreaterEqual(evaluate_metrics_vals['accuracy'], 0) self.assertLessEqual(evaluate_metrics_vals['accuracy'], 1) self.assertTrue(np.isnan(evaluate_metrics_vals['mean_average_precision'])) def test_train_and_evaluate(self): # Create a temporary directory where TensorBoard metrics are written. workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable # Get the test configuration. config = test.get_config() # Ensure train_and_evaluate() runs without any errors! def as_dataset_fn(*args, **kwargs): del args split = kwargs['split'] return self.raw_datasets[split] with tfds.testing.mock_data(as_dataset_fn=as_dataset_fn): train.train_and_evaluate(config=config, workdir=workdir) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/ppo/README.md ================================================ # Proximal Policy Optimization Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) to learn playing Atari games. ## Requirements This example depends on the `gymnasium[atari,accept-rom-license]`, `opencv-python` packages in addition to `jax` and `flax`. ## Supported setups The example should run with other configurations and hardware, but was explicitly tested on the following: | Hardware | Game | Training time | Total frames seen | TensorBoard.dev | | --- | --- | --- | --- | --- | | 1x V100 GPU | Breakout | 9h 15m 15s | 40M | [2020-10-02](https://tensorboard.dev/experiment/pY7D2qYQQLO9ZT5lA9PFPA) | > **Note** > It is possible to improve training efficiency through further optimizations. For example, CleanRL's PPO ([ppo_atari_envpool_xla_jax_scan.py](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jax_scanpy)) can achieve the same level of results in just 30 minutes with an RTX 2080 TI, 8 CPU, and the same hyperparameters — **a 1850% speedup end-to-end**. It achieves this by using [EnvPool](https://envpool.readthedocs.io/en/latest/), a library for fast parallelizing environments, jitting the entire rollout through [EnvPool's XLA interface](https://envpool.readthedocs.io/en/latest/content/xla_interface.html), storing data more efficiently, and `jax.scan`. ## How to run Running `python ppo_main.py` will run the example with default (hyper)parameters, i.e. for 40M frames on the Pong game. By default logging info and checkpoints will be stored in `/tmp/ppo_training` directory. This can be overridden as follows: ```python ppo_main.py --config=configs/default.py --workdir=/my_fav_directory``` You can also override the default (hyper)parameters, for example ```python ppo_main.py --config=configs/default.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --workdir=/tmp/seaquest``` will train the model on 20M Seaquest frames with constant (i.e. not linearly decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard files will be saved in `/tmp/seaquest`. Unit tests can be run using `python ppo_lib_test.py`. ## How to run on Google Cloud TPU It is also possible to run this code on Google Cloud TPU. For detailed instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/main/examples/wmt). ## Owners Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow ================================================ FILE: examples/ppo/agent.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Agent utilities, incl. choosing the move and running in separate process.""" import collections import functools import multiprocessing from typing import Any from collections.abc import Callable import flax import jax import numpy as np import env_utils @functools.partial(jax.jit, static_argnums=0) def policy_action( apply_fn: Callable[..., Any], params: flax.core.frozen_dict.FrozenDict, state: np.ndarray, ): """Forward pass of the network. Args: params: the parameters of the actor-critic model module: the actor-critic model state: the input for the forward pass Returns: out: a tuple (log_probabilities, values) """ out = apply_fn({'params': params}, state) return out ExpTuple = collections.namedtuple( 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done'] ) class RemoteSimulator: """Wrap functionality for an agent emulating Atari in a separate process. An object of this class is created for every agent. """ def __init__(self, game: str): """Start the remote process and create Pipe() to communicate with it.""" parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( target=rcv_action_send_exp, args=(child_conn, game) ) self.proc.daemon = True self.conn = parent_conn self.proc.start() def rcv_action_send_exp(conn, game: str): """Run the remote agents. Receive action from the main learner, perform one step of simulation and send back collected experience. """ env = env_utils.create_env(game, clip_rewards=True) while True: obs = env.reset() done = False # Observations fetched from Atari env need additional batch dimension. state = obs[None, ...] while not done: conn.send(state) action = conn.recv() obs, reward, done, _ = env.step(action) next_state = obs[None, ...] if not done else None experience = (state, action, reward, done) conn.send(experience) if done: break state = next_state ================================================ FILE: examples/ppo/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Definitions of default hyperparameters.""" import ml_collections def get_config(): """Get the default configuration. The default hyperparameters originate from PPO paper arXiv:1707.06347 and openAI baselines 2:: https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py """ config = ml_collections.ConfigDict() # The Atari game used. config.game = 'Pong' # Total number of frames seen during training. config.total_frames = 40000000 # The learning rate for the Adam optimizer. config.learning_rate = 2.5e-4 # Batch size used in training. config.batch_size = 256 # Number of agents playing in parallel. config.num_agents = 8 # Number of steps each agent performs in one policy unroll. config.actor_steps = 128 # Number of training epochs per each unroll of the policy. config.num_epochs = 3 # RL discount parameter. config.gamma = 0.99 # Generalized Advantage Estimation parameter. config.lambda_ = 0.95 # The PPO clipping parameter used to clamp ratios in loss function. config.clip_param = 0.1 # Weight of value function loss in the total loss. config.vf_coeff = 0.5 # Weight of entropy bonus in the total loss. config.entropy_coeff = 0.01 # Linearly decay learning rate and clipping parameter to zero during # the training. config.decaying_lr_and_clip_param = True return config ================================================ FILE: examples/ppo/env_utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for handling the Atari environment.""" import collections import gymnasium as gym import numpy as np import seed_rl_atari_preprocessing class ClipRewardEnv(gym.RewardWrapper): """Adapted from OpenAI baselines. github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py """ def __init__(self, env): gym.RewardWrapper.__init__(self, env) def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) class FrameStack: """Implements stacking of `num_frames` last frames of the game. Wraps an AtariPreprocessing object. """ def __init__( self, preproc: seed_rl_atari_preprocessing.AtariPreprocessing, num_frames: int, ): self.preproc = preproc self.num_frames = num_frames self.frames = collections.deque(maxlen=num_frames) def reset(self): ob = self.preproc.reset() for _ in range(self.num_frames): self.frames.append(ob) return self._get_array() def step(self, action: int): ob, reward, done, info = self.preproc.step(action) self.frames.append(ob) return self._get_array(), reward, done, info def _get_array(self): assert len(self.frames) == self.num_frames return np.concatenate(self.frames, axis=-1) def create_env(game: str, clip_rewards: bool): """Create a FrameStack object that serves as environment for the `game`.""" env = gym.make(game) if clip_rewards: env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) return stack def get_num_actions(game: str): """Get the number of possible actions of a given Atari game. This determines the number of outputs in the actor part of the actor-critic model. """ env = gym.make(game) return env.action_space.n ================================================ FILE: examples/ppo/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Class and functions to define and initialize the actor-critic model.""" from flax import linen as nn import jax.numpy as jnp class ActorCritic(nn.Module): """Class defining the actor-critic model.""" num_outputs: int @nn.compact def __call__(self, x): """Define the convolutional network architecture. Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) Network is used to both estimate policy (logits) and expected state value; in other words, hidden layers' params are shared between policy and value networks, see e.g.: github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py """ dtype = jnp.float32 x = x.astype(dtype) / 255.0 x = nn.Conv( features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype, )(x) x = nn.relu(x) x = nn.Conv( features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype, )(x) x = nn.relu(x) x = nn.Conv( features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype, )(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=512, name='hidden', dtype=dtype)(x) x = nn.relu(x) logits = nn.Dense(features=self.num_outputs, name='logits', dtype=dtype)(x) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(features=1, name='value', dtype=dtype)(x) return policy_log_probabilities, value ================================================ FILE: examples/ppo/ppo_lib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Library file which executes the PPO training.""" import functools from typing import Any from collections.abc import Callable from absl import logging import flax from flax import linen as nn import agent import models import test_episodes from flax.metrics import tensorboard from flax.training import checkpoints from flax.training import train_state import jax import jax.numpy as jnp import ml_collections import numpy as np import optax @jax.jit @functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) def gae_advantages( rewards: np.ndarray, terminal_masks: np.ndarray, values: np.ndarray, discount: float, gae_param: float, ): """Use Generalized Advantage Estimation (GAE) to compute advantages. As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}. Args: rewards: array shaped (actor_steps, num_agents), rewards from the game terminal_masks: array shaped (actor_steps, num_agents), zeros for terminal and ones for non-terminal states values: array shaped (actor_steps, num_agents), values estimated by critic discount: RL discount usually denoted with gamma gae_param: GAE parameter usually denoted with lambda Returns: advantages: calculated advantages shaped (actor_steps, num_agents) """ assert rewards.shape[0] + 1 == values.shape[0], ( 'One more value needed; Eq. ' '(12) in PPO paper requires ' 'V(s_{t+1}) for delta_t' ) advantages = [] gae = 0.0 for t in reversed(range(len(rewards))): # Masks used to set next state value to 0 for terminal states. value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] delta = rewards[t] + value_diff # Masks[t] used to ensure that values before and after a terminal state # are independent of each other. gae = delta + discount * gae_param * terminal_masks[t] * gae advantages.append(gae) advantages = advantages[::-1] return jnp.array(advantages) def loss_fn( params: flax.core.FrozenDict, apply_fn: Callable[..., Any], minibatch: tuple, clip_param: float, vf_coeff: float, entropy_coeff: float, ): """Evaluate the loss function. Compute loss as a sum of three components: the negative of the PPO clipped surrogate objective, the value function loss and the negative of the entropy bonus. Args: params: the parameters of the actor-critic model apply_fn: the actor-critic model's apply function minibatch: tuple of five elements forming one experience batch: states: shape (batch_size, 84, 84, 4) actions: shape (batch_size, 84, 84, 4) old_log_probs: shape (batch_size,) returns: shape (batch_size,) advantages: shape (batch_size,) clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: loss: the PPO loss, scalar quantity """ states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = agent.policy_action(apply_fn, params, states) values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). probs = jnp.exp(log_probs) value_loss = jnp.mean(jnp.square(returns - values), axis=0) entropy = jnp.sum(-probs * log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) pg_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp( 1.0 - clip_param, ratios, 1.0 + clip_param ) ppo_loss = -jnp.mean(jnp.minimum(pg_loss, clipped_loss), axis=0) return ppo_loss + vf_coeff * value_loss - entropy_coeff * entropy @functools.partial(jax.jit, static_argnums=(2,)) def train_step( state: train_state.TrainState, trajectories: tuple, batch_size: int, *, clip_param: float, vf_coeff: float, entropy_coeff: float, ): """Compilable train step. Runs an entire epoch of training (i.e. the loop over minibatches within an epoch is included here for performance reasons). Args: state: the train state trajectories: tuple of the following five elements forming the experience: states: shape (steps_per_agent*num_agents, 84, 84, 4) actions: shape (steps_per_agent*num_agents, 84, 84, 4) old_log_probs: shape (steps_per_agent*num_agents, ) returns: shape (steps_per_agent*num_agents, ) advantages: (steps_per_agent*num_agents, ) batch_size: the minibatch size, static argument clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: optimizer: new optimizer after the parameters update loss: loss summed over training steps """ iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_util.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories ) loss = 0.0 for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) l, grads = grad_fn( state.params, state.apply_fn, batch, clip_param, vf_coeff, entropy_coeff ) loss += l state = state.apply_gradients(grads=grads) return state, loss def get_experience( state: train_state.TrainState, simulators: list[agent.RemoteSimulator], steps_per_actor: int, ): """Collect experience from agents. Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ all_experience = [] # Range up to steps_per_actor + 1 to get one more value needed for GAE. for _ in range(steps_per_actor + 1): sim_states = [] for sim in simulators: sim_state = sim.conn.recv() sim_states.append(sim_state) sim_states = np.concatenate(sim_states, axis=0) log_probs, values = agent.policy_action( state.apply_fn, state.params, sim_states ) log_probs, values = jax.device_get((log_probs, values)) probs = np.exp(np.array(log_probs)) for i, sim in enumerate(simulators): probabilities = probs[i] action = np.random.choice(probs.shape[1], p=probabilities) sim.conn.send(action) experiences = [] for i, sim in enumerate(simulators): sim_state, action, reward, done = sim.conn.recv() value = values[i, 0] log_prob = log_probs[i][action] sample = agent.ExpTuple(sim_state, action, reward, value, log_prob, done) experiences.append(sample) all_experience.append(experiences) return all_experience def process_experience( experience: list[list[agent.ExpTuple]], actor_steps: int, num_agents: int, gamma: float, lambda_: float, ): """Process experience for training, including advantage estimation. Args: experience: collected from agents in the form of nested lists/namedtuple actor_steps: number of steps each agent has completed num_agents: number of agents that collected experience gamma: dicount parameter lambda_: GAE parameter Returns: trajectories: trajectories readily accessible for `train_step()` function """ obs_shape = (84, 84, 4) exp_dims = (actor_steps, num_agents) values_dims = (actor_steps + 1, num_agents) states = np.zeros(exp_dims + obs_shape, dtype=np.float32) actions = np.zeros(exp_dims, dtype=np.int32) rewards = np.zeros(exp_dims, dtype=np.float32) values = np.zeros(values_dims, dtype=np.float32) log_probs = np.zeros(exp_dims, dtype=np.float32) dones = np.zeros(exp_dims, dtype=np.float32) for t in range(len(experience) - 1): # experience[-1] only for next_values for agent_id, exp_agent in enumerate(experience[t]): states[t, agent_id, ...] = exp_agent.state actions[t, agent_id] = exp_agent.action rewards[t, agent_id] = exp_agent.reward values[t, agent_id] = exp_agent.value log_probs[t, agent_id] = exp_agent.log_prob # Dones need to be 0 for terminal states. dones[t, agent_id] = float(not exp_agent.done) for a in range(num_agents): values[-1, a] = experience[-1][a].value advantages = gae_advantages(rewards, dones, values, gamma, lambda_) returns = advantages + values[:-1, :] # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) trajectory_len = num_agents * actor_steps trajectories = tuple( map( lambda x: np.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories ) ) return trajectories @functools.partial(jax.jit, static_argnums=1) def get_initial_params(key: jax.Array, model: nn.Module): input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) init_shape = jnp.ones(input_dims, jnp.float32) initial_params = model.init(key, init_shape)['params'] return initial_params def create_train_state( params, model: nn.Module, config: ml_collections.ConfigDict, train_steps: int, ) -> train_state.TrainState: if config.decaying_lr_and_clip_param: lr = optax.linear_schedule( init_value=config.learning_rate, end_value=0.0, transition_steps=train_steps, ) else: lr = config.learning_rate tx = optax.adam(lr) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx ) return state def train( model: models.ActorCritic, config: ml_collections.ConfigDict, model_dir: str ): """Main training loop. Args: model: the actor-critic model config: object holding hyperparameters and the training information model_dir: path to dictionary where checkpoints and logging info are stored Returns: optimizer: the trained optimizer """ game = config.game + 'NoFrameskip-v4' simulators = [agent.RemoteSimulator(game) for _ in range(config.num_agents)] summary_writer = tensorboard.SummaryWriter(model_dir) summary_writer.hparams(dict(config)) loop_steps = config.total_frames // (config.num_agents * config.actor_steps) log_frequency = 40 checkpoint_frequency = 500 # train_step does multiple steps per call for better performance # compute number of steps per call here to convert between the number of # train steps and the inner number of optimizer steps iterations_per_step = ( config.num_agents * config.actor_steps // config.batch_size ) initial_params = get_initial_params(jax.random.key(0), model) state = create_train_state( initial_params, model, config, loop_steps * config.num_epochs * iterations_per_step, ) del initial_params state = checkpoints.restore_checkpoint(model_dir, state) # number of train iterations done by each train_step start_step = int(state.step) // config.num_epochs // iterations_per_step logging.info('Start training from step: %s', start_step) for step in range(start_step, loop_steps): # Bookkeeping and testing. if step % log_frequency == 0: score = test_episodes.policy_test(1, state.apply_fn, state.params, game) frames = step * config.num_agents * config.actor_steps summary_writer.scalar('game_score', score, frames) logging.info( 'Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score ) # Core training code. alpha = ( 1.0 - step / loop_steps if config.decaying_lr_and_clip_param else 1.0 ) all_experiences = get_experience(state, simulators, config.actor_steps) trajectories = process_experience( all_experiences, config.actor_steps, config.num_agents, config.gamma, config.lambda_, ) clip_param = config.clip_param * alpha for _ in range(config.num_epochs): permutation = np.random.permutation( config.num_agents * config.actor_steps ) trajectories = tuple(x[permutation] for x in trajectories) state, _ = train_step( state, trajectories, config.batch_size, clip_param=clip_param, vf_coeff=config.vf_coeff, entropy_coeff=config.entropy_coeff, ) if (step + 1) % checkpoint_frequency == 0: checkpoints.save_checkpoint(model_dir, state, step + 1) return state ================================================ FILE: examples/ppo/ppo_lib_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the PPO example.""" from absl.testing import absltest from flax.training import train_state import jax import ml_collections import numpy as np import numpy.testing as np_testing import agent import env_utils import models import ppo_lib import gymnasium as gym import ale_py gym.register_envs(ale_py) # test GAE class TestGAE(absltest.TestCase): def test_gae_shape_on_random(self): # create random data, simulating 4 parallel envs and 20 time_steps envs, steps = 10, 100 rewards = np.random.choice( [-1.0, 0.0, 1.0], size=(steps, envs), p=[0.01, 0.98, 0.01] ) terminal_masks = np.ones(shape=(steps, envs), dtype=np.float64) values = np.random.random(size=(steps + 1, envs)) discount = 0.99 gae_param = 0.95 adv = ppo_lib.gae_advantages( rewards, terminal_masks, values, discount, gae_param ) self.assertEqual(adv.shape, (steps, envs)) def test_gae_hardcoded(self): # test on small example that can be verified by hand rewards = np.array([[1.0, 0.0], [0.0, 0.0], [-1.0, 1.0]]) # one of the two episodes terminated in the middle terminal_masks = np.array([[1.0, 1.0], [0.0, 1.0], [1.0, 1.0]]) values = np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]) discount = 0.5 gae_param = 0.25 correct_gae = np.array([[0.375, -0.5546875], [-1.0, -0.4375], [-1.5, 0.5]]) actual_gae = ppo_lib.gae_advantages( rewards, terminal_masks, values, discount, gae_param ) np_testing.assert_allclose(actual_gae, correct_gae) # test environment and preprocessing class TestEnvironmentPreprocessing(absltest.TestCase): def choose_random_game(self): games = [ 'BeamRider', 'Breakout', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders', ] ind = np.random.choice(len(games)) return games[ind] + 'NoFrameskip-v4' def test_creation(self): frame_shape = (84, 84, 4) game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() self.assertEqual(obs.shape, frame_shape) def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: obs, reward, done, info = env.step(a) self.assertEqual(obs.shape, frame_shape) self.assertTrue(reward <= 1.0 and reward >= -1.0) self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) # test the model (creation and forward pass) class TestModel(absltest.TestCase): def choose_random_outputs(self): return np.random.choice([4, 5, 6, 7, 8, 9]) def test_model(self): outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) params = ppo_lib.get_initial_params(jax.random.key(0), module) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size,) + obs_shape) log_probs, values = agent.policy_action(module.apply, params, random_input) self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = np.sum(np.exp(log_probs), axis=1) self.assertEqual(sum_probs.shape, (test_batch_size,)) np_testing.assert_allclose( sum_probs, np.ones((test_batch_size,)), atol=1e-6 ) # test one optimization step class TestOptimizationStep(absltest.TestCase): def generate_random_data(self, num_actions): data_len = 256 # equal to one default-sized batch state_shape = (84, 84, 4) states = np.random.randint(0, 255, size=((data_len,) + state_shape)) actions = np.random.choice(num_actions, size=data_len) old_log_probs = np.random.random(size=data_len) returns = np.random.random(size=data_len) advantages = np.random.random(size=data_len) return states, actions, old_log_probs, returns, advantages def test_optimization_step(self): num_outputs = 4 trn_data = self.generate_random_data(num_actions=num_outputs) clip_param = 0.1 vf_coeff = 0.5 entropy_coeff = 0.01 batch_size = 256 module = models.ActorCritic(num_outputs) initial_params = ppo_lib.get_initial_params(jax.random.key(0), module) config = ml_collections.ConfigDict({ 'learning_rate': 2.5e-4, 'decaying_lr_and_clip_param': True, }) state = ppo_lib.create_train_state(initial_params, module, config, 1000) state, _ = ppo_lib.train_step( state, trn_data, batch_size, clip_param=clip_param, vf_coeff=vf_coeff, entropy_coeff=entropy_coeff, ) self.assertIsInstance(state, train_state.TrainState) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/ppo/ppo_main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # See issue #620. # pytype: disable=wrong-keyword-args from absl import app from absl import flags from ml_collections import config_flags import tensorflow as tf import env_utils import models import ppo_lib import gymnasium as gym import ale_py gym.register_envs(ale_py) FLAGS = flags.FLAGS flags.DEFINE_string( 'workdir', default='/tmp/ppo_training', help='Directory to save checkpoints and logging info.', ) config_flags.DEFINE_config_file( 'config', 'configs/default.py', 'File path to the default configuration file.', lock_config=True, ) def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') config = FLAGS.config game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') model = models.ActorCritic(num_outputs=num_actions) ppo_lib.train(model, config, FLAGS.workdir) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/ppo/requirements.txt ================================================ absl-py==1.0.0 atari-py==0.2.5 opencv-python==4.5.4.60 flax==0.3.6 gym==0.18.3 gymnasium[atari, accept-rom-license]==0.29.0 ml-collections==0.1.0 numpy==1.22.0 optax==0.1.5 tensorflow==2.11.1 ================================================ FILE: examples/ppo/seed_rl_atari_preprocessing.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # coding=utf-8 # Copyright 2019 The SEED Authors # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A class implementing minimal Atari 2600 preprocessing. Adapted from SEED RL, originally adapted from Dopamine. """ import cv2 import gymnasium as gym from gymnasium.spaces.box import Box import numpy as np class AtariPreprocessing: """A class implementing image preprocessing for Atari 2600 agents. Specifically, this provides the following subset from the JAIR paper (Bellemare et al., 2013) and Nature DQN paper (Mnih et al., 2015): * Frame skipping (defaults to 4). * Terminal signal when a life is lost (off by default). * Grayscale and max-pooling of the last two frames. * Downsample the screen to a square image (defaults to 84x84). More generally, this class follows the preprocessing guidelines set down in Machado et al. (2018), "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". It also provides random starting no-ops, which are used in the Rainbow, Apex and R2D2 papers. """ def __init__( self, environment: gym.Env, frame_skip=4, terminal_on_life_loss=False, screen_size=84, max_random_noops=0, ): """Constructor for an Atari 2600 preprocessor. Args: environment: Gym environment whose observations are preprocessed. frame_skip: int, the frequency at which the agent experiences the game. terminal_on_life_loss: bool, If True, the step() method returns is_terminal=True whenever a life is lost. See Mnih et al. 2015. screen_size: int, size of a resized Atari 2600 frame. max_random_noops: int, maximum number of no-ops to apply at the beginning of each episode to reduce determinism. These no-ops are applied at a low-level, before frame skipping. Raises: ValueError: if frame_skip or screen_size are not strictly positive. """ if frame_skip <= 0: raise ValueError( f'Frame skip should be strictly positive, got {frame_skip}' ) if screen_size <= 0: raise ValueError( 'Target screen size should be strictly positive, got {}'.format( screen_size ) ) self.environment = environment self.terminal_on_life_loss = terminal_on_life_loss self.frame_skip = frame_skip self.screen_size = screen_size self.max_random_noops = max_random_noops obs_dims = self.environment.observation_space # Stores temporary observations used for pooling over two successive # frames. self.screen_buffer = [ np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), ] self.game_over = False self.lives = 0 # Will need to be set by reset(). @property def observation_space(self): # Return the observation space adjusted to match the shape of the processed # observations. return Box( low=0, high=255, shape=(self.screen_size, self.screen_size, 1), dtype=np.uint8, ) @property def action_space(self): return self.environment.action_space @property def reward_range(self): return self.environment.reward_range # type: ignore @property def metadata(self): return self.environment.metadata def close(self): return self.environment.close() def apply_random_noops(self): """Steps self.environment with random no-ops.""" if self.max_random_noops <= 0: return # Other no-ops implementations actually always do at least 1 no-op. We # follow them. no_ops = self.environment.np_random.randint(1, self.max_random_noops + 1) for _ in range(no_ops): _, _, game_over, _, _ = self.environment.step(0) if game_over: self.environment.reset() def reset(self): """Resets the environment. Returns: observation: numpy array, the initial observation emitted by the environment. """ self.environment.reset() self.apply_random_noops() self.lives = self.environment.unwrapped.ale.lives() # pytype:disable=attribute-error self._fetch_grayscale_observation(self.screen_buffer[0]) self.screen_buffer[1].fill(0) return self._pool_and_resize() def render(self, mode): """Renders the current screen, before preprocessing. This calls the Gym API's render() method. Args: mode: Mode argument for the environment's render() method. Valid values (str) are: 'rgb_array': returns the raw ALE image. 'human': renders to display via the Gym renderer. Returns: if mode='rgb_array': numpy array, the most recent screen. if mode='human': bool, whether the rendering was successful. """ return self.environment.render(mode) # pytype:disable=wrong-arg-count def step(self, action): """Applies the given action in the environment. Remarks: * If a terminal state (from life loss or episode end) is reached, this may execute fewer than self.frame_skip steps in the environment. * Furthermore, in this case the returned observation may not contain valid image data and should be ignored. Args: action: The action to be executed. Returns: observation: numpy array, the observation following the action. reward: float, the reward following the action. is_terminal: bool, whether the environment has reached a terminal state. This is true when a life is lost and terminal_on_life_loss, or when the episode is over. info: Gym API's info data structure. """ accumulated_reward = 0.0 for time_step in range(self.frame_skip): # We bypass the Gym observation altogether and directly fetch the # grayscale image from the ALE. This is a little faster. _, reward, game_over, _, info = self.environment.step(action) accumulated_reward += float(reward) if self.terminal_on_life_loss: new_lives = self.environment.unwrapped.ale.lives() # pytype:disable=attribute-error is_terminal = game_over or new_lives < self.lives self.lives = new_lives else: is_terminal = game_over if is_terminal: break # We max-pool over the last two frames, in grayscale. elif time_step >= self.frame_skip - 2: t = time_step - (self.frame_skip - 2) self._fetch_grayscale_observation(self.screen_buffer[t]) # Pool the last two observations. observation = self._pool_and_resize() self.game_over = game_over return observation, accumulated_reward, is_terminal, info def _fetch_grayscale_observation(self, output): """Returns the current observation in grayscale. The returned observation is stored in 'output'. Args: output: numpy array, screen buffer to hold the returned observation. Returns: observation: numpy array, the current observation in grayscale. """ self.environment.unwrapped.ale.getScreenGrayscale(output) # pytype:disable=attribute-error return output def _pool_and_resize(self): """Transforms two frames into a Nature DQN observation. For efficiency, the transformation is done in-place in self.screen_buffer. Returns: transformed_screen: numpy array, pooled, resized screen. """ # Pool if there are enough screens to do so. if self.frame_skip > 1: np.maximum( self.screen_buffer[0], self.screen_buffer[1], out=self.screen_buffer[0], ) transformed_image = cv2.resize( self.screen_buffer[0], (self.screen_size, self.screen_size), interpolation=cv2.INTER_LINEAR, ) int_image = np.asarray(transformed_image, dtype=np.uint8) return np.expand_dims(int_image, axis=2) ================================================ FILE: examples/ppo/test_episodes.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test policy by playing a full Atari game.""" import itertools from typing import Any from collections.abc import Callable import flax import numpy as np import agent import env_utils def policy_test( n_episodes: int, apply_fn: Callable[..., Any], params: flax.core.frozen_dict.FrozenDict, game: str, ): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on apply_fn: the actor-critic apply function params: actor-critic model parameters, they define the policy being tested game: defines the Atari game to test on Returns: total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(apply_fn, params, state) probs = np.exp(np.array(log_probs, dtype=np.float32)) probabilities = probs[0] / probs[0].sum() action = np.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward next_state = obs[None, ...] if not done else None state = next_state if done: break return total_reward ================================================ FILE: examples/seq2seq/README.md ================================================ ## seq2seq addition This example trains a simple LSTM on a sequence-to-sequence addition task using an encoder-decoder architecture. The data is generated on the fly. Colab lets you edit the source files and interact with the model: https://colab.research.google.com/github/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb ### Example output From Colab run that also generated [tfhub.dev] ``` INFO:absl:[1800] accuracy=1.0, loss=0.0020284138154238462 INFO:absl:DECODE: 14+381 = 395 (CORRECT) INFO:absl:DECODE: 68+91 = 159 (CORRECT) INFO:absl:DECODE: 0+807 = 707 (INCORRECT) correct=807 INFO:absl:DECODE: 95+532 = 627 (CORRECT) INFO:absl:DECODE: 6+600 = 606 (CORRECT) ``` [tfhub.dev]: https://tensorboard.dev/experiment/TwvKVBqzTaKWgEbyebillw/#scalars&_smoothingWeight=0 ### How to run `python train.py` The total runtime for 1200 steps on CPU (3.5GHz Intel Core i7, 16GB memory) is about 4 minutes. ================================================ FILE: examples/seq2seq/configs/default.py ================================================ # Copyright 2025 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.workdir = '/tmp/seq2seq' config.learning_rate = 0.003 config.batch_size = 128 config.hidden_size = 512 config.num_train_steps = 10000 config.decode_frequency = 200 config.max_len_query_digit = 3 return config ================================================ FILE: examples/seq2seq/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for seq2seq addition example.""" import random from typing import Any, Dict, Optional, Tuple from collections.abc import Generator import jax.numpy as jnp import numpy as np Array = Any # pylint: disable=invalid-name class CharacterTable: """Encodes/decodes between strings and integer representations.""" def __init__(self, chars: str, max_len_query_digit: int = 3) -> None: self._chars = sorted(set(chars)) self._char_indices = {ch: idx + 2 for idx, ch in enumerate(self._chars)} self._indices_char = {idx + 2: ch for idx, ch in enumerate(self._chars)} self._indices_char[self.pad_id] = '_' # Maximum length of a single input digit. self._max_len_query_digit = max_len_query_digit @property def pad_id(self) -> int: return 0 @property def eos_id(self) -> int: return 1 @property def vocab_size(self) -> int: # All characters + pad token and eos token. return len(self._chars) + 2 @property def max_input_len(self) -> int: """Returns the max length of an input sequence.""" # The input has the form "digit1+digit2", so the max input length is # the length of two digits plus two tokens for "+" and the EOS token. return self._max_len_query_digit * 2 + 2 @property def max_output_len(self) -> int: """Returns the max length of an output sequence.""" # The output has the form "=digit". If `digit` is the result of adding # two digits of max length x, then max length of `digit` is x+1. # Additionally, we require two more tokens for "=" and " tuple[int, int, int]: return (1, self.max_input_len, self.vocab_size) @property def decoder_input_shape(self) -> tuple[int, int, int]: return (1, self.max_output_len, self.vocab_size) def encode(self, inputs: str) -> np.ndarray: """Encodes from string to list of integers.""" return np.array( [self._char_indices[char] for char in inputs] + [self.eos_id] ) def decode(self, inputs: Array) -> str: """Decodes from list of integers to string.""" chars = [] for elem in inputs.tolist(): if elem == self.eos_id: break chars.append(self._indices_char[elem]) return ''.join(chars) def one_hot(self, tokens: np.ndarray) -> np.ndarray: vecs = np.zeros((tokens.size, self.vocab_size), dtype=np.float32) vecs[np.arange(tokens.size), tokens] = 1 return vecs def encode_onehot( self, batch_inputs: Array, max_len: int | None = None ) -> np.ndarray: """One-hot encodes a string input.""" if max_len is None: max_len = self.max_input_len def encode_str(s): tokens = self.encode(s) unpadded_len = len(tokens) if unpadded_len > max_len: raise ValueError(f"Sequence too long ({len(tokens)}>{max_len}): '{s}'") tokens = np.pad(tokens, [(0, max_len - len(tokens))], mode='constant') return self.one_hot(tokens) return np.array([encode_str(inp) for inp in batch_inputs]) def decode_onehot(self, batch_inputs: Array) -> np.ndarray: """Decodes a batch of one-hot encoding to strings.""" decode_inputs = lambda inputs: self.decode(inputs.argmax(axis=-1)) return np.array(list(map(decode_inputs, batch_inputs))) def generate_examples( self, num_examples: int ) -> Generator[tuple[str, str], None, None]: """Yields `num_examples` examples.""" for _ in range(num_examples): max_digit = pow(10, self._max_len_query_digit) - 1 # TODO(marcvanzee): Use jax.random here. key = tuple(sorted((random.randint(0, 99), random.randint(0, max_digit)))) inputs = f'{key[0]}+{key[1]}' # Preprend output by the decoder's start token. outputs = '=' + str(key[0] + key[1]) yield (inputs, outputs) def get_batch(self, batch_size: int) -> dict[str, np.ndarray]: """Returns a batch of example of size @batch_size.""" inputs, outputs = zip(*self.generate_examples(batch_size)) return { 'query': self.encode_onehot(inputs), 'answer': self.encode_onehot(outputs), } def mask_sequences(sequence_batch: Array, lengths: Array) -> Array: """Sets positions beyond the length of each sequence to 0.""" return sequence_batch * ( lengths[:, np.newaxis] > np.arange(sequence_batch.shape[1])[np.newaxis] ) def get_sequence_lengths(sequence_batch: Array, eos_id: int) -> Array: """Returns the length of each one-hot sequence, including the EOS token.""" # sequence_batch.shape = (batch_size, seq_length, vocab_size) eos_row = sequence_batch[:, :, eos_id] eos_idx = jnp.argmax(eos_row, axis=-1) # returns first occurrence # `eos_idx` is 0 if EOS is not present, so we use full length in that case. return jnp.where( eos_row[jnp.arange(eos_row.shape[0]), eos_idx], eos_idx + 1, sequence_batch.shape[1], # if there is no EOS, use full length ) ================================================ FILE: examples/seq2seq/main.py ================================================ # Copyright 2025 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main script for seq2seq example.""" from absl import app from absl import flags from absl import logging import train from ml_collections import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) def main(argv): del argv config = FLAGS.config # Set train.FLAGS values from config train.FLAGS.workdir = config.workdir train.FLAGS.learning_rate = config.learning_rate train.FLAGS.batch_size = config.batch_size train.FLAGS.hidden_size = config.hidden_size train.FLAGS.num_train_steps = config.num_train_steps train.FLAGS.decode_frequency = config.decode_frequency train.FLAGS.max_len_query_digit = config.max_len_query_digit logging.info('Starting training with config: %s', config) _ = train.train_and_evaluate(train.FLAGS.workdir) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/seq2seq/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """seq2seq example: Mode code.""" # See issue #620. # pytype: disable=wrong-keyword-args from typing import Tuple from flax import linen as nn import jax import jax.numpy as jnp Array = jax.Array PRNGKey = jax.Array LSTMCarry = tuple[Array, Array] class DecoderLSTMCell(nn.RNNCellBase): """DecoderLSTM Module wrapped in a lifted scan transform. Attributes: teacher_force: See docstring on Seq2seq module. vocab_size: Size of the vocabulary. """ features: int teacher_force: bool vocab_size: int @nn.compact def __call__( self, carry: tuple[LSTMCarry, Array], x: Array ) -> tuple[tuple[LSTMCarry, Array], tuple[Array, Array]]: """Applies the DecoderLSTM model.""" lstm_state, last_prediction = carry if not self.teacher_force: x = last_prediction lstm_state, y = nn.LSTMCell(self.features)(lstm_state, x) logits = nn.Dense(features=self.vocab_size)(y) # Sample the predicted token using a categorical distribution over the # logits. categorical_rng = self.make_rng('lstm') predicted_token = jax.random.categorical(categorical_rng, logits) # Convert to one-hot encoding. prediction = jax.nn.one_hot( predicted_token, self.vocab_size, dtype=jnp.float32 ) return (lstm_state, prediction), (logits, prediction) @property def num_feature_axes(self) -> int: return 1 class Seq2seq(nn.Module): """Sequence-to-sequence class using encoder/decoder architecture. Attributes: teacher_force: whether to use `decoder_inputs` as input to the decoder at every step. If False, only the first input (i.e., the "=" token) is used, followed by samples taken from the previous output logits. hidden_size: int, the number of hidden dimensions in the encoder and decoder LSTMs. vocab_size: the size of the vocabulary. eos_id: EOS id. """ teacher_force: bool hidden_size: int vocab_size: int eos_id: int = 1 @nn.compact def __call__( self, encoder_inputs: Array, decoder_inputs: Array ) -> tuple[Array, Array]: """Applies the seq2seq model. Args: encoder_inputs: [batch_size, max_input_length, vocab_size]. padded batch of input sequences to encode. decoder_inputs: [batch_size, max_output_length, vocab_size]. padded batch of expected decoded sequences for teacher forcing. When sampling (i.e., `teacher_force = False`), only the first token is input into the decoder (which is the token "="), and samples are used for the following inputs. The second dimension of this tensor determines how many steps will be decoded, regardless of the value of `teacher_force`. Returns: Pair (logits, predictions), which are two arrays of length `batch_size` containing respectively decoded logits and predictions (in one hot encoding format). """ # Encode inputs. encoder = nn.RNN( nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder' ) decoder = nn.RNN( DecoderLSTMCell( decoder_inputs.shape[-1], self.teacher_force, self.vocab_size ), split_rngs={'params': False, 'lstm': True}, name='decoder', ) seq_lengths = self.get_seq_lengths(encoder_inputs) encoder_state, _ = encoder(encoder_inputs, seq_lengths=seq_lengths) logits, predictions = decoder( decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0]), ) return logits, predictions def get_seq_lengths(self, inputs: Array) -> Array: """Get segmentation mask for inputs.""" # undo one-hot encoding inputs = jnp.argmax(inputs, axis=-1) # calculate sequence lengths seq_lengths = jnp.argmax(inputs == self.eos_id, axis=-1) return seq_lengths ================================================ FILE: examples/seq2seq/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.3.6 numpy==1.22.0 optax==0.1.0 ================================================ FILE: examples/seq2seq/seq2seq.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Flax seq2seq Example\n", "\n", "\"Open\n", "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/seq2seq\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The **Flax Notebook Workflow**:\n", "\n", "1. Run the entire notebook end-to-end and check out the outputs.\n", " - This will open Python files in the right-hand editor!\n", " - You'll be able to interactively explore metrics in TensorBoard.\n", "2. Change some of the hyperparameters in the command-line flags in `train.py` for different hyperparameters. Check out the updated TensorBoard plots.\n", "3. Update the code in `train.py`, `models.py`, and `input_pipeline.py`. \n", " Thanks to `%autoreload`, any changes you make in the file will \n", " automatically appear in the notebook. Some ideas to get you started:\n", " - Change the model.\n", " - Log some per-batch metrics during training.\n", " - Add new hyperparameters to `models.py` and use them in `train.py`.\n", " - Train on a different vocabulary by initializing `CharacterTable` with a\n", " different character set.\n", "4. At any time, feel free to paste code from the source code into the notebook\n", " and modify it directly there!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "4c0a705c-8d7e-44cc-d851-873a40ac115e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 77 kB 3.1 MB/s \n", "\u001b[K |████████████████████████████████| 176 kB 30.2 MB/s \n", "\u001b[K |████████████████████████████████| 77 kB 5.2 MB/s \n", "\u001b[K |████████████████████████████████| 136 kB 45.5 MB/s \n", "\u001b[K |████████████████████████████████| 65 kB 2.8 MB/s \n", "\u001b[K |████████████████████████████████| 462 kB 44.3 MB/s \n", "\u001b[?25h Building wheel for ml-collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "# Install CLU & Flax.\n", "!pip install -q clu flax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "example_directory = 'examples/seq2seq'\n", "editor_relpaths = ('train.py', 'input_pipeline.py', 'models.py')\n", "\n", "repo, branch = 'https://github.com/google/flax', 'main'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "outputId": "4801432e-4090-4b13-f0f2-d99a3039ce47" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'flaxrepo'...\n", "remote: Enumerating objects: 349, done.\u001b[K\n", "remote: Counting objects: 100% (349/349), done.\u001b[K\n", "remote: Compressing objects: 100% (286/286), done.\u001b[K\n", "remote: Total 349 (delta 63), reused 220 (delta 51), pack-reused 0\u001b[K\n", "Receiving objects: 100% (349/349), 2.12 MiB | 13.39 MiB/s, done.\n", "Resolving deltas: 100% (63/63), done.\n" ] }, { "data": { "text/html": [ "

WARNING : Editing in VM - changes lost after reboot!!

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/seq2seq/train.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/seq2seq/input_pipeline.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " ((filepath) => {{\n", " if (!google.colab.kernel.accessAllowed) {{\n", " return;\n", " }}\n", " google.colab.files.view(filepath);\n", " }})(\"/content/examples/seq2seq/models.py\")" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", "# example directory and nothing needs to be done.)\n", "\n", "#@markdown **Fetch newest Flax, copy example code**\n", "#@markdown\n", "#@markdown **If you select no** below, then the files will be stored on the\n", "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n", "#@markdown be restarted and any changes are lost**.\n", "#@markdown\n", "#@markdown **If you select yes** below, then you will be asked for your\n", "#@markdown credentials to mount your personal Google Drive. In this case, all\n", "#@markdown changes you make will be *persisted*, and even if you re-run the\n", "#@markdown Colab later on, the files will still be the same (you can of course\n", "#@markdown remove directories inside your Drive's `flax/` root if you want to\n", "#@markdown manually revert these files).\n", "\n", "if 'google.colab' in str(get_ipython()):\n", " import os\n", " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", " !git clone --depth=1 -b $branch $repo flaxrepo\n", " # Copy example files & change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", " if mount_gdrive == 'yes':\n", " DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n", " else:\n", " DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'\n", " example_root_path = f'/content/{example_directory}'\n", " from IPython import display\n", " display.display(display.HTML(\n", " f'

{DISCLAIMER}

'))\n", " if not os.path.isdir(example_root_path):\n", " os.makedirs(example_root_path)\n", " !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n", " os.chdir(example_root_path)\n", " from google.colab import files\n", " for relpath in editor_relpaths:\n", " s = open(f'{example_root_path}/{relpath}').read()\n", " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "a292a7a2-ae3c-4518-af28-9c2fa0ed2d7b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/content/examples/seq2seq\n" ] } ], "source": [ "# Note : In Colab, above cell changed the working directory.\n", "!pwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from absl import app\n", "app.parse_flags_with_usage(['seq2seq'])\n", "\n", "from absl import logging\n", "logging.set_verbosity(logging.INFO)\n", "\n", "import jax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "7e1a29ce-9d8b-4715-ce60-9eae100a1df3", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "# Local imports from current directory - auto reload.\n", "# Any changes you make to the three imported files will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", "import input_pipeline\n", "import models\n", "import train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "cb5f7f6e-1e6f-40ff-e0d6-5b428511d75b" }, "outputs": [ { "data": { "text/plain": [ "[('72+789', '=861'),\n", " ('58+858', '=916'),\n", " ('77+358', '=435'),\n", " ('99+264', '=363'),\n", " ('94+115', '=209')]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Examples are generated on the fly.\n", "ctable = input_pipeline.CharacterTable('0123456789+= ')\n", "list(ctable.generate_examples(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "b58ea813-e757-4cc5-f3ba-3cb0f05d35a6" }, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n", " dtype=float32)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = ctable.get_batch(5)\n", "# A single query (/answer) is one-hot encoded.\n", "batch['query'][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "3b33e061-f0b5-42d7-ad49-5058e8fd3b90" }, "outputs": [ { "data": { "text/plain": [ "array(['1+243'], dtype=' and Privacy Policy\n", ", and TensorBoard.dev's Terms of Service\n", ".\n", "\n", "This notice will not be shown again while you are logged into the uploader.\n", "To log out, run `tensorboard dev auth revoke`.\n", "\n", "Continue? (yes/NO) yes\n", "\n", "Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&state=IjociK9llsm6dSiC1TDFvFmJksFy49&prompt=consent&access_type=offline\n", "Enter the authorization code: 4/1AX4XfWi6J9MqoDpbZ5Z_jd1AVheW7277VuUoTOEz5_8NRs_3oP9M7S4T81c\n", "\n", "\n", "New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/pgfmdFlaQTy9odov72ZvVQ/\n", "\n", "\u001b[1m[2022-02-25T09:34:21]\u001b[0m Started scanning logdir.\n", "\u001b[1m[2022-02-25T09:34:22]\u001b[0m Total uploaded: 38 scalars, 0 tensors, 0 binary objects\n", "\u001b[1m[2022-02-25T09:34:22]\u001b[0m Done scanning logdir.\n", "\n", "\n", "Done. View your TensorBoard at https://tensorboard.dev/experiment/pgfmdFlaQTy9odov72ZvVQ/\n" ] } ], "source": [ "if 'google.colab' in str(get_ipython()):\n", " #@markdown You can upload the training results directly to https://tensorboard.dev\n", " #@markdown\n", " #@markdown Note that everybody with the link will be able to see the data.\n", " upload_data = 'yes' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", " !tensorboard dev upload --one_shot --logdir ./workdirs --name 'Flax examples/seq2seq (Colab)'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "e22b7208-5413-4a63-abfb-b510af60f340" }, "outputs": [ { "data": { "text/plain": [ "(1, 8, 15)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = ctable.encode_onehot(['2+40'])\n", "# batch, max_length, vocab_size\n", "inputs.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Using different random seeds generates different samples.\n", "preds = train.decode(state.params, inputs, jax.random.key(0), ctable)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "outputId": "e5cdfd75-2c66-4165-8ab7-9fdecde5062a" }, "outputs": [ { "data": { "text/plain": [ "array(['42'], dtype=' models.Seq2seq: return models.Seq2seq( teacher_force=teacher_force, hidden_size=FLAGS.hidden_size, eos_id=ctable.eos_id, vocab_size=ctable.vocab_size, ) def get_initial_params( model: models.Seq2seq, rng: PRNGKey, ctable: CTable ) -> dict[str, Any]: """Returns the initial parameters of a seq2seq model.""" rng1, rng2 = jax.random.split(rng) variables = model.init( {'params': rng1, 'lstm': rng2}, jnp.ones(ctable.encoder_input_shape, jnp.float32), jnp.ones(ctable.decoder_input_shape, jnp.float32), ) return variables['params'] def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainState: """Returns a train state.""" model = get_model(ctable) params = get_initial_params(model, rng, ctable) tx = optax.adam(FLAGS.learning_rate) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx ) return state def cross_entropy_loss( logits: Array, labels: Array, lengths: Array ) -> jax.Array: """Returns cross-entropy loss.""" xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1) masked_xe = jnp.mean(mask_sequences(xe, lengths)) return -masked_xe def compute_metrics( logits: Array, labels: Array, eos_id: int ) -> dict[str, jax.Array]: """Computes metrics and returns them.""" lengths = get_sequence_lengths(labels, eos_id) loss = cross_entropy_loss(logits, labels, lengths) # Computes sequence accuracy, which is the same as the accuracy during # inference, since teacher forcing is irrelevant when all output are correct. token_accuracy = jnp.argmax(logits, -1) == jnp.argmax(labels, -1) sequence_accuracy = ( jnp.sum(mask_sequences(token_accuracy, lengths), axis=-1) == lengths ) accuracy = jnp.mean(sequence_accuracy) metrics = { 'loss': loss, 'accuracy': accuracy, } return metrics @jax.jit def train_step( state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int ) -> tuple[train_state.TrainState, dict[str, jax.Array]]: """Trains one step.""" labels = batch['answer'][:, 1:] lstm_key = jax.random.fold_in(lstm_rng, state.step) def loss_fn(params): logits, _ = state.apply_fn( {'params': params}, batch['query'], batch['answer'], rngs={'lstm': lstm_key}, ) loss = cross_entropy_loss( logits, labels, get_sequence_lengths(labels, eos_id) ) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, labels, eos_id) return state, metrics def log_decode(question: str, inferred: str, golden: str): """Logs the given question, inferred query, and correct query.""" suffix = ( '(CORRECT)' if inferred == golden else (f'(INCORRECT) correct={golden}') ) logging.info('DECODE: %s = %s %s', question, inferred, suffix) @functools.partial(jax.jit, static_argnums=3) def decode( params: dict[str, Any], inputs: Array, decode_rng: PRNGKey, ctable: CTable ) -> Array: """Decodes inputs.""" init_decoder_input = ctable.one_hot(ctable.encode('=')[0:1]) init_decoder_inputs = jnp.tile( init_decoder_input, (inputs.shape[0], ctable.max_output_len, 1) ) model = get_model(ctable, teacher_force=False) _, predictions = model.apply( {'params': params}, inputs, init_decoder_inputs, rngs={'lstm': decode_rng} ) return predictions def decode_batch( state: train_state.TrainState, batch: dict[str, Array], decode_rng: PRNGKey, ctable: CTable, ): """Decodes and log results for a batch.""" inputs, outputs = batch['query'], batch['answer'][:, 1:] decode_rng = jax.random.fold_in(decode_rng, state.step) inferred = decode(state.params, inputs, decode_rng, ctable) questions = ctable.decode_onehot(inputs) infers = ctable.decode_onehot(inferred) goldens = ctable.decode_onehot(outputs) for question, inferred, golden in zip(questions, infers, goldens): log_decode(question, inferred, golden) def train_and_evaluate(workdir: str) -> train_state.TrainState: """Trains for a fixed number of steps and decode during training.""" # TODO(marcvanzee): Integrate ctable with train_state. ctable = CTable('0123456789+= ', FLAGS.max_len_query_digit) rng = jax.random.key(0) state = get_train_state(rng, ctable) writer = metric_writers.create_default_writer(workdir) for step in range(FLAGS.num_train_steps): batch = ctable.get_batch(FLAGS.batch_size) state, metrics = train_step(state, batch, rng, ctable.eos_id) if step and step % FLAGS.decode_frequency == 0: writer.write_scalars(step, metrics) batch = ctable.get_batch(5) decode_batch(state, batch, rng, ctable) return state def main(_): _ = train_and_evaluate(FLAGS.workdir) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/seq2seq/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.examples.seq2seq.train.""" import functools from absl.testing import absltest from flax.training import train_state import jax from jax import random import numpy as np import optax import input_pipeline import train import models jax.config.parse_flags_with_absl() def create_ctable(chars='0123456789+= '): return input_pipeline.CharacterTable(chars) def create_train_state(ctable): model = models.Seq2seq( teacher_force=False, hidden_size=train.FLAGS.hidden_size, vocab_size=ctable.vocab_size, ) params = train.get_initial_params(model, jax.random.key(0), ctable) tx = optax.adam(train.FLAGS.learning_rate) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx ) return state class TrainTest(absltest.TestCase): def test_character_table(self): ctable = create_ctable() text = '410+19' enc_text = ctable.encode(text) dec_text = ctable.decode(enc_text) # The text is possibly padded with whitespace, but the trimmed output should # be equal to the input. self.assertEqual(text, dec_text.strip()) def test_mask_sequences(self): np.testing.assert_equal( input_pipeline.mask_sequences( np.arange(1, 13).reshape((4, 3)), np.array([3, 2, 1, 0]) ), np.array([[1, 2, 3], [4, 5, 0], [7, 0, 0], [0, 0, 0]]), ) def test_get_sequence_lengths(self): oh_sequence_batch = jax.vmap( functools.partial(jax.nn.one_hot, num_classes=4) )(np.array([[0, 1, 0], [1, 0, 2], [1, 2, 0], [1, 2, 3]])) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=0), np.array([1, 2, 3, 3], np.int32), ) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=1), np.array([2, 1, 1, 1], np.int32), ) np.testing.assert_equal( input_pipeline.get_sequence_lengths(oh_sequence_batch, eos_id=2), np.array([3, 3, 2, 2], np.int32), ) def test_train_one_step(self): ctable = create_ctable() batch = ctable.get_batch(128) state = create_train_state(ctable) key = random.key(0) _, train_metrics = train.train_step(state, batch, key, ctable.eos_id) self.assertLessEqual(train_metrics['loss'], 5) self.assertGreaterEqual(train_metrics['accuracy'], 0) def test_decode_batch(self): ctable = create_ctable() batch = ctable.get_batch(5) key = random.key(0) state = create_train_state(ctable) train.decode_batch(state, batch, key, ctable) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/sst2/README.md ================================================ ## SST-2 classification Trains a simple text classifier on the SST-2 sentiment classification dataset. You can run this code and even modify it directly in Google Colab, no installation required: https://colab.research.google.com/github/google/flax/blob/main/examples/sst2/sst2.ipynb ### Requirements * TensorFlow dataset `glue/sst2` will be downloaded and prepared automatically, if necessary. ### Example output | Name | Platform | Epochs | Walltime | Accuracy | Metrics | Workdir | |:--------|:--------|--------:|:-----------|:-----------------|:----------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------| | default | TPU | 10 | 4.3m | 85.21% | [tensorboard.dev](https://tensorboard.dev/experiment/yTQjjRY9RlGRrZzg8h9PJw/) | | ``` INFO:absl:train epoch 010 loss 0.1918 accuracy 92.41 INFO:absl:eval epoch 010 loss 0.4144 accuracy 85.21 ``` ### How to run ```bash python main.py --workdir=/tmp/sst2 --config=configs/default.py` ``` #### Overriding Hyperparameter configurations The SST2 example allows specifying a hyperparameter configuration by means of setting the `--config` flag. The configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py \ --workdir=/tmp/sst2 --config=configs/default.py \ --config.learning_rate=0.05 --config.num_epochs=5 ``` ================================================ FILE: examples/sst2/build_vocabulary.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A vocabulary builder that generates vocab.txt to be used for training.""" import time from collections.abc import Iterable, Sequence from absl import logging import tensorflow as tf import tensorflow_datasets as tfds import tensorflow_text as tftext import vocabulary def get_tokenized_sequences( dataset: tf.data.Dataset, tokenizer: tftext.Tokenizer = tftext.WhitespaceTokenizer(), input_key: str = 'sentence', ) -> Iterable[Sequence[bytes]]: """Returns tokenized sequences for vocabulary building.""" dataset = dataset.map( lambda example: tokenizer.tokenize(example[input_key]), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) yield from tfds.as_numpy(dataset) if __name__ == '__main__': logging.set_verbosity(logging.INFO) start_time = time.time() # Loads the dataset to build the vocabulary from. We use the train split. dataset = tfds.load('glue/sst2', split='train') # Tokenizes the sequences in the dataset and keeps only those. tokenized_sequences = get_tokenized_sequences(dataset) # Builds the vocabulary from the tokenized sequences. # A token needs to appear at least 3 times to be in the vocabulary. You can # play with this. It is there to make sure we don't overfit on rare words. vocab = vocabulary.Vocabulary( tokenized_sequences=tokenized_sequences, min_freq=3 ) vocab.save('vocab.txt') logging.info('Total time elapsed: %f s', time.time() - start_time) ================================================ FILE: examples/sst2/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default hyperparameter configuration for SST-2.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.embedding_size = 300 config.hidden_size = 256 config.vocab_size = None config.output_size = 1 config.vocab_path = 'vocab.txt' config.max_input_length = 60 config.dropout_rate = 0.5 config.word_dropout_rate = 0.1 config.unk_idx = 1 config.learning_rate = 0.1 config.momentum = 0.9 config.weight_decay = 3e-6 config.batch_size = 64 config.bucket_size = 8 config.num_epochs = 10 config.seed = 0 return config ================================================ FILE: examples/sst2/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SST-2 input pipeline.""" import sys from typing import Any, Dict, Optional from absl import logging import numpy as np import tensorflow as tf import tensorflow_datasets as tfds if sys.version_info < (3, 13): import tensorflow_text as text import vocabulary AUTOTUNE = tf.data.experimental.AUTOTUNE Example = dict[str, tf.Tensor] def get_bucket_boundaries(bucket_size: int, max_size: int) -> np.ndarray: """Bucket boundaries with `bucket_size` items per bucket, up to `max_size`. Example: ``` get_bucket_boundaries(8, 24) [9, 17, 25] ``` E.g., the first boundary covers items with sizes 0-8, the next boundary covers items with sizes 9-16, and the last bucket covers sizes 17-24. Each bucket covers 8 different sizes (e.g., sentence lengths). Args: bucket_size: The number of different items per bucket. max_size: The maximum size to be expected for a bucket. Returns: A list of (exclusive) bucket boundaries. """ return np.arange(bucket_size, max_size + bucket_size, bucket_size) + 1 def get_num_examples(dataset: tf.data.Dataset) -> int: """Returns the number of examples in the dataset by iterating over it once.""" return dataset.reduce(np.int64(0), lambda x, _: x + 1).numpy() def get_bucketed_batches( dataset: tf.data.Dataset, batch_size: int, bucket_size: int, max_length: int, padded_shapes: Any, example_size_fn: Any, shuffle: bool = False, shuffle_seed: int | None = None, drop_remainder: bool = False, ) -> tf.data.Dataset: """Returns padded batches of shuffled examples bucketed by length. This shuffles the examples randomly each epoch. The random order is deterministic and controlled by the seed. Batches are padded because sentences have different lengths. Sentences that are shorter in a batch will get 0s added at the end, until all sentences in the batch have the same length. For performance, examples of similar lengths are bucketed together. However, the contents of the buckets and their order is random each epoch, and controlled by the seed. Args: dataset: A TF Dataset with SST examples to be shuffled and batched. batch_size: The size of each batch. The remainder is dropped. bucket_size: How many different lengths go in each bucket. max_length: The maximum length to provide a bucket for. padded_shapes: A nested structure representing the shape to which the respective component of each input element should be padded prior to batching. See `tf.data.Dataset.padded_batch` for examples. example_size_fn: A TF function that returns the size of an example to determine in which bucket it goes. E.g., the sentence length. shuffle: Shuffle the dataset each epoch using seed. shuffle_seed: The seed that determines the shuffling order, with a different order each epoch. drop_remainder: Drop the last batch if it is not of size batch_size. Returns: A TF Dataset containing padded bucketed batches. """ if shuffle: assert shuffle_seed is not None, 'When shuffling you must provide a seed.' # For bucket_size 8 and max length 24, we get bucket boundaries [9, 17, 25]. bucket_boundaries = get_bucket_boundaries(bucket_size, max_length) logging.info('Batching bucket boundaries: %r', bucket_boundaries) # One batch size for each bucket plus one additional one (per requirement). bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1) bucket_fn = tf.data.experimental.bucket_by_sequence_length( example_size_fn, bucket_boundaries, bucket_batch_sizes, padded_shapes=padded_shapes, pad_to_bucket_boundary=True, drop_remainder=drop_remainder, ) if shuffle: # For shuffling we need to know how many training examples we have. num_examples = get_num_examples(dataset) num_batches = num_examples // batch_size return ( dataset.shuffle( num_examples, seed=shuffle_seed, reshuffle_each_iteration=True ) .apply(bucket_fn) .shuffle(num_batches, seed=shuffle_seed, reshuffle_each_iteration=True) .prefetch(tf.data.experimental.AUTOTUNE) ) return dataset.apply(bucket_fn).prefetch(tf.data.experimental.AUTOTUNE) def vocab_to_hashtable( vocab: vocabulary.Vocabulary, unk_idx: int ) -> tf.lookup.StaticHashTable: """Returns a TF lookup table (token -> ID) from a vocabulary.""" return tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( list(vocab.keys()), list(vocab.values()) ), default_value=unk_idx, ) def vocab_to_inverse_hashtable( vocab: vocabulary.Vocabulary, unk_token: bytes ) -> tf.lookup.StaticHashTable: """Returns an inverse TF lookup table (ID -> token) from a vocabulary.""" return tf.lookup.StaticHashTable( tf.lookup.KeyValueTensorInitializer( list(vocab.values()), list(vocab.keys()), key_dtype=tf.int64, value_dtype=tf.string, ), default_value=unk_token, ) def _is_text_field(feature_name_and_type): """Identifies a text field when given a feature (name, type) pair.""" _, feature_type = feature_name_and_type return isinstance(feature_type, tfds.features.Text) def _is_class_label(feature_name_and_type): """Identifies a class label field when given a feature (name, type) pair.""" _, feature_type = feature_name_and_type return isinstance(feature_type, tfds.features.ClassLabel) class TextDataset: """A text dataset with one sequence as input and a label.""" def __init__( self, tfds_name: str = 'glue/sst2', vocab_path: str = 'vocab.txt', tokenizer=None, split='train', ): """Initializes the SST2 data source.""" self.dataset, self.info = tfds.load(tfds_name, split=split, with_info=True) # Look up the feature name of the text and label in the dataset. # We assume there is one text input and one label. text_fields = filter(_is_text_field, self.info.features.items()) label_fields = filter(_is_class_label, self.info.features.items()) self.text_feature_name, _ = next(text_fields) self.label_feature_name, _ = next(label_fields) # Load the vocabulary. self.vocab = vocabulary.Vocabulary(vocab_path=vocab_path) # Convert the sentences to sequences of token IDs and compute length. if tokenizer is None: tokenizer = text.WhitespaceTokenizer() self.tokenizer = tokenizer self.tf_vocab = vocab_to_hashtable(self.vocab, unk_idx=self.vocab.unk_idx) self.examples = self.dataset.map( self.prepare_example, num_parallel_calls=AUTOTUNE ).cache() @property def padded_shapes(self): """The padded shapes used by batching functions.""" # None means variable length; pads to the longest sequence in the batch. return {'idx': [], 'token_ids': [None], 'label': [], 'length': []} def example_length_fn(self, example: Example) -> tf.Tensor: """Returns the length of the example for the purpose of the bucketing.""" return tf.size(example['token_ids']) def add_bos_eos(self, sequence: tf.Tensor) -> tf.Tensor: """Prepends BOS ID and appends EOS ID to a sequence of token IDs.""" return tf.concat([[self.vocab.bos_idx], sequence, [self.vocab.eos_idx]], 0) def prepare_example(self, example: Example) -> Example: """Prepares an example by converting text to token IDs.""" tokens = self.tokenizer.tokenize(example[self.text_feature_name]) label = example[self.label_feature_name] del example[self.text_feature_name] del example[self.label_feature_name] example['token_ids'] = self.add_bos_eos(self.tf_vocab.lookup(tokens)) example['length'] = tf.size(example['token_ids']) example['label'] = label return example def get_batches( self, batch_size: int, drop_remainder: bool = False, shuffle: bool = False, shuffle_seed: int | None = None, fixed_pad_length: int | None = None, dataset: tf.data.Dataset | None = None, ): """Returns an iterator with padded batches for the provided dataset.""" if dataset is None: dataset = self.examples if shuffle: buffer_size = get_num_examples(dataset) dataset = dataset.shuffle( buffer_size, seed=shuffle_seed, reshuffle_each_iteration=True ) padded_shapes = {k: v for k, v in self.padded_shapes.items()} if fixed_pad_length is not None: padded_shapes['token_ids'] = fixed_pad_length return dataset.padded_batch( batch_size, padded_shapes=padded_shapes, drop_remainder=drop_remainder ) def get_bucketed_batches( self, batch_size: int, bucket_size: int, max_input_length: int, drop_remainder: bool = False, shuffle: bool = False, shuffle_seed: int | None = None, dataset: tf.data.Dataset | None = None, ): """Returns an iterator with bucketed batches for the provided dataset.""" if dataset is None: dataset = self.examples return get_bucketed_batches( dataset, batch_size, bucket_size, max_input_length, self.padded_shapes, self.example_length_fn, shuffle=shuffle, shuffle_seed=shuffle_seed, drop_remainder=drop_remainder, ) ================================================ FILE: examples/sst2/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pathlib import sys import tempfile from absl.testing import absltest import tensorflow_datasets as tfds import input_pipeline import vocabulary class InputPipelineTest(absltest.TestCase): def setUp(self): super().setUp() if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') self.vocab_path = self._get_vocab_path() self.dataset = self._get_dataset(self.vocab_path) def _get_vocab_path(self): """Prepares a mock vocabulary and returns a path to it.""" vocab_path = os.path.join(tempfile.mkdtemp(), 'vocab.txt') tokenized_sequences = ( (b'this', b'is', b'a', b'test', b'sentence'), (b'this', b'is', b'a', b'test', b'sentence'), (b'this', b'is', b'a', b'test', b'sentence'), ) vocab = vocabulary.Vocabulary(tokenized_sequences=tokenized_sequences) vocab.save(vocab_path) return vocab_path def _get_dataset(self, vocab_path): """Uses mock data to create the dataset.""" # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): return input_pipeline.TextDataset(vocab_path=vocab_path, split='train') def test_bucketed_dataset(self): """Each batch should have a length that is a multiple of bucket_size.""" batch_size = 2 bucket_size = 8 for batch in self.dataset.get_bucketed_batches( batch_size=batch_size, bucket_size=bucket_size, max_input_length=60, shuffle=False, ).take(3): # Because of bucketing, sequence length must be multiple of bucket_size. length = batch['token_ids'].numpy().shape[-1] self.assertEqual(0, length % bucket_size) self.assertEqual((batch_size,), batch['length'].numpy().shape) self.assertEqual((batch_size,), batch['label'].numpy().shape) def test_batched_dataset(self): """Tests that the length of a batch matches the longest sequence.""" batch_size = 2 for batch in self.dataset.get_batches( batch_size=batch_size, shuffle=False ).take(1): # Each batch is padded to the maximum sentence length in the batch. max_length_in_batch = max(batch['length'].numpy()) length = batch['token_ids'].numpy().shape[-1] self.assertEqual(max_length_in_batch, length) self.assertEqual((batch_size,), batch['length'].numpy().shape) self.assertEqual((batch_size,), batch['label'].numpy().shape) def test_batched_dataset_fixed_length(self): """Tests that each batch has the fixed length.""" batch_size = 2 fixed_pad_length = 77 for batch in self.dataset.get_batches( batch_size=batch_size, shuffle=False, fixed_pad_length=fixed_pad_length ).take(1): length = batch['token_ids'].numpy().shape[-1] self.assertEqual(fixed_pad_length, length) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/sst2/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the SST2 example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) flags.mark_flags_as_required(['config', 'workdir']) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/sst2/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A text classification model.""" import functools from typing import Any, Optional from collections.abc import Callable from flax import linen as nn import jax from jax import numpy as jnp Array = jnp.ndarray def sequence_mask(lengths: Array, max_length: int) -> Array: """Computes a boolean mask over sequence positions for each given length. Example: ``` sequence_mask([1, 2], 3) [[True, False, False], [True, True, False]] ``` Args: lengths: The length of each sequence. [batch_size] max_length: The width of the boolean mask. Must be >= max(lengths). Returns: A mask with shape: [batch_size, max_length] indicating which positions are valid for each sequence. """ return jnp.arange(max_length)[None] < lengths[:, None] @jax.vmap def flip_sequences(inputs: Array, lengths: Array) -> Array: """Flips a sequence of inputs along the time dimension. This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements. Example: ```python inputs = [[1, 0, 0], [2, 3, 0] [4, 5, 6]] lengths = [1, 2, 3] flip_sequences(inputs, lengths) = [[1, 0, 0], [3, 2, 0], [6, 5, 4]] ``` Args: inputs: An array of input IDs [batch_size, seq_length]. lengths: The length of each sequence [batch_size]. Returns: An ndarray with the flipped inputs. """ # Note: since this function is vmapped, the code below is effectively for # a single example. max_length = inputs.shape[0] return jnp.flip(jnp.roll(inputs, max_length - lengths, axis=0), axis=0) class WordDropout(nn.Module): """Applies word dropout to a batch of input IDs. This is basically the same as `nn.Dropout`, but allows specifying the value of dropped out items. """ dropout_rate: float unk_idx: int deterministic: bool | None = None @nn.compact def __call__(self, inputs: Array, deterministic: bool | None = None): deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) if deterministic or self.dropout_rate == 0.0: return inputs rng = self.make_rng('dropout') mask = jax.random.bernoulli(rng, p=self.dropout_rate, shape=inputs.shape) return jnp.where(mask, jnp.array([self.unk_idx]), inputs) class Embedder(nn.Module): """Embeds batches of token IDs into feature space. Attributes: vocab_size: The size of the vocabulary (i.e., the number of embeddings). embedding_size: The dimensionality of the embeddings. embedding_init: The initializer used to initialize the embeddings. frozen: Freezes the embeddings table, keeping it fixed at initial values. dropout_rate: Percentage of units to drop after embedding the inputs. word_dropout_rate: Percentage of input words to replace with unk_idx. unk_idx: The index (integer) to use to replace inputs for word dropout. """ vocab_size: int embedding_size: int embedding_init: Callable[..., Array] = nn.initializers.normal(stddev=0.1) frozen: bool = False dropout_rate: float = 0.0 word_dropout_rate: float = 0.0 unk_idx: int | None = None deterministic: bool | None = None dtype: jnp.dtype = jnp.float32 def setup(self): self.embedding = self.param( 'embedding', self.embedding_init, (self.vocab_size, self.embedding_size), self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.dropout_rate) self.word_dropout_layer = WordDropout( dropout_rate=self.word_dropout_rate, unk_idx=self.unk_idx ) def __call__( self, inputs: Array, deterministic: bool | None = None ) -> Array: """Embeds the input sequences and applies word dropout and dropout. Args: inputs: Batch of input token ID sequences [batch_size, seq_length]. deterministic: Disables dropout when set to True. Returns: The embedded inputs, shape: [batch_size, seq_length, embedding_size]. """ deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) inputs = self.word_dropout_layer(inputs, deterministic=deterministic) embedded_inputs = self.embedding[inputs] # Keep the embeddings fixed at initial (e.g. pretrained) values. if self.frozen: embedded_inputs = jax.lax.stop_gradient(embedded_inputs) return self.dropout_layer(embedded_inputs, deterministic=deterministic) class SimpleLSTM(nn.Module): """A simple unidirectional LSTM.""" hidden_size: int @functools.partial( nn.transforms.scan, variable_broadcast='params', in_axes=1, out_axes=1, split_rngs={'params': False}, ) @nn.compact def __call__(self, carry, x): return nn.OptimizedLSTMCell(self.hidden_size)(carry, x) def initialize_carry(self, input_shape): # Use fixed random key since default state init fn is just zeros. return nn.OptimizedLSTMCell(self.hidden_size, parent=None).initialize_carry( jax.random.key(0), input_shape ) class SimpleBiLSTM(nn.Module): """A simple bi-directional LSTM.""" hidden_size: int def setup(self): self.forward_lstm = SimpleLSTM(self.hidden_size) self.backward_lstm = SimpleLSTM(self.hidden_size) def __call__(self, embedded_inputs, lengths): # Forward LSTM. initial_state = self.forward_lstm.initialize_carry( embedded_inputs[:, 0].shape ) _, forward_outputs = self.forward_lstm(initial_state, embedded_inputs) # Backward LSTM. reversed_inputs = flip_sequences(embedded_inputs, lengths) initial_state = self.backward_lstm.initialize_carry( reversed_inputs[:, 0].shape ) _, backward_outputs = self.backward_lstm(initial_state, reversed_inputs) backward_outputs = flip_sequences(backward_outputs, lengths) # Concatenate the forward and backward representations. outputs = jnp.concatenate([forward_outputs, backward_outputs], -1) return outputs class MLP(nn.Module): """A simple Multilayer perceptron with 1 hidden layer. Attributes: hidden_size: The size of the hidden layer. output_size: The size of the output. activation: The activation function to apply to the hidden layer. dropout_rate: The dropout rate applied to the hidden layer. output_bias: If False, do not use a bias term in the last layer. deterministic: Disables dropout if set to True. """ hidden_size: int output_size: int activation: Callable[..., Any] = nn.tanh dropout_rate: float = 0.0 output_bias: bool = False deterministic: bool | None = None def setup(self): self.intermediate_layer = nn.Dense(self.hidden_size) self.output_layer = nn.Dense(self.output_size, use_bias=self.output_bias) self.dropout_layer = nn.Dropout(rate=self.dropout_rate) def __call__(self, inputs: Array, deterministic: bool | None = None): """Applies the MLP to the last dimension of the inputs. Args: inputs: [batch_size, ..., input_features]. deterministic: Disables dropout when set to True. Returns: The MLP output [batch_size, ..., output_size] """ deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) hidden = self.intermediate_layer(inputs) hidden = self.activation(hidden) hidden = self.dropout_layer(hidden, deterministic=deterministic) output = self.output_layer(hidden) return output class KeysOnlyMlpAttention(nn.Module): """Computes MLP-based attention scores based on keys alone, without a query. Attention scores are computed by feeding the keys through an MLP. This results in a single scalar per key, and for each sequence the attention scores are normalized using a softmax so that they sum to 1. Invalid key positions are ignored as indicated by the mask. This is also called "Bahdanau attention" and was originally proposed in: ``` Bahdanau et al., 2015. Neural Machine Translation by Jointly Learning to Align and Translate. ICLR. https://arxiv.org/abs/1409.0473 ``` Attributes: hidden_size: The hidden size of the MLP that computes the attention score. """ hidden_size: int @nn.compact def __call__(self, keys: Array, mask: Array) -> Array: """Applies model to the input keys and mask. Args: keys: The inputs for which to compute an attention score. Shape: [batch_size, seq_length, embeddings_size]. mask: A mask that determines which values in `keys` are valid. Only values for which the mask is True will get non-zero attention scores. [batch_size, seq_length]. Returns: The normalized attention scores. [batch_size, seq_length]. """ hidden = nn.Dense(self.hidden_size, name='keys', use_bias=False)(keys) energy = nn.tanh(hidden) scores = nn.Dense(1, name='energy', use_bias=False)(energy) scores = scores.squeeze(-1) # New shape: [batch_size, seq_len]. scores = jnp.where(mask, scores, -jnp.inf) # Using exp(-inf) = 0 below. scores = nn.softmax(scores, axis=-1) # Captures the scores if 'intermediates' is mutable, otherwise does nothing. self.sow('intermediates', 'attention', scores) return scores class AttentionClassifier(nn.Module): """A classifier that uses attention to summarize the inputs. Attributes: hidden_size: The hidden size of the MLP classifier. output_size: The number of output classes for the classifier. dropout_rate: The dropout rate applied over the encoded_inputs, the summary of the inputs, and inside the MLP. Applied when `deterministic` is False. deterministic: Disables dropout if True. """ hidden_size: int output_size: int dropout_rate: float = 0.0 deterministic: bool | None = None def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout_rate) self.keys_only_mlp_attention = KeysOnlyMlpAttention( hidden_size=self.hidden_size ) self.mlp = MLP( hidden_size=self.hidden_size, output_size=self.output_size, output_bias=False, dropout_rate=self.dropout_rate, ) def __call__( self, encoded_inputs: Array, lengths: Array, deterministic: bool | None = None, ) -> Array: """Applies model to the encoded inputs. Args: encoded_inputs: The inputs (e.g., sentences) that have already been encoded by some encoder, e.g., an LSTM. [batch_size, seq_length, encoded_inputs_size]. lengths: The lengths of the inputs. [batch_size]. deterministic: Disables dropout when set to True. Returns: An array of logits [batch_size, output_size]. """ deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) encoded_inputs = self.dropout_layer( encoded_inputs, deterministic=deterministic ) # Compute attention. attention.shape: [batch_size, seq_len]. mask = sequence_mask(lengths, encoded_inputs.shape[1]) attention = self.keys_only_mlp_attention(encoded_inputs, mask) # Summarize the inputs by taking their weighted sum using attention scores. context = jnp.expand_dims(attention, 1) @ encoded_inputs context = context.squeeze(1) # [batch_size, encoded_inputs_size] context = self.dropout_layer(context, deterministic=deterministic) # Make the final prediction from the context vector (the summarized inputs). logits = self.mlp(context, deterministic=deterministic) return logits class TextClassifier(nn.Module): """A Text Classification model.""" embedding_size: int hidden_size: int vocab_size: int output_size: int dropout_rate: float word_dropout_rate: float unk_idx: int = 1 deterministic: bool | None = None def setup(self): self.embedder = Embedder( vocab_size=self.vocab_size, embedding_size=self.embedding_size, dropout_rate=self.dropout_rate, word_dropout_rate=self.word_dropout_rate, unk_idx=self.unk_idx, ) self.encoder = SimpleBiLSTM(hidden_size=self.hidden_size) self.classifier = AttentionClassifier( hidden_size=self.hidden_size, output_size=self.output_size, dropout_rate=self.dropout_rate, ) def embed_token_ids( self, token_ids: Array, deterministic: bool | None = None ) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) return self.embedder(token_ids, deterministic=deterministic) def logits_from_embedded_inputs( self, embedded_inputs: Array, lengths: Array, deterministic: bool | None = None, ) -> Array: deterministic = nn.module.merge_param( 'deterministic', self.deterministic, deterministic ) encoded_inputs = self.encoder(embedded_inputs, lengths) return self.classifier(encoded_inputs, lengths, deterministic=deterministic) def __call__( self, token_ids: Array, lengths: Array, deterministic: bool | None = None, ) -> Array: """Embeds the token IDs, encodes them, and classifies with attention.""" embedded_inputs = self.embed_token_ids( token_ids, deterministic=deterministic ) logits = self.logits_from_embedded_inputs( embedded_inputs, lengths, deterministic=deterministic ) return logits ================================================ FILE: examples/sst2/models_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for sst2.models.""" from absl.testing import absltest from absl.testing import parameterized import models import jax import jax.test_util import numpy as np # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class ModelTest(parameterized.TestCase): def test_embedder_returns_correct_output_shape(self): """Tests if the embedder returns the correct shape.""" vocab_size = 5 embedding_size = 3 model = models.Embedder( vocab_size=vocab_size, embedding_size=embedding_size ) rng = jax.random.key(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) output, _ = model.init_with_output(rng, token_ids, deterministic=True) self.assertEqual((token_ids.shape) + (embedding_size,), output.shape) def test_lstm_returns_correct_output_shape(self): """Tests if the simple LSTM returns the correct shape.""" batch_size = 2 seq_len = 3 embedding_size = 4 hidden_size = 5 model = models.SimpleLSTM(5) rng = jax.random.key(0) inputs = np.random.RandomState(0).normal( size=[batch_size, seq_len, embedding_size] ) initial_state = model.initialize_carry(inputs[:, 0].shape) (_, output), _ = model.init_with_output(rng, initial_state, inputs) self.assertEqual((batch_size, seq_len, hidden_size), output.shape) def test_bilstm_returns_correct_output_shape(self): """Tests if the simple BiLSTM returns the correct shape.""" batch_size = 2 seq_len = 3 embedding_size = 4 hidden_size = 5 model = models.SimpleBiLSTM(hidden_size=hidden_size) rng = jax.random.key(0) inputs = np.random.RandomState(0).normal( size=[batch_size, seq_len, embedding_size] ) lengths = np.array([2, 3], dtype=np.int32) outputs, _ = model.init_with_output(rng, inputs, lengths) # We expect 2*hidden_size because we concatenate forward+backward LSTMs. self.assertEqual((batch_size, seq_len, 2 * hidden_size), outputs.shape) def test_text_classifier_returns_correct_output_shape(self): """Tests if a TextClassifier model returns the correct shape.""" embedding_size = 3 hidden_size = 7 vocab_size = 5 output_size = 3 dropout_rate = 0.1 word_dropout_rate = 0.2 unk_idx = 1 model = models.TextClassifier( embedding_size=embedding_size, hidden_size=hidden_size, vocab_size=vocab_size, output_size=output_size, dropout_rate=dropout_rate, word_dropout_rate=word_dropout_rate, unk_idx=unk_idx, deterministic=True, ) rng = jax.random.key(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) lengths = np.array([2, 3], dtype=np.int32) output, _ = model.init_with_output(rng, token_ids, lengths) batch_size = token_ids.shape[0] self.assertEqual((batch_size, output_size), output.shape) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/sst2/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.3.6 ml-collections==0.1.0 numpy==1.22.0 optax==0.1.0 tensorflow==2.11.1 tensorflow-datasets==4.4.0 tensorflow-text==2.7.0 ================================================ FILE: examples/sst2/sst2.ipynb ================================================ { "cells": [ { "id": "29fb3c7c", "cell_type": "markdown", "source": [ "# Flax SST-2 Example\n", "\n", "\u003ca href=\"https://colab.research.google.com/github/google/flax/blob/main/examples/sst2/sst2.ipynb\" \u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n", "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/sst2" ], "metadata": {}, "execution_count": null }, { "id": "5b526b04", "cell_type": "markdown", "source": [ "**Before you start:** Select Runtime -\u003e Change runtime type -\u003e GPU." ], "metadata": {}, "execution_count": null }, { "id": "89ec78c1", "cell_type": "markdown", "source": [ "The **Flax Notebook Workflow**:\n", "\n", "1. Run the entire notebook end-to-end and check out the outputs.\n", " - This will open Python files in the right-hand editor!\n", " - You'll be able to interactively explore metrics in TensorBoard.\n", "2. Change `config` and train for different hyperparameters. Check out the\n", " updated TensorBoard plots.\n", "3. Update the code in `train.py`. Thanks to `%autoreload`, any changes you\n", " make in the file will automatically appear in the notebook. Some ideas to\n", " get you started:\n", " - Change the model.\n", " - Log some per-batch metrics during training.\n", " - Add new hyperparameters to `configs/default.py` and use them in\n", " `train.py`.\n", "4. At any time, feel free to paste code from `train.py` into the notebook\n", " and modify it directly there!" ], "metadata": {}, "execution_count": null }, { "id": "7e4ba0dc", "cell_type": "markdown", "source": [ "## Setup" ], "metadata": {}, "execution_count": null }, { "id": "ee8021b9", "cell_type": "code", "source": [ "example_directory = 'examples/sst2'\n", "editor_relpaths = ('configs/default.py', 'train.py', 'models.py')" ], "metadata": {}, "execution_count": null }, { "id": "36dab290", "cell_type": "code", "source": [ "# (If you run this code in Jupyter[lab], then you're already in the\n", "# example directory and nothing needs to be done.)\n", "\n", "#@markdown **Fetch newest Flax, copy example code**\n", "#@markdown\n", "#@markdown **If you select no** below, then the files will be stored on the\n", "#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will\n", "#@markdown be restarted an any changes are lost**.\n", "#@markdown\n", "#@markdown **If you select yes** below, then you will be asked for your\n", "#@markdown credentials to mount your personal Google Drive. In this case, all\n", "#@markdown changes you make will be *persisted*, and even if you re-run the\n", "#@markdown Colab later on, the files will still be the same (you can of course\n", "#@markdown remove directories inside your Drive's `flax/` root if you want to\n", "#@markdown manually revert these files).\n", "\n", "if 'google.colab' in str(get_ipython()):\n", " import os\n", " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", " pass\n", " !git clone --depth=1 https://github.com/google/flax flaxrepo\n", " # Copy example files \u0026 change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", " if mount_gdrive == 'yes':\n", " DISCLAIMER = 'Note: Editing in your Google Drive, changes will persist.'\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'\n", " else:\n", " DISCLAIMER = 'WARNING: Editing in VM - changes lost after reboot!!'\n", " example_root_path = f'/content/{example_directory}'\n", " from IPython import display\n", " display.display(display.HTML(\n", " f'\u003ch1 style=\"color:red;\" class=\"blink\"\u003e{DISCLAIMER}\u003c/h1\u003e'))\n", " if not os.path.isdir(example_root_path):\n", " os.makedirs(example_root_path)\n", " !cp -r flaxrepo/$example_directory/* \"$example_root_path\"\n", " os.chdir(example_root_path)\n", " from google.colab import files\n", " for relpath in editor_relpaths:\n", " s = open(f'{example_root_path}/{relpath}').read()\n", " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" ], "metadata": {}, "execution_count": null }, { "id": "700d9428", "cell_type": "code", "source": [ "# Note: In Colab, above cell changed the working directory.\n", "!pwd" ], "metadata": {}, "execution_count": null }, { "id": "2fbc3e64", "cell_type": "code", "source": [ "# Install SST-2 dependencies.\n", "!pip install -q -r requirements.txt" ], "metadata": {}, "execution_count": null }, { "id": "e40c50cf", "cell_type": "markdown", "source": [ "## Imports / Helpers" ], "metadata": {}, "execution_count": null }, { "id": "703f04fb", "cell_type": "code", "source": [ "import os\n", "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", "\n", "import jax\n", "jax.devices()" ], "metadata": {}, "execution_count": null }, { "id": "32e12a97", "cell_type": "code", "source": [ "from absl import logging\n", "import flax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "import time\n", "logging.set_verbosity(logging.INFO)\n", "\n", "# Make sure the GPU is for JAX, not for TF.\n", "tf.config.experimental.set_visible_devices([], 'GPU')" ], "metadata": {}, "execution_count": null }, { "id": "94ece24d", "cell_type": "code", "source": [ "# Local imports from current directory - auto reload.\n", "# Any changes you make to train.py will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", "try:\n", " import train\n", " import models\n", " import vocabulary\n", " import input_pipeline\n", " from configs import default as config_lib\n", "except ModuleNotFoundError:\n", " # Local imports may not be available in all contexts\n", " pass\n", "config = config_lib.get_config()" ], "metadata": { "tags": [] }, "execution_count": null }, { "id": "c8a6ec00", "cell_type": "markdown", "source": [ "## Dataset" ], "metadata": {}, "execution_count": null }, { "id": "5c615ca6", "cell_type": "code", "source": [ "# Get datasets. \n", "# If you get an error you need to install tensorflow_datasets from Github.\n", "train_dataset = input_pipeline.TextDataset(split='train')\n", "eval_dataset = input_pipeline.TextDataset(split='validation')" ], "metadata": { "tags": [] }, "execution_count": null }, { "id": "7d7c55cd", "cell_type": "markdown", "source": [ "## Training" ], "metadata": {}, "execution_count": null }, { "id": "df9d52ed", "cell_type": "code", "source": [ "# Get a live update during training - use the \"refresh\" button!\n", "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n", "if 'google.colab' in str(get_ipython()):\n", " pass\n", " %load_ext tensorboard\n", " %tensorboard --logdir=." ], "metadata": {}, "execution_count": null }, { "id": "f159d072", "cell_type": "code", "source": [ "config.num_epochs = 10\n", "model_name = 'bilstm'\n", "start_time = time.time()\n", "optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n", "logging.info('Walltime: %f s', time.time() - start_time)" ], "metadata": { "tags": [] }, "execution_count": null }, { "id": "b8e35c72", "cell_type": "code", "source": [ "if 'google.colab' in str(get_ipython()):\n", " #@markdown You can upload the training results directly to https://tensorboard.dev\n", " #@markdown\n", " #@markdown Note that everbody with the link will be able to see the data.\n", " upload_data = 'yes' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", " pass\n", " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'" ], "metadata": { "cellView": "form", "tags": [] }, "execution_count": null } ], "metadata": { "accelerator": "GPU" }, "nbformat_minor": 0, "nbformat": 4 } ================================================ FILE: examples/sst2/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Trains an SST2 text classifier.""" from typing import Any, Dict, Optional, Tuple, Union from collections.abc import Callable, Iterable, Sequence from absl import logging from flax import struct from flax.metrics import tensorboard from flax.training import train_state import jax import jax.numpy as jnp import ml_collections import numpy as np import optax import tensorflow as tf import input_pipeline import models Array = jnp.ndarray Example = dict[str, Array] TrainState = train_state.TrainState class Metrics(struct.PyTreeNode): """Computed metrics.""" loss: Array accuracy: Array count: int | None = None @jax.vmap def sigmoid_cross_entropy_with_logits(*, labels: Array, logits: Array) -> Array: """Sigmoid cross entropy loss.""" zeros = jnp.zeros_like(logits, dtype=logits.dtype) condition = logits >= zeros relu_logits = jnp.where(condition, logits, zeros) neg_abs_logits = jnp.where(condition, -logits, logits) return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits)) def get_initial_params(rng, model): """Returns randomly initialized parameters.""" token_ids = jnp.ones((2, 3), jnp.int32) lengths = jnp.ones((2,), dtype=jnp.int32) variables = model.init(rng, token_ids, lengths, deterministic=True) return variables['params'] def create_train_state(rng, config: ml_collections.ConfigDict, model): """Create initial training state.""" params = get_initial_params(rng, model) tx = optax.chain( optax.sgd(learning_rate=config.learning_rate, momentum=config.momentum), optax.add_decayed_weights(weight_decay=config.weight_decay), ) state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) return state def compute_metrics(*, labels: Array, logits: Array) -> Metrics: """Computes the metrics, summed across the batch if a batch is provided.""" if labels.ndim == 1: # Prevent the labels from broadcasting over the logits. labels = jnp.expand_dims(labels, axis=1) loss = sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) binary_predictions = logits >= 0.0 binary_accuracy = jnp.equal(binary_predictions, labels) return Metrics( loss=jnp.sum(loss), accuracy=jnp.sum(binary_accuracy), count=logits.shape[0], ) def model_from_config(config: ml_collections.ConfigDict): """Builds a text classification model from a config.""" model = models.TextClassifier( embedding_size=config.embedding_size, hidden_size=config.hidden_size, vocab_size=config.vocab_size, output_size=config.output_size, dropout_rate=config.dropout_rate, word_dropout_rate=config.word_dropout_rate, unk_idx=config.unk_idx, ) return model def train_step( state: TrainState, batch: dict[str, Array], rngs: dict[str, Any], ) -> tuple[TrainState, Metrics]: """Train for a single step.""" # Make sure to get a new RNG at every step. step = state.step rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()} def loss_fn(params): variables = {'params': params} logits = state.apply_fn( variables, batch['token_ids'], batch['length'], deterministic=False, rngs=rngs, ) labels = batch['label'] if labels.ndim == 1: labels = jnp.expand_dims(labels, 1) loss = jnp.mean( sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) ) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) value, grads = grad_fn(state.params) (_, logits) = value new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(labels=batch['label'], logits=logits) return new_state, metrics def eval_step( state: TrainState, batch: dict[str, Array], rngs: dict[str, Any] ) -> Metrics: """Evaluate for a single step. Model should be in deterministic mode.""" variables = {'params': state.params} logits = state.apply_fn( variables, batch['token_ids'], batch['length'], deterministic=True, rngs=rngs, ) metrics = compute_metrics(labels=batch['label'], logits=logits) return metrics def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: """Consolidates and normalizes a list of per-batch metrics dicts.""" # Here we sum the metrics that were already summed per batch. total_loss = np.sum([metrics.loss for metrics in batch_metrics]) total_accuracy = np.sum([metrics.accuracy for metrics in batch_metrics]) total = np.sum([metrics.count for metrics in batch_metrics]) # Divide each metric by the total number of items in the data set. return Metrics( loss=total_loss.item() / total, accuracy=total_accuracy.item() / total ) def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]: """Converts a batch with TF tensors to a batch of NumPy arrays.""" # _numpy() reuses memory, does not make a copy. # pylint: disable=protected-access return jax.tree_util.tree_map(lambda x: x._numpy(), batch) def evaluate_model( eval_step_fn: Callable[..., Any], state: TrainState, batches: Iterable[Example] | tf.data.Dataset, epoch: int, rngs: dict[str, Any] | None = None, ) -> Metrics: """Evaluate a model on a dataset.""" batch_metrics = [] for i, batch in enumerate(batches): batch = batch_to_numpy(batch) if rngs is not None: # New RNG for each step. rngs = {name: jax.random.fold_in(rng, i) for name, rng in rngs.items()} metrics = eval_step_fn(state, batch, rngs) batch_metrics.append(metrics) batch_metrics = jax.device_get(batch_metrics) metrics = normalize_batch_metrics(batch_metrics) logging.info( 'eval epoch %03d loss %.4f accuracy %.2f', epoch, metrics.loss, metrics.accuracy * 100, ) return metrics def train_epoch( train_step_fn: Callable[..., tuple[TrainState, Metrics]], state: TrainState, train_batches: tf.data.Dataset, epoch: int, rngs: dict[str, Any] | None = None, ) -> tuple[TrainState, Metrics]: """Train for a single epoch.""" batch_metrics = [] for batch in train_batches: batch = batch_to_numpy(batch) state, metrics = train_step_fn(state, batch, rngs) batch_metrics.append(metrics) # Compute the metrics for this epoch. batch_metrics = jax.device_get(batch_metrics) metrics = normalize_batch_metrics(batch_metrics) logging.info( 'train epoch %03d loss %.4f accuracy %.2f', epoch, metrics.loss, metrics.accuracy * 100, ) return state, metrics def train_and_evaluate( config: ml_collections.ConfigDict, workdir: str ) -> TrainState: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. workdir: Directory where the tensorboard summaries are written to. Returns: The final train state that includes the trained parameters. """ # Prepare datasets. train_dataset = input_pipeline.TextDataset( tfds_name='glue/sst2', split='train' ) eval_dataset = input_pipeline.TextDataset( tfds_name='glue/sst2', split='validation' ) train_batches = train_dataset.get_bucketed_batches( config.batch_size, config.bucket_size, max_input_length=config.max_input_length, drop_remainder=True, shuffle=True, shuffle_seed=config.seed, ) eval_batches = eval_dataset.get_batches(batch_size=config.batch_size) # Keep track of vocab size in the config so that the embedder knows it. config.vocab_size = len(train_dataset.vocab) # Compile step functions. train_step_fn = jax.jit(train_step) eval_step_fn = jax.jit(eval_step) # Create model and a state that contains the parameters. rng = jax.random.key(config.seed) model = model_from_config(config) state = create_train_state(rng, config, model) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) # Main training loop. logging.info('Starting training...') for epoch in range(1, config.num_epochs + 1): # Train for one epoch. rng, epoch_rng = jax.random.split(rng) rngs = {'dropout': epoch_rng} state, train_metrics = train_epoch( train_step_fn, state, train_batches, epoch, rngs ) # Evaluate current model on the validation data. eval_metrics = evaluate_model(eval_step_fn, state, eval_batches, epoch) # Write metrics to TensorBoard. summary_writer.scalar('train_loss', train_metrics.loss, epoch) summary_writer.scalar('train_accuracy', train_metrics.accuracy * 100, epoch) summary_writer.scalar('eval_loss', eval_metrics.loss, epoch) summary_writer.scalar('eval_accuracy', eval_metrics.accuracy * 100, epoch) summary_writer.flush() return state ================================================ FILE: examples/sst2/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for sst2.train.""" import sys from absl.testing import absltest from absl.testing import parameterized import jax import jax.test_util import numpy as np from configs import default as default_config import train # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class TrainTest(parameterized.TestCase): def test_train_step_updates_parameters(self): """Tests if the train step updates the parameters in train state.""" # Create model and a state that contains the parameters. config = default_config.get_config() config.vocab_size = 13 rng = jax.random.key(config.seed) model = train.model_from_config(config) state = train.create_train_state(rng, config, model) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) lengths = np.array([2, 3], dtype=np.int32) labels = np.zeros_like(lengths) batch = {'token_ids': token_ids, 'length': lengths, 'label': labels} rngs = {'dropout': rng} train_step_fn = jax.jit(train.train_step) new_state, metrics = train_step_fn(state, batch, rngs) self.assertIsInstance(new_state, train.TrainState) self.assertIsInstance(metrics, train.Metrics) old_param_values = jax.tree_util.tree_leaves(state.params) new_param_values = jax.tree_util.tree_leaves(new_state.params) for old_array, new_array in zip(old_param_values, new_param_values): # Make sure parameters were updated. self.assertFalse(np.allclose(old_array, new_array)) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/sst2/vocab.txt ================================================ the , a and of . to 's is that in it as with an film its for movie this you but be on n't by more -- one at than has not about his from are like so or all have most story ' good ... into out too who ) up characters i funny ( comedy if just no does much what can even ` your their will time some `` bad little '' very way which best any love been life make work enough there only he makes us new movies never something do they humor through was well action great would own made director many we really performances plot drama her how could films sense see such better other fun audience people every off two without cast nothing feel both when being look character may should entertaining acting ever real often performance them long : while still world because script also interesting another heart kind 're those dialogue hollywood minutes watch first screen down few get big over far thriller might hard less human moments actors tale compelling romantic cinema had rather year family almost material end watching seen - worth 've seem itself picture original take before my documentary seems were emotional after our quite find old these visual comes man things back fascinating moving sweet right works between feels here scenes full come piece direction care yet ; music go dull me going takes special ultimately years ca young keep making 'll anything laughs american times why smart worst give comic experience enjoyable least cinematic lot part where beautiful entertainment history style sometimes thing though art clever kids away gives again him together bit she intelligence dark gets idea amusing engaging theater powerful same genre intelligent once star women energy subject did charming surprisingly actually summer anyone charm screenplay want point filmmaking place short narrative solid pretty flick around feeling nearly feature silly simply whose manages strong face predictable enjoy think truly war wit offers say show deeply goes know perfect satisfying then fans power need whole always effort becomes done spirit fresh beautifully premise true trying half quirky since three filmmakers suspense dramatic hilarious portrait tone horror last under fine interest effects flat rare high rich series hours probably children everyone romance ideas touching ? familiar looking modern remarkable study 'd especially imagination pleasure wonderful boring classic easy everything small exercise leave set instead level title honest stuff culture past dumb intriguing tv wo filmmaker light video actor already turn audiences sad storytelling lack matter recent stories mind obvious despite put talent written ending french images memorable project terrific visually serious adventure completely woman become opera beauty camera gentle likely talented looks emotionally mess fails day ride slow sure cliches cold having head himself reason beyond directed gorgeous inside jokes left mr. bland men melodrama proves shot impossible low ways easily run above hour stupid thoughtful contrived excellent must simple complex debut different else eyes tired ugly de fairly lacks otherwise viewer believe brilliant comedies each shows sort viewers passion warmth attempt black certainly particularly turns writing play violence welcome wrong cheap formula lost social genuine soap themselves animation book crime delightful either personal plays role sequences thoroughly hero line barely historical impressive sex version appealing fact home nor seeing along gags quality change clichés got worse adults found lives middle old-fashioned surprising ambitious death engrossing girl message next running creative fantasy important able live now pretentious worthy ! 'm decent psychological sequel ends imagine nice tragedy warm entirely michael none perfectly act creepy waste deep remains sit concept inventive job journey laugh pictures rock unfunny unsettling cool insight painful try usual vision winning attention convincing john unique bring moral neither mystery satire stylish against believable knows leaves master nature side success thin artist awful elements lacking lead reality seat tedious mood shallow working appeal considerable epic falls moment period provocative situations view create days political sentimental cinematography endearing hackneyed road scene sensitive watchable colorful earnest finally getting greatest highly parents difficult help loud uses amount call delivers dry game pace production spy understand utterly witty adaptation flaws future hold mostly possible relationships scary second sharp stand themes cliché depth emotions keeps masterpiece odd sci-fi skill throughout during edge given gone poorly several successful sustain touch tries animated coming-of-age drag examination four mindless pointless promise use disney episode flicks hit tell captures definitely execution form magic major pieces score subtle surprise twists unexpected whether complete deserves ensemble grand intense ones poor pure spielberg across actress boy depressing intimate memory oscar pacing problem somewhat straight ability approach career couple exciting forced intrigue lovely surprises thinking told york delicate goofy manner offer recommend starts taste tension thought tragic affection brings final grace mediocre sexy contemporary eye female heartfelt humanity occasionally overcome sets slightly someone tribute words air allows core cultural equally evil fully generic hand manipulative nicely pleasant plenty purpose remake stale strange today twisted uneven worthwhile case events failure felt frame funnier money monster painfully playful sophisticated striking telling vivid weird wild car crafted delight doing huge needs offensive room storyline terrible clear coming expect let living mean motion playing poetry remarkably single soul thanks urban acted casting clarity devoid insightful inspiring sounds supposed ultimate among class forget imaginative justice rarely clearly company energetic exactly issues quiet sandler wants chilling close dead dog figure funniest genuinely party refreshing respect runs situation smile spiritual strangely toward wildly called country empty extraordinary festival guys minor night product realism school sincere sloppy step tiresome apart chemistry crowd cut meditation mix share sitting taking teen treat truth viewing wasted writer-director yourself ago alone annoying begins effective gripping ms. originality politics ridiculous wonderfully add ages brain cute dream entire name problems questions superficial target universal value wise 90 chance conclusion derivative excitement fare fire ii interested lame loss merely plain sexual sick twice 10 absolutely amateurish america created detailed extremely first-rate flawed impact kid latest liked puts shots slapstick steven succeeds taken unnecessary well-made attraction battle british except hardly lots natural popcorn robert sappy start suffers triumph date frustrating heartwarming horrible joy mark meandering places previous read served spectacle typical usually alive attempts chan city credits disaster downright dreary hope incredibly mesmerizing novel opportunity performers seriously treatment behind certain forgettable frequently girls holds indie overall realistic refreshingly rest stars suspenseful top used absorbing allen balance brown credit delivery elegant grant guilty inept oddly provides riveting stunning trip weak wish yarn avoid conventional conviction dazzling doubt finds franchise person pop process reveals shame superb until visuals waiting 2002 amazing badly based crisp david editing guy lazy please poetic thrills trouble 2 achievement although bullock college commercial complicated damned false hate identity including leads maybe subjects wonder admirable b behavior bore bright comedic creates dreams escape exquisite fashion fears listless ludicrous match number routine sequels skin substance undeniably cleverly favor lousy means murphy nasty pathetic rip-off talents thought-provoking bizarre cause excuse fiction glory grief inspired mainstream pay poignant quickly sentimentality sheer standard totally various water win writer zone age authentic business central cheesy damage deal die evocative faith gay happy hell holes philosophical promising repetitive sides starring superior tender voice aspects beneath charisma color creativity devastating fall forgotten friendship front gifted glimpse heavy humorous largely members reach remember revealing sequence sound subtlety theaters blue cruel delightfully deserve designed directors fast filled five flair frank general gross-out haunting intended involving others perspective points raw unpleasant word accomplished confusing document endless exhilarating fit growing intensity laughter levels particular retread sign songs spark stay stock studio virtually acts artificial budget challenging cliched consistently entertain exploration feel-good finding force formulaic generally george happens jackson lines modest murder observations possibly provide quietly rhythms stereotypes theme utter vibrant accessible affair apparent bears decades essentially farce fatal group happen house imax marvelous noir offering pull rhythm soundtrack structure belly bored break buoyant crazy delights deliver fake favorite hip industry joke lets missing obnoxious persona potential sight slight sum television unconvincing upon addition answers chase clumsy enjoyed extreme flashy further hugely irritating loses low-budget main masterful mention niro red rewarding saw south spent timely vehicle washington achieves astonishing attractive beginning body bother carries community computer deadly disappointing eloquent empathy era frustration gradually hearts heaven logic mixed necessary older plotting season self-conscious sitcom struggle talking tells unimaginative visceral wait crap directorial expected green hoffman incoherent king laughing lived marks mayhem mildly move moviegoers onto pair relationship represents sea slice source terms undeniable unfocused unpredictable uplifting violent wannabe wholly wife 20 absurd adult affecting bond cartoon christmas crude development exploitative fan fat gangster holiday however joyous large melodramatic offbeat outstanding poem resist saturday save setting standards supporting tear wedding amounts assured boys choppy deft emotion fantastic games grandeur horrifying intoxicating list lyrical mature office pleasures precious reading roll sensual smarter somehow stuck sympathy tough vulgar weight wickedly atmosphere brought captivating charmer conflict content courage crush edited endure kiddie lady length machine mother pale passionate pat personality price realize revelatory seagal shoot soderbergh tears trademark turned vapid walk well-crafted wry $ average banal check collection detail digital dirty endeavor guns hits hopeful intricate killer lawrence missed murky pack pain positive potentially realized rendered resonant service simplistic sink sour splendid stage teenagers unlikable unoriginal vividly yes absurdity adam biting bittersweet blade brothers bunch charismatic constructed deftly eventually filmed food freshness gem heartbreaking japanese late ordinary poignancy portrayal somewhere spirits strength wars americans blend combination cult damn dvd effect emerges enthusiasm essence example fear footage frightening god grow instantly involved loved lush minds mixture plodding present proceedings profound pulls results richer said self-indulgent sell simultaneously strain street surface territory touches unfolds willing 9 artists band bold breathtaking capture commentary compassion daring deeper dignity exceptional faster fight friend honestly knowledge leaving meaningful notice peter queen release return rush soon spooky sports trite whatever adventurous angst awkward bite bottom capable changing channel child condition desperate details determined drawn endlessly fanciful fate immensely increasingly indian insights kevin magnificent measure moore nightmare nowhere odds over-the-top overly paced pic preachy predecessor presents quick ryan sadly sane smug soulful space spectacular stirring string sweetly sweetness team timing twist unforced unusual values wears white williams animal b-movie bitter broad carried comfortable committed delivered demands dish features friends grab ground harrowing heavy-handed inspiration james knowing laughable literary losing magical movements needed paper-thin patience polished pretends protagonist question rank relentlessly shapeless type understanding vital week whom wow amused appalling artistic asks bigger burns cliche confident convinced crass credible early faithful friday grows hair hands happening homage italian longer martin monty nomination pass prove public real-life relatively relief resourceful sadness science scorsese sketchy startling strike sympathetic third tiny top-notch uncompromising unfortunately ambition apparently appear award broken celebration creating degree despair direct-to-video disguise exploitation fair finale finish flimsy forgive gritty happiness hip-hop honesty innocence key lesson loves musical pleasing presence roles screenwriter seek serves signs snow sparkling state teacher ticket uninspired winner witch witless ahead aplomb appears barbershop belt biggest brutally buy charms cloying concerned cop corny cynical depiction developed earnestness enthusiastic escapism event fights folks giant guessing hastily imagery impression inevitable leaden lee maudlin merit metaphor note outrageous perverse recycled relies report rooting sand scenery seconds shocking shooting skip slap strikes structured suffering technical thrill tortured treasure trick walked worked & abandon accept admire alternative benigni breathtakingly characterizations charmless chilly cohesive comfort communal convenient delicious delivering desperately doze drive dumbed-down executed feat fighting giving herzog insanely intentions irresistible juvenile lifeless lingering lively male manage messages miss mysterious nevertheless niche overwrought potent presented racial reasonably references soulless spirited staged terribly terrifying testament tom total unfaithful urge vampire variety verve wilde adventures astonishingly below breaks brutal came clumsily constant deceptively describe desire display dubbed exceedingly forces free good-natured heard integrity international limited moviemaking national ode park pile powers psychology quirks rises sake secret shelf silliness smartly son spite splendor stop theatrical villain visit warning alternately appreciate basic belongs blast chaotic chris coherent connection demanding directing distasteful draws dullest earth embarrassment exotic feces gently graphic halfway helps hot jason labored landscape locations looked mouth occasional overwhelming phony played plots plotted police private produced proud remain rent result revenge rising roger scratch screenwriting sensational shamelessly significant somber spare spell straightforward suck sweeping tasty touched traditional trash turning unlikely wallace x younger achingly analyze blood box buried childhood choices considered control course disappointed discovery diversion dogs drugs due dying eastwood embraces enthralling equivalent evening extended factor farts flow handled imitation indeed ingenious inherent jones motivated nation near open openness paper polanski professional race richly rowdy saccharine safe screaming screenwriters showtime shyamalan slack society soggy suited thumbs tradition tricks unlike wanting weirdly wooden afraid ambiguous area artifice attract beach behold broomfield candy capacity cheese chinese consider delicately efforts emerge engaged expense experiences fierce fish follow generous gloriously greek handsome held inane incompetent inspirational intermittently jim junk lackluster language maintains martha nonsensical onscreen path peculiar perhaps personalities pick players politically quest relevant required rewards sees sentiment started substantial throws tragedies transcends uncanny unflinching using vague vitality whimsy window within wondering woody 90-minute actresses beat bringing complexity destination devices engage english enter explore fantasies father fill finely fuzzy garbage generation insipid inspire intellectual intimacy iranian keeping kissinger kline knew layers locales mainly makers marriage mere misery morning myself photography possibilities ranks rehash returns revolution rude rushed sarandon scream serve slapdash slip slowly sluggish stays stomach subtly sumptuous technology term thematic topic trailer trappings tremendous uncommonly understands verbal videos went won *** arts aside astounding attitude awe bag begin breakthrough breezy buffs celebrates celluloid choice cleverness complications context convince cover daily dealing disgusting distance distinctive distinguished drags dramatically dysfunctional edgy elaborate embarrassed encouraging existential familial famous firing graceless holocaust hoping improbable incisive jack likable loose lopez low-key mean-spirited mild overlong packed paid parker pleasurable portrays problematic proper refusal regret released reminds secrets shaped shrewd sincerity singing stands station stereotypical succeed supremely talk text unbelievable unexpectedly vs. weeks weirdness wind / accents actual adolescent altogether ask brilliantly burn camp century chances chills count dancing desperation diverting drawing eager eat elegantly enormously excess existed exit expectation exposition fisher following genial giddy gyllenhaal harry hokum hopkins humorless hybrid ignored incoherence inconsequential infidelity instincts insulting invention kicking kiss kong lends loads lovers mistake movement neat nerve newcomer pathology planet redundant reminiscent replaced resolutely respectable responsible satirical savvy scenario seats settles shake shocks subculture surely talky technically timeless trees trifle web whimsical adds aims ballot blair bloody cartoonish clueless concern confused copy creation crucial curiously dismissed downbeat elevate enduring enormous existence finest freaks gangs ghetto graceful hallmark highest historically hunter husband impeccable insult intolerable jarring jolt martial meaning midnight noise notion obviously paul pity ponder popular predictably random record scattered severe shining similarly solondz souls spoof titles unfulfilling unusually useless ya-ya '70s accurate aimless blind blockbuster brooding chops clue cobbled comparison continues critics denial design dickens divine domestic dragon easier eddie enterprise exploiting flavor fluid former frailty handful heroine implausible inviting jackie learn marginal masterfully middle-aged miracle moved muddled mythic necessarily noble nuanced observation phenomenal pianist ponderous pow preposterous presentation protagonists providing rating resonance rollicking rote sade satisfies saving seven similar sketch song sordid sorvino spend stilted stretches teenage thinks trapped unabashedly vanity witherspoon woo absolute academy account added allow and/or anybody artful blank britney broadway built calm caper cares caught chabrol classics collapses combined competent conservative construct creature curiosity current dangerous deadpan demented disposable drop elegance empowerment enjoyment equilibrium evelyn fancy figures frozen frustratingly ghost gorgeously gratuitous grown-up hong hugh hype incongruous inhabit innovative kill kung lump malaise manufactured marvelously morality morally mushy obligatory outside pitch-perfect postcard poverty pretension primarily radical recipe recovery remembering remote replete scripts spider-man stinker sudden surrounding swallow tasteful tepid translation vincent wave 30 appropriate awkwardness bill bits blandness bogs boredom breath caring catch cautionary charge classical common cross dawns dawson dim-witted doses drown ease effectively element enjoyably equal esther evident excesses expressive filling fluff freeman harsh hideously hoot impersonal impressed incessant inconsistent irish ironic kidman known land languorous larger letting lighting loneliness lurid machismo madonna memories misogyny moves musicians notable numbers observant pantheon pertinent psyche pulling pulpy recording redemption reel reign reputation rise ruined seeking sensitivity shenanigans sleep sneaks sobering somebody spectacularly spot stiff store stunts thick travel truthful unsentimental villains well-written wet willis winds wonders 1/2 101 a-list ache admit aging album alienation baffling becoming blown bone bound bracing busy carol carvey catches concert credibility crimes dahmer darker demons demonstrates difference eerie emptiness exceptionally expectations express families fiennes fills fits flashes fragile godard heal howard ice idiotic immediate impressions jennifer longing misfire moody naturally nonsense panache parody phones piano positively pretensions properly pungent ravishing run-of-the-mill santa scope self-consciously shameless shanghai shines shock slick sly solidly sophomoric spin static steal stylistic table textbook toilet torture underbelly understated vintage wanted whiny wrapped 88-minute abuse advantage affirming anger antwone arnold artistry artsy aspect authority awfully balanced beast blow capturing carry cell characteristic choose claustrophobic climactic clothes concession contradictory contribution contrivances convey cost costly dare droll dud erotic ethnic exaggerated fable failing fallen filming flamboyant flourishes format good-hearted growth hack happened heroes high-concept higher hitting hole hollow holly holm ideal idiots illuminating incomprehensible insistent irony justify kinda korean layered lessons liberating lie lifts lightweight load loving lust market massive meaningless meat mile momentum moviegoing nose objective organic payoff pays pit pointed pokemon profile punch quarter raised remembered rohmer seemingly self-discovery sensibilities shockingly sisters snipes speak splash stately status struggled stunt subversive susan swinging teeth test thousands thrillers tragically treats uncertain uncomfortable uninteresting untalented wholesome wide wondrous 1 15 arrive awkwardly background beings believing blatant bodies bombastic brother bruce building cable chest childlike chord christian clichéd closer conceits consciousness contains conveying corner criminal dentist determination directs disappointment disjointed documentaries e.t. ellen entertained excruciating foul france generate generates glaring hedonistic hilarity horribly humane image improvised incredible informative innocent insane insultingly kaufman kwan later latter lunacy matters monsters narcissism news nobody obscure opaque ourselves outtakes paints parable parts perceptive pitfalls plotless pokes pompeo predecessors proportions purely ralph recommendation recommended rely reviews ring sacrificing scores search self-satisfied sensibility separate showing shrill simone sinks slo-mo spring squanders subplots sugar surreal survival tediously tends thrilling tomorrow tones torn unabashed unconventional unintentionally unnerving unsympathetic upper victims wear woefully writer/director accent action-packed african-americans alabama alluring amazingly animals art-house assassin atlantic avary avoids bar barney big-screen blame bleak brand bridge builds casual category changes china circumstances clockstoppers combines confidence conspiracy constantly critical cry curious cuteness cutting dares darkly defeated defies department disappointingly disguised distant dreadful drug eccentric eventual evokes exact expert explosion expression exterior fail fearlessness feathers featuring flaccid focus glorious glossy gory greater heads idealism impenetrable impressively improvement includes independent indifference indulgent innuendo insomnia insurance jazzy jordan kahlo leaky life-affirming measured mike million minimalist nifty no. parade paranoia photographed picture-perfect pie pleaser plight pool precision prison produce pryor qualities raise rambling refuses rental resemble revelation sacrifice schwarzenegger scratching secretary self-deprecating self-important shoddy significance simplicity solutions sorry speaking succumbs suits sustains tap taylor texture ties tour trashy undercuts unrewarding upbeat valuable wallop website well-acted 20th abstract aimed alien am amiable anderson arc baggage base boasts brightly brilliance brooklyn casts chuckle clinic closing comedian comedy-drama compellingly consequences controversy conventions costume dashing davis deaths deceptions decidedly deniro desired dishonest disregard diverse dose dramas drunken duvall ears emphasis episodes er evans extraordinarily fathers first-time fix fly foolish foreign frida gave gifts glass goals gosling gravity grey guaranteed halloween ham-fisted harvard heated hundred huston innovation interminable j. jerry juicy jumps junior lane last-minute liberation lies lighthearted limitations literate manhattan meditative monotonous nicholas nonjudgmental nostalgic notorious nuance opens palma pants paradiso populated powerfully producer profane professionals projects promised promises psychologically puerile putting radiant refugees richard romp russian sandra satisfyingly savor saying scare scott screening shadows shoulders slightest spears stomach-turning styles sublime sucked sun sundance surprised swooping teens tense tight tool tosses transcend transforms tried trust turf unassuming uncomfortably wallet winners yearning 3d acceptable access acclaim affected aliens angry appeals argentine asian aspires assurance attal audacious awards birthday boat bourne cage celebrated cerebral cho clown con condescending continuity deals demonstrate deserved develop directly discover dozen dragonfly dynamic ego eight encounter environmental extra failed feature-length ferocity first-class focused follows forceful gang gel grating hammy heights hey hilariously hoped huppert idealistic idiosyncratic immature import improve improvisation inescapable infuses inspires intensely irrelevant jaglom japan jell joyless judd knack korea larger-than-life laugh-out-loud leading legacy lesser limp lit loaded lofty losers manhood medium mental mind-numbingly morvern muddle must-see narratively neighborhood no-nonsense nonexistent obsession obsessive offend overbearing overexposed participants penetrating pilot prime procession profoundly punching punk pushes rapid-fire recite recognize reduced rejected remind remotely rolling saga salton scenic science-fiction secondhand senses sinister sleepless smack specimen spontaneity steve strictly strongest strongly student stylized suffer suggests system tarantino teeming thematically throwing thrown tongue-in-cheek traditions triangle tsai undoubtedly uneasy uneventful unforgettable unimaginable unless unsuspecting venture victim wake whale wins yiddish zero 100 achieve adrenaline after-school aggressively ai alike allowed anthony arbitrary artistes assembled attempted bard barrel basis bilked blockbusters boasting boldly borders bow breadth bruckheimer caine callar canny center chiller coffee collectively collision composed composition compromise contest convictions crudup cultures cylinders damning dance decade definitive detached disappoint dong dreck dressed drowned effortlessly empire entry excited explores faces fairy favors field flat-out frequent gear genres gift greatness grin gross guess harmless hartley has-been heady helped henry hewitt host hurt ilk inclination ingredients inner intellect intelligently intentionally inventiveness jewish joel jumble kick laudable laughed lifestyle listen loathsome loosely lunch maintain maintaining manipulation matinee medical mediocrity miserable misguided monstrous moronic narrator nary nicholson nijinsky normally not-very-funny notes ocean opening overcooked pathos patient peak physical picked pitch practically pretend pride program randall realizing relative roberto roman rouge sack scrooge segment send shakespeare sisterhood sleeve smash squarely strict struggling substitute subtitles subtlest superficiality supernatural surrender tackling tensions thousand took traffic trap undone vein wal-mart warfare warrior warriors wearing weighty well-meaning whenever wilson wisdom wishing zany 1958 4ever 51 abroad accurately admittedly aimlessly allegory angels astute athletes atrocious auteuil awakening awe-inspiring believability benefit biopic black-and-white blaxploitation brash brave canvas cat cgi characterization christopher circuit club co-writer complexities concoction consistent conspicuous crack craft daniel dialog drew drivel driven drooling duration dustin duty elsewhere emperor emphasizes enjoying entertainingly european excels flailing flashbacks freddy freedom gasp generosity goo grim gun headed hear hidden hitchcockian hotel humanistic humanly humans hypnotic iconoclastic immediately individual inexplicably inoffensive invaluable ironies irreparable jealousy jolts lapses legendary leon lifetime likeable literally lola longest lord low-grade luck lunatic maddeningly melancholy mexican mira molly mugging mundane nickleby order peek philosophy plympton porridge precisely predictability prefer puzzling ramsay reefs relaxed reliable remaining restroom rewarded ridiculousness rigor ripe roots rowling schneider seams sensuality showcases sickening silver skillfully sloppily smooth smoother spike sportsmen sprightly squeeze stealing stick stone stooping storylines subtext successor succumb sucks suggest suit superbly support swear swim tales theories three-hour toss towards trace transformation transformed turgid undead undercut underneath uniquely universe unlimited vietnam wanders weaknesses willingness wondrously 11 absurdities adequate admission affleck alert all-time amuse amy anatomical answer arrest arrives audacity backstage baffled bands beats bible bluster bogus bonus boobs bravery breathes bride brisk brosnan bullets buttons buñuel callow chaos chicago chick chosen clooney colored colors coma conduct conjured contrivance cram crammed cranky craven creepiness crisis crossover dead-end delectable depressed derived derrida devolves diesel discipline distanced distraction dramatics dude echoes embrace embraced ended endurance enigma enjoys established excruciatingly extravaganza fabric fabulous farcical fearless feast fellow fewer fizz followed fool fulfill fury gender german gibson glorified glum golden gracefully grade grew grounded guide hates hopeless hopelessly hour-and-a-half humming ian icon idiocy illustrates improved inimitable insecure insults introduction isolation jaw-dropping journalism kapur killed knockout lavish light-hearted lips littered lock lodging lucky lurches managed manners marvel masses maze mel metaphors moralizing mountain movie-going names nash naturalistic notwithstanding numbing o one-dimensional overused painted palette palpable par paxton pedigree performed perfunctory plotline praise praises prospect racist randomness range rapes reached reaching reacting ready recklessness recycling relentless reminder rescue resistance revealed revel revelations robin romances rotten sardonic scares searching seas seeks shabby ship showcase skies skillful soldiers sounding sources spaces speaks spirituality spotlight starting starving stuart submarine suitable swept thirty thoughtfulness throw timid toback transcendent transparent two-dimensional urine vain vietnamese well-intentioned witnessed worry x. yesterday youthful 12-year-old 13 adams adding adequately afternoon analysis anguish anime antonia appreciated aspirations assayas atmospheric backdrops backmasking bank bargain-basement barry beguiling begun bucks burger canon capra carrying carved chateau cheeky chew cia clad clause clunky co-written comics compared complaint concentrates conceptual consolation crafty creations crippled cumbersome curves cuts dana darkness debate deception depraved dime-store dimension displays distinct distinguish disturbing ditsy doofus dopey dozens dubious duke dupe dutiful dynamics edits effortless elizabeth engagingly enigmatic enlightening enticing errors essential evade examines expecting experienced experiment exploit extravagant fade fascination fixating flawless floor fluffy folk formal forth full-bodied genius grasp grisly grotesque gut-wrenching hall heat heightened highs hill histrionics hokey hot-button inappropriate inauthentic incorporates indescribably inexplicable infectious inject insensitivity interaction interestingly irresistibly jackass jagged jake jolly juliette jumbo lanes layer leather legal legged levity lonely mad majestic makeup maternal mechanical melancholic merits midway modesty months mysteries nair negative nostalgia noticeably one-liners ordeal outcome overrun overtly package paint-by-number pandering pastiche payami pedestrian perform phone piercing placement pollution portray predicament procedural productions pulse push raises received relate reluctant repellent requisite research resonate rings ritchie rival robinson rocky sara schmaltz schmidt screens screwball self-caricature self-esteem sentence sequel-for-the-sake serenity sessions set-up shared shower shred shtick significantly smiles smugly so-called sober sorority soufflé spice spontaneous steers sticky stinks stoop strangers strengths stunningly sublimely sudsy suspect symbolic tart tattered teachers technique tempting terror town trek turd unbearably undercurrent unmemorable unsatisfying unwatchable urgently van vastly victimized videotape weep 1999 40 86 absurdly accomplish achievements achieving acute afterthought agent ambiguity angel antics applies arguments atop austin ball banality ben bile binoche biography brainless brains brits brushes build bump buzz calibrated cardboard channeling chasing cleaner clear-eyed climax clung-to cohesion comprehension computer-generated conquer conversations cooper costner covers damaged darling dead-on dean deliberately depression deserving destined directions disquieting downtown downward dust earn eastern education encounters enervating episodic evoking experimental expertly explode exploitive exploring facile falling fascinate fast-paced feet fest fondness fontaine forcefulness fun-for-fun guilt gutless guts hampered harris heart-pounding hitchcock hopes hypocritical insignificance instance institution interpretation interviews jaded jolie laborious lobby loveable lows madness mcdormand mechanics member merchant mercy metropolitan misanthropic moods named nebrida nelson nourishing numerous one-of-a-kind oozing opened opera-ish optimistic orlando painting parent-child passes passing peculiarly peppered perils pervasive pileup pitched pompous possibility pratfalls pre-credit proficient pulp radar rarity raunchy reasons recognized regardless region require reserved retaliation rhetoric richness rubbo sanctimony scruffy sealed self-examination sensuous sermon setups shape-shifting shapes shifting shine shrek simpleminded sin slackers sleeper slippery smartest smirk smoochy snore soccer span spins split square steady steeped stepped sticks still-inestimable sting stretched stroke suffocating super surefire surfing sway tawdry thinly thinness threadbare three-dimensional ticking top-billed traditionally tripe tug tuxedo twisty twohy ugliness unbearable ups vulnerable walsh wastes waters weave weepy what-if witness worthless wound 19th-century abomination abundant abyss actuary adorably aftertaste aids alfred all-enveloping amateur amorality anticipated aplenty arty attached attack awarded aware baby bang barrymore besson betrayal blah bloated bogdanovich bogged bolt born brush burning burr butt candid cannes carmen carpenter cheaper cheated chekhov chelsea chiefly choreographed cinderella claim cliché-riddled cluelessness coen collapse collateral collective companionable competition compulsively connect-the-dots considering crackle crew crowd-pleasing cutesy cutting-edge dad deliberate delinquent departure desiccated devastatingly develops diaries dilemma dinner diplomacy dips disappointments dive donovan dough dover drains driver drowsy dual dullard dv ear educates eerily efficient eric ethereal expects experimentation extremes fessenden fiercely flabby flash flesh forcefully forest formed fourth frantic fresh-faced freshly gasping gere glib glimpses goofball grinder handsomely haphazard hardened heartening heist helm heroic hide hitler hole-ridden hooting horrors ice-t idemoto ill-conceived ill-fitting imaginatively impersonation improves inevitably infomercial information inner-city insightfully intact interview introducing involve kicks killing kindness kosashvili lark learning librarian limits loyal lucy lulls lumbering m. marching masculine meets messiness metaphorical mill miller minute misses miyazaki monumental naughty neatly neil non-stop nonstop noticing novels nuances obstacles odyssey offerings operates outing outsiders overproduced overwritten passe passions payne pg-13 pinocchio plainly pleasantly pop-induced possess post potshots pressed prevents propaganda prose pushed qualify questionable quivering r. rap rattling reaction rediscover redundancy refused regard releases reminded reruns restrictive reveal river rock-solid rollerball root routines rugrats rumination rut satisfactory saved scarcely scarier self-glorification sensationalism septic shafer shaggy shape soar soars southern spader specifically speculation speedy spellbinding spookiness stagy staring stoked strays strokes stronger struggles stupidity substances sucker summertime surfeit survive swan sweat tacky tactics tank tatters terminally testimony thornberrys titular tolerance transcendence translate traveler trenchant troubled twaddle twisting uh unamusing uncle unconditional uniformly unknown unpretentious unrelentingly visible wan wander waves weaving welsh wendigo wesley western whimper whodunit wicked wilco william wood writers yu # 1.8 1970s 21st 6 8 accepting aesthetic aged allegedly allowing anachronistic anemic animatronic anne apparatus ararat argue arresting articulate assault associated assumes astray atmospherics atrociously australia awry bad-movie bankrupt barrels basketball beacon behave bielinsky birot blanket blip bluescreen blurry boisterous bonds boozy borrowed borrows bouncy bowling brawny brief broiling brutality budding bumbling cackles california captain car-wreck carefully cathartic champion chaplin cheat cheesiest chief chopsocky chuckles churn codswallop combat comically companion complacency comprehend conceit conception condensed confined connect conned convention corniest counterparts courageous criticism crocodile cuban cynicism danger daughters dearth decrepit dedication deepest definition deliciously denied denying depths determine diary discernible discovering disgust distracted dragons dredge edit efficiently egregious electric elephant elevates embracing enchanting engages enhances enters establishes ethical examine expressing expressions eyelids fabricated fantasma fellowship flatly flaw fondly footnote ford forgiven forgiveness foul-mouthed fragmented freighter frenetic fused future-world gang-raped gimmicky good-time goodfellas gosford gratefully grind grown-ups grownups guillen gunfight gutter hamming hawaiian head-trip healthy heart-wrenching herself hint holographic hopefully horse hostile humility hurry hysterics ignore ill ill-considered ill-timed immediacy immersed impossibly impulsive inarticulate incarnation indecent indians indie-heads infatuated influenced injects intricately jaw-droppingly jerking joie jonah joys kalvert kids-and-family-oriented kinetic kinky kitchen lasting laughably laurice lengths lengthy lewis liberties lights lisping literal logical loopy lose lowbrow luckiest luridly macabre malkovich mamet manic media mendes mentally mired molested morgan mother/daughter mtv myth na narration natter needlessly ninth no-frills nonetheless novak noyce numbingly nutty of-a-sequel oliver ought out-bad-act overdoing overflowing overladen overly-familiar owes pairing paperbacks paranoid partly perdition perfection personified phrase pill pink pinnacle pizza player preemptive primitive prior prism profanity puberty punny rapt raunch reactionary receiving recognizable redeeming rehashes relic reminding repeatedly repetition represented reptilian resembling retains revisionist riot ritter roberts rotoscope s/m sag salt satisfy savage savagely savory scale schaeffer schedule schumacher screenings seattle section seduce self-indulgence semen seriousness sexually shadowy shop shortcomings shorter sillier sits sledgehammer smoking snl solemn soon-to-be-forgettable spider spiked spine sporadic sport sprung squad stark stephen stopped stream stuffiest sub subordinate subtler succeeded successfully supply survivors suspects swooning t tapping task tastelessness tea telegraphed tendency thomas tissues toilet-humor tolerate travesty true-to-life tuned un-bear-able unable unappealing underdone underlying unimaginatively uninvolving unleashes unnoticed updating upscale vibe viewed vignettes violently voice-over wasting wayward well-shaped west whatsoever whining witnesses work-in-progress woven xerox xxx youth zippy '60s 1984 2000 3 8-year-old abc accomplishes accomplishment actions adrenalin african aisle alice alienate all-too-familiar anyway anywhere apply arduous areas argument arms astoundingly audiard austerity authentically avarice badness balancing ballistic baroque bear beaten believer beloved biased blasphemous bliss blonde bloodsucker bollywood bones books boomer bothers boundaries bowel bubba campy captivates career-best careers cartoons catalyst celebrity charged charles chasm cheering church circle claptrap clean cliche-ridden climate cocktail code comments conceive conditioning conflicts consigned constraints contemplation continuation contrast controversial cops counter-cultural credulous creek creeping cruelty crypt damon dated daydreams decided deeds define depict destiny destroy destruction dey dickensian direct disabilities disintegrating disorientated don door dragged dreamed droning dullness durable earned ebullient echelons edition edward effecting elbowed elevated employs enchantment encourage enriched environment etc. evaluate evenings everlasting everyday evolved exists expands explosions exuberance fairy-tale fashioned fees feisty ferrara fiascos fingered fires flagrantly flatter flickering flight florid foibles folly forcing friendships furiously fuss futile gaping gaze genteel geriatric goers goldmember gong goods greene guarantee gusto half-baked harder heart-warming heartache heritage high-wattage ho-tep homicide honor hook human-scale humanism impart importance impress improperly include inexperienced infantile infusion ingenuity inherently innovations inquisitive intent introspective inuit invasion ironically issue jagger jet k-19 keen kieslowski kittenish knee-jerk lampoon landscapes late-twenty-somethings law legally legend letdown liability lift likes lip-non-synching little-known losses lustrous major-league matthew maturity mawkish messenger messy metropolis michel mine misplaced molestation monsoon morbidity motions motivations nails naked narc navigate neck nerves nudity oedekerk oh-so opportunities orgasm orwell outdated over-blown pacino parking pencil pickup plucking plus potato practice precise prejudice preliminary prep-school press prevails preview producers promisingly proposal proving provoke pseudo-educational psycho puff pulled purposeless purposes pushiness putrid quentin quirkily ray real-time realistically realm reassuring receive redolent reflection reflective refreshed refreshes regalia reggio reinforcement relegated reliance repulsive retro revenge-of-the-nerds reverent revisiting riveted robotically robust rocks romancer sanctimonious says scarifying schmaltzy scientists screwed-up searing second-guess seductive self-destructive sensation serials serving shadyac shakes sham sharpener sheerly shorts shout showdown showgirls sibling signposts sing skewed slanted slap-happy slim smacks sneak sneering sonny sorrow spiced spiffy spiteful squandering staggeringly stanley stitch stoops stops strains stretch stubbornly studios suitably summer-camp surf survivable taps tear-jerking teeth-clenching theatre theory threatened thrown-together tosca triteness trope tropes tunney turkey unapologetic under-inspired underappreciated underscore undisciplined unexplainable unintentional unrealized valiantly victories volume vulnerability walker walls water-born waydowntown winking wonderous word-of-mouth worn-out wrestling wryly yawning youngsters zings 'em 146 18-year-old 80 accumulates accuracy ace acknowledges acumen adjective admirably affirms aficionados aggressive alienating anchoring ancient annals annoyed anomie antic antitrust appearing appointed array ash ashley assume astonish auditorium autobiographical available avant-garde balances banter baran barn-burningly benefits bewitched blithe blood-curdling bodily bores brainpower bravado bucked burrito calls camaraderie capability captivated card catching catholic chair chases cheapo choosing circus clash cleverest closely cloyingly co-winner companionship conclusions concocted confidently confront confrontations convoluted corrupt corruption counts court craftsmen creatures credited crowd-pleaser cultivation cup cure curse dash dead-undead dealer dear decibel defiance deja demand demme demonstrating denis derailed describes devotees didactic differences digest dimensions dimness discordant discussion dismiss distinctly disturbance ditty doing-it-for dope drab dumped earlier earthly east eating ecks elbows eloquence elusive employment encompassing enthusiasts entity enveloping even-flowing everybody everywhere evidence evolve exceeds excites executive exist exuberant facet family-friendly fascinates fatter feeble feelings fi fiend filmgoing fishy flavorful flip-flop fluent fluke foot foul-natured fraction freddie frightfest frissons frontal fulfilling gary geared generations genial-rogue glass-shattering glimmer glow goddammit gondry goofiest goofiness graphics gratingly grave grueling guilty-pleasure guy-in-a-dress half-asleep half-hearted hallmarks haneke harbor hawk hayek helpings high-minded highlight hindered hoary hodgepodge holding hungry-man hunk identification illness impassive improbability in-depth indigestion ineptly injuries installment integrated intro invited irksome irrelevancy israel italicized j.k. jane jealous jessica jfk joe jr. juice julia junkie kissing knock knockabout laid landmark languid languishing laugh-a-minute leash leavened leftovers leguizamo libertine light-footed lilo listening loser low-rent lower luster lyricism maelstrom maneuvers manifesto mann mannerisms markets marveling mcgrath memorial meticulous milestones milquetoast minority miraculous miramax miscalculation mission mist mistress misuse modern-day monday morph multiplex negligible negotiate net newfangled nickelodeon-esque norton not-so-funny noticeable næss object objectivity oily olivier on-screen one-trick open-ended orgy ours out-to-change-the-world outer-space outright overwhelmed overworked padded pains painterly passable passed pawn pearl peas penchant permeates petter phoenix phoney-feeling phoniness pile-ups pipeline pony popularity portions power-lunchers prim principal principles print privy probing proceeds pubescent puddle pumpkin punishment purposeful puzzle quandaries queens quieter raindrop rally rant rate raucously razzle-dazzle readily reconstruction reflects relying reno reprieve rerun responsibility retitle revitalize ridicule rigid ringing roughshod round rousing sadistic san sassy schizophrenia sci scooped scooping scorpion selection self-aggrandizing self-aware self-awareness self-congratulation self-hatred self-righteousness sentiments serviceable sever sewage shakespearean she-cute shown shrieky shudder siegel simple-minded sizzle skills slapping so-so soft-core solaris sorrowful sought speed squashed stages stalls stamp stand-up states stellar stevens stifles stirs stores strained strangeness streets stuffing stuffy stylists subjugate subzero succumbing sucking suffered suggested sung superhuman swedish sweetheart symbolism tackles tagline tall tame tattoo taut tavernier tax teach tearjerker teen-driven terrorism terrorists testud the-cash threat thunderstorms tian tightrope timelessness toes topless topple tract tremble two-day two-hour-and-fifteen-minute unadorned unblinking uncreative undermining unfilmable unifying united unmentionable unmistakable unrecoverable unsubtle uptight urgency urgent vacant vaguely valiant variation veers veracity via visionary visualize vivre von vu waking washed waterlogged weary wednesday well-balanced well-deserved well-developed well-done wherever wide-smiling wig wildlife winningly wore working-class workplace world-renowned zingers zombie '30s '40s + 22-year-old 25 3000 4 5 84 91-minute above-average absurdist accepts accident accused action-adventure adorable advised aesthetically affable agenda ages-old aggressiveness aggrieved ah airless albeit all-star allusions amateurishly amazement amish amnesiac amusement anciently angle anguished anonymous anymore apocalypse appreciation argento armed artfully asia assassination astronomically awake awed babak backgrounds ballplayer banderas bartlett bathroom best-sustained biographical bludgeoning boilerplate bolero boost boots boundary-hopping bourgeois bratty brim brio bronx bug-eyed burst by-the-numbers cal calculated calculating campus cardiac caretakers caricatures carlin castro celebratory challenge challenges characteristically charade cheapened cheatfully checking cheer cheesier cherish chick-flicks chimes chump cleavage cletis comedy/drama commercialism communicates compassionate competence complexly concepts conditions connected constructs consumed contemptible coolness copout costumes cotton coupled courageousness craig crane crawl creeps cricket crossing-over crossroads crummy crummy-looking crushingly cuisine curmudgeon damaged-goods damme damon/bourne dangerously daughter davies daytime daytime-drama debated debt decide decision defiantly demise demographic denzel detract devastated developments device diabolical diane diner directionless disastrous disconnects discreet disdain dismay ditched diversity dolls doubts down-to-earth downhill downsizing draw dread dreaming dreamlike drips drowns duck eagerness earplugs ecological economical ecstasy edges elling enables encompasses ensure epiphany erotically error escort espionage estranged ethics evaporates exasperating exhausted expiration expresses extent extreme-sports eyre facetious factors faltering familiarity fantasia far-flung fast-moving first-timer flesh-and-blood fleshed-out flinching flippant fluidity fore foremost forgets forum four-hour freakshow fright fumbled function fundamentally funniness gabbiest gaining giant-screen girlfriend goodly graduated grandkids grayish gremlins grieving griffiths grips grown grumble guessed guffaw gullets half-step handle handling hard-hitting haynes heart-felt heavily hennings high-octane hitch horrendously hour-and-a-half-long household houses hug hughes humanizing humankind hungry husband-and-wife ignorant iii ills impulses inadvertent incident incorporate indication induce inhale initial insignificant instinct instructive insufferable interrogation introduce investigate irresponsible irwin italicizes iwai jacques jean-claud jesse joan kennedy kinetically-charged kosminsky kurt labute ladder lags lake lamer lan large-screen larry latin learned liar license life-changing lightness lilia lobotomy long-faced long-on-the-shelf louis lucratively lumpen macdowell magnetic makhmalbaf mall margin marquis marshall meal meanderings meara menace mergers mid middling minded mormon muddy multilayered mumbo mutates muzak mystical métier natured naïveté nervy niches nights no-brainer noon noteworthy numb obnoxiously obsessions odorous often-funny ominous on-camera one-star ops orchestrated oscar-winning oscar-worthy overblown overcomes overplayed overview ozpetek pact page panoramic paths persuasive philip phlegmatic photo pillages pitiful placid placing plane plateau pleas pleasuring ploddingly plunging poet population pregnant prevent principals privileged proficiency propels psychotic pub pump punchy purer quick-witted rea reaches reasonable recall reception reckless records redone reduces regain relevance rembrandt repeated resident retooling revulsion ripening rob rocket romanek rough-hewn runteldat rustic sabotage samira satin satisfied scandalous scathing schemes schlocky scotland scottish screams screenplays seamless seamy sedate seemed seeming segal self-assured self-empowering self-preservation sexist sexuality shadow shainberg shaking shapable sharper sharply sheridan shockwaves shoe shoplifts shovel shu signals simulate siuation sketched skins slob slog slovenly smart-aleck smarts smeary smith smoothly socially soothing soul-searching species spreads sprinkled squareness stare steals stewart stillborn stolid stomach-knotting strings stud styled sub-sophomoric subgenre subjected suddenly superfluous supermarket surroundings suspend sustenance tabloids tangled targeted technological terrorist theatrically throwback thumbing tickets tickles tier time-consuming tinny titled tooth toothless topical transfigures transform transgressive travail treasured tries-so-hard-to-be-cool truck tumbleweeds tunes turbulent twin twitchy two-thirds tykwer unattractive undergrad underrated undo unfakable ungainly unhappy unholy uninhibited unknowable unmotivated unsurprising unturned update vacation veered veiling veneer verges versus veteran vicious vidgame virulent visions volcano vowing wanderers wang wannabe-hip weaponry weasels well-executed well-lensed well-thought wildean windtalkers wiseman wistful withered worm wounds wrenching yard 10-year 103-minute 105 110 1937 1975 1989 1993 300 451 48 50 83 94 95-minute abandoned abhorrent ably action-comedy addict admirers adrift advice aerial affirm afterwards agreeably ailments aisles all-night all-woman amaze amble ambrose analytical anne-sophie anniversary annoyance anteing anti-semitism appeared approaches arcane argentinian arrangements arrogant asset athlete atrocities auteur autopsy awesome awfulness bale bang-up bare barf barlow bask bearing been-there been-told-a behalf bela bender benjamins big-wave bio-pic bloodletting bloodstream blowing board bomb bombs bone-chilling boom-bam bothered bottom-feeder bouquet bravura breathing brian bridget brit buddy bull bursts caffeinated café cake cameos camouflage camouflaging canned captured caricature cat-and-mouse catapulting celebrate centering centuries ceremonies certified chafing chatter chen cherry chewy chiaroscuro chicken chocolate choppiness clamoring clamorous clams clancy clashing cleansing cliché-laden clients close-ups clubs co-operative coal cockettes collide collinwood columbia column combustible commanding commonplace compare compatible compendium composure compressed condescension confronting conjures consciously consummate continue contriving coping coppola corners corporate corpse courtship cows crackles crappy creators crescendo cringe crises crooks crossed culkin dante dass dates dawdle dazzle decay decency dedicated deep-seated defensible deflated delay delusional demeaning demographically depressingly descends desert desirable despicable destructive detailing detective deuces devastation devote devotion digital-effects-heavy digs disbelief discarded discerned discord disguising disingenuous dispassionate distaste documentary-making documented documenting doles doltish douglas dour downer downfall dramatized dramaturgy drink drives dumbness duplicate eats ed eileen eloquently embarrassingly emerging encountering endangered endear enforced entertains entranced equals escapes every-joke-has evocation evolution exasperated excepting existentialism expose exposes extension exxon eyeballs faced faked fame famine farewell-to-innocence farrelly faulty feardotcom.com fell female-bonding feminine fence ferrera fervently fine-looking finesse fingers finished flowering flowers flowery follies football forthright foster four-star fragment framing freud frighteningly funky fussy futuristic gag gap gaps generating germanic gleaned glued goggles gooding grandparents greasy grievous grit grotesquely groundbreaking guessable guilt-trip guitar gutsy hal halfwit hallucinatory hang-ups happily hard-to-predict hard-won haunted hawn hazy healing heavyweight hellish helmer heralds hiatus high-profile hilary hispanic historic homes hooks hooliganism hospital hostage hot-blooded hubert humble humdrum hurley hustler i.q. ideally ignoring illogical illusion illustrating imaginable imagined imitator immaturity inadequately inc. incapable incinerates increase indicative indictment indisputably inexcusable influence infuriating ingratiating insufferably interplay invest invigorating invincible iran irrepressible irreverent itch items j jagjit jay jimmy johnny jonathan junior-high just-above-average kate kept kieran kinds kingsley knoxville kooky labeled labyrinthine laid-back laissez-passer largest-ever lashing lazier laziest lean leap levy liberal lingers logically loquacious lore low-brow lugosi luminous lumpy maddening malarkey mannered map margarita mary matron mcculloch meander meanspirited meaty megaplex megaplexes melodramas memento menacing mends meow meticulously mick middle-of-the-road midlife milk misconceived misconstrued mishmash mistaken modeled modernize modestly monotone monster/science moonlight morrison mothers mothman motives mournful movie-star musicals musty myopic nachtwey name-calling namesake nationwide native nazi nerds nincompoop nominated not-at-all-good notably now-cliched off-beat old-fashioned-movie old-hat operative opposite oppressively optimism orchard orchestrating out-sized outshined overstays pained paint paintings panic pared parent partisans passages pasta-fagioli pasty patchwork pathetically patriot patronizing pb pbs peace pedestal peploe performing permits philippe phillip pique pitifully played-out pleased pleasingly point-and-shoot polemic poorly-constructed populace porno portrayed precocious predict predisposed prefabricated prepare primer prince pristine probe probes projectile prologue prolonged pronounced protect pythonesque quaid quibble quirkiness races rachel racism radioactive ram rapid rarest raymond reaffirming realizes receives recreates redeems reductions refresh relayed releasing rendering renders renowned reprehensible resentment resolution responsibilities restrained restraint rewritten ribcage rifkin right-thinking righteousness riot-control ripoff risks robustness rodriguez romanticization romped rug rules runyon rural russell sabotaged saddled sandwich sanguine saps saucy saves schizo schlock scrape scrutiny scummy self self-destructiveness self-hating sensitivities separation sexpot sharing shedding shekhar shimmering shmear shock-you-into-laughter shoots show-stoppingly sickeningly sights simmer singer-turned six skunk sleekness slimed small-budget smallest smile-button smoky sneers sociology sons sophistication sophomore spectator speeds spicy spiffing spinning spliced spots sputters stab stacked stasis stimulating stitched storyteller straining strands strip strives struck structures strung-together stultifyingly sturdiest stylistically successes sugarman sultry sun-drenched surgeon surrenders surrounded suspected sustained swashbucklers swings switchblade sympathies szpilman taboo tambor taxi taymor teeny-bopper ten tenor tequila terry thankfully therapy threw through-line thrust tickled tiresomely tongue topics tossed tout transparently transporting treating treatise tremors triangles trotting trouble-in-the-ghetto tucked two-way unbridled uncompelling undeterminable undistinguished unentertaining unequivocally unflattering unfolding unfussily unimpressively uninflected unmentionables unrelated vampires versions vh1 vibrance vile vin virtues virtuoso virtuous vistas volletta vulgarity wall-to-wall warmed wasp waster watered wayne weighs well-wrought welled wertmuller whimsicality widow windshield withholds withstand wladyslaw woe-is-me workable worlds writer/directors ye yearnings yelling yellow yorker young-guns yours zeal '80s 12th 1957 2,500 295 90-plus 93 abrupt accuse action-movie ad adapted address admitting adolescence adoring adorns adultery aesthetics affectionate affectionately aggravating agitprop agonizing ally alterations ambitions amicable amid amok amours amp anew annex anti-erotic approached approaching archly arnie arrogance artworks as-nasty assumption assuredly athleticism attach attracting auspicious authenticate author auto autopilot avid backdrop backward baked ballistic-pyrotechnic banged barbed bartleby beer-fueled befallen begging believed besotted bewildered bewilderingly bias big-hearted bigelow bike binks birmingham blacked blackout blanks blazingly blowout blueprint bodice-ripper bona boorishness bout brain-deadening brand-new brazenly bread breaking breathless broadcast brockovich brow bruised brusqueness brussels bubble bubbly bueller burlap byzantine c.h.o. calamity campanella canadian cannibal cannon cared carnage carpets cartons castrated censure chain chamber charlie chatty cheapening checkout cheery childish chomps chooses chore château cinemantic civic claims clips clock closest co-stars cocoon coheres coldest commerce comparisons compensate conan confessions conflagration confusion conniving conspiratorial conversion copies costuming coughed county cracker cradles craftsmanship craziness creaky creep critique crosses crucifixion crudely curling cushion danny darned deadeningly debrauwer deceit decorating defensive defined definite demonstration depleted derive described didacticism difficulty dime diminishing dimwits director/co-writer disappearing/reappearing discomfort discourse disease dissecting disservice distinguishing divorce dog-paddle dogtown dollar doomed double-pistoled drastic drawn-out dreaded dreadfulness dress drill drumline ducts dull-witted eagle eccentricities eclipses economic efficiency elicit elicits elizabethans emergence empathizes employ endeavors energizes engross equalizer equation erin escaped et etc even-handedness eventful exclamation excursion expressiveness extensive extra-dry extravagantly extremist eye-popping fallibility family-film fart fast-edit faults feminist fiddle fide fidgeted fifteen-year-old fincher fireballs fireworks flavorless flop flopping foundation frames framework franc freaky fringes frontman fudged full-length fumes fun-seeking fundamentals fussing gaiety gain gallery gallic gamble gantz garnered gas gathering gawky giggle glucose goldbacher goodwill goofily gored goth-vampire grade-grubbers grade-school grandiloquent grandness gray grease greatly greenfingers groan groaners groggy grossly guest gulzar gut-busting hangover hankies hanks haplessness happenstance hard-bitten hard-hearted hard-pressed hard-sell haul head-turner heartwarmingly heathers heed herrings high-end highlighted holofcener honorable honored hop hopped-up hussein iconography idiotically idol illogic image-mongering imbued imitative implausibility impudent impulse incarnations incessantly inconceivable indefinitely indoor inferior inhalant injustice integrates interdependence interference intergalactic intoxication invite island italics jacket janice janine janklowicz-mann jar jeunet joseph journalist julianne junk-calorie juwanna kafka keel kidnappings kids-cute kills kirkegaard knickknacks krawczyk l.a. labor ladies lameness landing lapping laramie large-scale last-place laziness league lear legs lend leonine less-compelling lethargic letter lick light-heartedness lika limbs lion literature ll lovefest low-wattage lower-wit lowered lowly lucia lulled majority male-ridden managing marathons marina marked marred masochism masquerade mattei mcdonald melange mermaid mesmerize messing mexico mgm mib milked mimics mind-destroying mindset minutely minutiae mirren misogynist mixed-up moaning moat moderately mojo monotony mopes morbid mordantly morsels motherhood motionless moulin mounting movie-of-the-week muccino mug multi-layered multi-layers multiple musclefest muttering myers myrtle nagging naiveté napoleon negatives nervous newcastle newness nick nonconformist none-too-funny north nuclear o'fallon obscenely off-putting ok old-school oleander one-hour one-sided oomph open-hearted operational oppressive oprah option opts orbits ordered orders oscar-caliber outer outlandish outrage overhearing overkill overmanipulative overripe overshadows overweight pablum packages padding papin parental parrots participate partnership patch paws peppering per perception perch perfected permission perversity piccoli pizazz plague plagued plans plastic platter playlist playwriting plethora plummer plunge poetics pointing poo-poo porn pornographic porous post-feminist postmodern potboiler practices preaching precarious preciseness preordained presenting president presiding primal profundity projector prom promenade protective proved provokes pun punctuated punctuation purity purposefully pushing quick-buck r&d rae rah-rah ramble ransacked ransom rat-a-tat raunch-fests reassuringly recalls recognizes reconciliation recovers recycle reeks reflect regarding reinvigorated repetitious repressed reputedly resembles respite resulted resurrecting retelling retiring revels rhapsodic rhapsodizes right-wing rip road-trip rough-around-the-edges row rueful ruminations rumor run-of-the-filth runner ruthlessly s.c. sacrifices saddam sameness sands sap satiric scant scherfig scooby scorn scrapbook seal seizures self-control self-mocking self-parody self-promoter self-promotion senseless sentimentalizing serpent setpieces settle shades shameful shiner shiver-inducing shootings shreve shriek signpost sillified singh sitcom-worthy sketches skilfully slasher slender slopped slugfest slurs small-town smallness smeared snail snake sneaky sneeze soap-opera soliloquies sparks speculative spends spiral spits splendidly sprawling squint squirming standbys standoffish standout star-studded stature staying steaming stench sterling sticking stimulate stonehenge stones streamed street-smart stripped strong-minded stumbles stuttering substantive suicidal super-dooper-adorability super-simple superstar supposedly surfacey surrealism surrendering swank swimfan swipe switch sword swords tables tadpole tailor tapestry targets tedium teenybopper teleprompter tenderness tens thesps thinkers third-rate thoughtfully thrives throat-singing thud tick tied tightened time-switching timeout titus tolkien tonto too-conscientious torrent train transporter transvestite travails treachery treading tricky triumphantly trots troubling tub-thumpingly tucker tundra tutorial twenty u.n. ultra-violent unacceptable unaffected unapologetically unbalanced uncertainties unchanged unclean undercover underground undeserved undramatic unerring unexceptional unforgivingly unfortunate unhibited uni-dimensional unimpeachable uninventive univac-like unpersuasive unprecedented unprovoked unrelenting unruly unsatisfied unshapely vagueness vainly vat veil velocity venice verisimilitude vices victorious virtuosic viscerally wacky walking walt watery weakness weirdo well-defined well-paced well-realized welt westerners wheedling wherein wide-eyed wild-and-woolly wizened wollter woozy worldly wretchedly wrote yielded zaidan zealand zero-dimensional **** 10,000 10th-grade 13th 163 18 1980s 20-car 21/2 60 66 89 95 99 aaliyah abel accidental accountant achieved aching acidic activism actorly actress-producer addictive addressing adhering adjusting advance afghani agile aim aladdin alas all-around all-inclusive all-male alleged aloof ambience ambivalence ample amusements andie android animations anticipation antiseptic antsy appealingly applying appreciates appropriately ardent ardently armageddon articulates artificiality artistically artless ascends asleep aspire aspired astounds astringent automatically averting b.s. badder bags balm balto banger banter-filled barbarism barker bars battlefield beard beer befuddled begley benchmark bergmanesque berkley berry bitten bitterly blazing blight bling-bling blob blond bloodshed blues boatload boffo bogging bolstered boom bordering borderline botched bracingly breathe brogue brooms buck buffoons buying bygone cad cagney cameo canadians cancer capped cardoso careless carl cars cassavetes catch-22 catharsis cellophane-pop chaiken cheek cheered cheerfully chemically chicanery chillingly chin choke christianity chronicle cinematographer classify claustrophic claw clayburgh clear-cut clerk climb clone clouds clout collaborative collar combining comfortably comfy comical commend compassionately concentration conclusive conscious considers conspicuously construction consuming contact contender continuum contradiction contradicts contrasting contributions cope corn cotswolds cox crassly crawls creed creepy-scary cribbing crisper crowds crumb cube cue cussing cutoffs cyndi da dared dawn deafening debts decisive degraded dehumanizing del delusions denouement departments dependable derek derring-do desolate devito diaz diggs dilutes dimming dip disaffected-indie-film disagree disciplined discovered disease-of-the-week dishonesty disposition disquietingly dissing distances distinctions distracting divisions dizzy documentarian dollars dolphin-gasm done-that dorkier double-barreled down-and-dirty draft draggy dreadfully dreamy drifts druggy drying dug dumas dungpile dustbin ear-pleasing earmarks earns eclair elegy elite elm embellished embroils enacted encumbers energized enhanced ensues ensuing entered enthrall enthusiastically entrée environs erratic escapist esteemed ethos euphoria evanescent everett evergreen eviction evolves exalted exception exchanges exoticism expanse explains explosive expressionistic extra-large exudes facial faint family-oriented fanatics fantastically far-fetched favorably feed feel-bad fetishistic fighters firmly fish-out-of-water fitting flakeball flame flatula flatulence flavors flck fleeing flexible flibbertigibbet flies fling floating flourish followers food-for-thought forbidden forming forster fragmentary franklin freak-outs freakish free-wheeling freely freezers freshening friggin from-television frothy frustrates fu fullness fumbles funnybone fusion gamut gedeck gelati gender-provoking generalities ghandi ghosts giants gibberish gigantic giggles girl-on-girl give-and-take glee glides globalizing globe globetrotters-generals gloom gloomy gloss go-for-broke gobbler gods gon good-bad good-naturedly goose goosebumps gore gorefest grad grand-scale graves hairs hallucinogenic handed handles hands-off hanging hardship hardy harm harrison hated hawke heart-string heartbreak hefty hell-bent helluva hence hibernation hickenlooper hiding high-powered high-strung highway hills hired hiss holden hollywood-action homophobia homosexual honeys horror/action hue humanist humidity hundreds hypocrisy i.e. ick iconic igby ignite ill-advised imbue immaculate imperfect imposed impostor inadequate incognito india indifferent indoctrinated indomitability indulged industrial-model ineffective inelegant inescapably influential innovators innumerable inquiry inseparable instilled insufficiently integrating interests interior intermezzo intrepid introduces introspection invented invites invitingly irony-free irritates israeli ivan jam jar-jar jersey jokers jokester jostling kaige keenest kenneth kid-vid kitsch l. laboratory laden land-based laptops laser laughingly laundry lauper lax le leaps lectured lectures leigh lens less-is-more library lieutenant life-altering life-embracing lightly liking lil limit linearity local loco london long-suffering long-winded look-see love-struck lower-class loyalty luis lurks luscious lynch lyne macaroni machines made-for-movie maid maker manager marivaux marking married marveled masterly masterpieces matrix mattel mattered max mcbeal-style meager medicine melancholia melville mercilessly message-mongering metal metaphysical midsection mikes milder mind-numbing ming-liang minimal misdemeanor missive mixing mode modem modus moldering mom money-oriented monitor moralism mores moron morton motivate mourns much-needed muscles muse must-own muy naipaul naive narrow nba needy neo-fascism neorealism nettelbeck network neutral new-agey nicest nicole nightmarish nohe nonchalant nonconformity none-too-original nonfiction norma normal nosedive not-nearly not-so-bright nouvelle nuts objects obligation oddballs offended oftentimes omission ontiveros opportunists orange origins oscar-sweeping out-of-kilter outbursts outrageously over-dramatic over-indulgent overinflated overtake overuse owed owen pages painstaking palatable parochial pasolini passably patiently patrolmen pauly paunchy paying penn perennial perkiness perpetually pessimism petri philosophers phonce pin-like piscopo pissed pithy plant playfully plucks plummets pocket poise pokémon polite pootie pop-music portraits portraiture portuguese position posturing pot pours powerpuff practiced pranks preachy-keen precedent preferably prepackaged prescient prescribed preserving pretention prettiest prevention priceless prickly primary princess principled procedure prolific proof properties proposes proven proverbial provided provoked pseudo-bio psychedelia psychopathic pulpiness punchier punchlines puppets puppies purdy puzzled q quaint quaking questioning quickie quietude rain rambles ramblings rampant re-creations redeemed redundancies reeking refined registering rejigger religion removed reopens repartee repeating repellantly requiring resents reside resorting resorts respond retail retard retro-refitting reverie reviewers revolutionaries rife rivalry rode roiling rose rotting rough rumblings saddest saddle sags salute salvage sandbox sarah scandal scarily scheming schnitzler scientist scoring scripting scum scuzzy sean second-rate self-amused self-centered self-delusion self-glorified self-mutilation self-righteous selves semi sense-spinning servants sewing shackles shakesperean shift shifted shindler shiri shoes shootout shore shouting show-biz sica siege signature silbersteins simmering sincerely singer singles sitcomishly slaloming slam-dunk sleaziness sleazy sleek slope slumming slump smashing smear smell smokey snail-like snared snickers sociological solace sole solidity someday soothe sparking sparkles sparse specific spill spirit-crushing spiritualism splashed splashy splitting spoofy sprouts stable-full stagey staggered stamina standup statements stein stepmother stills stockwell storm straight-ahead straight-to-video straight-up strategic strenuously studied stunt-hungry sturdiness stylist substandard substitutable suckers sugary suggesting suitcase sulky sunny sunshine super-powers superficially superhero suppose surehanded surrealistic swaggers swaying swift swimming tang tanks tartly taxicab teaming tech-geeks teeth-gnashing telegrams temple the-night thirteen thoughts three-minute throat thrusts tics todd toddler tongue-tied tongues too-frosty too-long torments torpedo toss-up touchstone toys trade treads treated tried-and-true trier trifecta trimmings trio triple triumphant tube tuck turntablists twinkle u.s. ugly-looking uhf uhhh ultra-loud unashamedly unburdened unclear undergo underlay undernourished underrehearsed understandable understatement undertaking undisputed unengaging unflappable unhidden university unreligious unsettled unwieldy upfront user-friendly utilizing vacuum valentine valid validated valley vaporize varying versatile video-game-based viewpoint vittorio voices voyeuristic waldo walled-off wallflower warmest washout water-bound watts weaves weighted welcomed well-characterized well-timed well-told well-worn werner whiff widowmaker wiggling willies wind-in-the-hair wised-up wisely wittgenstein wonderment workings workshops worldly-wise worries wreaked wreck wreckage xmas yale yarn-spinner you-are-there yourselves zelda zhang zipper '50s * 10-course 112-minute 1920 1938 1952 1995 2001 24/7 2455 80-minute = abject absorb academic accident-prone accumulated acid acquires actioners actorliness actory adage adherents admitted adolescents adrenalized affluence afford affords african-american agency agitator ahem aimlessness airs alcatraz alexandre all-powerful allegiance altar ambiguities american-russian amoral anchors angelina angling aniston ankle-deep anna ante ants apartments apex apollo appetizing apted arbitrarily architect archive ardor arguably argues armchair arrow artwork as-it ashamed asparagus asphalt assassins assaults assembles assets assigned ate attackers attendant atypically auschwitz awash baaaaaaaaad backhanded badly-rendered banquet barrage basically beaches bean became bedside befuddling beginnings begrudge berling beseechingly besides best-known bestial bestowing betters beyond-lame bicentennial bid bidder big-fisted bigger-name bind bisset bladerunner bled blender blessed bluff boardwalk bodacious boho boiling boils bolado bombshell bon boorish boot boss bottom-of-the-bill bouncing bouts bowser brass brats braveheart bravo brawn brazen brendan brink bromides bryan bugsy bump-in buoy burdened burnt burnt-out burstein bush businesses busts butterflies butterworth béart c. cacoyannis calvin campaign campaign-trail campfire candidate candy-coat capably careening career-defining careful carmichael caruso casings catastrophic caterer catherine caton-jones cattaneo cellular centered cesspool chainsaw champagne characterisations charitable charlotte cheap-looking cheeks chefs cherished chill chomp chords chronically churlish churns circumstantial clare classicism classification classified cliff-notes clinical clint clive cloak clownish clumsiness cockney coda coincidence collaboration collected columbine coma-like comatose combine coming-of-age/coming-out commands communicate comparatively complain complaining completion compliment comprehensible compromising concentrating concoctions confection confuses congratulate conjure considerably contemplative contenders contentedly conversation conveyor convolutions cookie-cutter cooler coos cor-blimey-luv-a-duck counterculture courtesy cowardly cowering crapulence crawling crazed creatively creepiest criminals critic-proof cronenberg crudity cruelly crystal crystallize cumulative cuter d'etre dadaist daft dancers das day-to-day deathly debilitating decipherable decisions decomposition decorous defend definitions delighted democracies denouements denuded depress derisive devos direct-to-video/dvd disappoints disarmingly discerning disgracefully disintegrates dislocation distort distract distractions doa doc docs dog-tag done-to-death downs drain dreamworks driver-esque drops dullingly dummies dusty dwells ego-destroying eisenstein el elder election em embellishment embody eminently encomia encourages endorses engendering england ennui ennui-hobbled enthusiasms entrapment equations equipment espn eternal evasive eve ever-growing ever-watchful executives exhaustion expensive explanation expository exposure fabulously fabulousness facing fahrenheit failings fairytale fallible falters fanatical faraway fathom faux-urban fax fertile figured filter finery finger fingering flames flashing fledgling flimsier flopped flops flower flower-power flush fluttering fluxing flying folksy food-spittingly foreground forsaken fortune francisco franco fresher freshman freudianism friel frightful frittered frothing frustrated fuddled fuel fuels fun-loving functions fuse fusty gadgets gall ganesh gangster/crime garage garner generated geniality genitals genre-busting genre-curling gentlemen gestalt ghastly giggly gimmick gimmicks girl-meets-girl girlish give-me-an-oscar glamorous gliding glinting glitter godfrey goldie golf good-naturedness goodies gooeyness government governments graham grain grandfather grandiosity grandly grasping greenlight grinds grinning grittily groan-to-guffaw groups guiding guilt-free gullible gum gunplay guru gut-clutching halfhearted handily handy hang hannibal haranguing hard-driving hard-edged hard-to-swallow hardass hardhearted harmoniously harvesting hashiguchi hat-in-hand heartily heft hell-jaunt heyday hideous hijinks hippie-turned-yuppie holiday-season homework homogenized horrified hudlin huge-screen hurried hurts hyped hyper-realistic hypnotically hysteria ichi icily icky illuminated imitations immigrant implies imponderably in-jokes inactive inconclusive inconsistencies independent-community indignant induces inexpressible inexpressive infectiously infrequently infuse infusing ingredient inhabitants inquisitiveness insider insinuation installments insulted inter-species internalized intimidated inventing investigation invulnerable iota irrational irreconcilable jabs jams jaw jeff jez joint jon journalists julie jump-in-your-seat kafka-inspired kaos kid-movie kid-pleasing kilt-wearing kiosks kitchen-sink klein knees knucklehead koepp kouyate kraft la lack-of-attention larson late-inning laugh-free laugher laugther laurence learns leatherbound lecture lesbian let-down lewd libretto limpid lingerie linking lisa live-style livelier liveliness liza lobbyists long-lived looney loopholes loosely-connected lothario lottery lovable-loser lucas lying lyrics m-16 macbeth maggots majors make-believe manipulating mankind marine marine/legal marketable mask masked massoud maximum meant meetings melted merrily mesmerised midst mika mileage millions miniseries misanthropy mishandled misty-eyed misunderstood mixes mob models moist molina money-grubbing monosyllabic monument moodiness moralistic morals mornings mouglalis mounted mourning moviegoer moviemakers muck mulan multi-character mumbo-jumbo mummy murk mythology nalin narcotized natural-seeming naturalism nausea navel ne near-fatal needless negate neighbor nerve-rattling newcomers nine-year-old no-surprise noisy nolan non-disney non-exploitive non-techies norm notches nurtured obscenity occur oddity odor officer oft-brilliant oh-so-hollywood ol' old-time olympic omniscient online operandi operatic opposed opposites organizing oscar-nominated otherworldly ounce outlet outre outweighs overdue overstating overstimulated overture overwhelmingly paean pageantry painless paint-by-numbers palate pallid paralyzed parapsychological paris parlor participant parties passive-aggressive pastry patchy pathos-filled pearce pearls peels pell-mell pender pent percentages percussion peroxide perry personable photos picaresque picpus piffle pipe pitted pivotal pixar pixilated playoff pogue point-of-view pokey pollute pop-up popped portraying posing posterity preaches preciousness predominantly prejudices prescription presume prevalent previously probation proclaim profiling progresses projection prominent prominently promotion prophet propriety-obsessed prostitute provincial provocations prowess prurient pseudo-intellectual psychologizing publicists publicity punishable purists purr pursuing putters pyrotechnics quadrangle r racing raffish rage ragged raison rake ransacks raphael ratio ratliff ravaging razzie re-working reagan reassure reassures rebel recharged recompense recovering recruiting reduce reductive reef reek refers refracting refuse regards regular rejiggering remarks remembrance reparations replacing repugnant resemblance restage restatement retaining returning revigorates revives revolting revolutionary rice riled ringside ripping roach robbed robberies rodrigues rogue romanced ron roster rounded ruins résumé saddens sale salle sally salma sanctimoniousness sandlerian sanity sayles scariest scathingly scene-chewing schmucks schticky scientific scooby-doo scotches screed screwing scuttled seamstress searches seasonal seizing self-absorbed self-absorption self-image self-knowledge self-revealing semi-amusing semimusical sent sermonize serrault settings shaky shambles sheets shepard shivers shoot-em-up shrug sickly sidekicks simpering singular skeeved skeleton skidding skimpy skims skipping skirts skullduggery skulls slam-bang sleep-inducingly slickly slogans sloppiness slyly small-screen smorgasbord smothered smutty snagged snap snappy snide snooze so-five-minutes-ago soft-porn somebodies sopranos sorriest sorrowfully souvlaki spanning spirals spiritless splendid-looking split-screen spooks sports-movie spry squeezed squirm-inducing stabs stagings stains star-power stardom steam steamy stepdad sterile stevenson stickiness stifling stiflingly stop-and-start strangling stress-reducing strolls students studio-produced stylings subconscious subjective suburban suffocate suggestion supple supplies sure-fire surplus surveillance sweet-tempered swill sword-and-sorcery swung sympathizing sytle tabloid tackled tactic tainted talked talks tardier tarzan tear-stained tearful tearing technologies teen-speak telescope temperamental terminal terrifically testosterone-charged theatrics themed therapy-dependent thirteen-year-old thorough thoughtless throbbing throwaway thurman ticks tinged tinsel tirade titillating tolerable tonight tons tops totalitarian tougher tourists toy traces traffics transfixes transparency transported traps tremendously trend trials trudge true-crime truest truffaut tryingly tunnels turmoil tweener tyco unadulterated unblinkingly uncluttered uncouth und undemanding underachiever underconfident underlies undermines underscoring undertaken undertones unearth unendurable unfold unfulfilled unguarded unified uninspiring uniqueness unit units unlikeable unpaid unparalleled unsaid unsalvageability unsatisfactorily unsuccessful unsung unveil unwary updated upping uproarious urbane variant variations vaunted veering ventura venturesome verdu verging video-cam videodrome vigils virtuosity virulently visitor vivacious void vonnegut vulgarities w. wafer-thin waif waited walken wallowing waltzed war-movie war-weary warn wartime watered-down weakly welcomes well-conceived well-contructed well-directed well-edited wheezy whippersnappers whipping whir whirlwind white-knuckled wide-awake widescreen wilder wimps windows wishy-washy wispy witlessness wizardry workers workmanlike workout workshop worried would-be wrapping writer-actor x-files yank zhuangzhuang 129-minute 140 15-year 1960s 1998 2-day 3-d 7th 87 9-11 9/11 ?!? a.c. a.e.w. abrasive absent acclaimed accompanies accompanying accomplishments acerbic acquire acting-workshop activate activities adaptations adhere adversity advertised advertisement advises affability affectingly affinity afloat afterlife aground ahola akin al aldrich all-out almodovar already-shallow alternating amalgam amusedly anarchist anchored anecdote annie anti- apartheid appearance appetite apple aptitude archives aristocracy aristocrat arliss arrived assembly assert association assures astronaut attending attics attributable audience-pleaser autistic avalanche avert awakens awareness b-12 backbone backed backlash bai balk ballerinas ballsy balzac bang-bang banking barrie barrow basest bath bathtub battered batting bawdy beat-the-clock bedfellows bedtime beginners behind-the-scenes beliefs belly-dancing bellyaching benefited benshan bernal bertrand bespeaks bestowed betrayed betting beware bickle bikes billing binary bio bitchy blanchett bleed blended blips block blockage blossom boll bolster bons borrow bottomlessly bought bounds branagh brat brazil-like breed brimming broader brothers-style bruckheimeresque buckaroo bucket build-up built-in bullseye bumps bungle bushels byatt cadence calling camerawork candles captivatingly carrey cautions cedar celebi characterizes choreography chouraqui church-wary circular civics clarke-williams classes classroom classy claude claustrophobia cleaving clicking cling clocks clones clotted clunker clutch clyde co-writer/director cocky cogent coherence coke cold-hearted combustion comedically commenting commiserating commitment common-man communicating community-college compels conceivable concludes concrete conducted configurations connections connoisseurs conrad consistency conspirators contained contemporaries contemptuous contentious contributed conundrum converts convincingly convolution cooker cooly cop-flick copenhagen copycat coral correctness counter counterproductive courtroom cow cracked cracks cranked crash-and-bash crashing crawlies credulity crispin critically criticizing critiquing cross-country crowdpleaser culture-clash curtains cut-and-paste cutting-room d.w. danish darkest date-night dave day-lewis day-old death-defying decadent decides deckhand decommissioned deemed defense deferred degrades delineate delving democracie demonic demonstrated denlopp deportment derisions desires despairing desultory developmentally di diamond dicaprio diction diego dim dirgelike disabled discontent discoveries dismantle dismember disneyland disparate distressing distressingly doctorate doings donald doors doorstep double-cross doubling doyle dozing dramatize dramatizing drang dreamscape dropped drug-induced drunk dudsville duel duly dunno dunst dwindles dystopian educational eee effectiveness egocentricities egypt eh elemental elvis embalmed embedded emerged emphasizing emphatic emulates energizing enforcement engine engorged enhance enhancing enlightenment entree entwined epics eroded eroti-comedy escapade eudora europe evacuations even-toned ex-girlfriend examples exasperatingly excellence excessively exchange exercises exiled expanded expansion expectant fabuleux facades famuyiwa fang-baring fantasti farcically fashioning fastballs fatalism fateful fault fearlessly feral fictional fidel filler fillers fillm filmgoers fire-breathing firth flabbergasting flag flag-waving fleeting flick-knife flinging flirts floats flog flows focuses folds follow-your-dream forefront forgettably formalist formidable forms forte four-year-old fourth-rate frankly fret frosting fruit full-fledged futility gadzooks gai gainsbourg galled gamesmanship garcia garcía garde garden garish gates gay-niche gays gender-bending genesis gestures gianni gifford gilliam girls-behaving-badly glancing glasses gleefully glover glows gluing gold goldberg goldman goliath gone-to-seed goodall goodness gordy goth gotten graces gracious grandmother grateful gratitude grenade griffith groove grossest gulpilil gushing gut guzman hail half-an-hour half-bad hallelujah hand-drawn handbag-clutching handheld hanna-barbera hanussen harangues hardware harmon harvey hatred headline-fresh heap hears hearst heartbreakingly heck heller helpful helping hermocrates hews hibiscus hidebound high-spirited high-tech highly-praised hipness hollowness holmes honoring hopkins/rock horizons hormonal horrific hotter hounds howlingly hubristic humanize humbling humiliated humor-seeking huskies hustlers hyper-real hypothesis ii-birkenau ill-equipped ill-wrought illuminates illumination illustrated imaginary imamura immense immersive impacts imperious impetus impish implodes imply in-your-face indignation individuality indulgence inextricably infamy infatuation inflammatory inflate infused insecurity insistence insistently instrument intacto intellectually intermediary intermingling intern internal intersect intrusive iris irritatingly item january jargon jaunt jazz jie join jonze journalistic joyful jugglers juiced juiceless justifies juxtaposition juxtapositions kathie kathy katzenberg ke khan kibbitzes kid-empowerment kiddies kilt kim kingdom kirsten kitten knockaround knockoff lackadaisical lacked lad lagging landbound larded large-format lascivious-minded lathan lawn laws lazily leaping leblanc lecherous lecter leontine leys lifelong lifting lily liman limply lina liner lioness lip-gloss lip-reading literarily litmus live-action live-wire loathe log logistically longtime louiso low-tech lukewarm lullaby lunar lushness made-up makeup-deep makin malapropisms manchild manoel manual marcken marginally margins margolo marilyn markedly marketing marry mason masters matches maxim mcmullen meddles media-soaked meet melanie memoir merchandised-to-the-max merited methodology mid-to-low middle-age miike milieu millennial millennium millisecond mined minnie miracles miserably misfiring misunderstanding mockumentary moldy-oldie mommy monologues monopoly month mopping moratorium mordant mortal mother-daughter motivation mount mountains movie-esque movie-making mush muted mystic mélange nail napoli narcissistic naval navel-gazing near-impossible near-masterpiece neeson neglecting nemesis nicolas nighttime nimble no-bull novelty nowheresville nubile nurses nurtures nymphette oblivious obscured observed occupation odoriferous off-kilter old-world oliviera omits open-mouthed oppositions opulent oral ordinances otherness outage outdoes outgag outpaces overachieving overcoming-obstacles overdone overeager overflows overlapping overlook overpowered oversexed oversimplification oversized overstylized paeans painkillers paltrow pan parables paradoxically pardon pasts patricio patriotic pause peaks peevish pellington penance peopled perceptiveness period-piece permitting personnel pet petty phenomena phenomenon philadelphia phonograph physician pics piecing pies piles pillowcases piss placed plaguing plastered plausible playbook plug plumbing plumbs poking politesse pop-cyber pop-influenced pore pornography poses positives possesses powerhouse powerment prank preceded prefeminist preoccupations preserves pressure prey prisoners pro pro-wildlife processed professors profundities prognosis progressed propelled prophecies propulsive province publishing pummel purge puréed queasy queasy-stomached quixotic quotations r&b rabbit-proof rainbows rampantly ran rancorous ranges rapport rash rat rates ravel reader ready-made realities reams recesses recorded reeked reeses reference referred refusing regimen reincarnation relays relieved religious reminders renegade-cop renner repetitively reportedly reporting repulse resolutions resources respecting respects restoring retold retreats returned revelled revived rewrite reynolds rhames ribisi rick ricture rigged rightly romanticism rooted rose-tinted roussillon rusi s1m0ne salt-of-the-earth salvation sappiness satisfaction sc2 scalds scared scarface scary-funny scented school-age schoolers scooter screenful scrutinize seater seedy seeping seesawing seinfeld seldhal self-critical self-defeatingly self-exploitation self-reflection self-satisfaction send-up sending sends sense-of-humour serendipity serene settled sexism shallows shapiro shear shell shirley shirt shocker shortness shriveled sidesplitting sidewalks silence silences silent silent-movie silly-looking simulation single-handedly single-minded sister six-time sketchiest slapped slash-fest sleep-inducing sleeping slow-paced small-scale smaller snaps snow-and-stuntwork soft softheaded solemnity solo someplace somnambulant sooooo soporific soul-stirring soullessness spain spall spending spikes spiritually splashing spoken spoofs sporting spotty spousal spring-break springer springs squander stakes stammering star-making steinis stew still-raw stodgy stranded stray streaks streamlined studiously stuffs sturm stylishly sub-formulaic sub-par sub-tarantino subliminally submerged subplot subtexts sufficient sufficiently summery sunset super-serious super-stupid superlative surface-effect surrealist suspecting swanson sweaty-palmed swiftly swirling swordfights symbiotic synergistic syrup systematically tacked tail tailor-made talancón tamer tantamount tartakovsky tasteless teams tear-drenched techniques teddy teen-exploitation teen-sleaze teetering televised tenacious tend terrified tested testing thirst thornier thoroughfare thoughtlessly threefold thrilled thrillingly tides till time-killer tiniest tits ton tootsie top-heavy tough-man traced track traits trash-cinema travel-agency travis trickery trifling trilogy trims triumphs trombone trove truck-loving truth-in-advertising tumult tumultuous tunisian turks turntablism tv-movie twilight twinkling typed types ugh uglier uma unaware uncharismatically uncharted unconcerned uncool under-10 underestimated undergraduate undertone undying unencouraging unlaughable unnamed unpleasantly unremittingly unstoppable upends upheaval uplifter urges uselessly vainglorious vary vast veins vibrantly video-shot videologue views vigorously villainous ving volatile volumes vulakoro walking-dead wall waltz wanes warm-milk warmed-over warped wash webcast weber weinstein weissman well-behaved well-formed well-mounted well-put-together welty wewannour whip-smart whitewash whoop wimmer wind-tunnel winger winter wise-beyond-her-years wisegirls wittier woods wrap wrath wretched writings yakusho yields yuen zhao zinger-filled zoning zzzzzzzzz élan !? '90s 'til 10-year-old 120 14-year-old 1790 1950s 1959 1960 1962 1973 1997 22 3/4th 6-year-old 65-minute 65-year-old 71 75 88 8th 94-minute 99-minute a-bornin abbott abdul abiding absence absorbs abused according achival acidity addresses adept adobo adroit advancing advocacy affectation affectation-free aftermath agape agreed airy alagna alan all-over-the-map alternate amélie anakin andrei angles animated-movie anthropomorphic anti-adult anti-catholic anti-feminist antonio appetites appropriated archibald archival argot aristocrats arithmetic arthur asiaphiles assaultive assuming attacks attentions attitudes attuned audrey austrian authenticity auto-critique avant avoiding axel ayatollah babbitt baboon back-stabbing backseat bait-and-switch baker ballroom banzai barbarian barking-mad baseball baseball-playing batman beachcombing bearable beavis becalmed beers beg belgium belittle belong bermuda bernard bet bettany bettany/mcdowell better-focused big-budget/all-star bilingual billy biologically biscuit bjorkness blab blacklight blatantly blemishes bloodbath blown-out bluffs blustery bmw bo boiled bombay border borscht botches botching bounces bowl box-office brady bray brims broke bronze brooks bugged bullfighters bully bumper bundling burgeoning butler butter butthead byler byron bytes büttner caesar caliber calories canada candor cantet carré cash cassel catcher categorize cattle caustic chabrolian championship character-driven charred charting chastity cheap-shot cheapen cheats chips chou-chou christian-themed chuckling chung chyna cinematically clarify claws clean-cut cleaver clinch clinically closed co-dependence co-writers coaster coinage cold-blooded cold-fish collaborators collect collegiate collie coltish columbus comeback comforting comic-book commended communication communications compass competing complicate comprise compromised compulsive conceal conceited conceptions concerning concerns conduits confuse congrats consoled conspiracies consumerist contagious contours control-alt-delete controlled cool-j coordinated corcuera coupling crave cream creator crematorium cringing critic criticizes crooning cross-dressing crudities crushed cultivated cunning currently curtsy cutes cyber dabbles dark-as-pitch debris decades-spanning decent-enough decline deconstruction deem deficit defines degrading delete delhi demi democracy densely deny dependent depictions depravity describing destin detractors devils devious diatribes dictates die-hard dies dipped dire disadvantage disagreeable disgusted dislikable dismal dismissive disoriented dispossessed distinction diver divertissement divided documentarians dodge dogme dolly dominate dosage doting doubles doubtful down-home downplaying drama/action dramatization dramedy drawings dresses dridi drinking dripping driving drudgery drum drumbeat dualistic duds dungeons dwarfs dylan dynamism dyslexia eccentricity economics editor editorial eighth-grader einstein electrocute eludes emaciated emailed embodies emphasising emphasized enabling enact enactments encyclopedia endured english-language engulfed enlightened enlivens ensures epitaph ernest errol eternity exaggeration exalts excite exemplify exhaustingly exhibitionism exhibits existing explicit exploratory extends extracting exuberantly exude faceless fai fanboy fancies farm father-and-son favored featured feeds fencing fetishes fifty figuring fiji fist fizzle flails flakiness flood floundering flounders flurries focusing fogging follow-up foreman forged forgot formalism fortify forward foundering frat-boy fresh-squeezed friendly fringe frolic fruitful frustrations fueled fuhrman funeral furious fuzziness g-rated gained gallo garnish gaudy gaye gaza generational gesture glamour glosses glumly godzilla goose-pimple gore-free governmental grabs graced graffiti granted graphically gratify graze groaner guarded guiltless guise gussied hackery hailed half-dozen half-lit ham handiwork hard-eyed hardest harshness hayao head-on headaches hearing heavyweights heels heidegger hellstenius helms heroism hideousness high-adrenaline high-buffed hinges histrionic hit-and-miss hit-man hitchens hobbled hogwash hollywood-itis hollywood-predictable holofcenter homo-eroticism hoofing hooked hoopla horrifyingly horses hossein hotels hotsies hotter-two-years-ago housing howler humbuggery hunger hunky hyper-artificiality hypermasculine iced ideology idling imbecilic imparted impatient implications implicitly implied importantly imposter impressionistic improvise in-jokey inadvertently incurably indecipherable indistinct indoors indulge indulges ineptitude infantilized infinitely informed ingenue ingest inhuman insatiable inside-show-biz insouciance instances instruct intelligibility interludes interpreting interweaves intolerant introverted investment invokes iosseliani iq irredeemably irvine irwins isabelle israeli/palestinian italian-language ivory jacobi jacquot japanimator jelly jettisoned jews joffé joined joins jostles journalistically jovial juan juliet/west jury justifying ka kahn karim kazan keenly keg kendall kids-in-peril kinnear kitschy knocks koury kubrick kumble kurys lab label lamentations landau larky late-summer lawyers lds leader ledger leery lemon leniency leroy lessen lethally li liberalism life-at-arm lifted lighter like-themed lin linear linger lint liotta little-remembered liu lizard lo location lolita lone loop looseness losin lovably love-hate lugubrious lumps lungs lusty luther lynne machinery macho macnaughton madcap magician maintained majidi malcolm malik malleable mamá mandy manhunter mapquest marveilleux marvin marxian masochistic mass masterpeice matched mateys matinee-style mccann mckay mclaughlin mediterranean medium-grade melting mending mentioned menu merge merges meshes meyjes michele michelle microscope middle-class midwest military milking milks minac mini minimum minor-league mire miscalculates miscalculations misfortune mishandle mitch mitchell moan model modernizes modicum monkey monologue morlocks morphs morris mortality movie-industry movie-specific mugs murder-on-campus murderer murders murray musketeer myriad mystification mythologizing nada nail-biter namely naomi narrated natalie naturalness nauseating neglects neo-nazism neophyte nesbitt neurotic neurotics neverland nietzsche-referencing ninety non-fan non-narrative non-porn northwest norwegian not-so-divine not-so-small notch noticed notting numbered nyc obsessively occupied off-center offset often-hilarious oh-so-important okay olives omitted omnibus one-room one-sidedness oo ooze openly opportunism opting options orbit organized orlean oscars outline outrageousness overacted overboard overrides overstated overwhelms p.t. pa pabulum page-turning pageant palestinian pamela pander pap paradigm parmentier partner partners partnerships parton paulette pauline paved peanut peralta peril period-perfect perpetual perplexing persistence persistent personally personas perversely pessimistic phantasms phantom phocion photographic picnic piercingly pig pinheads pinks pinochet plasma playstation plea plutonium point-to-point pointedly poke-mania policy pollak pollyana pomposity pondering poof pooper-scoopers populating portent possession poster poster-boy posthumously pouty-lipped powered praying pre-teen pre-wwii preteen pretence prints proceed prod produces producing prone property prostituted protestors provocateur pseudo-philosophic psychedelic published punched punitive purgatory purportedly pursued pursuers père-fils q. quashed quasi-improvised quasi-shakespearean quick-cut quicksand quintet radiates rafael ragbag raging rails ramifications randolph rape rapidly rapids rated raving raw-nerved re-hash re-invents reap recently recessive reclaiming recognizably recreation reenacting reflected register relating relied relish remove render replace resonances restate resume resurrection resuscitate rethink retooled retrograde retrospective reunion reunions reversal reverse revision revisionism reward reworking rhapsodize rhino riding rigidly rigors ringu rinzler rip-roaring ripper risky rivals roads roars rockumentary rodan roland roller romantics romeo romoli rooms rough-trade rouses route routes roy rubbish rudd ruffle ruggero ruin ruse rusted-out ruthless rye sacrificed samurai sanitised sat satisfactorily scam scar schepisi schiffer schneidermeister schrader scoob scorcher scored scratches screeching-metal seated secondary selby seldom self-determination self-expression self-pitying self-reflexive self-styled semi-stable september serbs serial serious-minded serry servicable setpiece setup seventy-minute severely sexiness shaken shatters shaw shimmeringly shoulder signing sirk sixth sixties sixties-style skid-row skilled skinny skyscraper skyscraper-trapeze slaps sleight-of-hand slickest slickness slide slivers slot slow-motion slow-moving smokers smuggling snapping sniffle snoozer soaked soapy solomonic sommers song-and-dance-man sooner special-interest spin-off splatterfests spookily spot-on spouting springboard springing squaddie squirts stadium-seat stalker stalking stallone standardized stanzas startled stayed steadfast step-printing stereotype stiletto-stomps stimulus stoner storytellers straddle straight-faced straight-shooting streetwise strenuous strikingly strip-mined stroked strutting stultifying stumble submerging sugar-coated suggestive sullivan summons sunbaked super-sized superman supremes surest surge surveys susceptible suspended swathe sweet-and-sour sweetest swims sydow t-shirt tad taiwanese takashi talkiness también tangents tap-dancing tautou techno-tripe teenaged telanovela telephone temptingly tendencies tendentious tenderly testosterone thank theorist therapeutic therefore thewlis thievery thinks-it-is thousand-times threatens thuds tick-tock tidal tie tim tipped tiring titans title-bout togetherness tolerable-to-adults tonal too-hot-for-tv tornatore touchingly toughest tradition-bound trailer-trash trained trainspotting tranquil transition traveled treacly tree trembling trial trickster trimmed triviality trivialize tron troopers tropic troubadour trusted truths tu tufano tv-insider twelve twenty-some twentysomething twinkly-eyed uk ultra-cheesy unaccountable unafraid unbelievably under-7 underdeveloped undermined underutilized underventilated undiminished unexplained unhappiness unhurried unleashed unorthodox unromantic unscathed unslick unspeakably unstable unsuccessfully unthinkable unwavering useful ushered usurp vacuous vardalos veggies venality venice/venice verdict verismo violated virgin vocalized voting wai walter warned warner watched waterboy wazoo wearisome weighed weightless well-meant well-structured well-trod wheels whet whine whiney whirling wholesale whoopee-cushion widely wildcard willie willingly wince wire wireless wiser wives wobbly wondered wonderland wonton wrecks wrestler write writer-producer-director writes wyman y ya-yas yawner yeah yorkers z-boys zap zemeckis zoe '53 'n 170 1899 1950 1980 1986 4/5ths 50-year-old 50s 52 60-second 72 96 abbass abridged accentuating accommodate accomodates accumulate acolytes action/thriller actorish addams addiction ado adopts adrian adrien advert affect afghan afterschool agendas agony aid aided al. alexander alias alienated allison altered alternatives altman altman-esque amari amaro ambitiously ambivalent amc ame amidst amini amir amuses ana analgesic anarchy anchor angela animaton anonymity answered anti-war anyplace appetizer aragorn aranda arch archetypal architecture argentinean arkansas arwen asquith assess assign assignment astronauts attentive attracts audacious-impossible augustine aurelie australian auteil autocritique aversion award-winning ayres ayurveda babies babysitter backwater bagatelle baird ballerina ballet bang-the-drum barbs barriers bathos battles be-all-end-all be-bop beast-within beasts beating becker beckons bed bedevilling behan beijing believes belinsky beloved-major benevolent bergman berkeley berlin betty bibbidy-bobbidi-bland blandly blarney blasting bless blimp bluer blush boarders bollywood/hollywood bombards bonding booths boston bots bottom-rung boxes boy-meets-girl brainy breasts breitbart brent bristles brittle broadside brody bubbles buff buffeted bug bug-eye bullwinkle buries burlesque busby buscemi bustling bypassing cabins cam cameras capitalism capitalize capitalizes captions captive captives captors carlito carousel cascade cavorting centers chai chapter characterize chilled chimney chimps chokes christelle cimarron circa circles civil civilization clarissa clearasil cliffhanger climbing closure clothing cloudy colin colleagues colorfully com commune compelled complaints components compositions computerized concealment confessional confirms conflicted conjuring conscience consequence consists constricted constrictive consume contain continually contorting contribute conveys copious corbett corpus correctly couch counting cousin covered crassness crazier creepy-crawly cremaster crimen cross-cultural crushing cryin cuaron cub curlers curve d. daddy dalloway dampened danis daredevils dazed dead-eye dearly deblois debuts decasia decisively decoder defeats defiant defuses degenerates deleting deliberateness denmark dense dependence depending depicted deranged derivativeness desecrations desplat detachment detention deteriorates developers deviant deviously dicey dichotomy dick dictator-madman differently dignified dimensional disassociation discloses disconcertingly disconnection discursive discuss discussed dishes dismally disney-style disobedience dispatching dispel displacement disposible disrobed distill distinguishable distorts distracts disturb disturbed diva diverges doctor docu-drama documentary-like dodger dogma dognini dolby dominated domineering donna doris dorm doshas dot dotted double doug dr. drek drinker drumming drung dry-eyed duking dumbo e-graveyard eardrum-dicing earnhart edged elderly elect electoral electronic eleven elicited elves elysian emergency enamored endgame endings enemies engineering entering entries envelope envelops environments epiphanies eponymous equate erupt esoteric essayist establish establishing establishment estela etched ethnography evenly everlyn exceeding exceptions exhibit expand explain explodes exploits expressively expressly exquisitely extant extrusion eye-boggling eye-catching f f. fabian face-to-face factory facts fairies famed fantasized farenheit feardotcom fetid fields fiery fifth film-culture finch finishing firmer flashback flashbulb fleder fleet-footed floyd forewarned forgoes fork forty forwards fosters founders frankenstein-monster frankie freaking freewheeling freight french-produced frenzy frighten fuses gaghan gaitskill galvanize gambles gamely gandalf gaunt gaï geeked geneva georgia gesturing gheorghiu gidget giggling girlfriends glitz gobble good-looking goombah gorgeousness governance grabowsky great-grandson greengrass grenoble grip grizzled grouchy grunge-pirate grungy gurus hack-and-slash hades hagiographic hairdo hairier half-assed half-hour halle hammer hammering hammers hammily hands-on hanley haphazardness happily-ever hard-core hardwood harlem harmed harps hart haute headbanger health heartland heartstrings hermitage herring hierarchy highlander highways hints hippopotamus hitler-study hjelje hobnail holistic holland holy homeric homosexuality hookers horde horrid hos households howling hush hypertime hélène i-heard-a-joke iben idiom ignites immortal immune imparting impatiently impresses inauspicious inclusiveness independence indispensable individuals ineffable inert inertia infants infinite infirmity ingeniously insanity insiders insular inter-family inter-racial interlocked interminably intervention interviewees intoxicatingly intractable invested ireland irreparably irreversible irrevocable irrigates iteration itinerant j.r.r. jackal jacobson jammies jangle jeremy jesus jia judgment jumbled june kane kang kathryn keener kiddie-oriented kidnapper knitting kubrick-meets-spielberg kuras kurds labors laced lacey ladles lagaan lapaglia lasker late-night lau leanest leaning lector led leers lending les less-than-thrilling lethargically letterman liana limps lingered lionize lived-in liven living-room loaf locusts lohman longley louts lovingly lucid lucks luke lushly lynch-like machinations made-for-tv mafia magi maguire maiden mail maladjusted malle malls manically march marcus margaritas marker markers maryam massacres mctiernan meanest meatier media-constructed method methodically metro middle-earth mighty mind-bender mind-bending misbegotten mischievous misleading mobius mock modern-office molehill monkeys moon mora morrissette mosque movie-biz mrs. muckraking mud mumbles murderous musset muster nanook naqoyqatsi nationalism nature/nurture necessity needing needles negated negativity neo-augustinian neo-hitchcockianism neo-noir nerve-raked newfoundland niblet nicks no-holds-barred non-bondish nonbelievers nonethnic notations numbness obligations observe occasion odd-couple oddest off-puttingly off-screen oh oh-those-wacky-brits olympia onion oodles oozes open-minded opinion orleans orthodox otar ourside outnumber outward outweigh over-familiarity overnight overwhelm owe ownership p.c. padre pageants palm pan-american pandora parachutes paranormal paraphrase parrot party-hearty pasach passionately pastel patent patting pax pay-off payback peaked pedro peep percolating persecuted pessimists pete petite phifer plaintiveness plan plate plato play-doh playboy playwright pledge plod plods plunges pluto police-procedural poo populist pork portrayals post-camp post-production post-war potter potty-mouthed practical practitioners pray pre-9 predators preferable prepared preposterousness presses prima princesses pro-serbian professor profits progression propensity pros prosaic protecting prozac pseudo-rock-video psychodrama publicist pulchritude punch-drunk punches punchline puritanical purports pyschological qatsi quasi-documentary quinn rainbow rancid rape-payback rapturous re-fried re-imagining re-voiced react readings reaffirms realization reassembled recalling reconceptualize recreating recycles redeem redefinition rediscovers reel/real referential reginald regrets reid reigen reinvention rejection rejects rekindles relocated remained rendition renewal reporters represent requiem requires residences residents respectively responses restless restored rests resumes retaliatory retrospectively reverberates reverence reyes riddle right-hand risk rolls rom romance-novel romantic-comedy romanticized romero rosario rosemary rothman rubber-face s&m safely sailor salacious salvaged sam sampi sandeman sanders sascha satan savaged scarpia schindler schools schtick scouse scrap screwy scribe scripted scripters secretions secretly seeds seldahl self-flagellation self-inflicted self-sacrifice self-serious seller selling semi-surrealist separates sept. serenely serviceability shag shamefully shatner shattering shimizu shoe-loving shoestring shopping short-story shrewdly shum simpson sinuously siren situates skipped slather sleepwalk slides slights slop sloughs slowness smashups snapshot snatch sniping snobbery sober-minded sociopathy sodden sofia solely songbird sorcerer sorely space-based spanish sparked spawned specialized specifics spied spinoff splatter spouses spout spreading squabbling stabbing stacks stacy staggers stallion steadily stepmom steps stereo stinging stink stockings stomps stoppard straddles strangest strategies stress stripe stupefying style-free subsequent subsided subtitled suicide sunk suppression surfer surround surrounds sven swingers sy sylvie symbols t. tara tarkovsky tease teasing telenovela temporal temptations ten-year-old tenth texan theology thinly-conceived three-to-one throes thump tighter tinseltown tissue-thin tit-for-tat tnt tome tommy tonally toolbags tools tormented tourism tov towers tracking translating trivializing trot true-blue trumpet trusts truth-telling turpin twinkie ugly-duckling ultra-provincial unambitious unbreakable uncinematic uncommitted under-rehearsed undergoing underway underwear underwhelming undogmatic uneasily unembarrassing unemployment unevenly unfairly unhinged unreachable unrealistic unseemly unsophisticated untrained unwillingness updates upstaged usher ut utilizes variable vaudeville ver verite vicarious vicente video-viewing video/dvd villainess villeneuve virtue viva viveka vivi wade waged waits walks war-torn warlord wash. watercolor watstein wattage weaver wedgie well-rounded welles wending whiffle-ball widget wiel wintry wishful wistfully wizard woe woodman woolf worship wounded wounding wrists wrought yawn-provoking year-end yoda yong yvan zeitgeist zen zips zwick ================================================ FILE: examples/sst2/vocabulary.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A vocabulary that represents the tokens in a dataset and maps them to indices.""" import collections from typing import Optional from collections.abc import Iterable, Sequence from absl import logging class Vocabulary: """Represents a vocabulary that can be built from a dataset.""" def __init__( self, vocab_path: str | None = None, tokenized_sequences: Iterable[Sequence[bytes]] | None = None, min_freq: int = 1, pad_token: bytes = b'', unk_token: bytes = b'', bos_token: bytes = b'', eos_token: bytes = b'', ): """Loads the vocab from disk (if `vocab_path` is given) or builds it from `tokenized_sequences`.""" self.pad_token = pad_token self.unk_token = unk_token self.bos_token = bos_token self.eos_token = eos_token self.special_tokens = (pad_token, unk_token, bos_token, eos_token) if vocab_path: self.load(vocab_path) elif tokenized_sequences is not None: self.build(tokenized_sequences, min_freq=min_freq) else: raise ValueError( ( 'Vocabulary needs either `vocab_path` or `tokenized_sequences` to' ' be provided, got %r and %r.' ) % (vocab_path, tokenized_sequences) ) def build( self, tokenized_sequences: Iterable[Sequence[bytes]], min_freq: int = 1 ): """Builds a vocabulary over tokens with optional minimum frequency. Args: tokenized_sequences: Iterable of token sequences to build the vocabulary. min_freq: The minimum frequency of each token to be included. Default: 1. """ # Count all the tokens. counter = collections.Counter() for token_sequence in tokenized_sequences: counter.update(token_sequence) # Add special tokens to the start of vocab. vocab = collections.OrderedDict() for token in self.special_tokens: vocab[token] = len(vocab) # Add all other tokens to the vocab if their frequency is >= min_freq. for token, freq in sorted( # Sort by frequency (from high to low), and then by token string. # This makes sure high frequency tokens get a low token ID. counter.items(), key=lambda token_freq: (-token_freq[1], token_freq[0]), ): if freq >= min_freq: vocab[token] = len(vocab) logging.info('Number of unfiltered tokens: %d', len(counter)) logging.info('Vocabulary size: %d', len(vocab)) self.vocab = vocab def _getitem__(self, key: str): return self.vocab[key] def keys(self): return self.vocab.keys() def values(self): return self.vocab.values() def __len__(self): return len(self.vocab) @property def pad_idx(self): """The index of the padding token.""" return self.vocab[self.pad_token] @property def unk_idx(self): """The index of the unknown word token.""" return self.vocab[self.unk_token] @property def bos_idx(self): """The index of the beginning-of-sequence token.""" return self.vocab[self.bos_token] @property def eos_idx(self): """The index of the end-of-sequence token.""" return self.vocab[self.eos_token] def load(self, path: str) -> None: """Loads a vocabulary (one token per line) from disk.""" vocab = collections.OrderedDict() with open(path, 'rb') as f: for i, token in enumerate(f): assert isinstance(token, bytes), 'Expected byte tokens.' vocab[token.rstrip()] = i logging.info('Loaded vocabulary, size: %d', len(vocab)) self.vocab = vocab def save(self, path: str) -> None: """Saves the vocabulary to disk.""" with open(path, 'wb') as f: for token in self.vocab: assert isinstance(token, bytes), 'Expected byte tokens.' f.write(token + b'\n') logging.info('Saved vocabulary to %s.', path) ================================================ FILE: examples/vae/README.md ================================================ # Basic VAE Example This is an implementation of the paper [Auto-Encoding with Variational Bayes](http://arxiv.org/abs/1312.6114) by D.P.Kingma and M.Welling. This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blob/master/vae/README.md). ```bash pip install -r requirements.txt python main.py --workdir=/tmp/mnist --config=configs/default.py ``` ## Overriding Hyperparameter configurations This VAE example allows specifying a hyperparameter configuration by the means of setting `--config` flag. Configuration flag is defined using [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). `config_flags` allows overriding configuration fields. This can be done as follows: ```shell python main.py \ --workdir=/tmp/mnist --config=configs/default.py \ --config.learning_rate=0.01 --config.num_epochs=10 ``` ## Examples If you run the code by above command, you can get some generated images: ![generated_mnist](./sample.png) and reconstructions of test set digits: ![reconstruction_mnist](./reconstruction.png) The test set loss after 10 epochs should be around `104`. ================================================ FILE: examples/vae/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.learning_rate = 0.001 config.latents = 20 config.batch_size = 128 config.num_epochs = 30 return config ================================================ FILE: examples/vae/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for VAE dataset.""" import jax import jax.numpy as jnp import tensorflow as tf import tensorflow_datasets as tfds def build_train_set(batch_size, ds_builder): """Builds train dataset.""" train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN) train_ds = train_ds.map(prepare_image) train_ds = train_ds.cache() train_ds = train_ds.repeat() train_ds = train_ds.shuffle(50000) train_ds = train_ds.batch(batch_size) train_ds = iter(tfds.as_numpy(train_ds)) return train_ds def build_test_set(ds_builder): """Builds train dataset.""" test_ds = ds_builder.as_dataset(split=tfds.Split.TEST) test_ds = test_ds.map(prepare_image).batch(10000) test_ds = jnp.array(list(test_ds)[0]) test_ds = jax.device_put(test_ds) return test_ds def prepare_image(x): x = tf.cast(x['image'], tf.float32) x = tf.reshape(x, (-1,)) return x ================================================ FILE: examples/vae/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the VAE example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) flags.mark_flags_as_required(['config', 'workdir']) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': app.run(main) ================================================ FILE: examples/vae/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """VAE model definitions.""" from flax import linen as nn from jax import random import jax.numpy as jnp class Encoder(nn.Module): """VAE Encoder.""" latents: int @nn.compact def __call__(self, x): x = nn.Dense(500, name='fc1')(x) x = nn.relu(x) mean_x = nn.Dense(self.latents, name='fc2_mean')(x) logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x) return mean_x, logvar_x class Decoder(nn.Module): """VAE Decoder.""" @nn.compact def __call__(self, z): z = nn.Dense(500, name='fc1')(z) z = nn.relu(z) z = nn.Dense(784, name='fc2')(z) return z class VAE(nn.Module): """Full VAE model.""" latents: int = 20 def setup(self): self.encoder = Encoder(self.latents) self.decoder = Decoder() def __call__(self, x, z_rng): mean, logvar = self.encoder(x) z = reparameterize(z_rng, mean, logvar) recon_x = self.decoder(z) return recon_x, mean, logvar def generate(self, z): return nn.sigmoid(self.decoder(z)) def reparameterize(rng, mean, logvar): std = jnp.exp(0.5 * logvar) eps = random.normal(rng, logvar.shape) return mean + eps * std def model(latents): return VAE(latents=latents) ================================================ FILE: examples/vae/requirements.txt ================================================ absl-py==1.4.0 flax==0.6.9 numpy==1.23.5 optax==0.1.5 Pillow==10.2.0 tensorflow==2.12.0 tensorflow-datasets==4.9.2 ================================================ FILE: examples/vae/results/.gitignore ================================================ *.png ================================================ FILE: examples/vae/train.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Training and evaluation logic.""" from typing import Any from absl import logging from flax import linen as nn import input_pipeline import models import utils as vae_utils from flax.training import train_state import jax from jax import random import jax.numpy as jnp import ml_collections import optax import tensorflow as tf import tensorflow_datasets as tfds @jax.vmap def kl_divergence(mean, logvar): return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar)) @jax.vmap def binary_cross_entropy_with_logits(logits, labels): logits = nn.log_sigmoid(logits) return -jnp.sum( labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits)) ) def compute_metrics(recon_x, x, mean, logvar): bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean() kld_loss = kl_divergence(mean, logvar).mean() return {'bce': bce_loss, 'kld': kld_loss, 'loss': bce_loss + kld_loss} def train_step(state, batch, z_rng, latents): """Train step.""" def loss_fn(params): recon_x, mean, logvar = models.model(latents).apply( {'params': params}, batch, z_rng ) bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() kld_loss = kl_divergence(mean, logvar).mean() loss = bce_loss + kld_loss return loss grads = jax.grad(loss_fn)(state.params) return state.apply_gradients(grads=grads) def eval_f(params, images, z, z_rng, latents): """Evaluation function.""" def eval_model(vae): recon_images, mean, logvar = vae(images, z_rng) comparison = jnp.concatenate([ images[:8].reshape(-1, 28, 28, 1), recon_images[:8].reshape(-1, 28, 28, 1), ]) generate_images = vae.generate(z) generate_images = generate_images.reshape(-1, 28, 28, 1) metrics = compute_metrics(recon_images, images, mean, logvar) return metrics, comparison, generate_images return nn.apply(eval_model, models.model(latents))({'params': params}) def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Train and evaulate pipeline.""" tf.io.gfile.makedirs(workdir) rng = random.key(0) rng, key = random.split(rng) ds_builder = tfds.builder('binarized_mnist') ds_builder.download_and_prepare() logging.info('Initializing dataset.') train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder) test_ds = input_pipeline.build_test_set(ds_builder) logging.info('Initializing model.') init_data = jnp.ones((config.batch_size, 784), jnp.float32) params = models.model(config.latents).init(key, init_data, rng)['params'] state = train_state.TrainState.create( apply_fn=models.model(config.latents).apply, params=params, tx=optax.adam(config.learning_rate), ) rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, config.latents)) steps_per_epoch = ( ds_builder.info.splits['train'].num_examples // config.batch_size ) for epoch in range(config.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) state = train_step(state, batch, key, config.latents) metrics, comparison, sample = eval_f( state.params, test_ds, z, eval_rng, config.latents ) vae_utils.save_image( comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8 ) vae_utils.save_image( sample, f'{workdir}/sample_{epoch}.png', nrow=8 ) print( 'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( epoch + 1, metrics['loss'], metrics['bce'], metrics['kld'] ) ) ================================================ FILE: examples/vae/utils.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This code is created with reference to torchvision/utils.py. Modify: torch.tensor -> jax.numpy.DeviceArray If you want to know about this file in detail, please visit the original code: https://github.com/pytorch/vision/blob/master/torchvision/utils.py """ import math import jax.numpy as jnp import numpy as np from PIL import Image def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format_img=None): """Make a grid of images and Save it into an image file. Args: ndarray (array_like): 4D mini-batch images of shape (B x H x W x C) fp: A filename(string) or file object nrow (int, optional): Number of images displayed in each row of the grid. The final grid size is ``(B / nrow, nrow)``. Default: ``8``. padding (int, optional): amount of padding. Default: ``2``. pad_value (float, optional): Value for the padded pixels. Default: ``0``. format_img(Optional): If omitted, the format to use is determined from the filename extension. If a file object was used instead of a filename, this parameter should always be used. """ if not ( isinstance(ndarray, jnp.ndarray) or ( isinstance(ndarray, list) and all(isinstance(t, jnp.ndarray) for t in ndarray) ) ): raise TypeError(f'array_like of tensors expected, got {type(ndarray)}') ndarray = jnp.asarray(ndarray) if ndarray.ndim == 4 and ndarray.shape[-1] == 1: # single-channel images ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1) # make the mini-batch of images into a grid nmaps = ndarray.shape[0] xmaps = min(nrow, nmaps) ymaps = int(math.ceil(float(nmaps) / xmaps)) height, width = ( int(ndarray.shape[1] + padding), int(ndarray.shape[2] + padding), ) num_channels = ndarray.shape[3] grid = jnp.full( (height * ymaps + padding, width * xmaps + padding, num_channels), pad_value, ).astype(jnp.float32) k = 0 for y in range(ymaps): for x in range(xmaps): if k >= nmaps: break grid = grid.at[ y * height + padding : (y + 1) * height, x * width + padding : (x + 1) * width, ].set(ndarray[k]) k = k + 1 # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = np.array(jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)) im = Image.fromarray(ndarr.copy()) im.save(fp, format=format_img) ================================================ FILE: examples/wmt/README.md ================================================ ## Machine Translation Trains a Transformer-based model (Vaswani *et al.*, 2017) on the WMT Machine Translation en-de dataset. This example uses linear learning rate warmup and inverse square root learning rate schedule. Table of contents: - [Requirements](#requirements) - [Example runs](#example-runs) - [Running on Cloud](#running-on-cloud) - [Preparing the dataset](#preparing-the-dataset) - [Google Cloud TPU](#google-cloud-tpu) ### Requirements * TensorFlow datasets `wmt17_translate/de-en` and `wmt14_translate/de-en` need to be downloaded and prepared. A sentencepiece tokenizer vocabulary will be automatically generated and saved on each training run. * This example additionally depends on the `sentencepiece` and `tensorflow-text` packages. ### Example runs You should expect to get numbers similar to these: Hardware | config | Training time | BLEU | TensorBoard.dev | Workdir -------- | ------- | ------------- | -------------- | ------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------- TPU v3-8 | default | 24m
13h18m | 25.55
32.87 | [2021-08-04](https://tensorboard.dev/experiment/nnH7JNCxTgC1ROakWePTlg/) | [gs://flax_public/examples/wmt/default](https://console.cloud.google.com/storage/browser/flax_public/examples/wmt/default) TPU v3-32 | default | 3h1m | 32.45 | [2021-11-05](https://tensorboard.dev/experiment/7IKeXjoeRKiMtqysQlbqzw/) | [gs://flax_public/examples/wmt/default_v3-32](https://console.cloud.google.com/storage/browser/flax_public/examples/wmt/default_v3-32) GPU V100 x8 (Mixed Precision) | gpu_mixed_precision | 1h 58m | 25.69 | [2021-07-07](https://tensorboard.dev/experiment/9S2WuqNWRDemmBuQE8K6Ew/) | - ### Running on Cloud #### Preparing the WMT Datasets We recommend downloading and preparing the TFDS datasets beforehand. For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS data on a storage bucket, from where it can be loaded directly. Set the `TFDS_DATA_DIR` to your storage bucket path (`gs://`). You can download and prepare any of the WMT datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=wmt17_translate/de-en` The typical academic BLEU evaluation also uses the WMT 2014 Test set: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=wmt14_translate/de-en` #### Google Cloud TPU See below for commands to set up a single VM with 8 TPUs attached (`--accelerator-type v3-8`), or for an entire TPU slice spanning multiple VMs (e.g. `--accelerator-type v3-32`). For more details about how to set up and use TPUs, refer to Cloud docs for [single VM setup](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) and [pod slice setup](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). First create a single TPUv3-8 VM and connect to it: ``` ZONE=us-central1-a TPU_TYPE=v3-8 VM_NAME=wmt gcloud alpha compute tpus tpu-vm create $VM_NAME \ --zone $ZONE \ --accelerator-type $TPU_TYPE \ --version v2-alpha gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ -L 6006:localhost:6006 ``` When connected install JAX: ``` pip install "jax[tpu]>=0.2.21" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` Then install Flax + the example dependencies: ``` git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/wmt pip install -r requirements.txt ``` And finally start the training: ``` python3 main.py --workdir=$HOME/logs/wmt_256 \ --config.per_device_batch_size=32 \ --jax_backend_target="grpc://192.168.0.2:8470" ``` Note that you might want to set `TFDS_DATA_DIR` as explained above. You probably also want to start the long-running command above in a `tmux` session and start some monitoring in a separate pane (note that we forwarded port 6006 locally above): ``` tensorboard --logdir=$HOME/logs ``` When running on pod slices, after creating the TPU VM, there are different ways of running the training in SPMD fashion on the hosts connected to the TPUs that make up the slice. We simply send the same installation/execution shell commands to all hosts in parallel with the command below. If anything fails it's usually a good idea to connect to a single host and execute the commands interactively. For convenience, the TPU creation commands are inlined below. Please note that we define `GCS_TFDS_BUCKET` to where your data stands in your cloud bucket. Also `YOUR_BUCKET` is the work directory you are experimenting in. You should choose ZONE based on where your TPU and work directory is. [Here](https://cloud.google.com/tpu/docs/types-zones) has some useful information on which zones you can have different types of TPUs. ```shell VM_NAME=wmt REPO=https://github.com/google/flax BRANCH=main WORKDIR=gs://$YOUR_BUCKET/flax/examples/wmt/$(date +%Y%m%d_%H%M) gcloud alpha compute tpus tpu-vm create $VM_NAME \ --zone=$ZONE \ --version v2-alpha --accelerator-type v3-32 FLAGS="--config.num_train_steps=$(( 100 * 1000 * 8/32 )) --config.warmup_steps=$(( 1000 * 8/32 )) --config.checkpoint_every_steps=$(( 10 * 1000 * 8/32 ))" gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE \ --worker=all --command " set -x pip install 'jax[tpu]>=0.2.21' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install --user git+$REPO.git && (test -d flax || git clone --depth=1 -b $BRANCH $REPO) && cd flax && pip install -e . && cd examples/wmt && pip install -r requirements.txt && export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets && python3 main.py --workdir=$WORKDIR --config=configs/default.py $FLAGS " ``` Please don't forget to disconnect and delete your vm after you are done: ``` gcloud alpha compute tpus tpu-vm delete $VM_NAME \ --zone $ZONE ``` ================================================ FILE: examples/wmt/bleu.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Parallel BLEU score calculation. This version of BLEU calculation is derived from the MLPerf transformer reference. Tries to match SacreBLEU metric reasonably well, but is not identical. Refs: tokenizer at: https://github.com/tensorflow/models/blob/master/official/transformer/utils/tokenizer.py original preprocessing tokenizer: https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 original t2t code: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py Usage: refs = '''food bar brown cow blee bloo dog sat or please take me out ''' hyps = '''foo bar brown cow blee bloo dog sit please do take me out ''' bleu_local(refs.split("\n"), hyps.split("\n")) # 39.65 """ import collections import math import re import sys import unicodedata import numpy as np class UnicodeRegex: """Ad-hoc hack to recognize all punctuation and symbols.""" def __init__(self): punctuation = self.property_chars("P") self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])") self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])") self.symbol_re = re.compile("([" + self.property_chars("S") + "])") def property_chars(self, prefix): return "".join( chr(x) for x in range(sys.maxunicode) if unicodedata.category(chr(x)).startswith(prefix) ) uregex = UnicodeRegex() def bleu_tokenize(string): r"""Tokenize a string following the official BLEU implementation. See https://github.com/moses-smt/mosesdecoder/' 'blob/master/scripts/generic/mteval-v14.pl#L954-L983 In our case, the input string is expected to be just one line and no HTML entities de-escaping is needed. So we just tokenize on punctuation and symbols, except when a punctuation is preceded and followed by a digit (e.g. a comma/dot as a thousand/decimal separator). Note that a number (e.g. a year) followed by a dot at the end of sentence is NOT tokenized, i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a space after each sentence). However, this error is already in the original mteval-v14.pl and we want to be consistent with it. Args: string: the input string Returns: a list of tokens """ string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string) string = uregex.punct_nondigit_re.sub(r" \1 \2", string) string = uregex.symbol_re.sub(r" \1 ", string) return string.split() def _get_ngrams(segment, max_order): """Extracts all n-grams up to a given maximum order from an input segment. Args: segment: text segment from which n-grams will be extracted. max_order: maximum length in tokens of the n-grams returned by this methods. Returns: The Counter containing all n-grams up to max_order in segment with a count of how many times each n-gram occurred. """ ngram_counts = collections.Counter() for order in range(1, max_order + 1): for i in range(0, len(segment) - order + 1): ngram = tuple(segment[i : i + order]) ngram_counts[ngram] += 1 return ngram_counts def compute_bleu_matches(reference_corpus, translation_corpus, max_order=4): """Computes BLEU match stats of translations against one or more references. Args: reference_corpus: list of references for each translation. Each reference should be tokenized into a list of tokens. translation_corpus: list of translations to score. Each translation should be tokenized into a list of tokens. max_order: Maximum n-gram order to use when computing BLEU score. Returns: Aggregated n-gram stats for BLEU calculation. """ reference_length = 0 translation_length = 0 bp = 1.0 geo_mean = 0 matches_by_order = [0] * max_order possible_matches_by_order = [0] * max_order precisions = [] for references, translations in zip(reference_corpus, translation_corpus): reference_length += len(references) translation_length += len(translations) ref_ngram_counts = _get_ngrams(references, max_order) translation_ngram_counts = _get_ngrams(translations, max_order) overlap = { ngram: min(count, translation_ngram_counts[ngram]) for ngram, count in ref_ngram_counts.items() } for ngram in overlap: matches_by_order[len(ngram) - 1] += overlap[ngram] for ngram in translation_ngram_counts: possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[ ngram ] return ( np.array(matches_by_order), np.array(possible_matches_by_order), np.array(reference_length), np.array(translation_length), ) def bleu_partial(ref_lines, hyp_lines, case_sensitive=False): """Compute n-gram statistics for two lists of references and translations.""" if len(ref_lines) != len(hyp_lines): raise ValueError( "Reference and translation lists have different numbers of lines." ) if not case_sensitive: ref_lines = [x.lower() for x in ref_lines] hyp_lines = [x.lower() for x in hyp_lines] ref_tokens = [bleu_tokenize(x) for x in ref_lines] hyp_tokens = [bleu_tokenize(x) for x in hyp_lines] return compute_bleu_matches(ref_tokens, hyp_tokens) def complete_bleu( matches_by_order, possible_matches_by_order, reference_length, translation_length, max_order=4, use_bp=True, ): """Compute BLEU score from aggregated n-gram statistics.""" precisions = [0] * max_order smooth = 1.0 geo_mean = 0.0 for i in range(0, max_order): if possible_matches_by_order[i] > 0: precisions[i] = matches_by_order[i] / possible_matches_by_order[i] if matches_by_order[i] > 0: precisions[i] = matches_by_order[i] / possible_matches_by_order[i] else: smooth *= 2 precisions[i] = 1.0 / (smooth * possible_matches_by_order[i]) else: precisions[i] = 0.0 if max(precisions) > 0: p_log_sum = sum(math.log(p) for p in precisions if p) geo_mean = math.exp(p_log_sum / max_order) if use_bp: if not reference_length: bp = 1.0 else: ratio = translation_length / reference_length if ratio <= 0.0: bp = 0.0 elif ratio >= 1.0: bp = 1.0 else: bp = math.exp(1 - 1.0 / ratio) bleu = geo_mean * bp return float(bleu) * 100.0 def bleu_local(ref_lines, hyp_lines, case_sensitive=False): """Compute BLEU for two lists of reference and hypothesis translations.""" stats = bleu_partial(ref_lines, hyp_lines, case_sensitive=case_sensitive) return complete_bleu(*stats) * 100 ================================================ FILE: examples/wmt/configs/default.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Default Hyperparameter configuration.""" import ml_collections def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # Path to load or store sentencepiece vocab file. config.vocab_path = None # Vocabulary size if `vocab_path` is not given. config.vocab_size = 32_000 config.max_corpus_chars = 10**7 # Name of TFDS translation dataset to use. config.dataset_name = 'wmt17_translate/de-en' # Optional name of TFDS translation dataset to use for evaluation. config.eval_dataset_name = 'wmt14_translate/de-en' config.eval_split = 'test' # Reverse the direction of translation. config.reverse_translation = False # Per device batch size for training. config.per_device_batch_size = 32 # Beam size for inference. config.beam_size = 4 config.num_train_steps = 100_000 # Number of steps to take during evaluation. config.num_eval_steps = 20 # Number of steps to generate predictions (used for BLEU score). # -1 will use the whole eval dataset. config.num_predict_steps = -1 # Base learning rate. config.learning_rate = 0.002 # Linear learning rate warmup. config.warmup_steps = 1000 # Cross entropy loss label smoothing. config.label_smoothing = 0.1 # Decay factor for AdamW style weight decay. config.weight_decay = 0.0 # Maximum length cutoff for training examples. config.max_target_length = 256 # Maximum length cutoff for eval examples. config.max_eval_target_length = 256 # Maximum length cutoff for predicted tokens. config.max_predict_length = 256 # Inputs and targets share embedding. config.share_embeddings = True # Final logit transform uses embedding matrix transpose. config.logits_via_embedding = True # Number of transformer layers. config.num_layers = 6 # Size of query/key/value for attention. config.qkv_dim = 1024 # Size of embeddings. config.emb_dim = 1024 # Size of the MLP. config.mlp_dim = 4096 # Number of attention heads. config.num_heads = 16 # Dropout rate. config.dropout_rate = 0.1 # Attention dropout rate. config.attention_dropout_rate = 0.1 # Whether to save model checkpoints. config.save_checkpoints = True # Whether to restore from existing model checkpoints. config.restore_checkpoints = True # Save a checkpoint every these number of steps. config.checkpoint_every_steps = 10_000 # Frequency of eval during training, e.g. every 1000 steps. config.eval_every_steps = 1_000 # Use float16/bfloat16 (GPU/TPU) mixed precision training instead of float32. config.use_mixed_precision = True # Integer for PRNG random seed. config.seed = 0 return config def metrics(): return [ 'train_loss', 'eval_loss', 'bleu', 'eval_accuracy', 'train_accuracy', 'uptime', 'steps_per_sec', 'train_learning_rate', ] ================================================ FILE: examples/wmt/decode.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fast decoding routines for inference from a trained model.""" import typing import flax import jax from jax import lax import jax.numpy as jnp import numpy as np # Constants # We assume the default End-of-Sentence token id is 2 (SentencePiece). EOS_ID = 2 # "Effective negative infinity" constant for masking in beam search. NEG_INF = np.array(-1.0e7) def brevity_penalty(alpha, length): """Brevity penalty function for beam search penalizing short sequences. Args: alpha: float: brevity-penalty scaling parameter. length: int: length of considered sequence. Returns: Brevity penalty score as jax scalar. """ return jnp.power(((5.0 + length) / 6.0), alpha) # Beam handling utility functions: def add_beam_dim(x, beam_size): """Creates new beam dimension in non-scalar array and tiles into it.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x x = jnp.expand_dims(x, axis=1) tile_dims = [1] * x.ndim tile_dims[1] = beam_size return jnp.tile(x, tile_dims) def flatten_beam_dim(x): """Flattens the first two dimensions of a non-scalar array.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) def unflatten_beam_dim(x, batch_size, beam_size): """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" if x.ndim == 0: # ignore scalars (e.g. cache index) return x assert batch_size * beam_size == x.shape[0] return x.reshape((batch_size, beam_size) + x.shape[1:]) def flat_batch_beam_expand(x, beam_size): """Expands each batch item by beam_size in batch_dimension.""" return flatten_beam_dim(add_beam_dim(x, beam_size)) def gather_beams(nested, beam_indices, batch_size, new_beam_size): """Gathers the beam slices indexed by beam_indices into new beam array. Args: nested: pytree of arrays or scalars (the latter ignored). beam_indices: array of beam_indices batch_size: int: size of batch. new_beam_size: int: size of _new_ beam dimension. Returns: New pytree with new beam arrays. [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ batch_indices = jnp.reshape( jnp.arange(batch_size * new_beam_size) // new_beam_size, (batch_size, new_beam_size), ) def gather_fn(x): if x.ndim == 0: # ignore scalars (e.g. cache index) return x else: return x[batch_indices, beam_indices] return jax.tree_util.tree_map(gather_fn, nested) def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): """Gathers the top-k beam slices given by score_or_log_prob array. Args: nested: pytree of arrays or scalars (the latter ignored). score_or_log_prob: [batch_size, old_beam_size] array of values to sort by for top-k selection of beam slices. batch_size: int: size of batch. new_beam_size: int: size of _new_ top-k selected beam dimension Returns: New pytree with new beam arrays containing top k new_beam_size slices. [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] """ _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) topk_indices = jnp.flip(topk_indices, axis=1) return gather_beams(nested, topk_indices, batch_size, new_beam_size) # Beam search state: @flax.struct.dataclass class BeamState: """Holds beam search state data.""" # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence log probabilities and finished sequence scores. live_logprobs: jax.Array # float32: [batch_size, beam_size] finished_scores: jax.Array # float32: [batch_size, beam_size] # The current active-beam-searching and finished sequences. live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len] finished_seqs: jax.Array # int32: [batch_size, beam_size, # max_decode_len] # Records which of the 'finished_seqs' is occupied and not a filler slot. finished_flags: jax.Array # bool: [batch_size, beam_size] # The current state of the autoregressive decoding caches. cache: typing.Any # Any pytree of arrays, e.g. flax attention Cache object def beam_init(batch_size, beam_size, max_decode_len, cache): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] ) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements beam_cache0 = jax.tree_util.tree_map( lambda x: add_beam_dim(x, beam_size), cache ) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, finished_scores=finished_scores0, live_seqs=live_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, cache=beam_cache0, ) # Beam search routine: def beam_search( inputs, cache, tokens_to_logits, beam_size=4, alpha=0.6, eos_id=EOS_ID, max_decode_len=None, ): """Beam search for transformer machine translation. Args: inputs: array: [batch_size, length] int32 sequence of tokens. cache: flax attention cache. tokens_to_logits: fast autoregressive decoder function taking single token slices and cache and returning next-token logits and updated cache. beam_size: int: number of beams to use in beam search. alpha: float: scaling factor for brevity penalty. eos_id: int: id of end-of-sentence token for target vocabulary. max_decode_len: int: maximum length of decoded translations. Returns: Tuple of: [batch_size, beam_size, max_decode_len] top-scoring sequences [batch_size, beam_size] beam-search scores. """ # We liberally annotate shape information for clarity below. batch_size = inputs.shape[0] if max_decode_len is None: max_decode_len = inputs.shape[1] end_marker = jnp.array(eos_id) # initialize beam search state beam_search_init_state = beam_init( batch_size, beam_size, max_decode_len, cache ) def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" # Have we reached max decoding length? not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. min_brevity_penalty = brevity_penalty(alpha, max_decode_len) best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores = jnp.min( state.finished_scores, axis=1, keepdims=True ) # Mask out scores from slots without any actual finished sequences. worst_finished_scores = jnp.where( state.finished_flags, worst_finished_scores, NEG_INF ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = jnp.all(worst_finished_scores > best_live_scores) # If we're not at the max decode length, and the search hasn't terminated, # continue looping. return not_at_end & (~search_terminated) def beam_search_loop_body_fn(state): """Beam search loop state update function.""" # Collect the current position slice along length to feed the fast # autoregressive decoder model. Flatten the beam dimension into batch # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( lax.dynamic_slice( state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) ) ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} flat_cache = jax.tree_util.tree_map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) # unflatten beam dimension # [batch * beam, vocab] --> [batch, beam, vocab] logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} new_cache = jax.tree_util.tree_map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache ) # Gather log probabilities from logits candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] log_probs = candidate_log_probs + jnp.expand_dims( state.live_logprobs, axis=2 ) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] # Each item in batch has beam_size * vocab_size candidate sequences. # For each item, get the top 2*k candidates with the highest log- # probabilities. We gather the top 2*K beams here so that even if the best # K sequences reach EOS simultaneously, we have another K sequences # remaining to continue the live beam search. beams_to_keep = 2 * beam_size # Flatten beam and vocab dimensions. flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) # Gather the top 2*K scores from _all_ beams. # --> [batch, 2*beams], [batch, 2*beams] topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) # Recover the beam index by floor division. topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams. # --> [batch, 2*beams, length] topk_seq = gather_beams( state.live_seqs, topk_beam_indices, batch_size, beams_to_keep ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. # --> [batch, 2*beams, 1] topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] topk_seq = lax.dynamic_update_slice( topk_seq, topk_ids, (0, 0, state.cur_index + 1) ) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. new_log_probs = topk_log_probs + newly_finished * NEG_INF # Determine the top k beam indices (from top 2*k beams) from log probs. # --> [batch, beams] _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) new_topk_indices = jnp.flip(new_topk_indices, axis=1) # Gather the top k beams (from top 2*k beams). # --> [batch, beams, length], [batch, beams] top_alive_seq, top_alive_log_probs = gather_beams( [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size ) # Determine the top k beam indices from the original set of all beams. # --> [batch, beams] top_alive_indices = gather_beams( topk_beam_indices, new_topk_indices, batch_size, beam_size ) # With these, gather the top k beam-associated caches. # --> {[batch, beams, ...], ...} top_alive_cache = gather_beams( new_cache, top_alive_indices, batch_size, beam_size ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF # Combine sequences, scores, and flags along the beam dimension and compare # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] [state.finished_seqs, topk_seq], axis=1 ) finished_scores = jnp.concatenate( # --> [batch, 3*beams] [state.finished_scores, new_scores], axis=1 ) finished_flags = jnp.concatenate( # --> [batch, 3*beams] [state.finished_flags, newly_finished], axis=1 ) # --> [batch, beams, length], [batch, beams], [batch, beams] ( top_finished_seq, top_finished_scores, top_finished_flags, ) = gather_topk_beams( [finished_seqs, finished_scores, finished_flags], finished_scores, batch_size, beam_size, ) return BeamState( cur_index=state.cur_index + 1, live_logprobs=top_alive_log_probs, finished_scores=top_finished_scores, live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, cache=top_alive_cache, ) # Run while loop and get final beam search state. final_state = lax.while_loop( beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state ) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] none_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] finished_seqs = jnp.where( none_finished[:, None, None], final_state.finished_seqs, final_state.live_seqs, ) # --> [batch, beams] finished_scores = jnp.where( none_finished[:, None], final_state.finished_scores, final_state.live_logprobs, ) return finished_seqs, finished_scores ================================================ FILE: examples/wmt/input_pipeline.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input pipeline for a WMT dataset.""" import os from typing import Dict, Optional, List, Union from clu import deterministic_data import ml_collections import tensorflow as tf import tensorflow_datasets as tfds import tokenizer AUTOTUNE = tf.data.AUTOTUNE Features = dict[str, tf.Tensor] class NormalizeFeatureNamesOp: """Normalizes feature names to 'inputs' and 'targets'.""" def __init__(self, ds_info: tfds.core.DatasetInfo, reverse_translation: bool): self.input_lang, self.target_lang = ds_info.supervised_keys if reverse_translation: self.input_lang, self.target_lang = self.target_lang, self.input_lang def __call__(self, features: Features) -> Features: features['inputs'] = features.pop(self.input_lang) features['targets'] = features.pop(self.target_lang) return features def get_raw_dataset( dataset_builder: tfds.core.DatasetBuilder, split: str, *, reverse_translation: bool = False, ) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `slit`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. reverse_translation: bool: whether to reverse the translation direction. e.g. for 'de-en' this translates from english to german. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( split, num_examples, drop_remainder=False ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( NormalizeFeatureNamesOp( dataset_builder.info, reverse_translation=reverse_translation ), num_parallel_calls=AUTOTUNE, ) return ds def pack_dataset( dataset: tf.data.Dataset, key2length: int | dict[str, int], keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. This is meant to replace the irritation of having to create a separate "packed" version of a dataset to train efficiently on TPU. Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: _segmentation: an int32 tensor identifying the parts representing the original example. _position: an int32 tensor identifying the position within the original example. Example: Two input examples get combined to form an output example. The input examples are: {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} The output example is: { "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] } 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length "length", and the sequences in the output examples all have fixed (padded) length "length". Args: dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) Returns: a tf.data.Dataset """ shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) if keys is None: keys = list(shapes.keys()) for k in keys: if k not in shapes: raise ValueError( 'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys()) ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the # keys suffixed by "_segmentation" and "_position" if isinstance(key2length, int): key2length = {k: key2length for k in keys} for k in keys: for suffix in ['_segmentation', '_position']: key2length[k + suffix] = key2length[k] # trim to length dataset = dataset.map( lambda x: {k: x[k][: key2length[k]] for k in keys}, num_parallel_calls=AUTOTUNE, ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( batch_size, padded_shapes={k: [-1] for k in keys} ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. def my_fn(x): return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) def _pack_with_tf_ops( dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: an dict from feature-key to integer Returns: a dataset. """ empty_example = {} for k in keys: empty_example[k] = tf.zeros([0], dtype=tf.int32) empty_example[k + '_position'] = tf.zeros([0], dtype=tf.int32) keys_etc = empty_example.keys() def write_packed_example(partial, outputs): new_partial = empty_example.copy() new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), ) return new_partial, new_outputs def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) outputs[k + '_position'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] ), ) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( [partial[k + '_position'], tf.range(new_seq_len)], 0 ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. i, partial, outputs = tf.while_loop( cond=lambda *_: True, body=body_fn, loop_vars=(i, partial, outputs), shape_invariants=( tf.TensorShape([]), # type: ignore[wrong-arg-types] {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] ), maximum_iterations=dynamic_batch_size, ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = tf.cumsum( tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) return dataset.unbatch() # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- def preprocess_wmt_data( dataset, shuffle: bool, num_epochs: int | None = 1, pack_examples: bool = True, shuffle_buffer_size: int = 1024, max_length: int = 512, batch_size: int = 256, drop_remainder: bool = True, prefetch_size: int = AUTOTUNE, ): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): def filter_fn(x): source, target = x['inputs'], x['targets'] l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0]) return tf.less(l, max_len + 1) return filter_fn if max_length > 0: dataset = dataset.filter(length_filter(max_length)) if shuffle: dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.repeat(num_epochs) if pack_examples: dataset = pack_dataset(dataset, max_length) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, padded_shapes={'inputs': max_length, 'targets': max_length}, padding_values={'inputs': 0, 'targets': 0}, drop_remainder=drop_remainder, ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) return dataset def get_wmt_datasets( config: ml_collections.ConfigDict, *, n_devices: int, reverse_translation: bool = True, vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: vocab_path = os.path.expanduser('~/wmt_sentencepiece_model') train_ds_builder = tfds.builder(config.dataset_name) train_data = get_raw_dataset( train_ds_builder, 'train', reverse_translation=reverse_translation ) if config.eval_dataset_name: eval_ds_builder = tfds.builder(config.eval_dataset_name) else: eval_ds_builder = train_ds_builder eval_data = get_raw_dataset( eval_ds_builder, config.eval_split, reverse_translation=reverse_translation, ) # Tokenize data. sp_tokenizer = tokenizer.load_or_train_tokenizer( train_data, vocab_path=vocab_path, vocab_size=config.vocab_size, max_corpus_chars=config.max_corpus_chars, ) train_data = train_data.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE ) eval_data = eval_data.map( tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE ) batch_size = config.per_device_batch_size * n_devices train_ds = preprocess_wmt_data( train_data, shuffle=True, num_epochs=None, pack_examples=True, batch_size=batch_size, max_length=config.max_target_length, ) eval_ds = preprocess_wmt_data( eval_data, shuffle=False, pack_examples=False, batch_size=batch_size, max_length=config.max_eval_target_length, ) predict_ds = preprocess_wmt_data( eval_data, shuffle=False, pack_examples=False, batch_size=batch_size, max_length=config.max_predict_length, drop_remainder=False, ) return train_ds, eval_ds, predict_ds, sp_tokenizer ================================================ FILE: examples/wmt/input_pipeline_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pathlib import sys import tempfile from absl.testing import absltest import tensorflow_datasets as tfds from configs import default import input_pipeline # We just use different values here to verify that the input pipeline uses the # the correct value for the 3 different datasets. _TARGET_LENGTH = 32 _EVAL_TARGET_LENGTH = 48 _PREDICT_TARGET_LENGTH = 64 class InputPipelineTest(absltest.TestCase): def setUp(self): super().setUp() if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') self.train_ds, self.eval_ds, self.predict_ds = self._get_datasets() def _get_datasets(self): config = default.get_config() config.per_device_batch_size = 1 config.vocab_size = 32 config.max_corpus_chars = 1000 config.max_target_length = _TARGET_LENGTH config.max_eval_target_length = _EVAL_TARGET_LENGTH config.max_predict_length = _PREDICT_TARGET_LENGTH vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model') # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train_ds, eval_ds, predict_ds, _ = input_pipeline.get_wmt_datasets( n_devices=2, config=config, vocab_path=vocab_path ) return train_ds, eval_ds, predict_ds def test_train_ds(self): expected_shape = [2, _TARGET_LENGTH] # 2 devices. # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in self.train_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'inputs_position': expected_shape, 'inputs_segmentation': expected_shape, 'targets': expected_shape, 'targets_position': expected_shape, 'targets_segmentation': expected_shape, }, ) def test_eval_ds(self): expected_shape = [2, _EVAL_TARGET_LENGTH] # 2 devices. for batch in self.eval_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'targets': expected_shape, }, ) def test_predict_ds(self): expected_shape = [2, _PREDICT_TARGET_LENGTH] # 2 devices. for batch in self.predict_ds.take(3): self.assertEqual( {k: v.shape.as_list() for k, v in batch.items()}, { 'inputs': expected_shape, 'targets': expected_shape, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: examples/wmt/main.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Main file for running the WMT example. This file is intentionally kept short. The majority for logic is in libraries that can be easily tested and imported in Colab. """ from absl import app from absl import flags from absl import logging from clu import platform import jax from ml_collections import config_flags import tensorflow as tf import train FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', 'configs/default.py', 'File path to the training hyperparameter configuration.', lock_config=True, ) flags.mark_flags_as_required(['config', 'workdir']) def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.experimental.set_visible_devices([], 'GPU') logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) logging.info('JAX local devices: %r', jax.local_devices()) # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': jax.config.config_with_absl() app.run(main) ================================================ FILE: examples/wmt/models.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer-based machine translation model.""" # pylint: disable=attribute-defined-outside-init,g-bare-generic # See issue #620. # pytype: disable=wrong-arg-count # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error from collections.abc import Callable from typing import Any from flax import linen as nn from flax import struct from jax import lax import jax.numpy as jnp import numpy as np @struct.dataclass class TransformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" vocab_size: int output_vocab_size: int share_embeddings: bool = False logits_via_embedding: bool = False dtype: Any = jnp.float32 emb_dim: int = 512 num_heads: int = 8 num_layers: int = 6 qkv_dim: int = 512 mlp_dim: int = 2048 max_len: int = 2048 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 deterministic: bool = False decode: bool = False kernel_init: Callable = nn.initializers.xavier_uniform() bias_init: Callable = nn.initializers.normal(stddev=1e-6) posemb_init: Callable | None = None def shift_right(x, axis=1): """Shift the input to the right by padding on axis 1.""" pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( x, pad_widths, mode='constant', constant_values=x.dtype.type(0) ) return padded[:, :-1] def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0): """1D Sinusoidal Position Embedding Initializer. Args: max_len: maximum possible length for the input. min_scale: float: minimum frequency-scale in sine grating. max_scale: float: maximum frequency-scale in sine grating. Returns: output: init function returning `(1, max_len, d_feature)` """ def init(key, shape, dtype=np.float32): """Sinusoidal init.""" del key, dtype d_feature = shape[-1] pe = np.zeros((max_len, d_feature), dtype=np.float32) position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) pe[:, : d_feature // 2] = np.sin(position * div_term) pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) return init class AddPositionEmbs(nn.Module): """Adds (optionally learned) positional embeddings to the inputs. Attributes: config: TransformerConfig dataclass containing hyperparameters. decode: whether to run in single-position autoregressive mode. """ config: TransformerConfig decode: bool = False @nn.compact def __call__(self, inputs, inputs_positions=None): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a learned position embedding is desired, pass an initializer to posemb_init in the configuration. Args: inputs: input data. inputs_positions: input position indices for packed sequences. Returns: output: `(bs, timesteps, in_dim)` """ config = self.config # inputs.shape is (batch_size, seq_len, emb_dim) assert inputs.ndim == 3, ( 'Number of dimensions should be 3, but it is: %d' % inputs.ndim ) length = inputs.shape[1] pos_emb_shape = (1, config.max_len, inputs.shape[-1]) if config.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. pos_embedding = sinusoidal_init(max_len=config.max_len)( None, pos_emb_shape, None ) else: pos_embedding = self.param( 'pos_embedding', config.posemb_init, pos_emb_shape ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') cache_index = self.variable( 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) ) if is_initialized: i = cache_index.value cache_index.value = i + 1 _, _, df = pos_embedding.shape pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df)) if inputs_positions is None: # normal unpacked case: return inputs + pe else: # for packed data we need to use known position indices: return inputs + jnp.take(pe[0], inputs_positions, axis=0) class MlpBlock(nn.Module): """Transformer MLP / feed-forward block. Attributes: config: TransformerConfig dataclass containing hyperparameters. out_dim: optionally specify out dimension. """ config: TransformerConfig out_dim: int | None = None @nn.compact def __call__(self, inputs): """Applies Transformer MlpBlock module.""" config = self.config actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense( config.mlp_dim, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, )(inputs) x = nn.relu(x) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) output = nn.Dense( actual_out_dim, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, )(x) output = nn.Dropout(rate=config.dropout_rate)( output, deterministic=config.deterministic ) return output class Encoder1DBlock(nn.Module): """Transformer encoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__(self, inputs, encoder_mask=None): """Applies Encoder1DBlock module. Args: inputs: input data. encoder_mask: encoder self-attention mask. Returns: output after transformer encoder block. """ config = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=config.dtype)(inputs) x = nn.MultiHeadDotProductAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=config.kernel_init, bias_init=config.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, )(x, mask=encoder_mask) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=config.dtype)(x) y = MlpBlock(config=config)(y) return x + y class EncoderDecoder1DBlock(nn.Module): """Transformer encoder-decoder layer. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig @nn.compact def __call__( self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None ): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ config = self.config # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(dtype=config.dtype)(targets) x = nn.MultiHeadDotProductAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=config.kernel_init, bias_init=config.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, decode=config.decode, )(x, mask=decoder_mask) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=config.dtype)(x) y = nn.MultiHeadDotProductAttention( num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=config.kernel_init, bias_init=config.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, )(y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic ) y = y + x # MLP block. z = nn.LayerNorm(dtype=config.dtype)(y) z = MlpBlock(config=config)(z) return y + z class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ config: TransformerConfig shared_embedding: Any = None @nn.compact def __call__(self, inputs, inputs_positions=None, encoder_mask=None): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. Returns: output of a transformer encoder. """ config = self.config assert inputs.ndim == 2 # (batch, len) # Input Embedding if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=config.vocab_size, features=config.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(config=config, decode=False, name='posembed_input')( x, inputs_positions=inputs_positions ) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic ) x = x.astype(config.dtype) # Input Encoder for lyr in range(config.num_layers): x = Encoder1DBlock(config=config, name=f'encoderblock_{lyr}')( x, encoder_mask ) encoded = nn.LayerNorm(dtype=config.dtype, name='encoder_norm')(x) return encoded class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. shared_embedding: a shared embedding layer to use. """ config: TransformerConfig shared_embedding: Any = None @nn.compact def __call__( self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, ): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output of a transformer decoder. """ config = self.config assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=config.output_vocab_size, features=config.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), ) else: output_embed = self.shared_embedding y = targets.astype('int32') if not config.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs( config=config, decode=config.decode, name='posembed_output' )(y, inputs_positions=targets_positions) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic ) y = y.astype(config.dtype) # Target-Input Decoder for lyr in range(config.num_layers): y = EncoderDecoder1DBlock( config=config, name=f'encoderdecoderblock_{lyr}' )( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, ) y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y) # Decoded Logits if config.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense( config.output_vocab_size, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init, name='logitdense', )(y) return logits class Transformer(nn.Module): """Transformer Model for sequence to sequence translation. Attributes: config: TransformerConfig dataclass containing hyperparameters. """ config: TransformerConfig def setup(self): config = self.config if config.share_embeddings: if config.output_vocab_size is not None: assert ( config.output_vocab_size == config.vocab_size ), "can't share embedding with different vocab sizes." self.shared_embedding = nn.Embed( num_embeddings=config.vocab_size, features=config.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), ) else: self.shared_embedding = None self.encoder = Encoder( config=config, shared_embedding=self.shared_embedding ) self.decoder = Decoder( config=config, shared_embedding=self.shared_embedding ) def encode(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer encoder-branch on the inputs. Args: inputs: input data. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: encoded feature array from the transformer encoder. """ config = self.config # Make padding attention mask. encoder_mask = nn.make_attention_mask( inputs > 0, inputs > 0, dtype=config.dtype ) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, nn.make_attention_mask( inputs_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype, ), ) return self.encoder( inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask ) def decode( self, encoded, inputs, # only needed for masks targets, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, ): """Applies Transformer decoder-branch on encoded-input and target. Args: encoded: encoded input data from encoder. inputs: input data (only needed for masking). targets: target data. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: logits array from transformer decoder. """ config = self.config # Make padding attention masks. if config.decode: # for fast autoregressive decoding only a special encoder-decoder mask is # used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype ) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype), nn.make_causal_mask(targets, dtype=config.dtype), ) encoder_decoder_mask = nn.make_attention_mask( targets > 0, inputs > 0, dtype=config.dtype ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( targets_segmentation, targets_segmentation, jnp.equal, dtype=config.dtype, ), ) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, nn.make_attention_mask( targets_segmentation, inputs_segmentation, jnp.equal, dtype=config.dtype, ), ) logits = self.decoder( encoded, targets, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, ) return logits.astype(self.config.dtype) def __call__( self, inputs, targets, inputs_positions=None, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, ): """Applies Transformer model on the inputs. Args: inputs: input data. targets: target data. inputs_positions: input subsequence positions for packed examples. targets_positions: target subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. Returns: logits array from full transformer. """ encoded = self.encode( inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, ) return self.decode( encoded, inputs, # only used for masks targets, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, ) ================================================ FILE: examples/wmt/requirements.txt ================================================ absl-py==1.0.0 clu==0.0.6 flax==0.6.0 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda11_cudnn805]>=0.3.16 # change to jax[tpu] if running on tpus ml-collections==0.1.0 numpy==1.22.0 optax==0.1.0 sentencepiece==0.1.96 six==1.15.0 tensorflow==2.11.1 tensorflow-datasets==4.4.0 tensorflow-text==2.8.1 ================================================ FILE: examples/wmt/tokenizer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides op for tokenizing a dataset.""" import dataclasses import os import tempfile import time from typing import Any, Dict, Tuple from collections.abc import Iterable import sys from absl import logging import jax import tensorflow as tf if sys.version_info < (3, 13): import tensorflow_text as tftxt from sentencepiece import SentencePieceTrainer Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=('inputs', 'targets'), ) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. Args: dataset: tf.dataset containing string-data. maxchars: int: approximate number of characters to save from dataset. data_keys: Tuple[str]: what keys in dataset to dump from. Returns: name of temp file with dataset bytes, exact number of characters dumped. """ char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/ds_chars' ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: line = example[k] + b'\n' char_count += len(line) outfp.write(line) return outfp.name, char_count def _train_sentencepiece( dataset: tf.data.Dataset, *, vocab_size: int, maxchars: int = int(1e7), model_path: str, model_type: str = 'unigram', character_coverage: float = 1.0, data_keys=('inputs', 'targets'), ): """Train SentencePiece tokenizer from subset of tf dataset. Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. maxchars: int: number of characters to use for sentencepiece training. model_path: str: path of model file to save vocab model to. model_type: str: type of sentencepiece vocab to train. character_coverage: amount of characters covered by the model, good defaults are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. data_keys: Tuple[str]: keys of dataset to use for training. Returns: path to the trained sentencepiece vocabulary model. """ if model_path.startswith('gs://'): abs_model_path = model_path else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( dataset, maxchars=maxchars, data_keys=data_keys ) with tempfile.NamedTemporaryFile( delete=False, prefix='/tmp/sp_tmp' ) as model_fp: pass # we just want a prefix'd tmp-filename argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', f'--model_prefix={model_fp.name}', f'--model_type={model_type}', ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address # create and fill delays. copy_rename_path = abs_model_path + '.rntmp' tf.io.gfile.copy(model_fp.name + '.model', copy_rename_path, overwrite=True) tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True) logging.info('copied %s to %s', model_fp.name + '.model', abs_model_path) else: while not tf.io.gfile.exists(abs_model_path): time.sleep(1) time.sleep(1) return abs_model_path def _load_sentencepiece_tokenizer( model_path: str, add_bos: bool = False, add_eos: bool = True, reverse: bool = False, ): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse ) return sp_tokenizer def load_or_train_tokenizer( dataset: tf.data.Dataset, *, vocab_path: str, vocab_size: int, max_corpus_chars: int, data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) except tf.errors.NotFoundError: logging.info('SentencePiece vocab not found, building one from data.') vocab_path = _train_sentencepiece( dataset, vocab_size=vocab_size, maxchars=max_corpus_chars, model_path=vocab_path, data_keys=data_keys, ) return _load_sentencepiece_tokenizer(vocab_path) @dataclasses.dataclass class TokenizeOp: sp_tokenizer: Any data_keys: Iterable[str] = ('inputs', 'targets') def __call__(self, features: Features) -> Features: for k in self.data_keys: features[k] = self.sp_tokenizer.tokenize(features[k]) return features ================================================ FILE: examples/wmt/train.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Machine Translation example. This script trains a Transformer on a WMT dataset. """ # pytype: disable=wrong-arg-count # pytype: disable=attribute-error import collections import functools import os from absl import logging from clu import metric_writers from clu import periodic_actions from flax import jax_utils from flax import linen as nn from flax.training import checkpoints from flax.training import common_utils from flax.training import dynamic_scale as dynamic_scale_lib from flax.training import train_state import jax import jax.numpy as jnp import ml_collections import numpy as np import optax import orbax.checkpoint as ocp import tensorflow as tf import bleu import decode import input_pipeline import models class TrainState(train_state.TrainState): dynamic_scale: dynamic_scale_lib.DynamicScale def rsqrt_schedule( init_value: float, shift: int = 0, ): """Applies a reverse square-root schedule. The reverse square root schedule is simply `lr = init_value / sqrt(step)`. Args: init_value: Base learning rate (before applying the rsqrt schedule). shift: How many steps the rsqrt should be shifted. Shifting the rsqrt schedule makes it less steep in the beginning (close to 0). Returns: A schedule `count -> learning_rate`. """ def schedule(count): return init_value * (count + shift) ** -0.5 * shift**0.5 return schedule def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( [ optax.linear_schedule( init_value=0, end_value=learning_rate, transition_steps=warmup_steps, ), rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), ], boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length]. label_smoothing: label smoothing constant, used to determine the on and off values. Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant normalizing_factor = np.prod(targets.shape) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_weighted_accuracy(logits, targets, weights=None): """Compute weighted accuracy for log probs and targets. Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. weights: None or array of shape [batch, length] Returns: Tuple of scalar loss and batch normalizing factor. """ if logits.ndim != targets.ndim + 1: raise ValueError( "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) if weights is not None: loss = loss * weights normalizing_factor = weights.sum() return loss.sum(), normalizing_factor def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { "loss": loss, "accuracy": acc, "denominator": weight_sum, } metrics = jax.lax.psum(metrics, axis_name="batch") return metrics # Primary training / eval / decode step functions. # ----------------------------------------------------------------------------- def train_step( state, batch, config, learning_rate_fn, label_smoothing=0.0, dropout_rng=None, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" # where multiple sequences are packed into the same example with this # metadata. # if such features are not present they are ignored and the example is treated # like a normal, unpacked sequence example. train_keys = [ "inputs", "targets", "inputs_position", "targets_position", "inputs_segmentation", "targets_segmentation", ] ( inputs, targets, inputs_positions, targets_positions, inputs_segmentation, targets_segmentation, ) = (batch.get(k, None) for k in train_keys) weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) dropout_rng = jax.random.fold_in(dropout_rng, state.step) def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {"params": params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={"dropout": dropout_rng}, ) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits step = state.step if state.dynamic_scale: # dynamic scale takes care of averaging gradients across replicas grad_fn = state.dynamic_scale.value_and_grad( loss_fn, has_aux=True, axis_name="batch" ) dynamic_scale, is_fin, (_, logits), grads = grad_fn(state.params) state = state.replace(dynamic_scale=dynamic_scale) else: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (_, logits), grads = grad_fn(state.params) grads = jax.lax.pmean(grads, axis_name="batch") new_state = state.apply_gradients(grads=grads) metrics = compute_metrics(logits, targets, weights) metrics["learning_rate"] = learning_rate_fn(step) if state.dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). select_fn = functools.partial(jnp.where, is_fin) # pylint: disable=undefined-variable new_state = new_state.replace( opt_state=jax.tree_util.tree_map( select_fn, new_state.opt_state, state.opt_state ), params=jax.tree_util.tree_map( select_fn, new_state.params, state.params ), ) metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] # pylint: disable=undefined-variable return new_state, metrics def eval_step(params, batch, config, label_smoothing=0.0): """Calculate evaluation metrics on a batch.""" inputs, targets = batch["inputs"], batch["targets"] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(config).apply({"params": params}, inputs, targets) return compute_metrics(logits, targets, weights, label_smoothing) def initialize_cache(inputs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( jax.random.key(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(target_shape, config.dtype), ) return initial_variables["cache"] def predict_step( inputs, params, cache, eos_id, max_decode_len, config, beam_size=4 ): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( models.Transformer(config).apply( {"params": params}, inputs, method=models.Transformer.encode ), beam_size, ) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( {"params": params, "cache": flat_cache}, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=["cache"], method=models.Transformer.decode, ) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, eos_id=eos_id, max_decode_len=max_decode_len, ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:] # Utils for prediction and BLEU calculation # ----------------------------------------------------------------------------- def pad_examples(x, desired_batch_size): """Expand batch to desired size by repeating last slice.""" batch_pad = desired_batch_size - x.shape[0] return np.concatenate([x, np.tile(x[-1], (batch_pad, 1))], axis=0) def per_host_sum_pmap(in_tree): """Execute psum on in_tree"s leaves over one device per host.""" host2devices = collections.defaultdict(list) for d in jax.devices(): host2devices[d.process_index].append(d) devices = [host2devices[k][0] for k in host2devices] host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices) def pre_pmap(xs): return jax.tree_util.tree_map( lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs ) def post_pmap(xs): # Avoid degraded performance under the new jax.pmap. See # https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays. return jax.tree_util.tree_map( lambda x: x.addressable_shards[0].data.squeeze(0), xs ) return post_pmap(host_psum(pre_pmap(in_tree))) def tohost(x): """Collect batches from all devices to host and flatten batch dimensions.""" n_device, n_batch, *remaining_dims = x.shape return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims)) def evaluate( *, p_eval_step, params, eval_ds: tf.data.Dataset, num_eval_steps: int ): """Evaluate the params an return a dictionary with the metrics.""" logging.info("Gathering evaluation metrics.") eval_metrics = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = jax.tree_util.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(params, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_util.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums, ) return eval_summary def translate_and_calculate_bleu( *, p_pred_step, p_init_cache, params, predict_ds: tf.data.Dataset, decode_tokens, max_predict_length: int, ): """Translates the `predict_ds` and calculates the BLEU score.""" n_devices = jax.local_device_count() logging.info("Translating evaluation dataset.") sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_util.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int(np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_util.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch, ) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step( pred_batch["inputs"], params, cache, decode.EOS_ID, max_predict_length ) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info( "Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources), ) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" return exemplars, bleu_score def preferred_dtype(config): platform = jax.local_devices()[0].platform if config.use_mixed_precision: if platform == "tpu": return jnp.bfloat16 elif platform == "gpu": return jnp.float16 return jnp.float32 def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, vocab_path=vocab_path, ) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[: np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") dtype = preferred_dtype(config) # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=dtype, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), ) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.key(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32), ) # Create train state with Adam optimizer and weight decay. learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) dynamic_scale = None if dtype == jnp.float16: dynamic_scale = dynamic_scale_lib.DynamicScale() state = TrainState.create( apply_fn=m.apply, params=initial_variables["params"], tx=optax.adamw( learning_rate=learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay, ), dynamic_scale=dynamic_scale, ) # We access model params only via state.params del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dict(config)) # Replicate state. state = jax_utils.replicate(state) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing, ), axis_name="batch", donate_argnums=(0,), ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial(eval_step, config=eval_config), axis_name="batch" ) p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config, ), axis_name="batch", ) p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size ), axis_name="batch", static_broadcasted_argnums=(3, 4), ) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard( jax.tree_util.tree_map(np.asarray, next(train_iter)) ) state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_util.tree_map( lambda x: x / denominator, metrics_sums # pylint: disable=cell-var-from-loop ) summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, params=state.params, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps, ) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()} ) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, params=state.params, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length, ) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info("Saving checkpoint step %d.", step) # Orbax can not handle host local arrays from pmap. replicated_state = jax.tree_util.tree_map( ocp.utils.fully_replicated_host_local_array_to_global_array, state, ) with report_progress.timed("checkpoint"): checkpoints.save_checkpoint_multiprocess( workdir, replicated_state, step ) ================================================ FILE: examples/wmt/train_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pathlib import tempfile import sys from absl import logging from absl.testing import absltest import jax import tensorflow as tf import tensorflow_datasets as tfds from configs import default import train jax.config.update('jax_disable_most_optimizations', True) class TrainTest(absltest.TestCase): """Test cases for WMT library.""" def setUp(self): if sys.version_info >= (3, 13): self.skipTest('Test (and tensorflow-text) does not suport Python 3.13+') super().setUp() tf.config.experimental.set_visible_devices([], 'GPU') def test_train_and_evaluate(self): config = default.get_config() config.max_corpus_chars = 1000 config.vocab_size = 32 config.per_device_batch_size = 1 config.num_train_steps = 1 config.num_eval_steps = 1 config.num_predict_steps = 1 config.num_layers = 1 config.qkv_dim = 128 config.emb_dim = 128 config.mlp_dim = 512 config.num_heads = 2 config.max_target_length = 32 config.max_eval_target_length = 32 config.max_predict_length = 32 workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. flax_root_dir = pathlib.Path(__file__).parents[2] data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable with tfds.testing.mock_data(num_examples=128, data_dir=data_dir): train.train_and_evaluate(config, workdir) logging.info('workdir content: %s', tf.io.gfile.listdir(workdir)) if __name__ == '__main__': absltest.main() ================================================ FILE: flax/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax API.""" # pylint: disable=g-import-not-at-top # pyformat: disable from flax import configurations config: configurations.Config = configurations.config del configurations from flax import core from flax import jax_utils from flax import linen from flax import serialization from flax import traverse_util from flax import version __version__: str = version.__version__ del version # DO NOT REMOVE - Marker for internal deprecated API. # DO NOT REMOVE - Marker for internal logging. # pyformat: enable # pylint: enable=g-import-not-at-top ================================================ FILE: flax/configurations.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Global configuration flags for Flax.""" import os from contextlib import contextmanager from typing import Any, Generic, NoReturn, TypeVar, overload _T = TypeVar('_T') class Config: flax_use_flaxlib: bool flax_array_ref: bool flax_pytree_module: bool flax_max_repr_depth: int | None flax_always_shard_variable: bool flax_hijax_variable: bool nnx_graph_mode: bool nnx_graph_updates: bool # See https://google.github.io/pytype/faq.html. _HAS_DYNAMIC_ATTRIBUTES = True def __init__(self): self._values = {} def _add_option(self, name, default): if name in self._values: raise RuntimeError(f'Config option {name} already defined') self._values[name] = default def _read(self, name): try: return self._values[name] except KeyError: raise LookupError(f'Unrecognized config option: {name}') @overload def update(self, name: str, value: Any, /) -> None: ... @overload def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None: ... def update(self, name_or_holder, value, /): """Modify the value of a given flag. Args: name_or_holder: the name of the flag to modify or the corresponding flag holder object. value: new value to set. """ name = name_or_holder if isinstance(name_or_holder, FlagHolder): name = name_or_holder.name if name not in self._values: raise LookupError(f'Unrecognized config option: {name}') self._values[name] = value def __repr__(self): values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items()) return f'Config({values_repr}\n)' @contextmanager def temp_flip_flag(self, var_name: str, var_value: bool): """Context manager to temporarily flip feature flags for test functions. Args: var_name: the config variable name (without the 'flax_' prefix) var_value: the boolean value to set var_name to temporarily """ old_value = getattr(self, f'flax_{var_name}') try: self.update(f'flax_{var_name}', var_value) yield finally: self.update(f'flax_{var_name}', old_value) config = Config() # Config parsing utils class FlagHolder(Generic[_T]): def __init__(self, name, help): self.name = name self.__name__ = name[4:] if name.startswith('flax_') else name self.__doc__ = f'Flag holder for `{name}`.\n\n{help}' def __bool__(self) -> NoReturn: raise TypeError( "bool() not supported for instances of type '{0}' " "(did you mean to use '{0}.value' instead?)".format(type(self).__name__) ) @property def value(self) -> _T: return config._read(self.name) def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]: """Set up a boolean flag. Example:: enable_foo = bool_flag( name='flax_enable_foo', default=False, help='Enable foo.', ) Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to control the process-level value of the flag, in addition to using e.g. ``config.update("flax_enable_foo", True)`` directly. Args: name: converted to lowercase to define the name of the flag. It is converted to uppercase to define the corresponding shell environment variable. default: a default value for the flag. help: used to populate the docstring of the returned flag holder object. Returns: A flag holder object for accessing the value of the flag. """ name = name.lower() config._add_option(name, static_bool_env(name.upper(), default)) fh = FlagHolder[bool](name, help) setattr(Config, name, property(lambda _: fh.value, doc=help)) return fh def int_flag(name: str, *, default: int | None, help: str) -> FlagHolder[int]: """Set up an integer flag. Example:: num_foo = int_flag( name='flax_num_foo', default=42, help='Number of foo.', ) Now the ``FLAX_NUM_FOO`` shell environment variable can be used to control the process-level value of the flag, in addition to using e.g. ``config.update("flax_num_foo", 42)`` directly. Args: name: converted to lowercase to define the name of the flag. It is converted to uppercase to define the corresponding shell environment variable. default: a default value for the flag. help: used to populate the docstring of the returned flag holder object. Returns: A flag holder object for accessing the value of the flag. """ name = name.lower() config._add_option(name, static_int_env(name.upper(), default)) fh = FlagHolder[int](name, help) setattr(Config, name, property(lambda _: fh.value, doc=help)) return fh def static_bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. This is deprecated. Please use bool_flag() unless your flag will be used in a static method and does not require runtime updates. True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Args: varname: the name of the variable default: the default boolean value Returns: boolean return value derived from defaults and environment. Raises: ValueError if the environment variable is anything else. """ val = os.getenv(varname, str(default)) val = val.lower() if val in ('y', 'yes', 't', 'true', 'on', '1'): return True elif val in ('n', 'no', 'f', 'false', 'off', '0'): return False else: raise ValueError( f'invalid truth value {val!r} for environment {varname!r}' ) def static_int_env(varname: str, default: int | None) -> int | None: """Read an environment variable and interpret it as an integer. Args: varname: the name of the variable default: the default integer value Returns: integer return value derived from defaults and environment. Raises: ValueError if the environment variable is not an integer. """ val = os.getenv(varname) if val is None: return default try: return int(val) except ValueError: raise ValueError( f'invalid integer value {val!r} for environment {varname!r}' ) from None # Flax Global Configuration Variables: flax_filter_frames = bool_flag( name='flax_filter_frames', default=True, help='Whether to hide flax-internal stack frames from tracebacks.', ) flax_profile = bool_flag( name='flax_profile', default=True, help='Whether to run Module methods under jax.named_scope for profiles.', ) flax_use_orbax_checkpointing = bool_flag( name='flax_use_orbax_checkpointing', default=True, help='Whether to use Orbax to save checkpoints.', ) flax_preserve_adopted_names = bool_flag( name='flax_preserve_adopted_names', default=False, help="When adopting outside modules, don't clobber existing names.", ) # TODO(marcuschiam): remove this feature flag once regular dict migration is complete flax_return_frozendict = bool_flag( name='flax_return_frozendict', default=False, help='Whether to return FrozenDicts when calling init or apply.', ) flax_fix_rng = bool_flag( name='flax_fix_rng_separator', default=False, help=( 'Whether to add separator characters when folding in static data into' ' PRNG keys.' ), ) flax_use_flaxlib = bool_flag( name='flax_use_flaxlib', default=False, help='Whether to use flaxlib for C++ acceleration.', ) flax_array_ref = bool_flag( name='flax_array_ref', default=False, help='Whether to use array refs.', ) flax_pytree_module = bool_flag( name='flax_pytree_module', default=True, help='Whether Modules are pytrees by default or not.', ) flax_max_repr_depth = int_flag( name='flax_max_repr_depth', default=None, help='Maximum depth of reprs for nested flax objects. Default is None (no limit).', ) flax_always_shard_variable = bool_flag( name='flax_always_shard_variable', default=True, help='Whether a `nnx.Variable` should always automatically be sharded if it contains sharding annotations.', ) flax_hijax_variable = bool_flag( name='flax_hijax_variable', default=False, help='Whether to enable HiJAX support for `nnx.Variable`.', ) nnx_graph_mode = bool_flag( name='nnx_graph_mode', default=True, help='Whether NNX APIs default to graph-mode (True) or tree-mode (False).', ) nnx_graph_updates = bool_flag( name='nnx_graph_updates', default=True, help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.', ) ================================================ FILE: flax/core/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .axes_scan import broadcast as broadcast from .frozen_dict import ( FrozenDict as FrozenDict, copy as copy, freeze as freeze, pop as pop, pretty_repr as pretty_repr, unfreeze as unfreeze, ) from .lift import ( custom_vjp as custom_vjp, jit as jit, jvp as jvp, remat_scan as remat_scan, remat as remat, scan as scan, vjp as vjp, vmap as vmap, while_loop as while_loop, ) from .meta import ( AxisMetadata as AxisMetadata, map_axis_meta as map_axis_meta, unbox as unbox, ) from .scope import ( DenyList as DenyList, Scope as Scope, apply as apply, bind as bind, init as init, lazy_init as lazy_init, ) from .tracers import ( check_trace_level as check_trace_level, current_trace as current_trace, ) from flax.typing import ( Array as Array, ) ================================================ FILE: flax/core/axes_scan.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around jax.lax.scan with in_axes/out_axes API.""" from collections.abc import Callable import functools from typing import Any, Optional import jax from jax import core from jax import lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe import jax.numpy as jnp import numpy as np ScanAxis = Optional[int] class _Broadcast: pass broadcast = _Broadcast() def build_shaped_array(x, batch_dim: bool = False) -> core.ShapedArray: """Builds ShapedArray preserving as much information from x as possible.""" shape = jnp.shape(x) sharding = x.aval.sharding if hasattr(x, "aval") else None if batch_dim: shape = shape[1:] if sharding is not None: if sharding.spec[0] is not None: raise ValueError( "Batch dimension in scan `xs` cannot be sharded." ) sharding = sharding.update( spec=jax.sharding.PartitionSpec(*sharding.spec[1:])) return core.ShapedArray( shape=shape, dtype=jnp.result_type(x), sharding=sharding, **{k: getattr(x, k) for k in ["weak_type", "manual_type"] if hasattr(x, k)}, ) def scan( fn: Callable[..., Any], in_axes: Any, out_axes: Any, length: int | None = None, reverse: bool = False, unroll: int = 1, _split_transpose: bool = False, check_constancy_invariants: bool = True, ): """A wrapper around `jax.lax.scan` with in_axes/out_axes api. Example:: def body_fn(b, c, x): return b + 2, c + 1, 2 * x loop = scan(body_fn, in_axes=0, out_axes=0) broadcast_in = 1 carry = 2 xs = jnp.arange(3) broadcast_out, carry, ys = loop(broadcast_in, carry, xs) print(broadcast_out) # prints: 3 print(carry) # prints: 5 print(ys) # prints: [0, 2, 4] Args: fn: the body function of the scan loop of the form `(broadcast_in, carry, *args) -> (broadcast_out, carry, scan_out)`. the broadcast argument allows for loop independent inputs/outputs to be computed inside `fn`. `fn` will be called once to compute `broadcast_out`. The actual loop will receive `broadcast_out` as the new `broadcast_in`. This is useful for initializing values inside the loop. in_axes: specifies the axis along which arguments are scanned. Use `broadcast` to use the same value across iterations. out_axes: specifies the axis along which outputs are concatenated. Use `broadcast` if a return value should not be concatenated and is independent of the loop body. length: number of iterations. Only needs to be specified if there is no scan axis from which it can be derived. reverse: scan in reverse order from end to start. unroll: how many scan iterations to unroll within a single iteration of a loop (default: 1). _split_transpose: An experimental feature to split the transpose of scan into a scan and a map, backed by an experimental Jax lax.scan() feature. check_constancy_invariants: If true, the scan will verify that the broadcast constants are true loop invariants, and further supports broadcast function (non-carry) outputs. This requires an extra jax tracing step however, so setting to false can reduce trace time on larger models. Returns: the function that performs the scan of the form: (broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out). """ def transpose_to_front(ax, xs): if ax is broadcast: return () if ax == 0: return xs def trans(x): perm = tuple(range(x.ndim)) perm = (ax,) + tuple(np.delete(perm, ax)) return jnp.transpose(x, perm) return jax.tree_util.tree_map(trans, xs) def transpose_from_front(ax, xs): if ax is broadcast: return () if ax == 0: return xs def trans(x): if ax < 0: pax = x.ndim + ax else: pax = ax assert pax < x.ndim perm = tuple(range(1, pax + 1)) + (0,) + tuple(range(pax + 1, x.ndim)) return jnp.transpose(x, perm) return jax.tree_util.tree_map(trans, xs) def scan_fn(broadcast_in, init, *args): # Requires one extra tracing operation to test invariants: # Verifies that broadcast constants are true loop invariants, and further # supports broadcast function (non-carry) outputs. xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_util.tree_map( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs ) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_util.tree_map( lambda ax, y: (y if ax is broadcast else ()), out_axes, ys ) return broadcast_out, ys else: ys = jax.tree_util.tree_map( lambda ax, y: (() if ax is broadcast else y), out_axes, ys ) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) init_flat, carry_tree = jax.tree.flatten(init) xs_flat, scan_tree = jax.tree.flatten(xs) carry_avals = [build_shaped_array(x) for x in init_flat] scan_avals = [build_shaped_array(x, batch_dim=True) for x in xs_flat] in_avals = [*carry_avals, *scan_avals] in_tree = jax.tree_util.treedef_tuple((carry_tree, scan_tree)) assert all(isinstance(a, core.AbstractValue) for a in in_avals), in_avals debug_info = jax.api_util.debug_info("flax scan", broadcast_body, (in_tree,), {}) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body, debug_info=debug_info), in_tree ) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_util.tree_unflatten( out_tree(), out_flat ) if jax.version.__version_info__ > (0, 4, 25): c, ys = lax.scan( body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose ) else: c, ys = lax.scan( body_fn, init, xs, length=length, reverse=reverse, unroll=unroll ) ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) ys = jax.tree_util.tree_map( lambda ax, const, y: (const if ax is broadcast else y), out_axes, constants_out, ys, ) return broadcast_in, c, ys def simple_scan_fn(broadcast_in, init, *args): # Saves an extra tracing operation. # No verification of constancy, and no support for non-carry broadcast # function outputs. xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) if broadcast in jax.tree_util.tree_leaves(out_axes): raise ValueError(f"nn.scan run with check_constancy_invariants=False " f"does not support broadcast non-carry function " f"outputs. out_axes was given as {out_axes}") def body_fn(c, xs): # inject constants xs = jax.tree_util.tree_map( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs ) _, c, ys = fn(broadcast_in, c, *xs) return c, ys if jax.version.__version_info__ > (0, 4, 25): c, ys = lax.scan( body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose ) else: c, ys = lax.scan( body_fn, init, xs, length=length, reverse=reverse, unroll=unroll ) ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) return broadcast_in, c, ys if check_constancy_invariants: return scan_fn else: return simple_scan_fn ================================================ FILE: flax/core/flax_functional_engine.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "import functools\n", "import jax\n", "from jax import numpy as jnp, random, lax\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "from flax import linen as nn, struct" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "from flax.core import Scope, init, apply, Array, lift, unfreeze" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab_type": "code", "outputId": "2558605e-e485-407e-b062-74d31cc49f1e", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FrozenDict({'params': {'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n", " [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})\n" ] }, { "data": { "text/plain": [ "(DeviceArray([[0.17045607]], dtype=float32),\n", " FrozenDict({'params': {'hidden': {'bias': DeviceArray([0., 0., 0.], dtype=float32), 'kernel': DeviceArray([[-0.22119394, 0.22075175, -0.0925657 ],\n", " [ 0.40571952, 0.27750877, 1.0542233 ]], dtype=float32)}, 'out': {'kernel': DeviceArray([[ 0.21448377],\n", " [-0.01530595],\n", " [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)}}}))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def dense(\n", " scope: Scope,\n", " inputs: Array,\n", " features: int,\n", " bias: bool = True,\n", " kernel_init=nn.linear.default_kernel_init,\n", " bias_init=nn.initializers.zeros_init(),\n", "):\n", " kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features))\n", " y = jnp.dot(inputs, kernel)\n", " if bias:\n", " y += scope.param('bias', bias_init, (features,))\n", " return y\n", "\n", "\n", "model_fn = functools.partial(dense, features=3)\n", "\n", "x = jnp.ones((1, 2))\n", "y, params = init(model_fn)(random.key(0), x)\n", "print(params)\n", "\n", "\n", "def mlp(scope: Scope, inputs: Array, features: int):\n", " hidden = scope.child(dense, 'hidden')(inputs, features)\n", " hidden = nn.relu(hidden)\n", " return dense(scope.push('out'), hidden, 1)\n", "\n", "\n", "init(mlp)(random.key(0), x, features=3)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab_type": "code", "outputId": "5790b763-df4f-47c8-9f4e-53fd1e1eb1fd", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.11575121 -0.51936364 -1.113899 ]\n", " [ 0.45569834 -0.5300623 -0.5873911 ]]\n", "[ 0.45569834 -0.5300623 -0.5873911 ]\n", "[[-1.5175114 -0.6617551]]\n" ] } ], "source": [ "@struct.dataclass\n", "class Embedding:\n", " table: np.ndarray\n", "\n", " def lookup(self, indices):\n", " return self.table[indices]\n", "\n", " def attend(self, query):\n", " return jnp.dot(query, self.table.T)\n", "\n", "\n", "# all the embedding module does is provide a convenient initializers\n", "\n", "\n", "def embedding(\n", " scope: Scope,\n", " num_embeddings: int,\n", " features: int,\n", " init_fn=nn.linear.default_embed_init,\n", ") -> Embedding:\n", " table = scope.param('table', init_fn, (num_embeddings, features))\n", " return Embedding(table)\n", "\n", "\n", "embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n", "print(embedding.table)\n", "print(embedding.lookup(1))\n", "print(\n", " embedding.attend(\n", " jnp.ones((\n", " 1,\n", " 3,\n", " ))\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab_type": "code", "outputId": "dd9c5079-10e7-4944-e09a-e9f65573a733" }, "outputs": [ { "data": { "text/plain": [ "((((1, 3), (1, 3)), (1, 3)),\n", " FrozenDict({'params': {'hf': {'bias': (3,), 'kernel': (3, 3)}, 'hg': {'bias': (3,), 'kernel': (3, 3)}, 'hi': {'bias': (3,), 'kernel': (3, 3)}, 'ho': {'bias': (3,), 'kernel': (3, 3)}, 'if': {'kernel': (2, 3)}, 'ig': {'kernel': (2, 3)}, 'ii': {'kernel': (2, 3)}, 'io': {'kernel': (2, 3)}}}))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def lstm(\n", " scope,\n", " carry,\n", " inputs,\n", " gate_fn=nn.activation.sigmoid,\n", " activation_fn=nn.activation.tanh,\n", " kernel_init=nn.linear.default_kernel_init,\n", " recurrent_kernel_init=nn.initializers.orthogonal(),\n", " bias_init=nn.initializers.zeros_init(),\n", "):\n", " r\"\"\"A long short-term memory (LSTM) cell.\n", "\n", " the mathematical definition of the cell is as follows\n", " .. math::\n", " \\begin{array}{ll}\n", " i = \\sigma(W_{ii} x + W_{hi} h + b_{hi}) \\\\\n", " f = \\sigma(W_{if} x + W_{hf} h + b_{hf}) \\\\\n", " g = \\tanh(W_{ig} x + W_{hg} h + b_{hg}) \\\\\n", " o = \\sigma(W_{io} x + W_{ho} h + b_{ho}) \\\\\n", " c' = f * c + i * g \\\\\n", " h' = o * \\tanh(c') \\\\\n", " \\end{array}\n", " where x is the input, h is the output of the previous time step, and c is\n", " the memory.\n", "\n", " Args:\n", " carry: the hidden state of the LSTM cell,\n", " initialized using `LSTMCell.initialize_carry`.\n", " inputs: an ndarray with the input for the current time step.\n", " All dimensions except the final are considered batch dimensions.\n", " gate_fn: activation function used for gates (default: sigmoid)\n", " activation_fn: activation function used for output and memory update\n", " (default: tanh).\n", " kernel_init: initializer function for the kernels that transform\n", " the input (default: lecun_normal).\n", " recurrent_kernel_init: initializer function for the kernels that transform\n", " the hidden state (default: orthogonal).\n", " bias_init: initializer for the bias parameters (default: zeros_init())\n", " Returns:\n", " A tuple with the new carry and the output.\n", " \"\"\"\n", " c, h = carry\n", " hidden_features = h.shape[-1]\n", " # input and recurrent layers are summed so only one needs a bias.\n", " dense_h = lambda name: scope.child(dense, name)(\n", " h,\n", " features=hidden_features,\n", " bias=True,\n", " kernel_init=recurrent_kernel_init,\n", " bias_init=bias_init,\n", " )\n", " dense_i = lambda name: scope.child(dense, name)(\n", " inputs, features=hidden_features, bias=False, kernel_init=kernel_init\n", " )\n", " i = gate_fn(dense_i(name='ii') + dense_h(name='hi'))\n", " f = gate_fn(dense_i(name='if') + dense_h(name='hf'))\n", " g = activation_fn(dense_i(name='ig') + dense_h(name='hg'))\n", " o = gate_fn(dense_i(name='io') + dense_h(name='ho'))\n", " new_c = f * c + i * g\n", " new_h = o * activation_fn(new_c)\n", " return (new_c, new_h), new_h\n", "\n", "\n", "def lstm_init_carry(batch_dims, size, init_fn=jnp.zeros):\n", " shape = batch_dims + (size,)\n", " return init_fn(shape), init_fn(shape)\n", "\n", "\n", "x = jnp.ones((1, 2))\n", "carry = lstm_init_carry((1,), 3)\n", "y, variables = init(lstm)(random.key(0), carry, x)\n", "jax.tree_util.tree_map(np.shape, (y, variables))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}\n" ] } ], "source": [ "def simple_scan(scope: Scope, xs):\n", " init_carry = lstm_init_carry(xs.shape[:1], xs.shape[-1])\n", " # cell = scope.child(lstm, 'cell')\n", " # ys = []\n", " # for i in range(xs.shape[1]):\n", " # x = xs[:, i]\n", " # init_carry, y = cell(init_carry, x)\n", " # ys.append(y)\n", " # return init_carry, ys\n", " lstm_scan = lift.scan(\n", " lstm,\n", " in_axes=1,\n", " out_axes=1,\n", " variable_broadcast='params',\n", " split_rngs={'params': False},\n", " )\n", " return lstm_scan(scope, init_carry, xs)\n", "\n", "\n", "key1, key2 = random.split(random.key(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "\n", "y, init_variables = init(simple_scan)(key2, xs)\n", "\n", "print(\n", " 'initialized parameter shapes:\\n',\n", " jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)),\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output:\n", " (DeviceArray([[-0.35626447, 0.25178757]], dtype=float32), DeviceArray([[-0.17885922, 0.13063088]], dtype=float32))\n" ] } ], "source": [ "y = apply(simple_scan)(init_variables, xs)[0]\n", "print('output:\\n', y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: flax/core/frozen_dict.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Frozen Dictionary.""" import collections from types import MappingProxyType from typing import Any, TypeVar from collections.abc import Hashable, Mapping import jax from flax import serialization class FrozenKeysView(collections.abc.KeysView): """A wrapper for a more useful repr of the keys in a frozen dict.""" def __repr__(self): return f'frozen_dict_keys({list(self)})' class FrozenValuesView(collections.abc.ValuesView): """A wrapper for a more useful repr of the values in a frozen dict.""" def __repr__(self): return f'frozen_dict_values({list(self)})' K = TypeVar('K') V = TypeVar('V') def _indent(x, num_spaces): indent_str = ' ' * num_spaces lines = x.split('\n') assert not lines[-1] # skip the final line because it's empty and should not be indented. return '\n'.join(indent_str + line for line in lines[:-1]) + '\n' @jax.tree_util.register_pytree_with_keys_class class FrozenDict(Mapping[K, V]): """An immutable variant of the Python dict.""" __slots__ = ('_dict', '_hash') def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name # make sure the dict is as xs = dict(*args, **kwargs) if __unsafe_skip_copy__: self._dict = xs else: self._dict = _prepare_freeze(xs) self._hash = None def __getitem__(self, key): v = self._dict[key] if isinstance(v, dict): return FrozenDict(v) return v def __setitem__(self, key, value): raise ValueError('FrozenDict is immutable.') def __contains__(self, key): return key in self._dict def __iter__(self): return iter(self._dict) def __len__(self): return len(self._dict) def __repr__(self): return self.pretty_repr() def __reduce__(self): return FrozenDict, (self.unfreeze(),) def get(self, key, default=None): """Get an item from the FrozenDict.""" if key in self._dict: return self[key] return default def pretty_repr(self, num_spaces=4): """Returns an indented representation of the nested dictionary.""" def pretty_dict(x): if not isinstance(x, dict): return repr(x) rep = '' for key, val in x.items(): rep += f'{key}: {pretty_dict(val)},\n' if rep: return '{\n' + _indent(rep, num_spaces) + '}' else: return '{}' return f'FrozenDict({pretty_dict(self._dict)})' def __hash__(self): if self._hash is None: h = 0 for key, value in self.items(): h ^= hash((key, value)) self._hash = h return self._hash def copy( self, add_or_replace: Mapping[K, V] = MappingProxyType({}) ) -> 'FrozenDict[K, V]': """Create a new FrozenDict with additional or replaced entries.""" return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] def keys(self): return FrozenKeysView(self) def values(self): return FrozenValuesView(self) def items(self): for key in self._dict: yield (key, self[key]) def pop(self, key: K) -> tuple['FrozenDict[K, V]', V]: """Create a new FrozenDict where one entry is removed. Example:: >>> from flax.core import FrozenDict >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}}) >>> new_variables, params = variables.pop('params') Args: key: the key to remove from the dict Returns: A pair with the new FrozenDict and the removed value. """ value = self[key] new_dict = dict(self._dict) new_dict.pop(key) new_self = type(self)(new_dict) return new_self, value def unfreeze(self) -> dict[K, V]: """Unfreeze this FrozenDict. Returns: An unfrozen version of this FrozenDict instance. """ return unfreeze(self) def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]: """Flattens this FrozenDict. Returns: A flattened version of this FrozenDict instance. """ sorted_keys = sorted(self._dict) return tuple( [(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys] ), tuple(sorted_keys) @classmethod def tree_unflatten(cls, keys, values): # data is already deep copied due to tree map mechanism # we can skip the deep copy in the constructor return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True) def _prepare_freeze(xs: Any) -> Any: """Deep copy unfrozen dicts to make the dictionary FrozenDict safe.""" if isinstance(xs, FrozenDict): # we can safely ref share the internal state of a FrozenDict # because it is immutable. return xs._dict # pylint: disable=protected-access if not isinstance(xs, dict): # return a leaf as is. return xs # recursively copy dictionary to avoid ref sharing return {key: _prepare_freeze(val) for key, val in xs.items()} def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]: """Freeze a nested dict. Makes a nested ``dict`` immutable by transforming it into ``FrozenDict``. Args: xs: Dictionary to freeze (a regualr Python dict). Returns: The frozen dictionary. """ return FrozenDict(xs) def unfreeze(x: FrozenDict | dict[str, Any]) -> dict[Any, Any]: """Unfreeze a FrozenDict. Makes a mutable copy of a ``FrozenDict`` mutable by transforming it into (nested) dict. Args: x: Frozen dictionary to unfreeze. Returns: The unfrozen dictionary (a regular Python dict). """ if isinstance(x, FrozenDict): # deep copy internal state of a FrozenDict # the dict branch would also work here but # it is much less performant because jax.tree_util.tree_map # uses an optimized C implementation. return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore elif isinstance(x, dict): ys = {} for key, value in x.items(): ys[key] = unfreeze(value) return ys else: return x def copy( x: FrozenDict | dict[str, Any], add_or_replace: FrozenDict[str, Any] | dict[str, Any] = FrozenDict({}), ) -> FrozenDict | dict[str, Any]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of ``FrozenDict.copy``. Example:: >>> from flax.core import FrozenDict, copy >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}}) >>> new_variables = copy(variables, {'additional_entries': 1}) Args: x: the dictionary to be copied and updated add_or_replace: dictionary of key-value pairs to add or replace in the dict x Returns: A new dict with the additional and/or replaced entries. """ if isinstance(x, FrozenDict): return x.copy(add_or_replace) elif isinstance(x, dict): new_dict = jax.tree_util.tree_map( lambda x: x, x ) # make a deep copy of dict x new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') def pop( x: FrozenDict | dict[str, Any], key: str ) -> tuple[FrozenDict | dict[str, Any], Any]: """Create a new dict where one entry is removed. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of ``FrozenDict.pop``. Example:: >>> from flax.core import FrozenDict, pop >>> variables = FrozenDict({'params': {...}, 'batch_stats': {...}}) >>> new_variables, params = pop(variables, 'params') Args: x: the dictionary to remove the entry from key: the key to remove from the dict Returns: A pair with the new dict and the removed value. """ if isinstance(x, FrozenDict): return x.pop(key) elif isinstance(x, dict): new_dict = jax.tree_util.tree_map( lambda x: x, x ) # make a deep copy of dict x value = new_dict.pop(key) return new_dict, value raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') def pretty_repr(x: Any, num_spaces: int = 4) -> str: """Returns an indented representation of the nested dictionary. This is a utility function that can act on either a FrozenDict or regular dict and mimics the behavior of ``FrozenDict.pretty_repr``. If x is any other dtype, this function will return ``repr(x)``. Args: x: the dictionary to be represented num_spaces: the number of space characters in each indentation level Returns: An indented string representation of the nested dictionary. """ if isinstance(x, FrozenDict): return x.pretty_repr() else: def pretty_dict(x): if not isinstance(x, dict): return repr(x) rep = '' for key, val in x.items(): rep += f'{key}: {pretty_dict(val)},\n' if rep: return '{\n' + _indent(rep, num_spaces) + '}' else: return '{}' return pretty_dict(x) def _frozen_dict_state_dict(xs): str_keys = {str(k) for k in xs.keys()} if len(str_keys) != len(xs): raise ValueError( 'Dict keys do not have a unique string representation: ' f'{str_keys} vs given: {xs}' ) return {str(key): serialization.to_state_dict(value) for key, value in xs.items()} def _restore_frozen_dict(xs, states): diff = set(map(str, xs.keys())).difference(map(str, states.keys())) if diff: raise ValueError( 'The target dict keys and state dict keys do not match, target dict' f' contains keys {diff} which are not present in state dict at path' f' {serialization.current_path()}' ) return FrozenDict( { key: serialization.from_state_dict(value, states[str(key)], name=key) for key, value in xs.items() } ) serialization.register_serialization_state( FrozenDict, _frozen_dict_state_dict, _restore_frozen_dict ) ================================================ FILE: flax/core/lift.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Jax transform lifting.""" import collections from collections.abc import Callable, Iterable, Mapping, Sequence import contextlib import dataclasses import functools import threading from typing import Any, Generic, TypeVar import warnings from flax import traceback_util from flax import traverse_util from flax.typing import ( In, InOutAxis, InOutScanAxis, Out, ) import jax from jax import random from . import axes_scan, meta from .frozen_dict import freeze, unfreeze from .scope import ( CollectionFilter, DenyList, # pylint: disable=g-multiple-import Filter, LazyRng, PRNGSequenceFilter, Scope, group_collections, in_filter, intersect_filters, is_filter_empty, subtract_filters, union_filters, ) traceback_util.register_exclusion(__file__) A = TypeVar('A') @dataclasses.dataclass class TransformContext(Generic[A], threading.local): """Context for a transform.""" stack: list[A] = dataclasses.field(default_factory=list) @contextlib.contextmanager def push(self, a: A): self.stack.append(a) try: yield finally: self.stack.pop() def get(self) -> A: return self.stack[-1] def tree_map_rngs(fn, tree): """Needed for mapping JAX random.* functions over PRNGKey leaves.""" return jax.tree_util.tree_map( fn, tree, is_leaf=lambda x: hasattr(x, 'dtype') and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key), ) def _dedup_scopes(scopes): """Deduplicated scopes.""" paths = [] # must preseve insertion order for duplication to work correctly minimal_set = collections.OrderedDict((s, ()) for s in scopes) for leaf in scopes: scope = leaf.parent max_parent = leaf max_parent_path = () path = [leaf.name] while scope is not None: if scope in minimal_set: max_parent = scope max_parent_path = tuple(reversed(path)) path.append(scope.name) scope = scope.parent if max_parent is not leaf and leaf in minimal_set: del minimal_set[leaf] paths.append((max_parent, max_parent_path)) return tuple(minimal_set), tuple(paths) def _dup_scopes(orig_scopes, scopes, paths): """Duplicated scopes.""" mapping = dict(zip(orig_scopes, scopes)) scopes = [] for root, path in paths: scope = mapping[root] for name in path: scope = scope.push(name, reuse=True) scopes.append(scope) return scopes def _transpose(xs): return tuple(zip(*xs)) def _partial_pack( scope_tree: Scope, in_variable_filters: Sequence[CollectionFilter], out_variable_filters: Sequence[CollectionFilter], rng_filters: Sequence[PRNGSequenceFilter], name=None, ) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any]]: """Pack variables and rngs for functional transformations. The _partial_pack function is the building block for all other lifted transformations. Args: fn: The function to pack. `fn` has the signature in_variable_filters: Input variable filters. out_variable_filters: Output variable filters. rng_filters: RNG filters. name: The name of the packed scope. enable_kwargs: Whether to enable kwargs or not. Returns: `(scope_fn, repack_fn, variable_groups, rng_groups, publish_results_fn)` """ # pylint: disable=protected-access scopes, treedef = jax.tree_util.tree_flatten(scope_tree) scopes, paths = _dedup_scopes(scopes) variable_groups_xs = [] for scope in scopes: scope._validate_trace_level() scope._populate_collections() variable_groups_xs.append( group_collections(scope._variables, in_variable_filters) ) variable_groups_xs_t = _transpose(variable_groups_xs) # Make sure that in-only variable collections are frozen for variable_group_xs in variable_groups_xs_t: for variable_group in variable_group_xs: for col_name, collection in variable_group.items(): col_in_out = any( in_filter(col_filter, col_name) for col_filter in out_variable_filters ) if not col_in_out: variable_group[col_name] = freeze(collection) rng_groups_xs = [] inner_rng_counters = [] for scope in scopes: rng_counters = scope.rng_counters rng_groups = group_collections(scope.rngs, rng_filters) rng_groups_xs.append(rng_groups) inner_rng_counters.append(rng_counters) rng_groups_xs_t = _transpose(rng_groups_xs) def scope_fn( variable_groups_xs_t, rng_groups_xs_t, mutable_filter: CollectionFilter = True, ): inner_scopes = [] mutable: Filter = False for out_filter in out_variable_filters: mutable = union_filters(mutable, out_filter) # could be () in the edge case where no rngs or variable_groups are lifted # in this case fallback to ((),) * len(scopes) to make sure the zip has # something to iterate over for each scope. variable_groups_xs = _transpose(variable_groups_xs_t) or ((),) * len( scopes ) rng_groups_xs = _transpose(rng_groups_xs_t) or ((),) * len(scopes) assert len(variable_groups_xs) == len(scopes) assert len(rng_groups_xs) == len(scopes) for variable_groups, rng_groups, scope, rng_counters in zip( variable_groups_xs, rng_groups_xs, scopes, inner_rng_counters ): variables = {} rngs = {} for variable_group in variable_groups: variables.update(variable_group) for rng_group in rng_groups: rngs.update(rng_group) # make sure variable dicts are cloned and can't be manipulated by ref # sharing. variables = jax.tree_util.tree_map(lambda x: x, variables) scope_mutable = intersect_filters( intersect_filters(scope.mutable, mutable), mutable_filter ) new_debug_path = scope.debug_path if name: if new_debug_path: new_debug_path = new_debug_path[:-1] + ( f'{name}({new_debug_path[-1]})', ) else: new_debug_path = (f'{name}()',) inner_scope = Scope( variables, name=scope.name, rngs=rngs, mutable=scope_mutable, parent=None, path=scope.path, debug_path=new_debug_path, flags=scope.flags, ) inner_scope.rng_counters = rng_counters inner_scopes.append(inner_scope) inner_scopes = _dup_scopes(scopes, inner_scopes, paths) return treedef.unflatten(inner_scopes) def repack_fn(inner_scope_tree): inner_scopes = treedef.flatten_up_to(inner_scope_tree) inner_scopes, inner_paths = _dedup_scopes(inner_scopes) inner_scopes = list(inner_scopes) assert [p for _, p in paths] == [p for _, p in inner_paths] out_variable_groups_xs = [] for inner_scope in inner_scopes: inner_scope.invalidate() inner_scope._validate_trace_level() mutable_variables = { key: val for key, val in inner_scope._variables.items() if in_filter(inner_scope.mutable, key) } out_variable_groups = group_collections( mutable_variables, tuple(out_variable_filters) + (True,) ) remainder = tuple(out_variable_groups[-1].keys()) if remainder: raise ValueError(f'unmapped output variables: {remainder}') out_variable_groups_xs.append(out_variable_groups[:-1]) return _transpose(out_variable_groups_xs) def publish_results_fn(out_variable_groups_xs_t): out_variable_groups_xs = _transpose(out_variable_groups_xs_t) for scope, out_variable_groups, rng_counters in zip( scopes, out_variable_groups_xs, inner_rng_counters ): for out_variable_group in out_variable_groups: for col_name, collection in out_variable_group.items(): if not scope.is_mutable_collection(col_name): # Some lifted transforms like scan return redundant variables. continue for var_name, value in collection.items(): scope.put_variable(col_name, var_name, value) return ( scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, publish_results_fn, ) def pack( fn: Callable[..., Any], in_variable_filters: Sequence[CollectionFilter], out_variable_filters: Sequence[CollectionFilter], rng_filters: Sequence[PRNGSequenceFilter], name=None, enable_kwargs=False, ) -> Callable[..., Any]: """Pack variables and rngs for functional transformations. The pack function is the building block for all other lifted transformations. Args: fn: The function to pack. `fn` has the signature `(scope_fn, repack_fn, variable_groups, rng_groups, *args) -> (output, packed_variables)`. in_variable_filters: Input variable filters. out_variable_filters: Output variable filters. rng_filters: RNG filters. name: The name of the packed scope. enable_kwargs: Whether to enable kwargs or not. Returns: A callable which expects a scope as the first argument. """ @functools.wraps(fn) def wrapper(scope_tree: Scope, *args, **kwargs): if not enable_kwargs and kwargs: msg = 'kwargs are not supported in {}, so "{}" is(are) ignored' warnings.warn(msg.format(name, ', '.join(kwargs.keys())), RuntimeWarning) ( scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, publish_results_fn, ) = _partial_pack(scope_tree, in_variable_filters, out_variable_filters, rng_filters, name) if enable_kwargs: y, out_variable_groups_xs_t = fn( scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args, **kwargs, ) else: y, out_variable_groups_xs_t = fn( scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args ) publish_results_fn(out_variable_groups_xs_t) return y return wrapper id_fn = lambda x: x def map_variables( fn: Callable[..., Any], mapped_collections: CollectionFilter, map_in_fn: Callable[..., Any] = id_fn, map_out_fn: Callable[..., Any] = id_fn, init: bool = False, mutable: bool = False, rngs: PRNGSequenceFilter = True, variables: CollectionFilter = True, ) -> Callable[..., Any]: """Map Variables inside a scope. Args: fn: the function to be transformed. mapped_collections: the collection(s) to be transformed. map_in_fn: creates a view of the target variables. map_out_fn: transforms the updated variables in the view after mutation. init: If True, variables are initialized before transformation. mutable: If True, the mapped variable collections will be mutable. rngs: PRNGSequences added to the transformed scope (default: all). variables: Additional Variable collections added to the transformed scope. Besides those specified by `target` (default: all). Returns: A callable expecting a scope as the first argument. """ is_target_out = mutable or init def wrapper(scope_fn, repack, variable_groups, rng_groups, *args, **kwargs): target, variables = variable_groups if init: scopes = scope_fn((target, variables), rng_groups) has_mutable_cols = any( not is_filter_empty(scope.mutable) for scope in jax.tree_util.tree_leaves(scopes) ) if has_mutable_cols: fn(scopes, *args, **kwargs) target, _ = repack(scopes) target = tuple(map_out_fn(x) for x in target) target = tuple(map_in_fn(unfreeze(x)) for x in target) mfilter = True if not is_target_out: # mapped collections should not be mutable # unless the mapping supports it (by init=True or mutable=True) mfilter = subtract_filters(mfilter, mapped_collections) scopes = scope_fn((target, variables), rng_groups, mutable_filter=mfilter) y = fn(scopes, *args, **kwargs) out_target, out_vars = repack(scopes) if is_target_out: out_target = tuple(map_out_fn(x) for x in out_target) return y, (out_target, out_vars) in_vars = (mapped_collections, variables) out_vars = ( in_vars if is_target_out else (False, subtract_filters(variables, mapped_collections)) ) return pack( wrapper, in_vars, out_vars, (rngs,), enable_kwargs=True, name='map_variables', ) def swap_collection(fn: Callable[..., Any], col_a: str, col_b: str): """Swap two collections.""" def swap(target): a = target[col_a] if col_a in target else {} b = target[col_b] if col_b in target else {} target[col_b], target[col_a] = a, b return target return map_variables(fn, (col_a, col_b), swap, swap, mutable=True) def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]): unpack = lambda v: v.axis if isinstance(v, (In, Out)) else v in_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, Out)} out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)} return in_axes, out_axes def _bwd_wrapper(treedef, bwd_fn, tangent): vars_grad, *inputs_grad = bwd_fn(tangent) vars_grad = treedef.unflatten(vars_grad) return (vars_grad, *inputs_grad) def vjp( fn: Callable[..., Any], scope: Scope, *primals, has_aux: bool = False, reduce_axes=(), vjp_variables: CollectionFilter = 'params', variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A lifted version of ``jax.vjp``. See ``jax.vjp`` for the unlifted vector-Jacobian product (backward gradient). Note that a gradient is returned for all variables in the collections specified by `vjp_variables`. However, the backward function only expects a cotangent for the return value of `fn`. If variables require a co-tangent as well they can be returned from `fn` using `scope.variables()`. Example:: def learn_scale(scope, x, y): p = scope.param('scale', nn.initializers.zeros_init(), ()) return p * x * y def f(scope, x, y): z, bwd = lift.vjp(learn_scale, scope, x, y) params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) return z, params_grad, x_grad, y_grad Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. scope: The scope of which the variables will be differentiated. *primals: A sequence of primal values at which the Jacobian of ``fn`` should be evaluated. The length of ``primals`` should be equal to the number of positional parameters to ``fn``. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default ``False``. vjp_variables: The vjpfun will return a cotangent vector for all variable collections specified by this filter. variables: other variables collections that are available inside `fn` but do not receive a cotangent. rngs: the prngs that are available inside `fn`. Returns: If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where ``primals_out`` is ``fn(*primals)``. ``vjpfun`` is a function from a cotangent vector with the same shape as ``primals_out`` to a tuple of cotangent vectors with the same shape as ``primals``, representing the vector-Jacobian product of ``fn`` evaluated at ``primals``. If ``has_aux`` is ``True``, returns a ``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ if reduce_axes: raise NotImplementedError('reduce_axes argument to vjp is deprecated') del reduce_axes def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): vjp_vars, other_vars = variable_groups @functools.wraps(fn) def wrapper(vjp_vars, *args): variable_groups = (vjp_vars, other_vars) scope = scope_fn(variable_groups, rng_groups) if has_aux: y, aux = fn(scope, *args) else: y = fn(scope, *args) aux = () return y, (aux, repack_fn(scope)) y, bwd, (aux, out_vars) = jax.vjp( wrapper, vjp_vars, *args, has_aux=True ) treedef = jax.tree_util.tree_structure(scope) bwd = jax.tree_util.Partial(functools.partial(_bwd_wrapper, treedef), bwd) if has_aux: return (y, bwd, aux), out_vars else: return (y, bwd), out_vars return pack( inner, (vjp_variables, variables), (variables,), (rngs,), name='vjp', enable_kwargs=False, )(scope, *primals) def value_and_grad( fn: Callable[..., Any], scope: Scope, *primals, has_aux: bool = False, reduce_axes=(), variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A limited lifted version of ``jax.value_and_grad``. See ``jax.value_and_grad`` for the unlifted reverse mode gradient. Note that for this convenience function, gradients are only calculated for the function inputs (all function inputs), and not with respect to any scope variables. The target function must return a scalar-valued output. Example:: def learn_scale(scope, x, y): p = scope.param('scale', nn.initializers.zeros_init(), ()) return p * x * y def f(scope, x, y): z, x_grad, y_grad = lift.value_and_grad(learn_scale, scope, x, y) return z, x_grad, y_grad Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. scope: The scope of which the variables will be differentiated. *primals: A sequence of primal values at which the Jacobian of ``fn`` should be evaluated. The length of ``primals`` should be equal to the number of positional parameters to ``fn``. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default ``False``. variables: other variables collections that are available inside `fn` but do not receive a cotangent. rngs: the prngs that are available inside `fn`. Returns: If ``has_aux`` is ``False``, returns a ``(primals_out, grads)`` pair, where ``primals_out`` is ``fn(*primals)``. If ``has_aux`` is ``True``, returns a ``(primals_out, aux, grads)`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ if reduce_axes: raise NotImplementedError( 'reduce_axes argument to value_and_grad is deprecated') del reduce_axes def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): @functools.wraps(fn) def wrapper(*args): scope = scope_fn(variable_groups, rng_groups) if has_aux: y, aux = fn(scope, *args) else: y = fn(scope, *args) aux = () return y, (aux, repack_fn(scope)) y, bwd, (aux, out_vars) = jax.vjp( wrapper, *args, has_aux=True, ) inputs_grad = bwd(jax.numpy.ones_like(y)) if has_aux: return (y, aux, inputs_grad), out_vars else: return (y, inputs_grad), out_vars return pack( inner, (variables,), (variables,), (rngs,), name='value_and_grad', enable_kwargs=False, )(scope, *primals) def jvp( fn: Callable[..., Any], scope: Scope, primals, tangents, variable_tangents, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> tuple[Any, Any]: """A lifted version of ``jax.jvp``. See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by `fn` using `scope.variables()`. Example:: def learn_scale(scope, x): p = scope.param('scale', nn.initializers.zeros_init(), ()) return p * x def f(scope, x): vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) x, out_t = lift.jvp( learn_scale, scope, (x,), (jnp.zeros_like(x),), variable_tangents={'params': vars_t}) return out_t Args: fn: The function to be transformed. scope: The scope(s) which should be lifted into the transform. primals: The primal values at which the Jacobian of ``fun`` should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of ``fun``. tangents: The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as ``primals``. variable_tangents: A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent. variables: other variables collections that are available inside `fn` but do not receive a tangent. rngs: the prngs that are available inside `fn`. Returns: A ``(primals_out, tangents_out)`` pair, where ``primals_out`` is ``fun(*primals)``, and ``tangents_out`` is the Jacobian-vector product of ``function`` evaluated at ``primals`` with ``tangents``. The ``tangents_out`` value has the same Python tree structure and shapes as ``primals_out``. """ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): jvp_vars, other_vars = variable_groups @functools.wraps(fn) def wrapper(vars_primals, args): variable_groups = (vars_primals, other_vars) scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args) return y, repack_fn(scope) (y, out_vars), out_tangents = jax.jvp( wrapper, (jvp_vars, args), (variable_tangents, tangents) ) return (y, out_tangents[0]), out_vars # filter out empty tangent collections because JAX will error on non-equal # tree structure for example: {"params": {}} != {}. treedef = jax.tree_util.tree_structure(scope) variable_tangents = tuple( {k: v for k, v in vt.items() if v} # pylint: disable=g-complex-comprehension for vt in treedef.flatten_up_to(variable_tangents) ) target = tuple(variable_tangents[0].keys()) return pack( inner, (target, variables), (variables,), (rngs,), name='jvp', enable_kwargs=False, )(scope, *primals) def vmap( fn: Callable[..., Any], variable_axes: Mapping[CollectionFilter, InOutAxis], split_rngs: Mapping[PRNGSequenceFilter, bool], in_axes=0, out_axes=0, axis_size: int | None = None, axis_name: str | None = None, spmd_axis_name: str | None = None, metadata_params: dict[Any, Any] = {}, ) -> Callable[..., Any]: """A lifted version of ``jax.vmap``. See ``jax.vmap`` for the unlifted batch transform in Jax. ``vmap`` can be used to add a batch axis to a scope function. For example we could create a version of ``dense`` with a batch axis that does not share parameters:: batch_dense = lift.vmap( nn.dense, in_axes=(0, None), variable_axes={'params': 0}, split_rngs={'params': True}) By using ``variable_axes={'params': 0}``, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the 'params' RNG, otherwise the parameters would be initialized identically along the mapped axis. Similarly, ``vmap`` could be use to add a batch axis with parameter sharing:: batch_foo = lift.vmap( foo, in_axes=0, out_axes=0, variable_axes={'params': None}, split_rngs={'params': False}) Here we use ``variable_axes={'params': None}`` to indicate the parameter variables are shared along the mapped axis. Consequently, the 'params' RNG must also be shared. Args: fn: the function to be transformed. variable_axes: the variable collections that are lifted into the batching transformation. Use `None` to indicate a broadcasted collection or an integer to map over an axis. split_rngs: Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see `jax.vmap). out_axes: Specifies the mapping of the return value (see `jax.vmap). axis_size: Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments. axis_name: Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, etc.). Note, this is only used for pmap and shmap. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in `fn`. See also https://github.com/google/flax/blob/main/flax/linen/partitioning.py. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. Returns: A vectorized version of the input scope function. """ variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) rng_groups, rng_splits = _unzip2(split_rngs.items()) rng_axes = tuple(0 if rng_split else None for rng_split in rng_splits) def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): # optional user-defined variable transform on the way in new_variable_groups = [] for var_group, axis in zip(variable_groups, variable_in_axes): if axis is not None: new_variable_groups.append( meta.remove_axis(var_group, axis, metadata_params) ) else: new_variable_groups.append(var_group) variable_groups = tuple(new_variable_groups) # split rngs def find_axis_size(axis, x): if axis is not None: leaves = jax.tree_util.tree_leaves(x) if leaves: return leaves[0].shape[axis] return () axis_sizes = jax.tree_util.tree_map( find_axis_size, (variable_in_axes, in_axes), (variable_groups, args), is_leaf=lambda x: x is None ) axis_sizes = set(jax.tree_util.tree_leaves(axis_sizes)) if axis_size is None and len(axis_sizes) == 1: (d_axis_size,) = axis_sizes elif len(axis_sizes) > 1: raise ValueError(f'Inconsistent batch axis sizes: {axis_sizes}') elif axis_size is None: raise ValueError('axis_size should be specified manually.') else: d_axis_size = axis_size def split_fn(rng): # random.clone is only available on Jax versions 0.4.26 or newer. See: # https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.htmls if hasattr(random, 'clone'): rng = random.clone(rng) rngs = random.split(rng, d_axis_size) if spmd_axis_name is not None: args_flat, _ = jax.tree.flatten(args) axes_flat = _broadcast_prefix_tree(in_axes, args) any_vmapped_axis_sharded = any( jax.typeof(x).sharding.spec[i] == spmd_axis_name for x, i in zip(args_flat, axes_flat) if i is not None ) if any_vmapped_axis_sharded: rngs = jax.sharding.reshard(rngs, jax.P(spmd_axis_name)) return rngs rng_groups = tuple( tree_map_rngs(split_fn, rng_group) if split else rng_group for rng_group, split in zip(rng_groups, rng_splits) ) @functools.partial( jax.vmap, in_axes=(variable_in_axes, rng_axes, in_axes), out_axes=(out_axes, variable_out_axes), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(fn) def mapped(variable_groups, rng_groups, args): scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args) return y, repack_fn(scope) # optional user-defined variable transform on the way out y, vars_out = mapped(variable_groups, rng_groups, args) new_vars_out = [] for var_group, axis in zip(vars_out, variable_out_axes): if axis is not None: new_vars_out.append(meta.add_axis(var_group, axis, metadata_params)) else: new_vars_out.append(var_group) vars_out = tuple(new_vars_out) return y, vars_out return pack( inner, variable_in_groups, variable_out_groups, rng_groups, name='vmap' ) def scan( fn: Callable[..., Any], variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, length: int | None = None, reverse: bool = False, unroll: int = 1, _split_transpose: bool = False, data_transform: Callable[..., Any] | None = None, metadata_params: dict[Any, Any] = {}, check_constancy_invariants: bool = True, ) -> Callable[..., Any]: """A lifted version of ``jax.lax.scan``. See ``jax.lax.scan`` for the unlifted scan in Jax. To improve consistency with ``vmap``, this version of scan uses ``in_axes`` and ``out_axes`` to determine which arguments are scanned over and along which axis. ``scan`` distinguishes between 3 different types of values inside the loop: 1. **scan**: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis. 2. **carry**: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop. 3. **broadcast**: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables. The loop body should have the signature ``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. Example:: scope.variable('counter', 'i', jnp.zeros, ()) def body_fn(scope, c, x): counter = scope.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = scope.child(nn.dense)(x, 1) return c, x _, ys = lift.scan( body_fn, variable_carry='counter', variable_broadcast='params', split_rngs={'params': False})(scope, (), xs) Args: fn: the function to be transformed. variable_axes: the variable collections that are scanned over. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot b lifted out of the loop. This is typically used to define shared parameters inside the fn. variable_carry: Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. in_axes: Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use `flax.core.broadcast` to feed an entire input to each iteration of the scan body. out_axes: Specifies the axis to scan over for the return value. Should be a prefix tree of the return value. length: Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments. reverse: If true, scan from end to start in reverse order. unroll: how many scan iterations to unroll within a single iteration of a loop (default: 1). _split_transpose: An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature. data_transform: optional function to transform raw variable and rng groups, intended for inline SPMD annotations. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. check_constancy_invariants: If true, the scan will verify that the broadcast constants are true loop invariants, and further supports broadcast function (non-carry) outputs. This requires an extra jax tracing step however, so setting to false can reduce trace time on larger models. Returns: The scan function with the signature ``(scope, carry, *xxs) -> (carry, yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of the loop. """ variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) assert all(isinstance(ax, int) for ax in variable_in_axes) assert all(isinstance(ax, int) for ax in variable_out_axes) rng_groups, rng_splits = _unzip2(split_rngs.items()) rng_axes = tuple( 0 if rng_split else axes_scan.broadcast for rng_split in rng_splits ) def inner(scope_fn, repack_fn, variable_groups, rng_groups, init, *args): def find_length(axis, x): if axis is not axes_scan.broadcast: leaves = jax.tree_util.tree_leaves(x) if leaves: return leaves[0].shape[axis] return () # split rngs lengths = jax.tree_util.tree_map(find_length, in_axes, args) lengths = set(jax.tree_util.tree_leaves(lengths)) if length is None and len(lengths) == 1: (d_length,) = lengths elif len(lengths) > 1: raise ValueError(f'Inconsistent scan lengths: {lengths}') elif length is None: raise ValueError('length should be specified manually.') else: d_length = length # random.clone is only available on Jax versions 0.4.26 or newer # see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html if hasattr(random, 'clone'): split_fn = lambda rng: random.split(random.clone(rng), d_length) else: split_fn = lambda rng: random.split(rng, d_length) rng_groups = tuple( tree_map_rngs(split_fn, rng_group) if split else rng_group for rng_group, split in zip(rng_groups, rng_splits) ) @functools.partial( axes_scan.scan, in_axes=(variable_in_axes, rng_axes, in_axes), out_axes=(out_axes, variable_out_axes), length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, check_constancy_invariants=check_constancy_invariants, ) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups if data_transform is not None: variable_groups, rng_groups = data_transform( variable_groups, rng_groups ) scope = scope_fn(variable_groups, rng_groups) c, y = fn(scope, c, *args) out_vars = repack_fn(scope) broadcast_vars_out = out_vars[0] carry_vars = out_vars[1] scan_vars = out_vars[2:] # add immutable broadcast vars back to broadcast output # otherwise they won't be fed to the actual scan body for in_group, out_group in zip(broadcast_vars, broadcast_vars_out): for col in in_group: if col not in out_group: out_group[col] = in_group[col] return broadcast_vars_out, (carry_vars, c), (y, scan_vars) broadcast_vars = variable_groups[0] carry_vars = variable_groups[1] scan_vars = variable_groups[2:] new_scan_vars = [] for scan_group, axis in zip(scan_vars, variable_in_axes): new_scan_vars.append(meta.remove_axis(scan_group, axis, metadata_params)) broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( broadcast_vars, (carry_vars, init), tuple(new_scan_vars), rng_groups, args, ) new_scan_vars = [] for scan_group, axis in zip(scan_vars, variable_out_axes): new_scan_vars.append(meta.add_axis(scan_group, axis, metadata_params)) scan_vars = tuple(new_scan_vars) out_vars = ( broadcast_vars, carry_vars, ) + scan_vars return (c, ys), out_vars return pack( inner, (variable_broadcast, variable_carry) + variable_in_groups, (variable_broadcast, variable_carry) + variable_out_groups, rng_groups, name='scan', ) C = TypeVar('C') def while_loop( cond_fn: Callable[[Scope, C], bool], body_fn: Callable[[Scope, C], C], scope: Scope, init: C, carry_variables: CollectionFilter = False, broadcast_variables: CollectionFilter = True, split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, ) -> C: """Lifted version of jax.lax.while_loop. The lifted scope is passed to `cond_fn` and `body_fn`. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling `body_fn` once manually before calling `while_loop` if variable initialization is required. Example:: def f(scope, x): def cond_fn(scope, c): return scope.get_variable('state', 'acc') < 10 def body_fn(scope, c): acc = scope.variable('state', 'acc') acc += 1 y = scope.child(nn.dense)(c, c.shape[-1]) return y c = x c = body_fn(scope, c) return lift.while_loop(cond_fn, body_fn, scope, (), carry_variables='state') Args: cond_fn: Should return True as long as the loop should continue. body_fn: The body of the while loop. scope: The scope(s) which should be lifted into the loop. init: The initial state passed to the loop carry_variables: collections that are carried through the loop and are therefore mutable (default: none). broadcast_variables: collections that are closed over and are therefore read-only (default: all collections) split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Returns: The final state after executing the while loop. """ rng_groups, rng_splits = _unzip2(split_rngs.items()) def inner(scope_fn, repack_fn, variable_groups, rng_groups): carry_variables, broadcast_variables = variable_groups def make_loop_rngs(i): local_rng_groups = [] for rng_group, rng_split in zip(rng_groups, rng_splits): if rng_split: rng_group = tree_map_rngs( lambda rng: random.fold_in(rng, i), rng_group ) local_rng_groups.append(rng_group) return local_rng_groups def cond_wrapper(c): i, carry_variables, carry = c scope = scope_fn( (carry_variables, broadcast_variables), make_loop_rngs(-i), mutable_filter=False, ) return cond_fn(scope, carry) def body_wrapper(c): i, carry_variables, carry = c scope = scope_fn( (carry_variables, broadcast_variables), make_loop_rngs(i) ) carry = body_fn(scope, carry) (carry_variables,) = repack_fn(scope) return (i + 1, carry_variables, carry) c = (0, carry_variables, init) _, carry_variables, carry = jax.lax.while_loop( cond_wrapper, body_wrapper, c ) return carry, (carry_variables,) return pack( inner, (carry_variables, broadcast_variables), (carry_variables,), rng_groups, name='while_loop', )(scope) def cond( pred: Any, true_fun: Callable[..., C], false_fun: Callable[..., C], scope: Scope, *operands, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> C: """Lifted version of ``jax.lax.cond``. The returned values from ``true_fun`` and ``false_fun`` must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different. Example:: def cond_example(scope, x, pred): scope.variable('state', 'true_count', lambda: 0) scope.variable('state', 'false_count', lambda: 0) def true_fn(scope, x): scope.variable('state', 'true_count').value += 1 return scope.child(nn.dense)(x, 2) def false_fn(scope, x): scope.variable('state', 'false_count').value += 1 return -scope.child(nn.dense)(x, 2) return lift.cond(pred, true_fn, false_fn, scope, x) Args: pred: determines if true_fun or false_fun is evaluated. true_fun: The function evalauted when ``pred`` is `True`. The signature is (Scope, *operands) -> T. false_fun: The function evalauted when ``pred`` is `False`. The signature is (Scope, *operands) -> T. scope: A Scope or Pytree of scopes to pass *operands: The arguments passed to ``true_fun`` and ``false_fun`` variables: The variable collections passed to the conditional branches (default: all) rngs: The PRNG sequences passed to the conditionals (default: all) Returns: The result of the evaluated branch (``true_fun`` or ``false_fun``). """ branches = [true_fun, false_fun] def inner(scope_fn, repack_fn, variable_groups, rng_groups): def branch_wrapper(branch_fn, *operands): scope = scope_fn(variable_groups, rng_groups) y = branch_fn(scope, *operands) return y, repack_fn(scope) pure_branches = [ functools.partial(branch_wrapper, branch_fn) for branch_fn in branches ] return jax.lax.cond(pred, pure_branches[0], pure_branches[1], *operands) return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope) def switch( index: Any, branches: Sequence[Callable[..., C]], scope: Scope, *operands, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> C: """Lifted version of ``jax.lax.switch``. The returned values from ``branches`` must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different. Example:: def switch_example(scope, x, index): scope.variable('state', 'a_count', lambda: 0) scope.variable('state', 'b_count', lambda: 0) scope.variable('state', 'c_count', lambda: 0) def a_fn(scope, x): scope.variable('state', 'a_count').value += 1 return scope.child(nn.dense)(x, 2) def b_fn(scope, x): scope.variable('state', 'b_count').value += 1 return -scope.child(nn.dense)(x, 2) def c_fn(scope, x): scope.variable('state', 'c_count').value += 1 return scope.child(nn.dense)(x, 2) return lift.switch(index, [a_fn, b_fn, c_fn], scope, x) If you want to have a different parameter structure for each branch you should run all branch on initialization before calling switch:: def multihead_switch_example(scope, x, index): def a_fn(scope, x): x = scope.child(nn.dense)(x, 10) x = scope.child(nn.dense)(x, 7) return scope.child(nn.dense)(x, 5) def b_fn(scope, x): x = scope.child(nn.dense)(x, 11) return scope.child(nn.dense)(x, 5) def c_fn(scope, x): return scope.child(nn.dense)(x, 5) branches = [a_fn, b_fn, c_fn] # run all branches on init if scope.is_mutable_collection('params'): for branch in branches: _ = branch(scope, x) return lift.switch(index, branches, scope, x) Args: index: Integer scalar type, indicating which branch function to apply. branches: Sequence of functions to be applied based on index. The signature of each function is (Scope, *operands) -> T. scope: A Scope or Pytree of scopes to pass *operands: The arguments passed to ``true_fun`` and ``false_fun`` variables: The variable collections passed to the conditional branches (default: all) rngs: The PRNG sequences passed to the conditionals (default: all) Returns: The result of the evaluated branch. """ def inner(scope_fn, repack_fn, variable_groups, rng_groups): def branch_wrapper(branch_fn, *operands): scope = scope_fn(variable_groups, rng_groups) y = branch_fn(scope, *operands) return y, repack_fn(scope) pure_branches = [ functools.partial(branch_wrapper, branch_fn) for branch_fn in branches ] return jax.lax.switch(index, pure_branches, *operands) return pack(inner, (variables,), (variables,), (rngs,), name='switch')(scope) def custom_vjp( fn: Callable[..., Any], forward_fn: Callable[..., Any], backward_fn: Callable[..., Any], grad_vars: CollectionFilter = 'params', nondiff_argnums=(), ): """Lifted version of `jax.custom_vjp`. `forward_fn` and `backward_fn` together define a custom vjp for `fn`. The original `fn` will run in case a vjp (backward gradient) is not computed. The `forward_fn` receives the same arguments as `fn` but is expected to return a tuple containing the output of `fn(scope, *args)` and the residuals that are passed to `backward_fn`. The `backward_fn` receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents. Note that the vjp function returned by `lift.vjp` can be passed as residual and used in the `backward_fn`. The scope is unavailable during the backward pass. If the scope is required in `backward_fn`, a snapshot of the variables can be taken and returned as a residual in the `forward_fn`. Example:: f = nn.dense def fwd(scope, x, features): y, vjp_fn = lift.vjp(partial(f, features=features), scope, x) return y, vjp_fn def bwd(features, vjp_fn, y_t): params_t, *inputs_t = vjp_fn(y_t) params_t = jax.tree_util.tree_map(jnp.sign, params_t) return (params_t, *inputs_t) dense_sign_grad = lift.custom_vjp( f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,)) Args: fn: The function to define a custom_vjp for. The first argument should be a ``Module`` instance. forward_fn: A function with the same arguments as `fn` returning an tuple with the original output and the residuals that will be passed to `backward_fn`. backward_fn: arguments are passed as (*nondiff_args, residuals, tangents) The function should return a tuple containing the tangents for the variable in the collections specified by `grad_vars` and the input arguments (except the scope and nondiff args). grad_vars: The collections for which a vjp will be computed (default: "params"). nondiff_argnums: arguments for which no vjp is computed. Returns: A function with the same signature as `fn` with the custom vjp. """ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): grad_variables, other_variables = variable_groups scopes_treedef = None def f(grad_variables, *args): scope = scope_fn((grad_variables, other_variables), rng_groups) y = fn(scope, *args) vars_out = repack_fn(scope) return y, vars_out f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums) def f_fwd(grad_variables, *args): nonlocal scopes_treedef scopes = scope_fn((grad_variables, other_variables), rng_groups) scopes_treedef = jax.tree_util.tree_structure(scopes) y, res = forward_fn(scopes, *args) vars_out = repack_fn(scopes) return (y, vars_out), res def f_bwd(*args): # the backward function does not pass a lifted scope to the user. # Currently, there is no way to have side effects flow out of backward # pass. Even without mutation variables would be ill-defined. For example, # would we take a snapshot of the variables before or after calling # `forward_fn`? nondiff_args = args[:-2] res, g = args[-2:] # pylint: disable=unbalanced-tuple-unpacking g_y, _ = g var_t, *inputs_t = backward_fn(*nondiff_args, res, g_y) assert scopes_treedef is not None, 'backward called before forward?!' var_t = tuple(scopes_treedef.flatten_up_to(var_t)) return (var_t, *inputs_t) f.defvjp(f_fwd, f_bwd) return f(grad_variables, *args) variable_in_groups = (grad_vars, True) variable_out_groups = (grad_vars, True) rng_groups = (True,) return pack( inner, variable_in_groups, variable_out_groups, rng_groups, name='custom_vjp', ) def checkpoint( fn: Callable[..., Any], variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: Callable[..., bool] | None = None, ) -> Callable[..., Any]: """Lifted version of ``jax.checkpoint``. This function is aliased to ``lift.remat`` just like ``jax.remat``. Args: fn: scope function for which intermediate computations should be re-computed when computing gradients. variables: The variable collections that are lifted. By default all collections are lifted. rngs: The PRNG sequences that are lifted. By default all PRNG sequences are lifted. concrete: Optional, boolean indicating whether ``fun`` may involve value-dependent Python control flow (default ``False``). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with :func:`jax.jit` it can lead to some extra computation. prevent_cse: Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a ``scan``, this CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` can be set to False. static_argnums: Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. policy: Experimental checkpoint policy, see ``jax.checkpoint``. Returns: A wrapped version of ``fn``. When computing gradients intermediate computations will be re-computed when computing gradients. """ def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs): # add 2 to each static_argnums because we add two initial arguments to rematted static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums) # After JAX v0.3.16, concrete=False is a no-op and concrete=True raises # NotImplementedError. Starting in JAX v0.8.2, the concrete argument is # deprecated and will be removed in the future. if concrete: raise NotImplementedError( "The concrete argument is deprecated. Use static_argnums instead." " for more information, see" " https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html" ) @functools.partial( jax.remat, static_argnums=static_argnums_, prevent_cse=prevent_cse, policy=policy, ) @functools.wraps(fn) def rematted(variable_groups, rng_groups, *args, **kwargs): scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args, **kwargs) return y, repack_fn(scope) return rematted(variable_groups, rng_groups, *args, **kwargs) return pack( inner, (variables,), (variables,), (rngs,), name='remat', enable_kwargs=True, ) remat = checkpoint def _hashable_filter(x): """Hashable version of CollectionFilter.""" if isinstance(x, str): return (x,) if isinstance(x, Iterable): return tuple(x) # convert un-hashable list & sets to tuple if isinstance(x, DenyList): return DenyList( _hashable_filter(x.deny) ) # convert inner filter recursively return x class CountsHolder: def __init__(self, flat_d): self.flat_d = flat_d @classmethod def make(cls, d): flat_d = traverse_util.flatten_dict(d) flat_d = {k: v for k, v in flat_d.items()} return cls(flat_d) def sub(self, other): delta_flat_d = {} new_flat_d = collections.defaultdict(int, self.flat_d) old_flat_d = collections.defaultdict(int, other.flat_d) for k in new_flat_d: delta_flat_d[k] = new_flat_d[k] - old_flat_d[k] return CountsHolder(delta_flat_d) def add(self, other): delta_flat_d = {} new_flat_d = collections.defaultdict(int, self.flat_d) old_flat_d = collections.defaultdict(int, other.flat_d) for k in new_flat_d: delta_flat_d[k] = new_flat_d[k] + old_flat_d[k] return CountsHolder(delta_flat_d) def unflat(self): return traverse_util.unflatten_dict(self.flat_d) def set_from_dict(original, updates): for k in updates: if k not in original: original[k] = updates[k] else: if isinstance(updates[k], dict): set_from_dict(original[k], updates[k]) else: original[k] = updates[k] class _SideEffectCache(threading.local): def __init__(self): self.cache = {} _side_effect_cache = _SideEffectCache() def _restore_rng_counters(scopes, fingerprint, capture_old_counts): if fingerprint not in _side_effect_cache.cache: capture_new_counts = jax.tree.map( lambda s: CountsHolder.make(s.rng_counters), scopes ) capture_delta_counts = jax.tree.map( lambda old, new: new.sub(old), capture_old_counts, capture_new_counts, ) _side_effect_cache.cache[fingerprint] = capture_delta_counts else: updated_counts = jax.tree.map( lambda x, y: x.add(y).unflat(), _side_effect_cache.cache[fingerprint], capture_old_counts, ) jax.tree.map( lambda s, u: set_from_dict(s.rng_counters, u), scopes, updated_counts, ) def jit( fn: Callable[..., Any], variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, static_argnums: int | Iterable[int] = (), static_argnames: str | Iterable[str] = (), donate_argnums: int | Iterable[int] = (), device=None, backend: str | None = None, ) -> Callable[..., Any]: """Lifted version of ``jax.jit``. Args: fn: Scope function to be jitted. variables: The variable collections that are lifted. By default all collections are lifted. rngs: The PRNG sequences that are lifted. By default all PRNG sequences are lifted. static_argnums: An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by ``static_argnums`` then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to (). static_argnames: An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. donate_argnums: Specify which arguments are "donated" to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. device: This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. backend: a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. Returns: A wrapped version of ``fn``, set up for just-in-time compilation. """ if not isinstance(static_argnums, Iterable): static_argnums = (static_argnums,) if not isinstance(donate_argnums, Iterable): donate_argnums = (donate_argnums,) # offset argnums by two because first argument in the original function is the # scope while jitted has 3 functions before the user arguments. static_argnums = (0,) + tuple(i + 2 for i in static_argnums if i > 0) donate_argnums = tuple(i + 2 for i in donate_argnums if i > 0) # Close over scope_fn & repack_fn to avoid recompilation # this is impure but we use the fingerprint arg to differentiate between cases # where scope_fn or repack_fn actually produce non-identical results. jit_context = TransformContext[tuple[Callable, Callable]]() @functools.partial( jax.jit, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, device=device, backend=backend, ) @functools.wraps(fn) def jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs): scope_fn, repack_fn = jit_context.get() hash_key = fingerprint[1] # fingerprint is only used to differentiate the cache signature # del fingerprint scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable y = fn(scope, hash_key, *args, **kwargs) return y, repack_fn(scope) # pylint: disable=not-callable def inner( scope_fn, repack_fn, variable_groups, rng_groups, module_hash_key, *args, **kwargs, ): with jit_context.push((scope_fn, repack_fn)): scopes: list[Scope] = jax.tree_util.tree_leaves( scope_fn(variable_groups, rng_groups) ) mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) rng_groups = jax.tree.map( lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, rng_groups, is_leaf=lambda x: isinstance(x, LazyRng), ) fingerprint = (mutable, module_hash_key) capture_old_counts = jax.tree.map( lambda s: CountsHolder.make(s.rng_counters), scopes ) res = jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) _restore_rng_counters(scopes, fingerprint, capture_old_counts) return res return pack( inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True ) def remat_scan( body_fn: Callable[..., Any], lengths: Sequence[int], policy: Callable[..., bool] | None = None, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0}, split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True}, ) -> Callable[..., Any]: """Combines `lift.remat` and `lift.scan` for memory efficiency and constant time compilation. ``remat_scan`` allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models. Example:: def body_fn(scope, x): return nn.dense(scope, x, features=x.shape[-1]) # 100x dense with O(sqrt(N)) memory for gradient computation y = lift.remat_scan(body_fn, lengths=(10, 10))(scope, x) Args: body_fn: Scope function to be repeated using a (nested scan) lengths: number of loop iterations at the given level. The total number of iterations `n = prod(lengths)`. each loop is rematerialized. This way the memory consumption is proportional to `n^(1 / d)` where `d = len(lengths)`. Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop. policy: Experimental checkpoint policy, see ``jax.checkpoint``. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. variable_carry: Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. variable_axes: the variable collections that are scanned over. split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Returns: A wrapped version of ``body_fn`` that repeats itself prod(lengths) times. """ # TODO(jheek) should remat scan have scan inputs/outputs? scan_fn = functools.partial( scan, variable_broadcast=variable_broadcast, variable_carry=variable_carry, variable_axes=variable_axes, split_rngs=split_rngs, ) if len(lengths) == 1: def wrapper(scope, carry): return body_fn(scope, carry), () fn = lambda scope, c: scan_fn(wrapper, length=lengths[0])(scope, c)[0] else: @functools.partial(remat, policy=policy, prevent_cse=False) def inner_loop(scope, carry): carry = remat_scan( body_fn, lengths[1:], policy, variable_broadcast, variable_carry, variable_axes, split_rngs, )(scope, carry) return carry, () fn = lambda scope, c: scan_fn(inner_loop, length=lengths[0])(scope, c)[0] return fn def _unzip2(xs): ys = tuple(zip(*xs)) return ys if ys else ((), ()) def _broadcast_prefix_tree(prefix_tree: Any, full_tree: Any) -> list[Any]: bcast_flat = [] num_leaves_fn = lambda t: jax.tree.flatten(t)[1].num_leaves jax.tree.map( lambda x, subtree: bcast_flat.extend([x] * num_leaves_fn(subtree)), prefix_tree, full_tree, is_leaf=lambda x: x is None, ) return bcast_flat def fold_rngs( fn: Callable[..., Any], variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> Callable[..., Any]: # Close over scope_fn & repack_fn to avoid recompilation # this is impure but we use the fingerprint arg to differentiate between cases # where scope_fn or repack_fn actually produce non-identical results. fold_rngs_context = TransformContext[tuple[Callable, Callable]]() @functools.wraps(fn) def wrapped_fold_rngs(fingerprint, variable_groups, rng_groups, *args, **kwargs): scope_fn, repack_fn = fold_rngs_context.get() hash_key = fingerprint[1] # fingerprint is only used to differentiate the cache signature # del fingerprint scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable y = fn(scope, hash_key, *args, **kwargs) return y, repack_fn(scope) # pylint: disable=not-callable def inner_fold_rngs( scope_fn, repack_fn, variable_groups, rng_groups, module_hash_key, *args, **kwargs, ): with fold_rngs_context.push((scope_fn, repack_fn)): scopes: list[Scope] = jax.tree_util.tree_leaves( scope_fn(variable_groups, rng_groups) ) mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) rng_groups = jax.tree.map( lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, rng_groups, is_leaf=lambda x: isinstance(x, LazyRng), ) fingerprint = (mutable, module_hash_key) capture_old_counts = jax.tree.map( lambda s: CountsHolder.make(s.rng_counters), scopes ) res = wrapped_fold_rngs( fingerprint, variable_groups, rng_groups, *args, **kwargs ) _restore_rng_counters(scopes, fingerprint, capture_old_counts) return res return pack( inner_fold_rngs, (variables,), (variables,), (rngs,), name='fold_rngs', enable_kwargs=True, ) ================================================ FILE: flax/core/meta.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Boxed Metadata API Boxed metadata enables tracking arbitrary metadata for linen variables that is compatible with lifted transformations. See ``Partitioned`` for a practical example on how to use this metadata to keep track of how variables should be partitioned with ``jax.pjit``. """ import abc import dataclasses import functools from typing import Any, Generic, TypeVar from collections.abc import Callable from flax import errors, struct from flax.typing import LogicalNames import jax A = TypeVar('A') B = TypeVar('B') TAxisMetadata = TypeVar('TAxisMetadata', bound='AxisMetadata[Any]') class AxisMetadata(Generic[A], metaclass=abc.ABCMeta): """Abstract base class for boxed Metadata. ``AxisMetadata`` enables arbitrary, per axis metadata for variables. By using ``unbox`` the metadata is stripped away to obtain the original variables. By using unboxing, most code handling variables does not need to handle ``AxisMetadata`` specifically, but can directly operate on the JAX arrays that they wrap. Additionally, ``AxisMetadata`` supports updating metadata whenever an axis is added or removed by a functional transformation (e.g.: ``nn.scan`` or ``nn.vmap``) using the ``add_axis`` and ``remove_axis`` methods. By extending ``AxisMetadata``, custom metadata can be stored. See ``Partitioned`` for a specific implementation. """ @abc.abstractmethod def unbox(self) -> A: """Returns the content of the AxisMetadata box. Note that unlike ``meta.unbox`` the unbox call should not recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata. In practise, AxisMetadata subclasses should be registered as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this node should correspond to the value returned by unbox. Returns: The unboxed value. """ pass @abc.abstractmethod def replace_boxed(self, val: B) -> 'AxisMetadata[B]': """Replaces the boxed value with the provided value. Args: val: The new value to be boxed by this AxisMetadata wrapper Returns: A new instance of the same type as self with `val` as the new ``unbox`` content """ pass @abc.abstractmethod def add_axis( self: TAxisMetadata, index: int, params: dict[Any, Any] ) -> TAxisMetadata: """Adds a new axis to the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.add_axis(i, p).remove_axis(i, p) == x``) Args: index: The position at which the new axis will be inserted params: An arbitrary dictionary of parameters passed by the transformation that introduces the new axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass @abc.abstractmethod def remove_axis( self: TAxisMetadata, index: int, params: dict[Any, Any] ) -> TAxisMetadata: """Removes an axis from the axis metadata. Note that add_axis and remove_axis should act as each other's inverse (meaning: ``x.remove_axis(i, p).add_axis(i, p) == x``) Args: index: The position of the axis that is to be removed params: An arbitrary dictionary of parameters passed by the transformation that introduced the axis (e.g.: ``nn.scan`` or ``nn.vmap``). The user passes this dictionary as the `metadata_param` argument to the transformation. Returns: A new instance of the same type as self and with the same ``unbox`` content with updated axis metadata. """ pass def is_axis_metadata(val: Any) -> bool: """Returns whether the argument is an instance of AxisMetadata.""" return isinstance(val, AxisMetadata) def map_axis_meta(fn: Callable[[AxisMetadata[Any]], Any], tree: Any) -> Any: """Maps over all PyTree nodes that are AxisMetadata instances.""" def wrapper(x): if isinstance(x, AxisMetadata): return fn(x) else: return x return jax.tree_util.tree_map(wrapper, tree, is_leaf=is_axis_metadata) def add_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: """Add an axis to each AxisMetadata node in a PyTree.""" return map_axis_meta(lambda x: x.add_axis(index, params), tree) def remove_axis(tree: Any, index: int, params: dict[Any, Any]) -> Any: """Remove an axis from each AxisMetadata node in a PyTree.""" return map_axis_meta(lambda x: x.remove_axis(index, params), tree) def unbox(tree: Any) -> Any: """Strips all AxisMetadata boxes from a PyTree.""" return map_axis_meta(lambda x: unbox(x.unbox()), tree) def replace_boxed(tree: Any, updates: Any) -> Any: """Updates all AxisMetadata boxes with the values in updates.""" def inner_update(c, v): if isinstance(c, AxisMetadata): return c.replace_boxed(replace_boxed(c.unbox(), v)) else: return v return jax.tree_util.tree_map( inner_update, tree, updates, is_leaf=is_axis_metadata ) PARTITION_NAME = 'partition_name' def get_global_mesh() -> jax.sharding.AbstractMesh | jax.sharding.Mesh | None: mesh = jax.sharding.get_abstract_mesh() if mesh.empty: return None return mesh def global_mesh_defined() -> bool: """Checks if global mesh resource environment is defined.""" mesh = get_global_mesh() return mesh is not None class Partitioned(struct.PyTreeNode, AxisMetadata[A]): """Wrapper for partitioning metadata. ``Partitioned`` is used to extend variables with partitioning information required for ``jax.experimental.pjit``. The easiest way to define Partitioned variables is by using the ``with_partitioning`` wrapper around the variable initializer. Example:: class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x) ``Partitioned`` values can gain additional axes when using transformations like ``nn.vmap`` and ``nn.scan``. In this case you can specify the name of the new axis with the `metadata_params` args in vmap/scan:: class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c """ value: Any names: LogicalNames = struct.field(pytree_node=False) mesh: jax.sharding.Mesh | None = struct.field( default=None, pytree_node=False ) def unbox(self, apply_constraint=True) -> A: """Returns the wrapped value with the partitioning applied as a sharding constraint.""" if apply_constraint and (global_mesh_defined() or self.mesh is not None): axis_resource = self.get_partition_spec() if self.mesh is not None: sharding = jax.sharding.NamedSharding(self.mesh, axis_resource) return jax.lax.with_sharding_constraint(self.value, sharding) return jax.lax.with_sharding_constraint(self.value, axis_resource) else: return self.value def replace_boxed(self, val: B) -> 'Partitioned[B]': return self.replace(value=val) # type: ignore def _get_partition_name(self, params: dict[Any, Any]) -> str: if PARTITION_NAME not in params: raise errors.PartitioningUnspecifiedError(self) return params[PARTITION_NAME] def add_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned[A]': axis_name = self._get_partition_name(params) names = list(self.names) while len(names) < index: names.append(None) # type: ignore names.insert(index, axis_name) # type: ignore return self.replace(names=tuple(names)) def remove_axis(self, index: int, params: dict[Any, Any]) -> 'Partitioned[A]': axis_name = self._get_partition_name(params) names = list(self.names) assert names.pop(index) == axis_name return self.replace(names=tuple(names)) def get_partition_spec(self) -> jax.sharding.PartitionSpec: """Returns the ``Partitionspec`` for this partitioned value.""" return jax.sharding.PartitionSpec(*self.names) def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: """Returns the ``NamedSharding`` for this partitioned value.""" return jax.sharding.NamedSharding(mesh, self.get_partition_spec()) def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = dict(vars(self)) metadata['out_sharding'] = metadata.pop('names') return metadata @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" metadata['names'] = metadata.pop('out_sharding') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) def with_partitioning( fn: Callable[..., Any], names: LogicalNames, mesh: jax.sharding.Mesh | None = None, ) -> Callable[..., Partitioned[Any]]: """Wraps a function's return value with Partitioned. Example:: >>> import flax.linen as nn >>> kernel_init = nn.with_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init) Args: fn: The function to be wrapped. Typically this is an initializer. names: The logical axis passed to ``Partitioned``. mesh: The mesh to use for the partitioning. If None, the global mesh resource is used if available. Returns: A function wrapping ``fn`` that will return an instance of ``Partitioned``. """ @functools.wraps(fn) def wrapper(*args, **kwargs): return Partitioned(fn(*args, **kwargs), names, mesh=mesh) return wrapper def _get_leaf_pspec(x: Any) -> jax.sharding.PartitionSpec | None: if hasattr(x, 'get_partition_spec'): return x.get_partition_spec() # Unboxed arrays, which should be replicated across all devices elif hasattr(x, 'shape'): return jax.sharding.PartitionSpec() else: return None def get_partition_spec(tree: Any) -> Any: """Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values.""" return jax.tree_util.tree_map( _get_leaf_pspec, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) ) def get_sharding(tree: Any, mesh: jax.sharding.Mesh) -> Any: """Extracts a jax.sharding tree from a PyTree containing ``Partitioned`` values and a mesh.""" def f(x: Any) -> jax.sharding.Sharding | None: if hasattr(x, 'get_sharding'): return x.get_sharding(mesh) pspec = _get_leaf_pspec(x) if pspec is None: return None return jax.sharding.NamedSharding(mesh, pspec) return jax.tree_util.tree_map( f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) ) ================================================ FILE: flax/core/nn/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax Neural Network api.""" # pylint: disable=g-multiple-import # re-export commonly used modules and functions from flax.linen import activation as activation from flax.linen import initializers as initializers from flax.linen.activation import ( celu as celu, elu as elu, gelu as gelu, glu as glu, leaky_relu as leaky_relu, log_sigmoid as log_sigmoid, log_softmax as log_softmax, relu as relu, sigmoid as sigmoid, silu as silu, soft_sign as soft_sign, softmax as softmax, softplus as softplus, swish as swish, tanh as tanh, ) from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool) from .attention import ( dot_product_attention as dot_product_attention, multi_head_dot_product_attention as multi_head_dot_product_attention, ) from .linear import ( Embedding as Embedding, conv_transpose as conv_transpose, conv as conv, dense_general as dense_general, dense as dense, embedding as embedding, ) from .normalization import ( batch_norm as batch_norm, group_norm as group_norm, layer_norm as layer_norm, ) from .stochastic import dropout as dropout # pylint: enable=g-multiple-import ================================================ FILE: flax/core/nn/attention.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Attention core modules for Flax.""" import functools from collections.abc import Iterable # pylint: disable=g-importing-member from typing import Any from collections.abc import Callable import jax import jax.numpy as jnp import numpy as np from jax import lax, random from flax import struct from flax.core import Scope from flax.linen import initializers from .linear import default_kernel_init, dense_general def dot_product_attention( scope, query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, precision=None, ): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, dim1, dim2,..., dimN, num_heads, value_channels]`. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1] if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis,) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError( 'Attention axis must be between the batch axis and the last-two axes.' ) depth = query.shape[-1] n = key.ndim # batch_dims is , num_heads> batch_dims = tuple(np.delete(range(n), axis + (n - 1,))) # q & k -> (bs, , num_heads, , channels) qk_perm = batch_dims + axis + (n - 1,) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, , num_heads, channels, ) v_perm = batch_dims + (n - 1,) + axis value = value.transpose(v_perm) query = query / jnp.sqrt(depth).astype(dtype) batch_dims_t = tuple(range(len(batch_dims))) attn_weights = lax.dot_general( query, key, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), precision=precision, ) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = lax.exp( attn_weights - jax.scipy.special.logsumexp(attn_weights, axis=norm_dims, keepdims=True) ) attn_weights = attn_weights.astype(dtype) # apply dropout if not deterministic and dropout_rate > 0.0: if dropout_rng is None: dropout_rng = scope.make_rng('dropout') keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)) :] dropout_shape = tuple([1] * len(batch_dims_t)) + dropout_dims keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = keep.astype(attn_weights.dtype) / jnp.asarray( keep_prob, dtype=dtype ) attn_weights = attn_weights * multiplier # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = lax.dot_general( attn_weights, value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), precision=precision, ) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) return y def _invert_perm(perm): perm_inv = [0] * len(perm) for i, j in enumerate(perm): perm_inv[j] = i return tuple(perm_inv) class CacheEntry(struct.PyTreeNode): key: np.ndarray value: np.ndarray i: np.ndarray def multi_head_dot_product_attention( scope: Scope, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=False, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), bias=True, attention_fn=dot_product_attention, ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.deprecated.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert ( causal_mask or not cache ), 'Caching is only support for causal attention.' if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert ( qkv_features % num_heads == 0 ), 'Memory dimension must be divisible by number of heads.' head_dim = qkv_features // num_heads dense = functools.partial( dense_general, axis=-1, dtype=dtype, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision, ) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) if cache: cache_entry: Callable[[Any], CacheEntry] | CacheEntry if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) def init_fn(shape, dtype=jnp.float32): full_shape = shape + tail_shape if len(full_shape) != ndim: raise ValueError( 'Shape should be a tuple with the shape of the batch' 'and attention dims.' ) return CacheEntry( key=jnp.zeros(full_shape, dtype), value=jnp.zeros(full_shape, dtype), i=jnp.zeros((), jnp.uint32), ) cache_entry = init_fn else: cache_entry = scope.get_variable('cache', 'entry') if not isinstance(cache_entry, CacheEntry): raise ValueError('Cache is not initialized.') expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError( 'Invalid shape provided, expected shape %s instead got %s.' % (expected_shape, inputs_q.shape) ) cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = np.prod(np.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) # type: ignore value = lax.dynamic_update_slice(cache_entry.value, value, indices) # type: ignore one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace( i=cache_entry.i + one, key=key, value=value ) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2] ) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] scope.put_variable('cache', 'entry', cache_entry) # create attention masks mask_components = [] if causal_mask: if cache and isinstance(cache_entry, CacheEntry): bias_pre_shape = (1,) * (key.ndim - 1) attn_shape = tuple(np.take(key.shape, attention_axis)) attn_size = np.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask( padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, ) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True, ) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype), ) else: attention_bias = None # apply attention x = scope.child(attention_fn)( query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic, ) # back to the original inputs dimensions out = scope.child(dense_general, name='out')( x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, ) return out # TODO(flax-dev): Consider refactoring MultiHeadDotProductAttention and moving # causal_mask and cache support into this class instead. # SelfAttention = MultiHeadDotProductAttention.partial(inputs_kv=None) def make_padding_mask( padding_mask_query, padding_mask_key, query_shape, key_shape, attention_axis=None, segmentation_mask=False, ): """Makes padding mask for attention weights. In case of 1d inputs (i.e., `[bs, len, features]`, the attention weights will be `[bs, len, len]` and this function makes a square matrix [len, len]. Args: padding_mask_query: padding mask of query padding_mask_key: padding mask of query query_shape: shape of the query key_shape: shape of the key, which is equal to the shape of value. attention_axis: axis over which attention is applied. segmentation_mask: bool: if true use equality on cartesian product rather than outer product for constructing segmentation masks. Returns: The padding mask for attention weights. """ assert query_shape[0] == key_shape[0] assert len(query_shape) == len(key_shape) ndim = len(key_shape) if attention_axis is None: attention_axis = tuple(range(1, ndim - 2)) assert isinstance(attention_axis, tuple) for ax in attention_axis: if not (ndim >= 3 and 1 <= ax < ndim - 2): raise ValueError( 'Attention axis must be between the batch axis and the last-two axes.' ) mask_shape_final = (query_shape[0], 1) # batch_size, 1 (for all heads)s for ax in attention_axis: mask_shape_final += (query_shape[ax],) for ax in attention_axis: mask_shape_final += (key_shape[ax],) padding_mask_query = padding_mask_query[..., None] padding_mask_key = padding_mask_key[..., None] perm = (0,) + tuple(np.flip(np.arange(padding_mask_key.ndim)))[:-1] if segmentation_mask: mask = jnp.equal(padding_mask_query, padding_mask_key.transpose(perm)) else: mask = jnp.multiply(padding_mask_query, padding_mask_key.transpose(perm)) mask = mask.reshape(mask_shape_final) mask = jax.lax.convert_element_type(mask, jnp.float32) return mask def _make_causal_mask(key, attention_axis=None, self_mask=False): """Makes a causal mask, to be used for masking out the future for attention. In case of 1d inputs (i.e., `[bs, len, features]`, the attention weights will be `[bs, len, len]` and this function makes a square matrix [len, len] with zeros in upper triangle and ones in lower triangle. Args: key: shape of the key, which is equal to the shape of value and is assumed to be equal to the shape of the query (since this is used in self-attention when decoding). attention_axis: axis over which attention is applied. self_mask: if mask out the diagonal or not. Returns: A causal mask to be used to mask out future positions. """ if attention_axis is None: attention_axis = tuple(range(1, key.ndim - 2)) assert isinstance(attention_axis, tuple) for ax in attention_axis: if not (key.ndim >= 3 and 1 <= ax < key.ndim - 2): raise ValueError( 'Attention axis must be between the batch axis and the last-two axes.' ) mask_shape = tuple([1] * (key.ndim - len(attention_axis) - 1)) mask_shape_final = mask_shape for _ in range(2): flatten_dim = 1 for ax in attention_axis: mask_shape_final += (key.shape[ax],) flatten_dim *= key.shape[ax] mask_shape += (flatten_dim,) def tri(n, m, k=0): # Tie in the key to avoid the mask becoming a constant. # This way XLA can construct the mask during computation and fuse it # with the attention ops. x = jnp.arange(n, dtype=jnp.int32) y = jnp.arange(m, dtype=jnp.int32) mask = lax.ge( (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k, lax.broadcast(y, [n]), ) return mask k = -1 if self_mask else 0 mask = tri(*mask_shape[-2:], k=k).reshape(mask_shape_final) return mask ================================================ FILE: flax/core/nn/linear.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Linear modules.""" from collections.abc import Iterable # pylint: disable=g-importing-member import jax.numpy as jnp import numpy as np from jax import lax from flax import struct from flax.core import Scope from flax.linen import initializers default_kernel_init = initializers.lecun_normal() def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple(ax if ax >= 0 else ndim + ax for ax in axes) def dense_general( scope, inputs, features, axis=-1, batch_dims=(), bias=True, dtype=jnp.float32, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), precision=None, ): """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. batch_dims: tuple with batch axes. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: The transformed input. """ inputs = jnp.asarray(inputs, dtype) if not isinstance(features, Iterable): features = (features,) if not isinstance(axis, Iterable): axis = (axis,) if not isinstance(batch_dims, Iterable): batch_dims = (batch_dims,) features, axis, batch_dims = tuple(features), tuple(axis), tuple(batch_dims) if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims) ) ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) def kernel_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = ( np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), np.prod(shape[-n_features:]), ) kernel = jnp.concatenate( [kernel_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0, ) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel = scope.param('kernel', kernel_init_wrap, batch_shape + kernel_shape) kernel = jnp.asarray(kernel, dtype) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) out = lax.dot_general( inputs, kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=precision, ) if bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) bias = jnp.concatenate( [bias_init(rng, flat_shape, dtype) for _ in range(size_batch_dims)], axis=0, ) return jnp.reshape(bias, shape) bias = scope.param('bias', bias_init_wrap, batch_shape + features) # Reshape bias for broadcast. expand_dims = sorted(set(range(inputs.ndim)) - set(axis) - set(batch_dims)) for ax in expand_dims: bias = jnp.expand_dims(bias, ax) bias = jnp.asarray(bias, dtype) out = out + bias return out def dense( scope, inputs, features, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), ): """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. features: the number of output features. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. Returns: The transformed input. """ inputs = jnp.asarray(inputs, dtype) kernel = scope.param('kernel', kernel_init, (inputs.shape[-1], features)) kernel = jnp.asarray(kernel, dtype) y = lax.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=precision, ) if bias: bias = scope.param('bias', bias_init, (features,)) bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y def _conv_dimension_numbers(input_shape): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) def conv( scope, inputs, features, kernel_size, strides=None, padding='SAME', input_dilation=None, kernel_dilation=None, feature_group_count=1, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), ): """Applies a convolution to the inputs. Args: inputs: input data with dimensions (batch, spatial_dims..., features). features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. input_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs`. Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) if strides is None: strides = (1,) * (inputs.ndim - 2) in_features = inputs.shape[-1] assert in_features % feature_group_count == 0 kernel_shape = kernel_size + (in_features // feature_group_count, features) kernel = scope.param('kernel', kernel_init, kernel_shape) kernel = jnp.asarray(kernel, dtype) dimension_numbers = _conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( inputs, kernel, strides, padding, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, precision=precision, ) if bias: bias = scope.param('bias', bias_init, (features,)) bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y def conv_transpose( scope, inputs, features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, bias=True, dtype=jnp.float32, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros_init(), ): """Applies a transposed convolution to the inputs. Behaviour mirrors that of `jax.lax.conv_transpose`. Args: scope: functional scope. inputs: input data with dimensions (batch, spatial_dims..., features). features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. Returns: The convolved data. """ inputs = jnp.asarray(inputs, dtype) strides = strides or (1,) * (inputs.ndim - 2) in_features = inputs.shape[-1] kernel_shape = kernel_size + (in_features, features) kernel = scope.param('kernel', kernel_init, kernel_shape) kernel = jnp.asarray(kernel, dtype) y = lax.conv_transpose( inputs, kernel, strides, padding, rhs_dilation=kernel_dilation, precision=precision, ) if bias: bias = scope.param('bias', bias_init, (features,)) bias = jnp.asarray(bias, dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 ) @struct.dataclass class Embedding: table: np.ndarray def lookup(self, indices): """Embeds the inputs along the last dimension. Args: indices: input data, all dimensions are considered batch dimensions. Returns: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: raise ValueError('Input type must be an integer or unsigned integer.') return self.table[indices] def attend(self, query): """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth `features` of the embedding. Returns: An array with final dim `num_embeddings` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ return jnp.dot(query, self.table.T) def embedding( scope: Scope, num_embeddings: int, features: int, init_fn=default_embed_init ) -> Embedding: """Creates embedding dataclass. Args: num_embeddings: number of embeddings. features: Number of feature dimensions for each embedding. embedding_init: embedding initializer. Returns: Embedding dataclass with lookup and attend methods. """ table = scope.param('table', init_fn, (num_embeddings, features)) return Embedding(table) # type: ignore ================================================ FILE: flax/core/nn/normalization.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalization modules for Flax.""" import jax.numpy as jnp from jax import lax from flax.core import Scope from flax.linen import initializers def _absolute_dims(ndim, dims): return tuple(ndim + dim if dim < 0 else dim for dim in dims) def batch_norm( scope: Scope, x, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-5, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros_init(), scale_init=initializers.ones_init(), axis_name=None, axis_index_groups=None, kind='batch_stats', ): x = jnp.asarray(x, jnp.float32) axis = axis if isinstance(axis, tuple) else (axis,) axis = _absolute_dims(x.ndim, axis) redux = tuple(i for i in range(x.ndim) if i not in axis) def pmean(x): m = jnp.mean(x, redux, keepdims=True) if axis_name is not None: m = lax.pmean(m, axis_name=axis_name, axis_index_groups=axis_index_groups) return m mean = pmean(x) squeeze_shape = jnp.squeeze(mean).shape mean2 = pmean(jnp.square(x)) var = mean2 - jnp.square(mean) is_init = not scope.has_variable(kind, 'mean') ra_mean = scope.variable(kind, 'mean', jnp.zeros, squeeze_shape) ra_var = scope.variable(kind, 'var', jnp.ones, squeeze_shape) if use_running_average: # if ra_mean is not None: # raise ValueError('batch_stats should be provided if use_running_averages=True') mean = jnp.reshape(ra_mean.value, mean.shape) var = jnp.reshape(ra_var.value, var.shape) else: if not is_init: beta = 1.0 - momentum ra_mean.value += beta * (jnp.squeeze(mean) - ra_mean.value) ra_var.value += beta * (jnp.squeeze(var) - ra_var.value) y = x - mean mul = lax.rsqrt(var + epsilon) if scale: mul = mul * scope.param('scale', scale_init, squeeze_shape).reshape( mean.shape ) y = y * mul if bias: y = y + scope.param('bias', bias_init, squeeze_shape).reshape(mean.shape) return jnp.asarray(y, dtype) def layer_norm( scope: Scope, x, epsilon=1e-6, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros_init(), scale_init=initializers.ones_init(), ): """Applies layer normalization on the input. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Args: x: the inputs epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: If True, bias (beta) is added. scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. Returns: Normalized inputs (the same shape as inputs). """ features = x.shape[-1] mean = jnp.mean(x, axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) var = mean2 - lax.square(mean) mul = lax.rsqrt(var + epsilon) if scale: mul = mul * jnp.asarray( scope.param('scale', scale_init, (features,)), dtype ) y = (x - mean) * mul if bias: y = y + jnp.asarray(scope.param('bias', bias_init, (features,)), dtype) return y def group_norm( scope, x, num_groups=32, group_size=None, epsilon=1e-6, dtype=jnp.float32, bias=True, scale=True, bias_init=initializers.zeros_init(), scale_init=initializers.ones_init(), ): """Applies group normalization to the input (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group. Args: x: the input of shape N...C, where N is a batch dimension and C is a channels dimensions. `...` represents an arbitrary number of extra dimensions that are used to accumulate statistics over. num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). bias: If True, bias (beta) is added. scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. Returns: Normalized inputs (the same shape as inputs). """ x = jnp.asarray(x, jnp.float32) if (num_groups is None and group_size is None) or ( num_groups is not None and group_size is not None ): raise ValueError( 'Either `num_groups` or `group_size` should be ' 'specified, but not both of them.' ) if group_size is not None: channels = x.shape[-1] if channels % group_size != 0: raise ValueError( 'Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, group_size) ) num_groups = channels // group_size input_shape = x.shape group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) x = x.reshape(group_shape) reduction_axis = list(range(1, x.ndim - 2)) + [x.ndim - 1] mean = jnp.mean(x, axis=reduction_axis, keepdims=True) mean_of_squares = jnp.mean(jnp.square(x), axis=reduction_axis, keepdims=True) var = mean_of_squares - jnp.square(mean) x = (x - mean) * lax.rsqrt(var + epsilon) x = x.reshape(input_shape) feature_shape = tuple([1 for d in input_shape[:-1]] + [input_shape[-1]]) if scale: x = x * scope.param('scale', scale_init, feature_shape) if bias: x = x + scope.param('bias', bias_init, feature_shape) return x.astype(dtype) ================================================ FILE: flax/core/nn/stochastic.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Stochastic modules.""" import jax.numpy as jnp from jax import lax, random def dropout(scope, inputs, rate, deterministic=False, rng=None): """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. rate: the probablity of masking out a value. deterministic: if false the inputs are scaled by `1 / (1 - rate)` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used. Returns: The masked inputs. """ if rate == 0.0: return inputs keep_prob = 1.0 - rate if deterministic: return inputs else: if rng is None: rng = scope.make_rng('dropout') mask = random.bernoulli(rng, p=keep_prob, shape=inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) ================================================ FILE: flax/core/partial_eval.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from typing import Any import jax from jax import core from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe from flax import errors def _maybe_unknown(x: Any) -> pe.PartialVal: if isinstance(x, jax.ShapeDtypeStruct): return pe.PartialVal.unknown(core.ShapedArray(x.shape, x.dtype)) else: return pe.PartialVal.known(x) def lazy_init(fn): """Lazily evaluates a function by using the shapes of the inputs. The returned function accepts a combination of JAX values and ``jax.ShapeDtypeStruct`` instances for the inputs for which we don't need concrete values (only the shape and dtype). This API is used by ``core.lazy_init`` or ``Module.lazy_init`` to initialize variables without doing any actual computation on the inputs. Args: fn: the function to be lazily evaluated. Returns: A new function that accepts a mix of concrete values and ``jax.ShapeDtypeStruct`` instances. """ @functools.wraps(fn) def wrapper(*args, **kwargs): # TODO(mattjj,jheek): use a public JAX API # flatten fn and prepare for internal JAX transform inputs_flat, in_tree = jax.tree_util.tree_flatten((args, kwargs)) debug_info = jax.api_util.debug_info("lazy_init", fn, (in_tree,), {}) f_flat, out_tree = jax.api_util.flatten_fun( lu.wrap_init(fn, debug_info=debug_info), in_tree) # map inputs to PartialVal known/unknown # only the computations depending on knowns will be executed in_pvals = [_maybe_unknown(x) for x in inputs_flat] _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) # all outputs should be knowns. If this fails # the user is creating variables that depend on a # argument that was passed as a ShapeDtypeStruct. out_flat = [] for pv, const in out_pvals: if pv is None: # const is the actual value of the known output out_flat.append(const) else: raise errors.LazyInitError(pv) return jax.tree_util.tree_unflatten(out_tree(), out_flat) return wrapper ================================================ FILE: flax/core/scope.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax functional core: Scopes.""" import collections import contextlib import dataclasses import functools import hashlib import typing from typing import ( Any, Generic, Literal, Optional, TypeVar, Union, cast, overload, ) from collections.abc import Callable, Iterable, Mapping, Sequence import jax import numpy as np from jax import numpy as jnp from jax import random, tree_util from flax import config as config from flax import errors, struct, traceback_util from flax.ids import uuid from flax.typing import ( PRNGKey, Array, RNGSequences, Collection, MutableCollection, VariableDict, FrozenVariableDict as FrozenVariableDict, MutableVariableDict, PRNGFoldable, ) from . import meta, partial_eval, tracers from .frozen_dict import FrozenDict, freeze, unfreeze traceback_util.register_exclusion(__file__) T = TypeVar('T') Filter = Union[bool, str, typing.Collection[str], 'DenyList'] # When conditioning on filters we require explicit boolean comparisons. # pylint: disable=g-bool-id-comparison @dataclasses.dataclass(frozen=True, eq=True) class DenyList: """DenyList represents an opt-out based mutability filter. DenyList can be used to make every collection mutable except the ones defined in the given filter. To for example make everything but the params collection mutable:: nn.apply(fn, mutable=nn.DenyList(["params"])) Attributes: deny: The filter representing the collections that are not mutable. """ deny: Filter def __lt__(self, other): if isinstance(other, str): return False if isinstance(other, DenyList): return str(self.deny) < str(other.deny) return NotImplemented def __gt__(self, other): if isinstance(other, str): return True if isinstance(other, DenyList): return str(self.deny) > str(other.deny) return NotImplemented CollectionFilter = Filter PRNGSequenceFilter = Filter class LazyRng(struct.PyTreeNode): """Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng.""" rng: PRNGKey suffix: tuple[PRNGFoldable, ...] = struct.field(pytree_node=False) def as_jax_rng(self) -> PRNGKey: return _fold_in_static(self.rng, self.suffix) @staticmethod def create( rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable ) -> 'LazyRng': if isinstance(rng, LazyRng): return LazyRng(rng.rng, rng.suffix + suffix) else: return LazyRng(rng, suffix) def clear_suffix(self): key = self.rng return LazyRng(key, ()) def _fold_in_static( rng: PRNGKey, data: typing.Collection[PRNGFoldable] ) -> PRNGKey: """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash. This is faster than splitting an PRNGKey because it allows generating new PRNG keys in parallel that are independent of each other. Args: rng: the rng to fold the string into. data: the string to be folded in. Returns: The newly generated PRNG key. """ if not data: return rng m = hashlib.sha1() for x in data: if config.flax_fix_rng_separator: # encode seperate to avoid collisions like for example: ("ab", "c") and ("a", "bc") m.update(b'\00') if isinstance(x, str): m.update(x.encode('utf-8')) elif isinstance(x, int): m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big')) else: raise ValueError(f'Expected int or string, got: {x}') d = m.digest() hash_int = int.from_bytes(d[:4], byteorder='big') return random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore def is_filter_empty(filter_like: Filter) -> bool: """Returns True if `filter_like` is an empty filter. Args: filter_like: The filter to test. Returns: A filter is empty when it is an empty collection, it is a bool with value False, ir it is a DenyList that matches everything. A string filter is never empty. """ if isinstance(filter_like, str): return False if isinstance(filter_like, typing.Collection): return not filter_like if isinstance(filter_like, bool): return not filter_like if isinstance(filter_like, DenyList): # if any arbitrary collection is in the denylist it matches everything so # the filter is empty. This is checked with a stub. return in_filter(filter_like.deny, '__flax_internal_stub__') raise errors.InvalidFilterError(filter_like) def in_filter(filter_like: Filter, col: str) -> bool: """Checks whether a filter can be applied to a collection. Used for both collections and rng sequence filters. Args: filter_like: a filter (either a boolean, a string, or a list of strings) for a collection. col: a collection, which is a string identifying a dictionary of data, for instance "params" or "batch_stats". Returns: True if either `filter_like` is True, equal to `col`, or a sequence containing `col`. """ if isinstance(filter_like, str): return col == filter_like if isinstance(filter_like, typing.Collection): return col in filter_like if isinstance(filter_like, bool): return filter_like if isinstance(filter_like, DenyList): return not in_filter(filter_like.deny, col) raise errors.InvalidFilterError(filter_like) def filter_to_set(x: Filter) -> set[str]: """Converts a Filter into a set of collections, fails on the infinite set. Args: x: a filter (boolean, string, or list of strings). Returns: The input filter represented as a set of strings. """ assert x is not True and not isinstance(x, DenyList), 'Infinite set' if x is False: return set() if isinstance(x, str): return {x} if isinstance(x, typing.Collection): return set(x) raise errors.InvalidFilterError(x) def union_filters(a: Filter, b: Filter) -> Filter: """Takes the union of two filters (similar to a logical or). Args: a: a filter. b: a filter. Returns: The union of the two input filters. For instance, `union_filters('f1', ['f2']) = {'f1', 'f2'}`. """ if a is True or b is True: return True if isinstance(a, DenyList) and isinstance(b, DenyList): return DenyList(intersect_filters(a.deny, b.deny)) if isinstance(b, DenyList): a, b = b, a if isinstance(a, DenyList): return DenyList(subtract_filters(a.deny, b)) a = filter_to_set(a) b = filter_to_set(b) return a.union(b) def subtract_filters(a: Filter, b: Filter) -> Filter: """Returns the subtraction of b from a. Args: a: a filter. b: a filter. Returns: A filter matching with values in a that are not in b. """ if b is True: return False if a is True: return DenyList(b) if isinstance(a, DenyList) and isinstance(b, DenyList): return subtract_filters(b.deny, a.deny) if isinstance(a, DenyList): return DenyList(union_filters(a.deny, b)) if isinstance(b, DenyList): return intersect_filters(a, b.deny) a = filter_to_set(a) b = filter_to_set(b) return a - b def intersect_filters(a: Filter, b: Filter) -> Filter: """Take the intersection of two filters (similar to a logical and). Args: a: a filter. b: a filter. Returns: The intersection of the two input filters. For instance, `intersect_filters('f1', ['f1', 'f2']) = {'f1'}`. """ if a is True: return b if b is True: return a if isinstance(a, DenyList) and isinstance(b, DenyList): return DenyList(union_filters(b.deny, a.deny)) if isinstance(b, DenyList): b, a = a, b if isinstance(a, DenyList): return subtract_filters(b, a.deny) a = filter_to_set(a) b = filter_to_set(b) return a.intersection(b) def group_collections( xs: VariableDict, col_filters: Sequence[CollectionFilter] ) -> Sequence[MutableVariableDict]: """Groups variables by collection filters. Iteratively applies the filters in `col_filters` to `xs`, and adds the result of applying each filter to the output sequence. Each key in `xs` is only added to the output once. Args: xs: a dictionary of variables, keyed by collections (strings). col_filters: a list of collection filters. Returns: A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of applying filter `col_filters[i]` to the remaining keys in `xs`. """ cols: Iterable[str] cols = xs.keys() groups = [] for col_filter in col_filters: remaining_cols = [] group = {} for col in cols: if in_filter(col_filter, col): group[col] = jax.tree_util.tree_map(lambda x: x, xs[col]) else: remaining_cols.append(col) cols = remaining_cols groups.append(group) return tuple(groups) class Variable(Generic[T]): """A Variable object allows mutable access to a variable in a VariableDict. Variables are identified by a collection (e.g., "batch_stats") and a name (e.g., "moving_mean"). The value property gives access to the variable's content and can be assigned to for mutation. """ def __init__(self, scope: 'Scope', collection: str, name: str, unbox: bool): """Initializes a variable. Args: scope: The scope in which the variable is stored. collection: The collection of the variable (e.g., "params"). name: The name of the variable (e.g., "dense"). unbox: Whether to unbox boxed values with metadata. """ self._id = uuid() self.scope = scope self.collection = collection self.name = name self.unbox = unbox @property def value(self) -> T: """Returns the value of this Variable.""" v = self.scope.get_variable(self.collection, self.name) return meta.unbox(v) if self.unbox else v @value.setter def value(self, value: T): """Updates the value of this Variable.""" if self.unbox: cur = self.scope.get_variable(self.collection, self.name) cur_struct = tree_util.tree_structure(cur, is_leaf=meta.is_axis_metadata) value_struct = tree_util.tree_structure( value, is_leaf=meta.is_axis_metadata ) has_meta = any(map(meta.is_axis_metadata, cur_struct.flatten_up_to(cur))) if cur_struct == value_struct and has_meta: # type: ignore[operator] value = meta.replace_boxed(cur, value) self.scope.put_variable(self.collection, self.name, value) def is_mutable(self) -> bool: """Checks if this Variable is mutable.""" return self.scope.is_mutable_collection(self.collection) class _ChildRNGSentinel: pass # used to identify that an rng counter is meant for a child scope child_rng_token = _ChildRNGSentinel() class _DefaultSentinel: pass # used to denote no default flag value on scope no_flag = _DefaultSentinel() # Make sure reference sharing of child variable dictionaries isn't broken. # See https://github.com/google/flax/issues/2022 for more details. def _put_variable(target, key, val): if ( key in target and isinstance(target[key], dict) and isinstance(val, Mapping) ): for k, v in val.items(): _put_variable(target[key], k, v) else: target[key] = val class Scope: """A Scope allows easy access to variables and manages RNGS of a neural network layer. Scopes are purely functional and encapsulated in :class:`flax.linen.module.Module`, so users writing neural network code usually generally do not interact with ``Scopes`` directly. See `core design tests `_ for a number of examples using ``Scopes``. """ reservations: dict[str, set[str | None]] def __init__( self, variables: MutableVariableDict, rngs: RNGSequences | dict[str, LazyRng] | None = None, name: str | None = None, mutable: CollectionFilter = False, parent: Optional['Scope'] = None, path: Iterable[str] = (), debug_path: Iterable[str] = (), flags: Mapping | None = None, ): """Initializes a Scope. Args: variables: VariableDict to initialize the Scope with. rngs: RNGs used in this scope or one of the child scopes. name: name of this scope. mutable: A CollectionFilter determining which variables are mutable. parent: The parent scope. path: The path in the variable tree from the root scope to this scope. It exactly matches the module path. debug_path: Similar to path but could contain transformation decorators. flags: internal flags. """ rngs = {k: LazyRng.create(v) for k, v in rngs.items()} if rngs else {} self._variables = variables self.parent = parent self.name = name self.path = tuple(path) self.debug_path = tuple(debug_path) or self.path self.rngs = rngs self.mutable = mutable self.flags = freeze({} if flags is None else flags) self._root = parent.root if parent else None self.trace_level = tracers.current_trace() self.rng_counters = {key: 0 for key in self.rngs} self.reservations = collections.defaultdict(set) self._invalid = False def __eq__(self, other: Any) -> bool: # If the root variable dict and path are the same, then two scopes behave # identically. Effectively, a scope is nothing more than a cursor into a # variable dict and an rng counter dict. if not isinstance(other, Scope): return False if self is other: return True return ( self.root._variables is other.root._variables and self.path == other.path and self.rng_counters is other.rng_counters ) def __hash__(self) -> int: # see __eq__ return hash((id(self.root._variables), self.path, id(self.rng_counters))) @property def root(self) -> 'Scope': return self._root or self @property def path_text(self) -> str: """Returns the debug path as a human readable string.""" return '/' + '/'.join(self.debug_path) @property def invalid(self) -> bool: """Returns true if this scope is invalidated as a result of `Scope.temporary`.""" return self._invalid def _check_valid(self): if self._invalid: raise errors.InvalidScopeError(self.name) @contextlib.contextmanager def temporary(self): """Returns a context manager that will invalidate this Scope when leaving the context.""" try: yield self finally: self.invalidate() def invalidate(self): """Invalidates the Scope.""" self._invalid = True def mutable_variables(self) -> VariableDict | dict[str, Any]: """Returns an immutable copy of the mutable variables belonging to this Scope.""" self._populate_collections() xs = { k: v for k, v in self._variables.items() if in_filter(self.mutable, k) } if config.flax_return_frozendict: return freeze(xs) return xs def variables(self) -> VariableDict | dict[str, Any]: """Returns an immutable copy of the variables belonging to this Scope.""" self._populate_collections() if config.flax_return_frozendict: return freeze(self._variables) return self._variables def _validate_trace_level(self): tracers.check_trace_level(self.trace_level) def rewound(self, rewind_rngs: bool = False) -> 'Scope': """Returns a rewound version of this Scope. Args: rewind_rngs: if true, reset the RNG counter of this scope. Returns: A rewound version of this scope, which means reservations are emptied, and the rng counter is optionally rewound. """ self._check_valid() scope = Scope( self._variables, self.rngs, self.name, self.mutable, self.parent, path=self.path, debug_path=self.debug_path, flags=self.flags, ) if not rewind_rngs: scope.rng_counters = self.rng_counters return scope def name_reserved(self, name: str, col: str | None = None) -> bool: """Checks whether a name for a child Scope or Variable is taken. Args: name: the name to check for collision. col: if a variable, the collection used. """ if name in self.reservations: # allow the same name for two variables in # different collections, otherwise raise error. if ( None in self.reservations[name] or col is None or col in self.reservations[name] ): return True return False def reserve(self, name: str, col: str | None = None): """Reserves a name for a child Scope or Variable. Throws an error if the name exists already. Args: name: the name to reserve. col: if a variable, the collection used. """ if not isinstance(name, str): raise TypeError( 'The type of scope "{name}" should be string but ' f'it is {type(name)}' ) if self.name_reserved(name, col): raise ValueError(f'Duplicate use of scope name: "{name}"') self.reservations[name].add(col) def default_name(self, prefix: str) -> str: """Generates an unreserved name with the given prefix. Args: prefix: prefix to use for generating an unreserved name. Returns: The generated name. """ i = 0 while True: name = f'{prefix}{i}' if name not in self.reservations: return name i += 1 def push( self, name: str | None = None, prefix: str = '', reuse=False ) -> 'Scope': """Creates a child Scope. Args: name: optional name of the child. prefix: prefix used for generating the name if `name` is `None`. reuse: if True will return a pre-existing child scope with the given name instead of throwing an error. Returns: The child scope. """ self._check_valid() self._validate_trace_level() if name is None: name = self.default_name(prefix) if not reuse or name not in self.reservations: self.reserve(name) rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()} rng_key = (child_rng_token, name) if rng_key in self.rng_counters: rng_counters = self.rng_counters.get(rng_key) # type: ignore else: rng_counters = {key: 0 for key in rngs} self.rng_counters[rng_key] = rng_counters # type: ignore scope = Scope( {}, name=name, rngs=rngs, parent=self, mutable=self.mutable, path=self.path + (name,), debug_path=self.debug_path + (name,), flags=self.flags, ) scope.rng_counters = rng_counters return scope def child( self, fn: Callable[..., Any], name: str | None = None, prefix: str | None = None, named_call: bool = True, **partial_kwargs, ) -> Callable[..., Any]: """Partially applies a child scope to fn. When calling the returned function multiple times variables will be reused. Args: fn: the function to partially apply the child Scope to. name: optional name of the child. prefix: prefix used for generating name if it is `None`. named_call: if true, `fn` will be run under `jax.named_scope`. The XLA profiler will use this to name tag the computation. **partial_kwargs: additional kwargs partially applied to `fn`. Returns: The function with a partially applied scope. """ if name is None: if prefix is None: prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else '' name = self.default_name(prefix) scope = self.push(name) @functools.wraps(fn) def wrapper(*args, **kwargs): kwargs = dict(partial_kwargs, **kwargs) if named_call: with jax.named_scope(name): res = fn(scope.rewound(), *args, **kwargs) else: res = fn(scope.rewound(), *args, **kwargs) return res return wrapper def is_mutable_collection(self, col: str) -> bool: """Returns true if the collection `col` is mutable.""" return in_filter(self.mutable, col) def is_collection_empty(self, col: str) -> bool: """Returns true if the collection is empty.""" if col in self.root._variables: # pylint: disable=protected-access return not self.root._variables[col] # pylint: disable=protected-access return True def _mutable_collection(self, col: str) -> MutableCollection: """Returns the collection `col` as a mutable object.""" assert self.is_mutable_collection(col), f'Collection {col} is not mutable' # The actual variable dict is stored in the root scope only, and subscopes # hold references to subtrees relevant to them. This function ensures that # the collections are created in the top-level Scope and we return the # correct reference. if col not in self._variables: if not self.parent: # If this is the top-level Scope, just add an empty collection. self._variables[col] = {} else: assert self.name is not None # Only top-level Scope have name None. # Populate the parent collections recursively and obtain a reference to # the direct parent (which, by transitivity, is be a reference to a # dict in the root Scope). parent_col = self.parent._mutable_collection(col) # pylint: disable=protected-access if self.name not in parent_col: # If this Scope's name does not occur in the parent collection, add it # to the parent scope (updating the parent's variable dict). parent_col[self.name] = {} # Store a reference to the parent's scope collection for in this scope's # variable dict. self._variables[col] = parent_col[self.name] return self._variables[col] def _collection(self, col: str) -> Collection: """Returns a collection of variables of collection `col`.""" if col not in self._variables: if self.parent: assert self.name is not None parent_col = self.parent._collection(col) # pylint: disable=protected-access if self.name not in parent_col: return FrozenDict() self._variables[col] = parent_col[self.name] else: return FrozenDict() return self._variables[col] def has_rng(self, name: str) -> bool: """Returns true if a PRNGSequence with name `name` exists.""" return name in self.rngs def make_rng(self, name: str = 'params') -> PRNGKey: """Generates A PRNGKey from a PRNGSequence with name `name`.""" if not self.has_rng(name): if self.has_rng('params'): name = 'params' else: raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') self._check_valid() self._validate_trace_level() self.rng_counters[name] += 1 return LazyRng.create(self.rngs[name], self.rng_counters[name]).as_jax_rng() def get_variable(self, col: str, name: str, default: Any = None) -> Any: """Retrieves the value of a Variable. Args: col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in this scope. Returns: The value of the input variable, of the default value if the variable doesn't exist in this scope. """ variables = self._collection(col) if name in variables: return variables[name] else: return default def has_variable(self, col: str, name: str) -> bool: """Returns true if the given variable exists in this scope. Args: col: the collection of the variable. name: the name of the variable. """ variables = self._collection(col) return name in variables def put_variable(self, col: str, name: str, value: Any): """Updates the value of the given variable if it is mutable, or an error otherwise. Args: col: the collection of the variable. name: the name of the variable. value: the new value of the given variable. """ self._check_valid() self._validate_trace_level() if not self.is_mutable_collection(col): raise errors.ModifyScopeVariableError(col, name, self.path_text) variables = self._mutable_collection(col) _put_variable(variables, name, value) @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[True], **init_kwargs, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[False], **init_kwargs, ) -> Variable[meta.AxisMetadata[T]]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: ... def variable( self, col: str, name: str, # pylint: disable=keyword-arg-before-vararg init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: col: the collection of the variable. name: the name of the variable. init_fn: a function taking a PRNGKey plus any other number of positional arguments. If None, the variable must already be initialized otherwise an error is raised. *init_args: the positional arguments to evaluate init_fn on lazily. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: the key-word arguments to evaluate init_fn on lazily. Returns: The variable. Throws an error if the variable exists already. """ self.reserve(name, col) if not self.has_variable(col, name): if not self.is_mutable_collection(col) or init_fn is None: if self.is_collection_empty(col): raise errors.ScopeCollectionNotFound(col, name, self.path_text) raise errors.ScopeVariableNotFoundError(name, col, self.path_text) init_value = init_fn(*init_args, **init_kwargs) self.put_variable(col, name, init_value) # cast to make static analyzers happy return cast( Union[Variable[T], Variable[meta.AxisMetadata[T]]], Variable(self, col, name, unbox=unbox), ) @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., meta.AxisMetadata[T]] | Callable[..., T], *init_args, unbox: Literal[True], **init_kwargs, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: Literal[False], **init_kwargs, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T | meta.AxisMetadata[T]], *init_args, unbox: bool, **init_kwargs, ) -> T | meta.AxisMetadata[T]: ... def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: bool = True, **init_kwargs, ) -> T | meta.AxisMetadata[T]: """Creates a parameter if it doesn't exist yet in this scope and returns it. If the parameter exists already, the existing value is simply returned. Args: name: the name of the parameter. init_fn: a function taking a PRNGKey plus any other number of positional arguments. *init_args: the positional arguments to evaluate init_fn on lazily. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: the key-word arguments to evaluate init_fn on lazily. Returns: The parameters. Throws an error if the params exist already. """ self.reserve(name, 'params') if self.has_variable('params', name): value = self.get_variable('params', name) if unbox: value = meta.unbox(value) # Validate that the shape of the init_fn output is the same as the shape # of the existing parameter. This is to make sure that the hparams set up # in a Flax Module match the shapes coming in during apply, and if not, # catch it with an error message. # NOTE: We could consider moving this to `self.` abs_value = jax.eval_shape( lambda: init_fn(random.key(0), *init_args, **init_kwargs) ) abs_value_flat = jax.tree_util.tree_leaves(abs_value) value_flat = jax.tree_util.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): # NOTE: We could check dtype consistency here as well but its usefulness # is less obvious. We might intentionally change the dtype for inference # to a half float type for example. if np.shape(val) != np.shape(abs_val): raise errors.ScopeParamShapeError( name, self.path_text, np.shape(val), np.shape(abs_val) ) else: if not self.is_mutable_collection('params'): if self.is_collection_empty('params'): raise errors.ScopeCollectionNotFound('params', name, self.path_text) raise errors.ScopeParamNotFoundError(name, self.path_text) value = init_fn(self.make_rng('params'), *init_args, **init_kwargs) self.put_variable('params', name, value) if unbox: value = meta.unbox(value) return value def _populate_collections(self): collections = self.root._variables.keys() # pylint: disable=protected-access for col in collections: self._collection(col) def has_flag(self, key) -> bool: return key in self.flags def get_flag(self, key, default=no_flag) -> Any: if key not in self.flags and default is no_flag: return ValueError(f'Flag {key} not present on scope.') return self.flags.get(key, default) def _unfreeze_variables(variables, mutable): new_variables = {} for key, value in variables.items(): if in_filter(mutable, key): new_variables[key] = unfreeze(value) else: new_variables[key] = value return new_variables def bind( variables: VariableDict, rngs: RNGSequences | None = None, mutable: CollectionFilter = False, flags: Mapping | None = None, ): """Binds variables and rngs to a new ``Scope``. bind provides a ``Scope`` instance without transforming a function with ``apply``. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability split up code into different cells. a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore, we recommend using ``apply`` when code should be reusable and compatible across the JAX software ecosystem. Args: variables: Variable dictionary to bind. rngs: RNGs to bind. mutable: Which variable collections to treat as mutable. flags: internal flags. Returns: A new scope with the variables and rngs bound to it. """ if not _is_valid_variables(variables): raise errors.ApplyScopeInvalidVariablesTypeError() if rngs is not None and not _is_valid_rngs(rngs): raise errors.InvalidRngError( 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.' ) new_variables = _unfreeze_variables(variables, mutable) return Scope(new_variables, rngs=rngs, mutable=mutable, flags=flags) def apply( fn: Callable[..., Any], mutable: CollectionFilter = False, flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function. Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. """ @functools.wraps(fn) def wrapper( variables: VariableDict, *args, rngs: PRNGKey | RNGSequences | None = None, **kwargs, ) -> Any | tuple[Any, VariableDict | dict[str, Any]]: if rngs is not None: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'The ``rngs`` argument passed to an apply function should be a ' '``jax.PRNGKey`` or a dictionary mapping strings to ' '``jax.PRNGKey``.' ) if not isinstance(rngs, (dict, FrozenDict)): rngs = {'params': rngs} # Try to detect if user accidentally passed {'params': {'params': ...}. if ( 'params' in variables and isinstance(variables['params'], (dict, FrozenDict)) and 'params' in variables['params'] ): raise errors.ApplyScopeInvalidVariablesStructureError(variables) with bind( variables, rngs=rngs, mutable=mutable, flags=flags ).temporary() as root: y = fn(root, *args, **kwargs) if mutable is not False: return y, root.mutable_variables() else: return y return wrapper def init( fn: Callable[..., Any], mutable: CollectionFilter = True, flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalize a `Scope` function for initialization. Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. """ @functools.wraps(fn) def wrapper(rngs, *args, **kwargs) -> tuple[Any, VariableDict]: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'First argument passed to an init function should be a ' '``jax.PRNGKey`` or a dictionary mapping strings to ' '``jax.PRNGKey``.' ) if not isinstance(rngs, (dict, FrozenDict)): rngs = {'params': rngs} init_flags = {**(flags if flags is not None else {}), 'initializing': True} return apply(fn, mutable=mutable, flags=init_flags)( {}, *args, rngs=rngs, **kwargs ) return wrapper def lazy_init( fn: Callable[..., Any], mutable: CollectionFilter = True, flags: Mapping | None = None, ) -> Callable[..., Any]: """Functionalizes a `Scope` function for lazy initialization. Similar to ``init`` except that the init function now accepts ``jax.ShapeDtypeStruct`` instances for arguments that do not affect the variable initialization (typically this is all the input data). Example:: def f(scope, x): # the kernel init only uses the shape of x so we don't actually # need a value for x and can pass it as a ShapeDtypeStruct in lazy_init. k = scope.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) return x @ k init_fn = lazy_init(f) variables = init_fn(random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) Args: fn: a function taking a `Scope` as its first argument. mutable: the filter determining which variable collections are mutable. flags: internal flags. Returns: `fn` with the scope partially applied. Unlike ``init`` which returns a tuple of function output and variables, the lazy init function only returns the variables. """ return partial_eval.lazy_init( lambda *args, **kwargs: init(fn, mutable, flags)(*args, **kwargs)[1] ) def _is_valid_collection(col: VariableDict): if not isinstance(col, (FrozenDict, dict)): return False for name in col.keys(): # Any value can be stored in a collection so only keys can be verified. if not isinstance(name, str): return False return True def _is_valid_variables(variables: VariableDict) -> bool: """Checks whether the given variable dict is valid. Args: variables: A variable dict. Returns: True if `variables` is a valid variable dict. """ for name, col in variables.items(): if not isinstance(name, str): return False if not _is_valid_collection(col): return False return True def _is_valid_rng(rng: Array): """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs.""" # Allow for user-provided LazyRng - useful for compatibility when refactoring. if isinstance(rng, LazyRng): return True # This check is valid for either new-style or old-style PRNG keys if not isinstance(rng, (np.ndarray, jnp.ndarray)): return False # Handle new-style typed PRNG keys if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): return rng.shape == () # Handle old-style raw PRNG keys expected_rng = jax.eval_shape( lambda s: jax.random.key_data(jax.random.key(s)), 0 ) if (rng.shape, rng.dtype) != (expected_rng.shape, expected_rng.dtype): return False return True def _is_valid_rngs(rngs: PRNGKey | RNGSequences): if not isinstance(rngs, (FrozenDict, dict)): return False for key, val in rngs.items(): if not isinstance(key, str): return False if not _is_valid_rng(val): return False return True ================================================ FILE: flax/core/spmd.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import dataclasses import threading import jax from jax.sharding import PartitionSpec, NamedSharding from flax.core import meta from jax.experimental.layout import Format from flax.typing import ( LogicalRules, Sharding, ) def get_pspec(sharding, sharding_rules = None) -> PartitionSpec: """Given an `nnx.Variable`, return its `PartitionSpec`.""" return apply_rules(sharding, sharding_rules) # type: ignore def map_sharding(f, sharding): if isinstance(sharding, PartitionSpec) or isinstance(sharding, tuple): return PartitionSpec(*map(f, sharding)) elif isinstance(sharding, NamedSharding): return NamedSharding(sharding.mesh, map_sharding(f, sharding.spec)) # type: ignore elif isinstance(sharding, Format): return Format(sharding.layout, map_sharding(f, sharding.sharding)) # type: ignore def get_mesh(sharding): if isinstance(sharding, PartitionSpec) or isinstance(sharding, tuple): return None elif isinstance(sharding, NamedSharding): return sharding.mesh elif isinstance(sharding, Format): return get_mesh(sharding.sharding) def apply_rules(sharding, sharding_rules): """Rename the axes of a sharding specification (which can include `PartitionSpec`, `NamedSharding` or `Format` objects).""" if get_logical_axis_rules() or sharding_rules: context_rules = get_logical_axis_rules() rules = {alias: on_mesh for (alias, on_mesh) in composite_rules(context_rules, sharding_rules)} else: rules = {} return map_sharding(lambda a: rules.get(a, a), sharding) def _apply_sharding(value, sharding, mesh): if isinstance(sharding, Format): return jax.lax.with_sharding_constraint(value, sharding) if mesh.are_all_axes_explicit: return jax.sharding.reshard(value, sharding) elif mesh.are_all_axes_auto: return jax.lax.with_sharding_constraint(value, sharding) else: raise ValueError( 'Mesh must have all axes as Explicit or all axes as Auto. ' f'Got mixed axis types: {mesh.axis_types}') def shard_value(value, out_sharding, sharding_rules, mesh): if not out_sharding: return value if mesh is None: mesh = meta.get_global_mesh() out_sharding = apply_rules(out_sharding, sharding_rules) sharding_mesh = get_mesh(out_sharding) if sharding_mesh: if mesh: assert mesh == out_sharding.mesh mesh = sharding_mesh if mesh is None: raise ValueError( 'An auto mesh context or metadata is required if creating a variable' f' with annotation {out_sharding=}. ' 'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.') if isinstance(out_sharding, PartitionSpec): out_sharding = NamedSharding(mesh, out_sharding) return _apply_sharding(value, out_sharding, mesh) # Dynamic Axis Mapping Context # ------------------------------------------------------------------------------ @dataclasses.dataclass class _AxisRules(threading.local): """Dynamic logical axis to mesh axis binding context.""" rules: LogicalRules = () # Global axis binding context. _axis_rules = _AxisRules() def set_logical_axis_rules(rules: LogicalRules): """Sets the global logical axis to mesh axis binding.""" _axis_rules.rules = rules def get_logical_axis_rules() -> LogicalRules: """Returns the global logical axis to mesh axis binding.""" return _axis_rules.rules @contextlib.contextmanager def logical_axis_rules(rules: LogicalRules): """Context manager for setting the logical to mesh axis bindings.""" old_rules = _axis_rules.rules try: _axis_rules.rules = rules yield finally: _axis_rules.rules = old_rules def composite_rules(rule1, rule2): if not rule1 and not rule2: return () if rule1 and not rule2: return rule1 if rule2 and not rule1: return rule2 rules = {alias: value for alias, value in rule1} for alias, value in rule2: if alias in rules and rules[alias] != value: raise ValueError( f'Inconsistent logical axis annotations for {alias}: ' f'{rules[alias]} vs {value}' ) rules[alias] = value return tuple(rules.items()) def from_sharding_rules( sharding: Sharding, sharding_rules: LogicalRules ) -> Sharding: rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} return tuple( rules[str(s)] if (s and str(s) in rules) else s for s in sharding ) ================================================ FILE: flax/core/tracers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functionality for inspecting jax tracers.""" import jax import jax.core def current_trace(): """Returns the current JAX state tracer.""" if jax.__version_info__ <= (0, 4, 33): top = jax.core.find_top_trace(()) if top: return top.level else: return float('-inf') return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): # TODO(cgarciae): skipping for now as it breaks # too many internal tests. # level = current_trace() # if level != base_level: # raise errors.JaxTransformError() pass ================================================ FILE: flax/core/variables.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A variable dict is a normal Python dictionary, which is a container for one or more "variable collections", each of which are nested dictionaries whose leaves are ``jax.numpy`` arrays. The different variable collections share the same nested tree structure. For example, consider the following variable dictionary:: { "params": { "Conv1": { "weight": ..., "bias": ... }, "BatchNorm1": { "scale": ..., "mean": ... }, "Conv2": {...} }, "batch_stats": { "BatchNorm1": { "moving_mean": ..., "moving_average": ...} } } In this case, the ``"BatchNorm1"`` key lives in both the ``"params"`` and ```"batch_stats""`` collections. This reflects the fact that the submodule named ``""BatchNorm1""`` has both trainable parameters (the ``"params"`` collection), as well as other non-trainable variables (the ``"batch_stats"`` collection) TODO: Make "variable dict" design note, and link to it from here. """ from .scope import Variable ================================================ FILE: flax/cursor.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import enum from typing import ( Any, Generic, Protocol, TypeVar, runtime_checkable, ) from collections.abc import Callable, Generator, Mapping from flax.core import FrozenDict from flax.errors import CursorFindError, TraverseTreeError A = TypeVar('A') Key = Any @runtime_checkable class Indexable(Protocol): def __getitem__(self, key) -> Any: ... class AccessType(enum.Enum): ITEM = enum.auto() ATTR = enum.auto() @dataclasses.dataclass class ParentKey(Generic[A]): parent: 'Cursor[A]' key: Key access_type: AccessType def is_named_tuple(obj): return ( isinstance(obj, tuple) and hasattr(obj, '_fields') and hasattr(obj, '_asdict') and hasattr(obj, '_replace') ) def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None): """Helper function for ``Cursor.apply_update`` and ``Cursor.find_all``. Exactly one of ``update_fn`` and ``cond_fn`` must be not None. - If ``update_fn`` is not None, then ``Cursor.apply_update`` is calling this function and ``_traverse_tree`` will return a generator where each generated element is of type Tuple[Tuple[Union[str, int], AccessType], Any]. The first element is a tuple of the key path and access type where the change was applied from the ``update_fn``, and the second element is the newly modified value. If the generator is non-empty, then the tuple key path will always be non-empty as well. - If ``cond_fn`` is not None, then ``Cursor.find_all`` is calling this function and ``_traverse_tree`` will return a generator where each generated element is of type Tuple[Union[str, int], AccessType]. The tuple contains the key path and access type where the object was found that fulfilled the conditions of the ``cond_fn``. """ if not (bool(update_fn) ^ bool(cond_fn)): raise TraverseTreeError(update_fn, cond_fn) if path: str_path = '/'.join(str(key) for key, _ in path) if update_fn: new_obj = update_fn(str_path, obj) if new_obj is not obj: yield path, new_obj return elif cond_fn(str_path, obj): # type: ignore yield path return if isinstance(obj, (FrozenDict, dict)): items = obj.items() access_type = AccessType.ITEM elif is_named_tuple(obj): items = ((name, getattr(obj, name)) for name in obj._fields) # type: ignore access_type = AccessType.ATTR elif isinstance(obj, (list, tuple)): items = enumerate(obj) access_type = AccessType.ITEM elif dataclasses.is_dataclass(obj): items = ( (f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj) if f.init ) access_type = AccessType.ATTR else: return if update_fn: for key, value in items: yield from _traverse_tree( path + ((key, access_type),), value, update_fn=update_fn ) else: for key, value in items: yield from _traverse_tree( path + ((key, access_type),), value, cond_fn=cond_fn ) class Cursor(Generic[A]): _obj: A _parent_key: ParentKey[A] | None _changes: dict[Any, 'Cursor[A]'] def __init__(self, obj: A, parent_key: ParentKey[A] | None): # NOTE: we use `vars` here to avoid calling `__setattr__` # vars(self) = self.__dict__ vars(self)['_obj'] = obj vars(self)['_parent_key'] = parent_key vars(self)['_changes'] = {} @property def _root(self) -> 'Cursor[A]': if self._parent_key is None: return self else: return self._parent_key.parent._root # type: ignore @property def _path(self) -> str: if self._parent_key is None: return '' if self._parent_key.access_type == AccessType.ITEM: # type: ignore if isinstance(self._parent_key.key, str): # type: ignore key = "'" + self._parent_key.key + "'" # type: ignore else: key = str(self._parent_key.key) # type: ignore return self._parent_key.parent._path + '[' + key + ']' # type: ignore # self.parent_key.access_type == AccessType.ATTR: return self._parent_key.parent._path + '.' + self._parent_key.key # type: ignore def __getitem__(self, key) -> 'Cursor[A]': if key in self._changes: return self._changes[key] if not isinstance(self._obj, Indexable): raise TypeError(f'Cannot index into {self._obj}') if isinstance(self._obj, Mapping) and key not in self._obj: raise KeyError(f'Key {key} not found in {self._obj}') if is_named_tuple(self._obj): return getattr(self, self._obj._fields[key]) # type: ignore child = Cursor(self._obj[key], ParentKey(self, key, AccessType.ITEM)) self._changes[key] = child return child def __getattr__(self, name) -> 'Cursor[A]': if name in self._changes: return self._changes[name] if not hasattr(self._obj, name): raise AttributeError(f'Attribute {name} not found in {self._obj}') child = Cursor( getattr(self._obj, name), ParentKey(self, name, AccessType.ATTR) ) self._changes[name] = child return child def __setitem__(self, key, value): if is_named_tuple(self._obj): return setattr(self, self._obj._fields[key], value) # type: ignore self._changes[key] = Cursor(value, ParentKey(self, key, AccessType.ITEM)) def __setattr__(self, name, value): self._changes[name] = Cursor(value, ParentKey(self, name, AccessType.ATTR)) def set(self, value) -> A: """Set a new value for an attribute, property, element or entry in the Cursor object and return a copy of the original object, containing the new set value. Example:: >>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> modified_state = cursor(state).params['b'][1].set(10) >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]} Args: value: the value used to set an attribute, property, element or entry in the Cursor object Returns: A copy of the original object with the new set value. """ if self._parent_key is None: return value parent, key = self._parent_key.parent, self._parent_key.key # type: ignore parent._changes[key] = Cursor(value, self._parent_key) return parent._root.build() def build(self) -> A: """Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object. .. note:: The new object is built bottom-up, the changes will be first applied to the leaf nodes, and then its parent, all the way up to the root. Example:: >>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> new_fn = lambda x: x + 1 >>> c = cursor(state) >>> c.params['b'][1] = 10 >>> c.apply_fn = new_fn >>> modified_state = c.build() >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]} >>> assert modified_state.apply_fn == new_fn Returns: A copy of the original object with the accumulated changes. """ changes = { key: child.build() if isinstance(child, Cursor) else child for key, child in self._changes.items() } if isinstance(self._obj, FrozenDict): obj = self._obj.copy(changes) # type: ignore elif isinstance(self._obj, (dict, list)): obj = self._obj.copy() # type: ignore for key, value in changes.items(): obj[key] = value elif is_named_tuple(self._obj): obj = self._obj._replace(**changes) # type: ignore elif isinstance(self._obj, tuple): obj = list(self._obj) # type: ignore for key, value in changes.items(): obj[key] = value obj = tuple(obj) # type: ignore elif dataclasses.is_dataclass(self._obj): obj = dataclasses.replace(self._obj, **changes) # type: ignore else: obj = self._obj # type: ignore return obj # type: ignore def apply_update( self, update_fn: Callable[[str, Any], Any], ) -> 'Cursor[A]': """Traverse the Cursor object and record conditional changes recursively via an ``update_fn``. The changes are recorded in the Cursor object's ``._changes`` dictionary. To generate a copy of the original object with the accumulated changes, call the ``.build`` method after calling ``.apply_update``. The ``update_fn`` has a function signature of ``(str, Any) -> Any``: - The input arguments are the current key path (in the form of a string delimited by ``'/'``) and value at that current key path - The output is the new value (either modified by the ``update_fn`` or same as the input value if the condition wasn't fulfilled) .. note:: - If the ``update_fn`` returns a modified value, this method will not recurse any further down that branch to record changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don't need to look for further changes inside the dictionary, since the dictionary will be replaced anyways. - The ``is`` operator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists, dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value. - ``.apply_update`` WILL NOT call the ``update_fn`` to the value at the top-most level of the pytree (i.e. the root node). The ``update_fn`` will first be called on the root node's children, and then the pytree traversal will continue recursively from there. Example:: >>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] >>> def update_fn(path, value): ... '''Multiply all dense kernel params by 2 and add 1. ... Subtract the Dense_1 bias param by 1.''' ... if 'kernel' in path: ... return value * 2 + 1 ... elif 'Dense_1' in path and 'bias' in path: ... return value - 1 ... return value >>> c = cursor(params) >>> new_params = c.apply_update(update_fn).build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.key(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged Args: update_fn: the function that will conditionally record changes to the Cursor object Returns: The current Cursor object with the recorded conditional changes specified by the ``update_fn``. To generate a copy of the original object with the accumulated changes, call the ``.build`` method after calling ``.apply_update``. """ for path, value in _traverse_tree((), self._obj, update_fn=update_fn): child = self for key, access_type in path[:-1]: if access_type is AccessType.ITEM: child = child[key] else: # access_type is AccessType.ATTR child = getattr(child, key) key, access_type = path[-1] if access_type is AccessType.ITEM: child[key] = value else: # access_type is AccessType.ATTR setattr(child, key, value) return self def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]': """Traverse the Cursor object and return a child Cursor object that fulfill the conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``: - The input arguments are the current key path (in the form of a string delimited by ``'/'``) and value at that current key path - The output is a boolean, denoting whether to return the child Cursor object at this path Raises a :meth:`CursorFindError ` if no object or more than one object is found that fulfills the condition of the ``cond_fn``. We raise an error because the user should always expect this method to return the only object whose corresponding key path and value fulfill the condition of the ``cond_fn``. .. note:: - If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the "earliest" child node that fulfills the condition in ``cond_fn`` in a particular key path - ``.find`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children. Example:: >>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find the second dense layer params.''' ... return 'Dense_1' in path >>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1) >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> c = cursor(params) >>> c2 = c.find(cond_fn) >>> c2['kernel'] += 2 >>> c2['bias'] += 2 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all() ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all() ... else: ... assert (new_params[layer]['kernel'] == params[layer]['kernel']).all() ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged Args: cond_fn: the function that will conditionally find child Cursor objects Returns: A child Cursor object that fulfills the condition in the ``cond_fn``. """ generator = self.find_all(cond_fn) try: cursor = next(generator) except StopIteration: raise CursorFindError() try: cursor2 = next(generator) raise CursorFindError(cursor, cursor2) except StopIteration: return cursor def find_all( self, cond_fn: Callable[[str, Any], bool] ) -> Generator['Cursor[A]', None, None]: """Traverse the Cursor object and return a generator of child Cursor objects that fulfill the conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``: - The input arguments are the current key path (in the form of a string delimited by ``'/'``) and value at that current key path - The output is a boolean, denoting whether to return the child Cursor object at this path .. note:: - If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the "earliest" child nodes that fulfill the condition in ``cond_fn`` in a particular key path - ``.find_all`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children. Example:: >>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find all dense layer params.''' ... return 'Dense' in path >>> c = cursor(params) >>> for dense_params in c.find_all(cond_fn): ... dense_params['bias'] += 1 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged Args: cond_fn: the function that will conditionally find child Cursor objects Returns: A generator of child Cursor objects that fulfill the condition in the ``cond_fn``. """ for path in _traverse_tree((), self._obj, cond_fn=cond_fn): child = self for key, access_type in path: if access_type is AccessType.ITEM: child = child[key] else: # access_type is AccessType.ATTR child = getattr(child, key) yield child def __str__(self): return str(self._obj) def __repr__(self): return self._pretty_repr() def _pretty_repr(self, indent=2, _prefix_indent=0): s = 'Cursor(\n' obj_str = repr(self._obj).replace( '\n', '\n' + ' ' * (_prefix_indent + indent) ) s += ' ' * (_prefix_indent + indent) + f'_obj={obj_str},\n' s += ' ' * (_prefix_indent + indent) + '_changes={' if self._changes: s += '\n' for key in self._changes: str_key = repr(key) prefix = ' ' * (_prefix_indent + 2 * indent) + str_key + ': ' s += ( prefix + self._changes[key]._pretty_repr( indent=indent, _prefix_indent=len(prefix) ) + ',\n' ) s = s[ :-2 ] # remove comma and newline character for last element in self._changes s += '\n' + ' ' * (_prefix_indent + indent) + '}\n' else: s += '}\n' s += ' ' * _prefix_indent + ')' return s def __len__(self): return len(self._obj) def __iter__(self): if isinstance(self._obj, (tuple, list)): return (self[i] for i in range(len(self._obj))) else: raise NotImplementedError( '__iter__ method only implemented for tuples and lists, not type' f' {type(self._obj)}' ) def __reversed__(self): if isinstance(self._obj, (tuple, list)): return (self[i] for i in range(len(self._obj) - 1, -1, -1)) else: raise NotImplementedError( '__reversed__ method only implemented for tuples and lists, not type' f' {type(self._obj)}' ) def __add__(self, other): return self._obj + other def __sub__(self, other): return self._obj - other def __mul__(self, other): return self._obj * other def __matmul__(self, other): return self._obj @ other def __truediv__(self, other): return self._obj / other def __floordiv__(self, other): return self._obj // other def __mod__(self, other): return self._obj % other def __divmod__(self, other): return divmod(self._obj, other) def __pow__(self, other): return pow(self._obj, other) def __lshift__(self, other): return self._obj << other def __rshift__(self, other): return self._obj >> other def __and__(self, other): return self._obj & other def __xor__(self, other): return self._obj ^ other def __or__(self, other): return self._obj | other def __radd__(self, other): return other + self._obj def __rsub__(self, other): return other - self._obj def __rmul__(self, other): return other * self._obj def __rmatmul__(self, other): return other @ self._obj def __rtruediv__(self, other): return other / self._obj def __rfloordiv__(self, other): return other // self._obj def __rmod__(self, other): return other % self._obj def __rdivmod__(self, other): return divmod(other, self._obj) def __rpow__(self, other): return pow(other, self._obj) def __rlshift__(self, other): return other << self._obj def __rrshift__(self, other): return other >> self._obj def __rand__(self, other): return other & self._obj def __rxor__(self, other): return other ^ self._obj def __ror__(self, other): return other | self._obj def __neg__(self): return -self._obj def __pos__(self): return +self._obj def __abs__(self): return abs(self._obj) def __invert__(self): return ~self._obj def __round__(self, ndigits=None): return round(self._obj, ndigits) def __lt__(self, other): return self._obj < other def __le__(self, other): return self._obj <= other def __eq__(self, other): return self._obj == other def __ne__(self, other): return self._obj != other def __gt__(self, other): return self._obj > other def __ge__(self, other): return self._obj >= other def cursor(obj: A) -> Cursor[A]: """Wrap :class:`Cursor ` over ``obj`` and return it. Changes can then be applied to the Cursor object in the following ways: - single-line change via the ``.set`` method - multiple changes, and then calling the ``.build`` method - multiple changes conditioned on the pytree path and node value via the ``.apply_update`` method, and then calling the ``.build`` method ``.set`` example:: >>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]} ``.build`` example:: >>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]} ``.apply_update`` example:: >>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> def update_fn(path, value): ... '''Replace params with empty dictionary.''' ... if 'params' in path: ... return {} ... return value >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params={'a': 1, 'b': 2}, ... tx=optax.adam(1e-3), ... ) >>> c = cursor(state) >>> state2 = c.apply_update(update_fn).build() >>> assert state2.params == {} >>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged If the underlying ``obj`` is a ``list`` or ``tuple``, iterating over the Cursor object to get the child Cursors is also possible:: >>> from flax.cursor import cursor >>> c = cursor(((1, 2), (3, 4))) >>> for child_c in c: ... child_c[1] *= -1 >>> assert c.build() == ((1, -2), (3, -4)) View the docstrings for each method to see more examples of their usage. Args: obj: the object you want to wrap the Cursor in Returns: A Cursor object wrapped around obj. """ return Cursor(obj, None) ================================================ FILE: flax/errors.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """""" # Use an empty top-level docstring so Sphinx won't output the one below. """Flax error classes. === When to create a Flax error class? If an error message requires more explanation than a one-liner, it is useful to add it as a separate error class. This may lead to some duplication with existing documentation or docstrings, but it will provide users with more help when they are debugging a problem. We can also point to existing documentation from the error docstring directly. === How to name the error class? * If the error occurs when doing something, name the error Error For instance, if you want to raise an error when applying a module with an invalid method, the error can be: ApplyModuleInvalidMethodError. is optional, for instance if there is only one error when modifying a variable, the error can simply be: ModifyVariableError. * If there is no concrete action involved the only a description of the error is sufficient. For instance: InvalidFilterError, NameInUseError, etc. === Copy/pastable template for new error messages: class Template(FlaxError): "" " "" " def __init__(self): super().__init__(f'') """ class FlaxError(Exception): def __init__(self, message): error_page = ( 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html' ) module_name = self.__class__.__module__ class_name = self.__class__.__name__ if error_page not in message: # do not add a FlaxError link on unpickling message = f'{message} ({error_page}#{module_name}.{class_name})' super().__init__(message) def __reduce__(self): return (FlaxError, (str(self),)) ################################################# # NNX errors # ################################################# class TraceContextError(FlaxError): pass ################################################# # lazy_init.py errors # ################################################# class LazyInitError(FlaxError): """Lazy Init function has uncomputable return values. This happens when passing an argument to lazy_init with ``jax.ShapeDtypeStruct`` that affects the initialized variables. Make sure the init function only uses the shape and dtype or pass an actual JAX array if this is impossible. Example:: class Foo(nn.Module): @compact def __call__(self, x): # This parameter depends on the input x # this causes an error when using lazy_init. k = self.param("kernel", lambda _: x) return x * k Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) """ def __init__(self, partial_val): super().__init__( 'Lazy init encountered a value that could with ' f'the given inputs (shape: {partial_val}).' ) ################################################# # scope.py errors # ################################################# class InvalidRngError(FlaxError): """All rngs used in a Module should be passed to :meth:`Module.init() ` and :meth:`Module.apply() ` appropriately. We explain both separately using the following example:: class Bar(nn.Module): @nn.compact def __call__(self, x): some_param = self.param('some_param', nn.initializers.zeros_init(), (1, )) dropout_rng = self.make_rng('dropout') x = nn.Dense(features=4)(x) ... class Foo(nn.Module): @nn.compact def __call__(self, x): x = Bar()(x) ... **PRNGs for Module.init()** In this example, two rngs are used: * ``params`` is used for initializing the parameters of the model. This rng is used to initialize the ``some_params`` parameter, and for initializing the weights of the ``Dense`` Module used in ``Bar``. * ``dropout`` is used for the dropout rng that is used in ``Bar``. So, ``Foo`` is initialized as follows:: init_rngs = {'params': random.key(0), 'dropout': random.key(1)} variables = Foo().init(init_rngs, init_inputs) If a Module only requires an rng for ``params``, you can use:: SomeModule().init(rng, ...) # Shorthand for {'params': rng} **PRNGs for Module.apply()** When applying ``Foo``, only the rng for ``dropout`` is needed, because ``params`` is only used for initializing the Module parameters:: Foo().apply(variables, inputs, rngs={'dropout': random.key(2)}) If a Module only requires an rng for ``params``, you don't have to provide rngs for apply at all:: SomeModule().apply(variables, inputs) # rngs=None """ def __init__(self, msg): # For this error message we pass the entire message, since there are various # different kinds of RNG errors and we want to be able to be more specific # in the error message, while always linking to the same documentation. super().__init__(msg) class ApplyScopeInvalidVariablesTypeError(FlaxError): """When calling :meth:`Module.apply() `, the first argument should be a variable dict. For more explanation on variable dicts, please see :mod:`flax.core.variables`. """ def __init__(self): super().__init__( 'The first argument passed to an apply function should be ' 'a dictionary of collections. Each collection should be a ' 'dictionary with string keys.' ) class ApplyScopeInvalidVariablesStructureError(FlaxError): """This error is thrown when the dict passed as ``variables`` to apply() has an extra 'params' layer, i.e. {'params': {'params': ...}}. For more explanation on variable dicts, please see :mod:`flax.core.variables`. """ def __init__(self, variables): super().__init__( f'Expect the `variables` (first argument) passed to apply() ' f'to be a dict with the structure {{"params": ...}}, but got a dict ' f'with an extra params layer, i.e. {{"params": {{"params": ... }} }}. ' f'You should instead pass in your dict\'s ["params"].' ) class ScopeParamNotFoundError(FlaxError): """This error is thrown when trying to access a parameter that does not exist. For instance, in the code below, the initialized embedding name 'embedding' does not match the apply name 'embed':: class Embed(nn.Module): num_embeddings: int features: int @nn.compact def __call__(self, inputs, embed_name='embedding'): inputs = inputs.astype('int32') embedding = self.param(embed_name, jax.nn.initializers.lecun_normal(), (self.num_embeddings, self.features)) return embedding[inputs] model = Embed(4, 8) variables = model.init(random.key(0), jnp.ones((5, 5, 1))) _ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed') """ def __init__(self, param_name, scope_path): super().__init__( f'Could not find parameter named "{param_name}" in scope ' f'"{scope_path}".' ) class ScopeCollectionNotFound(FlaxError): """This error is thrown when trying to access a variable from an empty collection. There are two common causes: 1. | The collection was not passed to ``apply`` correctly. | For example, you might have used ``module.apply(params, ...)`` instead | of ``module.apply({'params': params}, ...)``. 2. | The collection is empty because the variables need to be initialized. | In this case, you should have made the collection mutable during | apply (e.g.: ``module.apply(variables, ..., mutable=['state'])``. """ def __init__(self, col_name, var_name, scope_path): super().__init__( f'Tried to access "{var_name}" from collection "{col_name}" in ' f'"{scope_path}" but the collection is empty.' ) class ScopeParamShapeError(FlaxError): """This error is thrown when the shape of an existing parameter is different from the shape of the return value of the ``init_fn``. This can happen when the shape provided during :meth:`Module.apply() ` is different from the one used when initializing the module. For instance, the following code throws this error because the apply shape (``(5, 5, 1)``) is different from the init shape (``(5, 5``). As a result, the shape of the kernel during ``init`` is ``(1, 8)``, and the shape during ``apply`` is ``(5, 8)``, which results in this error.:: class NoBiasDense(nn.Module): features: int = 8 @nn.compact def __call__(self, x): kernel = self.param('kernel', lecun_normal(), (x.shape[-1], self.features)) # <--- ERROR y = lax.dot_general(x, kernel, (((x.ndim - 1,), (0,)), ((), ()))) return y variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1))) _ = NoBiasDense().apply(variables, jnp.ones((5, 5))) """ def __init__(self, param_name, scope_path, value_shape, init_shape): super().__init__( f'For parameter "{param_name}" in "{scope_path}", the given ' f'initializer is expected to generate shape {init_shape}, but the ' f'existing parameter it received has shape {value_shape}.' ) class ScopeVariableNotFoundError(FlaxError): """This error is thrown when trying to use a variable in a Scope in a collection that is immutable. In order to create this variable, mark the collection as mutable explicitly using the ``mutable`` keyword in :meth:`Module.apply() `. """ def __init__(self, name, col, scope_path): super().__init__( f'No Variable named "{name}" for collection "{col}" ' f'exists in "{scope_path}".' ) class InvalidFilterError(FlaxError): """A filter should be either a boolean, a string or a container object.""" def __init__(self, filter_like): super().__init__(f'Invalid Filter: "{filter_like}"') class InvalidScopeError(FlaxError): """A temporary Scope is only valid within the context in which it is created:: with Scope(variables, rngs=rngs).temporary() as root: y = fn(root, *args, **kwargs) # Here root is valid. # Here root is invalid. """ def __init__(self, scope_name): super().__init__(f'The scope "{scope_name}" is no longer valid.') class ModifyScopeVariableError(FlaxError): """You cannot update a variable if the collection it belongs to is immutable. When you are applying a Module, you should specify which variable collections are mutable:: class MyModule(nn.Module): @nn.compact def __call__(self, x): ... var = self.variable('batch_stats', 'mean', ...) var.value = ... ... v = MyModule.init(...) ... logits = MyModule.apply(v, batch) # This throws an error. logits = MyModule.apply(v, batch, mutable=['batch_stats']) # This works. """ def __init__(self, col, variable_name, scope_path): super().__init__( f'Cannot update variable "{variable_name}" in ' f'"{scope_path}" because collection "{col}" is immutable.' ) class ImmutableVariableError(FlaxError): """You cannot update a variable that is marked as immutable. This error occurs when attempting to modify a Variable that has been set to 'immutable' mode. Variables in immutable mode are read-only and cannot be changed after creation. To fix this error, either: 1. Use a different variable mode (e.g., 'qdd' or 'pytree') 2. Or ensure you're not trying to modify the variable's value """ def __init__(self, message): super().__init__(message) class JaxTransformError(FlaxError): """JAX transforms and Flax modules cannot be mixed. JAX's functional transformations expect pure function. When you want to use JAX transformations **inside** Flax models, you should make use of the Flax transformation wrappers (e.g.: ``flax.linen.vmap``, ``flax.linen.scan``, etc.). """ def __init__(self): super().__init__('Jax transforms and Flax models cannot be mixed.') ################################################# # meta.py errors # ################################################# class PartitioningUnspecifiedError(FlaxError): """This error is raised when trying to add an axis to a Partitioned variable by using a transformation (e.g.: ``scan``, ``vmap``) without specifying the "partition_name" in the ``metadata_params`` dict. """ def __init__(self, target): super().__init__( 'Trying to transform a Partitioned variable but "partition_name"' f' is not specified in metadata_params: {target}' ) ################################################# # module.py errors # ################################################# class NameInUseError(FlaxError): """This error is raised when trying to create a submodule, param, or variable with an existing name. They are all considered to be in the same namespace. **Sharing Submodules** This is the wrong pattern for sharing submodules:: y = nn.Dense(feature=3, name='bar')(x) z = nn.Dense(feature=3, name='bar')(x+epsilon) Instead, modules should be shared by instance:: dense = nn.Dense(feature=3, name='bar') y = dense(x) z = dense(x+epsilon) If submodules are not provided with a name, a unique name will be given to them automatically:: class MyModule(nn.Module): @nn.compact def __call__(self, x): x = MySubModule()(x) x = MySubModule()(x) # This is fine. return x **Parameters and Variables** A parameter name can collide with a submodule or variable, since they are all stored in the same variable dict:: class Foo(nn.Module): @nn.compact def __call__(self, x): bar = self.param('bar', nn.initializers.zeros_init(), (1, )) embed = nn.Embed(num_embeddings=2, features=5, name='bar') # <-- ERROR! Variables should also have unique names, even if they have their own collection:: class Foo(nn.Module): @nn.compact def __call__(self, inputs): _ = self.param('mean', initializers.lecun_normal(), (2, 2)) _ = self.variable('stats', 'mean', initializers.zeros_init(), (2, 2)) """ def __init__(self, key_type, value, module_name): # key_type is in {param, variable, submodule}. super().__init__( f'Could not create {key_type} "{value}" in Module ' f'{module_name}: Name in use.' ) class AssignSubModuleError(FlaxError): """You are only allowed to create submodules in two places: 1. If your Module is noncompact: inside :meth:`Module.setup() `. 2. If your Module is compact: inside the method wrapped in :meth:`nn.compact() `. For instance, the following code throws this error, because ``nn.Conv`` is created in ``__call__``, which is not marked as compact:: class Foo(nn.Module): def setup(self): pass def __call__(self, x): conv = nn.Conv(features=3, kernel_size=3) Foo().init(random.key(0), jnp.zeros((1,))) Note that this error is also thrown if you partially defined a Module inside setup:: class Foo(nn.Module): def setup(self): self.conv = functools.partial(nn.Conv, features=3) def __call__(self, x): x = self.conv(kernel_size=4)(x) return x Foo().init(random.key(0), jnp.zeros((1,))) In this case, ``self.conv(kernel_size=4)`` is called from ``__call__``, which is disallowed because it's neither within ``setup`` nor a method wrapped in x``nn.compact``. """ def __init__(self, cls): super().__init__( f'Submodule {cls} must be defined in `setup()` or in a ' 'method wrapped in `@compact`' ) class SetAttributeInModuleSetupError(FlaxError): """You are not allowed to modify Module class attributes in :meth:`Module.setup() `:: class Foo(nn.Module): features: int = 6 def setup(self): self.features = 3 # <-- ERROR def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule().init(random.key(0), jnp.ones((1, ))) Instead, these attributes should be set when initializing the Module:: class Foo(nn.Module): features: int = 6 @nn.compact def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule(features=3).init(random.key(0), jnp.ones((1, ))) TODO(marcvanzee): Link to a design note explaining why it's necessary for modules to stay frozen (otherwise we can't safely clone them, which we use for lifted transformations). """ def __init__(self): super().__init__(f'Module construction attributes are frozen.') class SetAttributeFrozenModuleError(FlaxError): """You can only assign Module attributes to ``self`` inside :meth:`Module.setup() `. Outside of that method, the Module instance is frozen (i.e., immutable). This behavior is similar to frozen Python dataclasses. For instance, this error is raised in the following case:: class SomeModule(nn.Module): @nn.compact def __call__(self, x, num_features=10): self.num_features = num_features # <-- ERROR! x = nn.Dense(self.num_features)(x) return x s = SomeModule().init(random.key(0), jnp.ones((5, 5))) Similarly, the error is raised when trying to modify a submodule's attributes after constructing it, even if this is done in the ``setup()`` method of the parent module:: class Foo(nn.Module): def setup(self): self.dense = nn.Dense(features=10) self.dense.features = 20 # <--- This is not allowed def __call__(self, x): return self.dense(x) """ def __init__(self, module_cls, attr_name, attr_val): super().__init__( f"Can't set {attr_name}={attr_val} for Module of type " f'{module_cls}: Module instance is frozen outside of ' 'setup method.' ) class MultipleMethodsCompactError(FlaxError): """The ``@compact`` decorator may only be added to at most one method in a Flax module. In order to resolve this, you can: * remove ``@compact`` and define submodules and variables using :meth:`Module.setup() `. * Use two separate modules that both have a unique ``@compact`` method. TODO(marcvanzee): Link to a design note explaining the motivation behind this. There is no need for an equivalent to ``hk.transparent`` and it makes submodules much more sane because there is no need to prefix the method names. """ def __init__(self): super().__init__(f'Only one method per class can be @compact') class ReservedModuleAttributeError(FlaxError): """This error is thrown when creating a Module that is using reserved attributes. The following attributes are reserved: * ``parent``: The parent Module of this Module. * ``name``: The name of this Module. """ def __init__(self, annotations): super().__init__( f'properties `parent` and `name` are reserved: {annotations}' ) class ApplyModuleInvalidMethodError(FlaxError): """When calling :meth:`Module.apply() `, you can specify the method to apply using parameter ``method``. This error is thrown if the provided parameter is not a method in the Module and not a function with at least one argument. Learn more on the reference docs for :meth:`Module.apply() `. """ def __init__(self, method): super().__init__( f'Cannot call apply(): {method} is not a valid function for apply().' ) class CallCompactUnboundModuleError(FlaxError): """This error occurs when you are trying to call a Module directly, rather than through :meth:`Module.apply() `. For instance, the error will be raised when trying to run this code:: from flax import linen as nn import jax.numpy as jnp test_dense = nn.Dense(10) test_dense(jnp.ones((5,5))) Instead, you should pass the variables (parameters and other state) via :meth:`Module.apply() ` (or use :meth:`Module.init() ` to get initial variables):: from jax import random variables = test_dense.init(random.key(0), jnp.ones((5,5))) y = test_dense.apply(variables, jnp.ones((5,5))) """ def __init__(self): super().__init__("Can't call compact methods on unbound modules") class CallSetupUnboundModuleError(FlaxError): """This error occurs when you are trying to call ``.setup()`` directly. For instance, the error will be raised when trying to run this code:: from flax import linen as nn import jax.numpy as jnp class MyModule(nn.Module): def setup(self): self.submodule = MySubModule() module = MyModule() module.setup() # <-- ERROR! submodule = module.submodule In general you shouldn't call ``.setup()`` yourself, if you need to get access to a field or submodule defined inside ``setup`` you can instead create a function to extract it and pass it to ``nn.apply``:: # setup() will be called automatically by ``nn.apply`` def get_submodule(module): return module.submodule.clone() # avoid leaking the Scope empty_variables = {} # you can also use the real variables submodule = nn.apply(get_submodule, module)(empty_variables) """ def __init__(self): super().__init__("Can't call compact methods on unbound modules") class CallUnbindOnUnboundModuleError(FlaxError): """This error occurs when you are trying to call ``.unbind()`` on an unbound Module. For instance, when you try running the following example, an error will be raised:: from flax import linen as nn class MyModule(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(features=10)(x) module = MyModule() module.unbind() # <-- ERROR! Instead, you should ``bind`` the Module to a variable collection before calling ``.unbind()``:: bound_module = module.bind(variables) ... # do something with bound_module module = bound_module.unbind() # <-- OK! """ def __init__(self): super().__init__("Can't call `unbind()` on unbound modules") class CallShareScopeOnUnboundModuleError(FlaxError): """This error occurs when you are trying to call ``nn.share_scope`` on an unbound Module. For instance, when you try to use ``nn.share_scope`` at the top-level:: from flax import linen as nn class CustomDense(nn.Dense): def __call__(self, x): return super().__call__(x) + 1 custom_dense = CustomDense(5) dense = nn.Dense(5) # has the parameters nn.share_scope(custom_dense, dense) # <-- ERROR! """ def __init__(self): super().__init__("Can't call `share_scope` on unbound modules") class InvalidInstanceModuleError(FlaxError): """This error occurs when you are trying to call ``.init()``, ``.init_with_output()``, ``.apply()`` or ``.bind()`` on the Module class itself, instead of an instance of the Module class. For example, the error will be raised when trying to run this code:: class B(nn.Module): @nn.compact def __call__(self, x): return x k = random.key(0) x = random.uniform(random.key(1), (2,)) B.init(k, x) # B is module class, not B() a module instance B.apply(vs, x) # similar issue with apply called on class instead of instance. """ def __init__(self): super().__init__( 'Can only call init, init_with_output or apply methods on an instance' ' of the Module class, not the Module class itself' ) class IncorrectPostInitOverrideError(FlaxError): """This error occurs when you overrode ``.__post_init__()`` without calling ``super().__post_init__()``. For example, the error will be raised when trying to run this code:: from flax import linen as nn import jax.numpy as jnp import jax class A(nn.Module): x: float def __post_init__(self): self.x_square = self.x ** 2 # super().__post_init__() <-- forgot to add this line @nn.compact def __call__(self, input): return input + 3 r = A(x=3) r.init(jax.random.key(2), jnp.ones(3)) """ def __init__(self): super().__init__( 'Overrode `.__post_init__()` without calling `super().__post_init__()`' ) class DescriptorAttributeError(FlaxError): """This error occurs when you are trying to access a property that is accessing a non-existent attribute. For example, the error will be raised when trying to run this code:: class Foo(nn.Module): @property def prop(self): return self.non_existent_field # ERROR! def __call__(self, x): return self.prop foo = Foo() variables = foo.init(jax.random.key(0), jnp.ones(shape=(1, 8))) """ def __init__(self): super().__init__( 'Trying to access a property that is accessing a non-existent' ' attribute.' ) class InvalidCheckpointError(FlaxError): """A checkpoint cannot be stored in a directory that already has a checkpoint at the current or a later step. You can pass ``overwrite=True`` to disable this behavior and overwrite existing checkpoints in the target directory. """ def __init__(self, path, step): super().__init__( f'Trying to save an outdated checkpoint at step: "{step}" and path:' f' "{path}".' ) class MPACheckpointingRequiredError(FlaxError): """To optimally save and restore a multiprocess array (GDA or jax Array outputted from pjit), use GlobalAsyncCheckpointManager. You can create an GlobalAsyncCheckpointManager at top-level and pass it as argument:: from jax.experimental.gda_serialization import serialization as gdas gda_manager = gdas.GlobalAsyncCheckpointManager() save_checkpoint(..., gda_manager=gda_manager) """ def __init__(self, path, step): super().__init__( f'Checkpoint failed at step: "{step}" and path: "{path}": Target ' 'contains a multiprocess array should be saved/restored with a ' 'GlobalAsyncCheckpointManager.' ) class MPARestoreTargetRequiredError(FlaxError): """Provide a valid target when restoring a checkpoint with a multiprocess array. Multiprocess arrays need a sharding (global meshes and partition specs) to be initialized. Therefore, to restore a checkpoint that contains a multiprocess array, make sure the ``target`` you passed contains valid multiprocess arrays at the corresponding tree structure location. If you cannot provide a full valid ``target``, consider ``allow_partial_mpa_restoration=True``. """ def __init__(self, path, step, key=None): error_msg = ( f'Restore checkpoint failed at step: "{step}" and path: "{path}": ' 'Checkpoints containing a multiprocess array need to be restored with ' 'a target with pre-created arrays. If you cannot provide a full valid ' 'target, consider ``allow_partial_mpa_restoration=True``. ' ) if key: error_msg += f'This error fired when trying to restore array at {key}.' super().__init__(error_msg) class MPARestoreDataCorruptedError(FlaxError): """A multiprocess array stored in Google Cloud Storage doesn't contain a "commit_success.txt" file, which should be written at the end of the save. Failure of finding it could indicate a corruption of your saved GDA data. """ def __init__(self, step, path): super().__init__( f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' f' "{path}": No "commit_success.txt" found on this "_gda" directory. ' 'Was its save halted before completion?' ) ################################################# # transforms.py errors # ################################################# class TransformedMethodReturnValueError(FlaxError): """Transformed Module methods cannot return other Modules or Variables.""" def __init__(self, name): super().__init__( f'Transformed module method {name} cannot return Modules or Variables.' ) class TransformTargetError(FlaxError): """Linen transformations must be applied to Modules classes or functions taking a Module instance as the first argument. This error occurs when passing an invalid target to a linen transform (nn.vmap, nn.scan, etc.). This occurs for example when trying to transform a Module instance:: nn.vmap(nn.Dense(features))(x) # raises TransformTargetError You can transform the ``nn.Dense`` class directly instead:: nn.vmap(nn.Dense)(features)(x) Or you can create a function that takes the module instance as the first argument:: class BatchDense(nn.Module): @nn.compact def __call__(self, x): return nn.vmap( lambda mdl, x: mdl(x), variable_axes={'params': 0}, split_rngs={'params': True})(nn.Dense(3), x) """ def __init__(self, target): super().__init__( 'Linen transformations must be applied to Modules classes or' ' functions taking a Module instance as the first argument.' f' The provided target is not a Module class or callable: {target}' ) ################################################# # io.py errors # ################################################# class AlreadyExistsError(FlaxError): """Attempting to overwrite a file via copy. You can pass ``overwrite=True`` to disable this behavior and overwrite existing files in. """ def __init__(self, path): super().__init__(f'Trying overwrite an existing file: "{path}".') ################################################# # cursor.py errors # ################################################# class CursorFindError(FlaxError): """Error when calling :meth:`Cursor.find() `. This error occurs if no object or more than one object is found, given the conditions of the ``cond_fn``. """ def __init__(self, cursor=None, cursor2=None): if cursor and cursor2: super().__init__( 'More than one object found given the conditions of the cond_fn. ' 'The first two objects found have the following paths: ' f'{cursor._path} and {cursor2._path}' ) else: super().__init__('No object found given the conditions of the cond_fn.') class TraverseTreeError(FlaxError): """Error when calling ``Cursor._traverse_tree()``. This function has two modes: - if ``update_fn`` is not None, it will traverse the tree and return a generator of tuples containing the path where the ``update_fn`` was applied and the newly modified value. - if ``cond_fn`` is not None, it will traverse the tree and return a generator of tuple paths that fulfilled the conditions of the ``cond_fn``. This error occurs if either both ``update_fn`` and ``cond_fn`` are None, or both are not None. """ def __init__(self, update_fn, cond_fn): if update_fn is None and cond_fn is None: super().__init__( 'Both update_fn and cond_fn are None. Exactly one of them must be' ' None.' ) else: super().__init__( 'Both update_fn and cond_fn are not None. Exactly one of them must be' ' not None.' ) ================================================ FILE: flax/experimental/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: flax/experimental/nnx.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl import logging from flax.nnx import * logging.warning( "Using 'flax.experimental.nnx' is deprecated. Please use 'flax.nnx' instead." ) ================================================ FILE: flax/ids.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """UUIDs for Flax internals.""" import threading class UUIDManager: """Globally unique counter-based id manager. We need globally unique key ids for Module and Variable object instances to preserve and recreate sharing-by-reference relationship when lifting transforms and adopting outside Modules. - Use of id() is unacceptable because these identifiers are literally pointers which can be recycled, so we rely on a globally unique counter id instead. - We need to handle copy/deepcopy uniqueness via a wrapped type. """ def __init__(self): self._lock = threading.Lock() self._id = 0 def __call__(self): with self._lock: self._id += 1 return FlaxId(self._id) uuid = UUIDManager() class FlaxId: """Hashable wrapper for ids that handles uniqueness of copies.""" def __init__(self, rawid): self.id = rawid def __eq__(self, other): return isinstance(other, FlaxId) and other.id == self.id def __hash__(self): return hash(self.id) def __repr__(self): return f'FlaxId({self.id})' def __deepcopy__(self, memo): del memo return uuid() def __copy__(self): return uuid() ================================================ FILE: flax/io.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """IO Abstraction Layer. The sole purpose of this abstraction layer is to avoid requiring tensorflow as an open-source dependency solely for its tensorflow.io.gfile functions. """ import contextlib import glob as glob_module import importlib import os import shutil from enum import Enum from absl import logging from . import errors # Global Modes and selective import of tensorflow.io gfile. class BackendMode(Enum): DEFAULT = 0 TF = 1 io_mode = None gfile = None if importlib.util.find_spec('tensorflow'): from tensorflow.io import gfile # type: ignore io_mode = BackendMode.TF else: logging.warning( 'Tensorflow library not found, tensorflow.io.gfile ' 'operations will use native shim calls. ' "GCS paths (i.e. 'gs://...') cannot be accessed." ) io_mode = BackendMode.DEFAULT # Constants and Exceptions if io_mode == BackendMode.TF: from tensorflow import errors as tf_errors # type: ignore NotFoundError = tf_errors.NotFoundError else: NotFoundError = FileNotFoundError # Overrides for testing. @contextlib.contextmanager def override_mode(override: BackendMode): # pylint: disable=g-doc-return-or-yield """Returns a context manager that changes backend IO mode. Args: override: BackendMode enum value to set IO mode inside context. """ # pylint: enable=g-doc-return-or-yield global io_mode io_mode_prev = io_mode io_mode = override try: yield finally: io_mode = io_mode_prev def set_mode(override: BackendMode): """Sets global io mode. Args: override: BackendMode enum value to set for IO mode. """ global io_mode io_mode = override # tensorflow.io.gfile API shim functions. def GFile(name, mode): # pylint: disable=invalid-name if io_mode == BackendMode.DEFAULT: if 'b' in mode: return open(name, mode) # pylint: disable=unspecified-encoding else: return open(name, mode, encoding='utf-8') elif io_mode == BackendMode.TF: return gfile.GFile(name, mode) else: raise ValueError('Unknown IO Backend Mode.') def listdir(path): if io_mode == BackendMode.DEFAULT: return os.listdir(path=path) elif io_mode == BackendMode.TF: return gfile.listdir(path=path) else: raise ValueError('Unknown IO Backend Mode.') def isdir(path): if io_mode == BackendMode.DEFAULT: return os.path.isdir(path) elif io_mode == BackendMode.TF: return gfile.isdir(path) else: raise ValueError('Unknown IO Backend Mode.') def copy(src, dst, overwrite=False): if io_mode == BackendMode.DEFAULT: if os.path.exists(dst) and not overwrite: raise errors.AlreadyExistsError(dst) shutil.copy(src, dst) return elif io_mode == BackendMode.TF: return gfile.copy(src, dst, overwrite=overwrite) else: raise ValueError('Unknown IO Backend Mode.') def rename(src, dst, overwrite=False): if io_mode == BackendMode.DEFAULT: if os.path.exists(dst) and not overwrite: raise errors.AlreadyExistsError(dst) return os.rename(src, dst) elif io_mode == BackendMode.TF: return gfile.rename(src, dst, overwrite=overwrite) else: raise ValueError('Unknown IO Backend Mode.') def exists(path): if io_mode == BackendMode.DEFAULT: return os.path.exists(path) elif io_mode == BackendMode.TF: return gfile.exists(path) else: raise ValueError('Unknown IO Backend Mode.') def makedirs(path): if io_mode == BackendMode.DEFAULT: return os.makedirs(path, exist_ok=True) elif io_mode == BackendMode.TF: return gfile.makedirs(path) else: raise ValueError('Unknown IO Backend Mode.') def glob(pattern): if io_mode == BackendMode.DEFAULT: return [ path.rstrip('/') for path in glob_module.glob(pattern, recursive=False) ] elif io_mode == BackendMode.TF: return gfile.glob(pattern) else: raise ValueError('Unknown IO Backend Mode.') def remove(path): """Remove the file at path. Might fail if used on a directory path.""" if io_mode == BackendMode.DEFAULT: return os.remove(path) elif io_mode == BackendMode.TF: return gfile.remove(path) else: raise ValueError('Unknown IO Backend Mode.') def rmtree(path): """Remove a directory and recursively all contents inside. Might fail if used on a file path.""" if io_mode == BackendMode.DEFAULT: return shutil.rmtree(path) elif io_mode == BackendMode.TF: return gfile.rmtree(path) else: raise ValueError('Unknown IO Backend Mode.') def getsize(path): """Return the size, in bytes, of path.""" if io_mode == BackendMode.DEFAULT: return os.path.getsize(path) elif io_mode == BackendMode.TF: return gfile.stat(path).length else: raise ValueError('Unknown IO Backend Mode.') ================================================ FILE: flax/jax_utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities we could consider upstreaming to Jax.""" import collections import itertools import warnings from collections.abc import Iterable # pylint: disable=g-importing-member import jax import jax.numpy as jnp import numpy as np from jax import core, lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe def _pmap_device_order(): return jax.local_devices() def replicate(tree, devices=None): """Replicates arrays to multiple devices. Args: tree: a pytree containing the arrays that should be replicated. devices: the devices the data is replicated to (default: same order as expected by ``jax.pmap()``). Returns: A new pytree containing the replicated arrays. """ devices = devices or _pmap_device_order() return jax.device_put_replicated(tree, devices) def unreplicate(tree): """Returns a single instance of a replicated array.""" def _unreplicate_one(x): # Avoid degraded performance under the new jax.pmap. # Handle 0-dimensional (scalar) arrays - cannot index into them if hasattr(x, 'ndim') and x.ndim == 0: return x if ( not hasattr(x, 'sharding') or isinstance(x.sharding, jax.sharding.SingleDeviceSharding) or len(jax.local_devices()) == 1 ): return x[0] if x.sharding.is_fully_replicated: return x.addressable_shards[0].data return x.addressable_shards[0].data.squeeze(0) return jax.tree_util.tree_map(_unreplicate_one, tree) def pmean(xs, axis_name): warnings.warn('use jax.lax.pmean instead', DeprecationWarning) return lax.pmean(xs, axis_name) def partial_eval_by_shape(fn, input_spec, *args, **kwargs): """Lazily evaluate a function by using the shapes of the inputs. This function is similar to ``jax.eval_shape`` with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example ``module.init_by_shape`` where this functionality is used to initialize a model without using input data lr computation. Args: fn: the function to be lazily evaluated. input_spec: an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32. *args: other arguments passed to the module's apply function **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and an instance of Model """ # output cannot be returned in lazy_create because jax.eval_shape will only # return the shape and dtype. # TODO(mattjj,jheek): use a public JAX API f = lambda *inputs: fn(*inputs, *args, **kwargs) input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_util.tree_flatten(input_structs) debug_info = jax.api_util.debug_info("flax partial_eval_by_shape", f, (in_tree,), {}) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(f, debug_info=debug_info), in_tree) in_pvals = [ pe.PartialVal.unknown(core.ShapedArray(x.shape, x.dtype)) for x in inputs_flat ] _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [ const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) for pv, const in out_pvals ] return jax.tree_util.tree_unflatten(out_tree(), out_flat) def _parse_spec(spec): """Parse an input spec of the form (shape, dtype) or shape into a jax.ShapeDtypeStruct.""" spec = tuple(spec) if len(spec) == 2 and isinstance(spec[0], Iterable): return jax.ShapeDtypeStruct(tuple(spec[0]), spec[1]) else: return jax.ShapeDtypeStruct(spec, jnp.float32) def prefetch_to_device(iterator, size, devices=None): """Shard and prefetch batches on device. This utility takes an iterator and returns a new iterator which fills an on device prefetch buffer. Eager prefetching can improve the performance of training loops significantly by overlapping compute and data transfer. This utility is mostly useful for GPUs, for TPUs and CPUs it should not be necessary -- the TPU & CPU memory allocators (normally) don't pick a memory location that isn't free yet so they don't block. Instead those allocators OOM. Args: iterator: an iterator that yields a pytree of ndarrays where the first dimension is sharded across devices. size: the size of the prefetch buffer. If you're training on GPUs, 2 is generally the best choice because this guarantees that you can overlap a training step on GPU with a data prefetch step on CPU. devices: the list of devices to which the arrays should be prefetched. Defaults to the order of devices expected by ``jax.pmap``. Yields: The original items from the iterator where each ndarray is now sharded to the specified devices. """ queue = collections.deque() devices = _pmap_device_order() if devices is None else devices def _prefetch(xs): return jax.device_put_sharded(list(xs), devices) def enqueue(n): # Enqueues *up to* `n` elements from the iterator. for data in itertools.islice(iterator, n): queue.append(jax.tree_util.tree_map(_prefetch, data)) enqueue(size) # Fill up the buffer. while queue: yield queue.popleft() enqueue(1) def _scan_nd(body_fn, init, xs, n=1, unroll=(1,)): """Utility for performing an n-dimensional `lax.scan`. The n-d scan is simply recursive call of 1-d scan. Args: body_fn: the body of the loop of type (c, x) -> (c, y). init: initial value for the carry. xs: a pytree of tensors to scan over. n: number of dimensions to scan over (default: 1) Returns: A tuple of the final carry and the values returned by the body. """ if n == 1: return lax.scan(body_fn, init, xs, unroll=unroll[0]) else: def scan_body(c, x): return _scan_nd(body_fn, c, x, n=n - 1, unroll=unroll[1:]) return lax.scan(scan_body, init, xs, unroll=unroll[0]) def _invert_perm(perm): perm_inv = [0] * len(perm) for i, j in enumerate(perm): perm_inv[j] = i return tuple(perm_inv) def scan_in_dim(body_fn, init, xs, axis=(0,), unroll=(1,), keepdims=False): """utility for doing a scan along arbitrary dimensions. See `lax.scan` for details on how the scan operation works. Note on `unroll`: This argument gets left padded with ones to match the size of `axis`. Doing so allows unrolls to performed from the innermost loop first. For example, `scan_in_dim(..., axis=(1, 2, 3), unroll=5)` is equivalent to `scan_in_dim(..., axis=(1, 2, 3), unroll=(1, 1, 5))`. Args: body_fn: the body of the loop of type (c, x) -> (c, y). init: initial value for the carry. xs: a pytree of tensors to scan over. axis: the axis to scan over. keepdims: keep the dimensions that are scanned over. unroll: an optional positive integer, or tuple of positive integers showing how many iterations of the loop to be unrolled into a single iteration for each axis. Returns: A tuple of the final carry and the values returned by the body. """ if not isinstance(axis, Iterable): axis = (axis,) if not isinstance(unroll, Iterable): unroll = (unroll,) # Pad unroll with ones so we start unrolling from the innermost loop len_diff = len(axis) - len(unroll) unroll = (1,) * len_diff + unroll def transpose_in(x): perm = axis + tuple(np.delete(np.arange(x.ndim), axis)) return x.transpose(perm) def transpose_out(x): perm = axis + tuple(np.delete(np.arange(x.ndim), axis)) return x.transpose(_invert_perm(perm)) def body_wrapper(c, xs): if keepdims: xs = jax.tree_util.tree_map( lambda x: x.reshape((1,) * len(axis) + x.shape), xs ) xs = jax.tree_util.tree_map(transpose_out, xs) c, ys = body_fn(c, xs) if keepdims: ys = jax.tree_util.tree_map(transpose_in, ys) ys = jax.tree_util.tree_map(lambda x: x.reshape(x.shape[len(axis) :]), ys) return c, ys xs = jax.tree_util.tree_map(transpose_in, xs) c, ys = _scan_nd(body_wrapper, init, xs, n=len(axis), unroll=unroll) ys = jax.tree_util.tree_map(transpose_out, ys) return c, ys # Copied from https://github.com/google-research/big_vision def pad_shard_unpad( wrapped, static_argnums=(0,), static_argnames=(), static_return=False ): """Wraps a function with code that pads, shards, then un-shards, un-pads. Args: wrapped: the function to be wrapped. Signature is ``params, *args, *kwargs``. static_argnums: indices of arguments to ``wrapped`` that should _not_ be padded and sharded, but instead be forwarded as-is. The default is (0,) because by far the most common use-case is to pass ``params`` first. static_argnames: names of kwargs to ``wrapped`` that should _not_ be padded and sharded, but instead be forwarded as-is. static_return: whether not to un-shard, and un-pad the return value; static return values are typically used with eval steps that compute metrics Returns: A new function that pads and shards its arguments before passing them to the wrapped function, and un-shards and un-pads the returned pytree. This is useful for calling a pmap'ed function with inputs that aren't divisible by the number of devices. A typical use is: @pad_shard_unpad @jax.pmap def forward(params, x): ... Notes: The padding is done in host-memory before being passed to the function, and the values returned by the function are transferred back to host memory. The returned function is augmented with a new keyword-only argument ``min_device_batch`` that, if specified, forces padding inputs to at least this size per device. This can be useful to avoid recompiles for the last batch and reduce memory fragmentation. For more information refer to https://flax.readthedocs.io/en/latest/guides/data_preprocessing/full_eval.html """ def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): d = jax.local_device_count() # d = devices, b = batch batch_sizes = set() for i, a in enumerate(args): if i not in static_argnums: batch_sizes |= {t.shape[0] for t in jax.tree_util.tree_leaves(a)} for k, v in kw.items(): if k not in static_argnames: batch_sizes |= {t.shape[0] for t in jax.tree_util.tree_leaves(v)} assert len(batch_sizes) == 1, f'Inconsistent batch-sizes: {batch_sizes}' b = batch_sizes.pop() def pad(x): _, *shape = x.shape db, rest = divmod(b, d) if rest: x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) db += 1 if min_device_batch and db < min_device_batch: x = np.concatenate( [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)] ) db = min_device_batch return x.reshape(d, db, *shape) def maybe_pad(tree, actually_pad=True): if not actually_pad: return tree # For call-site convenience below. return jax.tree_util.tree_map(pad, tree) args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} out = wrapped(*args, **kw) def unpad(x): # Transfer back before cutting, to reduce on-device shape diversity. return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b] return out if static_return else jax.tree_util.tree_map(unpad, out) return pad_shard_unpad_wrapper ================================================ FILE: flax/linen/README.md ================================================ # Linen: A comfortable evolution of Flax Linen is a neural network API developed based on learning from our users and the broader JAX community. Linen improves on much of the former `flax.nn` API (removed since v0.4.0), such as submodule sharing and better support for non-trainable variables. Moreover, Linen builds on a "functional core", enabling direct usage of JAX transformations such as `vmap`, `remat` or `scan` inside your modules. In Linen, Modules behave much closer to vanilla Python objects, while still letting you opt-in to the concise single-method pattern many of our users love. The Linen Module API is stable and currently recommended for new projects. We are already supporting users in the OSS community and within Google. Minor changes may come to the top-level `apply` and `init` patterns, which we will communicate clearly. We plan a few improvements, including writing up short design notes, adding more design tests (see last link below), and an API for interactive module instances. Please open a [discussion](https://github.com/google/flax/discussions) if you have any questions or thoughts. **See the [Linen API reference docs](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html)**, or take a look at our additional material: * 2-page intro to the [Linen Design Principles](https://docs.google.com/document/d/1ZlL_4bXCw5Xl0WstQw1GpnZqfb9JFOeUGAPcBVk-kn8) * [Slides from a talk to the JAX core team](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0) * [Brief Intro to Linen](https://colab.research.google.com/github/google/flax/blob/main/docs/linen_intro.ipynb) in Colab * An [upgrade guide](https://docs.google.com/document/d/1hYavTVPaKVVe9Be8pCB7yW7r6dDv3RALVNit8NZca4c) + some additional questions we're considering * Ported [examples](https://github.com/google/flax/tree/main/examples) * "Design tests" used to ensure that our "functional core" supports [various advanced use-cases](https://github.com/google/flax/tree/main/tests/core/design), and that the mostly-syntactic-sugar Module abstraction [doesn't get in the way](https://github.com/google/flax/tree/main/examples/linen_design_test) ================================================ FILE: flax/linen/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The Flax Module system.""" # pylint: disable=g-multiple-import,useless-import-alias # re-export commonly used modules and functions from flax.core import ( DenyList as DenyList, FrozenDict as FrozenDict, broadcast as broadcast, meta as meta, ) from flax.core.meta import ( PARTITION_NAME as PARTITION_NAME, Partitioned as Partitioned, get_partition_spec as get_partition_spec, get_sharding as get_sharding, unbox as unbox, with_partitioning as with_partitioning, ) from flax.core.spmd import ( get_logical_axis_rules as get_logical_axis_rules, logical_axis_rules as logical_axis_rules, set_logical_axis_rules as set_logical_axis_rules, ) from .activation import ( PReLU as PReLU, celu as celu, elu as elu, gelu as gelu, glu as glu, hard_sigmoid as hard_sigmoid, hard_silu as hard_silu, hard_swish as hard_swish, hard_tanh as hard_tanh, leaky_relu as leaky_relu, log_sigmoid as log_sigmoid, log_softmax as log_softmax, logsumexp as logsumexp, normalize as normalize, one_hot as one_hot, relu6 as relu6, relu as relu, selu as selu, sigmoid as sigmoid, silu as silu, soft_sign as soft_sign, softmax as softmax, softplus as softplus, standardize as standardize, swish as swish, tanh as tanh, ) from .attention import ( MultiHeadAttention as MultiHeadAttention, MultiHeadDotProductAttention as MultiHeadDotProductAttention, SelfAttention as SelfAttention, combine_masks as combine_masks, dot_product_attention_weights as dot_product_attention_weights, dot_product_attention as dot_product_attention, make_attention_mask as make_attention_mask, make_causal_mask as make_causal_mask, ) from .batch_apply import BatchApply as BatchApply from .combinators import Sequential as Sequential from .fp8_ops import ( Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp, Fp8DotGeneral as Fp8DotGeneral, Fp8DotGeneralOp as Fp8DotGeneralOp, Fp8Einsum as Fp8Einsum, NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp, ) from .initializers import ( ones_init as ones_init, ones as ones, zeros_init as zeros_init, zeros as zeros, ) from .linear import ( ConvLocal as ConvLocal, ConvTranspose as ConvTranspose, Conv as Conv, DenseGeneral as DenseGeneral, Dense as Dense, Einsum as Einsum, Embed as Embed, ) from .module import ( Module as Module, Variable as Variable, apply as apply, compact_name_scope as compact_name_scope, compact as compact, disable_named_call as disable_named_call, enable_named_call as enable_named_call, init_with_output as init_with_output, init as init, intercept_methods as intercept_methods, merge_param as merge_param, nowrap as nowrap, override_named_call as override_named_call, share_scope as share_scope, ) from .normalization import ( BatchNorm as BatchNorm, GroupNorm as GroupNorm, InstanceNorm as InstanceNorm, LayerNorm as LayerNorm, RMSNorm as RMSNorm, SpectralNorm as SpectralNorm, WeightNorm as WeightNorm, ) from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) from .recurrent import ( Bidirectional as Bidirectional, ConvLSTMCell as ConvLSTMCell, GRUCell as GRUCell, LSTMCell as LSTMCell, MGUCell as MGUCell, OptimizedLSTMCell as OptimizedLSTMCell, RNNCellBase as RNNCellBase, RNN as RNN, SimpleCell as SimpleCell, ) from .spmd import ( LogicallyPartitioned as LogicallyPartitioned, logical_to_mesh, logical_to_mesh_axes, logical_to_mesh_sharding, with_logical_constraint, with_logical_partitioning as with_logical_partitioning, ) from .stochastic import Dropout as Dropout from .summary import tabulate from .transforms import ( add_metadata_axis, checkpoint as checkpoint, cond as cond, custom_vjp as custom_vjp, fold_rngs as fold_rngs, grad as grad, jit as jit, jvp as jvp, map_variables as map_variables, named_call as named_call, remat_scan as remat_scan, remat as remat, scan as scan, switch as switch, value_and_grad as value_and_grad, vjp as vjp, vmap as vmap, while_loop as while_loop, ) # pylint: enable=g-multiple-import ================================================ FILE: flax/linen/activation.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Activation functions.""" # pylint: disable=unused-import # re-export activation functions from jax.nn from typing import Any, Optional from flax.linen.module import compact from flax.linen.module import Module from flax.linen.linear import Dense from flax.typing import Array, Dtype from jax.nn import celu from jax.nn import elu from jax.nn import gelu from jax.nn import glu from jax.nn import hard_sigmoid from jax.nn import hard_silu from jax.nn import hard_swish from jax.nn import hard_tanh from jax.nn import leaky_relu from jax.nn import log_sigmoid from jax.nn import log_softmax from jax.nn import logsumexp from jax.nn import one_hot from jax.nn import relu from jax.nn import relu6 from jax.nn import selu from jax.nn import sigmoid from jax.nn import silu from jax.nn import soft_sign from jax.nn import softmax from jax.nn import softplus from jax.nn import standardize from jax.nn import swish import jax.numpy as jnp from jax.numpy import tanh # Normalize is a deprecated alias of standardize normalize = standardize # pylint: enable=unused-import class PReLU(Module): """Parametric Rectified Linear Unit (PReLU) activation function. Note that PReLU is a Flax layer and not a simple activation function, so it needs to be initialized before being called. Example usage:: >>> import flax.linen as nn >>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(2)(x) ... x = nn.PReLU()(x) # initialized ... return x Attributes: param_dtype: the dtype passed to parameter initializers (default: float32). negative_slope_init: the value to initialize the negative slope (default 0.01). """ param_dtype: Dtype = jnp.float32 negative_slope_init: float = 0.01 @compact def __call__(self, inputs: Array) -> Array: """Applies an activation to the inputs. Args: inputs: the nd-array to apply the activation function to. Returns: The transformed input. """ negative_slope = self.param( 'negative_slope', lambda k: jnp.asarray(self.negative_slope_init, self.param_dtype), ) return jnp.where( inputs >= 0, inputs, jnp.asarray(negative_slope, inputs.dtype) * inputs ) ================================================ FILE: flax/linen/attention.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Attention core modules for Flax.""" from __future__ import annotations import functools import inspect import warnings from typing import Any, overload from collections.abc import Callable import jax import jax.numpy as jnp from jax import lax, random from flax.linen import initializers from flax.linen.dtypes import promote_dtype from flax.linen.linear import ( DenseGeneral, default_kernel_init, ) from flax.linen.module import Module, compact, merge_param from flax.linen.normalization import LayerNorm from flax.typing import ( Array, PRNGKey, Dtype, Shape as Shape, Initializer, PrecisionLike, DotGeneralT, ) def dot_product_attention_weights( query: Array, key: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: PRNGKey | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Dtype | None = None, precision: PrecisionLike = None, module: Module | None = None, force_fp32_for_softmax: bool = False, einsum_dot_general: Callable[..., Array] | None = None, einsum: Callable[..., Array] | None = None, ): """Computes dot-product attention weights given query and key. Used by :func:`dot_product_attention`, which is what you'll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself. Args: query: queries for calculating attention with shape of ``[batch..., q_length, num_heads, qk_depth_per_head]``. key: keys for calculating attention with shape of ``[batch..., kv_length, num_heads, qk_depth_per_head]``. bias: bias for the attention weights. This should be broadcastable to the shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is ``False``. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs and params) precision: numerical precision of the computation see ``jax.lax.Precision`` for details. module: the Module that will sow the attention weights into the 'intermediates' collection. Remember to mark 'intermediates' as mutable via ``mutable=['intermediates']`` in order to have that collection returned. If ``module`` is None, the attention weights will not be sowed. force_fp32_for_softmax: bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability. einsum_dot_general: the dot_general to use in einsum. einsum: If unspecified, default `jnp.einsum` will be used. This argument is mutually exclusive with `precision` and `einsum_dot_general`. Raises: ValueError: if both `precision`/`einsum_dot_general` and `einsum` are specified. Returns: Output of shape ``[batch..., num_heads, q_length, kv_length]``. """ if (precision or einsum_dot_general) and einsum: raise ValueError( 'precision/einsum_dot_general and einsum are mutually exclusive. Please' ' specify only one of them.' ) if not einsum: einsum = functools.partial( jnp.einsum, precision=precision, _dot_general=einsum_dot_general if einsum_dot_general else jax.lax.dot_general, ) query, key = promote_dtype(query, key, dtype=dtype) dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # calculate attention matrix depth = query.shape[-1] query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) attn_weights = einsum('...qhd,...khd->...hqk', query, key) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: attn_weights = attn_weights + bias # apply attention mask if mask is not None: big_neg = jnp.finfo(dtype).min attn_weights = jnp.where(mask, attn_weights, big_neg) # normalize the attention weights if force_fp32_for_softmax and dtype != jnp.float32: attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32)) else: attn_weights = jax.nn.softmax(attn_weights).astype(dtype) if module: module.sow('intermediates', 'attention_weights', attn_weights) # apply attention dropout if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:] keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier return attn_weights def dot_product_attention( query: Array, key: Array, value: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: PRNGKey | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Dtype | None = None, precision: PrecisionLike = None, module: Module | None = None, force_fp32_for_softmax: bool = False, einsum_dot_general: Callable[..., Array] | None = None, qk_attn_weights_einsum: Callable[..., Array] | None = None, attn_weights_value_einsum: Callable[..., Array] | None = None, ): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. .. note:: ``query``, ``key``, ``value`` needn't have any batch dimensions. Args: query: queries for calculating attention with shape of ``[batch..., q_length, num_heads, qk_depth_per_head]``. key: keys for calculating attention with shape of ``[batch..., kv_length, num_heads, qk_depth_per_head]``. value: values to be used in attention with shape of ``[batch..., kv_length, num_heads, v_depth_per_head]``. bias: bias for the attention weights. This should be broadcastable to the shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is ``False``. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see ``jax.lax.Precision` for details. module: the Module that will sow the attention weights into the 'intermediates' collection. Remember to mark 'intermediates' as mutable via ``mutable=['intermediates']`` in order to have that collection returned. If ``module`` is None, the attention weights will not be sowed. force_fp32_for_softmax: bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability. einsum_dot_general: the dot_general to use in `jnp.einsum`. qk_attn_weights_einsum: the einsum for computing the attention weights. When unspecified, the default `jnp.einsum` will be used. This argument is mutually exclusive with `precision` and `einsum_dot_general`. attn_weights_value_einsum: the einsum for computing the product of the attention weights and the values. When unspecified, the default `jnp.einsum` will be used. This argument is mutually exclusive with `precision` and `einsum_dot_general`. Returns: Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``. Raises: ValueError: if both `precision`/`einsum_dot_general` and `qk_attn_weights_einsum`/`attn_weights_value_einsum` are specified. """ if (qk_attn_weights_einsum and not attn_weights_value_einsum) or ( not qk_attn_weights_einsum and attn_weights_value_einsum ): raise ValueError( 'qk_attn_weights_einsum and attn_weights_value_einsum must be specified' ' together.' ) if (precision or einsum_dot_general) and ( qk_attn_weights_einsum or attn_weights_value_einsum ): raise ValueError( 'precision/einsum_dot_general and' ' qk_attn_weights_einsum/attn_weights_value_einsum are mutually' ' exclusive. Please specify only one of them.' ) query, key, value = promote_dtype(query, key, value, dtype=dtype) dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert ( query.shape[:-3] == key.shape[:-3] == value.shape[:-3] ), 'q, k, v batch dims must match.' assert ( query.shape[-2] == key.shape[-2] == value.shape[-2] ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights attn_weights = dot_product_attention_weights( query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, deterministic, dtype, precision, module, force_fp32_for_softmax, einsum_dot_general=einsum_dot_general, einsum=qk_attn_weights_einsum, ) if not attn_weights_value_einsum: attn_weights_value_einsum = functools.partial( jnp.einsum, precision=precision, _dot_general=einsum_dot_general if einsum_dot_general else jax.lax.dot_general, ) # return weighted sum over values for each query position return attn_weights_value_einsum( '...hqk,...khd->...qhd', attn_weights, value, ) class MultiHeadDotProductAttention(Module): """Multi-head dot-product attention. Example usage:: >>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) Attributes: num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: The dtype of the computation (default: infer from inputs and params) param_dtype: The dtype passed to parameter initializers (default: float32) qkv_features: Dimension of the key, query, and value. out_features: Dimension of the last projection broadcast_dropout: Use a broadcasted dropout along batch dims. dropout_rate: Dropout rate. deterministic: If False, the attention weight is masked randomly using dropout, whereas if True, the attention weights are deterministic. precision: Numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: Initializer for the kernel of the Dense layers. out_kernel_init: Optional Initializer for the kernel of the output Dense layer, if None, ``kernel_init`` will be used. bias_init: Initializer for the bias of the Dense layers. out_bias_init: Optional Initializer for the bias of the output Dense layer, if None, ``bias_init`` will be used. use_bias: Whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` decode: Whether to prepare and use an autoregressive cache. normalize_qk: Should QK normalization be applied (arxiv.org/abs/2302.05442). qk_attn_weights_einsum_cls: factory function to create the einsum for computing the attention weights. attn_weights_value_einsum_cls: factory function to create the einsum for computing the product of the attention weights and the values. """ num_heads: int dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 qkv_features: int | None = None out_features: int | None = None broadcast_dropout: bool = True dropout_rate: float = 0.0 deterministic: bool | None = None precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init out_kernel_init: Initializer | None = None bias_init: Initializer = initializers.zeros_init() out_bias_init: Initializer | None = None use_bias: bool = True attention_fn: Callable[..., Array] = dot_product_attention decode: bool = False normalize_qk: bool = False force_fp32_for_softmax: bool = False # Deprecated, will be removed. qkv_dot_general: DotGeneralT | None = None out_dot_general: DotGeneralT | None = None qkv_dot_general_cls: Any = None out_dot_general_cls: Any = None qk_attn_weights_einsum_cls: Callable[..., Callable[..., Array]] | None = None attn_weights_value_einsum_cls: Callable[..., Callable[..., Array]] | None = ( None ) @overload def __call__( self, inputs_q: Array, inputs_k: Array | None = None, inputs_v: Array | None = None, *, mask: Array | None = None, deterministic: bool | None = None, dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): ... @overload def __call__( self, inputs_q: Array, *, inputs_kv: Array | None = None, mask: Array | None = None, deterministic: bool | None = None, dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): ... @compact def __call__( self, inputs_q: Array, inputs_k: Array | None = None, inputs_v: Array | None = None, *, inputs_kv: Array | None = None, mask: Array | None = None, deterministic: bool | None = None, dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k. Args: inputs_q: input queries of shape ``[batch_sizes..., length, features]``. inputs_k: key of shape ``[batch_sizes..., length, features]``. If None, inputs_k will copy the value of inputs_q. inputs_v: values of shape ``[batch_sizes..., length, features]``. If None, inputs_v will copy the value of inputs_k. inputs_kv: key/values of shape ``[batch_sizes..., length, features]``. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead. mask: attention mask of shape ``[batch_sizes..., num_heads, query_length, key/value_length]``. Attention weights are masked out if their corresponding mask value is ``False``. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. dropout_rng: optional rng key to pass to the attention layer's dropout mask. Otherwise, self.make_rng('dropout') is used instead. sow_weights: if ``True``, the attention weights are sowed into the 'intermediates' collection. Remember to mark 'intermediates' as mutable via ``mutable=['intermediates']`` in order to have that collection returned. Returns: output of shape ``[batch_sizes..., length, features]``. """ if inputs_kv is not None: if inputs_k is not None or inputs_v is not None: raise ValueError( 'If either `inputs_k` or `inputs_v` is not None, ' '`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` ' 'and `inputs_v` must be None. We recommend using `inputs_k` and ' '`inputs_v` args, since `inputs_kv` will be deprecated soon. See ' 'https://github.com/google/flax/discussions/3389 for more ' 'information.' ) inputs_k = inputs_v = inputs_kv warnings.warn( 'The inputs_kv arg will be deprecated soon. ' 'Use inputs_k and inputs_v instead. See ' 'https://github.com/google/flax/discussions/3389 ' 'for more information.', DeprecationWarning, ) else: if inputs_k is None: if inputs_v is not None: raise ValueError( '`inputs_k` cannot be None if `inputs_v` is not None. ' 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' 'value to `inputs_k` and leave `inputs_v` as None.' ) inputs_k = inputs_q if inputs_v is None: inputs_v = inputs_k elif inputs_v.shape[-1] == inputs_v.shape[-2]: warnings.warn( f'You are passing an array of shape {inputs_v.shape} ' 'to the `inputs_v` arg, when you may have intended ' 'to pass it to the `mask` arg. As of Flax version ' '0.7.4, the function signature of ' "MultiHeadDotProductAttention's `__call__` method " 'has changed to `__call__(inputs_q, inputs_k=None, ' 'inputs_v=None, *, inputs_kv=None, mask=None, ' 'deterministic=None)`. Use the kwarg `mask` instead. ' 'See https://github.com/google/flax/discussions/3389 ' 'and read the docstring for more information.', DeprecationWarning, ) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( f'Memory dimension ({qkv_features}) must be divisible by number of' f' heads ({self.num_heads}).' ) head_dim = qkv_features // self.num_heads dense = functools.partial( DenseGeneral, axis=-1, dtype=self.dtype, param_dtype=self.param_dtype, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision, dot_general=self.qkv_dot_general, dot_general_cls=self.qkv_dot_general_cls, ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = ( dense(name='query')(inputs_q), dense(name='key')(inputs_k), dense(name='value')(inputs_v), ) if self.normalize_qk: # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. query = LayerNorm( name='query_ln', use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, )(query) # type: ignore[call-arg] key = LayerNorm( name='key_ln', use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, )(key) # type: ignore[call-arg] # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable( 'cache', 'cached_key', jnp.zeros, key.shape, key.dtype ) cached_value = self.variable( 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype ) cache_index = self.variable( 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) ) if is_initialized: ( *batch_dims, max_length, num_heads, depth_per_head, ) = cached_key.value.shape # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape) ) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) indices: tuple[int | jax.Array, ...] = (zero,) * len( batch_dims ) + ( cur_index, zero, zero, ) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length), ), ) if ( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. m_deterministic = merge_param( 'deterministic', self.deterministic, deterministic ) if not m_deterministic and dropout_rng is None: dropout_rng = self.make_rng('dropout') else: m_deterministic = True # `qk_attn_weights_einsum` and `attn_weights_value_einsum` are optional # arguments that can be used to override the default `jnp.einsum`. They # exist for quantized einsum support in AQT. qk_attn_weights_einsum = ( self.qk_attn_weights_einsum_cls() if self.qk_attn_weights_einsum_cls else None ) attn_weights_value_einsum = ( self.attn_weights_value_einsum_cls() if self.attn_weights_value_einsum_cls else None ) # apply attention attn_args = (query, key, value) # This kwargs list match the default nn.dot_product_attention. # For custom `attention_fn`s, invalid kwargs will be filtered. attn_kwargs = dict( mask=mask, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=m_deterministic, dtype=self.dtype, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax, qk_attn_weights_einsum=qk_attn_weights_einsum, attn_weights_value_einsum=attn_weights_value_einsum, ) attn_kwargs = { k: v for k, v in attn_kwargs.items() if k in inspect.signature(self.attention_fn).parameters } if sow_weights: x = self.attention_fn(*attn_args, **attn_kwargs, module=self) else: x = self.attention_fn(*attn_args, **attn_kwargs) # back to the original inputs dimensions out = DenseGeneral( features=features, axis=(-2, -1), kernel_init=self.out_kernel_init or self.kernel_init, bias_init=self.out_bias_init or self.bias_init, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, dot_general=self.out_dot_general, dot_general_cls=self.out_dot_general_cls, name='out', # type: ignore[call-arg] )(x) return out class MultiHeadAttention(MultiHeadDotProductAttention): """Multi-head dot-product attention. Alias for ``MultiHeadDotProductAttention``. **NOTE**: ``MultiHeadAttention`` is a wrapper of ``MultiHeadDotProductAttention``, and so their implementations are identical. However ``MultiHeadAttention`` layers will, by default, be named ``MultiHeadAttention_{index}``, whereas ``MultiHeadDotProductAttention`` will be named ``MultiHeadDotProductAttention_{index}``. Therefore, this could affect checkpointing, param collection names and RNG threading (since the layer name is used when generating new RNG's) within the module. Example usage:: >>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) Attributes: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: infer from inputs and params) param_dtype: the dtype passed to parameter initializers (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rate: dropout rate deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. use_bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` decode: whether to prepare and use an autoregressive cache. normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). """ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of ``MultiHeadDotProductAttention``. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5))) """ @compact def __call__( # type: ignore self, inputs_q: Array, mask: Array | None = None, deterministic: bool | None = None, dropout_rng: PRNGKey | None = None, sow_weights: bool = False, ): """Applies multi-head dot product self-attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape ``[batch_sizes..., length, features]``. mask: attention mask of shape ``[batch_sizes..., num_heads, query_length, key/value_length]``. Attention weights are masked out if their corresponding mask value is ``False``. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape ``[batch_sizes..., length, features]``. """ warnings.warn( 'SelfAttention will be deprecated soon. Use ' '`MultiHeadDotProductAttention.__call__(inputs_q)` instead. ' 'See https://github.com/google/flax/discussions/3389 ' 'for more information.', DeprecationWarning, ) return super().__call__( inputs_q, mask=mask, deterministic=deterministic, dropout_rng=dropout_rng, sow_weights=sow_weights, ) # mask-making utility functions def make_attention_mask( query_input: Array, key_input: Array, pairwise_fn: Callable[..., Any] = jnp.multiply, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32, ): """Mask-making helper for attention weights. In case of 1d inputs (i.e., ``[batch..., len_q]``, ``[batch..., len_kv]``, the attention weights will be ``[batch..., heads, len_q, len_kv]`` and this function will produce ``[batch..., 1, len_q, len_kv]``. Args: query_input: a batched, flat input of query_length size key_input: a batched, flat input of key_length size pairwise_fn: broadcasting elementwise comparison function extra_batch_dims: number of extra batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A ``[batch..., 1, len_q, len_kv]`` shaped mask for 1d attention. """ mask = pairwise_fn( jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) ) mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) return mask.astype(dtype) def make_causal_mask( x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 ) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., ``[batch..., len]``, the self-attention weights will be ``[batch..., heads, len, len]`` and this function will produce a causal mask of shape ``[batch..., 1, len, len]``. Args: x: input array of shape ``[batch..., len]`` extra_batch_dims: number of batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A ``[batch..., 1, len, len]`` shaped causal mask for 1d attention. """ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) return make_attention_mask( idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype, ) def combine_masks( *masks: Array | None, dtype: Dtype = jnp.float32 ) -> Array | None: """Combine attention masks. Args: *masks: set of attention mask arguments to combine, some can be None. dtype: dtype for the returned mask. Returns: Combined mask, reduced by logical and, returns None if no masks given. """ masks_list = [m for m in masks if m is not None] if not masks_list: return None assert all( map(lambda x: x.ndim == masks_list[0].ndim, masks_list) ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' mask, *other_masks = masks_list for other_mask in other_masks: mask = jnp.logical_and(mask, other_mask) return mask.astype(dtype) ================================================ FILE: flax/linen/batch_apply.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Batch apply.""" import jax, jax.numpy as jnp import numpy as np def ndim_at_least(x, num_dims): if not (isinstance(x, jax.Array) or isinstance(x, np.ndarray)): x = jnp.asarray(x) return x.ndim >= num_dims def arbitrary_mergeable_leaf(min_num_dims, args, kwargs): for a in jax.tree_util.tree_leaves(args): if ndim_at_least(a, min_num_dims): return a for k in jax.tree_util.tree_leaves(kwargs): if ndim_at_least(k, min_num_dims): return k # Couldn't find a satisfactory leaf. return None def merge_leading_dims(x, num_dims): """Merge leading dimensions.""" # Don't merge if there aren't dimensions to merge. if not ndim_at_least(x, num_dims): return x new_shape = (np.prod(x.shape[:num_dims]),) + x.shape[num_dims:] return x.reshape(new_shape) def split_leading_dim(x, to_dim): new_shape = to_dim + x.shape[1:] return x.reshape(new_shape) class BatchApply: r"""Temporarily merges leading dimensions of input tensors. Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input. Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified. This may be useful for applying a module to each timestep of e.g. a ``[Time, Batch, ...]`` array. For some ``f``\ s and platforms, this may be more efficient than :func:`jax.vmap`, especially when combined with other transformations like :func:`jax.grad`. Example usage:: >>> import jax, jax.numpy as jnp >>> a = jax.random.normal(jax.random.key(0), [2, 3, 4]) >>> b = jax.random.normal(jax.random.key(1), [4]) >>> def raises(a, b): ... if len(a.shape) != 2: ... raise ValueError("a must be shape 2") ... if len(b.shape) != 1: ... raise ValueError("b must be shape 1") ... return jnp.dot(a, b) >>> out = BatchApply(raises)(a, b) >>> expected_merged_leading = raises(a.reshape(2*3, 4), b) >>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:]) >>> np.testing.assert_array_equal(out, expected) """ def __init__(self, f, num_dims=2): """Constructs a :class:`BatchApply` module. Args: f: The callable to be applied to the reshaped array. num_dims: The number of dimensions to merge. """ self._f = f self.num_dims = num_dims def __call__(self, *args, **kwargs): example = arbitrary_mergeable_leaf(self.num_dims, args, kwargs) if example is None: raise ValueError( 'BatchApply requires at least one input with ndim >= ' f'{self.num_dims}.' ) merge = lambda x: merge_leading_dims(x, self.num_dims) split = lambda x: split_leading_dim(x, example.shape[:self.num_dims]) args = jax.tree_util.tree_map(merge, args) kwargs = jax.tree_util.tree_map(merge, kwargs) outputs = self._f(*args, **kwargs) return jax.tree_util.tree_map(split, outputs) ================================================ FILE: flax/linen/combinators.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Combinators of modules, such as a Sequential.""" from typing import Any from collections.abc import Callable, Sequence from flax.linen.module import Module, compact class Sequential(Module): """Applies a linear chain of Modules. Meant to be used only for the simple case of fusing together callables where the input of a particular module/op is the output of the previous one. Modules will be applied in the order that they are passed in the constructor. The ``__call__`` method of Sequential accepts any input and forwards it to the first module it contains. It chains the output sequentially to the input of the next module and returns the output of the final module. Example usage:: >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([nn.Dense(4), ... nn.relu, ... nn.Dense(2), ... nn.log_softmax])(x) Since `Sequential.__call__` is a `compact` method, you can also pass functions that construct Modules inline if you need shape inference:: module = nn.Sequential([ # << more layers lambda x: SomeModule(x.shape[-1])(x), # shape inference # << more layers ]) This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a ``tuple`` it will be expanded as ``*args`` in the next layer, if its a ``dict`` it will be expanded as ``**kwargs``. Example usage:: >>> class CrossAttentionBlock(nn.Module): ... num_heads: int = 2 ... qkv_features: int = 16 ... ... @nn.compact ... def __call__(self, query, key_value): ... output = nn.MultiHeadDotProductAttention( ... num_heads=self.num_heads, qkv_features=self.qkv_features)(query, ... key_value) ... output = nn.Dense(self.qkv_features)(output) ... return dict(query=output, key_value=key_value) # also works for tuples >>> from typing import Sequence >>> class CrossAttentionNetwork(nn.Module): ... num_layers: Sequence[int] ... ... @nn.compact ... def __call__(self, x): ... return nn.Sequential([CrossAttentionBlock() for _ in ... range(self.num_layers)])(query, key_value) Attributes: layers: A sequence of callables to be applied in order. Raises: ValueError: If layers is not a sequence. """ layers: Sequence[Callable[..., Any]] def __post_init__(self): if not isinstance(self.layers, Sequence): raise ValueError( f"'layers' must be a sequence, got '{type(self.layers).__name__}'." ) super().__post_init__() @compact def __call__(self, *args, **kwargs): if not self.layers: raise ValueError(f'Empty Sequential module {self.name}.') outputs = self.layers[0](*args, **kwargs) for layer in self.layers[1:]: if isinstance(outputs, tuple): outputs = layer(*outputs) elif isinstance(outputs, dict): outputs = layer(**outputs) else: outputs = layer(outputs) return outputs ================================================ FILE: flax/linen/dtypes.py ================================================ # Copyright 2022 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """APIs for handling dtypes in Linen Modules.""" from typing import Any, TypeVar from flax.typing import Dtype from jax import numpy as jnp T = TypeVar('T', bound=tuple) def canonicalize_dtype( *args, dtype: Dtype | None = None, inexact: bool = True ) -> Dtype: """Canonicalize an optional dtype to the definitive dtype. If the ``dtype`` is None this function will infer the dtype. If it is not None it will be returned unmodified or an exceptions is raised if the dtype is invalid. from the input arguments using ``jnp.result_type``. Args: *args: JAX array compatible values. None values are ignored. dtype: Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled. inexact: When True, the output dtype must be a subdtype of `jnp.inexact`. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don't work directly on integers like taking a mean for example. Returns: The dtype that *args should be cast to. """ if dtype is None: args_filtered = [jnp.asarray(x) for x in args if x is not None] dtype = jnp.result_type(*args_filtered) if inexact and not jnp.issubdtype(dtype, jnp.inexact): dtype = jnp.promote_types(jnp.float32, dtype) if inexact and not jnp.issubdtype(dtype, jnp.inexact): raise ValueError(f'Dtype must be inexact: {dtype}') return dtype def promote_dtype(*args, dtype=None, inexact=True) -> list[Any]: """ "Promotes input arguments to a specified or inferred dtype. All args are cast to the same dtype. See ``canonicalize_dtype`` for how this dtype is determined. The behavior of promote_dtype is mostly a convinience wrapper around ``jax.numpy.promote_types``. The differences being that it automatically casts all input to the inferred dtypes, allows inference to be overridden by a forced dtype, and has an optional check to garantuee the resulting dtype is inexact. Args: *args: JAX array compatible values. None values are returned as is. dtype: Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled. inexact: When True, the output dtype must be a subdtype of `jnp.inexact`. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don't work directly on integers like taking a mean for example. Returns: The arguments cast to arrays of the same dtype. """ dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) return [jnp.asarray(x, dtype) if x is not None else None for x in args] ================================================ FILE: flax/linen/experimental/layers_with_named_axes.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Experimental layers with named axes for the partitioning API.""" import dataclasses from typing import Any from collections.abc import Callable, Iterable, Sequence import jax.numpy as jnp from jax import lax from flax import linen as nn from flax.linen import initializers from flax.linen.partitioning import param_with_axes, with_sharding_constraint from flax.typing import ( Array, Dtype, Axes, Initializer, PrecisionLike, DotGeneralT, ) # Type annotations Activation = Callable[..., Array] default_kernel_init = initializers.lecun_normal() default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 ) class Dense(nn.Module): """A Dense layer with named axes for :meth:`jax.experimental.pjit.pjit`. .. warning:: This class is hightly EXPERIMENTAL and the API is likely to change. For regular (non-pjit) use, please use :class:`flax.linen.linear.Dense`. Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: float32). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. """ features: int use_bias: bool = True dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() kernel_axes: tuple[str, ...] = () # Deprecated. Will be removed. dot_general: DotGeneralT | None = None dot_general_cls: Any = None @nn.compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ inputs = jnp.asarray(inputs, self.dtype) kernel = param_with_axes( 'kernel', self.kernel_init, (inputs.shape[-1], self.features), self.param_dtype, axes=self.kernel_axes, ) kernel = jnp.asarray(kernel, self.dtype) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general y = dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if self.use_bias: bias = param_with_axes( 'bias', self.bias_init, (self.features,), self.param_dtype, axes=(self.kernel_axes[-1],), ) bias = jnp.asarray(bias, self.dtype) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y class Embed(nn.Module): """An embedding layer with named axes for :meth:`jax.experimental.pjit.pjit`. .. warning:: This class is hightly EXPERIMENTAL and the API is likely to change. For regular (non-pjit) use, please use :class:`flax.linen.linear.Embed`. Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: float32). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. one_hot: performs the gather with a one-hot contraction rather than a true gather. This is currently needed for SPMD partitioning. """ num_embeddings: int features: int cast_input_dtype: Dtype | None = None dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32 attend_dtype: Dtype | None = None embedding_init: Initializer = default_embed_init one_hot: bool = False embedding: Array = dataclasses.field(init=False) def setup(self): self.embedding = param_with_axes( 'embedding', self.embedding_init, (self.num_embeddings, self.features), self.param_dtype, axes=('vocab', 'embed'), ) def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Returns: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ if self.cast_input_dtype: inputs = inputs.astype(self.cast_input_dtype) if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') if self.one_hot: iota = lax.iota(jnp.int32, self.num_embeddings) one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) else: output = jnp.asarray(self.embedding, self.dtype)[inputs] output = with_sharding_constraint(output, ('batch', 'length', 'embed')) return output def attend(self, query: Array) -> Array: """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth `features` of the embedding. Returns: An array with final dim `num_embeddings` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) def _canonicalize_axes(rank: int, axes: Axes) -> Sequence[int]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, Iterable): axes = (axes,) return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): """Computes the elementwise square of the absolute value |x|^2.""" if jnp.iscomplexobj(x): return lax.square(lax.real(x)) + lax.square(lax.imag(x)) else: return lax.square(x) def _compute_stats(x: Array, axes: Axes): """Computes mean and variance statistics. This implementation takes care of a few important details: - Computes in float32 precision for half precision inputs - mean and variance is computable in a single XLA fusion, by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). - Clips negative variances to zero which can happen due to roundoff errors. This avoids downstream NaNs. - Supports averaging across a parallel axis and subgroups of a parallel axis with a single `lax.pmean` call to avoid latency. """ # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) mean = jnp.mean(x, axes) mean2 = jnp.mean(_abs_sq(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mean2 - _abs_sq(mean)) return mean, var def _normalize( mdl: nn.Module, x: Array, mean: Array, var: Array, reduction_axes: Axes, feature_axes: Axes, dtype: Dtype, param_dtype: Dtype, epsilon: float, use_bias: bool, use_scale: bool, bias_init: Initializer, scale_init: Initializer, ): """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. A separate bias and scale is learned for each feature as specified by feature_axes. """ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) for axis in reduction_axes: stats_shape[axis] = 1 mean = mean.reshape(stats_shape) var = var.reshape(stats_shape) feature_shape = [1] * x.ndim reduced_feature_shape = [] for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) y = x - mean mul = lax.rsqrt(var + epsilon) if use_scale: scale = mdl.param_with_axes( 'scale', scale_init, reduced_feature_shape, param_dtype, axes=('embed',) ).reshape(feature_shape) mul *= scale y *= mul if use_bias: bias = mdl.param_with_axes( 'bias', bias_init, reduced_feature_shape, param_dtype, axes=('embed',) ).reshape(feature_shape) y += bias return jnp.asarray(y, dtype) class LayerNorm(nn.Module): """Layer normalization (https://arxiv.org/abs/1607.06450) with named axes for :meth:`jax.experimental.pjit.pjit`. .. warning:: This class is hightly EXPERIMENTAL and the API is likely to change. For regular (non-pjit) use, please use :class:`flax.linen.normalization.LayerNorm`. Operates on the last axis of the input data. It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Attributes: epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. """ epsilon: float = 1e-6 dtype: Any = jnp.float32 param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros_init() scale_init: Initializer = initializers.ones_init() @nn.compact def __call__(self, x): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ reduction_axes = (-1,) feature_axes = (-1,) mean, var = _compute_stats(x, reduction_axes) return _normalize( self, x, mean, var, reduction_axes, feature_axes, self.dtype, self.param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init, ) ================================================ FILE: flax/linen/fp8_ops.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import itertools import numpy as np import warnings from functools import partial from typing import Any DType = Any import jax from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp from jax._src import core from jax._src import dtypes from jax.typing import DTypeLike try: from jax._src import earray from jax._src.interpreters import pxla CAN_USE_EARRAY = True except (ModuleNotFoundError, ImportError): CAN_USE_EARRAY = False from flax.linen import initializers, module OVERWRITE_WITH_GRADIENT = '_overwrite_with_gradient' # Define a custom dtype for FP8 meta params. class Fp8MetaTyRules: # tell JAX how to lower this dtype to an HLO dtype @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), dtype.float_dtype) if jax.__version_info__ < (0, 4, 29): @staticmethod def replicate_trailing_dims(ctx, val, aval): del ctx, aval return val @staticmethod def logical_sharding(aval, phys_sharding): return phys_sharding @staticmethod def physical_sharding(aval, sharding): return sharding # unlike KeyTyRules, assume same shape # allow conversions to and from the corresponding float type @staticmethod def convert_from(fp8_meta_dtype, other_dtype) -> bool: return fp8_meta_dtype.float_dtype == other_dtype @staticmethod def convert_to(other_dtype, fp8_meta_dtype) -> bool: return fp8_meta_dtype.float_dtype == other_dtype # define how autodiff should accumulate these values @staticmethod def add(dt, x, y): from_fp8_meta = partial(lax.convert_element_type, new_dtype=dt.float_dtype) to_fp8_meta = partial(lax.convert_element_type, new_dtype=dt) return to_fp8_meta(lax.max(from_fp8_meta(x), from_fp8_meta(y))) @staticmethod def zero(dt): neginf = np.array(-np.inf if dtypes.supports_inf(dt.float_dtype) else dtypes.finfo(dt.float_dtype).min, dt.float_dtype) return lax.convert_element_type(neginf, dt) @staticmethod def tangent_dtype(dtype): return dtype @staticmethod def full(shape, fill_value, dtype): fill_value = lax.convert_element_type(fill_value, dtype.float_dtype) out_raw = lax.full(shape, fill_value, dtype.float_dtype) return lax.convert_element_type(out_raw, dtype) @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): if not CAN_USE_EARRAY: raise NotImplementedError("convert back under the jit") phys_sharding = out_sharding # unlike KeyTyRules, assume same shape phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) return lambda bufs: earray.EArray(aval, phys_handler(bufs)) # class to use as second argument to jax.dtypes.issubdtype class fp8_meta_dtype(dtypes.extended): pass # parameterized datatype for use in e.g. lax.convert_element_type @dataclasses.dataclass(frozen=True) class fp8_meta_dtype_wrapper(dtypes.ExtendedDType): float_dtype: dtypes.DType _rules: type = Fp8MetaTyRules type: type = fp8_meta_dtype def __repr__(self) -> str: nbits = dtypes.finfo(self.float_dtype).bits return f'fp8_meta{nbits}' name = property(__repr__) fm32 = fp8_meta_dtype_wrapper(jnp.float32) fp32_max_grad = fp8_meta_dtype_wrapper(jnp.float32) def get_fp8_max(fp8_dtype, out_dtype): assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz) return jnp.finfo(fp8_dtype).max.astype(out_dtype) def quantize(x, q_dtype, scale, compute_dtype): # Explicitly cast the max values to the compute dtype to avoid unnecessary # casting to FP32 during the subsequent math operations." dtype_max = get_fp8_max(q_dtype, compute_dtype) scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape) clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max) return clipped_x.astype(q_dtype) def dequantize(x, dq_dtype, scale): return x.astype(dq_dtype) * jnp.broadcast_to(scale.astype(dq_dtype), x.shape) def qdq(x, q_dtype, scale, compute_dtype): qx = quantize(x, q_dtype, scale, compute_dtype) return dequantize(qx, x.dtype, scale) def compute_scale(amax, scale, fp8_max, margin=0): # The algorithm for computing the new scale is sourced from # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas # wherein the `original_scale` corresponds to the reciprocal of the `scale` # passed in this function. scale = 1.0 / scale sf = (fp8_max / amax) / (2**margin) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) return 1.0 / sf def compute_amax_history(x, amax_history): amax_update = jnp.max(jnp.abs(x)).astype(amax_history.dtype) new_history = jnp.roll(amax_history, shift=-1, axis=0).at[0].set(amax_update) return new_history def update_fp8_meta( x, q_dtype, scale, amax_history ): is_fmax32 = (scale.dtype == fm32 and amax_history.dtype == fm32) # convert fm32->f32 so we can do math if is_fmax32: amax_history = lax.convert_element_type(amax_history, jnp.float32) scale = lax.convert_element_type(scale, jnp.float32) # Update the fp8 meta dtype_max = get_fp8_max(q_dtype, jnp.float32) amax_from_history = jnp.max(amax_history, axis=0) new_scale = compute_scale(amax_from_history, scale, dtype_max) new_history = compute_amax_history(x, amax_history) if is_fmax32: new_history = lax.convert_element_type(new_history, fp32_max_grad) new_scale = lax.convert_element_type(new_scale, fp32_max_grad) return new_scale, new_history def quantize_dequantize_update(x, q_dtype, scale, amax_history, compute_dtype): updated_scale, updated_history = update_fp8_meta(x, q_dtype, scale, amax_history) qdq_x = qdq(x, q_dtype, _fm32_to_float32(updated_scale), compute_dtype) return qdq_x, updated_scale, updated_history def _fm32_to_float32(value): if value.dtype == fm32: return lax.convert_element_type(value, jnp.float32) return value def dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, swap_ans=False): def _remaining(original, *removed_lists): removed = set(itertools.chain(*removed_lists)) return [i for i in original if i not in removed] def _ranges_like(*xs): start = 0 for x in xs: x_len = len(x) yield range(start, start + x_len) start += x_len (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim x_kept = _remaining(range(x_ndim), x_contract, x_batch) y_kept = _remaining(range(np.ndim(y)), y_contract, y_batch) if swap_ans: ans_batch, ans_y, _ = _ranges_like(x_batch, y_kept, x_kept) else: ans_batch, _, ans_y = _ranges_like(x_batch, x_kept, y_kept) dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) x_bar = lax.transpose( lax.dot_general( g, y, dims, precision=precision, preferred_element_type=preferred_element_type ), tuple(out_axes) ) return x_bar def dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) y_bar = dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, swap_ans=True) return y_bar @partial(custom_vjp, nondiff_argnums=(0, 1)) def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history): qin, _, _ = quantize_dequantize_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history): qin, new_scale, new_history = quantize_dequantize_update( inp, q_dtype, scale, amax_history, compute_dtype ) return qin, (new_scale, new_history) def in_qdq_bwd(compute_dtype, q_dtype, res, g): new_scale, new_history = res q_g = g return q_g, new_scale, new_history in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd) @partial(custom_vjp, nondiff_argnums=(0, 1)) def out_qdq(compute_dtype, q_dtype, out, scale, amax_history): return out def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history): return out, (scale, amax_history) def out_qdq_bwd(compute_dtype, q_dtype, res, g): scale, amax_history = res q_g, new_scale, new_history = quantize_dequantize_update( g, q_dtype, scale, amax_history, compute_dtype ) return q_g, new_scale, new_history out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) @partial(custom_vjp, nondiff_argnums=(0, 1)) def in_q(compute_dtype, q_dtype, inp, scale, amax_history): new_scale, _ = update_fp8_meta(inp, q_dtype, scale, amax_history) qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype) return qin, new_scale def in_q_fwd(compute_dtype, q_dtype, inp, scale, amax_history): new_scale, new_history = update_fp8_meta(inp, q_dtype, scale, amax_history) qin = quantize(inp, q_dtype, _fm32_to_float32(new_scale), compute_dtype) return (qin, new_scale), (new_scale, new_history) def in_q_bwd(compute_dtype, q_dtype, res, _): new_scale, new_history = res # We don't compute gradients for inp, scale and amax_history, but we pass through scale and history return None, new_scale, new_history in_q.defvjp(in_q_fwd, in_q_bwd) @partial(custom_vjp, nondiff_argnums=(0, )) def out_dq(dq_type, lhs_scale, rhs_scale, out): q_out = dequantize( out, dq_type, _fm32_to_float32(lhs_scale) * _fm32_to_float32(rhs_scale) ) return q_out def out_dq_fwd(dq_type, lhs_scale, rhs_scale, out): return out_dq(dq_type, lhs_scale, rhs_scale, out), None def out_dq_bwd(dq_type, _, g): return None, None, g out_dq.defvjp(out_dq_fwd, out_dq_bwd) @partial(custom_vjp, nondiff_argnums=(8, 9)) def quantized_dot( lhs, q_lhs, lhs_scale, # scale for this step rhs, q_rhs, rhs_scale, # scale for this step out_grad_scale, # scale from previous step out_grad_amax_history, # amax history from previous step dimension_numbers, preferred_element_type=None ): return lax.dot_general( q_lhs, q_rhs, dimension_numbers, preferred_element_type=preferred_element_type, precision=lax.Precision.DEFAULT, ) def quantized_dot_fwd( lhs, q_lhs, lhs_scale, rhs, q_rhs, rhs_scale, out_grad_scale, out_grad_amax_history, dimension_numbers, preferred_element_type, ): out = lax.dot_general( q_lhs, q_rhs, dimension_numbers, preferred_element_type=preferred_element_type, precision=lax.Precision.DEFAULT, ) res = ( lhs, q_lhs, lhs_scale, rhs, q_rhs, rhs_scale, out_grad_scale, out_grad_amax_history, ) return out, res def quantized_dot_bwd( dimension_numbers, preferred_element_type, res, g ): ( lhs, q_lhs, lhs_scale, rhs, q_rhs, rhs_scale, out_grad_scale, out_grad_amax_history, ) = res new_out_grad_scale, new_out_grad_amax_history = update_fp8_meta( g, jnp.float8_e5m2, out_grad_scale, out_grad_amax_history, ) q_g = quantize( g, jnp.float8_e5m2, _fm32_to_float32(new_out_grad_scale), preferred_element_type ) grad_lhs = dot_general_transpose_lhs( q_g, lhs, q_rhs, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) grad_lhs = dequantize( grad_lhs, preferred_element_type, _fm32_to_float32(rhs_scale) * _fm32_to_float32(new_out_grad_scale) ) grad_rhs = dot_general_transpose_rhs( q_g, q_lhs, rhs, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST, preferred_element_type=preferred_element_type, ) grad_rhs = dequantize( grad_rhs, preferred_element_type, _fm32_to_float32(lhs_scale) * _fm32_to_float32(new_out_grad_scale) ) return ( grad_lhs, None, None, grad_rhs, None, None, new_out_grad_scale, new_out_grad_amax_history, ) quantized_dot.defvjp(quantized_dot_fwd, quantized_dot_bwd) # Wrapper function to achieve the same effect as the dot_general function # but with fp8 quantization and dequantization. def fp8_scaled_dot_general( lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, *, lhs_scale=None, rhs_scale=None, grad_scale=None, lhs_amax_history=None, rhs_amax_history=None, grad_amax_history=None, quantize_compute_type=jnp.float32, ): if precision != None: warnings.warn( 'The function fp8_scaled_dot_general will set the "precision" and ' 'disregard any provided "precision" argument.' ) q_lhs, new_lhs_scale = in_q( quantize_compute_type, jnp.float8_e4m3fn, lhs, lhs_scale, lhs_amax_history ) q_rhs, new_rhs_scale = in_q( quantize_compute_type, jnp.float8_e4m3fn, rhs, rhs_scale, rhs_amax_history ) y = quantized_dot( lhs, q_lhs, new_lhs_scale, rhs, q_rhs, new_rhs_scale, grad_scale, grad_amax_history, dimension_numbers, preferred_element_type ) y = out_dq( dq_type=preferred_element_type, lhs_scale=new_lhs_scale, rhs_scale=new_rhs_scale, out=y ) return y # type: ignore @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) def dot_general_with_precision( lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None ): if precision != None or preferred_element_type != None: warnings.warn( 'The function dot_general_with_precision will set the ' 'precision/preferred_element_type and disregard any provided ' 'values.' ) return lax.dot_general( lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT ) @dot_general_with_precision.defjvp def dot_general_with_precision_jvp( dimension_numbers, precision, preferred_element_type, primals, tangents ): lhs, rhs = primals lhs_dot, rhs_dot = tangents out = lax.dot_general( lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT ) grad_out = lax.dot_general( lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST ) + lax.dot_general( lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST ) return out, grad_out def _parse_dot_inputs(*args, **kwargs): assert len(args) == 3 x = args[0] k = args[1] dimension_numbers = args[2] # Use the `k.dtype` since it aligns with the `dtype` of its layers, # namely, the computation data type. comp_dtype = k.dtype x = jnp.asarray(x, comp_dtype) return x, k, dimension_numbers, comp_dtype class Fp8DotGeneralBase(module.Module): amax_history_length: int = 1024 e4m3_dtype: DType = jnp.float8_e4m3fn e5m2_dtype: DType = jnp.float8_e5m2 def setup(self) -> None: scale_args = ( initializers.ones_init(), random.PRNGKey(0), (1,), jnp.float32, ) amax_history_args = ( initializers.zeros_init(), random.PRNGKey(0), (self.amax_history_length,), jnp.float32, ) self.input_amax_history = self.variable( OVERWRITE_WITH_GRADIENT, 'input_amax_history', *amax_history_args ) self.kernel_amax_history = self.variable( OVERWRITE_WITH_GRADIENT, 'kernel_amax_history', *amax_history_args ) self.output_grad_amax_history = self.variable( OVERWRITE_WITH_GRADIENT, 'output_grad_amax_history', *amax_history_args ) self.input_scale = self.variable( OVERWRITE_WITH_GRADIENT, 'input_scale', *scale_args ) self.kernel_scale = self.variable( OVERWRITE_WITH_GRADIENT, 'kernel_scale', *scale_args ) self.output_grad_scale = self.variable( OVERWRITE_WITH_GRADIENT, 'output_grad_scale', *scale_args ) class Fp8DotGeneralOp(Fp8DotGeneralBase): def __post_init__(self): super().__post_init__() if type(self) is Fp8DotGeneralOp: warnings.warn( 'The Fp8DotGeneralOp is deprecated. Use Fp8DirectDotGeneralOp or ' 'Fp8Einsum instead.', DeprecationWarning, ) def __call__(self, *args, **kwargs): x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( *args, **kwargs ) x_qdq = in_qdq( comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value ) k_qdq = in_qdq( comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, self.e5m2_dtype, y_qdq, self.output_grad_scale.value, self.output_grad_amax_history.value, ) return y # type: ignore class Fp8DirectDotGeneralOp(Fp8DotGeneralBase): def __call__(self, *args, **kwargs): x, k, dimension_numbers, comp_dtype = _parse_dot_inputs( *args, **kwargs ) y = fp8_scaled_dot_general( x, k, dimension_numbers, precision=None, preferred_element_type=x.dtype, lhs_scale=self.input_scale.value, rhs_scale=self.kernel_scale.value, grad_scale=self.output_grad_scale.value, lhs_amax_history=self.input_amax_history.value, rhs_amax_history=self.kernel_amax_history.value, grad_amax_history=self.output_grad_amax_history.value, quantize_compute_type=comp_dtype, ) return y # type: ignore class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): e4m3_dtype: DType = jnp.float8_e4m3fnuz e5m2_dtype: DType = jnp.float8_e5m2fnuz class Fp8Einsum(Fp8DotGeneralBase): def __call__(self, eqn, lhs: jnp.ndarray, rhs: jnp.ndarray, precision: lax.Precision | None = None, preferred_element_type: DTypeLike | None = None) -> jnp.ndarray: # Here we assume that the rhs is the weight and its dtype is the actual compute dtype (not storage dtype). # TODO(kaixih@nvidia): Better way to handle this? actual_compute_dtype = rhs.dtype lhs = lhs.astype(actual_compute_dtype) dot_general_fn = partial( fp8_scaled_dot_general, lhs_scale=self.input_scale.value, rhs_scale=self.kernel_scale.value, grad_scale=self.output_grad_scale.value, lhs_amax_history=self.input_amax_history.value, rhs_amax_history=self.kernel_amax_history.value, grad_amax_history=self.output_grad_amax_history.value, quantize_compute_type=actual_compute_dtype ) out = jnp.einsum(eqn, lhs, rhs, precision=precision, preferred_element_type=preferred_element_type, _dot_general=dot_general_fn) return out # Alias for backward compatibility Fp8DotGeneral = Fp8DirectDotGeneralOp ================================================ FILE: flax/linen/initializers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Initializers for Flax.""" # pylint: disable=unused-import # re-export initializer functions from jax.nn from jax.nn.initializers import constant as constant from jax.nn.initializers import delta_orthogonal as delta_orthogonal from jax.nn.initializers import glorot_normal as glorot_normal from jax.nn.initializers import glorot_uniform as glorot_uniform from jax.nn.initializers import he_normal as he_normal from jax.nn.initializers import he_uniform as he_uniform from jax.nn.initializers import kaiming_normal as kaiming_normal from jax.nn.initializers import kaiming_uniform as kaiming_uniform from jax.nn.initializers import lecun_normal as lecun_normal from jax.nn.initializers import lecun_uniform as lecun_uniform from jax.nn.initializers import normal as normal from jax.nn.initializers import ones as ones from jax.nn.initializers import orthogonal as orthogonal from jax.nn.initializers import truncated_normal as truncated_normal from jax.nn.initializers import uniform as uniform from jax.nn.initializers import variance_scaling as variance_scaling from jax.nn.initializers import xavier_normal as xavier_normal from jax.nn.initializers import xavier_uniform as xavier_uniform from jax.nn.initializers import zeros as zeros from flax.typing import Initializer as Initializer # pylint: enable=unused-import def zeros_init() -> Initializer: """Builds an initializer that returns a constant array full of zeros. >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import zeros_init >>> zeros_initializer = zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ return zeros def ones_init() -> Initializer: """Builds an initializer that returns a constant array full of ones. >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) """ return ones ================================================ FILE: flax/linen/kw_only_dataclasses.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Support for keyword-only fields in dataclasses for Python versions <3.10. This module provides wrappers for `dataclasses.dataclass` and `dataclasses.field` that simulate support for keyword-only fields for Python versions before 3.10 (which is the version where dataclasses added keyword-only field support). If this module is imported in Python 3.10+, then `kw_only_dataclasses.dataclass` and `kw_only_dataclasses.field` will simply be aliases for `dataclasses.dataclass` and `dataclasses.field`. For earlier Python versions, when constructing a dataclass, any fields that have been marked as keyword-only (including inherited fields) will be moved to the end of the constructor's argument list. This makes it possible to have a base class that defines a field with a default, and a subclass that defines a field without a default. E.g.: >>> from flax.linen import kw_only_dataclasses >>> @kw_only_dataclasses.dataclass ... class Parent: ... name: str = kw_only_dataclasses.field(default='', kw_only=True) >>> @kw_only_dataclasses.dataclass ... class Child(Parent): ... size: float # required. >>> import inspect >>> print(inspect.signature(Child.__init__)) (self, size: float, name: str = '') -> None (If we used `dataclasses` rather than `kw_only_dataclasses` for the above example, then it would have failed with TypeError "non-default argument 'size' follows default argument.") WARNING: fields marked as keyword-only will not *actually* be turned into keyword-only parameters in the constructor; they will only be moved to the end of the parameter list (after all non-keyword-only parameters). """ import dataclasses import functools import inspect import sys from types import MappingProxyType from typing import Any, TypeVar import typing_extensions as tpe import flax M = TypeVar('M', bound='flax.linen.Module') FieldName = str Annotation = Any Default = Any class _KwOnlyType: """Metadata tag used to tag keyword-only fields.""" def __repr__(self): return 'KW_ONLY' KW_ONLY = _KwOnlyType() def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): """Wrapper for dataclasses.field that adds support for kw_only fields. Args: metadata: A mapping or None, containing metadata for the field. kw_only: If true, the field will be moved to the end of `__init__`'s parameter list. **kwargs: Keyword arguments forwarded to `dataclasses.field` Returns: A `dataclasses.Field` object. """ if kw_only is not dataclasses.MISSING and kw_only: if ( kwargs.get('default', dataclasses.MISSING) is dataclasses.MISSING and kwargs.get('default_factory', dataclasses.MISSING) is dataclasses.MISSING ): raise ValueError('Keyword-only fields with no default are not supported.') if metadata is None: metadata = {} metadata[KW_ONLY] = True return dataclasses.field(metadata=metadata, **kwargs) @tpe.dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] def dataclass(cls=None, extra_fields=None, **kwargs): """Wrapper for dataclasses.dataclass that adds support for kw_only fields. Args: cls: The class to transform (or none to return a decorator). extra_fields: A list of `(name, type, Field)` tuples describing extra fields that should be added to the dataclass. This is necessary for linen's use-case of this module, since the base class (linen.Module) is *not* a dataclass. In particular, linen.Module class is used as the base for both frozen and non-frozen dataclass subclasses; but the frozen status of a dataclass must match the frozen status of any base dataclasses. **kwargs: Additional arguments for `dataclasses.dataclass`. Returns: `cls`. """ def wrap(cls): return _process_class(cls, extra_fields=extra_fields, **kwargs) return wrap if cls is None else wrap(cls) def _process_class(cls: type[M], extra_fields=None, **kwargs): """Transforms `cls` into a dataclass that supports kw_only fields.""" if sys.version_info < (3, 14) and '__annotations__' not in cls.__dict__: cls.__annotations__ = {} # The original __dataclass_fields__ dicts for all base classes. We will # modify these in-place before turning `cls` into a dataclass, and then # restore them to their original values. base_dataclass_fields = {} # dict[cls, cls.__dataclass_fields__.copy()] # The keyword only fields from `cls` or any of its base classes. kw_only_fields: dict[FieldName, tuple[Annotation, Default]] = {} # Scan for KW_ONLY marker. kw_only_name = None for name, annotation in cls.__annotations__.items(): if annotation is KW_ONLY: if kw_only_name is not None: raise TypeError('Multiple KW_ONLY markers') kw_only_name = name elif kw_only_name is not None: if not hasattr(cls, name): raise ValueError( 'Keyword-only fields with no default are not supported.' ) default = getattr(cls, name) if isinstance(default, dataclasses.Field): default.metadata = MappingProxyType({**default.metadata, KW_ONLY: True}) else: default = field(default=default, kw_only=True) setattr(cls, name, default) if kw_only_name: del cls.__annotations__[kw_only_name] # Inject extra fields. if extra_fields: for name, annotation, default in extra_fields: if not (isinstance(name, str) and isinstance(default, dataclasses.Field)): raise ValueError( 'Expected extra_fields to a be a list of ' '(name, type, Field) tuples.' ) setattr(cls, name, default) cls.__annotations__[name] = annotation # Extract kw_only fields from base classes' __dataclass_fields__. for base in reversed(cls.__mro__[1:]): if not dataclasses.is_dataclass(base): continue if sys.version_info < (3, 14): base_annotations = base.__dict__.get('__annotations__', {}) else: base_annotations = inspect.get_annotations(base) base_dataclass_fields[base] = dict( getattr(base, '__dataclass_fields__', {}) ) for base_field in list(dataclasses.fields(base)): field_name = base_field.name if base_field.metadata.get(KW_ONLY) or field_name in kw_only_fields: kw_only_fields[field_name] = ( base_annotations.get(field_name), base_field, ) del base.__dataclass_fields__[field_name] # Remove any keyword-only fields from this class. if sys.version_info < (3, 14): cls_annotations = cls.__dict__['__annotations__'] else: cls_annotations = cls.__annotations__ for name, annotation in list(cls_annotations.items()): value = getattr(cls, name, None) if ( isinstance(value, dataclasses.Field) and value.metadata.get(KW_ONLY) ) or name in kw_only_fields: del cls_annotations[name] kw_only_fields[name] = (annotation, value) # Add keyword-only fields at the end of __annotations__, in the order they # were found in the base classes and in this class. for name, (annotation, default) in kw_only_fields.items(): setattr(cls, name, default) cls_annotations.pop(name, None) cls_annotations[name] = annotation create_init = '__init__' not in vars(cls) and kwargs.get('init', True) # Apply the dataclass transform. transformed_cls: type[M] = dataclasses.dataclass(cls, **kwargs) # Restore the base classes' __dataclass_fields__. for _cls, fields in base_dataclass_fields.items(): _cls.__dataclass_fields__ = fields if create_init: dataclass_init = transformed_cls.__init__ # use sum to count the number of init fields that are not keyword-only expected_num_args = sum( f.init and not f.metadata.get(KW_ONLY, False) for f in dataclasses.fields(transformed_cls) ) @functools.wraps(dataclass_init) def init_wrapper(self, *args, **kwargs): num_args = len(args) if num_args > expected_num_args: # we add + 1 to each to account for `self`, matching python's # default error message raise TypeError( f'__init__() takes {expected_num_args + 1} positional ' f'arguments but {num_args + 1} were given' ) dataclass_init(self, *args, **kwargs) init_wrapper.__signature__ = inspect.signature(dataclass_init) # type: ignore transformed_cls.__init__ = init_wrapper # type: ignore[method-assign] # Return the transformed dataclass return transformed_cls ================================================ FILE: flax/linen/linear.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Linear modules.""" from collections.abc import Iterable, Sequence from typing import Any, Protocol from flax.core import meta from flax.linen import initializers from flax.linen import module from flax.linen.dtypes import promote_dtype from flax.linen.module import Module, compact from flax.typing import ( Array, ConvGeneralDilatedT, DotGeneralT, Dtype, Initializer, LaxPadding, PRNGKey as PRNGKey, PaddingLike, PrecisionLike, Shape as Shape, ) import jax from jax import eval_shape, lax from jax.core import ShapedArray import jax.numpy as jnp import numpy as np import opt_einsum class PromoteDtypeFn(Protocol): def __call__( self, *args: jax.Array | None, dtype: Any = None, inexact: bool = True ) -> list[jax.Array | None]: ... default_kernel_init = initializers.lecun_normal() def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes)) def _canonicalize_tuple(x: Sequence[int] | int) -> tuple[int, ...]: if isinstance(x, Iterable): return tuple(x) else: return (x,) class DenseGeneral(Module): """A linear transformation with flexible axes. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # equivalent to `nn.Dense(features=4)` >>> layer = nn.DenseGeneral(features=4) >>> # output features (4, 5) >>> layer = nn.DenseGeneral(features=(4, 5)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} >>> # apply transformation on the the second and last axes >>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}} Attributes: features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see ``jax.lax.Precision`` for details. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ features: int | Sequence[int] axis: int | Sequence[int] = -1 batch_dims: Sequence[int] = () use_bias: bool = True dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() precision: PrecisionLike = None promote_dtype: PromoteDtypeFn = promote_dtype # Deprecated. Will be removed. dot_general: DotGeneralT | None = None dot_general_cls: Any = None @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims) ) ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) def kernel_init_wrap(rng, shape, dtype=jnp.float32): flat_shape = ( np.prod(shape[:n_batch_dims]) * np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), np.prod(shape[-n_features:]), ) flat_shape = jax.tree_util.tree_map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, meta.AxisMetadata): return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape)) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims) # batch and non-contracting dims of input with 1s for batch dims. expanded_batch_shape = tuple( inputs.shape[ax] if ax in batch_dims else 1 for ax in range(inputs.ndim) if ax not in axis ) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel = self.param( 'kernel', kernel_init_wrap, batch_shape + kernel_shape, self.param_dtype ) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) if self.use_bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): flat_shape = ( np.prod(shape[:n_batch_dims]) * np.prod(shape[-n_features:]), ) flat_shape = jax.tree_util.tree_map(int, flat_shape) bias = self.bias_init(rng, flat_shape, dtype) if isinstance(bias, meta.AxisMetadata): return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape)) return jnp.reshape(bias, shape) bias = self.param( 'bias', bias_init_wrap, batch_shape + features, self.param_dtype ) else: bias = None inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general out = dot_general( inputs, kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=self.precision, ) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: # expand bias shape to broadcast bias over batch dims. assert bias is not None bias = jnp.reshape(bias, expanded_batch_shape + features) out += bias return out class Dense(Module): """A linear transformation applied over the last dimension of the input. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Dense(features=4) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_util.tree_map(jnp.shape, params) {'params': {'bias': (4,), 'kernel': (3, 4)}} Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ features: int use_bias: bool = True dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() promote_dtype: PromoteDtypeFn = promote_dtype dot_general: DotGeneralT | None = None dot_general_cls: Any = None @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ kernel = self.param( 'kernel', self.kernel_init, (jnp.shape(inputs)[-1], self.features), self.param_dtype, ) if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,), self.param_dtype ) else: bias = None inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) assert inputs is not None assert kernel is not None if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general y = dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y class Einsum(Module): """An einsum transformation with learnable kernel and bias. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde') >>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}} Attributes: shape: the shape of the kernel. einsum_str: a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of ``einsum_str`` in the constructor argument and call argument must be not None, while the other must be None. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ shape: Shape einsum_str: str | None = None use_bias: bool = True dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() promote_dtype: PromoteDtypeFn = promote_dtype preferred_element_type: Dtype | None = None @compact def __call__(self, inputs: Array, einsum_str: str | None = None) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. einsum_str: a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. The ``einsum_str`` passed into the call method will take precedence over the ``einsum_str`` passed into the constructor. Returns: The transformed input. """ einsum_str = module.merge_param('einsum_str', self.einsum_str, einsum_str) einsum_str = einsum_str.replace(' ', '') if '->' not in einsum_str: raise ValueError( '`einsum_str` equation must be explicit and include "->".' ) if einsum_str.count(',') != 1: raise ValueError( '`einsum_str` equation must have exactly two operands and ' 'therefore, exactly one comma character, instead of ' f'{einsum_str.count(",")}' ) kernel = self.param( 'kernel', self.kernel_init, self.shape, self.param_dtype, ) if self.use_bias: bias_shape, broadcasted_bias_shape = self._get_bias_shape( einsum_str, inputs, kernel ) bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) else: bias = None inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) y = jnp.einsum( einsum_str, inputs, kernel, precision=self.precision, preferred_element_type=self.preferred_element_type, ) if bias is not None: y += jnp.reshape(bias, broadcasted_bias_shape) return y def _get_bias_shape(self, einsum_str: str, lhs: Array, rhs: Array): """Infer the bias shape and broadcasted bias shape given the ``einsum_str``, ``lhs`` and ``rhs`` arrays. This is needed for instantiating the bias parameter and adding the bias to the output during forward inference. This function first replaces all ellipses with actual letter characters, then computes the bias shape by checking to see which axes in the rhs array remain in the resulting array after einsumming. These axes are the embedding/feature dimensions, and all other axes in rhs are reduction axes. """ # More details on the parsing function: https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/parser.py#L246 # returns the einsum string representation of the operands and result, with # ellipsis replaced by actual letter characters operands_str, result_str, _ = opt_einsum.parser.parse_einsum_input( (einsum_str, lhs, rhs) ) # rhs_dict is a dict{character:index} mapping that maps every character in # the rhs einsum string representation to its corresponding index position in the string rhs_dict = {c: i for i, c in enumerate(operands_str.split(',')[1])} assert len(rhs_dict) == len(self.shape) broadcasted_bias_shape = [1] * len(result_str) bias_shape = [] for i, c in enumerate(result_str): if c in rhs_dict: broadcasted_bias_shape[i] = self.shape[rhs_dict[c]] bias_shape.append(self.shape[rhs_dict[c]]) return bias_shape, broadcasted_bias_shape def _conv_dimension_numbers(input_shape): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: """ "Canonicalizes conv padding to a jax.lax supported format.""" if isinstance(padding, str): return padding if isinstance(padding, int): return [(padding, padding)] * rank if isinstance(padding, Sequence) and len(padding) == rank: new_pad = [] for p in padding: if isinstance(p, int): new_pad.append((p, p)) elif isinstance(p, tuple) and len(p) == 2: new_pad.append(p) else: break if len(new_pad) == rank: return new_pad raise ValueError( f'Invalid padding format: {padding}, should be str, int,' f' or a sequence of len {rank} where each element is an' ' int or pair of ints.' ) class _Conv(Module): """Convolution Module wrapping ``lax.conv_general_dilated``. Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'` (reflection across the padding boundary), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of ``inputs`` (default: 1). Convolution with input dilation ``d`` is equivalent to transposed convolution with stride ``d``. kernel_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ features: int kernel_size: int | Sequence[int] strides: None | int | Sequence[int] = 1 padding: PaddingLike = 'SAME' input_dilation: None | int | Sequence[int] = 1 kernel_dilation: None | int | Sequence[int] = 1 feature_group_count: int = 1 use_bias: bool = True mask: Array | None = None dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() promote_dtype: PromoteDtypeFn = promote_dtype # Deprecated. Will be removed. conv_general_dilated: ConvGeneralDilatedT | None = None conv_general_dilated_cls: Any = None @property def shared_weights(self) -> bool: # type: ignore """Defines whether weights are shared or not between different pixels. Returns: ``True`` to use shared weights in convolution (regular convolution). ``False`` to use different weights at different pixels, a.k.a. "locally connected layer", "unshared convolution", or "local convolution". """ ... @compact def __call__(self, inputs: Array) -> Array: """Applies a (potentially unshared) convolution to the inputs. Args: inputs: input data with dimensions ``(*batch_dims, spatial_dims..., features)``. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated``, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. Returns: The convolved data. """ kernel_size: Sequence[int] if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) else: kernel_size = tuple(self.kernel_size) def maybe_broadcast( x: int | Sequence[int] | None, ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 x = 1 if isinstance(x, int): return (x,) * len(kernel_size) return tuple(x) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] flat_input_shape = (-1,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) # self.strides or (1,) * (inputs.ndim - 2) strides = maybe_broadcast(self.strides) input_dilation = maybe_broadcast(self.input_dilation) kernel_dilation = maybe_broadcast(self.kernel_dilation) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax in ('CIRCULAR', 'REFLECT'): assert isinstance(padding_lax, str) kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] zero_pad: list[tuple[int, int]] = [(0, 0)] pads = ( zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] ) padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax] inputs = jnp.pad(inputs, pads, mode=padding_mode) padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: raise ValueError( 'Causal padding is only implemented for 1D convolutions.' ) left_pad = kernel_dilation[0] * (kernel_size[0] - 1) pads = [(0, 0), (left_pad, 0), (0, 0)] inputs = jnp.pad(inputs, pads) padding_lax = 'VALID' dimension_numbers = _conv_dimension_numbers(inputs.shape) in_features = jnp.shape(inputs)[-1] if self.shared_weights: # One shared convolutional kernel for all pixels in the output. assert in_features % self.feature_group_count == 0 kernel_shape = kernel_size + ( in_features // self.feature_group_count, self.features, ) else: if self.feature_group_count != 1: raise NotImplementedError( '`lax.conv_general_dilated_local` does not support ' f'`feature_group_count != 1`, got `{self.feature_group_count}`.' ) # Need to know the spatial output shape of a standard convolution to # create the unshared convolution kernel. if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated else: conv_general_dilated = lax.conv_general_dilated conv_output_shape = eval_shape( lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda lhs=lhs, rhs=rhs, window_strides=strides, padding=padding_lax, dimension_numbers=dimension_numbers, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, ), inputs, ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), ).shape # One (unshared) convolutional kernel per each pixel in the output. kernel_shape = conv_output_shape[1:-1] + ( np.prod(kernel_size) * in_features, self.features, ) if self.mask is not None and self.mask.shape != kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {kernel_shape}' ) kernel = self.param( 'kernel', self.kernel_init, kernel_shape, self.param_dtype ) if self.mask is not None: kernel *= self.mask if self.use_bias: if self.shared_weights: # One bias weight per output channel, shared between pixels. bias_shape = (self.features,) else: # One bias weight per output entry, unshared betwen pixels. bias_shape = conv_output_shape[1:] bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) else: bias = None inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) assert inputs is not None assert kernel is not None if self.shared_weights: if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated else: conv_general_dilated = lax.conv_general_dilated y = conv_general_dilated( inputs, kernel, strides, padding_lax, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) else: y = lax.conv_general_dilated_local( lhs=inputs, rhs=kernel, window_strides=strides, padding=padding_lax, filter_shape=kernel_size, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, precision=self.precision, ) if self.use_bias: bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) # type: ignore y += bias if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y class Conv(_Conv): """Convolution Module wrapping ``lax.conv_general_dilated``. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string ``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of ``inputs`` (default: 1). Convolution with input dilation ``d`` is equivalent to transposed convolution with stride ``d``. kernel_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ @property def shared_weights(self) -> bool: return True class ConvLocal(_Conv): """Local convolution Module wrapping ``lax.conv_general_dilated_local``. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((6, 9, 4))) >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string ``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of ``inputs`` (default: 1). Convolution with input dilation ``d`` is equivalent to transposed convolution with stride ``d``. kernel_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ @property def shared_weights(self) -> bool: return False class ConvTranspose(Module): """Convolution Module wrapping ``lax.conv_transpose``. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers. strides: an integer or a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, the string `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. kernel_dilation: ``None``, or an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. transpose_kernel: if ``True`` flips spatial axes and swaps the input/output channel axes of the kernel. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ features: int kernel_size: int | Sequence[int] strides: Sequence[int] | None = None padding: PaddingLike = 'SAME' kernel_dilation: Sequence[int] | None = None use_bias: bool = True mask: Array | None = None dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() transpose_kernel: bool = False promote_dtype: PromoteDtypeFn = promote_dtype preferred_element_type: Dtype | None = None @compact def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of ``jax.lax.conv_transpose``. Args: inputs: input data with dimensions ``(*batch_dims, spatial_dims..., features).`` This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated``, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. Returns: The convolved data. """ kernel_size: tuple[int, ...] if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) else: kernel_size = tuple(self.kernel_size) def maybe_broadcast( x: int | Sequence[int] | None, ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 x = 1 if isinstance(x, int): return (x,) * len(kernel_size) return tuple(x) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] flat_input_shape = (-1,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) strides = maybe_broadcast(self.strides) kernel_dilation = maybe_broadcast(self.kernel_dilation) in_features = jnp.shape(inputs)[-1] if self.transpose_kernel: kernel_shape = kernel_size + (self.features, in_features) else: kernel_shape = kernel_size + (in_features, self.features) if self.mask is not None and self.mask.shape != kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {kernel_shape}' ) kernel = self.param( 'kernel', self.kernel_init, kernel_shape, self.param_dtype ) if self.mask is not None: kernel *= self.mask padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax == 'CIRCULAR': padding_lax = 'VALID' if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,), self.param_dtype ) else: bias = None inputs, kernel, bias = self.promote_dtype( inputs, kernel, bias, dtype=self.dtype ) assert inputs is not None assert kernel is not None y = lax.conv_transpose( inputs, kernel, strides, padding_lax, rhs_dilation=kernel_dilation, transpose_kernel=self.transpose_kernel, precision=self.precision, preferred_element_type=self.preferred_element_type, ) if self.padding == 'CIRCULAR': # For circular padding, we need to identify the size of the final output # ("period") along each spatial dimension, pad each dimension to an # integer number of periods, and wrap the array periodically around each # dimension. Padding should be done in such a way that the start of the # original input data inside the padded array is located at integer # number of periods - otherwise the result would be circularly shifted. # Compute period along each spatial dimension - it's input size scaled # by the stride. scaled_x_dims = [ x_dim * stride for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) ] # Compute difference between the current size of y and the final output # size, and complement this difference to 2 * period - that gives how # much we need to pad. size_diffs = [ -(y_dim - x_dim) % (2 * x_dim) for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) ] if self.transpose_kernel: # If the kernel is transposed, the "+1" is put on the right to # mirror the regular convolution. If the same kernel parameters are used # as for Conv, this layer then computes the proper transpose convolution. total_pad = [ (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs ] else: # Divide the padding equally between left and right. The choice to put # "+1" on the left (and not on the right) represents a convention for # aligning even-sized kernels. total_pad = [ ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs ] y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)]) # Wrap the result periodically around each spatial dimension, # one by one. for i in range(1, y.ndim - 1): y = y.reshape( y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :] ) y = y.sum(axis=i) if self.use_bias: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # type: ignore if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 ) class Embed(Module): """Embedding Module. A parameterized function from integers [0, ``num_embeddings``) to ``features``-dimensional vectors. This ``Module`` will create an ``embedding`` matrix with shape ``(num_embeddings, features)``. When calling this layer, the input values will be used to 0-index into the ``embedding`` matrix. Indexing on a value greater than or equal to ``num_embeddings`` will result in ``nan`` values. When ``num_embeddings`` equals to 1, it will broadcast the ``embedding`` matrix to input shape with ``features`` dimension appended. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Embed(num_embeddings=5, features=3) >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> variables = layer.init(jax.random.key(0), indices_input) >>> variables {'params': {'embedding': Array([[ 0.04396089, -0.9328513 , -0.97328115], [ 0.41147125, 0.66334754, 0.49469155], [ 0.09719624, 0.49861377, 0.49519277], [-0.13316602, 0.6697022 , 0.3710195 ], [-0.5039532 , 0.287319 , 1.4369922 ]], dtype=float32)}} >>> # get the first three and last three embeddings >>> layer.apply(variables, indices_input) Array([[[ 0.04396089, -0.9328513 , -0.97328115], [ 0.41147125, 0.66334754, 0.49469155], [ 0.09719624, 0.49861377, 0.49519277]], [[-0.5039532 , 0.287319 , 1.4369922 ], [-0.13316602, 0.6697022 , 0.3710195 ], [ 0.09719624, 0.49861377, 0.49519277]]], dtype=float32) Attributes: num_embeddings: number of embeddings / vocab size. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(embedding,)`` during ``__call__`` or ``(query, embedding)`` during ``attend``, and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. """ num_embeddings: int features: int dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 embedding_init: Initializer = default_embed_init promote_dtype: PromoteDtypeFn = promote_dtype def setup(self): self.embedding = self.param( 'embedding', self.embedding_init, (self.num_embeddings, self.features), self.param_dtype, ) def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Values in the input array must be integers. Returns: Output which is embedded input data. The output shape follows the input, with an additional ``features`` dimension appended. """ if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = self.promote_dtype( self.embedding, dtype=self.dtype, inexact=False ) assert embedding is not None if self.num_embeddings == 1: return jnp.broadcast_to(embedding, inputs.shape + (self.features,)) return jnp.take(embedding, inputs, axis=0) def attend(self, query: Array) -> Array: """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth ``features`` of the embedding. Returns: An array with final dim ``num_embeddings`` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ embedding: Array query, embedding = self.promote_dtype( query, self.embedding, dtype=self.dtype ) assert query is not None assert embedding is not None return jnp.dot(query, embedding.T) ================================================ FILE: flax/linen/module.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax Module.""" import contextlib import dataclasses import enum import functools import inspect import sys import threading import typing import weakref from types import MappingProxyType from typing import ( Any, Literal, Optional, TypeVar, Union, overload, ) from collections.abc import Callable, Iterable, Iterator, Mapping import jax import jax.numpy as jnp import typing_extensions as tpe import flax import flax.linen as nn from flax import ( config, core, errors, serialization, traceback_util, traverse_util, ) from flax.core import Scope, meta, partial_eval from flax.core.frozen_dict import FrozenDict from flax.core.scope import ( CollectionFilter, DenyList, Variable, union_filters, ) from flax.ids import FlaxId, uuid from flax.linen import kw_only_dataclasses from flax.typing import ( RNGSequences, PRNGKey, FrozenVariableDict, VariableDict, ) traceback_util.register_exclusion(__file__) T = TypeVar('T') K = TypeVar('K') M = TypeVar('M', bound='Module') _CallableT = TypeVar('_CallableT', bound=Callable) # Used for abstractly testing module behavior. TestScope = type( 'TestScope', (Scope,), {'make_rng': lambda self, name: jax.random.key(0)}, ) # pylint: disable=protected-access,attribute-defined-outside-init def _get_fn_name(fn): if isinstance(fn, functools.partial): return _get_fn_name(fn.func) return getattr(fn, '__name__', 'unnamed_function') def _indent(x: str, num_spaces: int): indent_str = ' ' * num_spaces lines = x.split('\n') # skip last line because it is always empty and should not be indented. assert not lines[-1] return '\n'.join(indent_str + line for line in lines[:-1]) + '\n' def _attr_repr(value: Any): if callable(value) and ( (isinstance(value, nn.Module) and value.__dict__.get('__name__', None)) or (not isinstance(value, nn.Module) and getattr(value, '__name__', None)) ): value_rep = value.__name__ else: value_rep = repr(value) return value_rep def _module_repr(module: 'Module', num_spaces: int = 4): """Returns a pretty printed representation of the module.""" cls = type(module) try: fields = dataclasses.fields(cls) except TypeError: # Edge case with no fields e.g. module = nn.Module() causes error later. return object.__repr__(module) cls_name = cls.__name__ rep = '' attributes = { f.name: f.type for f in fields if f.name not in ('parent', 'name') and f.repr } child_modules = { k: v for k, v in module._state.children.items() # pytype: disable=attribute-error if isinstance(v, Module) } if attributes: rep += '# attributes\n' for attr in attributes.keys(): # TODO(jheek): can we get a nice string representation of attribute types? value = module.__dict__.get(attr, None) value_rep = _attr_repr(value) rep += f'{attr} = {value_rep}\n' if child_modules: rep += '# children\n' for name, child in child_modules.items(): child_rep = _module_repr(child, num_spaces) rep += f'{name} = {child_rep}\n' if rep: return f'{cls_name}(\n{_indent(rep, num_spaces)})' else: return f'{cls_name}()' # Tabulation utilities. # ----------------------------------------------------------------------------- @dataclasses.dataclass class _CallInfo: index: int path: tuple[str, ...] module: 'Module' rngs: dict[str, core.scope.PRNGKey | core.scope.LazyRng] | None mutable: bool method: str args: tuple[Any, ...] kwargs: dict[str, Any] outputs: Any @dataclasses.dataclass class _CallInfoContext(threading.local): index: int calls: list[_CallInfo] def get_call_index(self) -> int: index = self.index self.index += 1 return index @contextlib.contextmanager def _tabulate_context(): _context.call_info_stack.append(_CallInfoContext(0, [])) try: yield finally: _context.call_info_stack.pop() # Track parent relationship across Modules. # ----------------------------------------------------------------------------- class _DynamicContext(threading.local): """Dynamic context.""" # TODO(marcvanzee): switch to using contextvars once minimum python version is # 3.7 def __init__(self): self.module_stack: list['Module' | None] = [ None, ] self.capture_stack = [] self.call_info_stack: list[_CallInfoContext] = [] # The global context _context = _DynamicContext() class _Sentinel: def __copy__(self): return self # Do not copy singleton sentinel. def __deepcopy__(self, memo): del memo return self # Do not copy singleton sentinel. def __reduce__(self): return _get_unspecified_parent, () def _get_unspecified_parent(): return _unspecified_parent _unspecified_parent = _Sentinel() # Enable automatic named_call wrapping for labelling profile traces. # ----------------------------------------------------------------------------- _use_named_call = config.flax_profile def _derive_profiling_name(module, fn): fn_name = _get_fn_name(fn) method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' module_name = module.name or module.__class__.__name__ return f'{module_name}{method_suffix}' def enable_named_call(): """Enables named call wrapping for labelling profile traces. When named call wrapping is enabled all JAX ops executed in a Module will be run under ``jax.named_scope``. The ``Module`` class name will show up around the operations belonging to that Module in the Tensorboard profiling UI, simplifying the profiling process. Note that ``jax.named_scope`` only works for compiled functions (e.g.: using jax.jit or jax.pmap). """ global _use_named_call _use_named_call = True def disable_named_call(): """Disables named call wrapping. See ``enable_named_call`` """ global _use_named_call _use_named_call = False @contextlib.contextmanager def override_named_call(enable: bool = True): # pylint: disable=g-doc-return-or-yield """Returns a context manager that enables/disables named call wrapping. Args: enable: If true, enables named call wrapping for labelling profile traces. (see ``enabled_named_call``). """ # pylint: enable=g-doc-return-or-yield global _use_named_call use_named_call_prev = _use_named_call _use_named_call = enable try: yield finally: _use_named_call = use_named_call_prev # Intercept module methods. # ----------------------------------------------------------------------------- @dataclasses.dataclass(frozen=True) class InterceptorContext: """Read only state showing the calling context for method interceptors. Attributes: module: The Module instance whose method is being called. method_name: The name of the method being called on the module. orig_method: The original method defined on the module. Calling it will short circuit all other interceptors. """ module: 'Module' method_name: str orig_method: Callable[..., Any] class ThreadLocalStack(threading.local): """Thread-local stack.""" def __init__(self): self._storage = [] def push(self, elem: Any) -> None: self._storage.append(elem) def pop(self) -> Any: return self._storage.pop() def __iter__(self) -> Iterator[Any]: return iter(reversed(self._storage)) def __len__(self) -> int: return len(self._storage) def __repr__(self) -> str: return f'{self.__class__.__name__}({self._storage})' Args = tuple[Any] Kwargs = dict[str, Any] NextGetter = Callable[..., Any] Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any] _global_interceptor_stack = ThreadLocalStack() @contextlib.contextmanager def intercept_methods(interceptor: Interceptor): # pylint: disable=g-doc-return-or-yield r"""Registers a new method interceptor. Method interceptors allow you to (at a distance) intercept method calls to modules. It works similarly to decorators. You could modify args/kwargs before calling the underlying method and/or modify the result returning from calling the underlying method. Or you could completely skip calling the underlying method and decide to do something differently. For example:: >>> import flax.linen as nn >>> import jax.numpy as jnp ... >>> class Foo(nn.Module): ... def __call__(self, x): ... return x ... >>> def my_interceptor1(next_fun, args, kwargs, context): ... print('calling my_interceptor1') ... return next_fun(*args, **kwargs) ... >>> foo = Foo() >>> with nn.intercept_methods(my_interceptor1): ... _ = foo(jnp.ones([1])) calling my_interceptor1 You could also register multiple interceptors on the same method. Interceptors will run in order. For example:: >>> def my_interceptor2(next_fun, args, kwargs, context): ... print('calling my_interceptor2') ... return next_fun(*args, **kwargs) ... >>> with nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor1 calling my_interceptor2 You could skip other interceptors by directly calling the ``context.orig_method``. For example:: >>> def my_interceptor3(next_fun, args, kwargs, context): ... print('calling my_interceptor3') ... return context.orig_method(*args, **kwargs) >>> with nn.intercept_methods(my_interceptor3), \ ... nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor3 The following methods couldn't be intercepted: 1. Methods decoratored with ``nn.nowrap``. 2. Dunder methods including ``__eq__``, ``__repr__``, ``__init__``, ``__hash__``, and ``__post_init__``. 3. Module dataclass fields. 4. Module descriptors. Args: interceptor: A method interceptor. """ _global_interceptor_stack.push(interceptor) try: yield finally: assert _global_interceptor_stack.pop() is interceptor def run_interceptors( orig_method: Callable[..., Any], module: 'Module', *args, **kwargs, ) -> Any: """Runs method interceptors.""" method_name = _get_fn_name(orig_method) fun = functools.partial(orig_method, module) context = InterceptorContext(module, method_name, fun) def wrap_interceptor(interceptor, fun): """Wraps `fun` with `interceptor`.""" @functools.wraps(fun) def wrapped(*args, **kwargs): return interceptor(fun, args, kwargs, context) return wrapped # Wraps interceptors around the original method. The innermost interceptor is # the last one added and directly wrapped around the original bound method. for interceptor in _global_interceptor_stack: fun = wrap_interceptor(interceptor, fun) return fun(*args, **kwargs) # Utilities for pytrees of Modules defined inside setup() # ----------------------------------------------------------------------------- def _sorted_items(x): """Returns items of a dict ordered by keys.""" return sorted(x.items(), key=lambda x: x[0]) def _get_suffix_value_pairs( tree_or_leaf: Any, ) -> list[tuple[str, type['Module']]]: """Helper for naming pytrees of submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: return [('', tree_or_leaf)] else: flat_dict = traverse_util.flatten_dict(dict_or_leaf) return [('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict)] def _map_over_modules_in_tree(fn, tree_or_leaf): """Helper for mapping function over submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: return fn('', tree_or_leaf) else: flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True) mapped_flat_dict = { k: fn('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict) } return serialization.from_state_dict( tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict) ) def _freeze_attr(val: Any) -> Any: """Recursively wrap the given attribute `var` in ``FrozenDict``.""" if isinstance(val, (dict, FrozenDict)): return FrozenDict({k: _freeze_attr(v) for k, v in val.items()}) elif isinstance(val, tuple): # Special case namedtuples and special JAX tuple structures otherwise they # would be downgraded to normal tuples. if hasattr(val, '_fields') or type(val).__name__ == 'PartitionSpec': return type(val)(*[_freeze_attr(v) for v in val]) else: return tuple(_freeze_attr(v) for v in val) elif isinstance(val, list): return tuple(_freeze_attr(v) for v in val) else: return val # Method wrapping of "compact methods" and setup() # ----------------------------------------------------------------------------- def compact(fun: _CallableT) -> _CallableT: """Marks the given module method allowing inlined submodules. Methods wrapped in @compact can define submodules directly within the method. For instance:: >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, features): ... x = nn.Dense(features)(x) ... ... ... return x At most one method in each Module may be wrapped with @compact. Args: fun: The Module method to mark as compact. Returns: The given function ``fun`` marked as compact. """ fun.compact = True # type: ignore[attr-defined] return fun def nowrap(fun: _CallableT) -> _CallableT: """Marks the given module method as a helper method that needn't be wrapped. Methods wrapped in ``@nowrap`` are private helper methods that needn't be wrapped with the state handler or a separate named_call transform. This is needed in several concrete instances: - if you're subclassing a method like Module.param and don't want this overridden core function decorated with the state management wrapper. - If you want a method to be callable from an unbound Module (e.g.: a function of construction of arguments that doesn't depend on params/RNGs). If you want to learn more about how Flax Modules manage their state read the [The Flax Module lifecycle](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) guide. For instance:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... num_features: int ... @nn.nowrap ... def _make_dense(self, num_features): ... return nn.Dense(num_features) ... @nn.compact ... def __call__(self, x): ... # now safe to use constructor helper even if using named_call ... dense = self._make_dense(self.num_features) ... return dense(x) Args: fun: The Module method to mark as nowrap. Returns: The given function ``fun`` marked as nowrap. """ fun.nowrap = True # type: ignore[attr-defined] return fun def compact_name_scope(fun: _CallableT) -> _CallableT: """Creates compact submodules from a method. This is a decorator that allows you to define compact submodules from a method. It's intention is to make it easier to port code Haiku code to Flax by providing the same functionality. Example:: >>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp >>> from flax.core import pretty_repr ... >>> class Foo(nn.Module): ... @nn.compact_name_scope ... def up(self, x): ... return nn.Dense(3)(x) ... ... @nn.compact_name_scope ... def down(self, x): ... return nn.Dense(3)(x) ... ... def __call__(self, x): ... return self.up(x) + self.down(x) ... >>> module = Foo() >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2))) >>> params = variables['params'] >>> print(pretty_repr(jax.tree_util.tree_map(jnp.shape, params))) { down: { Dense_0: { bias: (3,), kernel: (2, 3), }, }, up: { Dense_0: { bias: (3,), kernel: (2, 3), }, }, } You can also use ``compact_name_scope`` inside ``@compact`` methods or even other ``compact_name_scope`` methods. Methods that are decorated with ``compact_name_scope`` can also be called directly from ``init`` or ``apply`` via the ``method`` argument:: >>> y_down = module.apply({'params': params}, jnp.ones((1, 2)), method='down') >>> y_down.shape (1, 3) Args: fun: The Module method to mark as compact_name_scope. Returns: The given function ``fun`` marked as compact_name_scope. """ @functools.wraps(fun) def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs): name = fun.__name__ if not hasattr(self, '_compact_name_scope_modules'): raise ValueError( f'Cannot call compact_name_scope method {name!r} on a Module that has not been ' f'setup. This is likely because you are calling {name!r} ' 'from outside of init or apply.' ) module = self._compact_name_scope_modules[name] return module(*args, **kwargs) compact_name_scope_wrapper.compact_name_scope = True # type: ignore[attr-defined] compact_name_scope_wrapper.inner_fun = fun # type: ignore[attr-defined] compact_name_scope_wrapper.nowrap = True # type: ignore[attr-defined] return compact_name_scope_wrapper # type: ignore[return-value] def _get_local_method_names( cls: Any, exclude: Iterable[str] = () ) -> tuple[str, ...]: """Gets method names of a class, excluding class and static methods. Args: cls: The class to get method names for. exclude: Names to exclude from output. Returns: A list of method names. """ true_methods = set() for m in cls.__dict__: if callable(cls.__dict__[m]) and not inspect.isclass( cls.__dict__[m] ): # pytype: disable=not-supported-yet mtype = type(cls.__dict__[m]) if mtype != staticmethod and mtype != classmethod: true_methods.add(m) return tuple(true_methods.difference(set(exclude))) def _get_local_descriptor_names( cls: Any, exclude: Iterable[str] = () ) -> tuple[str, ...]: """Gets descriptor names of a class. Args: cls: The class to get property names for. exclude: Names to exclude from output. Returns: A list of property names. """ true_properties = set() for m, attr in cls.__dict__.items(): if not callable(attr) and ( hasattr(attr, '__get__') or hasattr(attr, '__set__') or hasattr(attr, '__delete__') ): mtype = type(attr) if mtype != staticmethod and mtype != classmethod: true_properties.add(m) return tuple(true_properties.difference(set(exclude))) def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]: """Manages Module state for a given user-defined method. Args: fun: User-defined Module method to manage state for. Returns: Wrapped method. """ # Don't rewrap methods that have already had the state management wrapper # applied in the decorator stack. This wrapper should always be applied # before transformation wrappers. if hasattr(fun, 'method_handler_wrapped'): return fun @functools.wraps(fun) def wrapped_module_method(*args, **kwargs): # We might have incorrectly wrappped a callable # that is not a method. Check whether the first arg is self, # otherwise call the wrapped function as is. if args and isinstance(args[0], Module): self, args = args[0], args[1:] return self._call_wrapped_method(fun, args, kwargs) else: return fun(*args, **kwargs) wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] return wrapped_module_method def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper': """Wraps a descriptor to give better error messages. Args: descriptor: User-defined Module attribute descriptor. Returns: Wrapped descriptor. """ # Don't rewrap descriptors. if isinstance(descriptor, DescriptorWrapper): return descriptor return create_descriptor_wrapper(descriptor) def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]: """Wraps a hash function with some check for Flax Modules.""" @functools.wraps(hash_fn) def wrapped(self): if self.scope is not None: raise TypeError("Can't call __hash__ on modules that hold variables.") try: hash_value = hash_fn(self) except TypeError as exc: raise TypeError( 'Failed to hash Flax Module. ' 'The module probably contains unhashable attributes. ' f'Module={self}' ) from exc return hash_value return wrapped def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: """Returns an unbound function from a method that is possibly bound. This means that if the passed function belongs of an instance of a class, then the returned function does no longer depend on the instance, which is passed as the first argument to the function. Args: method_or_fn: A class method or function. Returns: An unbound version of input function. """ if inspect.ismethod(method_or_fn) and isinstance( method_or_fn.__self__, Module ): # pytype: disable=attribute-error method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error # The method should be callable, and it should have at least one argument # representing the class that is passed in. if ( not callable(method_or_fn) or len(inspect.signature(method_or_fn).parameters) < 1 ): raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn def _map_submodules(fn: Callable[['Module'], Any], tree): """Map a function over all submodules in a tree.""" g = lambda _, x: fn(x) if isinstance(x, Module) else x return _freeze_attr(_map_over_modules_in_tree(g, tree)) class SetupState(enum.IntEnum): # setup() has not been called. NEW = 0 # setup() has been called outside a transform boundary. TRANSFORMED = 1 # setup() has been called. DONE = 2 @dataclasses.dataclass class _ModuleInternalState: """Ephemeral Module Evaluation State. For clarity, we collect all of the temporary flags and ephemeral state used by Modules for autonaming and error messages here, alongside the rules used to pass this ephemeral state across transform boundaries. """ in_compact_method: bool = False in_setup: bool = False setup_called: SetupState = SetupState.NEW is_initialized: bool = False autoname_cursor: dict[str, int] = dataclasses.field(default_factory=dict) children: dict[str, Union[str, 'Module']] = dataclasses.field( default_factory=dict ) def reset(self) -> None: """Resets transient state. This function is called after each module method, so only attributes that are method-dependent are reset. """ self.in_compact_method = False self.in_setup = False self.autoname_cursor = dict() def export(self) -> '_ModuleInternalState': """Exports transform-preserved state across transform boundary.""" setup_state = ( SetupState.TRANSFORMED if self.setup_called else SetupState.NEW ) cloned = _ModuleInternalState( in_compact_method=self.in_compact_method, in_setup=self.in_setup, setup_called=setup_state, is_initialized=self.is_initialized, autoname_cursor=dict(self.autoname_cursor), ) return cloned def reimport(self, other: '_ModuleInternalState') -> None: """Re-imports transform-preserved state from across transform boundary.""" self.in_compact_method = other.in_compact_method self.in_setup = other.in_setup self.is_initialized = other.is_initialized self.autoname_cursor = dict(other.autoname_cursor) _uninitialized_module_internal_state = _ModuleInternalState() _UNDEFINED_COPY_PICKLE_METHODS = ( '__getstate__', '__setstate__', '__getnewargs_ex__', '__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__', ) _caches: 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId, Module]]' = weakref.WeakKeyDictionary() tuple_reduce = lambda xs, x: xs + (x,) tuple_init = lambda: () capture_call_intermediates = lambda _, method_name: method_name == '__call__' class ParentDescriptor: """Wraps parent module references in weak refs. This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation. Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms. """ def __get__(self, obj, objtype=None): # check if obj is None, happens during %autoreload if obj is None: return None parent = object.__getattribute__(obj, '_parent_ref') return parent() if isinstance(parent, weakref.ReferenceType) else parent def __set__(self, obj, value): maybe_weak = weakref.ref(value) if isinstance(value, Module) else value object.__setattr__(obj, '_parent_ref', maybe_weak) class Descriptor(tpe.Protocol): __isabstractmethod__: bool def __get__(self, obj, objtype=None) -> Any: ... def __set__(self, obj, value) -> None: ... def __delete__(self, obj) -> None: ... def __set_name__(self, owner, name) -> None: ... class DescriptorWrapper: pass def create_descriptor_wrapper(descriptor: Descriptor): """Creates a descriptor wrapper that calls a get_fn on the descriptor.""" class _DescriptorWrapper(DescriptorWrapper): """A descriptor that can wrap any descriptor.""" if hasattr(descriptor, '__isabstractmethod__'): __isabstractmethod__ = descriptor.__isabstractmethod__ def __init__(self, wrapped: Descriptor): self.wrapped = wrapped # conditionally define descriptor methods if hasattr(descriptor, '__get__'): def __get__(self, *args, **kwargs): # here we will catch internal AttributeError and re-raise it as a # more informative and correct error message. try: return self.wrapped.__get__(*args, **kwargs) except AttributeError as e: raise errors.DescriptorAttributeError() from e if hasattr(descriptor, '__set__'): def __set__(self, *args, **kwargs): return self.wrapped.__set__(*args, **kwargs) if hasattr(descriptor, '__delete__'): def __delete__(self, *args, **kwargs): return self.wrapped.__delete__(*args, **kwargs) if hasattr(descriptor, '__set_name__'): def __set_name__(self, *args, **kwargs): self.wrapped.__set_name__(*args, **kwargs) def __getattr__(self, name): if 'wrapped' not in vars(self): raise AttributeError() return getattr(self.wrapped, name) return _DescriptorWrapper(descriptor) # Base Module definition. # ----------------------------------------------------------------------------- def module_field(*, kw_only: bool = False, default: Any | None = ...) -> Any: ... # The ModuleBase class is created only to make static analyzers happy # mainly pytype and pyright. Some notes: # * pyright (correctly) complains that Module itself is not a dataclass, even # though all its subclasses and intances ARE dataclasses. Because there is no # way to annotate this in a way that pyright understands, we create a # ModuleBase class decorated with `dataclass_transform` such that pyright # thinks Module is a dataclass (in reality only subclasses are instantiated # so this is fine). # * The `__dataclass_fields__` attribute is needed because pytype seems to # not understand the `dataclass_transform` decorator, therefore we need # to add the attribute manually. # * Other attributes are annotated for completeness. Because we are using # the `if typing.TYPE_CHECKING` pattern, these annotations are not present # at runtime so they don't affect the dataclass behavior. @tpe.dataclass_transform(field_specifiers=(module_field,)) # type: ignore[literal-required] class ModuleBase: if typing.TYPE_CHECKING: scope: Scope | None _state: _ModuleInternalState _parent_ref: Union['Module', weakref.ReferenceType['Module'], None] __dataclass_fields__: dict[str, dataclasses.Field] class Module(ModuleBase): """Base class for all neural network modules. Layers and models should subclass this class. All Flax Modules are Python 3.7 `dataclasses `_. Since dataclasses take over ``__init__``, you should instead override :meth:`setup`, which is automatically called to initialize the module. Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the :meth:`setup` method. You can define arbitrary "forward pass" methods on your Module subclass. While no methods are special-cased, ``__call__`` is a popular choice because it allows you to use module instances as if they are functions:: >>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x))) Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the :meth:`compact` wrapper. """ if typing.TYPE_CHECKING: name: str | None = module_field(kw_only=True, default=None) parent: Union['Module', _Sentinel, None] = module_field( kw_only=True, default=None ) def __init__(self, *args, **kwargs): # this stub makes sure pytype accepts constructor arguments. pass def __call__(self, *args, **kwargs) -> Any: # this stub allows pytype to accept Modules as Callables. pass @classmethod def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: """Automatically initializes all subclasses as custom dataclasses.""" super().__init_subclass__(**kwargs) # All Flax Modules are dataclasses. We force this convention since # it encourages the stateless behavior needed to clone module instances for # functional transformation. Instead of using a python metaclass, we # automatically transform Modules into dataclasses at subclass creation # time, and we set the last dataclass arguments to `parent` and `name`. cls._customized_dataclass_transform(kw_only) # We wrap user-defined methods including setup and __call__ to enforce # a number of different checks and to provide clear error messages. cls._find_compact_name_scope_methods() cls._wrap_module_attributes() # Set empty class defaults. cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] cls.scope: Scope | None = None # type: ignore # Handles weak referencing of parent Modules to prevent reference cycles. cls._parent_ref = None # type: ignore[attr-defined] cls.parent = ParentDescriptor() # type: ignore[assignment] @classmethod def _customized_dataclass_transform(cls, kw_only: bool): """Transforms `cls` into a dataclass, with custom additional behavior. 1. Inject `parent` and `name` fields. (If they are already present, then check that they have the expected types.) 2. Set compare, hash, and repr to False for non-init fields. 3. Generate a hash function (if not provided by cls). """ # Check reserved attributes have expected type annotations. if sys.version_info < (3, 14): annotations = dict(cls.__dict__.get('__annotations__', {})) else: annotations = inspect.get_annotations(cls) if annotations.get('parent', _ParentType) != _ParentType: raise errors.ReservedModuleAttributeError(annotations) if annotations.get('name', str) not in ('str', str, Optional[str]): raise errors.ReservedModuleAttributeError(annotations) # any non-init field will only be set in setup # During __hash__ and __eq__ the field is not set yet # so it should not be used in compare, hash or repr. for field in annotations: field_meta = getattr(cls, field, None) if isinstance(field_meta, dataclasses.Field) and not field_meta.init: field_meta.compare = False field_meta.hash = False field_meta.repr = False extra_fields = [ ( 'parent', _ParentType, kw_only_dataclasses.field( repr=False, default=_unspecified_parent, kw_only=True ), ), ( 'name', Optional[str], kw_only_dataclasses.field(default=None, kw_only=True), ), ] if kw_only: if tuple(sys.version_info)[:3] >= (3, 10, 0): for ( name, annotation, # pytype: disable=invalid-annotation default, ) in extra_fields: setattr(cls, name, default) cls.__annotations__[name] = annotation dataclasses.dataclass( # type: ignore[call-overload] unsafe_hash='__hash__' not in cls.__dict__, repr=False, kw_only=True, )(cls) else: raise TypeError('`kw_only` is not available before Py 3.10.') else: # Now apply dataclass transform (which operates in-place). # Do generate a hash function only if not provided by the class. kw_only_dataclasses.dataclass( cls, unsafe_hash='__hash__' not in cls.__dict__, repr=False, extra_fields=extra_fields, ) # pytype: disable=wrong-keyword-args cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign] @classmethod def _find_compact_name_scope_methods(cls): """Finds all compact_name_scope methods in the class.""" methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] compact_name_scope_fns = tuple( method_name for method_name in methods if hasattr(getattr(cls, method_name), 'compact_name_scope') ) cls._compact_name_scope_methods = compact_name_scope_fns @classmethod def _wrap_module_attributes(cls): """Wraps user-defined non-inherited methods and descriptors with state management functions. """ # wrap methods method_exclusions = [f.name for f in dataclasses.fields(cls)] + [ '__eq__', '__repr__', '__init__', '__hash__', '__post_init__', ] for key in _get_local_method_names(cls, exclude=method_exclusions): method = getattr(cls, key) if hasattr(method, 'nowrap'): continue setattr(cls, key, wrap_method_once(method)) # wrap descriptors descriptor_exclusions = [f.name for f in dataclasses.fields(cls)] + [ 'parent', '__dict__', ] for key in _get_local_descriptor_names(cls, descriptor_exclusions): # don't use getattr here, since it will call the descriptor descriptor = cls.__dict__[key] if hasattr(descriptor, 'nowrap'): continue setattr(cls, key, wrap_descriptor_once(descriptor)) return cls def _call_wrapped_method(self, fun, args, kwargs): """Calls a wrapped method. This function is responsible for setting up the thread local state correctly before calling the method and cleaning up afterwards. This includes storing intermediates, setup of the compact scope, and making sure setup is called before any other method. Args: fun: The wrapped method. args: Named arguments passed to ``fun``. kwargs: Keyword arguments passed to ``fun``. Returns: The results of calling ``fun``. """ is_compact_method = hasattr(fun, 'compact') fun_name = _get_fn_name(fun) is_setup_method = fun_name == 'setup' add_call_info = not is_setup_method and len(_context.call_info_stack) > 0 # We lazily call setup() only when needed. if is_setup_method: if self.scope is None: raise errors.CallSetupUnboundModuleError() is_recurrent = self._state.in_setup self._state.in_setup = True else: self._try_setup() if is_compact_method: if self.scope is None: raise errors.CallCompactUnboundModuleError() is_recurrent = self._state.in_compact_method self._state.in_compact_method = True _context.module_stack.append(self) try: # get call info if add_call_info: assert self.scope is not None call_index = _context.call_info_stack[-1].get_call_index() if _global_interceptor_stack: run_fun = functools.partial(run_interceptors, fun) else: run_fun = fun # call method if _use_named_call: with jax.named_scope(_derive_profiling_name(self, fun)): y = run_fun(self, *args, **kwargs) else: y = run_fun(self, *args, **kwargs) if _context.capture_stack: filter_fn = _context.capture_stack[-1] if filter_fn and filter_fn(self, fun_name): self.sow('intermediates', fun_name, y) if add_call_info: _args, _kwargs, _y = flax.linen.summary._represent_tree( (args, kwargs, y) ) _context.call_info_stack[-1].calls.append( _CallInfo( call_index, self.path, self.clone(), self.scope.rngs, self.scope.mutable, fun.__name__, _args, _kwargs, _y, ) ) return y finally: _context.module_stack.pop() if is_compact_method: object.__setattr__(self, 'scope', self.scope.rewound()) # setup or compact calls can be recurrent for example due to super calls # resetting the state would cause is compact/setup method # to be set to False prematurely. if (is_compact_method or is_setup_method) and not is_recurrent: self._state.reset() def __setattr__(self, name: str, val: Any): """Sets an attribute on this Module. We overload setattr solely to support pythonic naming via assignment of submodules in the special :meth:`setup` function:: self.submodule_name = MyModule(...) We also support lists and other general pytrees, e.g.:: self.submodules = [MyModule0(..), MyModule1(..), ...] Args: name: Attribute to set. val: Value of the attribute. """ fields = self.__dataclass_fields__ # pytype: disable=attribute-error is_dataclass_attr = name in fields and fields[name].init if not self._state.in_setup: if not self._state.is_initialized: # Setting attributes before end of Module.__post_init__() object.__setattr__(self, name, val) return else: # If the attribute is a python special method, we allow setting it (this # is useful e.g. for IPython auto-reload). if name.startswith('__'): object.__setattr__(self, name, val) return # We're past all initialization and setup logic: # Raises a TypeError just like frozen python dataclasses. raise errors.SetAttributeFrozenModuleError( self.__class__.__name__, name, val ) # We're inside the setup() method: if is_dataclass_attr: # These names are specified as dataclass fields. They should not be # initialized within the setup() method, but can be modified freely # before it. raise errors.SetAttributeInModuleSetupError() # Values (that may be variables or submodules) are being defined and # attached in setup(), we run some extra logic in that case. self._register_submodules(name, val) def __getattr__(self, name: str) -> Any: """Call setup() before getting any setup-defined attributes.""" # We don't want to return anything for python copy / pickle methods. if name in _UNDEFINED_COPY_PICKLE_METHODS: raise AttributeError() self._try_setup() if name in self.__dict__: return self.__dict__[name] else: msg = f'"{self.__class__.__name__}" object has no attribute "{name}".' if self.scope is None: msg += ( f' If "{name}" is defined in \'.setup()\', remember these fields ' "are only accessible from inside 'init' or 'apply'." ) raise AttributeError(msg) def __dir__(self) -> list[str]: """Call setup() before listing attributes.""" self._try_setup() return object.__dir__(self) # type: ignore def __post_init__(self) -> None: # DO NOT REMOVE - Marker for internal logging. # In dataclasses, __init__ is overridden to process dataclass arguments, # and __post_init__ is called immediately afterwards. Here, depending on the # type of `parent` passed to initialize the Module, we either defer # initialization, attach this Module as a submodule of a parent, or bind # this Module at the top-level to variables and rngs. object.__setattr__(self, '_id', uuid()) object.__setattr__(self, '_state', _ModuleInternalState()) # Typically we set the parent based on the dynamic module context. if self.parent is _unspecified_parent: # pytype: disable=attribute-error object.__setattr__(self, 'parent', _context.module_stack[-1]) # Initialization is deferred for top level Modules or any other "orphan" # Modules until attachment by __setattr__ i.e. MyModule(..., parent=None) if self.parent is None: return # Register submodule on parent Module. if isinstance(self.parent, Module): # When initializing an unnamed Module inside setup() # initialization is deferred until attachment by __setattr__ # i.e. self.mymodule = MyModule(...) self.name: str | None if ( self.parent._state.in_setup and self.name is None ): # pytype: disable=attribute-error return if not self.parent._initialization_allowed: raise errors.AssignSubModuleError(self.__class__.__name__) # Autonaming of submodules. if self.name is None: # pytype: disable=attribute-error prefix = f'{self.__class__.__name__}' cursor = self.parent._state.autoname_cursor.get(prefix, 0) self.name = f'{prefix}_{cursor}' self.parent._state.autoname_cursor[prefix] = cursor + 1 # Allow scope aliasing under transforms for submodules defined in setup. reuse_scopes = ( self.parent._state.in_setup and self.parent._state.setup_called == SetupState.TRANSFORMED ) # Perform name-collision check. if self.parent._name_taken(self.name, reuse_scopes=reuse_scopes): parent_class = self.parent.__class__.__name__ raise errors.NameInUseError('submodule', self.name, parent_class) # Finalize attachment to parent and scope initialization. self.parent._state.children[self.name] = self assert self.parent.scope is not None object.__setattr__( self, 'scope', self.parent.scope.push(self.name, reuse=reuse_scopes) ) # Top-level invocation with a functional Scope. elif isinstance(self.parent, Scope): object.__setattr__(self, 'scope', self.parent) else: raise ValueError('parent must be None, Module or Scope') # eagerly bind submodules if scope is available if self.scope is not None: for field in dataclasses.fields(self): if field.name not in ('parent', 'name') and field.init: self._register_submodules(field.name, getattr(self, field.name)) self._state.is_initialized = True def __repr__(self) -> str: return _module_repr(self) def setup(self) -> None: """Initializes a Module lazily (similar to a lazy ``__init__``). ``setup`` is called once lazily on a module instance when a module is bound, immediately before any other methods like ``__call__`` are invoked, or before a ``setup``-defined attribute on ``self`` is accessed. This can happen in three cases: 1. Immediately when invoking :meth:`apply`, :meth:`init` or :meth:`init_and_output`. 2. Once the module is given a name by being assigned to an attribute of another module inside the other module's ``setup`` method (see :meth:`__setattr__`):: >>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once. 3. Once a module is constructed inside a method wrapped with :meth:`compact`, immediately before another method is called or ``setup`` defined attribute is accessed. """ pass def _register_submodules(self, name, val): """Registers a submodule.""" assert self.scope, 'Trying to register submodules on unbound scope.' root = self.scope.root cache = _caches.get(root, weakref.WeakValueDictionary()) _caches[root] = cache queue = [] preserve_adopted_names = config.flax_preserve_adopted_names if hasattr(type(self), 'preserve_adopted_names'): preserve_adopted_names = type(self).preserve_adopted_names def adopt_attr_modules(cache, queue, suffix, subvalue): if isinstance(subvalue, Module): current_name = subvalue.name adopted_name = None if subvalue.parent is None: # Preserve sharing-by-reference relationships during adoption # via cache keyed on unique instance ids. key = subvalue._id # Module was passed from outside. It needs to be cloned. # Outside modules are named by attachment, not an outer name, # UNLESS we're using new adopted name policy, in which case an existing # name will be used, as is often supplied by config systems. if preserve_adopted_names: adopted_name = object.__getattribute__(subvalue, 'name') if key in cache: subvalue = cache[key] else: subvalue = subvalue.clone(name=None) cache[key] = subvalue if subvalue.name is None: object.__setattr__(subvalue, 'parent', self) if adopted_name is None: adopted_name = ( f'{name}{suffix}' if not isinstance(subvalue, CompactNameScope) else current_name ) object.__setattr__(subvalue, 'name', adopted_name) queue.append(subvalue) return subvalue val = _freeze_attr( _map_over_modules_in_tree( functools.partial(adopt_attr_modules, cache, queue), val ) ) object.__setattr__(self, name, val) for x in queue: x.__post_init__() def _try_setup(self, shallow: bool = False) -> None: """Tries to setup module if scope is available and setup has not been called yet.""" if ( self.scope and not self._state.in_setup and self._state.setup_called != SetupState.DONE ): try: self._state.in_setup = True # A shallow setup will only register attribute submodules but it does # not call the user's setup. This avoids running before a # transformation. for field in dataclasses.fields(self): if field.name not in ('parent', 'name') and field.init: self._register_submodules(field.name, getattr(self, field.name)) if not shallow: self.setup() # create NonTransparent Modules self._compact_name_scope_modules = { name: CompactNameScope( getattr(type(self), name).inner_fun, lambda: self, name=name ) for name in self._compact_name_scope_methods } # We run static checks abstractly once for setup before any transforms # to detect name collisions and other python errors. elif self._state.setup_called == SetupState.NEW: self._validate_setup() finally: self._state.in_setup = False if not shallow: self._state.setup_called = SetupState.DONE def _validate_setup(self) -> None: """Abstractly evaluates setup only to run static checks.""" def run_setup_only(x): wrapped_id = wrap_method_once(lambda m, x: x) with TestScope({}, rngs={}, mutable=True).temporary() as root: return wrapped_id(self.clone(parent=root), x) _ = jax.eval_shape(run_setup_only, 0) def _name_taken( self, name: str, reuse_scopes: bool = False, collection: str | None = None, ) -> bool: assert self.scope is not None if reuse_scopes: return False return self.scope.name_reserved(name, collection) @property def _initialization_allowed(self): return ( not self._state.is_initialized # allow eager attachment in post-init or self._state.in_setup or self._state.in_compact_method ) @property def path(self): """Get the path of this Module. Top-level root modules have an empty path ``()``. Note that this method can only be used on bound modules that have a valid scope. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class SubModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'SubModel path: {self.path}') ... return x >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'Model path: {self.path}') ... return SubModel()(x) >>> model = Model() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 2))) Model path: () SubModel path: ('SubModel_0',) """ if self.scope is None: raise ValueError("Can't access module paths on unbound modules.") return self.scope.path def clone( self: M, *, parent: Union[Scope, 'Module', _Sentinel] | None = None, _deep_clone: bool | weakref.WeakValueDictionary = False, _reset_names: bool = False, **updates, ) -> M: """Creates a clone of this Module, with optionally updated arguments. NOTE: end users are encouraged to use the ``copy`` method. ``clone`` is used primarily for internal routines, and ``copy`` offers simpler arguments and better defaults. Args: parent: The parent of the clone. The clone will have no parent if no explicit parent is specified. _deep_clone: A boolean or a weak value dictionary to control deep cloning of submodules. If True, submodules will be cloned recursively. If a weak value dictionary is passed, it will be used to cache cloned submodules. This flag is used by init/apply/bind to avoid scope leakage. _reset_names: If True, ``name=None`` is also passed to submodules when cloning. Resetting names in submodules is necessary when calling ``.unbind``. **updates: Attribute updates. Returns: A clone of the this Module with the updated attributes and parent. """ attrs = { f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init } attrs.update(parent=parent, **updates) # Here we implement deep cloning of submodules, this is necessary to avoid scope leakage # from external submodules into init/apply/bind while preserving sharing-by-reference # relationships between submodules. if _deep_clone != False: # We use a weak value dictionary to cache cloned submodules. When a shared # submodule is cloned, its only cloned once else its fetched from the cache. cache = ( weakref.WeakValueDictionary() if isinstance(_deep_clone, bool) else _deep_clone ) def clone_fn(m: Module) -> Module: if hasattr(m, '_id'): key = m._id if key in cache: return cache[key] else: if _reset_names: clone = m.clone( _deep_clone=cache, _reset_names=_reset_names, name=None ) else: clone = m.clone(_deep_clone=cache) cache[key] = clone return clone else: # If the module doesn't have an _id attribute it could be a mock object # so we return it as is. return m # _map_submodules will map over all submodules inside attrs # value here can be any pytree, non-module values are ignored for field_name, value in attrs.items(): if field_name == 'parent': continue attrs[field_name] = _map_submodules(clone_fn, value) module = self.__class__(**attrs) return module def copy( self: M, *, parent: Union[Scope, 'Module', _Sentinel] | None = _unspecified_parent, name: str | None = None, **updates, ) -> M: """Creates a copy of this Module, with optionally updated arguments. Args: parent: The parent of the copy. By default the current module is taken as parent if not explicitly specified. name: A new name for the copied Module, by default a new automatic name will be given. **updates: Attribute updates. Returns: A copy of the this Module with the updated name, parent, and attributes. """ return self.clone( parent=parent, name=name, _deep_clone=True, _reset_names=False, **updates ) @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[True], **init_kwargs, ) -> Variable[T]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: Literal[False], **init_kwargs, ) -> Variable[meta.AxisMetadata[T]]: ... @overload def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: ... def variable( self, col: str, name: str, init_fn: Callable[..., T] | None = None, *init_args, unbox: bool = True, **init_kwargs, ) -> Variable[T] | Variable[meta.AxisMetadata[T]]: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param` for a shorthand way to define read-only variables in the "params" collection. Contrary to :meth:`param`, all arguments passing using ``init_fn`` should be passed on explicitly:: >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}} In the example above, the function ``lecun_normal`` expects two arguments: ``key`` and ``shape``, and both have to be passed on. The PRNG for ``stats`` has to be provided explicitly when calling :meth:`init` and :meth:`apply`. Args: col: The variable collection name. name: The variable name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised. *init_args: The positional arguments to pass to init_fn. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: The key-word arguments to pass to init_fn Returns: A :class:`flax.core.variables.Variable` that can be read or set via ".value" attribute. Throws an error if the variable exists already. """ if not self._initialization_allowed: raise ValueError( 'Variables must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) if self._name_taken(name, collection=col): raise errors.NameInUseError('variable', name, self.__class__.__name__) assert self.scope is not None v = self.scope.variable( col, name, init_fn, *init_args, unbox=unbox, **init_kwargs ) self._state.children[name] = col return v @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., meta.AxisMetadata[T]] | Callable[..., T], *init_args, unbox: Literal[True], **init_kwargs, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T], *init_args, unbox: Literal[False], **init_kwargs, ) -> T: ... @overload def param( self, name: str, init_fn: Callable[..., T | meta.AxisMetadata[T]], *init_args, unbox: bool, **init_kwargs, ) -> T | meta.AxisMetadata[T]: ... def param( self, name: str, init_fn: Callable[..., T | meta.AxisMetadata[T]], *init_args, unbox: bool = True, **init_kwargs, ) -> T | meta.AxisMetadata[T]: """Declares and returns a parameter in this Module. Parameters are read-only variables in the collection named "params". See :mod:`flax.core.variables` for more details on variables. The first argument of ``init_fn`` is assumed to be a PRNG key, which is provided automatically and does not have to be passed using ``init_args`` or ``init_kwargs``:: >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}} In the example above, the function ``lecun_normal`` expects two arguments: ``key`` and ``shape``, but only ``shape`` has to be provided explicitly; ``key`` is set automatically using the PRNG for ``params`` that is passed when initializing the module using :meth:`init`. Args: name: The parameter name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module. *init_args: The positional arguments to pass to init_fn. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). **init_kwargs: The key-word arguments to pass to init_fn. Returns: The value of the initialized parameter. Throws an error if the parameter exists already. """ if not self._initialization_allowed: raise ValueError( 'Parameters must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) if self._name_taken(name, collection='params'): raise errors.NameInUseError('param', name, self.__class__.__name__) assert self.scope is not None v: T | meta.AxisMetadata[T] = self.scope.param( name, init_fn, *init_args, unbox=unbox, **init_kwargs ) self._state.children[name] = 'params' return v def has_variable(self, col: str, name: str) -> bool: """Checks if a variable of given collection and name exists in this Module. See :mod:`flax.core.variables` for more explanation on variables and collections. Args: col: The variable collection name. name: The name of the variable. Returns: True if the variable exists. """ if self.scope is None: raise ValueError("Can't access variables on unbound modules") return self.scope.has_variable(col, name) def is_mutable_collection(self, col: str) -> bool: """Returns true if the collection ``col`` is mutable.""" if self.scope is None: raise ValueError("Can't check mutability on unbound modules") return self.scope.is_mutable_collection(col) def has_rng(self, name: str) -> bool: """Returns true if a PRNGSequence with name ``name`` exists.""" if self.scope is None: raise ValueError("Can't query for RNGs on unbound modules") return self.scope.has_rng(name) def make_rng(self, name: str = 'params') -> PRNGKey: """Returns a new RNG key from a given RNG sequence for this Module. The new RNG key is split from the previous one. Thus, every call to ``make_rng`` returns a new RNG key, while still guaranteeing full reproducibility. .. note:: If an invalid name is passed (i.e. no RNG key was passed by the user in ``.init`` or ``.apply`` for this name), then ``name`` will default to ``'params'``. Example:: >>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out Learn more about RNG's by reading the Flax RNG guide: https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html Args: name: The RNG sequence name. Returns: The newly generated RNG key. """ if self.scope is None: raise ValueError("Can't use RNGs on unbound modules") return self.scope.make_rng(name) def is_initializing(self) -> bool: """Returns True if running under self.init(...) or nn.init(...)(). This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under ``module.init`` or ``nn.init``. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized. """ if self.scope is None: raise ValueError("Can't check if running under init() on unbound modules") return self.scope.get_flag('initializing', False) def _module_checks(self): """Run standard runtime checks.""" if not isinstance(self, Module): raise errors.InvalidInstanceModuleError() overridden_post_init = self.__post_init__ != Module.__post_init__ if overridden_post_init and not hasattr(self, '_id'): raise errors.IncorrectPostInitOverrideError() @traceback_util.api_boundary def bind( self: M, variables: VariableDict, *args, rngs: RNGSequences | None = None, mutable: CollectionFilter = False, ) -> M: """Creates an interactive Module instance by binding variables and RNGs. ``bind`` provides an "interactive" instance of a Module directly without transforming a function with ``apply``. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells. Once the variables (and optionally RNGs) are bound to a ``Module`` it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. ``bind()`` should only be used for interactive experimentation, and in all other cases we strongly encourage users to use ``apply()`` instead. Example:: >>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = nn.Dense(3) ... self.decoder = nn.Dense(5) ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> x = jnp.ones((16, 9)) >>> ae = AutoEncoder() >>> variables = ae.init(jax.random.key(0), x) >>> model = ae.bind(variables) >>> z = model.encoder(x) >>> x_reconstructed = model.decoder(z) Args: variables: A dictionary containing variables keyed by variable collections. See :mod:`flax.core.variables` for more details about variables. *args: Named arguments (not used). rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. Returns: A copy of this instance with bound variables and RNGs. """ Module._module_checks(self) del args scope = core.bind(variables, rngs=rngs, mutable=mutable) return self.clone(parent=scope, _deep_clone=True) def unbind(self: M) -> tuple[M, VariableDict]: """Returns an unbound copy of a Module and its variables. ``unbind`` helps create a stateless version of a bound Module. An example of a common use case: to extract a sub-Module defined inside ``setup()`` and its corresponding variables: 1) temporarily ``bind`` the parent Module; and then 2) ``unbind`` the desired sub-Module. (Recall that ``setup()`` is only called when the Module is bound.):: >>> class Encoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(256)(x) >>> class Decoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(784)(x) >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = Encoder() ... self.decoder = Decoder() ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> module = AutoEncoder() >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) >>> # Extract the Encoder sub-Module and its variables >>> encoder, encoder_vars = module.bind(variables).encoder.unbind() Returns: A tuple with an unbound copy of this Module and its variables. """ Module._module_checks(self) if self.scope is None: raise errors.CallUnbindOnUnboundModuleError() variables = self.variables module = self.clone(_deep_clone=True, _reset_names=True, name=None) return module, variables @traceback_util.api_boundary def apply( self, variables: VariableDict, *args, rngs: PRNGKey | RNGSequences | None = None, method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = False, capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, ) -> Any | tuple[Any, FrozenVariableDict | dict[str, Any]]: """Applies a module method to variables and returns output and modified variables. Note that ``method`` should be set if one would like to call ``apply`` on a different class method than ``__call__``. For instance, suppose a Transformer modules has a method called ``encode``, then the following calls ``apply`` on that method:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode) If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:: >>> encoded = model.apply(variables, x, method=model.encode) You can also pass a string to a callable attribute of the module. For example, the previous can be written as:: >>> encoded = model.apply(variables, x, method='encode') Note ``method`` can also be a function that is not defined in ``Transformer``. In that case, the function should have at least one argument representing an instance of the Module class:: >>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn) If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding ``PRNGKey`` to ``apply``. If ``self.make_rng(name)`` is called on an RNG stream name that isn't passed by the user, it will default to using the ``'params'`` RNG stream. Example:: >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3) Args: variables: A dictionary containing variables keyed by variable collections. See :mod:`flax.core.variables` for more details about variables. *args: Named arguments passed to the specified apply method. rngs: a dict of PRNGKeys to initialize the PRNG sequences. The "params" PRNG sequence is used to initialize parameters. method: A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the ``__call__`` method of the module. A string can also be provided to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. capture_intermediates: If ``True``, captures intermediate return values of all Modules inside the "intermediates" collection. By default, only the return values of all ``__call__`` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the specified apply method. Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict of the modified collections. """ Module._module_checks(self) if rngs is not None and not isinstance(rngs, dict): if not core.scope._is_valid_rng(rngs): raise errors.InvalidRngError( 'RNGs should be of shape (2,) or PRNGKey in Module ' f'{self.__class__.__name__}, but rngs are: {rngs}' ) rngs = {'params': rngs} if isinstance(method, str): attribute_name = method method = getattr(self, attribute_name) if not callable(method): class_name = type(self).__name__ raise TypeError( f"'{class_name}.{attribute_name}' must be a callable, got" f' {type(method)}.' ) # if the `method` string is a submodule, we create a lambda function # that calls the submodule, forwarding all arguments. if isinstance(method, Module): method = lambda self, *args, **kwargs: getattr(self, attribute_name)( *args, **kwargs ) elif method is None: method = self.__call__ method = _get_unbound_fn(method) return apply( method, self, mutable=mutable, capture_intermediates=capture_intermediates, )(variables, *args, **kwargs, rngs=rngs) @traceback_util.api_boundary def init_with_output( self, rngs: PRNGKey | RNGSequences, *args, method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = DenyList('intermediates'), capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, ) -> tuple[Any, FrozenVariableDict | dict[str, Any]]: """Initializes a module method with variables and returns output and modified variables. Args: rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. A string can also be provided to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default, all collections except "intermediates" are mutable. capture_intermediates: If ``True``, captures intermediate return values of all Modules inside the "intermediates" collection. By default only the return values of all ``__call__`` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the init function. Returns: ``(output, vars)``, where ``vars`` are is a dict of the modified collections. """ Module._module_checks(self) if not isinstance(rngs, dict): if not core.scope._is_valid_rng(rngs): raise errors.InvalidRngError( 'RNGs should be of shape (2,) or PRNGKey in Module ' f'{self.__class__.__name__}, but rngs are: {rngs}' ) rngs = {'params': rngs} if isinstance(method, str): attribute_name = method method = getattr(self, attribute_name) if not callable(method): class_name = type(self).__name__ raise TypeError( f"'{class_name}.{attribute_name}' must be a callable, got" f' {type(method)}.' ) elif method is None: method = self.__call__ method = _get_unbound_fn(method) return init_with_output( method, self, mutable=mutable, capture_intermediates=capture_intermediates, )(rngs, *args, **kwargs) @traceback_util.api_boundary def init( self, rngs: PRNGKey | RNGSequences, *args, method: Callable[..., Any] | str | None = None, mutable: CollectionFilter = DenyList('intermediates'), capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, ) -> FrozenVariableDict | dict[str, Any]: """Initializes a module method with variables and returns modified variables. ``init`` takes as first argument either a single ``PRNGKey``, or a dictionary mapping variable collections names to their ``PRNGKeys``, and will call ``method`` (which is the module's ``__call__`` function by default) passing ``*args`` and ``**kwargs``, and returns a dictionary of initialized variables. Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> key = jax.random.key(0) >>> variables = module.init(key, x, train=False) If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding ``PRNGKey`` to ``init``. If ``self.make_rng(name)`` is called on an RNG stream name that isn't passed by the user, it will default to using the ``'params'`` RNG stream. Example:: >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... other_variable = self.variable( ... 'other_collection', ... 'other_variable', ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), ... x, ... ) ... x = x + other_variable.value ... ... return nn.Dense(1)(x) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} >>> variables0 = module.init(rngs, x) >>> rngs['other_rng'] = jax.random.key(0) >>> variables1 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables0['params'], variables1['params'] ... ) >>> # different other_variable (key(1) vs key(0)) >>> np.testing.assert_raises( ... AssertionError, ... np.testing.assert_allclose, ... variables0['other_collection']['other_variable'], ... variables1['other_collection']['other_variable'], ... ) >>> del rngs['other_rng'] >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream >>> variables2 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables1['params'], variables2['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables1['other_collection']['other_variable'], ... variables2['other_collection']['other_variable'], ... ) >>> # passing in a single key is equivalent to passing in {'params': key} >>> variables3 = module.init(jax.random.key(0), x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables2['params'], variables3['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables2['other_collection']['other_variable'], ... variables3['other_collection']['other_variable'], ... ) Jitting ``init`` initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:: >>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.key(0), x) ``init`` is a light wrapper over ``apply``, so other ``apply`` arguments like ``method``, ``mutable``, and ``capture_intermediates`` are also available. Args: rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. A string can also be provided to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default all collections except "intermediates" are mutable. capture_intermediates: If ``True``, captures intermediate return values of all Modules inside the "intermediates" collection. By default only the return values of all ``__call__`` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the init function. Returns: The initialized variable dict. """ Module._module_checks(self) _, v_out = self.init_with_output( rngs, *args, method=method, mutable=mutable, capture_intermediates=capture_intermediates, **kwargs, ) return v_out @traceback_util.api_boundary def lazy_init( self, rngs: PRNGKey | RNGSequences, *args, method: Callable[..., Any] | None = None, mutable: CollectionFilter = DenyList('intermediates'), **kwargs, ) -> FrozenVariableDict: """Initializes a module without computing on an actual input. lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a ``jax.ShapeDtypeStruct`` which specifies the shape and dtype of the input but no concrete data. Example:: >>> model = nn.Dense(features=256) >>> variables = model.lazy_init( ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) The args and kwargs args passed to ``lazy_init`` can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise ``lazy_init`` cannot infer which variables should be initialized. Args: rngs: The rngs for the variable collections. *args: arguments passed to the init function. method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default all collections except "intermediates" are mutable. **kwargs: Keyword arguments passed to the init function. Returns: The initialized variable dict. """ Module._module_checks(self) def lazy_wrapper(rngs, *args, **kwargs): return self.init(rngs, *args, method=method, mutable=mutable, **kwargs) return partial_eval.lazy_init(lazy_wrapper)(rngs, *args, **kwargs) @property def variables(self) -> VariableDict: """Returns the variables in this module.""" if self.scope is None: raise ValueError("Can't access variables on unbound modules") return self.scope.variables() def get_variable(self, col: str, name: str, default: T | None = None) -> T: """Retrieves the value of a Variable. Args: col: the variable collection. name: the name of the variable. default: the default value to return if the variable does not exist in this scope. Returns: The value of the input variable, of the default value if the variable doesn't exist in this scope. """ if self.scope is None: raise ValueError("Can't access variables on unbound modules") return self.scope.get_variable(col, name, default) def put_variable(self, col: str, name: str, value: Any): """Updates the value of the given variable if it is mutable, or an error otherwise. Args: col: the variable collection. name: the name of the variable. value: the new value of the variable. """ if self.scope is None: raise ValueError("Can't access variables on unbound modules") self.scope.put_variable(col, name, value) @overload def sow(self, col: str, name: str, value: Any) -> bool: ... @overload def sow( self, col: str, name: str, value: T, reduce_fn: Callable[[K, T], K] = tuple_reduce, init_fn: Callable[[], K] = tuple_init, # type: ignore ) -> bool: ... def sow( self, col: str, name: str, value: T, reduce_fn: Callable[[K, T], K] = tuple_reduce, init_fn: Callable[[], K] = tuple_init, # type: ignore ) -> bool: """Stores a value in a collection. Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. If the target collection is not mutable ``sow`` behaves like a no-op and returns ``False``. Example:: >>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)} By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:: >>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)} Args: col: The name of the variable collection. name: The name of the variable. value: The value of the variable. reduce_fn: The function used to combine the existing value with the new value. The default is to append the value to a tuple. init_fn: For the first value stored, ``reduce_fn`` will be passed the result of ``init_fn`` together with the value to be stored. The default is an empty tuple. Returns: ``True`` if the value has been stored successfully, ``False`` otherwise. """ if self.scope is None: raise ValueError("Can't store variables on unbound modules") if not self.scope.is_mutable_collection(col): return False if self.scope.has_variable(col, name): xs = self.scope.get_variable(col, name) else: self.scope.reserve(name, col) self._state.children[name] = col xs = init_fn() xs = reduce_fn(xs, value) self.scope.put_variable(col, name, xs) return True def perturb( self, name: str, value: T, collection: str = 'perturbations' ) -> T: """Add an zero-value variable ('perturbation') to the intermediate value. The gradient of ``value`` would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of ``value`` by running ``jax.grad`` on the perturbation argument. .. note:: This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training. Example:: >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = self.perturb('dense3', x) ... return nn.Dense(2)(x) >>> def loss(variables, inputs, targets): ... preds = model.apply(variables, inputs) ... return jnp.square(preds - targets).mean() >>> x = jnp.ones((2, 9)) >>> y = jnp.ones((2, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) [[-0.04684732 0.06573904 -0.3194327 ] [-0.04684732 0.06573904 -0.3194327 ]] If perturbations are not passed to ``apply``, ``perturb`` behaves like a no-op so you can easily disable the behavior when not needed:: >>> model.apply(variables, x) # works as expected Array([[-0.04579116, 0.50412744], [-0.04579116, 0.50412744]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op Array([[-0.04579116, 0.50412744], [-0.04579116, 0.50412744]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True """ if self.scope is None: raise ValueError("Can't store variables on unbound modules") if self.is_mutable_collection(collection): if not self.scope.has_variable(collection, name): self.scope.reserve(name, collection) self._state.children[name] = collection zeros = jax.tree.map(jnp.zeros_like, value) self.scope.put_variable(collection, name, zeros) # type: ignore if collection in self.scope.root._variables: if self.scope.has_variable(collection, name): old_value = self.scope.get_variable(collection, name) value = jax.tree.map(jnp.add, value, old_value) # type: ignore else: raise ValueError(f"Perturbation collection {collection} present, but " f"missing perturbation variable {name}") return value def tabulate( self, rngs: PRNGKey | RNGSequences, *args, depth: int | None = None, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), console_kwargs: Mapping[str, Any] | None = None, table_kwargs: Mapping[str, Any] = MappingProxyType({}), column_kwargs: Mapping[str, Any] = MappingProxyType({}), compute_flops: bool = False, compute_vjp_flops: bool = False, **kwargs, ) -> str: """Creates a summary of the Module represented as a table. This method has the same signature and internally calls ``Module.init``, but instead of returning the variables, it returns the string summarizing the Module in a table. ``tabulate`` uses ``jax.eval_shape`` to run the forward computation without consuming any FLOPs or allocating memory. Additional arguments can be passed into the ``console_kwargs`` argument, for example, ``{'width': 120}``. For a full list of ``console_kwargs`` arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> # print(Foo().tabulate( >>> # jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True)) This gives the following output:: Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B) **Note**: rows order in the table does not represent execution order, instead it aligns with the order of keys in ``variables`` which are sorted alphabetically. **Note**: ``vjp_flops`` returns ``0`` if the module is not differentiable. Args: rngs: The rngs for the variable collections as passed to ``Module.init``. *args: The arguments to the forward computation. depth: controls how many submodule deep the summary can go. By default, its ``None`` which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module. show_repeated: If ``True``, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is ``False``. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default, all collections except 'intermediates' are mutable. console_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.console.Console`` when rendering the table. Default arguments are ``{'force_terminal': True, 'force_jupyter': False}``. table_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table`` constructor. column_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table.add_column`` when adding columns to the table. compute_flops: whether to include a ``flops`` column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion's UNet, whereas otherwise tabulation would finish in 5 seconds). compute_vjp_flops: whether to include a ``vjp_flops`` column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of ``compute_flops``. **kwargs: keyword arguments to pass to the forward computation. Returns: A string summarizing the Module. """ from flax.linen import summary tabulate_fn = summary.tabulate( self, rngs, depth=depth, show_repeated=show_repeated, mutable=mutable, console_kwargs=console_kwargs, table_kwargs=table_kwargs, column_kwargs=column_kwargs, compute_flops=compute_flops, compute_vjp_flops=compute_vjp_flops, ) return tabulate_fn(*args, **kwargs) def module_paths( self, rngs: PRNGKey | RNGSequences, *args, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), **kwargs, ) -> dict[str, 'Module']: """Returns a dictionary mapping module paths to module instances. This method has the same signature and internally calls ``Module.init``, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime. ``module_paths`` uses ``jax.eval_shape`` to run the forward computation without consuming any FLOPs or allocating memory. Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> modules = Foo().module_paths(jax.random.key(0), x) >>> print({ ... p: type(m).__name__ for p, m in modules.items() ... }) {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'} Args: rngs: The rngs for the variable collections as passed to ``Module.init``. *args: The arguments to the forward computation. show_repeated: If ``True``, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is ``False``. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default, all collections except 'intermediates' are mutable. **kwargs: keyword arguments to pass to the forward computation. Returns: A dict`ionary mapping module paths to module instances. """ from flax.linen import summary table = summary._get_module_table( module=self, depth=None, show_repeated=show_repeated, compute_flops=False, compute_vjp_flops=False, )(rngs, *args, **kwargs, mutable=mutable) return {'/'.join(row.path): row.module_copy for row in table} _ParentType = Union[Module, Scope, _Sentinel, None] def merge_param(name: str, a: T | None, b: T | None) -> T: """Merges construction- and call-time argument. This is a utility for supporting a pattern where a Module hyperparameter can be passed either to ``__init__`` or ``__call__``, and the value that is not ``None`` will be used. Example:: >>> import flax.linen as nn >>> from typing import Optional >>> class Foo(nn.Module): ... train: Optional[bool] = None ... def __call__(self, train: Optional[bool] = None): ... train = nn.merge_param('train', self.train, train) An error is thrown when both arguments are ``None`` or both values are not ``None``. Args: name: the name of the parameter. Used for error messages. a: option a b: option b Returns: a or b whichever is not ``None``. """ if a is None and b is None: raise ValueError( f'Parameter "{name}" must be passed to the constructor or at call time.' ) if a is not None and b is not None: raise ValueError( f'Parameter "{name}" was passed to the constructor and at call time.' ' Should be passed just once.' ) if a is None: assert b is not None return b return a @traceback_util.api_boundary def apply( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = False, capture_intermediates: bool | Callable[[Module, str], bool] = False, ) -> Callable[..., Any]: """Creates an apply function to call ``fn`` with a bound module. Unlike ``Module.apply`` this function returns a new function with the signature ``(variables, *args, rngs=None, **kwargs) -> T`` where ``T`` is the return type of ``fn``. If ``mutable`` is not ``False`` the return type is a tuple where the second item is a ``FrozenDict`` with the mutated variables. The apply function that is returned can be directly composed with JAX transformations like ``jax.jit``:: >>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3))) Args: fn: The function that should be applied. The first argument passed will be a module instance of the ``module`` with variables and RNGs bound to it. module: The ``Module`` that will be used to bind variables and RNGs to. The ``Module`` passed as the first argument to ``fn`` will be a clone of module. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. capture_intermediates: If ``True``, captures intermediate return values of all Modules inside the "intermediates" collection. By default, only the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. Returns: The apply function wrapping ``fn``. """ @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() if capture_intermediates is True: # pylint: disable=g-bool-id-comparison capture_intermediates = capture_call_intermediates if capture_intermediates: mutable = union_filters(mutable, 'intermediates') return core.apply(scope_fn, mutable=mutable) @traceback_util.api_boundary def init_with_output( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = DenyList('intermediates'), capture_intermediates: bool | Callable[[Module, str], bool] = False, ) -> Callable[..., tuple[Any, FrozenVariableDict | dict[str, Any]]]: """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. Unlike ``Module.init_with_output`` this function returns a new function with the signature ``(rngs, *args, **kwargs) -> (T, variables)`` where ``T`` is the return type of ``fn``. The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is equivalent to passing a dict with one PRNGKey with the name "params". The init function that is returned can be directly composed with JAX transformations like ``jax.jit``:: >>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3))) Args: fn: The function that should be applied. The first argument passed will be a module instance of the ``module`` with variables and RNGs bound to it. module: The ``Module`` that will be used to bind variables and RNGs to. The ``Module`` passed as the first argument to ``fn`` will be a clone of module. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default, all collections except "intermediates" are mutable. capture_intermediates: If ``True``, captures intermediate return values of all Modules inside the "intermediates" collection. By default, only the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. Returns: The init function wrapping ``fn``. """ @functools.wraps(fn) def scope_fn(scope, *args, **kwargs): _context.capture_stack.append(capture_intermediates) try: return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) finally: _context.capture_stack.pop() if capture_intermediates is True: # pylint: disable=g-bool-id-comparison capture_intermediates = capture_call_intermediates if capture_intermediates: mutable = union_filters(mutable, 'intermediates') return core.init(scope_fn, mutable=mutable) @traceback_util.api_boundary def init( fn: Callable[..., Any], module: Module, mutable: CollectionFilter = DenyList('intermediates'), capture_intermediates: bool | Callable[[Module, str], bool] = False, ) -> Callable[..., FrozenVariableDict | dict[str, Any]]: """Creates an init function to call ``fn`` with a bound module. Unlike ``Module.init`` this function returns a new function with the signature ``(rngs, *args, **kwargs) -> variables``. The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is equivalent to passing a dict with one PRNGKey with the name "params". The init function that is returned can be directly composed with JAX transformations like ``jax.jit``:: >>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3))) Args: fn: The function that should be applied. The first argument passed will be a module instance of the ``module`` with variables and RNGs bound to it. module: The ``Module`` that will be used to bind variables and RNGs to. The ``Module`` passed as the first argument to ``fn`` will be a clone of module. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default, all collections except "intermediates" are mutable. capture_intermediates: If `True`, captures intermediate return values of all Modules inside the "intermediates" collection. By default, only the return values of all `__call__` methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored. Returns: The init function wrapping ``fn``. """ init_fn = init_with_output(fn, module, mutable, capture_intermediates) @functools.wraps(init_fn) def init_wrapper(*args, **kwargs): return init_fn(*args, **kwargs)[1] return init_wrapper # TODO(cgarciae): we are defining CompactNameScope just to # avoid a pytype bug with the Flax overlay. We should aim to # remove in the at some point as its not ergonomic. if not typing.TYPE_CHECKING: class CompactNameScope(Module): fn: Callable module_fn: Callable[[], Module] @compact def __call__(self, *args, **kwargs) -> Any: return self.fn(self.module_fn(), *args, **kwargs) else: @dataclasses.dataclass class CompactNameScope: fn: Callable module_fn: Callable name: str def __call__(self, *args, **kwargs) -> Any: ... def share_scope(module: Module, other: Module, /): """Modifies one of the Modules such that they share the same scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure. ``share_scope`` takes two Modules, ``module`` and ``other``. ``module`` will use ``other``'s scope if ``other`` has a scope and its not a descendant of``module``'s scope:: >>> import flax.linen as nn >>> import jax >>> from jax import numpy as jnp, random ... >>> class DenseLoRA(nn.Module): ... base: nn.Dense ... rank: int ... ... def setup(self): ... nn.share_scope(self, self.base) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.base.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.base(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... dense = nn.Dense(10) # base scope ... return DenseLoRA(dense, rank=2)(x) # reuse the base scope ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['Dense_0'].keys()) ['A', 'B', 'kernel', 'bias'] When ``other``'s scope is a descendant of ``module``'s scope then ``other`` will use ``module``'s scope instead:: >>> class DenseLoRA(nn.Module): ... features: int ... rank: int ... ... def setup(self): ... self.child = nn.Dense(self.features) ... nn.share_scope(self, self.child) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.child(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... return DenseLoRA(10, rank=2)(x) ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['DenseLoRA_0'].keys()) ['A', 'B', 'kernel', 'bias'] """ if module.scope is None or other.scope is None: raise errors.CallShareScopeOnUnboundModuleError() def _is_child_scope(scope: Scope, other: Scope) -> bool: target: Scope | None = other while target is not None: if target is scope: return True target = target.parent return False if _is_child_scope(module.scope, other.scope): # Child is a true child, overwrite its scope module_to_update = other new_scope = module.scope else: # Child has its own independent scope, overwrite # parent scope, so that we preserve the sharing module_to_update = module new_scope = other.scope old_scope = module_to_update.scope object.__setattr__(module_to_update, 'scope', new_scope) # Reattach all the children to the new scope as well. for m in module_to_update._state.children.values(): if not isinstance(m, Module): continue # Should we go recursively to check if any of the ancestors point to the old # scope? if m.scope and m.scope.parent == old_scope: # Reserve the scope, so that if there is a conflict we can raise an error. if isinstance(m.scope.name, str): new_scope.reserve(m.scope.name) m.scope.parent = new_scope ================================================ FILE: flax/linen/normalization.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalization modules for Flax.""" import dataclasses import functools from typing import Any from collections.abc import Iterable import jax import jax.numpy as jnp from jax import lax from jax.nn import initializers from flax.linen import dtypes, module, transforms from flax.typing import ( Array, PRNGKey as PRNGKey, Dtype, Shape as Shape, Initializer, Axes, ) field = dataclasses.field canonicalize_dtype = dtypes.canonicalize_dtype compact = module.compact Module = module.Module merge_param = module.merge_param map_variables = transforms.map_variables def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, Iterable): axes = (axes,) return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): """Computes the elementwise square of the absolute value |x|^2.""" if jnp.iscomplexobj(x): return lax.square(lax.real(x)) + lax.square(lax.imag(x)) else: return lax.square(x) def _compute_stats( x: Array, axes: Axes, dtype: Dtype | None, axis_name: str | None = None, axis_index_groups: Any = None, use_mean: bool = True, use_fast_variance: bool = True, mask: Array | None = None, force_float32_reductions=True, ): """Computes mean and variance statistics. This implementation takes care of a few important details: - By default, computes in float32 precision for stability in half precision training. - If `use_fast_variance` is `True`, mean and variance are computed using Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion. - Clips negative variances to zero which can happen due to roundoff errors. This avoids downstream NaNs. - Supports averaging across a parallel axis and subgroups of a parallel axis with a single `lax.pmean` call to avoid latency. Arguments: x: Input array. axes: The axes in ``x`` to compute mean and variance statistics for. dtype: Optional dtype specifying the minimal precision. Statistics are always at least float32 for stability (default: dtype of x). axis_name: Optional name for the pmapped axis to compute mean over. Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: Optional groups of indices within that named axis. use_mean: If true, calculate the mean from the input and use it when computing the variance. If false, set the mean to zero and compute the variance without subtracting the mean. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. mask: Binary array of shape broadcastable to `inputs` tensor, indicating the positions for which the mean and variance should be computed. force_float32_reductions: If false, this will skip float32 promotion and use the input dtype or inherited dtype from ``x``. Returns: A pair ``(mean, var)``. """ if dtype is None: dtype = jnp.result_type(x) # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points if force_float32_reductions: dtype = jnp.promote_types(dtype, jnp.float32) if isinstance(x, jax.Array): x = x.astype(dtype) else: x = jnp.asarray(x, dtype) axes = _canonicalize_axes(x.ndim, axes) def maybe_distributed_mean(*xs, mask=None): if mask is not None: mask = jnp.asarray(mask, dtype=bool) mus = tuple(x.mean(axes, where=mask) for x in xs) if axis_name is None: return mus if len(xs) > 1 else mus[0] else: # In the distributed case we stack multiple arrays to speed comms. if len(xs) > 1: reduced_mus = lax.pmean( jnp.stack(mus, axis=0), axis_name, axis_index_groups=axis_index_groups, ) return tuple(reduced_mus[i] for i in range(len(xs))) else: return lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups) if use_mean: if use_fast_variance: mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: mu = maybe_distributed_mean(x, mask=mask) var = maybe_distributed_mean( _abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask ) else: var = maybe_distributed_mean(_abs_sq(x), mask=mask) mu = jnp.zeros_like(var) return mu, var def _normalize( mdl: Module, x: Array, mean: Array, var: Array, reduction_axes: Axes, feature_axes: Axes, dtype: Dtype | None, param_dtype: Dtype, epsilon: float, use_bias: bool, use_scale: bool, bias_init: Initializer, scale_init: Initializer, force_float32_reductions: bool = True ): """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. Arguments: mdl: Module to apply the normalization in (normalization params will reside in this module). x: The input. mean: Mean to use for normalization. var: Variance to use for normalization. reduction_axes: The axes in ``x`` to reduce. feature_axes: Axes containing features. A separate bias and scale is learned for each specified feature. dtype: The dtype of the result (default: infer from input and params). param_dtype: The dtype of the parameters. epsilon: Normalization epsilon. use_bias: If true, add a bias term to the output. use_scale: If true, scale the output. bias_init: Initialization function for the bias term. scale_init: Initialization function for the scaling function. force_float32_reductions: If false, the scale and bias parameters use the param_dtype. Otherwise, they will have at least float32 precision due to the mean and var being promoted to float32. Returns: The normalized input. """ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) feature_shape = [1] * x.ndim reduced_feature_shape = [] for ax in feature_axes: feature_shape[ax] = x.shape[ax] reduced_feature_shape.append(x.shape[ax]) mean = jnp.expand_dims(mean, reduction_axes) var = jnp.expand_dims(var, reduction_axes) y = x - mean mul = lax.rsqrt(var + epsilon) args = [x] if use_scale: scale = mdl.param( 'scale', scale_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) if not force_float32_reductions: scale = jnp.asarray(scale, param_dtype) mul *= scale args.append(scale) y *= mul if use_bias: bias = mdl.param( 'bias', bias_init, reduced_feature_shape, param_dtype ).reshape(feature_shape) if not force_float32_reductions: bias = jnp.asarray(bias, param_dtype) y += bias args.append(bias) dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) return jnp.asarray(y, dtype) def _l2_normalize(x, axis=None, eps=1e-12): """Normalizes along dimension `axis` using an L2 norm. This specialized function exists for numerical stability reasons. Args: x: An input ndarray. axis: Dimension along which to normalize, e.g. `1` to separately normalize vectors in a batch. Passing `None` views `t` as a flattened vector when calculating the norm (equivalent to Frobenius norm). eps: Epsilon to avoid dividing by zero. Returns: An array of the same shape as 'x' L2-normalized along 'axis'. """ return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) class BatchNorm(Module): """BatchNorm Module. Usage Note: If we define a model with BatchNorm, for example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32) The initialized variables dict will contain, in addition to a 'params' collection, a separate 'batch_stats' collection that will contain all the running statistics for all the BatchNorm layers in a model:: >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> variables = BN.init(jax.random.key(1), x, use_running_average=False) >>> jax.tree_util.tree_map(jnp.shape, variables) {'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}} We then update the batch_stats during training by specifying that the ``batch_stats`` collection is mutable in the ``apply`` method for our module.:: >>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False) During eval we would define BN with ``use_running_average=True`` and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn't mark it mutable:: >>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True) Attributes: use_running_average: if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ use_running_average: bool | None = None axis: int = -1 momentum: float = 0.99 epsilon: float = 1e-5 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact def __call__( self, x, use_running_average: bool | None = None, *, mask: jax.Array | None = None, ): """Normalizes the input using batch statistics. .. note:: During initialization (when ``self.is_initializing()`` is ``True``) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don't need to match that of the actual input distribution and the reduction axis (set with ``axis_name``) does not have to exist. Args: x: the input to be normalized. use_running_average: if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input. mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = module.merge_param( 'use_running_average', self.use_running_average, use_running_average ) feature_axes = _canonicalize_axes(x.ndim, self.axis) reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) feature_shape = [x.shape[ax] for ax in feature_axes] ra_mean = self.variable( 'batch_stats', 'mean', lambda s: jnp.zeros( s, jnp.float32 if self.force_float32_reductions else self.param_dtype, ), feature_shape, ) ra_var = self.variable( 'batch_stats', 'var', lambda s: jnp.ones( s, jnp.float32 if self.force_float32_reductions else self.param_dtype, ), feature_shape, ) if use_running_average: mean = ( ra_mean.value if self.force_float32_reductions else jnp.asarray(ra_mean.value, self.param_dtype) ) var = ( ra_var.value if self.force_float32_reductions else jnp.asarray(ra_var.value, self.param_dtype) ) else: mean, var = _compute_stats( x, reduction_axes, dtype=self.dtype, axis_name=self.axis_name if not self.is_initializing() else None, axis_index_groups=self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, force_float32_reductions=self.force_float32_reductions, ) if not self.is_initializing(): ra_mean.value = ( self.momentum * ra_mean.value + (1 - self.momentum) * mean ) ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var return _normalize( self, x, mean, var, reduction_axes, feature_axes, self.dtype, self.param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init, self.force_float32_reductions, ) class LayerNorm(Module): """Layer normalization (https://arxiv.org/abs/1607.06450). LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. .. note:: This normalization operation is identical to InstanceNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters). Example usage:: >>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.LayerNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) Attributes: epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. reduction_axes: Axes for computing normalization statistics. feature_axes: Feature axes for learned bias and scaling. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ epsilon: float = 1e-6 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones reduction_axes: Axes = -1 feature_axes: Axes = -1 axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact def __call__(self, x, *, mask: jax.Array | None = None): """Applies layer normalization on the input. Args: x: the inputs mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ mean, var = _compute_stats( x, self.reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, force_float32_reductions=self.force_float32_reductions, ) return _normalize( self, x, mean, var, self.reduction_axes, self.feature_axes, self.dtype, self.param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init, self.force_float32_reductions, ) class RMSNorm(Module): """RMS Layer normalization (https://arxiv.org/abs/1910.07467). RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations. Example usage:: >>> import flax.linen as nn >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nn.RMSNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}} >>> y = layer.apply(variables, x) Attributes: epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. scale_init: Initializer for scale, by default, one. reduction_axes: Axes for computing normalization statistics. feature_axes: Feature axes for learned bias and scaling. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ epsilon: float = 1e-6 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_scale: bool = True scale_init: Initializer = initializers.ones reduction_axes: Axes = -1 feature_axes: Axes = -1 axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact def __call__(self, x, *, mask: jax.Array | None = None): """Applies RMS layer normalization on the input. Args: x: the inputs mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ mean, var = _compute_stats( x, self.reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_mean=False, use_fast_variance=self.use_fast_variance, mask=mask, force_float32_reductions=self.force_float32_reductions, ) return _normalize( self, x, mean, var, self.reduction_axes, self.feature_axes, self.dtype, self.param_dtype, self.epsilon, False, self.use_scale, initializers.zeros, self.scale_init, self.force_float32_reductions, ) class GroupNorm(Module): """Group normalization (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group. .. note:: LayerNorm is a special case of GroupNorm where ``num_groups=1``, and InstanceNorm is a special case of GroupNorm where ``group_size=1``. Example usage:: >>> import flax.linen as nn >>> import jax >>> import numpy as np >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nn.GroupNorm(num_groups=3) >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> y = nn.GroupNorm(num_groups=1).apply(variables, x) >>> y2 = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x) >>> np.testing.assert_allclose(y, y2) >>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x) >>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2) Attributes: num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. reduction_axes: List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ num_groups: int | None = 32 group_size: int | None = None epsilon: float = 1e-6 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones reduction_axes: Axes | None = None axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact def __call__(self, x, *, mask: jax.Array | None = None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: x: the input of shape ``...C`` where ``C`` is a channels dimension and ``...`` represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions ``...`` will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch. mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ if self.reduction_axes is not None: reduction_axes = self.reduction_axes else: reduction_axes = list(range(1, x.ndim - 1)) + [-1] feature_axis = -1 reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) if reduction_axes[-1] != (feature_axis % x.ndim): raise ValueError( 'The reduction axes must include the final dimension ' 'as this is assumed to be the feature axis.' ) if (self.num_groups is None and self.group_size is None) or ( self.num_groups is not None and self.group_size is not None ): raise ValueError( 'Either `num_groups` or `group_size` should be ' 'specified. If `group_size` is to be specified, ' 'pass `num_groups=None` as argument to override ' 'the default `num_groups` value of 32.' ) channels = x.shape[-1] if self.group_size is not None: if channels % self.group_size != 0: raise ValueError( 'Number of channels ({}) is not multiple of the ' 'group size ({}).'.format(channels, self.group_size) ) num_groups = channels // self.group_size else: num_groups = self.num_groups assert isinstance(num_groups, int) if num_groups <= 0 or channels % num_groups != 0: raise ValueError( 'Number of groups ({}) does not divide the number' ' of channels ({}).'.format(num_groups, channels) ) group_size = x.shape[-1] // num_groups group_shape = x.shape[:-1] + (num_groups, group_size) if mask is not None: mask = mask.reshape(mask.shape[:-1] + (num_groups, group_size)) mean, var = _compute_stats( x.reshape(group_shape), list(reduction_axes[:-1]) + [-1], self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, force_float32_reductions=self.force_float32_reductions, ) mean = jnp.repeat(mean, group_size, axis=-1) var = jnp.repeat(var, group_size, axis=-1) return _normalize( self, x, mean, var, reduction_axes[:-1], (feature_axis,), self.dtype, self.param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init, self.force_float32_reductions, ) class InstanceNorm(Module): """Instance normalization (https://arxiv.org/abs/1607.08022v3). InstanceNorm normalizes the activations of the layer for each channel (rather than across all channels like Layer Normalization), and for each given example in a batch independently (rather than across an entire batch like Batch Normalization). i.e. applies a transformation that maintains the mean activation within each channel within each example close to 0 and the activation standard deviation close to 1. .. note:: This normalization operation is identical to LayerNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters). Example usage:: >>> import flax.linen as nn >>> import jax >>> import numpy as np >>> # dimensions: (batch, height, width, channel) >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) >>> layer = nn.InstanceNorm() >>> variables = layer.init(jax.random.key(1), x) >>> variables {'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}} >>> y = layer.apply(variables, x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x) >>> np.testing.assert_allclose(y, y3, atol=1e-7) Attributes: epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. feature_axes: Axes for features. The learned bias and scaling parameters will be in the shape defined by the feature axes. All other axes except the batch axes (which is assumed to be the leading axis) will be reduced. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ epsilon: float = 1e-6 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Initializer = initializers.zeros scale_init: Initializer = initializers.ones feature_axes: Axes = -1 axis_name: str | None = None axis_index_groups: Any = None use_fast_variance: bool = True force_float32_reductions: bool = True @compact def __call__(self, x, *, mask: jax.Array | None = None): """Applies instance normalization on the input. Args: x: the inputs mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ feature_axes = _canonicalize_axes(x.ndim, self.feature_axes) if 0 in feature_axes: raise ValueError('The channel axes cannot include the leading dimension ' 'as this is assumed to be the batch axis.') reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes] mean, var = _compute_stats( x, reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, force_float32_reductions=self.force_float32_reductions, ) return _normalize( self, x, mean, var, reduction_axes, feature_axes, self.dtype, self.param_dtype, self.epsilon, self.use_bias, self.use_scale, self.bias_init, self.scale_init, self.force_float32_reductions, ) class SpectralNorm(Module): """Spectral normalization. See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096 Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its ``__call__`` output. .. note:: The initialized variables dict will contain, in addition to a 'params' collection, a separate 'batch_stats' collection that will contain a ``u`` vector and ``sigma`` value, which are intermediate values used when performing spectral normalization. During training, we pass in ``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u`` and ``sigma`` are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass in ``update_stats=False`` to ensure we get deterministic behavior from the model. Example usage:: >>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> import optax >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(3)(x) ... # only spectral normalize the params of the second Dense layer ... x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train) ... x = nn.Dense(5)(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 5)) >>> model = Foo() >>> variables = model.init(jax.random.PRNGKey(0), x, train=False) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ batch_stats: { SpectralNorm_0: { Dense_1/kernel/sigma: (), Dense_1/kernel/u: (1, 4), }, }, params: { Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, }, }) >>> # train >>> def train_step(variables, x, y): ... def loss_fn(params): ... logits, updates = model.apply( ... {'params': params, 'batch_stats': variables['batch_stats']}, ... x, ... train=True, ... mutable=['batch_stats'], ... ) ... loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y)) ... return loss, updates ... ... (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)( ... variables['params'] ... ) ... return { ... 'params': jax.tree_util.tree_map( ... lambda p, g: p - 0.1 * g, variables['params'], grads ... ), ... 'batch_stats': updates['batch_stats'], ... }, loss >>> for _ in range(10): ... variables, loss = train_step(variables, x, y) >>> # inference / eval >>> out = model.apply(variables, x, train=False) Attributes: layer_instance: Module instance that is wrapped with SpectralNorm n_steps: How many steps of power iteration to perform to approximate the singular value of the weight params. epsilon: A small float added to l2-normalization to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). error_on_non_matrix: Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer. collection_name: Name of the collection to store intermediate values used when performing spectral normalization. """ layer_instance: Module n_steps: int = 1 epsilon: float = 1e-12 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 error_on_non_matrix: bool = False collection_name: str = 'batch_stats' @compact def __call__(self, *args, update_stats: bool, **kwargs): """Compute the largest singular value of the weights in ``self.layer_instance`` using power iteration and normalize the weights using this value before computing the ``__call__`` output. Args: *args: positional arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. update_stats: if True, update the internal ``u`` vector and ``sigma`` value after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. **kwargs: keyword arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. Returns: Output of the layer using spectral normalized weights. """ def layer_forward(layer_instance): return layer_instance(*args, **kwargs) return transforms.map_variables( layer_forward, trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( functools.partial( self._spectral_normalize, update_stats=update_stats, ), vs, ), init=self.is_initializing(), mutable=True, )(self.layer_instance) def _spectral_normalize(self, path, vs, update_stats): """Compute the largest singular value using power iteration and normalize the variables ``vs`` using this value. This is intended to be a helper function used in this Module's ``__call__`` method in conjunction with ``nn.transforms.map_variables`` and ``jax.tree_util.tree_map_with_path``. Args: path: dict key path, used for naming the ``u`` and ``sigma`` variables vs: variables to be spectral normalized update_stats: if True, update the ``u`` vector and ``sigma`` variables after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. """ value = jnp.asarray(vs) value_shape = value.shape # Skip and return value if input is scalar, vector or if number of power # iterations is less than 1 if value.ndim <= 1 or self.n_steps < 1: return value # Handle higher-order tensors. elif value.ndim > 2: if self.error_on_non_matrix: raise ValueError( f'Input is {value.ndim}D but error_on_non_matrix is True' ) else: value = jnp.reshape(value, (-1, value.shape[-1])) u_var_name = ( self.layer_instance.name + '/' + '/'.join(dict_key.key for dict_key in path[1:]) + '/u' ) u_var = self.variable( self.collection_name, u_var_name, jax.random.normal, self.make_rng('params') if not self.has_variable(self.collection_name, u_var_name) else None, (1, value.shape[-1]), self.param_dtype, ) u0 = u_var.value sigma_var_name = ( self.layer_instance.name + '/' + '/'.join(dict_key.key for dict_key in path[1:]) + '/sigma' ) sigma_var = self.variable( self.collection_name, sigma_var_name, jnp.ones, (), self.param_dtype ) # Power iteration for the weight's singular value. for _ in range(self.n_steps): v0 = _l2_normalize( jnp.matmul(u0, value.transpose([1, 0])), eps=self.epsilon ) u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.epsilon) u0 = jax.lax.stop_gradient(u0) v0 = jax.lax.stop_gradient(v0) sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0] value /= jnp.where(sigma != 0, sigma, 1) value_bar = value.reshape(value_shape) if update_stats: u_var.value = u0 sigma_var.value = sigma dtype = dtypes.canonicalize_dtype(vs, u0, v0, sigma, dtype=self.dtype) return jnp.asarray(value_bar, dtype) class WeightNorm(Module): """L2 weight normalization (https://arxiv.org/abs/1602.07868). Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its ``__call__`` output. Example usage:: >>> import flax, flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Baz(nn.Module): ... @nn.compact ... def __call__(self, x): ... return nn.Dense(2)(x) >>> class Bar(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = Baz()(x) ... x = nn.Dense(3)(x) ... x = Baz()(x) ... x = nn.Dense(3)(x) ... return x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... # l2-normalize all params of the second Dense layer ... x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x) ... x = nn.Dense(5)(x) ... # l2-normalize all kernels in the Bar submodule and all params in ... # the Baz submodule ... x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x) ... return x >>> # init >>> x = jnp.ones((1, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables)) FrozenDict({ params: { Bar_0: { Baz_0: { Dense_0: { bias: (2,), kernel: (5, 2), }, }, Baz_1: { Dense_0: { bias: (2,), kernel: (3, 2), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (3,), kernel: (2, 3), }, }, Dense_0: { bias: (3,), kernel: (2, 3), }, Dense_1: { bias: (4,), kernel: (3, 4), }, Dense_2: { bias: (5,), kernel: (4, 5), }, WeightNorm_0: { Dense_1/bias/scale: (4,), Dense_1/kernel/scale: (4,), }, WeightNorm_1: { Bar_0/Baz_0/Dense_0/bias/scale: (2,), Bar_0/Baz_0/Dense_0/kernel/scale: (2,), Bar_0/Baz_1/Dense_0/bias/scale: (2,), Bar_0/Baz_1/Dense_0/kernel/scale: (2,), Bar_0/Dense_0/kernel/scale: (3,), Bar_0/Dense_1/kernel/scale: (3,), }, }, }) Attributes: layer_instance: Module instance that is wrapped with WeightNorm epsilon: A small float added to l2-normalization to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_scale: If True, creates a learnable variable ``scale`` that is multiplied to the ``layer_instance`` variables after l2-normalization. scale_init: Initialization function for the scaling function. feature_axes: The feature axes dimension(s). The l2-norm is calculated by reducing the ``layer_instance`` variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (if ``use_scale=True``) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis. variable_filter: An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the ``layer_instance`` variables whose key path (delimited by '/') has a match with ``variable_filter``. For example, ``variable_filter={'kernel'}`` will only apply l2-normalization to variables whose key path contains 'kernel'. By default, ``variable_filter={'kernel'}``. """ layer_instance: Module epsilon: float = 1e-12 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_scale: bool = True scale_init: Initializer = initializers.ones feature_axes: Axes | None = -1 variable_filter: Iterable | None = dataclasses.field( default_factory=lambda: {'kernel'} ) @compact def __call__(self, *args, **kwargs): """Compute the l2-norm of the weights in ``self.layer_instance`` and normalize the weights using this value before computing the ``__call__`` output. Args: *args: positional arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. **kwargs: keyword arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. Returns: Output of the layer using l2-normalized weights. """ def layer_forward(layer_instance): return layer_instance(*args, **kwargs) return transforms.map_variables( layer_forward, trans_in_fn=lambda vs: jax.tree_util.tree_map_with_path( self._l2_normalize, vs, ), init=self.is_initializing(), )(self.layer_instance) def _l2_normalize(self, path, vs): """Compute the l2-norm and normalize the variables ``vs`` using this value. This is intended to be a helper function used in this Module's ``__call__`` method in conjunction with ``nn.transforms.map_variables`` and ``jax.tree_util.tree_map_with_path``. Args: path: dict key path, used for naming the ``scale`` variable vs: variables to be l2-normalized """ value = jnp.asarray(vs) str_path = ( self.layer_instance.name + '/' + jax.tree_util.keystr(path[1:], simple=True, separator='/') ) if self.variable_filter: for variable_name in self.variable_filter: if variable_name in str_path: break else: return value if self.feature_axes is None: feature_axes = () reduction_axes = tuple(i for i in range(value.ndim)) else: feature_axes = _canonicalize_axes(value.ndim, self.feature_axes) reduction_axes = tuple( i for i in range(value.ndim) if i not in feature_axes ) feature_shape = [1] * value.ndim reduced_feature_shape = [] for ax in feature_axes: feature_shape[ax] = value.shape[ax] reduced_feature_shape.append(value.shape[ax]) value_bar = _l2_normalize(value, axis=reduction_axes, eps=self.epsilon) args = [vs] if self.use_scale: scale = self.param( str_path + '/scale', self.scale_init, reduced_feature_shape, self.param_dtype, ).reshape(feature_shape) value_bar *= scale args.append(scale) dtype = dtypes.canonicalize_dtype(*args, dtype=self.dtype) return jnp.asarray(value_bar, dtype) ================================================ FILE: flax/linen/partitioning.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Legacy utilities for working with pjit and partitioned models. **Experimental: please give feedback, and expect changes.** This module introduces `axis_rules`, `logical_to_mesh_axes`, `with_sharding_constraint` for appyling pjit sharding constraints in terms of "logical named axes" rather than pjit's default mesh axes. Additionally, flax linen methods `param_with_axes` and `variable_with_axes` are introduced alongside `get_axis_names` for defining variables and parameters and variables with logical axis name annotations that are managed as metadata. Lastly, `*_with_axes` versions of `nn.scan` and `nn.vmap` are introduced to add logical axis metadata to the underlying Lifted transformations. """ import functools import re from typing import (Any, Optional, Tuple) from collections.abc import Callable, Mapping import flax from flax import linen as nn from flax import struct from flax.core.frozen_dict import freeze from flax.core.frozen_dict import unfreeze from flax.core.scope import ( CollectionFilter as CollectionFilter, PRNGSequenceFilter as PRNGSequenceFilter, ) from flax.core.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import from flax.core.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import from flax.core.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import from flax.linen.spmd import _is_logical_spec from flax.linen.spmd import _with_sharding_constraint # pylint: disable=unused-import from flax.linen.spmd import logical_to_mesh # pylint: disable=unused-import from flax.linen.spmd import logical_to_mesh_axes # pylint: disable=unused-import from flax.linen.spmd import RulesFallback from flax.linen.spmd import with_logical_constraint as with_sharding_constraint from flax.traverse_util import flatten_dict from flax.traverse_util import unflatten_dict from flax.typing import ( Array, In as ScanIn, # pylint: disable=unused-import Out as ScanOut, # pylint: disable=unused-import InOutAxis, InOutScanAxis, LogicalRules, # pylint: disable=unused-import ArrayPytree, # pylint: disable=unused-import LogicalPartitionSpec, # pylint: disable=unused-import LogicalPartitionSpecPytree, PartitionSpecPytree, # pylint: disable=unused-import ) import jax # ------------------------------------------------------------------------------ # NOTICE: This experimental partitioning utility API is deprecated # # We intend to continue supporting it indefinitely for those using it, but # we encourage new users to adopt the simpler metadata handling system found # in "spmd.py". # ------------------------------------------------------------------------------ # Annotated parameters and Module axis metadata handling. # ------------------------------------------------------------------------------ @struct.dataclass class AxisMetadata: """Contains a tuple of axis names, which is passed through FLAX.""" names: LogicalPartitionSpecPytree = struct.field(pytree_node=False) def _param_with_axes_sow_reduce_fn(x, y): """Reduction function for sow() calls. Args: x: Existing value, or () if there was none. y: New axis names sown. Returns: New axis names. Raises: TypeError: If the newly sown value is not an AxisMetadata. ValueError: If the newly sown axis names don't match previously sown axis names. AssertionError: If a previously sown value was truthy and not an AxisMetadata. """ if not isinstance(y, AxisMetadata): raise TypeError('Expected newly sown value to be an AxisMetadata') if isinstance(x, AxisMetadata): if x != y: raise ValueError( 'If axis names are sown twice, expected them to match. ' f'Got {x} and {y}.' ) elif x: # Shouldn't happen, so raise a fairly internal error. raise AssertionError(f'Non-initial-or-AxisMetadata value encountered: {x}') return y def param_with_axes( name: str, init_fn, *init_args, axes: tuple[str, ...] | None = None, module: Optional['nn.Module'] = None, **init_kwargs, ): """Declares and returns a parameter with logical axes in the current Module. See :mod:`flax.linen.module.param` for original docstring. Args: name: The parameter name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module. *init_args: The positional arguments to pass to init_fn. axes: A tuple of axis names, must match the rank of the param array. module: Use an explicit module instead of deriving the most recent from dynamic module context. **init_kwargs: The key-word arguments to pass to init_fn. Returns: The value of the initialized parameter. Raises: TypeError: if axes specification is mal-formed. ValueError: if specified logical axes don't match parameter rank. """ # get current module if not explicitly provided if module is None: module = nn.module._context.module_stack[-1] # pylint: disable=protected-access assert module is not None # define/fetch parameter on that module module_param = module.param(name, init_fn, *init_args, **init_kwargs) if axes is not None: # apply logical axis constraint immediately module_param = with_sharding_constraint( module_param, jax.sharding.PartitionSpec(*axes) ) # record logical axis constraint for global axis metadata module.sow( 'params_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore reduce_fn=_param_with_axes_sow_reduce_fn, ) return module_param class PartitionedVariable(flax.core.scope.Variable): """A PartitionedVariable object allows mutable access to a variable. PartitionedVariables are identified by a collection (e.g., "batch_stats") and a name (e.g., "moving_mean"). The value property gives access to the variable's content and can be assigned to for mutation. Additionally, PartitionedVariables enforce logical sharding constraints on both retrieval and assignment. """ def __init__( self, scope, collection: str, name: str, axes: tuple[str, ...] | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): """Initializes a partitioned variable. Args: scope: The scope in which the variable is stored. collection: The collection of the variable (e.g., "params"). name: The name of the variable (e.g., "dense"). axes: logical axes name of variable. fallback: Fallback behavior if no matching rule is found. """ self.scope = scope self.collection = collection self.name = name self.axes = axes self.fallback = fallback @property def value(self): """Returns the value of this Variable.""" value = self.scope.get_variable(self.collection, self.name) if self.axes is not None: value = with_sharding_constraint(value, self.axes, fallback=self.fallback) return value @value.setter def value(self, value): """Updates the value of this Variable.""" if self.axes is not None: value = with_sharding_constraint(value, self.axes, fallback=self.fallback) self.scope.put_variable(self.collection, self.name, value) def _core_variable_with_axes( scope, col: str, name: str, init_fn: Callable[..., Any], *init_args, axes: tuple[str, ...] | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, **init_kwargs, ): """Variant of flax core variable scope call with sharding constraints.""" scope.reserve(name) if not scope.has_variable(col, name): if not scope.is_mutable_collection(col): raise flax.errors.ScopeVariableNotFoundError(name, col, scope.path_text) init_value = init_fn(*init_args, **init_kwargs) if axes is not None: init_value = with_sharding_constraint(init_value, axes, fallback=fallback) scope.put_variable(col, name, init_value) return PartitionedVariable(scope, col, name, axes, fallback) def variable_with_axes( collection: str, name: str, init_fn, *init_args, axes: tuple[str, ...] | None = None, module: Optional['nn.Module'] = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, **init_kwargs, ): """Declares and returns a variable with logical axes in the current Module. See :mod:`flax.linen.module.variable` for original docstring. Args: collection: The name of the variable collection. name: The variable name. init_fn: The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module. *init_args: The positional arguments to pass to init_fn. axes: A tuple of axis names, must match the rank of the variable array. module: Use an explicit module instead of deriving the most recent from dynamic module context. fallback: How sharding should behave if there is no rule covering some axis. **init_kwargs: The key-word arguments to pass to init_fn. Returns: A flax `PartitionedVariable` object referencing the initialized variable array. Raises: TypeError: if axes specification is mal-formed. ValueError: if specified logical axes don't match parameter rank. """ # get current module if not explicitly provided if module is None: module = nn.module._context.module_stack[-1] # pylint: disable=protected-access assert module is not None module_var = _core_variable_with_axes( module.scope, collection, name, init_fn, *init_args, axes=axes, fallback=fallback, **init_kwargs, ) if axes is not None: # record logical axis constraint for global axis metadata module.sow( f'{collection}_axes', f'{name}_axes', AxisMetadata(axes), # type: ignore reduce_fn=_param_with_axes_sow_reduce_fn, ) return module_var def get_axis_names(axes_metadata): """Gets axis names for variables as logical PartitionSpecs. Args: axes_metadata: a single axes-metadata collection from a flax-initialized set of collections. Returns: Collection of Partitionspecs with logical axis names, with the "_axes" suffix on variable names removed to match original variable collection for annotations. """ def leaf_rewrite(x): return None if x is None else jax.sharding.PartitionSpec(*x) def rewrite(tree): return jax.tree_util.tree_map(leaf_rewrite, tree, is_leaf=_is_logical_spec) axes_metadata = unfreeze(axes_metadata) # pytype: disable=wrong-arg-types flat_dict = { re.sub(r'_axes$', '', '/'.join(k)): rewrite(v.names) for k, v in flatten_dict(axes_metadata).items() } return freeze( unflatten_dict({tuple(k.split('/')): v for k, v in flat_dict.items()}) ) # Metadata Aware Scan # ----------------------------------------------------------------------------- def _tree_map_axes(fn, tree): """Only map over AxisMetadata leaves in pytree - identity for other leaves.""" safe_fn = lambda x: fn(x) if isinstance(x, AxisMetadata) else x return jax.tree_util.tree_map( safe_fn, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) ) def _is_mutable(axis_col: str) -> bool: """Determines whether a collection is mutable. For example, when a module is called with `module.apply(..., mutable=['z'])`, this function will return True for `axis_col='z'` and False otherwise. If there is no module in scope, this function will return True. Args: axis_col: Name of the collection in question. Returns: Whether it is currently mutable. """ last = nn.module._context.module_stack[-1] # pylint: disable=protected-access if last: return last.is_mutable_collection(axis_col) else: return True # uses this variable_transform to change 'params_axes' pytree as it bubbles # up / out from scan. def _add_axis_to_metadata(fn, axis_pos, axis_name, axis_col='params_axes'): """Insert a named axis to axes metadata.""" # Handle In() / Out() scan axis marker types. if hasattr(axis_pos, 'axis'): axis_pos = axis_pos.axis def insert_fn_leaf(names): if names is None: return names names = list(names) names.insert(axis_pos, axis_name) return tuple(names) def insert_fn(x): new_names = jax.tree_util.tree_map( insert_fn_leaf, x.names, is_leaf=_is_logical_spec ) return x.replace(names=new_names) def remove_fn_leaf(names): if names is None: return names names = list(names) if names[axis_pos] != axis_name: raise ValueError( f'Expected axis {axis_name} at position {axis_pos} in ' f'axis metadata {names}.' ) names.pop(axis_pos) return tuple(names) def remove_fn(x): new_names = jax.tree_util.tree_map( remove_fn_leaf, x.names, is_leaf=_is_logical_spec ) return x.replace(names=new_names) return nn.transforms.map_variables( fn, axis_col, mutable=_is_mutable(axis_col), trans_in_fn=lambda tree: _tree_map_axes(remove_fn, tree), trans_out_fn=lambda tree: _tree_map_axes(insert_fn, tree), ) # pylint: disable=dangerous-default-value def scan_with_axes( target: 'flax.linen.transforms.Target', variable_axes: Mapping[ CollectionFilter, InOutScanAxis ] = {}, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, length: int | None = None, reverse: bool = False, unroll: int = 1, axis_name: str = 'layers', axes_collections: tuple[str, ...] = ('params',), data_transform: Callable[..., Any] | None = None, methods=None, ) -> 'flax.linen.transforms.Target': """Wrapped version of nn.scan that handles logical axis metadata.""" # we broadcast the static metadata collections. axes_filters = tuple(f'{col}_axes' for col in axes_collections) variable_broadcast = flax.core.scope.union_filters( variable_broadcast, axes_filters ) # perform usual lifted scan scanned = flax.linen.transforms.lift_transform( flax.core.lift.scan, target, variable_axes=variable_axes, variable_broadcast=variable_broadcast, variable_carry=variable_carry, split_rngs=split_rngs, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, data_transform=data_transform, methods=methods, ) # add scan axis to logical axes metadata for col in axes_collections: if col in variable_axes: scanned = _add_axis_to_metadata( scanned, axis_pos=variable_axes[col], axis_name=axis_name, axis_col=f'{col}_axes', ) return scanned # pylint: disable=dangerous-default-value def vmap_with_axes( target: 'flax.linen.transforms.Target', variable_axes: Mapping[ CollectionFilter, InOutAxis ], split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, axis_size: int | None = None, axis_name: str | None = None, partitioning_axis_names: Mapping[Any, str] = {}, spmd_axis_name: str | None = None, methods=None, ) -> 'flax.linen.transforms.Target': """Wrapped version of nn.vmap that handles logical axis metadata.""" # tell normal vmap to broadcast axis metadata. variable_axes = dict(variable_axes) # shallow copy for name in partitioning_axis_names: variable_axes[f'{name}_axes'] = None # perform usual lifted vmap vmapped = flax.linen.transforms.lift_transform( flax.core.lift.vmap, target, variable_axes=variable_axes, split_rngs=split_rngs, in_axes=in_axes, out_axes=out_axes, axis_size=axis_size, axis_name=axis_name, spmd_axis_name=spmd_axis_name, methods=methods, ) for collection_name, axis in variable_axes.items(): if collection_name in partitioning_axis_names: vmapped = _add_axis_to_metadata( # pylint: disable=protected-access vmapped, axis_pos=axis, axis_name=partitioning_axis_names[collection_name], axis_col=f'{collection_name}_axes', ) return vmapped # Remat abstraction bug hotfix # ------------------------------------------------------------------------------ # TODO(levskaya): upstream this fix into main flax.core.lift.remat. # Workaround a scan(remat(...)) abstraction bug by manually implementing a # static_argnums behavior for flax remat via closure before applying jax remat. def core_remat_static( fn, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, ): """Flax functional core remat version with static_argnums.""" static_argnums = tuple(sorted(static_argnums)) def _repack_remat_args(dyn_args, static_args): """Remake arg list from static and dynamic args given static_argnums.""" args = [] s_cnt, d_cnt = 0, 0 for i in range(len(dyn_args) + len(static_args)): if i in static_argnums: args.append(static_args[s_cnt]) s_cnt += 1 else: args.append(dyn_args[d_cnt]) d_cnt += 1 return tuple(args) def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): static_args = tuple(x for i, x in enumerate(args) if i in static_argnums) dyn_args = tuple(x for i, x in enumerate(args) if i not in static_argnums) # After JAX v0.3.16, concrete=False is a no-op and concrete=True raises # NotImplementedError. Starting in JAX v0.8.2, the concrete argument is # deprecated and will be removed in the future. if concrete: raise NotImplementedError( "The concrete argument is deprecated. Use static_argnums instead." " for more information, see" " https://docs.jax.dev/en/latest/jep/11830-new-remat-checkpoint.html" ) @functools.partial( jax.remat, prevent_cse=prevent_cse, policy=policy ) @functools.wraps(fn) def rematted(variable_groups, rng_groups, *dyn_args): args = _repack_remat_args(dyn_args, static_args) scope = scope_fn(variable_groups, rng_groups) y = fn(scope, *args) return y, repack_fn(scope) return rematted(variable_groups, rng_groups, *dyn_args) return flax.core.lift.pack( inner, (variables,), (variables,), (rngs,), name='remat' ) def remat( target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None, ): """Flax lifted remat that supports static_argnums.""" return flax.linen.transforms.lift_transform( core_remat_static, target, variables=variables, rngs=rngs, concrete=concrete, prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, methods=methods, ) ================================================ FILE: flax/linen/pooling.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pooling modules.""" import jax.numpy as jnp import numpy as np from jax import lax def pool(inputs, init, reduce_fn, window_shape, strides, padding): """Helper function to define pooling functions. Pooling functions are implemented using the ReduceWindow XLA op. .. note:: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable. Args: inputs: input data with dimensions (batch, window dims..., features). init: the initial value for the reduction reduce_fn: a reduce function of the form ``(T, T) -> T``. window_shape: a shape tuple defining the window to reduce over. strides: a sequence of ``n`` integers, representing the inter-window strides (default: ``(1, ..., 1)``). padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. Returns: The output of the reduction for each window slice. """ num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) assert len(window_shape) == len( strides ), f'len({window_shape}) must equal len({strides})' strides = (1,) * num_batch_dims + strides + (1,) dims = (1,) * num_batch_dims + window_shape + (1,) is_single_input = False if num_batch_dims == 0: # add singleton batch dimension because lax.reduce_window always # needs a batch dimension. inputs = inputs[None] strides = (1,) + strides dims = (1,) + dims is_single_input = True assert inputs.ndim == len(dims), f'len({inputs.shape}) != len({dims})' if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( f'padding {padding} must specify pads for same number of dims as ' f'window_shape {window_shape}' ) assert all( [len(x) == 2 for x in padding] ), f'each entry in padding {padding} must be length 2' padding = ((0, 0),) + padding + ((0, 0),) y = lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) if is_single_input: y = jnp.squeeze(y, axis=0) return y def avg_pool( inputs, window_shape, strides=None, padding='VALID', count_include_pad=True ): """Pools the input by taking the average over a window. Args: inputs: input data with dimensions (batch, window dims..., features). window_shape: a shape tuple defining the window to reduce over. strides: a sequence of ``n`` integers, representing the inter-window strides (default: ``(1, ..., 1)``). padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension (default: ``'VALID'``). count_include_pad: a boolean whether to include padded tokens in the average calculation (default: ``True``). Returns: The average for each window slice. """ y = pool(inputs, 0.0, lax.add, window_shape, strides, padding) if count_include_pad: y = y / np.prod(window_shape) else: div_shape = inputs.shape[:-1] + (1,) if len(div_shape) - 2 == len(window_shape): div_shape = (1,) + div_shape[1:] y = y / pool( jnp.ones(div_shape), 0.0, lax.add, window_shape, strides, padding ) return y def max_pool(inputs, window_shape, strides=None, padding='VALID'): """Pools the input by taking the maximum of a window slice. Args: inputs: input data with dimensions (batch, window dims..., features). window_shape: a shape tuple defining the window to reduce over. strides: a sequence of ``n`` integers, representing the inter-window strides (default: ``(1, ..., 1)``). padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension (default: ``'VALID'``). Returns: The maximum for each window slice. """ y = pool(inputs, -jnp.inf, lax.max, window_shape, strides, padding) return y def min_pool(inputs, window_shape, strides=None, padding='VALID'): """Pools the input by taking the minimum of a window slice. Args: inputs: Input data with dimensions (batch, window dims..., features). window_shape: A shape tuple defining the window to reduce over. strides: A sequence of ``n`` integers, representing the inter-window strides (default: ``(1, ..., 1)``). padding: Either the string ``'SAME'``, the string ``'VALID'``, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension (default: ``'VALID'``). Returns: The minimum for each window slice. """ return pool(inputs, jnp.inf, lax.min, window_shape, strides, padding) ================================================ FILE: flax/linen/recurrent.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Recurrent neural network modules. THe RNNCell modules can be scanned using lifted transforms. For more information see: https://flax.readthedocs.io/en/latest/developer_notes/lift.html. """ from functools import partial # pylint: disable=g-importing-member from typing import ( Any, TypeVar, ) from collections.abc import Callable, Mapping, Sequence import jax import numpy as np from absl import logging from jax import numpy as jnp from jax import random from typing_extensions import Protocol from flax.core.frozen_dict import FrozenDict from flax.core.scope import CollectionFilter, PRNGSequenceFilter from flax.linen import initializers, transforms from flax.linen.activation import sigmoid, tanh from flax.linen.dtypes import promote_dtype from flax.linen.linear import Conv, Dense, default_kernel_init from flax.linen.module import Module, compact, nowrap from flax.typing import ( Array, PRNGKey, Dtype, InOutScanAxis, Initializer, PrecisionLike, ) A = TypeVar('A') Carry = Any CarryHistory = Any Output = Any class RNNCellBase(Module): """RNN cell base class.""" @nowrap def initialize_carry( self, rng: PRNGKey, input_shape: tuple[int, ...] ) -> Carry: """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ raise NotImplementedError @property def num_feature_axes(self) -> int: """Returns the number of feature axes of the RNN cell.""" raise NotImplementedError class LSTMCell(RNNCellBase): r"""LSTM cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.LSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: features: number of output features. gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()) dtype: the dtype of the computation (default: infer from inputs and params). param_dtype: the dtype passed to parameter initializers (default: float32). """ features: int gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @compact def __call__(self, carry, inputs): r"""A long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial( Dense, features=hidden_features, use_bias=True, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, dtype=self.dtype, param_dtype=self.param_dtype, ) dense_i = partial( Dense, features=hidden_features, use_bias=False, kernel_init=self.kernel_init, dtype=self.dtype, param_dtype=self.param_dtype, ) i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h)) f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h)) g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h)) o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h)) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @nowrap def initialize_carry( self, rng: PRNGKey, input_shape: tuple[int, ...] ) -> tuple[Array, Array]: """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] key1, key2 = random.split(rng) mem_shape = batch_dims + (self.features,) c = self.carry_init(key1, mem_shape, self.param_dtype) h = self.carry_init(key2, mem_shape, self.param_dtype) return (c, h) @property def num_feature_axes(self) -> int: return 1 class DenseParams(Module): """Dummy module for creating parameters matching ``flax.linen.Dense``.""" features: int use_bias: bool = True param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() @compact def __call__(self, inputs: Array) -> tuple[Array, Array | None]: k = self.param( 'kernel', self.kernel_init, (inputs.shape[-1], self.features), self.param_dtype, ) if self.use_bias: b = self.param('bias', self.bias_init, (self.features,), self.param_dtype) else: b = None return k, b class OptimizedLSTMCell(RNNCellBase): r"""More efficient LSTM Cell that concatenates state components before matmul. The parameters are compatible with ``LSTMCell``. Note that this cell is often faster than ``LSTMCell`` as long as the hidden size is roughly <= 2048 units. The mathematical definition of the cell is the same as ``LSTMCell`` and as follows .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.OptimizedLSTMCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()). dtype: the dtype of the computation (default: infer from inputs and params). param_dtype: the dtype passed to parameter initializers (default: float32). """ features: int gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @compact def __call__( self, carry: tuple[Array, Array], inputs: Array ) -> tuple[tuple[Array, Array], Array]: r"""An optimized long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry hidden_features = h.shape[-1] def _concat_dense( inputs: Array, params: Mapping[str, tuple[Array, Array | None]], use_bias: bool = True, ) -> dict[str, Array]: # Concatenates the individual kernels and biases, given in params, into a # single kernel and single bias for efficiency before applying them using # dot_general. kernels = [kernel for kernel, _ in params.values()] kernel = jnp.concatenate(kernels, axis=-1) if use_bias: biases = [] for _, bias in params.values(): if bias is None: raise ValueError('bias is None but use_bias is True.') biases.append(bias) bias = jnp.concatenate(biases, axis=-1) else: bias = None inputs, kernel, bias = promote_dtype( inputs, kernel, bias, dtype=self.dtype ) y = jnp.dot(inputs, kernel) if use_bias: # This assert is here since mypy can't infer that bias cannot be None assert bias is not None y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # Split the result back into individual (i, f, g, o) outputs. split_indices = np.cumsum([kernel.shape[-1] for kernel in kernels[:-1]]) ys = jnp.split(y, split_indices, axis=-1) return dict(zip(params.keys(), ys)) # Create params with the same names/shapes as `LSTMCell` for compatibility. dense_params_h = {} dense_params_i = {} for component in ['i', 'f', 'g', 'o']: dense_params_i[component] = DenseParams( features=hidden_features, use_bias=False, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, name=f'i{component}', # type: ignore[call-arg] )(inputs) dense_params_h[component] = DenseParams( features=hidden_features, use_bias=True, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, name=f'h{component}', # type: ignore[call-arg] )(h) dense_h = _concat_dense(h, dense_params_h, use_bias=True) dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) i = self.gate_fn(dense_h['i'] + dense_i['i']) f = self.gate_fn(dense_h['f'] + dense_i['f']) g = self.activation_fn(dense_h['g'] + dense_i['g']) o = self.gate_fn(dense_h['o'] + dense_i['o']) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @nowrap def initialize_carry( self, rng: PRNGKey, input_shape: tuple[int, ...] ) -> tuple[Array, Array]: """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] key1, key2 = random.split(rng) mem_shape = batch_dims + (self.features,) c = self.carry_init(key1, mem_shape, self.param_dtype) h = self.carry_init(key2, mem_shape, self.param_dtype) return c, h @property def num_feature_axes(self) -> int: return 1 class SimpleCell(RNNCellBase): r"""Simple cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array} where x is the input and h is the output of the previous time step. If `residual` is `True`, .. math:: \begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array} Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.SimpleCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: features: number of output features. activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()) dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). residual: pre-activation residual connection (https://arxiv.org/abs/1801.06105). """ features: int activation_fn: Callable[..., Any] = tanh kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() residual: bool = False @compact def __call__(self, carry, inputs): """Simple cell. Args: carry: the hidden state of the Simple cell, initialized using ``SimpleCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ hidden_features = carry.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial( Dense, features=hidden_features, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, ) dense_i = partial( Dense, features=hidden_features, use_bias=True, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) new_carry = dense_i(name='i')(inputs) + dense_h(name='h')(carry) if self.residual: new_carry += carry new_carry = self.activation_fn(new_carry) return new_carry, new_carry @nowrap def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.features,) return self.carry_init(rng, mem_shape, self.param_dtype) @property def num_feature_axes(self) -> int: return 1 class GRUCell(RNNCellBase): r"""GRU cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array} where x is the input and h is the output of the previous time step. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.GRUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: features: number of output features. gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()) dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). """ features: int gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @compact def __call__(self, carry, inputs): """Gated recurrent unit (GRU) cell. Args: carry: the hidden state of the GRU cell, initialized using ``GRUCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial( Dense, features=hidden_features, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, ) dense_i = partial( Dense, features=hidden_features, use_bias=True, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) r = self.gate_fn(dense_i(name='ir')(inputs) + dense_h(name='hr')(h)) z = self.gate_fn(dense_i(name='iz')(inputs) + dense_h(name='hz')(h)) # add bias because the linear transformations aren't directly summed. n = self.activation_fn( dense_i(name='in')(inputs) + r * dense_h(name='hn', use_bias=True)(h) ) new_h = (1.0 - z) * n + z * h return new_h, new_h @nowrap def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.features,) return self.carry_init(rng, mem_shape, self.param_dtype) @property def num_feature_axes(self) -> int: return 1 class MGUCell(RNNCellBase): r"""MGU cell (https://arxiv.org/pdf/1603.09420.pdf). The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array} where x is the input and h is the output of the previous time step. If ``reset_gate`` is false, the above becomes .. math:: \begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array} Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.MGUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: features: number of output features. gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). forget_bias_init: initializer for the bias parameters of the forget gate. The default is set to initializers.ones_init() because this prevents vanishing gradients. See https://proceedings.mlr.press/v37/jozefowicz15.pdf, section 2.2 for more details. activation_bias_init: initializer for the bias parameters of the activation output (default: initializers.zeros_init()). dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). reset_gate: flag for applying reset gating. """ features: int gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Initializer = default_kernel_init recurrent_kernel_init: Initializer = initializers.orthogonal() forget_bias_init: Initializer = initializers.ones_init() activation_bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() reset_gate: bool = True @compact def __call__(self, carry, inputs): """Minimal gated unit (MGU) cell. Args: carry: the hidden state of the MGU cell, initialized using ``MGUCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial( Dense, features=hidden_features, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.recurrent_kernel_init, bias_init=self.activation_bias_init, ) dense_i = partial( Dense, features=hidden_features, use_bias=True, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, ) f = self.gate_fn( dense_i(name='if', bias_init=self.forget_bias_init)(inputs) + dense_h(name='hf')(h) ) # add bias when the linear transformations aren't directly summed. x = dense_h(name="hn", use_bias=self.reset_gate)(h) if self.reset_gate: x *= f n = self.activation_fn( dense_i(name="in", bias_init=self.activation_bias_init)(inputs) + x ) new_h = (1.0 - f) * n + f * h return new_h, new_h @nowrap def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.features,) return self.carry_init(rng, mem_shape, self.param_dtype) @property def num_feature_axes(self) -> int: return 1 class ConvLSTMCell(RNNCellBase): r"""A convolutional LSTM cell. The implementation is based on xingjian2015convolutional. Given x_t and the previous state (h_{t-1}, c_{t-1}) the core computes .. math:: \begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array} where * denotes the convolution operator; i_t, f_t, o_t are input, forget and output gate activations, and g_t is a vector of cell updates. .. note:: Forget gate initialization: Following jozefowicz2015empirical we add 1.0 to b_f after initialization in order to reduce the scale of forgetting in the beginning of the training. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (3, 5, 5)) >>> layer = nn.ConvLSTMCell(features=4, kernel_size=(2, 2)) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: a sequence of ``n`` integers, representing the inter-window strides. padding: either the string ``'SAME'``, the string ``'VALID'``, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). """ features: int kernel_size: Sequence[int] strides: Sequence[int] | None = None padding: str | Sequence[tuple[int, int]] = 'SAME' use_bias: bool = True dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init() @compact def __call__(self, carry, inputs): """Constructs a convolutional LSTM. Args: carry: the hidden state of the Conv2DLSTM cell, initialized using ``Conv2DLSTM.initialize_carry``. inputs: input data with dimensions (batch, spatial_dims..., features). Returns: A tuple with the new carry and the output. """ c, h = carry input_to_hidden = partial( Conv, features=4 * self.features, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, name='ih', ) hidden_to_hidden = partial( Conv, features=4 * self.features, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, name='hh', ) gates = input_to_hidden()(inputs) + hidden_to_hidden()(h) i, g, f, o = jnp.split(gates, indices_or_sections=4, axis=-1) f = sigmoid(f + 1) new_c = f * c + sigmoid(i) * jnp.tanh(g) new_h = sigmoid(o) * jnp.tanh(new_c) return (new_c, new_h), new_h @nowrap def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]): """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ # (*batch_dims, *signal_dims, features) signal_dims = input_shape[-self.num_feature_axes : -1] batch_dims = input_shape[: -self.num_feature_axes] key1, key2 = random.split(rng) mem_shape = batch_dims + signal_dims + (self.features,) c = self.carry_init(key1, mem_shape, self.param_dtype) h = self.carry_init(key2, mem_shape, self.param_dtype) return c, h @property def num_feature_axes(self) -> int: return len(self.kernel_size) + 1 class RNN(Module): """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence using :func:`flax.linen.scan`. Example:: >>> import jax.numpy as jnp >>> import jax >>> import flax.linen as nn >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64)) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (batch, time, cell_size) (10, 50, 64) As shown above, RNN uses the ``cell_size`` argument to set the ``size`` argument for the cell's ``initialize_carry`` method, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example the :class:`ConvLSTMCell` requires a ``size`` argument of the form ``(kernel_height, kernel_width, features)``:: >>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3))) >>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x) >>> y.shape # (batch, time, height, width, features) (10, 50, 32, 32, 64) By default RNN expect the time dimension after the batch dimension (``(*batch, time, *features)``), if you set ``time_major=True`` RNN will instead expect the time dimension to be at the beginning (``(time, *batch, *features)``):: >>> x = jnp.ones((50, 10, 32)) # (time, batch, features) >>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (time, batch, cell_size) (50, 10, 64) The output is an array of shape ``(*batch, time, *cell_size)`` by default (typically), however if you set ``return_carry=True`` it will instead return a tuple of the final carry and the output:: >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True) >>> variables = lstm.init(jax.random.key(0), x) >>> carry, y = lstm.apply(variables, x) >>> jax.tree_util.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) ((10, 64), (10, 64)) >>> y.shape # (batch, time, cell_size) (10, 50, 64) To support variable length sequences, you can pass a ``seq_lengths`` which is an integer array of shape ``(*batch)`` where each element is the length of the sequence in the batch. For example:: >>> seq_lengths = jnp.array([3, 2, 5]) The output elements corresponding to padding elements are NOT zeroed out. If ``return_carry`` is set to ``True`` the carry will be the state of the last valid element of each sequence. RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default they are set to work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be overridden as needed. Overriding default values to scan looks like this:: >>> lstm = nn.RNN( ... nn.LSTMCell(64), ... unroll=1, variable_axes={}, variable_broadcast='params', ... variable_carry=False, split_rngs={'params': False}) Attributes: cell: an instance of :class:`RNNCellBase`. time_major: if ``time_major=False`` (default) it will expect inputs with shape ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. return_carry: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. reverse: if ``reverse=False`` (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If ``seq_lengths`` is passed, padding will always remain at the end of the sequence. keep_order: if ``keep_order=True``, when ``reverse=True`` the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default), the output will remain in the order specified by ``reverse``. unroll: how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to ``nn.scan``. variable_axes: a dictionary mapping each collection to either an integer ``i`` (meaning we scan over dimension ``i``) or ``None`` (replicate rather than scan). This argument is forwarded to ``nn.scan``. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to ``nn.scan``. variable_carry: Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to ``nn.scan``. split_rngs: a mapping from PRNGSequenceFilter to bool specifying whether a collection's PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to ``nn.scan``. """ cell: RNNCellBase time_major: bool = False return_carry: bool = False reverse: bool = False keep_order: bool = False unroll: int = 1 variable_axes: Mapping[ CollectionFilter, InOutScanAxis ] = FrozenDict() variable_broadcast: CollectionFilter = 'params' variable_carry: CollectionFilter = False split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict( {'params': False} ) def __call__( self, inputs: jax.Array, *, initial_carry: Carry | None = None, init_key: PRNGKey | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, keep_order: bool | None = None, ) -> Output | tuple[Carry, Output]: """ Applies the RNN to the inputs. ``__call__`` allows you to optionally override some attributes like ``return_carry`` and ``time_major`` defined in the constructor. Arguments: inputs: the input sequence. initial_carry: the initial carry, if not provided it will be initialized using the cell's :meth:`RNNCellBase.initialize_carry` method. init_key: a PRNG key used to initialize the carry, if not provided ``jax.random.key(0)`` will be used. Most cells will ignore this argument. seq_lengths: an optional integer array of shape ``(*batch)`` indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored. return_carry: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. time_major: if ``time_major=False`` (default) it will expect inputs with shape ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. reverse: overrides the ``reverse`` attribute, if ``reverse=False`` (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If ``seq_lengths`` is passed, padding will always remain at the end of the sequence. keep_order: overrides the ``keep_order`` attribute, if ``keep_order=True``, when ``reverse=True`` the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If ``keep_order=False`` (default), the output will remain in the order specified by ``reverse``. Returns: if ``return_carry=False`` (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence. """ if return_carry is None: return_carry = self.return_carry if time_major is None: time_major = self.time_major if reverse is None: reverse = self.reverse if keep_order is None: keep_order = self.keep_order # Infer the number of batch dimensions from the input shape. # Cells like ConvLSTM have additional spatial dimensions. time_axis = ( 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) ) # make time_axis positive if time_axis < 0: time_axis += inputs.ndim if time_major: # we add +1 because we moved the time axis to the front batch_dims = inputs.shape[1 : -self.cell.num_feature_axes] else: batch_dims = inputs.shape[:time_axis] # maybe reverse the sequence if reverse: inputs = jax.tree_util.tree_map( lambda x: flip_sequences( x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major, # type: ignore ), inputs, ) carry: Carry if initial_carry is None: if init_key is None: init_key = random.key(0) input_shape = inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :] carry = self.cell.initialize_carry(init_key, input_shape) else: carry = initial_carry slice_carry = seq_lengths is not None and return_carry def scan_fn( cell: RNNCellBase, carry: Carry, x: Array ) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: carry, y = cell(carry, x) # When we have a segmentation mask we return the carry as an output # so that we can select the last carry for each sequence later. # This uses more memory but is faster than using jnp.where at each # iteration. As a small optimization do this when we really need it. if slice_carry: return carry, (carry, y) else: return carry, y scan = transforms.scan( scan_fn, in_axes=time_axis, out_axes=(0, time_axis) if slice_carry else time_axis, unroll=self.unroll, variable_axes=self.variable_axes, variable_broadcast=self.variable_broadcast, variable_carry=self.variable_carry, split_rngs=self.split_rngs, ) scan_output = scan(self.cell, carry, inputs) # Next we select the final carry. If a segmentation mask was provided and # return_carry is True we slice the carry history and select the last valid # carry for each sequence. Otherwise we just use the last carry. if slice_carry: assert seq_lengths is not None _, (carries, outputs) = scan_output # seq_lengths[None] expands the shape of the mask to match the # number of dimensions of the carry. carry = _select_last_carry(carries, seq_lengths) else: carry, outputs = scan_output if reverse and keep_order: outputs = jax.tree_util.tree_map( lambda x: flip_sequences( x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major, # type: ignore ), outputs, ) if return_carry: return carry, outputs else: return outputs def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: last_idx = seq_lengths - 1 def _slice_array(x: jnp.ndarray): return x[last_idx, jnp.arange(x.shape[1])] return jax.tree_util.tree_map(_slice_array, sequence) def _expand_dims_like(x, target): """Expands the shape of `x` to match `target`'s shape by adding singleton dimensions.""" return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim)) def flip_sequences( inputs: Array, seq_lengths: Array | None, num_batch_dims: int, time_major: bool, ) -> Array: """Flips a sequence of inputs along the time axis. This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements. Example: ```python inputs = [[1, 0, 0], [2, 3, 0] [4, 5, 6]] lengths = [1, 2, 3] flip_sequences(inputs, lengths) = [[1, 0, 0], [3, 2, 0], [6, 5, 4]] ``` Args: inputs: An array of input IDs [batch_size, seq_length]. lengths: The length of each sequence [batch_size]. Returns: An ndarray with the flipped inputs. """ # Compute the indices to put the inputs in flipped order as per above example. time_axis = 0 if time_major else num_batch_dims max_steps = inputs.shape[time_axis] if seq_lengths is None: # reverse inputs and return inputs = jnp.flip(inputs, axis=time_axis) return inputs seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis) # create indexes idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] if time_major: idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims) else: idxs = jnp.reshape( idxs, [1] * num_batch_dims + [max_steps] ) # [1, ..., max_steps] idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] idxs = _expand_dims_like( idxs, target=inputs ) # [*batch, max_steps, *features] # Select the inputs in flipped order. outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis) return outputs def _concatenate(a: Array, b: Array) -> Array: """Concatenates two arrays along the last dimension.""" return jnp.concatenate([a, b], axis=-1) class RNNBase(Protocol): def __call__( self, inputs: jax.Array, *, initial_carry: Carry | None = None, init_key: PRNGKey | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, keep_order: bool | None = None, ) -> Output | tuple[Carry, Output]: ... class Bidirectional(Module): """Processes the input in both directions and merges the results. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4))) >>> x = jnp.ones((2, 3)) >>> variables = layer.init(jax.random.key(0), x) >>> out = layer.apply(variables, x) """ forward_rnn: RNNBase backward_rnn: RNNBase merge_fn: Callable[[Array, Array], Array] = _concatenate time_major: bool = False return_carry: bool = False def __call__( self, inputs: jax.Array, *, initial_carry: Carry | None = None, init_key: PRNGKey | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, keep_order: bool | None = None, ) -> Output | tuple[Carry, Output]: if time_major is None: time_major = self.time_major if return_carry is None: return_carry = self.return_carry if init_key is not None: key_forward, key_backward = random.split(init_key) else: key_forward = key_backward = None if initial_carry is not None: initial_carry_forward, initial_carry_backward = initial_carry else: initial_carry_forward = initial_carry_backward = None # Throw a warning in case the user accidentally re-uses the forward RNN # for the backward pass and does not intend for them to share parameters. if self.forward_rnn is self.backward_rnn: logging.warning( 'forward_rnn and backward_rnn is the same object, so ' 'they will share parameters.' ) # Encode in the forward direction. carry_forward, outputs_forward = self.forward_rnn( inputs, initial_carry=initial_carry_forward, init_key=key_forward, seq_lengths=seq_lengths, return_carry=True, time_major=time_major, reverse=False, ) carry_backward, outputs_backward = self.backward_rnn( inputs, initial_carry=initial_carry_backward, init_key=key_backward, seq_lengths=seq_lengths, return_carry=True, time_major=time_major, reverse=True, keep_order=True, ) carry = (carry_forward, carry_backward) outputs = jax.tree_util.tree_map( self.merge_fn, outputs_forward, outputs_backward ) if return_carry: return carry, outputs else: return outputs ================================================ FILE: flax/linen/spmd.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for working with jit and partitioned models. This module introduces ``axis_rules``, ``logical_to_mesh_axes``, ``logical_to_mesh``, ``with_logical_constraint`` for appyling jit sharding constraints in terms of "logical named axes" rather than jit's default mesh axes. Additionally the ``LogicallyPartitioned`` metadata wrapper is defined as well as the initializer function wrapper ``with_logical_partitioning``for introducing logical axis metadata into a model's variables. """ import collections import dataclasses import enum import functools from typing import Any from collections.abc import Callable, Sequence import jax from jax import lax from flax import struct from flax.core import meta from flax.core.spmd import ( get_logical_axis_rules, ) from flax.typing import ( Array, LogicalNames, LogicalRules, ArrayPytree, # pylint: disable=invalid-name LogicalPartitionSpec, # pylint: disable=unused-import LogicalPartitionSpecPytree, # pylint: disable=invalid-name ) class _UnassignedAxis: """Sentinel class for unassigned logical axis name.""" def __repr__(self): return 'UnassignedAxis' def __bool__(self): return False _unassigned_axis = _UnassignedAxis() def _mesh_assignment_free(new_assignment, existing_assignments): """Determines if a given mesh axis has already been assigned.""" new = set(jax.tree_util.tree_leaves(new_assignment)) existing = set(jax.tree_util.tree_leaves(existing_assignments)) new.discard(jax.sharding.PartitionSpec.UNCONSTRAINED) new.discard(None) if existing.intersection(new): return False return True def _logical_to_mesh_axes( array_dim_names: Sequence[str | None] | None, rules: LogicalRules | None = None, ) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None: """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" if array_dim_names is None: return None if rules is None: rules = get_logical_axis_rules() axis_name_counts = collections.Counter(array_dim_names) # None and special values such as PartitionSpec.UNCONSTRAINED can appear more # then once. dups = tuple( k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str) ) if dups: raise ValueError( f'Unsupported: Dimensions {dups} occur more than once in array names.' ) if not isinstance(rules, (tuple, list)): raise ValueError('Unknown axis rule specification type.') # We assign mesh axes using a priority based ruleset over logical axis names. result: list[_UnassignedAxis | None | str | tuple[str, ...]] result = [ (_unassigned_axis if isinstance(name, str) else name) for name in array_dim_names ] for rule_model_name, rule_mesh_names in rules: if rule_model_name in array_dim_names: pos = array_dim_names.index(rule_model_name) if ( _mesh_assignment_free(rule_mesh_names, result) and result[pos] == _unassigned_axis ): result[pos] = rule_mesh_names return result def logical_to_mesh_axes( array_dim_names: Sequence[str | None] | None, rules: LogicalRules | None = None, ) -> jax.sharding.PartitionSpec | None: """Compute layout for an array. The rules are in order of precedence, and consist of pairs: ``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused). A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh. For example, given an array with:: array_dim_names = ('batch', 'length', 'heads', 'features') and the layout rules are:: rules = (('batch', 'X'), ('features', 'X'), ('heads', 'Y'), ('batch', 'Z')) then this function will return:: PartitionSpec('X', None, 'Y', None) Args: array_dim_names: Tuple of array dimension names or None. rules: Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the ``axis_rules`` function. Returns: PartitionSpec for the parameter. """ result = _logical_to_mesh_axes(array_dim_names, rules) if result is None: return None # We default to None - ie unsharded along the dimension. result = [None if x is _unassigned_axis else x for x in result] return jax.sharding.PartitionSpec(*result) def logical_to_mesh(tree: Any, rules: LogicalRules | None = None) -> Any: """Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.""" return jax.tree_util.tree_map( lambda x: logical_to_mesh_axes(x, rules), tree, is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), ) def logical_to_mesh_sharding( tree: Any, mesh: jax.sharding.Mesh, rules: LogicalRules | None = None, ) -> Any: """Convert pytrees of logical PartitionSpecs to shardings.""" return jax.tree_util.tree_map( lambda x: jax.sharding.NamedSharding(mesh, x), logical_to_mesh(tree, rules), is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), ) class RulesFallback(enum.Enum): """How a sharding constraint should behave when no matching rule is found.""" AXIS_IS_UNSHARDED = 'axis_is_unsharded' RAISE_ERROR = 'raise_error' NO_CONSTRAINT = 'no_constraint' def _with_sharding_constraint( x: Array, axis_resources: jax.sharding.PartitionSpec | None, mesh: jax.sharding.Mesh | None = None, ): """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit.""" if not meta.global_mesh_defined() and mesh is None: return x else: if mesh is not None and axis_resources is not None: sharding = jax.sharding.NamedSharding(mesh, axis_resources) return lax.with_sharding_constraint(x, sharding) return lax.with_sharding_constraint(x, axis_resources) def _with_sharding_constraint_one_fallback( axis_resources: LogicalPartitionSpec, x: Array, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, rules: LogicalRules | None = None, mesh: jax.sharding.Mesh | None = None, ): """Either imposes a sharding constraint or applies fallback.""" mesh_axes = _logical_to_mesh_axes(axis_resources, rules) if mesh_axes is None: return _with_sharding_constraint(x, None, mesh=mesh) if fallback == RulesFallback.AXIS_IS_UNSHARDED: mesh_axes = [None if x is _unassigned_axis else x for x in mesh_axes] else: if any(x is _unassigned_axis for x in mesh_axes): if fallback == RulesFallback.RAISE_ERROR: raise ValueError(f'Axis names {axis_resources} did not match a rule') else: return x return _with_sharding_constraint( x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh ) def _is_axis_spec(x): return ( isinstance(x, str) or x is jax.sharding.PartitionSpec.UNCONSTRAINED or x is None ) def _is_logical_spec(x): return x is None or ( isinstance(x, tuple) and all(_is_axis_spec(e) for e in x) ) def with_logical_constraint( x: ArrayPytree, logical_axis_resources: LogicalPartitionSpecPytree, rules: LogicalRules | None = None, mesh: jax.sharding.Mesh | None = None, fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, ): """Version of jit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: rules = get_logical_axis_rules() if not rules or logical_axis_resources is None: return x # Translate logical names to mesh assignments. return jax.tree_util.tree_map( functools.partial( _with_sharding_constraint_one_fallback, fallback=fallback, rules=rules, mesh=mesh, ), logical_axis_resources, x, is_leaf=_is_logical_spec, ) # Logical Partitioning Axis Metadata # ------------------------------------------------------------------------------ class LogicallyPartitioned(meta.Partitioned): rules: LogicalRules | None = struct.field(default=None, pytree_node=False) # Directly comparing members of the type `jax.Array` can throw an error: # "The truth value of an array with more than one element is ambiguous." # So we bring back an explicit implementation of __eq__ like it was prior to # Python 3.13 in order work around this possibility. def __eq__(self, other): if self is other: return True if other.__class__ is self.__class__: return (self.value,) == (other.value,) return NotImplemented def unbox(self, apply_constraint=True) -> Any: """Returns the wrapped value with the partitioning constraint applied.""" if apply_constraint and (meta.global_mesh_defined() or self.mesh is not None): return with_logical_constraint( self.value, self.get_partition_spec(), rules=self.rules, mesh=self.mesh, ) else: return self.value def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = vars(self) if 'names' in metadata: metadata['out_sharding'] = metadata.pop('names') if 'rules' in metadata: metadata['sharding_rules'] = metadata.pop('rules') return metadata @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" metadata['names'] = metadata.pop('out_sharding') metadata['rules'] = metadata.pop('sharding_rules') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) def with_logical_partitioning( fn: Callable[..., Any], names: LogicalNames, mesh: jax.sharding.Mesh | None = None, rules: LogicalRules | None = None, ) -> Callable[..., LogicallyPartitioned]: """Wraps a function's return value with LogicallyPartitioned. Example:: >>> import flax.linen as nn >>> kernel_init = nn.with_logical_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init) Args: fn: The function to be wrapped. Typically this is an initializer. names: The logical axis passed to ``LogicallyPartitioned``. mesh: The mesh to use for the partitioning. If None, the global mesh resource is used if available. rules: Optional logical to mesh rules use. If None, the global rules are used if available. Returns: A function wrapping ``fn`` that will return an instance of ``LogicallyPartitioned``. """ @functools.wraps(fn) def wrapper(*args, **kwargs): return LogicallyPartitioned( fn(*args, **kwargs), names, mesh=mesh, rules=rules ) # pytype: disable=wrong-keyword-args return wrapper ================================================ FILE: flax/linen/stochastic.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Stochastic modules.""" from collections.abc import Sequence import jax.numpy as jnp from jax import lax, random from flax.linen.module import Module, compact, merge_param from flax.typing import PRNGKey class Dropout(Module): """Create a dropout layer. .. note:: When using :meth:`Module.apply() `, make sure to include an RNG seed named ``'dropout'``. Dropout isn't necessary for variable initialization. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class MLP(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(4)(x) ... x = nn.Dropout(0.5, deterministic=not train)(x) ... return x >>> model = MLP() >>> x = jnp.ones((1, 3)) >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout >>> model.apply(variables, x, train=False) # don't use dropout Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) Attributes: rate: the dropout probability. (_not_ the keep rate!) broadcast_dims: dimensions that will share the same dropout mask deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng_collection: the rng collection name to use when requesting an rng key. """ rate: float broadcast_dims: Sequence[int] = () deterministic: bool | None = None rng_collection: str = 'dropout' @compact def __call__( self, inputs, deterministic: bool | None = None, rng: PRNGKey | None = None, ): """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng: an optional PRNGKey used as the random key, if not specified, one will be generated using ``make_rng`` with the ``rng_collection`` name. Returns: The masked inputs reweighted to preserve mean. """ deterministic = merge_param( 'deterministic', self.deterministic, deterministic ) if (self.rate == 0.0) or deterministic: return inputs # Prevent gradient NaNs in 1.0 edge-case. if self.rate == 1.0: return jnp.zeros_like(inputs) keep_prob = 1.0 - self.rate if rng is None: rng = self.make_rng(self.rng_collection) broadcast_shape = list(inputs.shape) for dim in self.broadcast_dims: broadcast_shape[dim] = 1 mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) ================================================ FILE: flax/linen/summary.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax Module summary library.""" import dataclasses import enum import io from abc import ABC, abstractmethod from types import MappingProxyType from typing import ( Any, ) from collections.abc import Callable, Iterable, Mapping, Sequence import jax import jax.numpy as jnp import numpy as np import rich.console import rich.table import rich.text import yaml import flax.linen.module as module_lib from flax.core import meta, unfreeze from flax.core.scope import ( CollectionFilter, DenyList, LazyRng, ) from flax.typing import ( Array, PRNGKey, RNGSequences, FrozenVariableDict, MutableVariableDict, LogicalNames, ) class _ValueRepresentation(ABC): """A class that represents a value in the summary table.""" @abstractmethod def render(self) -> str: ... @dataclasses.dataclass class _ArrayRepresentation(_ValueRepresentation): shape: tuple[int, ...] dtype: Any @classmethod def from_array(cls, x: Array) -> '_ArrayRepresentation': return cls(jnp.shape(x), jnp.result_type(x)) @classmethod def render_array(cls, x) -> str: return cls.from_array(x).render() def render(self): shape_repr = ','.join(str(x) for x in self.shape) return f'[dim]{self.dtype}[/dim][{shape_repr}]' @dataclasses.dataclass class _PartitionedArrayRepresentation(_ValueRepresentation): array_representation: _ArrayRepresentation names: LogicalNames @classmethod def from_partitioned( cls, partitioned: meta.Partitioned ) -> '_PartitionedArrayRepresentation': return cls( _ArrayRepresentation.from_array(partitioned.value), partitioned.names ) def render(self): return self.array_representation.render() + f' [dim]P[/dim]{self.names}' @dataclasses.dataclass class _ObjectRepresentation(_ValueRepresentation): obj: Any def render(self): return repr(self.obj) @dataclasses.dataclass class Row: """Contains the information about a single row in the summary table. Attributes: path: A tuple of strings that represents the path to the module. module_copy: A copy of the module being summarized. method: method of the module called. inputs: inputs to the module. outputs: Output of the Module as reported by `capture_intermediates`. module_variables: Dictionary of variables in the module (no submodules included). counted_variables: Dictionary of variables that should be counted for this row, if no summarization is done (e.g. `depth=None` in `module_summary`) then this field is the same as `module_variables`, however if a summarization is done then this dictionary potentially contains parameters from submodules depending on the depth of the Module in question. flops: FLOPs cost of calling the module method. vjp_flops: FLOPs cost of calling the VJP of the module method. """ path: tuple[str, ...] module_copy: module_lib.Module method: str inputs: Any outputs: Any module_variables: dict[str, dict[str, Any]] counted_variables: dict[str, dict[str, Any]] flops: int vjp_flops: int def __post_init__(self): self.inputs = self.inputs self.outputs = self.outputs self.module_variables = self.module_variables self.counted_variables = self.counted_variables def size_and_bytes( self, collections: Iterable[str] ) -> dict[str, tuple[int, int]]: return { col: ( _size_and_bytes(self.counted_variables[col]) if col in self.counted_variables else (0, 0) ) for col in collections } class Table(list[Row]): """A list of Row objects. Table inherits from `List[Row]` so it has all the methods of a list, however it also contains some additional fields: * `module`: the module that this table is summarizing * `collections`: a list containing the parameter collections (e.g. 'params', 'batch_stats', etc) """ def __init__( self, module: module_lib.Module, collections: Sequence[str], rows: Iterable[Row], ): super().__init__(rows) self.module = module self.collections = collections def tabulate( module: module_lib.Module, rngs: PRNGKey | RNGSequences, depth: int | None = None, show_repeated: bool = False, mutable: CollectionFilter = DenyList('intermediates'), console_kwargs: Mapping[str, Any] | None = None, table_kwargs: Mapping[str, Any] = MappingProxyType({}), column_kwargs: Mapping[str, Any] = MappingProxyType({}), compute_flops: bool = False, compute_vjp_flops: bool = False, **kwargs, ) -> Callable[..., str]: """Returns a function that creates a summary of the Module represented as a table. This function accepts most of the same arguments and internally calls `Module.init`, except that it returns a function of the form `(*args, **kwargs) -> str` where `*args` and `**kwargs` are passed to `method` (e.g. `__call__`) during the forward pass. `tabulate` uses `jax.eval_shape` under the hood to run the forward computation without consuming any FLOPs or allocating memory. Additional arguments can be passed into the `console_kwargs` argument, for example, `{'width': 120}`. For a full list of `console_kwargs` arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> tabulate_fn = nn.tabulate( ... Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True) >>> # print(tabulate_fn(x)) This gives the following output:: Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B) **Note**: rows order in the table does not represent execution order, instead it aligns with the order of keys in `variables` which are sorted alphabetically. **Note**: `vjp_flops` returns `0` if the module is not differentiable. Args: module: The module to tabulate. rngs: The rngs for the variable collections as passed to `Module.init`. depth: controls how many submodule deep the summary can go. By default its `None` which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module. show_repeated: If `True`, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is `False`. mutable: Can be bool, str, or list. Specifies which collections should be treated as mutable: ``bool``: all/no collections are mutable. ``str``: The name of a single mutable collection. ``list``: A list of names of mutable collections. By default all collections except 'intermediates' are mutable. console_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.console.Console` when rendering the table. Default arguments are `{'force_terminal': True, 'force_jupyter': False}`. table_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.table.Table` constructor. column_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.table.Table.add_column` when adding columns to the table. compute_flops: whether to include a `flops` column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion's UNet, whereas otherwise tabulation would finish in 5 seconds). compute_vjp_flops: whether to include a `vjp_flops` column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of `compute_flops`. **kwargs: Additional arguments passed to `Module.init`. Returns: A function that accepts the same `*args` and `**kwargs` of the forward pass (`method`) and returns a string with a tabular representation of the Modules. """ # add non-default arguments to kwargs, this prevents some issue we overloading init # see: https://github.com/google/flax/issues/3299 if mutable != DenyList('intermediates'): kwargs['mutable'] = mutable def _tabulate_fn(*fn_args, **fn_kwargs): table_fn = _get_module_table( module, depth=depth, show_repeated=show_repeated, compute_flops=compute_flops, compute_vjp_flops=compute_vjp_flops, ) table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs) non_param_cols = [ 'path', 'module', 'inputs', 'outputs', ] if compute_flops: non_param_cols.append('flops') if compute_vjp_flops: non_param_cols.append('vjp_flops') return _render_table( table, console_kwargs, table_kwargs, column_kwargs, non_param_cols ) return _tabulate_fn def _get_flops(fn, *args, **kwargs): e = jax.jit(fn).lower(*args, **kwargs) cost = e.cost_analysis() if cost is None: return 0 flops = int(cost['flops']) if 'flops' in cost else 0 return flops def _get_call_flops( c: module_lib._CallInfo, compute_flops: bool, compute_vjp_flops: bool, ) -> tuple[int, int]: """Return the FLOPs of executing the call `c` in the call stack. Does not perform actual computation / compilation / memory allocation, but still introduces overhead for large modules. Args: c: ``_CallInfo``. compute_flops: whether to compute forward pass FLOPs. Return `-1` otherwise. compute_vjp_flops: whether to compute backward pass FLOPs. Return `-1` otherwise. Returns: FLOPs of executing forward pass of `c`, and its VJP. """ if not compute_flops and not compute_vjp_flops: return -1, -1 rngs = jax.tree_util.tree_map( lambda x: x.rng, c.rngs, is_leaf=lambda x: isinstance(x, LazyRng) ) args = jax.tree_util.tree_map(_from_value_representation, c.args) kwargs = jax.tree_util.tree_map(_from_value_representation, c.kwargs) leaves, treedef = jax.tree_util.tree_flatten((args, kwargs)) dynamic_leaves = [] dynamic_idxs = [] for i, arg in enumerate(leaves): if isinstance(arg, jax.ShapeDtypeStruct): dynamic_leaves.append(arg) dynamic_idxs.append(i) def _get_inputs(dynamic_leaves): new_leaves: list[Any] = leaves.copy() for i, arg in zip(dynamic_idxs, dynamic_leaves): new_leaves[i] = arg return treedef.unflatten(new_leaves) def init(rngs, dynamic_leaves): """`c.module.init` closed over static keyword arguments.""" args, kwargs = _get_inputs(dynamic_leaves) return c.module.init( rngs, *args, method=c.method, mutable=c.mutable, **kwargs, ) variables = jax.eval_shape(init, rngs, dynamic_leaves) def apply(variables, rngs, dynamic_leaves): """`c.module.apply` closed over static keyword arguments.""" args, kwargs = _get_inputs(dynamic_leaves) return c.module.apply( variables, *args, rngs=rngs, method=c.method, mutable=c.mutable, **kwargs, ) # Forward pass FLOPs if compute_flops: flops = _get_flops(apply, variables, rngs, dynamic_leaves) else: flops = -1 if compute_vjp_flops: # Backward pass FLOPs def apply_vjp(variables, rngs, dynamic_leaves): """VJP of `c.module.apply` closed over static keyword arguments.""" out, vjp_fn = jax.vjp(apply, variables, rngs, dynamic_leaves) return vjp_fn(out) vjp_flops = _get_flops(apply_vjp, variables, rngs, dynamic_leaves) else: vjp_flops = -1 return flops, vjp_flops def _get_module_table( module: module_lib.Module, depth: int | None, show_repeated: bool, compute_flops: bool, compute_vjp_flops: bool, ) -> Callable[..., Table]: """A function that takes a Module and returns function with the same signature as `init` but returns the Table representation of the Module.""" def _get_table_fn(*args, **kwargs): with module_lib._tabulate_context(): def _get_variables(): return module.init(*args, **kwargs) # TODO(cgarciae): is it possible to avoid leaking tracers for summaries? with jax.check_tracer_leaks(False): variables = jax.eval_shape(_get_variables) calls = module_lib._context.call_info_stack[-1].calls calls.sort(key=lambda c: c.index) collections: set[str] = set(variables.keys()) rows = [] all_paths: set[tuple[str, ...]] = {call.path for call in calls} visited_paths: set[tuple[str, ...]] = set() for c in calls: call_depth = len(c.path) inputs = _process_inputs(c.args, c.kwargs) if c.path in visited_paths or not hasattr(c.module, c.method): if not show_repeated: continue module_vars = {} counted_vars = {} elif depth is not None: if call_depth > depth: continue module_vars, _ = _get_module_variables(c.path, variables, all_paths) if call_depth == depth: counted_vars = _get_path_variables(c.path, variables) else: counted_vars = module_vars else: module_vars, _ = _get_module_variables(c.path, variables, all_paths) counted_vars = module_vars visited_paths.add(c.path) rows.append( Row( c.path, c.module.copy(parent=None), c.method, inputs, c.outputs, module_vars, counted_vars, *_get_call_flops(c, compute_flops, compute_vjp_flops), ) ) return Table(module, tuple(collections), rows) return _get_table_fn def _get_module_variables( path: tuple[str, ...], variables: FrozenVariableDict, all_paths: set[tuple[str, ...]], ) -> tuple[MutableVariableDict, Any]: """A function that takes a path and variables structure and returns a (module_variables, submodule_variables) tuple for that path. _get_module_variables uses the `all_paths` set to determine if a variable belongs to a submodule or not. """ module_variables = _get_path_variables(path, variables) submodule_variables: Any = {collection: {} for collection in module_variables} all_keys = { key for collection in module_variables.values() for key in collection } for key in all_keys: submodule_path = path + (key,) if submodule_path in all_paths: for collection in module_variables: if key in module_variables[collection]: submodule_variables[collection][key] = module_variables[ collection ].pop(key) return module_variables, submodule_variables def _get_path_variables( path: tuple[str, ...], variables: FrozenVariableDict ) -> MutableVariableDict: """A function that takes a path and a variables structure and returns the variable structure at that path. """ path_variables = {} for collection in variables: collection_variables = variables[collection] for name in path: if name not in collection_variables: collection_variables = None break collection_variables = collection_variables[name] if collection_variables is not None: path_variables[collection] = unfreeze(collection_variables) return path_variables def _process_inputs(args, kwargs) -> Any: """A function that normalizes the representation of the ``args`` and ``kwargs`` for the ``inputs`` column. """ if args and kwargs: input_values = (*args, kwargs) elif args and not kwargs: input_values = args[0] if len(args) == 1 else args elif kwargs and not args: input_values = kwargs else: input_values = () return input_values def _render_table( table: Table, console_extras: Mapping[str, Any] | None, table_kwargs: Mapping[str, Any], column_kwargs: Mapping[str, Any], non_params_cols: list[str], ) -> str: """A function that renders a Table to a string representation using rich.""" console_kwargs = {'force_terminal': True, 'force_jupyter': False} if console_extras is not None: console_kwargs.update(console_extras) rich_table = rich.table.Table( show_header=True, show_lines=True, show_footer=True, title=f'{table.module.__class__.__name__} Summary', **table_kwargs, ) for c in non_params_cols: rich_table.add_column(c, **column_kwargs) for col in table.collections: rich_table.add_column(col, **column_kwargs) for row in table: collections_size_repr = [] for collection, size_bytes in row.size_and_bytes(table.collections).items(): col_repr = '' if collection in row.module_variables: module_variables = _represent_tree(row.module_variables[collection]) module_variables = _normalize_structure(module_variables) col_repr += _as_yaml_str( _summary_tree_map(_maybe_render, module_variables) ) if col_repr: col_repr += '\n\n' col_repr += f'[bold]{_size_and_bytes_repr(*size_bytes)}[/bold]' collections_size_repr.append(col_repr) no_show_methods = {'__call__', ''} path_repr = '/'.join(row.path) method_repr = ( f' [dim]({row.method})[/dim]' if row.method not in no_show_methods else '' ) rich_table.add_row( path_repr, type(row.module_copy).__name__ + method_repr, *( _as_yaml_str( _summary_tree_map( _maybe_render, _normalize_structure(getattr(row, c)) ) ) for c in non_params_cols[2:] ), *collections_size_repr, ) # add footer with totals n_non_params_cols = len(non_params_cols) rich_table.columns[n_non_params_cols - 1].footer = rich.text.Text.from_markup( 'Total', justify='right' ) # get collection totals collection_total = {col: (0, 0) for col in table.collections} for row in table: for col, size_bytes in row.size_and_bytes(table.collections).items(): collection_total[col] = ( collection_total[col][0] + size_bytes[0], collection_total[col][1] + size_bytes[1], ) # add totals to footer for i, col in enumerate(table.collections): rich_table.columns[n_non_params_cols + i].footer = _size_and_bytes_repr( *collection_total[col] ) # add final totals to caption caption_totals = (0, 0) for size, num_bytes in collection_total.values(): caption_totals = ( caption_totals[0] + size, caption_totals[1] + num_bytes, ) rich_table.caption_style = 'bold' rich_table.caption = ( f'\nTotal Parameters: {_size_and_bytes_repr(*caption_totals)}' ) return '\n' + _get_rich_repr(rich_table, console_kwargs) + '\n' def _summary_tree_map(f, tree, *rest): return jax.tree_util.tree_map(f, tree, *rest, is_leaf=lambda x: x is None) def _size_and_bytes_repr(size: int, num_bytes: int) -> str: if not size: return '' bytes_repr = _bytes_repr(num_bytes) return f'{size:,} [dim]({bytes_repr})[/dim]' def _size_and_bytes(pytree: Any) -> tuple[int, int]: leaves = jax.tree_util.tree_leaves(pytree) size = sum(x.size for x in leaves if hasattr(x, 'size')) num_bytes = sum( x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') ) return size, num_bytes def _get_rich_repr(obj, console_kwargs): f = io.StringIO() console = rich.console.Console(file=f, **console_kwargs) console.print(obj) return f.getvalue() def _as_yaml_str(value) -> str: if (hasattr(value, '__len__') and len(value) == 0) or value is None: return '' file = io.StringIO() yaml.safe_dump( value, file, default_flow_style=False, indent=2, sort_keys=False, explicit_end=False, ) return file.getvalue().replace('\n...', '').replace("'", '').strip() def _normalize_structure(obj): if isinstance(obj, _ValueRepresentation): return obj if isinstance(obj, (tuple, list)): return tuple(map(_normalize_structure, obj)) elif isinstance(obj, Mapping): return { _normalize_structure(k): _normalize_structure(v) for k, v in obj.items() } elif dataclasses.is_dataclass(obj): return { f.name: _normalize_structure(getattr(obj, f.name)) for f in dataclasses.fields(obj) } elif isinstance(obj, enum.Enum): # `yaml.safe_dump` does not support Enum key types so extract the underlying value return obj.value else: return obj def _bytes_repr(num_bytes): count, units = ( (f'{num_bytes / 1e9 :,.1f}', 'GB') if num_bytes > 1e9 else (f'{num_bytes / 1e6 :,.1f}', 'MB') if num_bytes > 1e6 else (f'{num_bytes / 1e3 :,.1f}', 'KB') if num_bytes > 1e3 else (f'{num_bytes:,}', 'B') ) return f'{count} {units}' def _get_value_representation(x: Any) -> _ValueRepresentation: if isinstance(x, (int, float, bool, type(None))) or ( isinstance(x, np.ndarray) and np.isscalar(x) ): return _ObjectRepresentation(x) elif isinstance(x, meta.Partitioned): return _PartitionedArrayRepresentation.from_partitioned(x) try: return _ArrayRepresentation.from_array(x) except: return _ObjectRepresentation(x) def _from_value_representation(x: _ValueRepresentation) -> Any: if isinstance(x, _ArrayRepresentation): return jax.ShapeDtypeStruct(x.shape, x.dtype) elif isinstance(x, _PartitionedArrayRepresentation): return jax.ShapeDtypeStruct( x.array_representation.shape, x.array_representation.dtype ) elif isinstance(x, _ObjectRepresentation): return x.obj raise TypeError(x, type(x)) def _represent_tree(x): """Returns a tree with the same structure as `x` but with each leaf replaced by a `_ValueRepresentation` object.""" return jax.tree_util.tree_map( _get_value_representation, x, is_leaf=lambda x: x is None or isinstance(x, meta.Partitioned), ) def _maybe_render(x): return x.render() if hasattr(x, 'render') else repr(x) ================================================ FILE: flax/linen/transforms.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """JAX transformations on Modules. Jax functional transformations operate on pure functions. Flax extends these transformations to also operate on Module's which have stateful variables and PRNG sequences. We refer to these extended versions as "lifted transformations". A lifted transformation can be applied to a ``Module`` class or a function that takes a ``Module`` instance as its first argument. """ from collections.abc import Callable, Iterable, Mapping, Sequence import contextlib import dataclasses import functools import inspect from typing import ( Any, TypeVar, Union, ) import weakref from flax import core from flax import errors, struct, traceback_util from flax import serialization from flax.core import Scope, lift, meta from flax.core.frozen_dict import FrozenDict from flax.core.scope import ( CollectionFilter, LazyRng, PRNGSequenceFilter, ) from flax.ids import FlaxId from flax.linen import module as linen_module from flax.linen.module import ( Module, Variable, _derive_profiling_name, _get_unbound_fn, wrap_method_once, ) from flax.typing import ( InOutAxis, InOutScanAxis, ) import jax traceback_util.register_exclusion(__file__) # pylint: disable=protected-access,dangerous-default-value # Utils # ----------------------------------------------------------------------------- def clean_clone(x): """Remove scopes and tracers from children.""" if isinstance(x, Module): object.__setattr__( x, 'children', {k: clean_clone(v) for k, v in x.children.items()} ) object.__setattr__(x, 'scope', None) return x @struct.dataclass class VariablePlaceholder: """Used to mark Variables in a JAX-compatible way when lifting arguments.""" collection: str = struct.field(pytree_node=False) name: str = struct.field(pytree_node=False) unbox: bool = struct.field(pytree_node=False) id: int = struct.field(pytree_node=False) @struct.dataclass class InstancePlaceholder: """Marks module instances in a JAX-compatible way when lifting arguments.""" cls: type[Any] = struct.field(pytree_node=False) attrs: dict[Any, Any] = struct.field(pytree_node=False) id: int = struct.field(pytree_node=False) def _memoize_by_id(fn, refs): """Memoization by module/variable id to handle aliasing in traversal.""" @functools.wraps(fn) def wrapped_fn(x): nonlocal refs if isinstance(x, (VariablePlaceholder, InstancePlaceholder)): x_id = x.id elif isinstance(x, (Variable, Module)): x_id = x._id else: return fn(x) if x_id not in refs: refs[x_id] = fn(x) return refs[x_id] return wrapped_fn def get_module_scopes(module, args=None, kwargs=None): """Get all scopes on module, including constructor Module arguments. To properly functionalize a Module that has other bound Modules passed in "from the outside" as dataclass attributes, we need to traverse all dataclass fields to find the Scopes associated with the Module. Additionally, because we allow Modules to be passed inside pytrees on the dataclass attributes, we must traverse all dataclass attributes as pytrees to find all Modules. We additionally handle lifting Variables (which are just references to data in particular scopes) and Module instances that are passed as arguments to methods. Args: module: a bound flax Module. args: an *args list possibly containing Variables or Module instances referencing a scope. kwargs: a **kwargs dict possibly containing Variables or Module instances referencing a scope. Returns: A list of all functional-core Scopes bound on self and inside dataclass fields as well as any Scopes passed via argument Variables, an updated args list, and an updated kwargs dict that have both had Variables replaced with VariablePlaceholders and Module instances replaced with InstancePlaceholders that are compatible with jax functions. """ scopes: list[Scope] = [] refs = {} # Gather scopes associated with Variables and Module instances passed as # positional and keyword arguments. @functools.partial(_memoize_by_id, refs=refs) def get_arg_scope(x): nonlocal scopes if isinstance(x, Variable) and isinstance(x.scope, Scope): scopes.append(x.scope) return VariablePlaceholder(x.collection, x.name, x.unbox, x._id) elif isinstance(x, Module) and isinstance(x.scope, Scope): x._try_setup(shallow=True) scopes.append(x.scope) attrs = { f.name: getattr(x, f.name) for f in dataclasses.fields(x) if f.name != 'parent' and f.init } attrs = jax.tree_util.tree_map(get_arg_scope, attrs) return InstancePlaceholder(x.__class__, attrs, x._id) return x new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs)) # Gather scopes in Variables and Submodules passed as Module attributes. @functools.partial(_memoize_by_id, refs=refs) def get_scopes(module): nonlocal scopes module._try_setup(shallow=True) def get_scopes_inner(x): nonlocal scopes if isinstance(x, Module) and isinstance(x.scope, Scope): get_scopes(x) elif isinstance(x, Variable) and isinstance(x.scope, Scope): scopes.append(x.scope) attrs = { f.name: getattr(module, f.name) for f in dataclasses.fields(module) if f.name != 'parent' and f.init } for leaf in jax.tree_util.tree_leaves(attrs): get_scopes_inner(leaf) scopes.append(module.scope) get_scopes(module) return scopes, new_args, new_kwargs def set_module_scopes(module, args, kwargs, scopes): """Set all scopes on module, including those on Modules in dataclass fields. To properly functionalize a Module we must also "rehydrate" it with Scopes from `get_module_scopes`. We need to set scopes not just on the Module but also on any Module living inside dataclass attributes or even pytrees in its dataclass attributes. We additionally handle restoring Variables and Module instances from their placeholders in the method positional and keyword arguments. The order of traversal through this method is the same as in `get_module_scopes`, guaranteeing the correct Scopes are applied to each Module. Args: module: a flax Module. args: an *args list possibly containing VariablePlaceholder or InstancePlaceholder members. kwargs: a **kwargs dict possibly containing VariablePlaceholder or InstancePlaceholder members. scopes: a list of Scopes corresponding to this Module and its arguments that was created by the `get_module_scopes` function. Returns: A copy of the module with it and its attributes bound to the scopes passed to this function, an updated args list, and an updated kwargs dict with updated Variable and Module instance references. """ idx = 0 refs = {} # Set scopes associated with Variables and Module instances passed as # positional and keyword arguments. @functools.partial(_memoize_by_id, refs=refs) def set_arg_scope(x): nonlocal idx if isinstance(x, VariablePlaceholder): new_x = Variable( scope=scopes[idx], collection=x.collection, name=x.name, unbox=x.unbox ) idx += 1 return new_x elif isinstance(x, InstancePlaceholder): instance_scope = scopes[idx] idx += 1 instance_attrs = jax.tree_util.tree_map(set_arg_scope, x.attrs) return x.cls(parent=instance_scope, **instance_attrs) return x def is_placeholder(x): return isinstance(x, (VariablePlaceholder, InstancePlaceholder)) new_args, new_kwargs = jax.tree_util.tree_map( set_arg_scope, (args, kwargs), is_leaf=is_placeholder ) # set scopes in Variables and Submodules passed as Module attributes @functools.partial(_memoize_by_id, refs=refs) def set_scopes(module): nonlocal idx def set_scopes_inner(x): nonlocal idx if isinstance(x, Module) and isinstance(x.scope, Scope): return set_scopes(x) elif isinstance(x, Variable) and isinstance(x.scope, Scope): new_x = Variable( scope=scopes[idx], collection=x.collection, name=x.name, unbox=x.unbox, ) idx += 1 return new_x else: return x attrs = { f.name: getattr(module, f.name) for f in dataclasses.fields(module) if f.name != 'parent' and f.init } new_attrs = jax.tree_util.tree_map(set_scopes_inner, attrs) new_module = module.clone(parent=scopes[idx], **new_attrs) idx += 1 return new_module new_module = set_scopes(module) assert len(scopes) == idx, f'scope list mismatch {len(scopes)} != {idx}' return new_module, new_args, new_kwargs def _test_transformed_return_values(tree, method_name): """Tests whether the return value contains any Modules or Variables.""" impure = any( map( lambda x: isinstance(x, (Module, Variable)), jax.tree_util.tree_leaves(tree), ) ) if impure: raise errors.TransformedMethodReturnValueError(method_name) # Class lifting # ----------------------------------------------------------------------------- def module_class_lift_transform( transform, module_class, *trafo_args, methods=None, **trafo_kwargs ): """Module class lift transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # TODO(levskaya): find nicer argument convention for multi-method case? # Prepare per-method transform args, kwargs. if methods is None: # Default case, just transform __call__ class_trafo_args = {'__call__': (trafo_args, trafo_kwargs)} elif isinstance(methods, (list, tuple)): # Transform every method in methods with given args, kwargs. class_trafo_args = {m: (trafo_args, trafo_kwargs) for m in methods} elif isinstance(methods, dict): # Pass different trafo args per each method. class_trafo_args = {k: ((), v) for k, v in methods.items()} else: raise ValueError( 'transform methods argument must be None, tuple, list, or dict.' ) # Handle partially initialized module class constructors. if isinstance(module_class, functools.partial) and issubclass( module_class.func, Module ): partial_object = module_class module_class = module_class.func else: partial_object = None def create_trans_fn(fn_name, fn_trafo_args): # get existing unbound method from class fn = getattr(module_class, fn_name) trafo_args, trafo_kwargs = fn_trafo_args # we need to create a scope-function from our class for the given method @functools.wraps(fn) def wrapped_fn(self, *args, **kwargs): state = self._state.export() # make a scope-function to transform def core_fn(scopes, *args, **kwargs): # make a clone of self using its arguments attrs = { f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != 'parent' and f.init } # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) object.__setattr__(cloned, '_state', state.export()) res = fn(cloned, *args, **kwargs) self._state.reimport(cloned._state) _test_transformed_return_values(res, fn_name) return res # here we apply the given lifting transform to the scope-ingesting fn trafo_fn = transform(core_fn, *trafo_args, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) ret = trafo_fn(module_scopes, *args, **kwargs) return ret return wrapped_fn transformed_fns = { fn_name: create_trans_fn(fn_name, fn_trafo_args) for fn_name, fn_trafo_args in class_trafo_args.items() } # construct new dynamic class w. transformed methods transformed_cls = type( transform.__name__.capitalize() + module_class.__name__, (module_class,), transformed_fns, ) # Handle partially initialized module class constructors. if partial_object is not None: transformed_cls = functools.partial( transformed_cls, *partial_object.args, **partial_object.keywords ) return transformed_cls # Function lifting as decorator on methods __inside__ class definition. # ----------------------------------------------------------------------------- def decorator_lift_transform( transform, class_fn, *trafo_args, multi_scope=True, **trafo_kwargs ): """Decorator for lifted transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # Due to the ordering of method decorators, we must wrap the class_fn # with the module state management wrapper first to maintain Module state # correctly. if isinstance(class_fn, tuple): class_fns = class_fn else: class_fns = (class_fn,) prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns] @functools.wraps(prewrapped_fns[0]) def wrapped_fn(self: Module, *args, **kwargs): state = self._state.export() # make a scope-function to transform def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): if not multi_scope: scopes = [scopes] cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) object.__setattr__(cloned, '_state', state.export()) res = prewrapped_fn(cloned, *args, **kwargs) self._state.reimport(cloned._state) _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) return res core_fns = [ functools.partial(core_fn, prewrapped_fn, class_fn) for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) ] # here we apply the given lifting transform to the scope-ingesting fn trafo_fn = transform(*core_fns, *trafo_args, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) if not multi_scope: if len(module_scopes) != 1: # TODO(levskaya): transforms like jvp & vjp have args that follow the # pytree structure of scopes. The user doesn't explicitly control shared # modules passed as arguments to methods or as attributes to Module # constructors. Therefore, there is no obvious API for specifying # arguments per lifted Module. raise NotImplementedError( 'This transform does not yet support' ' Modules that include other Modules passed as arguments.' ) module_scopes = module_scopes[0] return trafo_fn(module_scopes, *args, **kwargs) return wrapped_fn @dataclasses.dataclass(frozen=True) class _HashableProxy: """A hashable proxy object that is use to define a hash for Modules. The hash produced by _HashableProxy is useful for nn.jit to decide if a function should be retraced or not """ module_ref: weakref.ref hash_key: int @classmethod def from_module(cls, module: Module) -> '_HashableProxy': fingerprint = _module_fingerprint(module) hash_key = hash(fingerprint) return cls(weakref.ref(module), hash_key) def __hash__(self): return self.hash_key def __eq__(self, other): return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key @property def module(self): return self.module_ref() def _module_fingerprint(module: Module) -> tuple[type[Any], Any]: return _fingerprint_recursive(module, (), {}) def _fingerprint_recursive( obj: Any, path: tuple[str, ...], seen_modules: dict[FlaxId, int] ) -> Any: """Creates a hashable representation for a Module by traversing its structure recursively.""" def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]: return name, _fingerprint_recursive(value, (*path, name), seen_modules) if isinstance(obj, str): return obj elif hasattr(obj, '__fn_or_cls__'): # support PaxConfig objects return _fingerprint_recursive(obj.__fn_or_cls__, path, seen_modules) elif isinstance(obj, Module): fingerprint: Any if obj._id in seen_modules: # if we have already seen the module we just use the index # as its static component fingerprint = seen_modules[obj._id] return type(obj), fingerprint else: # if its a new module we add it to the cache and give it # a new index seen_modules[obj._id] = len(seen_modules) # TODO(cgarciae): define a way for the user of nn.jit to define # what fields it wants to ignore per Module instance. fingerprints = [] for field in dataclasses.fields(obj): if not hasattr(obj, field.name): continue if field.name not in ('parent', 'name'): value = getattr(obj, field.name) fingerprints.append(_get_fingerprint(field.name, value)) # add state fingerprint state_fingerprint = ( _get_fingerprint('in_compact_method', obj._state.in_compact_method), _get_fingerprint('in_setup', obj._state.in_setup), _get_fingerprint('setup_called', obj._state.setup_called), _get_fingerprint('is_initialized', obj._state.is_initialized), _get_fingerprint('autoname_cursor', obj._state.autoname_cursor), ) fingerprints.append(('_state', state_fingerprint)) # add scope fingerprint scope = obj.scope if scope is not None: static_scope = ( _get_fingerprint('mutable', scope.mutable), _get_fingerprint('flags', scope.flags), _get_fingerprint('rng_counts', scope.rng_counters), _get_fingerprint('reservations', scope.reservations), ) _check_field_is_hashable((*path, 'scope'), static_scope) fingerprints.append(('scope', static_scope)) fingerprint = tuple(fingerprints) return type(obj), fingerprint elif dataclasses.is_dataclass(obj): fingerprints = [] for field in dataclasses.fields(obj): if not hasattr(obj, field.name): continue value = getattr(obj, field.name) value_fingerprint = _get_fingerprint(field.name, value) fingerprints.append((field.name, value_fingerprint)) return type(obj), tuple(fingerprints) elif isinstance(obj, core.DenyList): return type(obj), _get_fingerprint('deny', obj.deny) elif isinstance(obj, dict): fingerprint = tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) return fingerprint elif serialization.is_serializable(obj): state = serialization.to_state_dict(obj) fingerprint = _fingerprint_recursive(state, path, seen_modules) return type(obj), fingerprint elif isinstance(obj, Mapping): return tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) elif isinstance(obj, Iterable): return tuple(_get_fingerprint(str(i), v) for i, v in enumerate(obj)) else: _check_field_is_hashable(path, obj) return obj def _check_field_is_hashable(path: tuple[str, ...], x: Any): """Checks if a field is hashable.""" try: hash(x) except Exception as e: path_name = '/'.join(path) raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs): """Decorator for lifted transform. Similar to `decorator_lift_transform` but specialized for `jit`, it reuses the previous transform when available to avoid retracing. """ # TODO(marcvanzee): Improve docstrings (#1977). # Due to the ordering of method decorators, we must wrap the class_fn # with the module state management wrapper first to maintain Module state # correctly. multi_scope = True if isinstance(class_fn, tuple): class_fns = class_fn else: class_fns = (class_fn,) prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns] trafo_fn = None @functools.wraps(prewrapped_fns[0]) def wrapped_fn(self: Module, *args, **kwargs): nonlocal trafo_fn state = self._state.export() # increment rng counters for all rngs in scope with fork_rngs(self): # make a scope-function to transform def core_fn( prewrapped_fn, class_fn, scopes, module_hash, *args, **kwargs, ): # self = hash_key.obj self: Module = module_hash.module if not multi_scope: scopes = [scopes] cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) object.__setattr__(cloned, '_state', state.export()) res = prewrapped_fn(cloned, *args, **kwargs) self._state.reimport(cloned._state) _test_transformed_return_values( res, getattr(class_fn, '__name__', None) ) return res core_fns = [ functools.wraps(class_fn)( functools.partial(core_fn, prewrapped_fn, class_fn) ) for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) ] # here we apply the given lifting transform to the scope-ingesting fn if trafo_fn is None: trafo_fn = transform(*core_fns, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) if not multi_scope: if len(module_scopes) != 1: # TODO(levskaya): transforms like jvp & vjp have args that follow the # pytree structure of scopes. The user doesn't explicitly control shared # modules passed as arguments to methods or as attributes to Module # constructors. Therefore, there is no obvious API for specifying # arguments per lifted Module. raise NotImplementedError( 'This transform does not yet support' ' Modules that include other Modules passed as arguments.' ) module_scopes = module_scopes[0] # get a hashable proxy object for the Module hash_key = _HashableProxy.from_module(self) return trafo_fn(module_scopes, hash_key, *args, **kwargs) return wrapped_fn @contextlib.contextmanager def fork_rngs(module: Module): """Context manager to fork rngs in a module.""" if module.scope is None: yield return current_rngs = module.scope.rngs.copy() module.scope.rngs = { name: LazyRng.create(module.make_rng(name)) for name in current_rngs } try: yield finally: module.scope.rngs = current_rngs def module_class_lift_transform_cached( transform, module_class, methods=None, **trafo_kwargs ): """Module class lift transform.""" # TODO(marcvanzee): Improve docstrings (#1977). # TODO(levskaya): find nicer argument convention for multi-method case? trafo_args = () # Prepare per-method transform args, kwargs. if methods is None: # Default case, just transform __call__ class_trafo_args = {'__call__': (trafo_args, trafo_kwargs)} elif isinstance(methods, (list, tuple)): # Transform every method in methods with given args, kwargs. class_trafo_args = {m: (trafo_args, trafo_kwargs) for m in methods} elif isinstance(methods, dict): # Pass different trafo args per each method. class_trafo_args = {k: ((), v) for k, v in methods.items()} else: raise ValueError( 'transform methods argument must be None, tuple, list, or dict.' ) # Handle partially initialized module class constructors. if isinstance(module_class, functools.partial) and issubclass( module_class.func, Module ): partial_object = module_class module_class = module_class.func else: partial_object = None def create_trans_fn(fn_name, fn_trafo_args): # get existing unbound method from class fn = getattr(module_class, fn_name) trafo_args, trafo_kwargs = fn_trafo_args trafo_fn = None # we need to create a scope-function from our class for the given method @functools.wraps(fn) def wrapped_fn(self: Module, *args, **kwargs): assert self.scope is not None nonlocal trafo_fn state = self._state.export() # increment rng counters for all rngs in scope with fork_rngs(self): # make a scope-function to transform def core_fn(scopes, module_hash, *args, **kwargs): self: Module = module_hash.module # make a clone of self using its arguments attrs = { f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != 'parent' and f.init } # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) object.__setattr__(cloned, '_state', state.export()) res = fn(cloned, *args, **kwargs) self._state.reimport(cloned._state) _test_transformed_return_values(res, fn_name) return res # here we apply the given lifting transform to the scope-ingesting fn trafo_fn = trafo_fn or transform(core_fn, *trafo_args, **trafo_kwargs) module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) # get a hash for the Module by using its repr as a proxy hash_key = _HashableProxy.from_module(self) ret = trafo_fn(module_scopes, hash_key, *args, **kwargs) return ret return wrapped_fn transformed_fns = { fn_name: create_trans_fn(fn_name, fn_trafo_args) for fn_name, fn_trafo_args in class_trafo_args.items() } # construct new dynamic class w. transformed methods transformed_cls = type( transform.__name__.capitalize() + module_class.__name__, (module_class,), transformed_fns, ) # Handle partially initialized module class constructors. if partial_object is not None: transformed_cls = functools.partial( transformed_cls, *partial_object.args, **partial_object.keywords ) return transformed_cls # Utility to wrap a class or to use as decorator in def of class method. # ----------------------------------------------------------------------------- TransformTarget = Union[type[Module], Callable[..., Any]] Target = TypeVar('Target', bound=TransformTarget) def _is_module_class(target: TransformTarget) -> bool: return ( inspect.isclass(target) and issubclass(target, Module) or (isinstance(target, functools.partial)) and _is_module_class(target.func) ) def lift_transform( transform, target, *trafo_args, methods=None, **trafo_kwargs ): """Applies to class or as a decorator on class fns.""" # TODO(marcvanzee): Improve docstrings (#1977). if _is_module_class(target): return module_class_lift_transform( transform, target, *trafo_args, methods=methods, **trafo_kwargs ) # we presume this is being used as a function decorator in class definition elif callable(target) and not isinstance(target, Module): return decorator_lift_transform( transform, target, *trafo_args, **trafo_kwargs ) else: raise errors.TransformTargetError(target) def lift_transform_cached( transform, target, *trafo_args, methods=None, **trafo_kwargs ): """Applies to class or as a decorator on class fns.""" # TODO(marcvanzee): Improve docstrings (#1977). if _is_module_class(target): return module_class_lift_transform_cached( transform, target, *trafo_args, methods=methods, **trafo_kwargs ) # we presume this is being used as a function decorator in class definition elif callable(target) and not isinstance(target, Module): return decorator_lift_transform_cached( transform, target, *trafo_args, **trafo_kwargs ) else: raise errors.TransformTargetError(target) def lift_direct_transform( transform: Callable[..., Any], targets: tuple[Callable[..., Any], ...], mdl: Module, *args, multi_scope=True, **kwargs, ): """Lift direct transform.""" # TODO(marcvanzee): Improve docstrings (#1977). for target in targets: if _is_module_class(target): raise ValueError( f'The {transform.__name__} transform can only be applied on a Module' ' method. That is function that takes a Module instance as its first' ' arg.' ) elif not callable(target): raise ValueError('transform target must be callable') # normalize self.foo bound methods to class.foo unbound methods. targets = tuple(_get_unbound_fn(target) for target in targets) aug_transform = lambda *fns: functools.partial(transform, *fns) return decorator_lift_transform( aug_transform, targets, multi_scope=multi_scope )(mdl, *args, **kwargs) def vmap( target: Target, variable_axes: Mapping[CollectionFilter, InOutAxis] = FrozenDict(), split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), in_axes=0, out_axes=0, axis_size: int | None = None, axis_name: str | None = None, spmd_axis_name: str | None = None, metadata_params: Mapping[Any, Any] = {}, methods=None, ) -> Target: """A lifted version of ``jax.vmap``. See ``jax.vmap`` for the unlifted batch transform in Jax. ``vmap`` can be used to add a batch axis to a ``Module``. For example we could create a version of ``Dense`` with a batch axis that does not share parameters:: >>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': 0}, ... split_rngs={'params': True}) By using ``variable_axes={'params': 0}``, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the 'params' RNG, otherwise the parameters would be initialized identically along the mapped axis. Similarly, ``vmap`` could be used to add a batch axis with parameter sharing:: >>> import flax.linen as nn >>> BatchDense = nn.vmap( ... nn.Dense, ... in_axes=0, out_axes=0, ... variable_axes={'params': None}, ... split_rngs={'params': False}) Here we use ``variable_axes={'params': None}`` to indicate the parameter variables are shared along the mapped axis. Consequently, the 'params' RNG must also be shared. Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. variable_axes: the variable collections that are lifted into the batching transformation. Use ``None`` to indicate a broadcasted collection or an integer to map over an axis. For example, passing in ``variable_axes={'params': None}`` will indicate that the parameter variables should be shared along the mapped axis. split_rngs: Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted. in_axes: Specifies the mapping of the input arguments (see ``jax.vmap``). out_axes: Specifies the mapping of the return value (see ``jax.vmap``). axis_size: Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments. axis_name: Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. ``jax.lax.pmean``, ``jax.lax.ppermute``, etc.). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. methods: If ``target`` is a ``Module``, the methods of ``Module`` to vmap over. spmd_axis_name: Axis name added to any pjit sharding constraints appearing in ``fn``. See also https://github.com/google/flax/blob/main/flax/linen/partitioning.py. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. Returns: A batched/vectorized version of ``target``, with the same arguments but with extra axes at positions indicated by ``in_axes``, and the same return value, but with extra axes at positions indicated by ``out_axes``. """ return lift_transform( lift.vmap, target, variable_axes, split_rngs, methods=methods, in_axes=in_axes, out_axes=out_axes, axis_size=axis_size, axis_name=axis_name, metadata_params=metadata_params, spmd_axis_name=spmd_axis_name, ) def jit( target: Target, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, static_argnums: int | Iterable[int] = (), static_argnames: str | Iterable[str] = (), donate_argnums: int | Iterable[int] = (), device=None, backend: str | None = None, methods=None, ) -> Target: """Lifted version of ``jax.jit``. Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. variables: The variable collections that are lifted. By default all collections are lifted. rngs: The PRNG sequences that are lifted. By default all PRNG sequences are lifted. static_argnums: An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by ``static_argnums`` then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to (). static_argnames: An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. donate_argnums: Specify which arguments are "donated" to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. device: This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. backend: a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. methods: If ``target`` is a ``Module``, the methods of ``Module`` to jit. Returns: A wrapped version of target, set up for just-in-time compilation. """ # TODO(marcvanzee): Improve docstrings (#1977). return lift_transform_cached( lift.jit, target, variables=variables, rngs=rngs, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, device=device, backend=backend, methods=methods, ) def checkpoint( target: Target, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, concrete: bool = False, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: Callable[..., bool] | None = None, methods=None, ) -> Target: """Lifted version of ``jax.checkpoint``. Checkpointing is a technique for reducing memory usage by recomputing activations during backpropagation. When training large models, it can be helpful to checkpoint parts of the model to trade off memory usage for additional computation. Example:: >>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CheckpointedMLP(nn.Module): ... @nn.checkpoint ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(128)(x) ... x = nn.relu(x) ... x = nn.Dense(1)(x) ... return x ... >>> model = CheckpointedMLP() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16))) This function is aliased to ``remat`` just like ``jax.remat``. Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. intermediate computations will be re-computed when computing gradients for the target. variables: The variable collections that are lifted. By default all collections are lifted. rngs: The PRNG sequences that are lifted. By default all PRNG sequences are lifted. concrete: Optional, boolean indicating whether ``fun`` may involve value-dependent Python control flow (default ``False``). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with :func:`jax.jit` it can lead to some extra computation. prevent_cse: Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a ``jit`` or ``pmap``, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a ``scan``, this CSE prevention mechanism is unnecessary, in which case ``prevent_cse`` should be set to False. static_argnums: Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. policy: Experimental checkpoint policy, see ``jax.checkpoint``. methods: An optional list of method names that will be lifted, if ``methods`` is None (default) only the ``__call__`` method will be lifted. If``target`` is a function, ``methods`` is ignored. Returns: A wrapped version of ``target``. When computing gradients intermediate computations will be re-computed on the backward pass. """ # subtract 1 from each static_argnums because 'self' is not passed to the # lifted function static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums) return lift_transform( lift.checkpoint, target, variables=variables, rngs=rngs, concrete=concrete, static_argnums=static_argnums, prevent_cse=prevent_cse, policy=policy, methods=methods, ) remat = checkpoint def remat_scan( target: Target, lengths: Sequence[int] | None = (), policy: Callable[..., bool] | None = None, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict( {True: 0} ), split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({True: True}), ) -> Target: """Combines remat and scan for memory efficiency and constant time compilation. ``remat_scan`` allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models. Example:: >>> import flax.linen as nn >>> class BigModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) ... # 100x dense with O(sqrt(N)) memory for gradient computation ... return DenseStack(8, name="dense_stack")(x) Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. lengths: number of loop iterations at the given level. The total number of iterations ``n = prod(lengths)``. each loop is rematerialized. This way the memory consumption is proportional to ``n^(1 / d)`` where ``d = len(lengths)``. Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop. policy: Experimental checkpoint policy, see ``jax.checkpoint``. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. variable_carry: Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. variable_axes: the variable collections that are scanned over. Defaults to ``{True: 0}``. split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to ``{True: True}``. Returns: A wrapped version of ``target`` that repeats itself prod(lengths) times. """ return lift_transform( lift.remat_scan, target, lengths=lengths, variable_broadcast=variable_broadcast, variable_carry=variable_carry, variable_axes=variable_axes, split_rngs=split_rngs, policy=policy, ) def scan( target: Target, variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict(), variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), in_axes=0, out_axes=0, length: int | None = None, reverse: bool = False, unroll: int = 1, data_transform: Callable[..., Any] | None = None, metadata_params: Mapping[Any, Any] = {}, methods=None, _split_transpose: bool = False, check_constancy_invariants: bool = True, ) -> Target: """A lifted version of ``jax.lax.scan``. See ``jax.lax.scan`` for the unlifted scan in Jax. To improve consistency with ``vmap``, this version of scan uses ``in_axes`` and ``out_axes`` to determine which arguments are scanned over and along which axis. ``scan`` distinguishes between 3 different types of values inside the loop: #. **scan**: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis. #. **carry**: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop. #. **broadcast**: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables. The ``target`` should have the signature ``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. Example:: >>> import flax.linen as nn >>> import jax >>> import jax.numpy as jnp ... >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ScanLSTM = nn.scan( ... nn.LSTMCell, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) >>> y, variables = module.init_with_output(jax.random.key(0), x) Note that when providing a function to ``nn.scan``, the scanning happens over all arguments starting from the third argument, as specified by ``in_axes``. The previous example could also be written using the functional form as:: >>> class LSTM(nn.Module): ... features: int ... ... @nn.compact ... def __call__(self, x): ... ... cell = nn.LSTMCell(self.features) ... def body_fn(cell, carry, x): ... carry, y = cell(carry, x) ... return carry, y ... scan = nn.scan( ... body_fn, variable_broadcast="params", ... split_rngs={"params": False}, in_axes=1, out_axes=1) ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7))) You can also use ``scan`` to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:: >>> class ResidualMLPBlock(nn.Module): ... @nn.compact ... def __call__(self, x, _): ... h = nn.Dense(features=2)(x) ... h = nn.relu(h) ... return x + h, None ... >>> class ResidualMLP(nn.Module): ... n_layers: int = 4 ... ... @nn.compact ... def __call__(self, x): ... ScanMLP = nn.scan( ... ResidualMLPBlock, variable_axes={'params': 0}, ... variable_broadcast=False, split_rngs={'params': True}, ... length=self.n_layers) ... x, _ = ScanMLP()(x, None) ... return x ... >>> model = ResidualMLP(n_layers=4) >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2))) To reduce both compilation and memory usage, you can use :func:`remat_scan` which will in addition checkpoint each layer in the scan loop. Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. variable_axes: the variable collections that are scanned over. variable_broadcast: Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. variable_carry: Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. in_axes: Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use ``flax.core.broadcast`` to feed an entire input to each iteration of the scan body. out_axes: Specifies the axis to scan over for the return value. Should be a prefix tree of the return value. length: Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments. reverse: If true, scan from end to start in reverse order. unroll: how many scan iterations to unroll within a single iteration of a loop (default: 1). data_transform: optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over. _split_transpose: An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature. check_constancy_invariants: If true, the scan will verify that the broadcast constants are true loop invariants, and further supports broadcast function (non-carry) outputs. This requires an extra jax tracing step however, so setting to false can reduce trace time on larger models. Returns: The scan function with the signature ``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of the loop. """ return lift_transform( lift.scan, target, variable_axes=variable_axes, variable_broadcast=variable_broadcast, variable_carry=variable_carry, split_rngs=split_rngs, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, data_transform=data_transform, metadata_params=metadata_params, methods=methods, check_constancy_invariants=check_constancy_invariants, ) def map_variables( target: Target, mapped_collections: CollectionFilter = True, trans_in_fn: Callable[..., Any] = lift.id_fn, trans_out_fn: Callable[..., Any] = lift.id_fn, init: bool = False, mutable: bool = False, rngs: PRNGSequenceFilter = True, variables: CollectionFilter = True, methods=None, ) -> Target: """Map Variables inside a module. ``map_variables`` can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself. Example:: >>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CausalDense(nn.Module): ... '''A dense layer that masks the weights such that the output is ... causal, i.e. output i only depends on input <= i. ... ''' ... features: int ... ... def apply_mask(self, variables): ... return (jax.tree_util.tree_map(jnp.triu, variables) ... if not self.is_initializing() else variables) ... ... def setup(self): ... # temporary class ... _CausalDense = nn.map_variables( ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) ... self.dense = _CausalDense(features=self.features, use_bias=False) ... ... def __call__(self, x): ... return self.dense(x) ... >>> module = CausalDense(features=5) >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5))) Args: target: the module or function to be transformed. mapped_collections: the collection(s) to be transformed. trans_in_fn: modifies the variables before applying the module or function. trans_out_fn: modifies the variables after applying the module or function, it is only applied if either ``init`` or ``mutable`` are not False. init: If True, variables are initialized before transformation. mutable: If True, the mapped variable collections will be mutable. rngs: PRNGSequences added to the transformed scope (default: all). variables: Additional Variable collections added to the transformed scope. Besides those specified by ``target`` (default: all). methods: If ``target`` is a ``Module``, the methods of ``Module`` to map variables for. Returns: a wrapped version of ``target`` that will map the specified collections. """ return lift_transform( lift.map_variables, target, mapped_collections, trans_in_fn, trans_out_fn, init, mutable, rngs, variables, methods=methods, ) def vjp( fn: Callable[..., Any], mdl: Module, *primals, has_aux: bool = False, reduce_axes=(), vjp_variables: CollectionFilter = 'params', variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, multi_scope: bool = False, ): """A lifted version of ``jax.vjp``. See ``jax.vjp`` for the unlifted vector-Jacobian product (backward gradient). Note that a gradient is returned for all variables in the collections specified by ``vjp_variables``. However, the backward function only expects a cotangent for the return value of ``fn``. If variables require a co-tangent as well they can be returned from ``fn`` using ``Module.variables``. Example:: >>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... p = self.param('scale', nn.initializers.zeros_init(), ()) ... return p * x * y >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, y): ... z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) ... params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) ... return z, params_grad, x_grad, y_grad Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. mdl: The module of which the variables will be differentiated. *primals: A sequence of primal values at which the Jacobian of ``fn`` should be evaluated. The length of ``primals`` should be equal to the number of positional parameters to ``fn``. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default ``False``. vjp_variables: The vjpfun will return a cotangent vector for all variable collections specified by this filter. variables: other variables collections that are available inside ``fn`` but do not receive a cotangent. rngs: the prngs that are available inside ``fn``. multi_scope: for Modules containing multiple scopes from outside modules passed in, allow for variable gradients to be returned for multiple scopes instead of erroring. Returns: If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where ``primals_out`` is ``fn(*primals)``. ``vjpfun`` is a function from a cotangent vector with the same shape as ``primals_out`` to a tuple of cotangent vectors with the same shape as ``primals``, representing the vector-Jacobian product of ``fn`` evaluated at ``primals``. If ``has_aux`` is ``True``, returns a ``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ if reduce_axes: raise NotImplementedError('reduce_axes argument to vjp is deprecated') del reduce_axes return lift_direct_transform( lift.vjp, (fn,), mdl, *primals, multi_scope=multi_scope, has_aux=has_aux, vjp_variables=vjp_variables, variables=variables, rngs=rngs, ) def value_and_grad( fn: Callable[..., Any], mdl: Module, *primals, has_aux: bool = False, reduce_axes=(), variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ): """A limited, lifted equivalent of ``jax.value_and_grad``. Note that for this convenience function, gradients are only calculated for the function inputs, and not with respect to any module variables. The target function must return a scalar-valued output. For a more general lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobian product. Example:: class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.zeros_init(), ()) return p * x * y class Foo(nn.Module): @nn.compact def __call__(self, x, y): z, (x_grad, y_grad) = nn.value_and_grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) return z, x_grad, y_grad Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. mdl: The module of which the variables will be differentiated. *primals: A sequence of primal values at which the Jacobian of ``fn`` should be evaluated. The length of ``primals`` should be equal to the number of positional parameters to ``fn``. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default ``False``. variables: variables collections that are available inside ``fn`` but do not receive a cotangent. rngs: the prngs that are available inside ``fn``. Returns: If ``has_aux`` is ``False``, returns a ``primals_out, grads`` pair, where ``primals_out`` is ``fn(*primals)``. ``grads`` are the gradients for the corresponding primals and do not include the gradients for module variables. If ``has_aux`` is ``True``, returns a ``(primals_out, aux), grads`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ if reduce_axes: raise NotImplementedError( 'reduce_axes argument to value_and_grad is deprecated') del reduce_axes grad_partial = functools.partial( lift_direct_transform, lift.value_and_grad, (fn,), mdl, *primals, has_aux=has_aux, variables=variables, rngs=rngs, ) if has_aux: out, aux, argument_grads = grad_partial() if out.shape != (): raise ValueError( 'grad can only work on functions with ' f'scalar-valued outputs. out shape={out.shape}' ) return (out, aux), argument_grads else: out, argument_grads = grad_partial() if out.shape != (): raise ValueError( 'grad can only work on functions with ' f'scalar-valued outputs. out shape={out.shape}' ) return out, argument_grads def grad( fn: Callable[..., Any], mdl: Module, *primals, has_aux: bool = False, reduce_axes=(), variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ): """A limited, lifted equivalent of ``jax.grad``. Note that for this convenience function, gradients are only calculated for the function inputs, and not with respect to any module variables. The target function must return a scalar-valued output. For a more general lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobian product. Example:: class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.zeros_init(), ()) return p * x * y class Foo(nn.Module): @nn.compact def __call__(self, x, y): x_grad, y_grad = nn.grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) return x_grad, y_grad Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. mdl: The module of which the variables will be differentiated. *primals: A sequence of primal values at which the Jacobian of ``fn`` should be evaluated. The length of ``primals`` should be equal to the number of positional parameters to ``fn``. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default ``False``. variables: variables collections that are available inside ``fn`` but do not receive a cotangent. rngs: the prngs that are available inside ``fn``. Returns: If ``has_aux`` is ``False``, returns ``grads``, where ``grads`` are the gradients for the corresponding primals and do not include the gradients for module variables. If ``has_aux`` is ``True``, returns a ``(grads, aux)`` tuple where ``aux`` is the auxiliary data returned by ``fn``. """ if reduce_axes: raise NotImplementedError('reduce_axes argument to grad is deprecated') del reduce_axes value_and_grad_partial = functools.partial( value_and_grad, fn, mdl, *primals, has_aux=has_aux, variables=variables, rngs=rngs, ) if has_aux: (_, aux), argument_grads = value_and_grad_partial() return argument_grads, aux else: _, argument_grads = value_and_grad_partial() return argument_grads def jvp( fn: Callable[..., Any], mdl: Module, primals, tangents, variable_tangents, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> tuple[Any, Callable[..., Any]] | tuple[Any, Callable[..., Any], Any]: """A lifted version of ``jax.jvp``. See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by ``fn`` using ``Module.variables``:: >>> import flax.linen as nn >>> import jax.numpy as jnp >>> class LearnScale(nn.Module): ... @nn.compact ... def __call__(self, x): ... p = self.param('test', nn.initializers._init(), ()) ... return p * x >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... scale = LearnScale() ... vars_t = jax.tree_util.tree_map(jnp.ones_like, ... scale.variables.get('params', {})) ... _, out_t = nn.jvp( ... lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t Example:: >>> def learn_scale(scope, x): ... p = scope.param('scale', nn.initializers.zeros_init(), ()) ... return p * x >>> def f(scope, x): ... vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) ... x, out_t = lift.jvp( ... learn_scale, scope, (x,), (jnp.zeros_like(x),), ... variable_tangents={'params': vars_t}) ... return out_t Args: fn: Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments. mdl: The module of which the variables will be differentiated. primals: The primal values at which the Jacobian of ``fun`` should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of ``fun``. tangents: The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as ``primals``. variable_tangents: A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent. variables: other variables collections that are available in ``fn`` but do not receive a tangent. rngs: the prngs that are available inside ``fn``. Returns: A ``(primals_out, tangents_out)`` pair, where ``primals_out`` is ``fun(*primals)``, and ``tangents_out`` is the Jacobian-vector product of ``function`` evaluated at ``primals`` with ``tangents``. The ``tangents_out`` value has the same Python tree structure and shapes as ``primals_out``. """ return lift_direct_transform( lift.jvp, (fn,), mdl, primals, tangents, variable_tangents, multi_scope=False, variables=variables, rngs=rngs, ) ModuleT = TypeVar('ModuleT', bound=Module) C = TypeVar('C') def while_loop( cond_fn: Callable[[ModuleT, C], bool], body_fn: Callable[[ModuleT, C], C], mdl: ModuleT, init: C, carry_variables: CollectionFilter = False, broadcast_variables: CollectionFilter = True, split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), ) -> C: """Lifted version of jax.lax.while_loop. The lifted scope is passed to ``cond_fn`` and ``body_fn``. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling ``body_fn`` once manually before calling ``while_loop`` if variable initialization is required. Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class WhileLoopExample(nn.Module): ... @nn.compact ... def __call__(self, x): ... def cond_fn(mdl, c): ... return mdl.variables['state']['acc'] < 10 ... def body_fn(mdl, c): ... acc = mdl.variable('state', 'acc', lambda: jnp.array(0)) ... acc.value += 1 ... y = nn.Dense(c.shape[-1])(c) ... return y ... c = x ... if self.is_mutable_collection('params'): ... return body_fn(self, c) ... else: ... return nn.while_loop(cond_fn, body_fn, self, c, ... carry_variables='state') >>> k = jax.random.key(0) >>> x = jnp.ones((2, 2)) >>> initial_vars = WhileLoopExample().init(k, x) >>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state']) Args: cond_fn: Should return True as long as the loop should continue. body_fn: The body of the while loop. mdl: The Module which should be lifted into the loop. init: The initial state passed to the loop carry_variables: collections that are carried through the loop and are therefore mutable (default: none). broadcast_variables: collections that are closed over and are therefore read-only (default: all collections) split_rngs: Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Returns: The final state after executing the while loop. """ return lift_direct_transform( lift.while_loop, (cond_fn, body_fn), mdl, init, carry_variables, broadcast_variables, split_rngs, ) def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): return lift.cond( pred, t_fn, f_fn, scope, *ops, variables=variables, rngs=rngs ) def cond( pred: Any, true_fun: Callable[..., C], false_fun: Callable[..., C], mdl: Module, *operands, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> C: """Lifted version of ``jax.lax.cond``. The returned values from ``true_fun`` and ``false_fun`` must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different. Example:: >>> import flax.linen as nn >>> class CondExample(nn.Module): ... @nn.compact ... def __call__(self, x, pred): ... self.variable('state', 'true_count', lambda: 0) ... self.variable('state', 'false_count', lambda: 0) ... def true_fn(mdl, x): ... mdl.variable('state', 'true_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def false_fn(mdl, x): ... mdl.variable('state', 'false_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... return nn.cond(pred, true_fn, false_fn, self, x) Args: pred: determines if true_fun or false_fun is evaluated. true_fun: The function evaluated when ``pred`` is ``True``. The signature is (module, *operands) -> T. false_fun: The function evaluated when ``pred`` is ``False``. The signature is (module, *operands) -> T. mdl: A Module target to pass. *operands: The arguments passed to ``true_fun`` and ``false_fun`` variables: The variable collections passed to the conditional branches (default: all) rngs: The PRNG sequences passed to the conditionals (default: all) Returns: The result of the evaluated branch (``true_fun`` or ``false_fun``). """ return lift_direct_transform( _cond_wrapper, (true_fun, false_fun), mdl, pred, *operands, variables=variables, rngs=rngs, ) def _switch_wrapper(*args, variables, rngs, n_branches): # first n_branches arguments are branches. # then scope, index, and the rest are *operands branches = args[:n_branches] scope, index, *operands = args[n_branches:] return lift.switch( index, branches, scope, *operands, variables=variables, rngs=rngs ) def switch( index: Any, branches: Sequence[Callable[..., C]], mdl: Module, *operands, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> C: """Lifted version of ``jax.lax.switch``. The returned values from ``branches`` must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different. Example:: >>> import flax.linen as nn >>> class SwitchExample(nn.Module): ... @nn.compact ... def __call__(self, x, index): ... self.variable('state', 'a_count', lambda: 0) ... self.variable('state', 'b_count', lambda: 0) ... self.variable('state', 'c_count', lambda: 0) ... def a_fn(mdl, x): ... mdl.variable('state', 'a_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... def b_fn(mdl, x): ... mdl.variable('state', 'b_count').value += 1 ... return -nn.Dense(2, name='dense')(x) ... def c_fn(mdl, x): ... mdl.variable('state', 'c_count').value += 1 ... return nn.Dense(2, name='dense')(x) ... return nn.switch(index, [a_fn, b_fn, c_fn], self, x) If you want to have a different parameter structure for each branch you should run all branches on initialization before calling switch:: >>> class MultiHeadSwitchExample(nn.Module): ... def setup(self) -> None: ... self.heads = [ ... nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), ... nn.Sequential([nn.Dense(11), nn.Dense(5)]), ... nn.Dense(5), ... ] ... ... @nn.compact ... def __call__(self, x, index): ... def head_fn(i): ... return lambda mdl, x: mdl.heads[i](x) ... branches = [head_fn(i) for i in range(len(self.heads))] ... ... # run all branches on init ... if self.is_mutable_collection('params'): ... for branch in branches: ... _ = branch(self, x) ... ... return nn.switch(index, branches, self, x) Args: index: Integer scalar type, indicating which branch function to apply. branches: Sequence of functions to be applied based on index. The signature of each function is (module, *operands) -> T. mdl: A Module target to pass. *operands: The arguments passed to the branches. variables: The variable collections passed to the conditional branches (default: all) rngs: The PRNG sequences passed to the conditionals (default: all) Returns: The result of the evaluated branch. """ return lift_direct_transform( _switch_wrapper, tuple(branches), mdl, index, *operands, variables=variables, rngs=rngs, n_branches=len(branches), ) # a version of lift.custom_vjp with a single scope function # this avoids having to lift multiple functions in # lift_transform. def _custom_vjp_single_scope_fn( fn: Callable[..., Any], backward_fn: Callable[..., Any], grad_vars: CollectionFilter = 'params', nondiff_argnums=(), ): nodiff_fn = functools.partial(fn, needs_residual=False) forward_fn = functools.partial(fn, needs_residual=True) return lift.custom_vjp( nodiff_fn, forward_fn, backward_fn, grad_vars, nondiff_argnums ) def custom_vjp( fn: Callable[..., Any], forward_fn: Callable[..., Any], backward_fn: Callable[..., Any], grad_vars: CollectionFilter = 'params', nondiff_argnums=(), ): """Lifted version of ``jax.custom_vjp``. ``forward_fn`` and ``backward_fn`` together define a custom vjp for ``fn``. The original ``fn`` will run in case a vjp (backward gradient) is not computed. The ``forward_fn`` receives the same arguments as ``fn`` but is expected to return a tuple containing the output of ``fn(mdl, *args)`` and the residuals that are passed to ``backward_fn``. The ``backward_fn`` receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents. Note that the vjp function returned by ``nn.vjp`` can be passed as residual and used in the ``backward_fn``. The scope is unavailable during the backward pass. If the module is required in ``backward_fn``, a snapshot of the variables can be taken and returned as a residual in the ``forward_fn``. Example:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... def f(mdl, x): ... return mdl(x) ... ... def fwd(mdl, x): ... return nn.vjp(f, mdl, x) ... ... def bwd(vjp_fn, y_t): ... params_t, *inputs_t = vjp_fn(y_t) ... params_t = jax.tree_util.tree_map(jnp.sign, params_t) ... return (params_t, *inputs_t) ... ... sign_grad = nn.custom_vjp( ... f, forward_fn=fwd, backward_fn=bwd) ... return sign_grad(nn.Dense(1), x).reshape(()) >>> x = jnp.ones((2,)) >>> variables = Foo().init(jax.random.key(0), x) >>> grad = jax.grad(Foo().apply)(variables, x) Args: fn: The function to define a custom_vjp for. forward_fn: A function with the same arguments as ``fn`` returning an tuple with the original output and the residuals that will be passed to ``backward_fn``. backward_fn: arguments are passed as ``(*nondiff_args, residuals, tangents)`` The function should return a tuple containing the tangents for the variable in the collections specified by ``grad_vars`` and the input arguments (except the module and nondiff args). grad_vars: The collections for which a vjp will be computed (default: "params"). nondiff_argnums: arguments for which no vjp is computed. Returns: A function with the same signature as ``fn`` with the custom vjp. """ def shared_forward_fn(*args, needs_residual, **kwargs): if needs_residual: return forward_fn(*args, **kwargs) else: return fn(*args, **kwargs) return decorator_lift_transform( _custom_vjp_single_scope_fn, shared_forward_fn, backward_fn=backward_fn, grad_vars=grad_vars, nondiff_argnums=nondiff_argnums, multi_scope=False, ) def named_call(class_fn, force=True): """Labels a method for labelled traces in profiles. Note that it is better to use the `jax.named_scope` context manager directly to add names to JAX's metadata name stack. Args: class_fn: The class method to label. force: If True, the named_call transform is applied even if it is globally disabled. (e.g.: by calling `flax.linen.disable_named_call()`) Returns: A wrapped version of ``class_fn`` that is labeled. """ # We use JAX's dynamic name-stack named_call. No transform boundary needed! @functools.wraps(class_fn) def wrapped_fn(self, *args, **kwargs): if (not force and not linen_module._use_named_call) or self._state.in_setup: # pylint: disable=protected-access # pylint: disable=protected-access return class_fn(self, *args, **kwargs) full_name = _derive_profiling_name(self, class_fn) return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs) return wrapped_fn def add_metadata_axis( target: Target, variable_axes: Mapping[CollectionFilter, InOutAxis] = FrozenDict(), metadata_params: dict[Any, Any] = {}, ) -> Target: """A helper to manipulate boxed axis metadata. This is a helper to manipulate the *metadata* in boxed variables, similar to how lifted ``vmap`` and ``scan`` will handle the introduction and stripping of the new metadata axis across a transform boundary. Args: target: a ``Module`` or a function taking a ``Module`` as its first argument. variable_axes: the variable collections whose axis metadata is being transformed. Use `None` to indicate a broadcasted collection or an integer to specify an axis index for an introduced axis. methods: If `target` is a `Module`, the methods of `Module` to vmap over. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. Returns: A transformed version of ``target`` that performs a transform of the axis metadata on its variables. """ def add_fn(axis): return lambda x: meta.add_axis(x, axis, metadata_params) def remove_fn(axis): return lambda x: meta.remove_axis(x, axis, metadata_params) for col_name, axis in variable_axes.items(): target = map_variables( target, col_name, trans_in_fn=remove_fn(axis), trans_out_fn=add_fn(axis), mutable=True, ) return target def fold_rngs( target: Target, variables: CollectionFilter = True, rngs: PRNGSequenceFilter = True, ) -> Target: return lift_transform_cached( lift.fold_rngs, target, variables=variables, rngs=rngs, ) ================================================ FILE: flax/metrics/__init__.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: flax/metrics/tensorboard.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Write Summaries from JAX for use with Tensorboard.""" import contextlib import functools import os import numpy as np import tensorflow as tf # pytype: disable=import-error from tensorboard.plugins.hparams import api as hparams_api # pylint: disable=g-import-not-at-top from flax import io def _flatten_dict(input_dict, parent_key='', sep='.'): """Flattens and simplifies dict such that it can be used by hparams. Args: input_dict: Input dict, e.g., from ConfigDict. parent_key: String used in recursion. sep: String used to separate parent and child keys. Returns: Flattened dict. """ items = [] for k, v in input_dict.items(): new_key = parent_key + sep + k if parent_key else k # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py valid_types = ( bool, int, float, str, np.bool_, np.integer, np.floating, np.character, ) if isinstance(v, dict): # Recursively flatten the dict. items.extend(_flatten_dict(v, new_key, sep=sep).items()) continue elif not isinstance(v, valid_types): # Cast any incompatible values as strings such that they can be handled by hparams v = str(v) items.append((new_key, v)) return dict(items) @contextlib.contextmanager def _as_default(summary_writer: tf.summary.SummaryWriter, auto_flush: bool): """No-flush variation of summary_writer.as_default().""" context_manager = summary_writer.as_default() try: context_manager.__enter__() yield summary_writer finally: old_flush = summary_writer.flush new_flush = old_flush if auto_flush else lambda: None summary_writer.flush = new_flush context_manager.__exit__(None, None, None) summary_writer.flush = old_flush class SummaryWriter: """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir, auto_flush=True): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. auto_flush: if true, flush after every reported metric. """ log_dir = os.fspath(log_dir) # If needed, create log_dir directory as well as missing parent directories. if not io.isdir(log_dir): io.makedirs(log_dir) self._event_writer = tf.summary.create_file_writer(log_dir) self._as_default = functools.partial(_as_default, auto_flush=auto_flush) self._closed = False def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def flush(self): self._event_writer.flush() def scalar(self, tag, value, step): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(np.array(value)) with self._as_default(self._event_writer): tf.summary.scalar(name=tag, data=value, step=step) def image(self, tag, image, step, max_outputs=3): """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3], [K,H,W], [K,H,W,1], [K,H,W,3] Save image in greyscale or colors. Pixel values could be either uint8 or float. Floating point values should be in range [0, 1). step: int: training step max_outputs: At most this many images will be emitted at each step. Defaults to 3. """ image = np.array(image) # tf.summary.image expects image to have shape [k, h, w, c] where, # k = number of samples, h = height, w = width, c = number of channels. if len(np.shape(image)) == 2: image = image[np.newaxis, :, :, np.newaxis] elif len(np.shape(image)) == 3: # this could be either [k, h, w] or [h, w, c] if np.shape(image)[-1] in (1, 3): image = image[np.newaxis, :, :, :] else: image = image[:, :, :, np.newaxis] if np.shape(image)[-1] == 1: image = np.repeat(image, 3, axis=-1) # Convert to tensor value as tf.summary.image expects data to be a tensor. image = tf.convert_to_tensor(image) with self._as_default(self._event_writer): tf.summary.image(name=tag, data=image, step=step, max_outputs=max_outputs) def audio(self, tag, audiodata, step, sample_rate=44100, max_outputs=3): """Saves audio as wave. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples, Nframes, Nchannels]: audio data to be saved as wave. The data will be clipped to [-1.0, 1.0]. step: int: training step sample_rate: sample rate of passed in audio buffer max_outputs: At most this many audio clips will be emitted at each step. Defaults to 3. """ # tf.summary.audio expects the audio data to have floating values in # [-1.0, 1.0]. audiodata = np.clip(np.array(audiodata), -1, 1) # Convert to tensor value as tf.summary.audio expects data to be a tensor. audio = tf.convert_to_tensor(audiodata, dtype=tf.float32) with self._as_default(self._event_writer): tf.summary.audio( name=tag, data=audio, sample_rate=sample_rate, step=step, max_outputs=max_outputs, encoding='wav', ) def histogram(self, tag, values, step, bins=None): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine step: int: training step bins: number of bins in histogram """ values = np.array(values) values = np.reshape(values, -1) with self._as_default(self._event_writer): tf.summary.histogram(name=tag, data=values, step=step, buckets=bins) def text(self, tag, textdata, step): """Saves a text summary. Args: tag: str: label for this data textdata: string step: int: training step Note: markdown formatting is rendered by tensorboard. """ if not isinstance(textdata, (str, bytes)): raise ValueError('`textdata` should be of the type `str` or `bytes`.') with self._as_default(self._event_writer): tf.summary.text(name=tag, data=tf.constant(textdata), step=step) def write(self, tag, tensor, step, metadata=None): """Saves an arbitrary tensor summary. Useful when working with custom plugins or constructing a summary directly. Args: tag: str: label for this data tensor: ndarray: tensor data to save. step: int: training step metadata: Optional SummaryMetadata, as a proto or serialized bytes. Note: markdown formatting is rendered by tensorboard. """ with self._as_default(self._event_writer): tf.summary.write(tag=tag, tensor=tensor, step=step, metadata=metadata) def hparams(self, hparams): """Saves hyper parameters. Args: hparams: Flat mapping from hyper parameter name to value. """ with self._as_default(self._event_writer): hparams_api.hparams(hparams=_flatten_dict(hparams)) ================================================ FILE: flax/nnx/README.md ================================================ [![codecov](https://codecov.io/gh/cgarciae/nnx/branch/main/graph/badge.svg?token=VqJjL474Z7)](https://codecov.io/gh/cgarciae/nnx) # NNX _**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/nnx/index.html) | NNX is a JAX-based neural network library that focuses on providing the best development experience to make building and experimenting with neural networks as easy and intuitive as possible. * **Pythonic**: Modules are standard Python classes, promoting ease of use and a more familiar development experience. * **Easy-to-use**: NNX provides a set of transforms that take care of state management, allowing users to focus on building their models and training loops. * **Expressive**: NNX allows fine-grained over the Module state with lifted transforms, enabling users to define complex architectures. * **Compatible**: NNX allows functionalizing Module state, making it possible to directly use JAX transformations when needed. ## What does NNX look like? NNX removes most of the friction from building and training neural networks in JAX. It provides a Module system that uses standard Python classes, and a set of transforms that extend JAX to handle objects. ```python from flax import nnx import optax class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) # reference sharing @nnx.jit # automatic state management def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) # call methods directly return ((y_pred - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # inplace updates return loss ``` To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#) guide. ## Installation To get started with `nnx`, install Flax from GitHub: ``` pip install git+https://github.com/google/flax.git ``` ### Examples * [LM1B](https://github.com/google/flax/tree/main/examples/lm1b_nnx): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples * [Basic Example](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. * [Using the Functional API](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. * [Training a VAE](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. * [Scan over layers](https://github.com/google/flax/tree/main/examples/nnx_toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. ================================================ FILE: flax/nnx/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flax.core.spmd import logical_axis_rules as logical_axis_rules from flax.linen.pooling import avg_pool as avg_pool from flax.linen.pooling import max_pool as max_pool from flax.linen.pooling import min_pool as min_pool from flax.linen.pooling import pool as pool from flax.typing import Initializer as Initializer from .bridge import wrappers as wrappers from .filterlib import WithTag as WithTag from .filterlib import PathContains as PathContains from .filterlib import OfType as OfType from .filterlib import Any as Any from .filterlib import All as All from .filterlib import Not as Not from .filterlib import Everything as Everything from .filterlib import Nothing as Nothing from .graphlib import GraphDef as GraphDef from .graphlib import GraphState as GraphState from .graphlib import PureState as PureState from . import pytreelib as object from .pytreelib import Pytree as Pytree from .pytreelib import Object as Object from .pytreelib import Data as Data from .pytreelib import Static as Static from .pytreelib import dataclass as dataclass from .pytreelib import data as data from .pytreelib import static as static from .pytreelib import register_data_type as register_data_type from .pytreelib import is_data as is_data from .pytreelib import has_data as has_data from .pytreelib import check_pytree as check_pytree from .helpers import Dict as Dict from .helpers import List as List from .helpers import Sequential as Sequential from .helpers import TrainState as TrainState from .module import M as M from .module import Module as Module from .module import capture as capture from .module import view as view from .module import view_info as view_info from .module import with_attributes as with_attributes from .module import iter_children as iter_children, iter_modules as iter_modules from .graphlib import merge as merge from .graphlib import UpdateContext as UpdateContext from .graphlib import update_context as update_context from .graphlib import current_update_context as current_update_context from .graphlib import split as split from .graphlib import update as update from .graphlib import clone as clone from .graphlib import pop as pop from .graphlib import state as state from .graphlib import graphdef as graphdef from .graphlib import iter_graph as iter_graph from .graphlib import recursive_map as recursive_map from .graphlib import find_duplicates as find_duplicates from .graphlib import map as map from .graphlib import call as call from .graphlib import set_metadata as set_metadata from .graphlib import SplitContext as SplitContext from .graphlib import split_context as split_context from .graphlib import MergeContext as MergeContext from .graphlib import merge_context as merge_context from .graphlib import variables as variables from .graphlib import vars_as as vars_as from .graphlib import pure as pure from .graphlib import cached_partial as cached_partial from .graphlib import flatten as flatten from .graphlib import unflatten as unflatten from .graphlib import set_graph_mode as set_graph_mode from .graphlib import set_graph_updates as set_graph_updates from .nn import initializers as initializers from .nn.activations import celu as celu from .nn.activations import elu as elu from .nn.activations import gelu as gelu from .nn.activations import glu as glu from .nn.activations import hard_sigmoid as hard_sigmoid from .nn.activations import hard_silu as hard_silu from .nn.activations import hard_swish as hard_swish from .nn.activations import hard_tanh as hard_tanh from .nn.activations import leaky_relu as leaky_relu from .nn.activations import log_sigmoid as log_sigmoid from .nn.activations import log_softmax as log_softmax from .nn.activations import logsumexp as logsumexp from .nn.activations import one_hot as one_hot from .nn.activations import relu as relu from .nn.activations import relu6 as relu6 from .nn.activations import selu as selu from .nn.activations import sigmoid as sigmoid from .nn.activations import identity as identity from .nn.activations import silu as silu from .nn.activations import soft_sign as soft_sign from .nn.activations import softmax as softmax from .nn.activations import softplus as softplus from .nn.activations import standardize as standardize from .nn.activations import swish as swish from .nn.activations import tanh as tanh from .nn.activations import PReLU as PReLU from .nn.attention import MultiHeadAttention as MultiHeadAttention from .nn.attention import combine_masks as combine_masks from .nn.attention import dot_product_attention as dot_product_attention from .nn.attention import make_attention_mask as make_attention_mask from .nn.attention import make_causal_mask as make_causal_mask from .nn.recurrent import RNNCellBase as RNNCellBase from .nn.recurrent import LSTMCell as LSTMCell from .nn.recurrent import GRUCell as GRUCell from .nn.recurrent import OptimizedLSTMCell as OptimizedLSTMCell from .nn.recurrent import SimpleCell as SimpleCell from .nn.recurrent import RNN as RNN from .nn.recurrent import Bidirectional as Bidirectional from .nn.linear import Conv as Conv from .nn.linear import ConvTranspose as ConvTranspose from .nn.linear import Embed as Embed from .nn.linear import Linear as Linear from .nn.linear import LinearGeneral as LinearGeneral from .nn.linear import Einsum as Einsum from .nn.lora import LoRA as LoRA from .nn.lora import LoRALinear as LoRALinear from .nn.lora import LoRAParam as LoRAParam from .nn.normalization import BatchNorm as BatchNorm from .nn.normalization import LayerNorm as LayerNorm from .nn.normalization import RMSNorm as RMSNorm from .nn.normalization import GroupNorm as GroupNorm from .nn.normalization import InstanceNorm as InstanceNorm from .nn.normalization import SpectralNorm as SpectralNorm from .nn.normalization import WeightNorm as WeightNorm from .nn.stochastic import Dropout as Dropout from .rnglib import Rngs as Rngs from .rnglib import RngStream as RngStream from .rnglib import RngState as RngState from .rnglib import RngKey as RngKey from .rnglib import RngCount as RngCount from .rnglib import fork_rngs as fork_rngs from .rnglib import reseed as reseed from .rnglib import split_rngs as split_rngs from .rnglib import restore_rngs as restore_rngs from .spmd import PARTITION_NAME as PARTITION_NAME from .spmd import get_partition_spec as get_partition_spec from .spmd import get_named_sharding as get_named_sharding from .spmd import with_partitioning as with_partitioning from .spmd import get_abstract_model as get_abstract_model from .spmd import abstract_with_sharding as abstract_with_sharding from .statelib import FlatState as FlatState from .statelib import State as State from .statelib import to_flat_state as to_flat_state from .statelib import from_flat_state as from_flat_state from .statelib import to_pure_dict as to_pure_dict from .statelib import replace_by_pure_dict as replace_by_pure_dict from .statelib import restore_int_paths as restore_int_paths from .statelib import filter_state as filter_state from .statelib import merge_state as merge_state from .statelib import split_state as split_state from .statelib import map_state as map_state from .training import metrics as metrics from .variablelib import Param as Param # this needs to be imported before optimizer to prevent circular import from .training import optimizer as optimizer from .training.metrics import Metric as Metric from .training.metrics import MultiMetric as MultiMetric from .training.optimizer import OptState as OptState from .training.optimizer import OptArray as OptArray from .training.optimizer import OptVariable as OptVariable from .training.optimizer import Optimizer as Optimizer from .training.optimizer import ModelAndOptimizer as ModelAndOptimizer from .training.optimizer import OptState as OptState from .transforms.autodiff import DiffState as DiffState from .transforms.autodiff import grad as grad from .transforms.autodiff import value_and_grad as value_and_grad from .transforms.autodiff import custom_vjp as custom_vjp from .transforms.autodiff import vjp as vjp from .transforms.autodiff import jvp as jvp from .transforms.autodiff import remat as remat from .transforms.compilation import jit as jit from .transforms.compilation import jit_partial as jit_partial from .transforms.compilation import shard_map as shard_map from .transforms.compilation import StateSharding as StateSharding from .transforms.iteration import Carry as Carry from .transforms.iteration import scan as scan from .transforms.iteration import vmap as vmap from .transforms.iteration import pmap as pmap from .transforms.transforms import eval_shape as eval_shape from .transforms.transforms import cond as cond from .transforms.transforms import switch as switch from .transforms.transforms import checkify as checkify from .transforms.iteration import while_loop as while_loop from .transforms.iteration import fori_loop as fori_loop from .transforms.iteration import StateAxes as StateAxes from .transforms.iteration import transform_metadata as transform_metadata from .variablelib import A as A from .variablelib import BatchStat as BatchStat from .variablelib import Cache as Cache from .variablelib import Intermediate as Intermediate from .variablelib import Perturbation as Perturbation from .variablelib import Variable as Variable from .variablelib import VariableMetadata as VariableMetadata from .variablelib import with_metadata as with_metadata from .variablelib import variable_type_from_name as variable_type_from_name from .variablelib import variable_name_from_type as variable_name_from_type from .variablelib import register_variable_name as register_variable_name from .variablelib import use_eager_sharding as use_eager_sharding, using_eager_sharding as using_eager_sharding, var_defaults as var_defaults from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree from .extract import NodeStates as NodeStates from .summary import tabulate as tabulate from . import traversals as traversals from . import graphlib as graphlib # import last to prevent potential import cycles from . import graph as graph from . import compat as compat import typing as _tp if _tp.TYPE_CHECKING: VariableState = Variable else: def __getattr__(name): if name == "VariableState": import warnings warnings.warn( "'VariableState' was removed, this is just an alias to 'Variable'. " "Plase use 'Variable' directly instead.", DeprecationWarning, stacklevel=2, ) return Variable raise AttributeError(f"Module {__name__} has no attribute '{name}'") ================================================ FILE: flax/nnx/bridge/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .wrappers import functional as functional from .wrappers import Functional as Functional from .wrappers import ToNNX as ToNNX from .wrappers import lazy_init as lazy_init from .wrappers import ToLinen as ToLinen from .wrappers import to_linen as to_linen from .variables import NNXMeta as NNXMeta from .variables import with_partitioning as with_partitioning from .module import Module as Module from .module import Scope as Scope from .module import AttrPriority as AttrPriority from .module import compact as compact from .module import current_context as current_context from .module import current_module as current_module from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl from flax.nnx.nn import initializers as initializers ================================================ FILE: flax/nnx/bridge/interop.py ================================================ # Copyright 2025 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp from flax.linen import module as nn_module from flax.nnx import graphlib, rnglib from flax.nnx.bridge import wrappers from flax.nnx.bridge import module as bdg_module import flax.nnx.module as nnx_module from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape from flax.nnx.transforms.compilation import jit as nnx_jit def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module], name: str | None = None) -> nnx_module.Module: """Make pure NNX modules a submodule of a bridge module. Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation. Args: factory: a function that takes an `nnx.Rngs` arg and returns an NNX module. name: the name of the module. Only used during `bridge.compact` functions; in setup() function the user will set it to an attribute explicitly. Returns: A submodule (`nnx.Module`) of the bridge module. """ parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module() assert parent_ctx is not None and parent is not None, 'nnx_in_bridge_mdl() only needed inside bridge Module' parent = parent_ctx.module assert parent.scope is not None if parent.is_initializing(): module = factory(parent.scope.rngs) else: rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy module = nnx_eval_shape(factory, rngs, graph=True) @nnx_jit def rng_state(rngs): return graphlib.state(factory(rngs), rnglib.RngState, graph=True) # Make sure the internal rng state is not abstract - other vars shall be if parent.scope.rngs: graphlib.update(module, rng_state(parent.scope.rngs)) # Automatically set the attribute if compact. If setup, user is responsible # for setting the attribute of the superlayer. if parent_ctx.in_compact: if name is None: name = bdg_module._auto_submodule_name(parent_ctx, type(module)) setattr(parent, name, module) return module def linen_in_bridge_mdl(linen_module: nn_module.Module, name: str | None = None) -> nnx_module.Module: """Make Linen modules a submodule of a bridge module using wrappers.ToNNX(). Args: linen_module: the underlying Linen module instance. name: the name of the module. Only used during `bridge.compact` functions; in setup() function the user will set it to an attribute explicitly. Returns: A submodule (`nnx.Module`) of the bridge module. """ parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module() assert parent_ctx is not None and parent is not None, 'linen_in_bridge_mdl() only needed inside bridge Module' assert parent.scope is not None module = wrappers.ToNNX(linen_module, parent.scope.rngs) wrappers._set_initializing(module, parent.is_initializing()) if parent_ctx.in_compact: if name is None: name = bdg_module._auto_submodule_name(parent_ctx, type(linen_module)) setattr(parent, name, module) return module ================================================ FILE: flax/nnx/bridge/module.py ================================================ # Copyright 2025 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections import defaultdict import dataclasses import enum import functools import inspect import threading import typing as tp import jax import typing_extensions as tpe from flax import errors from flax.core import meta from flax.core.scope import CollectionFilter from flax.core.frozen_dict import FrozenDict from flax.nnx import graphlib, rnglib, statelib, traversals from flax.nnx import variablelib import flax.nnx.module as nnx_module from flax.nnx.pytreelib import Pytree from flax.nnx import variablelib from flax.nnx.bridge import variables as bridge_variables import numpy as np A = tp.TypeVar('A') M = tp.TypeVar('M', bound='Module') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @dataclasses.dataclass class ModuleStackEntry: module: Module in_compact: bool type_counter: defaultdict[type, int] = dataclasses.field( default_factory=lambda: defaultdict(int) ) @dataclasses.dataclass class ModuleContext(threading.local): module_stack: list[ModuleStackEntry | None] = dataclasses.field( default_factory=lambda: [None] ) MODULE_CONTEXT = ModuleContext() class ModuleState(statelib.State): pass from flax.nnx.pytreelib import register_data_type register_data_type(ModuleState) class Scope(Pytree): def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter): self.rngs = rngs self.mutable = mutable def copy(self): return Scope(self.rngs, self.mutable) class _HasSetup(tp.Protocol): def setup(self) -> None: ... def has_setup(x: tp.Any) -> tp.TypeGuard[_HasSetup]: return hasattr(x, 'setup') def _maybe_call_setup(module: Module): if ( has_setup(module) and isinstance(module, Module) and not module._pytree__state.is_setup ): # void parent context MODULE_CONTEXT.module_stack.append( ModuleStackEntry(module, in_compact=False) ) try: module.setup() # type: ignore[attribute-error] module._pytree__state._is_setup = True finally: MODULE_CONTEXT.module_stack.pop() def _bind_module(parent: Module, module: Module) -> Module: assert parent.scope is not None for _, value in reversed(list(graphlib.iter_graph(module, graph=True))): if isinstance(value, Module): if module.scope is None: value.scope = parent.scope.copy() # type: ignore[attribute-error] _maybe_call_setup(value) return module def current_context() -> ModuleStackEntry | None: return MODULE_CONTEXT.module_stack[-1] def current_module() -> Module | None: """A quick util to get the current bridge module.""" ctx = current_context() if ctx is None: return None return ctx.module def _auto_submodule_name(parent_ctx, cls): """Increment type count and generate a new submodule name.""" type_index = parent_ctx.type_counter[cls] parent_ctx.type_counter[cls] += 1 return f'{cls.__name__}_{type_index}' class ModuleMeta(nnx_module.ModuleMeta): def _pytree_meta_construct(cls, self, *args, **kwargs): vars(self)['scope'] = None super()._pytree_meta_construct(self, *args, **kwargs) def _module_meta_call(cls: type[M], *args, **kwargs) -> M: # compact behavior parent_ctx = MODULE_CONTEXT.module_stack[-1] parent = None module: M name = None if parent_ctx is not None: if 'parent' in kwargs: parent = kwargs.pop('parent') if parent_ctx.in_compact and parent is not None: raise ValueError( f"'parent' can only be set to None, got {type(parent).__name__}" ) else: parent = parent_ctx.module if 'name' in kwargs: name = kwargs['name'] if not 'name' in inspect.get_annotations(cls): kwargs.pop('name') if not isinstance(name, str): raise ValueError(f"'name' must be a 'str', got {type(name).__name__}") elif parent_ctx.in_compact: name = _auto_submodule_name(parent_ctx, cls) module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs) module.scope = None module.attr_priorities = {} if parent is not None: assert parent.scope is not None # compact, or setup if `name` exists if name is not None: setattr(parent, name, module) parent.set_attr_priority(name, AttrPriority.INIT_PARENT) return module # type: ignore # set ModuleMeta.__call__ like this because pytype doesn't understand # the use of TYPE_CHECKING conditionals for metaclass methods ModuleMeta.__call__ = _module_meta_call # type: ignore class AttrPriority(enum.IntEnum): HIGH = 0 INIT_PARENT = 20 DEFAULT = 50 LOW = 100 class PriorityStr(str): _priority: AttrPriority def __new__(cls, priority: AttrPriority, value: str): obj = super().__new__(cls, value) obj._priority = priority return obj def _check_and_get_priority(self, other) -> AttrPriority: if not isinstance(other, (str, PriorityStr)): raise NotImplementedError( f'Cannot compare {type(self)} with {type(other)}' ) if isinstance(other, PriorityStr): return other._priority return AttrPriority.DEFAULT def __lt__(self, other) -> bool: other_priority = self._check_and_get_priority(other) if self._priority == other_priority: return super().__lt__(other) return self._priority < other_priority def __gt__(self, other) -> bool: other_priority = self._check_and_get_priority(other) if self._priority == other_priority: return super().__gt__(other) return self._priority > other_priority class ModuleBase: if tp.TYPE_CHECKING: scope: Scope | None attr_priorities: dict[str, AttrPriority] @tpe.dataclass_transform(field_specifiers=(dataclasses.field,)) # type: ignore[not-supported-yet] class Module(nnx_module.Module, ModuleBase, metaclass=ModuleMeta): def __init_subclass__(cls) -> None: super().__init_subclass__(pytree=False) cls = dataclasses.dataclass(repr=False)(cls) cls.__hash__ = object.__hash__ # type: ignore[method-assign] def __getattribute__(self, name: str): return type(self)._getattr(self, name) def _getattr(self, name: str) -> tp.Any: value = super().__getattribute__(name) if isinstance(value, ModuleState): raise AttributeError return value def _setattr(self, name: str, value: tp.Any) -> None: if self.scope is not None: if name in vars(self) and isinstance( state := vars(self)[name], ModuleState ): graphlib.update(value, state) for leaf in jax.tree.leaves(value, is_leaf=graphlib.is_graph_node): if isinstance(leaf, Module): leaf._pytree__state._initializing = self.is_initializing() _bind_module(self, leaf) super()._setattr(name, value) def _graph_node_flatten(self): nodes = vars(self).copy() keys = ( PriorityStr(self.attr_priorities.get(k, AttrPriority.DEFAULT), k) for k in nodes.keys() ) sorted_nodes = list((k, nodes[k]) for k in sorted(keys)) return sorted_nodes, type(self) def set_attr_priority(self, name: str, value: AttrPriority): self.attr_priorities[name] = value def make_rng(self, name: str = 'default') -> jax.Array: if self.scope is None: raise ValueError("Can't use RNGs on unbound modules") return self.scope.rngs[name]() # type: ignore[attribute-error] def param( # type: ignore[invalid-annotation] self, name: str, init_fn: tp.Callable[..., A], *init_args, unbox: bool = True, **init_kwargs, ) -> variablelib.Param[A]: # TODO(cgarciae): implement same condition as linen if self.scope is None: raise ValueError( 'Parameters must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) if hasattr(self, name): value = getattr(self, name) # TODO(cgarciae): implement reservations # if self._name_taken(name): # raise errors.NameInUseError('param', name, self.__class__.__name__) if isinstance(value, variablelib.Variable): if not isinstance(value, variablelib.Param): raise ValueError( f"Expected '{name}' to be a Param, got {type(value).__name__}" ) return value abs_value = jax.eval_shape( lambda: init_fn(jax.random.key(0), *init_args, **init_kwargs) ) abs_value_flat = jax.tree.leaves(abs_value) value_flat = jax.tree.leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): if np.shape(val) != np.shape(abs_val): raise errors.ScopeParamShapeError( name, '', np.shape(abs_val), np.shape(val) ) if isinstance(abs_value, variablelib.VariableMetadata): abs_value.raw_value = value value = abs_value variable = variablelib.Param(value) else: value = init_fn(self.make_rng('params'), *init_args, **init_kwargs) variable = variablelib.Param(value) setattr(self, name, variable) return variable def variable( # type: ignore[invalid-annotation] self, col: str, name: str, init_fn: tp.Callable[..., A] | None = None, *init_args, unbox: bool = True, **init_kwargs, ) -> variablelib.Variable[A]: variable_type = variablelib.variable_type_from_name( col, allow_register=True ) if self.scope is None: raise ValueError( 'Variables must be initialized in `setup()` or in a method ' 'wrapped in `@compact`' ) if hasattr(self, name): value = getattr(self, name) # TODO(cgarciae): implement reservations # if self._name_taken(name): # raise errors.NameInUseError('param', name, self.__class__.__name__) if isinstance(value, variablelib.Variable): return value if init_fn is None: raise ValueError(f"Expected 'init_fn' to be a callable, got None") abs_value = jax.eval_shape(lambda: init_fn(*init_args, **init_kwargs)) abs_value_flat = jax.tree.leaves(abs_value) value_flat = jax.tree.leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): if np.shape(val) != np.shape(abs_val): raise errors.ScopeParamShapeError( name, '', np.shape(abs_val), np.shape(val) ) if isinstance(abs_value, variablelib.VariableMetadata): abs_value.raw_value = value value = abs_value variable = variable_type(value) else: if init_fn is None: raise ValueError(f"Expected 'init_fn' to be a callable, got None") value = init_fn(*init_args, **init_kwargs) variable = variable_type(value) setattr(self, name, variable) return variable def _get_variables(self) -> tp.Mapping: state = graphlib.state(self, graph=True) _variables: dict = {} variable: variablelib.Variable for path, variable in statelib.to_flat_state(state): if isinstance(variable, rnglib.RngState): # Don't return RNG states, since Linen doesn't have them. continue try: collection = variablelib.variable_name_from_type(type(variable)) except ValueError: collection = type(variable).__name__ if collection not in _variables: _variables[collection] = {} if isinstance( variable, variablelib.Variable ) and bridge_variables.is_vanilla_variable(variable): leaf = variable.get_value() else: leaf = bridge_variables.to_linen_var(variable) _variables[collection][path] = leaf _variables = { collection: traversals.unflatten_mapping(flat_state) for collection, flat_state in _variables.items() } return _variables @property def variables(self): _variables = FrozenDict(self._get_variables()) return _variables def apply( self, variables: dict[str, tp.Mapping], *args, rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None, method: tp.Callable[..., tp.Any] | str = '__call__', mutable: CollectionFilter = False, _initialize: bool = False, **kwargs, ) -> tp.Any: module = graphlib.clone(self, graph=True) # create variables real_variables = dict(variables) for col_name, linen_collection in variables.items(): def to_variable(value): return bridge_variables.to_nnx_var(col_name, value) linen_collection = jax.tree.map( to_variable, linen_collection, is_leaf=lambda x: isinstance(x, meta.AxisMetadata), ) real_variables[col_name] = linen_collection states = ({},) if not real_variables else real_variables.values() state = statelib.merge_state(*states, cls=ModuleState) graphlib.update(module, state) if rngs is None: rngs = rnglib.Rngs() elif isinstance(rngs, jax.Array | int): rngs = rnglib.Rngs(rngs) elif isinstance(rngs, dict): rngs = rnglib.Rngs(**rngs) # get method _method: tp.Callable[..., tp.Any] if isinstance(method, str): attribute_name = method _method = getattr(module, attribute_name) if not callable(_method): class_name = type(module).__name__ raise TypeError( f"'{class_name}.{attribute_name}' must be a callable, got" f' {type(_method)}.' ) # if the `method` string is a submodule, we create a lambda function # that calls the submodule, forwarding all arguments. if isinstance(_method, Module): _method = lambda module, *args, **kwargs: getattr( module, attribute_name )(*args, **kwargs) else: _method = method _method = _get_unbound_fn(_method) # set temporary state for _, value in graphlib.iter_graph(module, graph=True): if isinstance(value, Pytree): value._pytree__state._initializing = _initialize if isinstance(value, Module): value.scope = Scope(rngs, mutable) _maybe_call_setup(value) MODULE_CONTEXT.module_stack.append( ModuleStackEntry(module, in_compact=False) ) try: out = _method(module, *args, **kwargs) finally: MODULE_CONTEXT.module_stack.pop() # reset temporary state for _, value in graphlib.iter_graph(module, graph=True): if isinstance(value, Pytree): value._pytree__state._initializing = False if isinstance(value, Module): value.scope = None _variables: tp.Mapping = module._get_variables() if mutable is False: return out else: return out, _variables def init( self, rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None, *args, method: tp.Callable[..., tp.Any] | str = '__call__', **kwargs, ): out, variables = self.apply( {}, *args, _initialize=True, mutable=True, rngs=rngs, method=method, **kwargs, ) return variables def init_with_output( self, rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None, *args, method: tp.Callable[..., tp.Any] | str = '__call__', mutable: tp.Any = False, # capture_intermediates: bool | Callable[['Module', str], bool] = False, **kwargs, ) -> tuple[tp.Any, dict[str, tp.Mapping]]: return self.apply( {}, *args, rngs=rngs, method=method, mutable=True, _initialize=True, **kwargs, ) def is_initializing(self) -> bool: return self._pytree__state._initializing def compact(f: F) -> F: @functools.wraps(f) def compact_wrapper(self, *args, **kwargs): if not isinstance(self, Module): raise ValueError( f"Expected 'self' to be a nnx.bridge.Module, got {type(self).__name__}" ) MODULE_CONTEXT.module_stack.append(ModuleStackEntry(self, in_compact=True)) try: return f(self, *args, **kwargs) finally: MODULE_CONTEXT.module_stack.pop() return compact_wrapper # type: ignore def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: if inspect.ismethod(method_or_fn) and isinstance( method_or_fn.__self__, Module ): # pytype: disable=attribute-error method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error if ( not callable(method_or_fn) or len(inspect.signature(method_or_fn).parameters) < 1 ): raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn ================================================ FILE: flax/nnx/bridge/variables.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, TypeVar import typing as tp import jax from flax import struct from flax.core import meta from flax.nnx import spmd from flax.nnx import traversals from flax.nnx import variablelib from flax.typing import LogicalNames A = TypeVar('A') B = TypeVar('B') def sort_variable_types(types: tp.Iterable[type]): def _variable_parents_count(t: type): return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) parent_count = {t: _variable_parents_count(t) for t in types} return sorted(types, key=lambda t: -parent_count[t]) ############################################# ### NNX Variable <-> Linen metadata boxes ### ############################################# class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): """Default Flax metadata class for `nnx.Variable`.""" var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False) value: Any = struct.field(pytree_node=True) metadata: dict[str, tp.Any] = struct.field(pytree_node=False) def unbox(self) -> A: return self.value def replace_boxed(self, val: B) -> 'NNXMeta[B]': return self.replace(value=val) # type: ignore def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': # TODO: implement this, supporting hooks return self def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': # TODO: implement this, supporting hooks return self def get_partition_spec(self) -> jax.sharding.PartitionSpec: """Returns the ``Partitionspec`` for this partitioned value.""" nnx_var = self.to_nnx_variable() spec = spmd.get_partition_spec(nnx_var).get_raw_value() assert isinstance(spec, jax.sharding.PartitionSpec) return spec def to_nnx_variable(self) -> variablelib.Variable: return self.var_type(self.value, **self.metadata) def is_vanilla_variable(vs: variablelib.Variable) -> bool: """A variable is vanilla if its metadata is essentially blank. Returns False only if it has non-empty hooks or any non-built-in attribute. """ for key, value in vs.get_metadata().items(): if key in variablelib.Variable.required_metadata: continue if key.endswith('_hooks') and value == (): continue return False return True def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: metadata = vs.get_metadata() if 'linen_meta_type' in metadata: linen_type = metadata['linen_meta_type'] if hasattr(linen_type, 'from_nnx_metadata'): return linen_type.from_nnx_metadata({'value': vs.get_value(), **metadata}) return linen_type(vs.get_value(), **metadata) if is_vanilla_variable(vs): return vs.get_value() return NNXMeta(type(vs), vs.get_value(), metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: """Given the keypath of a Flax variable type, return its Linen collection name.""" # Infer variable type from the leaf's path, which contains its Linen collection name assert isinstance(keypath[0], jax.tree_util.DictKey) return str(keypath[0].key) def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variablelib.Variable: """Convert a Linen variable to an NNX variable.""" vtype = variablelib.variable_type_from_name(col, allow_register=True) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' return x.to_nnx_variable() if isinstance(x, meta.AxisMetadata): x_metadata = vars(x) if hasattr(x, 'to_nnx_metadata'): x_metadata = x.to_nnx_metadata() assert hasattr(x, 'value') return vtype(**x_metadata, linen_meta_type=type(x)) return vtype(x) def _recursive_merge(dict1, dict2): """Recursively merge two dicts.""" flat_map = traversals.flatten_mapping(dict1) flat_map |= traversals.flatten_mapping(dict2) return traversals.unflatten_mapping(flat_map) def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: """Convert a dict of Linen-style variables to NNX variables.""" nnx_vars = jax.tree_util.tree_map_with_path( lambda kp, x: to_nnx_var(get_col_name(kp), x), variables, is_leaf=lambda x: not isinstance(x, dict), ) flat_paths: dict[tuple, tp.Any] = {} for col_name, col_variables in nnx_vars.items(): # pylint: disable=unused-variable for path, variable in traversals.flatten_mapping(col_variables).items(): if path in flat_paths: raise ValueError( f"Found duplicate variable path {path} with variables " f"{flat_paths[path]} and {variable}. " "This is not allowed in NNX." ) flat_paths[path] = variable nnx_vars = traversals.unflatten_mapping(flat_paths) return nnx_vars def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: """Convert a dict of NNX variables to Linen-style variables.""" linen_structured = {} for kp, v in traversals.flatten_mapping(nnx_attrs).items(): if isinstance(v, variablelib.Variable): col_name = variablelib.variable_name_from_type(type(v)) v = to_linen_var(v.to_state()) else: raise ValueError(f'Cannot infer collection name from value: {v}') linen_structured[(col_name, *kp)] = v variables = traversals.unflatten_mapping(linen_structured) return variables def with_partitioning( fn: tp.Callable[..., tp.Any], names: LogicalNames, mesh: jax.sharding.Mesh | None = None, ) -> tp.Callable[..., meta.Partitioned[tp.Any]]: """Same interface as Linen, but calls NNX `with_partitioning` within.""" return spmd.with_partitioning(fn, names, mesh, linen_meta_type=meta.Partitioned) ================================================ FILE: flax/nnx/bridge/wrappers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import typing as tp from typing import Any import warnings import dataclasses from flax import linen from flax import core from flax import nnx from flax.core import FrozenDict from flax.core import meta from flax.nnx import graphlib from flax.nnx import variablelib from flax.nnx.bridge import variables as bv from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module from flax.nnx.statelib import State from flax.nnx.pytreelib import Pytree from flax.nnx.rnglib import Rngs import jax from jax import tree_util as jtu M = tp.TypeVar('M', bound=Module) # Flax-like style is NNX @dataclasses.dataclass class Functional(tp.Generic[M]): module_type: tp.Type[M] graphdef: tp.Optional[graphlib.GraphDef[M]] args: tuple[tp.Any, ...] kwargs: dict[str, tp.Any] def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: kwargs = {} if rngs is not None: kwargs['rngs'] = rngs module = self.module_type(*self.args, **self.kwargs, **kwargs) graphdef, state = nnx.split(module) self.graphdef = graphdef return state # type: ignore def apply(self, *states: tp.Any): assert self.graphdef is not None return self.graphdef.apply(*states) def functional(cls: tp.Type[M]) -> tp.Callable[..., Functional[M]]: def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]: return Functional(cls, None, args, kwargs) return _functional_constructor def _set_initializing(module: Module, initializing: bool): for _, value in graphlib.iter_graph(module, graph=True): if isinstance(value, Pytree): value._pytree__state._initializing = initializing def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs): """To run through an arbitrary nnx.Module method and initialize all its needed state. Here used to trigger initialization of all `LinenToNNX` module variables.""" if isinstance(fn, Module): module = fn assert callable(fn) else: if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)): raise ValueError(f'{fn = } needs to be a method of an NNX Module.') module = fn.__self__ _set_initializing(module, True) try: _ = fn(*args, **kwargs) finally: _set_initializing(module, False) return fn def current_linen_module() -> linen.Module | None: """Get the current Linen module from the Linen context.""" if linen.module._context.module_stack: # pylint: disable=W0212 return linen.module._context.module_stack[-1] # pylint: disable=W0212 return None class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. The result NNX module can be used standalone with all NNX APIs, or as a submodule of another NNX module. Since Linen module initialization requires a sample input, you need to call `lazy_init` with an argument to initialize the variables. Example:: >>> from flax import linen as nn, nnx >>> import jax >>> linen_module = nn.Dense(features=64) >>> x = jax.numpy.ones((1, 32)) >>> # Like Linen init(), initialize with a sample input >>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) >>> # Like Linen apply(), but using NNX's direct call method >>> y = model(x) >>> model.kernel.shape (32, 64) Args: module: The Linen Module instance. rngs: The `nnx.Rngs` instance being passed to any NNX module. Returns: A stateful NNX module that behaves the same as the wrapped Linen module. """ def __init__( self, module: linen.Module, rngs: Rngs | jax.Array | None = None, ): self.to_nnx__module = module self.to_nnx__rngs: Rngs | None if isinstance(rngs, jax.Array): self.to_nnx__rngs = Rngs(params=rngs) elif isinstance(rngs, nnx.Rngs): self.to_nnx__rngs = rngs.fork() else: self.to_nnx__rngs = rngs @property def rngs(self) -> Rngs | None: warnings.warn( '`ToNNX.rngs` is deprecated. Please use `to_nnx__rngs` instead.', DeprecationWarning, ) return self.to_nnx__rngs @property def module(self) -> linen.Module: warnings.warn( '`ToNNX.module` is deprecated. Please use `to_nnx__module` instead.', DeprecationWarning, ) return self.to_nnx__module def _setattr(self, name, value): if not nnx.is_data(value) and nnx.has_data(value): value = nnx.data(value) super()._setattr(name, value) def lazy_init(self, *args, **kwargs): """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" return lazy_init(self, *args, **kwargs) def __getattr__(self, name: str): if hasattr(super(), name): return super().__getattribute__(name) maybe_method = getattr(type(self.to_nnx__module), name, None) if callable(maybe_method): method = partial(self.__call__, method=maybe_method) method.__self__ = self return method return super().__getattribute__(name) def __call__( self, *args: Any, rngs: Rngs | jax.Array | None = None, method: tp.Callable[..., Any] | str | None = None, mutable: tp.Any = None, **kwargs: Any, ) -> Any: # Shape-based lazy init of the flax variables if rngs is None: rngs = self.to_nnx__rngs if isinstance(rngs, nnx.Rngs): _rngs = {name: stream() for name, stream in rngs.items()} elif isinstance(rngs, jax.Array): _rngs = {'params': rngs} else: _rngs = {} # rename default to params if 'params' not in _rngs and 'default' in _rngs: _rngs['params'] = _rngs.pop('default') if self._pytree__state.initializing: out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs) else: nnx_attrs = { k: v for k, v in vars(self).items() if not k.startswith('to_nnx__') and not k.startswith('_pytree__') } variables = bv.nnx_attrs_to_linen_vars(nnx_attrs) # Get `mutable` from top level bridge.Module context if any if mutable is not None: pass elif (m := bdg_module.current_module()) is not None: # type: ignore[assignment] assert m.scope is not None mutable = m.scope.mutable elif (m := current_linen_module()) is not None: # type: ignore[assignment] assert m.scope is not None mutable = m.scope.mutable else: mutable = False out = self.to_nnx__module.apply( variables, *args, rngs=_rngs, method=method, mutable=mutable, **kwargs ) # Split out the updates if `mutable` is passed into the Flax module if mutable is not False: out, updates = out else: updates = None # Split out the updates if `mutable` is passed into the Flax module if updates: nnx_attrs = bv.linen_vars_to_nnx_attrs(updates) # nnx.update(self, nnx_attrs) # TODO(cgarciae): ideally we just do an update but currently dictionaries don't allow # insertion of new keys, we need to enable this in NNX to simplify the code bellow # to the simple nnx.update(self, nnx_attrs) above. for attr_name, value in nnx_attrs.items(): if hasattr(self, attr_name) and isinstance(value, dict): original_value = getattr(self, attr_name) new_values = bv._recursive_merge(original_value, value) setattr(self, attr_name, nnx.data(new_values)) else: setattr(self, attr_name, nnx.data(value)) return out def linen_rngs_dict(linen_module: linen.Module, add_default: bool = False): """Given a module, split out one of its every active RNG key collections.""" assert linen_module.scope is not None, 'linen_rngs_dict() must be called inside a Linen module.' rngs: dict[str, tp.Any] = { name: linen_module.make_rng(name) for name in linen_module.scope.rngs.keys() } if add_default and 'default' not in rngs: rngs['default'] = 0 return rngs def _get_module_method(module, method: tp.Callable[..., Any] | str | None): """Get a callable method from the module, or raise TypeError.""" if method is None: method = '__call__' if isinstance(method, str): attribute_name = method method = getattr(type(module), attribute_name) if not callable(method): class_name = type(module).__name__ raise TypeError( f"'{class_name}.{attribute_name}' must be a callable, got" f' {type(method)}.' ) if not callable(method): class_name = type(module).__name__ raise TypeError( f"'{method}' must be a callable, got {type(method)}." ) return method class ToLinen(linen.Module): """A wrapper to turn any NNX module into a Linen module. The result Linen module can be used standalone with all Linen APIs, or as a submodule of another Linen module. Since NNX modules are stateful and owns the state, we only create it once during init time, and will track its state and static data as separate variables. Example:: >>> from flax import linen as nn, nnx >>> import jax >>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64)) >>> x = jax.numpy.ones((1, 32)) >>> y, variables = model.init_with_output(jax.random.key(0), x) >>> y.shape (1, 64) >>> variables['params']['kernel'].shape (32, 64) >>> # The static GraphDef of the underlying NNX module >>> variables.keys() dict_keys(['params']) Args: nnx_class: The NNX Module class (not instance!). args: The arguments that normally would be passed in to create the NNX module. kwargs: The keyword arguments that normally would be passed in to create the NNX module. skip_rng: True if this NNX module doesn't need `rngs` arg during initialization (not common). Returns: A stateful NNX module that behaves the same as the wrapped Linen module. """ nnx_class: tp.Callable[..., Module] args: tp.Sequence = () kwargs: tp.Mapping[str, tp.Any] = FrozenDict({}) skip_rng: bool = False metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = bv.to_linen_var @linen.compact def __call__( self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs ): def _module_kwargs(): maybe_add_default = not self.is_initializing() module_kwargs = dict(self.kwargs) if not self.skip_rng: module_kwargs['rngs'] = nnx.Rngs( **linen_rngs_dict(self, add_default=maybe_add_default) ) return module_kwargs # init codepath if self.is_initializing(): module = self.nnx_class(*self.args, **_module_kwargs()) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. # update linen variables before call module to save initial state self._update_variables(module) method_fn = _get_module_method(module, nnx_method) out = method_fn(module, *args, **kwargs) return out # create the nnx module module = self.nnx_class(*self.args, **_module_kwargs()) # update nnx module from linen variables def maybe_unbox(x): if isinstance(x, meta.AxisMetadata): return x.unbox() return x states = jtu.tree_map( maybe_unbox, list(core.unfreeze(self.variables).values()), # type: ignore[wrong-arg-types, arg-type] is_leaf=lambda x: isinstance(x, meta.AxisMetadata), ) if not states: states = ({},) new_state = nnx.merge_state(*states) new_state_flat = nnx.traversals.flatten_mapping(new_state) current_state_flat = nnx.traversals.flatten_mapping(nnx.state(module)) unknown_state_flat = {path: v for path, v in new_state_flat.items() if path not in current_state_flat} if unknown_state_flat: paths_str = "" for path, _ in unknown_state_flat.items(): paths_str += f"\n - {'/'.join(map(str, path))}" warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) method_fn = _get_module_method(module, nnx_method) out = method_fn(module, *args, **kwargs) self._update_variables(module) return out def __getattr__(self, name: str): if hasattr(super(), name): return super().__getattribute__(name) if name in self.kwargs: return self.kwargs[name] maybe_method = getattr(self.nnx_class, name, None) if callable(maybe_method): method = partial(self.__call__, nnx_method=maybe_method) method.__self__ = self return method return super().__getattribute__(name) def _update_variables(self, module): """Store the NNX module's graph def and state inside Linen module variables.""" state = nnx.state(module, nnx.Not(nnx.RngState)) collection_flat_state: dict[str, list[tuple[tuple[tp.Any, ...], tp.Any]]] = {} # group state by collection for path, leaf in nnx.to_flat_state(state): type_ = type(leaf) collection = variablelib.variable_name_from_type( type_, allow_register=True ) if collection not in collection_flat_state: collection_flat_state[collection] = [] collection_flat_state[collection].append((path, leaf)) # update linen variables for collection, flat_state in collection_flat_state.items(): if self.is_mutable_collection(collection): def _to_linen_var(x): if isinstance(x, nnx.Variable): if self.metadata_fn is not None: return self.metadata_fn(x) # pylint: disable=too-many-function-args else: return x.get_value() return x collection_state = nnx.traversals.unflatten_mapping(flat_state) collection_state = jax.tree.map( _to_linen_var, collection_state, is_leaf=lambda x: isinstance(x, nnx.Variable), ) for k, v in collection_state.items(): self.put_variable(collection, k, v) class _Missing: ... _MISSING = _Missing() def to_linen( nnx_class: tp.Callable[..., Module], *args, metadata_fn: ( tp.Callable[[variablelib.Variable], tp.Any] | None ) = bv.to_linen_var, name: str | None = None, skip_rng: bool = False, abstract_init: bool = True, **kwargs, ): """Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields.""" return ToLinen( nnx_class, args=args, kwargs=FrozenDict(kwargs), metadata_fn=metadata_fn, skip_rng=skip_rng, name=name, ) def to_linen_class( base_nnx_class: type[M], base_metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = bv.to_linen_var, base_skip_rng: bool = False, **partial_kwargs: tp.Any, ) -> type[ToLinen]: """Dynamically wraps an NNX module class into a Flax Linen module class.""" class ToLinenPartial(ToLinen): """A dynamically created Linen Module that wraps a specific NNX Module. This class is not meant to be used directly. Instead, it is created and returned by the `to_linen_class` function. It acts as a "partially applied" version of the `ToLinen` wrapper, where the NNX module to be wrapped and its default arguments are pre-configured. When you instantiate this class, it behaves like a standard Linen module. The arguments you provide during instantiation can override the defaults that were set when this class was created by `to_linen_class`. For example: >>> from flax import linen as nn, nnx >>> from maxtext.src.maxtext.layers import linears >>> # Create a specialized Linen wrapper for linears.DenseGeneral >>> LinenDenseGeneral = to_linen_class(linears.DenseGeneral) >>> # Now, LinenDenseGeneral can be used like a regular Linen module >>> class MyModel(nn.Module): ... def setup(self): ... # Instantiate the wrapped linears.DenseGeneral with its arguments ... self.dense = LinenDenseGeneral( ... in_features_shape=10, out_features_shape=5 ... ) ... def __call__(self, x): ... return self.dense(x) Attributes: (The attributes are dynamically set by the `ToLinen` parent class based on the arguments provided during instantiation.) """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) def __init__( self, args=None, kwargs=None, nnx_class=None, skip_rng=None, metadata_fn=None, name=_MISSING, parent=_MISSING, **other_kwargs, ): linen_kwargs = {} if not isinstance(parent, _Missing): linen_kwargs["parent"] = parent if not isinstance(name, _Missing): linen_kwargs["name"] = name ToLinen.__init__( self, nnx_class=nnx_class or base_nnx_class, args=args or (), metadata_fn=metadata_fn or base_metadata_fn, skip_rng=skip_rng or base_skip_rng, kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), **linen_kwargs, ) cls.__init__ = __init__ return ToLinenPartial ================================================ FILE: flax/nnx/compat.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Compat API. Compatibility wrappers for NNX APIs. Each function in this module mirrors the corresponding ``nnx.*`` API but enforces ``graph=True`` (and ``graph_updates=True`` for transforms), preserving the pre-tree-mode behavior. """ import functools from flax.nnx import graphlib as _graphlib from flax.nnx import module as _module from flax.nnx import rnglib as _rnglib from flax.nnx.transforms import autodiff as _autodiff from flax.nnx.transforms import compilation as _compilation from flax.nnx.transforms import iteration as _iteration from flax.nnx.transforms import transforms as _transforms from flax.nnx import spmd as _spmd # graphlib split = functools.partial(_graphlib.split, graph=True) state = functools.partial(_graphlib.state, graph=True) clone = functools.partial(_graphlib.clone, graph=True) graphdef = functools.partial(_graphlib.graphdef, graph=True) flatten = functools.partial(_graphlib.flatten, graph=True) iter_graph = functools.partial(_graphlib.iter_graph, graph=True) recursive_map = functools.partial(_graphlib.recursive_map, graph=True) # module view = functools.partial(_module.view, graph=True) view_info = functools.partial(_module.view_info, graph=True) iter_modules = functools.partial(_module.iter_modules, graph=True) iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type] # rnglib split_rngs = functools.partial(_rnglib.split_rngs, graph=True) fork_rngs = functools.partial(_rnglib.fork_rngs, graph=True) reseed = functools.partial(_rnglib.reseed, graph=True) backup_keys = functools.partial(_rnglib.backup_keys, graph=True) # transforms - compilation jit = functools.partial(_compilation.jit, graph=True, graph_updates=True) shard_map = functools.partial( _compilation.shard_map, graph=True, graph_updates=True ) # transforms - autodiff grad = functools.partial(_autodiff.grad, graph=True, graph_updates=True) value_and_grad = functools.partial( _autodiff.value_and_grad, graph=True, graph_updates=True ) custom_vjp = functools.partial( _autodiff.custom_vjp, graph=True, graph_updates=True ) vjp = functools.partial(_autodiff.vjp, graph=True, graph_updates=True) jvp = functools.partial(_autodiff.jvp, graph=True, graph_updates=True) remat = functools.partial(_autodiff.remat, graph=True, graph_updates=True) # transforms - iteration vmap = functools.partial(_iteration.vmap, graph=True, graph_updates=True) scan = functools.partial(_iteration.scan, graph=True, graph_updates=True) pmap = functools.partial(_iteration.pmap, graph=True, graph_updates=True) while_loop = functools.partial( _iteration.while_loop, graph=True, graph_updates=True ) fori_loop = functools.partial( _iteration.fori_loop, graph=True, graph_updates=True ) # transforms - general eval_shape = functools.partial( _transforms.eval_shape, graph=True, graph_updates=True ) cond = functools.partial(_transforms.cond, graph=True, graph_updates=True) switch = functools.partial(_transforms.switch, graph=True, graph_updates=True) checkify = functools.partial( _transforms.checkify, graph=True, graph_updates=True ) # spmd get_abstract_model = functools.partial(_spmd.get_abstract_model, graph=True) ================================================ FILE: flax/nnx/extract.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import functools import typing as tp from collections import namedtuple import jax from flax import struct from flax import typing from flax.nnx.pytreelib import Pytree from flax.typing import Missing, PathParts from flax.nnx import graphlib, variablelib A = tp.TypeVar('A') Index = int KeyPath = tuple[tp.Hashable, ...] Prefix = tp.Any Leaf = tp.Any class PrefixMapping(abc.ABC): @abc.abstractmethod def map_prefix( self, path: typing.PathParts, variable: variablelib.Variable, /, ) -> tp.Any: ... def check_consistent_aliasing( node: tp.Any, prefix: tp.Any, /, *, node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] | None = None, ): """Check for consistent aliasing of nodes when extracting graph.""" if node_prefixes is None: node_prefixes = {} # Store variable references for error messages node_id_to_variable: dict[int, tp.Any] = {} # collect all paths and prefixes for each node for path, value in graphlib.iter_graph(node, graph=True): if graphlib.is_graph_node(value) or isinstance(value, graphlib.Variable): if isinstance(value, Pytree): value._check_valid_context( lambda: f'Trying to extract graph node from different trace level, got {value!r}' ) if isinstance(value, graphlib.Variable): if not value._can_update: raise ValueError( f'Cannot extract graph node from different trace level, got {value!r}' ) if isinstance(prefix, PrefixMapping): variable_prefix = prefix.map_prefix(path, value) else: variable_prefix = prefix value_id = id(value) node_id_to_variable[value_id] = value if value_id in node_prefixes: paths_prefixes = node_prefixes[value_id] paths_prefixes.append((path, variable_prefix)) else: node_prefixes[value_id] = [(path, variable_prefix)] # check for inconsistent aliasing node_msgs = [] for node_id, paths_prefixes in node_prefixes.items(): unique_prefixes = {prefix for _, prefix in paths_prefixes} if len(unique_prefixes) > 1: path_prefix_repr = '\n'.join( f' {"/".join(map(str,path)) if path else ""}: {prefix}' for path, prefix in paths_prefixes ) # Get the variable type name if available if node_id in node_id_to_variable: variable = node_id_to_variable[node_id] node_type_name = type(variable).__name__ else: node_type_name = f'Node ID: {node_id}' nodes_msg = f'Node: {node_type_name}\n{path_prefix_repr}' node_msgs.append(nodes_msg) if node_msgs: raise ValueError( 'Inconsistent aliasing detected. The following nodes have different prefixes:\n' + '\n'.join(node_msgs) ) def check_consistent_aliasing2( node: tp.Any, prefix: tp.Any, /, *, base_path: tuple[tp.Any, ...] = (), node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] | None = None, ): if node_prefixes is None: node_prefixes = {} node_id_to_variable: dict[int, tp.Any] = {} for path, value in graphlib.iter_graph(node, graph=True): path = base_path + path if graphlib.is_graph_node(value) or isinstance(value, graphlib.Variable): value_id = id(value) node_id_to_variable[value_id] = value if value_id in node_prefixes: node_prefixes[value_id].append((path, prefix)) else: node_prefixes[value_id] = [(path, prefix)] node_msgs = [] for node_id, paths_prefixes in node_prefixes.items(): unique_prefixes = {p for _, p in paths_prefixes} if len(unique_prefixes) > 1: path_prefix_repr = '\n'.join( f' {"/".join(map(str,path)) if path else ""}: {p}' for path, p in paths_prefixes ) if node_id in node_id_to_variable: variable = node_id_to_variable[node_id] node_type_name = type(variable).__name__ else: node_type_name = f'Node ID: {node_id}' node_msgs.append(f'Node: {node_type_name}\n{path_prefix_repr}') if node_msgs: raise ValueError( 'Inconsistent aliasing detected. The following nodes have different prefixes:\n' + '\n'.join(node_msgs) ) # ----------------------------- # to_tree/from_tree # ----------------------------- def broadcast_prefix( prefix_tree: tp.Any, full_tree: tp.Any, prefix_is_leaf: tp.Callable[[tp.Any], bool] | None = None, tree_is_leaf: tp.Callable[[tp.Any], bool] | None = None, ) -> list[tp.Any]: # If prefix_tree is not a tree prefix of full_tree, this code can raise a # ValueError; use prefix_errors to find disagreements and raise more precise # error messages. result = [] num_leaves = lambda t: jax.tree_util.tree_structure( t, is_leaf=tree_is_leaf ).num_leaves add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree)) jax.tree.map( add_leaves, prefix_tree, full_tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x) or (prefix_is_leaf is not None and prefix_is_leaf(x)), ) return result def broadcast_prefix2( prefix_tree: tp.Any, full_tree: tp.Any, is_leaf: tp.Callable[[tp.Any], bool] | None = None, ) -> tuple[list[KeyPath], list[tp.Any]]: paths: list[KeyPath] = [] leaves: list[tp.Any] = [] num_leaves = lambda t: jax.tree_util.tree_structure(t).num_leaves def add_leaves(path, x, subtree): n = num_leaves(subtree) paths.extend([path] * n) leaves.extend([x] * n) jax.tree.map_with_path(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf) return paths, leaves def broadcast_prefix_map( f: tp.Callable[..., tp.Any], prefix_tree: tp.Any, full_tree: tp.Any, *rest: tp.Any, is_leaf: tp.Callable[[tp.Any], bool] | None = None, ) -> tp.Any: paths, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf) leaves, treedef = jax.tree_util.tree_flatten(full_tree, is_leaf=is_leaf) full_prefix_tree = treedef.unflatten(prefix_leaves) return jax.tree.map_with_path(f, full_prefix_tree, full_tree, *rest, is_leaf=is_leaf) class GraphDefState(struct.PyTreeNode): graphdef: graphlib.GraphDef[tp.Any] = struct.field(pytree_node=False) state: graphlib.GraphState = struct.field(pytree_node=True) S = tp.TypeVar( 'S', bound=graphlib.GraphState | graphlib.GraphFlatState | list[tp.Any] ) class NodeStates(struct.PyTreeNode): _graphdef: graphlib.GraphDef[tp.Any] | None states: tuple[tp.Any, ...] metadata: tp.Any = struct.field(pytree_node=False) @property def graphdef(self) -> graphlib.GraphDef[tp.Any]: if self._graphdef is None: raise ValueError('No graphdef available') return self._graphdef @property def state(self) -> tp.Any: if len(self.states) != 1: raise ValueError( f'Expected exactly one GraphDefState, got {len(self.states)}' ) return self.states[0] @classmethod def from_split( cls, graphdef: graphlib.GraphDef[tp.Any] | None, state: tp.Any, /, *states: tp.Any, metadata: tp.Any = None, ): return cls(_graphdef=graphdef, states=(state, *states), metadata=metadata) @classmethod def from_states( cls, state: tp.Any, *states: tp.Any, ): return cls(_graphdef=None, states=(state, *states), metadata=None) @classmethod def from_prefixes( cls, prefixes: tp.Iterable[tp.Any], /, *, metadata: tp.Any = None, ): return cls(_graphdef=None, states=tuple(prefixes), metadata=metadata) def default_split_fn( ctx: graphlib.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf ) -> tp.Any: return NodeStates.from_split(*ctx.split(leaf)) def to_tree( tree, /, *, prefix: tp.Any = Missing, split_fn: tp.Callable[ [graphlib.SplitContext, KeyPath, Prefix, Leaf], tp.Any ] = default_split_fn, map_non_graph_nodes: bool = False, ctxtag: tp.Hashable | None = None, check_aliasing: bool = True, ) -> tp.Any: if prefix is Missing or prefix is None: # fast path, no need for prefix broadcasting or consistent aliasing checks with graphlib.split_context(ctxtag) as split_ctx: return jax.tree.map( lambda x: split_fn(split_ctx, (), prefix, x) if map_non_graph_nodes or graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable) else x, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), ) leaf_prefixes = broadcast_prefix( prefix, tree, prefix_is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), tree_is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), ) leaf_keys, treedef = jax.tree_util.tree_flatten_with_path( tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), ) assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] = {} with graphlib.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): if graphlib.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable): if check_aliasing: check_consistent_aliasing( leaf, leaf_prefix, node_prefixes=node_prefixes ) tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf) leaves_out.append(tree_node) else: if map_non_graph_nodes: leaf = split_fn(split_ctx, keypath, leaf_prefix, leaf) leaves_out.append(leaf) pytree_out = jax.tree.unflatten(treedef, leaves_out) return pytree_out def to_tree2( tree, /, *, prefix: tp.Any = Missing, check_aliasing: bool = True, ) -> tp.Any: """to_tree2 has two main tasks: 1. Convert all graph nodes to NodeStates (a tree representation). 2. Check all Variables are aliased consistently given the prefix tree, e.g. vmap's in/out_axes arguments. Each NodeState contains the `GraphDef` and State for each object, these are generated using `graphlib.flatten`. `extract.broadcast_prefix` is used to calculate the prefix for each node, `check_consistent_aliasing2` traverses the nodes subgraph and checks for Variable aliasing. """ ref_index: graphlib.RefMap = graphlib.RefMap() def _to_node_states(leaf): if not (graphlib.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable)): return leaf graphdef, flat_state = graphlib.flatten( leaf, ref_index=ref_index, graph=True ) (state,) = graphlib._to_nested_state(graphdef, (flat_state,)) return NodeStates.from_split(graphdef, state) is_leaf = lambda x: ( isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x) ) if prefix is Missing or prefix is None: return jax.tree.map(_to_node_states, tree, is_leaf=is_leaf) leaf_prefixes = broadcast_prefix( prefix, tree, prefix_is_leaf=lambda x: x is None or is_leaf(x), tree_is_leaf=is_leaf, ) leaf_paths, treedef = jax.tree_util.tree_flatten_with_path(tree, is_leaf=is_leaf) assert len(leaf_paths) == len(leaf_prefixes) leaves_out = [] node_prefixes: dict[int, list[tuple[PathParts, tp.Any]]] = {} for (keypath, leaf), leaf_prefix in zip(leaf_paths, leaf_prefixes): if is_leaf(leaf): if check_aliasing: base_path = graphlib.jax_to_nnx_path(keypath) check_consistent_aliasing2( leaf, leaf_prefix, base_path=base_path, node_prefixes=node_prefixes ) leaves_out.append(_to_node_states(leaf)) else: leaves_out.append(leaf) return jax.tree.unflatten(treedef, leaves_out) def from_tree2(tree: tp.Any, /) -> tp.Any: index_ref = graphlib.IndexMap() def _from_node_states(x): if not isinstance(x, NodeStates): return x state = graphlib._merge_to_flat_state(x.states) return graphlib.unflatten( x.graphdef, state, index_ref=index_ref, ) return jax.tree.map( _from_node_states, tree, is_leaf=lambda x: ( isinstance(x, NodeStates) or graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable) ), ) def merge_tree_node( ctx: graphlib.MergeContext, path: KeyPath, prefix: Prefix, leaf: Leaf ) -> tp.Any: if not isinstance(leaf, NodeStates): raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') return ctx.merge(leaf.graphdef, *leaf.states) def is_tree_node(x): return isinstance(x, NodeStates) def from_tree( tree: tp.Any, /, *, prefix: tp.Any = Missing, merge_fn: tp.Callable[ [graphlib.MergeContext, KeyPath, Prefix, Leaf], tp.Any ] = merge_tree_node, is_node_leaf: tp.Callable[[Leaf], bool] = is_tree_node, is_leaf: tp.Callable[[Leaf], bool] = is_tree_node, map_non_graph_nodes: bool = False, is_inner: bool | None = None, ctxtag: tp.Hashable | None = None, ) -> tp.Any: if prefix is Missing or prefix is None: # fast path, no need for prefix broadcasting or consistent aliasing checks with graphlib.merge_context(ctxtag, is_inner) as merge_ctx: def maybe_split(x): if ( map_non_graph_nodes or is_node_leaf(x) or isinstance(x, variablelib.Variable) ): return merge_fn(merge_ctx, (), prefix, x) return x return jax.tree.map(maybe_split, tree, is_leaf=is_leaf) leaf_prefixes = broadcast_prefix( prefix, tree, prefix_is_leaf=lambda x: x is None or is_leaf(x), tree_is_leaf=is_leaf, ) leaf_keys, treedef = jax.tree_util.tree_flatten_with_path( tree, is_leaf=is_leaf ) assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] with graphlib.merge_context(ctxtag, is_inner) as merge_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): if ( map_non_graph_nodes or is_node_leaf(leaf) or isinstance(leaf, variablelib.Variable) ): leaf = merge_fn(merge_ctx, keypath, leaf_prefix, leaf) leaves_out.append(leaf) pytree_out = jax.tree.unflatten(treedef, leaves_out) return pytree_out def clear_non_graph_nodes(tree): return jax.tree.map( lambda x: x if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable) else None, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), ) class Mask(tp.NamedTuple): pass def mask_at(t: tuple, index: int | None) -> tuple: if index is None: return t return tuple( Mask() if i == index else x for i, x in enumerate(t) ) def replace_at(t: tuple, index: int, value: tp.Any) -> tuple: return tuple( value if i == index else x for i, x in enumerate(t) ) def updates_and_snapshot(args: A) -> tuple[A, A]: is_leaf = lambda x: isinstance(x, variablelib.Variable) leaves, treedef = jax.tree.flatten(args, is_leaf=is_leaf) updates_leaves: list[variablelib.Variable | Mask] = [] snapshot_leaves: list[variablelib.Variable | Mask] = [] for leaf in leaves: if isinstance(leaf, variablelib.Variable): updates_leaves.append(leaf) snapshot_leaves.append(leaf.copy()) else: updates_leaves.append(Mask()) snapshot_leaves.append(Mask()) updates = jax.tree.unflatten(treedef, updates_leaves) snapshot = jax.tree.unflatten(treedef, snapshot_leaves) return updates, snapshot def check_no_aliases(fn_name: str, /, **kwargs): Attrs = namedtuple('Attrs', kwargs.keys()) # type: ignore[misc] container = Attrs(**kwargs) is_leaf = lambda x: isinstance(x, variablelib.Variable) seen: dict[int, jax.tree_util.KeyPath] = {} for path, leaf in jax.tree.leaves_with_path( container, is_leaf=is_leaf ): if not isinstance(leaf, variablelib.Variable): continue var_id = id(leaf) if var_id in seen: path_str = jax.tree_util.keystr(path) seen_path_str = jax.tree_util.keystr(seen[var_id]) raise ValueError( f'Duplicate {leaf}\nfound at paths:\n\n' f' - {seen_path_str}\n' f' - {path_str}\n\n' f'nnx.{fn_name} with graph_updates=False does not support ' 'returning input Variables as outputs. ' f'Consider the following options:\n\n' f'1. Remove the duplicate Variables.\n' f'2. Create new Variables via nnx.clone() and use those instead.\n' f'3. Enable graph mode and graph updates by passing graph=True and ' f'graph_updates=True to {fn_name}\n\n' f' nnx.{fn_name}(..., graph=True, graph_updates=True)\n\n' f'4. Use nnx.compat.{fn_name} (sets graph and graph_updates to True ' f'automatically)\n\n' f' nnx.compat.{fn_name}(...)' ) seen[var_id] = path def check_prefix(prefix: tp.Any, prefix_name: str, fn_name: str): def _check(path, leaf): if graphlib.is_graph_node(leaf) or isinstance(leaf, variablelib.Variable): raise ValueError( f'Found graph node or Variable of type {type(leaf).__name__} ' f'at path {jax.tree_util.keystr(path)} in `{prefix_name}` ' f'for nnx.{fn_name}. Graph nodes and Variables are not allowed ' f'as prefixes when graph=True and graph_updates=False' ) jax.tree.map_with_path( _check, prefix, is_leaf=lambda x: isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x), ) def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool: post_leaves, post_td = jax.tree.flatten(post) pre_leaves, pre_td = jax.tree.flatten(pre) return post_td != pre_td or any( # type: ignore[operator] a is not b for a, b in zip(post_leaves, pre_leaves) ) KeepFn = tp.Callable[ [PathParts, tp.Any, variablelib.Variable, variablelib.Variable], bool ] def mask_variable_updates( current_tree: A, snapshot_tree: A, *, prefix: tp.Any = Missing, keep_fn: KeepFn | None = None, ) -> A: if keep_fn is None: keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap) def _mask_updates(path, prefix_leaf, current, snapshot): if isinstance(current, variablelib.Variable): if current.hijax or current.ref: return Mask() if keep_fn(path, prefix_leaf, current, snapshot): return current return Mask() is_leaf = lambda x: isinstance(x, variablelib.Variable) or x is None if prefix is Missing: return jax.tree.map_with_path( lambda path, cur, snap: _mask_updates(path, None, cur, snap), current_tree, snapshot_tree, is_leaf=is_leaf, ) return broadcast_prefix_map( _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf, ) def apply_variable_updates(args_tree: A, updates_tree: A): is_leaf = lambda x: isinstance(x, variablelib.Variable) or isinstance(x, Mask) args_leaves = jax.tree.leaves(args_tree, is_leaf=is_leaf) _, treedef = jax.tree.flatten(args_tree, is_leaf=is_leaf) updates_leaves = treedef.flatten_up_to(updates_tree) for variable, update in zip(args_leaves, updates_leaves, strict=True): if isinstance(update, variablelib.Variable): assert isinstance(variable, variablelib.Variable) variable.update_from_state(update) def treemap_copy_args(f): @functools.wraps(f) def wrapper(*args, **kwargs): args, kwargs = jax.tree.map(lambda x: x, (args, kwargs)) return f(*args, **kwargs) return wrapper def check_same_variables(inputs, outputs, transform_name: str = ''): def _check(in_leaf, out_leaf): if isinstance(in_leaf, variablelib.Variable) and in_leaf is not out_leaf: raise ValueError( f'{transform_name} Variable identity must be preserved ' 'across iterations.' ) is_leaf = lambda x: isinstance(x, (Mask, variablelib.Variable)) jax.tree.map( _check, inputs, outputs, is_leaf=is_leaf, ) def update_carry_variables(init_val, val_out): def _update(in_leaf, out_leaf): if isinstance(in_leaf, variablelib.Variable): in_leaf.update_from_state(out_leaf) return in_leaf return out_leaf return jax.tree.map( _update, init_val, val_out, is_leaf=lambda x: isinstance(x, variablelib.Variable), ) ================================================ FILE: flax/nnx/filterlib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import builtins import dataclasses from flax.typing import Key, PathParts import typing as tp if tp.TYPE_CHECKING: ellipsis = builtins.ellipsis else: ellipsis = tp.Any Predicate = tp.Callable[[PathParts, tp.Any], bool] FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None] Filter = tp.Union[FilterLiteral, tuple['Filter', ...], list['Filter']] def to_predicate(filter: Filter) -> Predicate: """Converts a Filter to a predicate function. See `Using Filters `__. """ if isinstance(filter, str): return WithTag(filter) elif isinstance(filter, type): return OfType(filter) elif isinstance(filter, bool): if filter: return Everything() else: return Nothing() elif filter is Ellipsis: return Everything() elif filter is None: return Nothing() elif callable(filter): return filter elif isinstance(filter, (list, tuple)): return Any(*filter) else: raise TypeError(f'Invalid collection filter: {filter:!r}. ') def filters_to_predicates( filters: tp.Sequence[Filter], ) -> tuple[Predicate, ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] if not all(f in (..., True) for f in remaining_filters): raise ValueError( '`...` or `True` can only be used as the last filters, ' f'got {filter_} it at index {i}.' ) return tuple(map(to_predicate, filters)) class HasTag(tp.Protocol): tag: str def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]: return hasattr(x, 'tag') @dataclasses.dataclass(frozen=True) class WithTag: tag: str def __call__(self, path: PathParts, x: tp.Any): return _has_tag(x) and x.tag == self.tag def __repr__(self): return f'WithTag({self.tag!r})' @dataclasses.dataclass(frozen=True) class PathContains: key: Key | str exact: bool = True def __call__(self, path: PathParts, x: tp.Any): if self.exact: return self.key in path return any(str(self.key) in str(part) for part in path) def __repr__(self): return f'PathContains({self.key!r}, exact={self.exact})' class PathIn: def __init__(self, *paths: PathParts): self.paths = frozenset(paths) def __call__(self, path: PathParts, x: tp.Any): return path in self.paths def __repr__(self): paths_repr = ','.join(map(repr, self.paths)) return f'PathIn({paths_repr})' def __eq__(self, other): return isinstance(other, PathIn) and self.paths == other.paths def __hash__(self): return hash(self.paths) @dataclasses.dataclass(frozen=True) class OfType: type: type def __call__(self, path: PathParts, x: tp.Any): return isinstance(x, self.type) def __repr__(self): return f'OfType({self.type!r})' class Any: def __init__(self, *filters: Filter): self.predicates = tuple( to_predicate(collection_filter) for collection_filter in filters ) def __call__(self, path: PathParts, x: tp.Any): return any(predicate(path, x) for predicate in self.predicates) def __repr__(self): return f'Any({", ".join(map(repr, self.predicates))})' def __eq__(self, other): return isinstance(other, Any) and self.predicates == other.predicates def __hash__(self): return hash(self.predicates) class All: def __init__(self, *filters: Filter): self.predicates = tuple( to_predicate(collection_filter) for collection_filter in filters ) def __call__(self, path: PathParts, x: tp.Any): return all(predicate(path, x) for predicate in self.predicates) def __repr__(self): return f'All({", ".join(map(repr, self.predicates))})' def __eq__(self, other): return isinstance(other, All) and self.predicates == other.predicates def __hash__(self): return hash(self.predicates) class Not: def __init__(self, collection_filter: Filter, /): self.predicate = to_predicate(collection_filter) def __call__(self, path: PathParts, x: tp.Any): return not self.predicate(path, x) def __repr__(self): return f'Not({self.predicate!r})' def __eq__(self, other): return isinstance(other, Not) and self.predicate == other.predicate def __hash__(self): return hash(self.predicate) class Everything: def __call__(self, path: PathParts, x: tp.Any): return True def __repr__(self): return 'Everything()' def __eq__(self, other): return isinstance(other, Everything) def __hash__(self): return hash(Everything) class Nothing: def __call__(self, path: PathParts, x: tp.Any): return False def __repr__(self): return 'Nothing()' def __eq__(self, other): return isinstance(other, Nothing) def __hash__(self): return hash(Nothing) ================================================ FILE: flax/nnx/graph.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Graph module. Re-exports APIs from ``flax.nnx.graphlib``. This module is kept for backward compatibility with code that imports from ``flax.nnx.graph``. """ from flax.nnx.graphlib import * ================================================ FILE: flax/nnx/graphlib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import contextlib import dataclasses import functools import threading import typing as tp import builtins import jax.core from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.nnx import statelib from flax.nnx.proxy_caller import ( ApplyCaller, CallableProxy, DelayedAccessor, ) from flax.nnx.statelib import FlatState, State, map_state from flax.nnx.variablelib import Variable, is_array_ref, V from flax.typing import BaseConfigContext, HashableMapping, Key, PathParts, is_key_like import jax import numpy as np import treescope # type: ignore[import-not-found,import-untyped] import typing_extensions as tpe A = tp.TypeVar('A') B = tp.TypeVar('B') C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable) HA = tp.TypeVar('HA', bound=tp.Hashable) HB = tp.TypeVar('HB', bound=tp.Hashable) KeyT = tp.TypeVar('KeyT', bound=Key) Index = int def _tree_mode_suggestion_api(fn_name: str) -> str: return ( f'Consider the following options:\n\n' '1. Remove the duplicates and guarantee a tree structure.\n' f'2. Enable graph mode by passing graph=True to {fn_name} e.g.\n\n' f' nnx.{fn_name}(..., graph=True)\n\n' f'3. Use nnx.compat.{fn_name} instead e.g.\n\n' f' nnx.compat.{fn_name}(...)' ) def _tree_mode_suggestion_transform(fn_name: str) -> str: return ( f'Consider the following options:\n\n' '1. Remove the duplicates.\n' f'2. Enable graph mode and graph updates by passing graph=True and ' f'graph_updates=True to {fn_name} e.g.\n\n' f' nnx.{fn_name}(..., graph=True, graph_updates=True)\n\n' f'3. Use nnx.compat.{fn_name} instead e.g.\n\n' f' nnx.compat.{fn_name}(...)' ) def _check_valid_pytree( node: tp.Any, fn_name: str, path: str = '', ) -> None: from flax.nnx import pytreelib if ( isinstance(node, pytreelib.Pytree) and not node._pytree__is_pytree ): msg = ( f"Cannot use '{fn_name}' with graph=False on a " f"'{type(node).__name__}' instance that has pytree=False. " ) if path: msg += f"Found at path: {path}. " msg += ( f"Pytree subclasses with pytree=False are not registered as " f"JAX pytrees and cannot be used in tree-mode. " + _tree_mode_suggestion_api(fn_name) ) raise ValueError(msg) Names = tp.Sequence[int] Node = tp.TypeVar('Node') Leaf = tp.TypeVar('Leaf') AuxData = tp.TypeVar('AuxData') @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class NoUpdate: ... NO_UPDATE = NoUpdate() @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class Repeated: ... REPEATED = Repeated() @jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True, slots=True, repr=False) class ArrayRefOutput(reprlib.Representable): value: jax.Array def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('value', self.value) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'value': self.value, }, path=path, subtree_renderer=subtree_renderer, ) LeafType = tp.Union[ Variable, jax.Array, np.ndarray, variablelib.Ref, ArrayRefOutput, NoUpdate, ] GraphState = State[Key, LeafType] GraphFlatState = FlatState[LeafType] def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[LeafType]: return isinstance(x, LeafType) or variablelib.is_array_ref(x) # type: ignore[misc, arg-type] class IndexMap(dict[Index, tp.Any]): @staticmethod def from_refmap(refmap: RefMap) -> IndexMap: return IndexMap((index, value) for value, index in refmap.items()) if config.flax_use_flaxlib: import flaxlib # type: ignore[import] globals()['IndexMap'] = flaxlib.IndexMap # RefMap = dict class RefMap(tp.MutableMapping[tp.Any, int], reprlib.MappingReprMixin): """A mapping that hashes keys by their identity.""" def __init__( self, mapping: tp.Mapping[tp.Any, int] | tp.Iterable[tuple[tp.Any, int]] | None = None, /, ): self._mapping: dict[int, tuple[tp.Any, int]] = dict() if mapping is not None: self.update(mapping) @staticmethod def from_indexmap(indexmap: IndexMap) -> RefMap: refmap = RefMap() refmap.update((value, index) for index, value in indexmap.items()) return refmap def get(self, key: tp.Any, default: int | None = None) -> int | None: # type: ignore[override] return self._mapping.get(id(key), (None, default))[1] def __getitem__(self, key: tp.Any) -> int: return self._mapping[id(key)][1] def __setitem__(self, key: tp.Any, value: int): self._mapping[id(key)] = (key, value) def __delitem__(self, key: tp.Any): del self._mapping[id(key)] def __len__(self) -> int: return len(self._mapping) def __contains__(self, key: tp.Any) -> bool: return id(key) in self._mapping def __iter__(self) -> tp.Iterator[tp.Any]: for key, _ in self._mapping.values(): yield key def items(self) -> tp.ItemsView[tp.Any, int]: return self._mapping.values() # type: ignore # save python version PythonRefMap = RefMap if config.flax_use_flaxlib: import flaxlib # type: ignore[import] globals()['RefMap'] = flaxlib.RefMap @dataclasses.dataclass(frozen=True, slots=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): type: type[Node] flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] def node_dict(self, node: Node) -> dict[Key, tp.Any]: nodes, _ = self.flatten(node) return { key: value.value if isinstance(value, DataElem | StaticElem) else value for key, value in nodes } @dataclasses.dataclass(frozen=True, slots=True) class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): set_key: tp.Callable[[Node, Key, Leaf], None] pop_key: tp.Callable[[Node, Key], Leaf] create_empty: tp.Callable[[AuxData], Node] clear: tp.Callable[[Node], None] init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None] @dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node] set_key: tp.Callable[[Node, Key, Leaf], None] | None pop_key: tp.Callable[[Node, Key], Leaf] | None NodeImpl = tp.Union[ GraphNodeImpl[Node, Leaf, AuxData], PytreeNodeImpl[Node, Leaf, AuxData] ] GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {} def register_graph_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], set_key: tp.Callable[[Node, Key, Leaf], None], pop_key: tp.Callable[[Node, Key], Leaf], create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], init: tp.Callable[[Node, tp.Iterable[tuple[Key, Leaf]]], None], ): if type in GRAPH_REGISTRY: raise ValueError(f'Node type {type} is already registered.') GRAPH_REGISTRY[type] = GraphNodeImpl( type=type, flatten=flatten, set_key=set_key, pop_key=pop_key, create_empty=create_empty, clear=clear, init=init, ) def register_pytree_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], unflatten: tp.Callable[[tp.Sequence[tuple[Key, Leaf]], AuxData], Node], *, set_key: tp.Callable[[Node, Key, Leaf], None] | None = None, pop_key: tp.Callable[[Node, Key], Leaf] | None = None, ): if type in PYTREE_REGISTRY: raise ValueError(f'Node type {type} is already registered.') PYTREE_REGISTRY[type] = PytreeNodeImpl( type=type, flatten=flatten, unflatten=unflatten, set_key=set_key, pop_key=pop_key, ) def is_node(x: tp.Any) -> bool: if isinstance(x, Variable): return False if type(x) in GRAPH_REGISTRY: return True return is_pytree_node(x) def is_graph_node(x: tp.Any) -> bool: return ( type(x) in GRAPH_REGISTRY or variablelib.is_array_ref(x) or isinstance(x, Variable) ) def is_node_type(x: type[tp.Any]) -> bool: return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None: if isinstance(x, Variable): return None node_type = type(x) if node_type in GRAPH_REGISTRY: return GRAPH_REGISTRY[node_type] elif node_type in PYTREE_REGISTRY: return PYTREE_REGISTRY[node_type] elif node_type in JAX_PYTREE_REGISTRY or issubclass(node_type, tuple): return PYTREE_NODE_IMPL # type: ignore else: return None def get_node_impl_for_type( x: type[Node], ) -> NodeImpl[Node, tp.Any, tp.Any] | None: if x is GenericPytree: return PYTREE_NODE_IMPL # type: ignore elif x in PYTREE_REGISTRY: return PYTREE_REGISTRY[x] elif x in GRAPH_REGISTRY: return GRAPH_REGISTRY[x] else: return None # use type-aware sorting to support int keys def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: key, _ = item if isinstance(key, int): return (0, key) elif isinstance(key, str): return (1, key) else: raise ValueError(f'Unsupported key type: {type(key)!r}') @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, repr=False) class NodeRef(tp.Generic[Node], reprlib.Representable): index: int def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('index', self.index) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={'index': self.index}, path=path, subtree_renderer=subtree_renderer, ) if config.flax_use_flaxlib: import flaxlib # type: ignore[import] jax.tree_util.register_static(flaxlib.NodeRef) globals()['NodeRef'] = flaxlib.NodeRef @dataclasses.dataclass(frozen=True, repr=False) class VariableDef(reprlib.Representable, tp.Generic[Node]): type: type[Node] index: int outer_index: int | None metadata: HashableMapping[str, tp.Any] array_refdef: ArrayRefDef | NodeRef | None def with_no_outer_index(self) -> VariableDef: return VariableDef( type=self.type, index=self.index, outer_index=None, metadata=self.metadata, array_refdef=self.array_refdef.with_no_outer_index() if isinstance(self.array_refdef, ArrayRefDef) else self.array_refdef, ) def with_same_outer_index(self) -> VariableDef: return VariableDef( type=self.type, index=self.index, outer_index=self.index, metadata=self.metadata, array_refdef=self.array_refdef.with_same_outer_index() if isinstance(self.array_refdef, ArrayRefDef) else self.array_refdef, ) def with_matching_outer_index(self, other) -> VariableDef: return VariableDef( type=self.type, index=self.index, outer_index=other.outer_index, metadata=self.metadata, array_refdef=self.array_refdef.with_matching_outer_index(other.array_refdef) if isinstance(self.array_refdef, ArrayRefDef) else self.array_refdef ) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) yield reprlib.Attr('outer_index', self.outer_index) yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata)) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, 'index': self.index, 'outer_index': self.outer_index, 'metadata': self.metadata, }, path=path, subtree_renderer=subtree_renderer, ) if config.flax_use_flaxlib: import flaxlib # type: ignore[import] jax.tree_util.register_static(flaxlib.VariableDef) globals()['VariableDef'] = flaxlib.VariableDef @dataclasses.dataclass(frozen=True, repr=False) class ArrayRefDef(reprlib.Representable): index: int outer_index: int | None def with_no_outer_index(self): return ArrayRefDef( index=self.index, outer_index=None, ) def with_same_outer_index(self): return ArrayRefDef( index=self.index, outer_index=self.index, ) def with_matching_outer_index(self, other): return ArrayRefDef( index=self.index, outer_index=other.outer_index, ) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('index', self.index) yield reprlib.Attr('outer_index', self.outer_index) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'index': self.index, 'outer_index': self.outer_index, }, path=path, subtree_renderer=subtree_renderer, ) @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(tp.Generic[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a :class:`Module`. A ``GraphDef`` can be generated by either calling :func:`split` or :func:`graphdef` on the :class:`Module`.""" type: tp.Type[Node] index: int | None outer_index: int | None num_attributes: int metadata: tp.Any def with_no_outer_index(self) -> NodeDef[Node]: return NodeDef( type=self.type, index=self.index, outer_index=None, num_attributes=self.num_attributes, metadata=self.metadata, ) def with_same_outer_index(self) -> NodeDef[Node]: return NodeDef( type=self.type, index=self.index, outer_index=self.index, num_attributes=self.num_attributes, metadata=self.metadata, ) def with_matching_outer_index(self, other) -> NodeDef[Node]: return NodeDef( type=self.type, index=self.index, outer_index=other.outer_index, num_attributes=self.num_attributes, metadata=self.metadata, ) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) yield reprlib.Attr('type', self.type.__name__) yield reprlib.Attr('index', self.index) yield reprlib.Attr('outer_index', self.outer_index) yield reprlib.Attr('num_attributes', self.num_attributes) yield reprlib.Attr('metadata', self.metadata) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={ 'type': self.type, 'index': self.index, 'outer_index': self.outer_index, 'num_attributes': self.num_attributes, 'metadata': self.metadata, }, path=path, subtree_renderer=subtree_renderer, ) if config.flax_use_flaxlib: import flaxlib # type: ignore[import] jax.tree_util.register_static(flaxlib.NodeDef) globals()['NodeDef'] = flaxlib.NodeDef @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class TreeNodeDef(tp.Generic[Node]): type: tp.Type[Node] treedef: jax.tree_util.PyTreeDef path_index: tuple[tuple[PathParts, int], ...] def with_no_outer_index(self) -> TreeNodeDef[Node]: return self def with_same_outer_index(self) -> TreeNodeDef[Node]: return self def with_matching_outer_index(self, other) -> TreeNodeDef[Node]: return self NodeDefType = tp.Union[ NodeDef[Node], NodeRef[Node], VariableDef[Node], ArrayRefDef, TreeNodeDef[Node], ] @dataclasses.dataclass(frozen=True, slots=True) class NodeAttr: pass NODE_ATTR = NodeAttr() @dataclasses.dataclass(frozen=True, slots=True) class LeafAttr: pass LEAF_ATTR = LeafAttr() AttrType = tp.Union[ NodeAttr, LeafAttr, 'Static[tp.Any]', ] # GraphDef = tp.Union[NodeDef[Node], NodeRef[Node], VariableDef[Node]] @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class GraphDef(tp.Generic[Node]): nodes: list[NodeDefType[tp.Any]] attributes: list[tuple[Key, AttrType]] num_leaves: int def __hash__(self) -> int: return hash((tuple(self.nodes), tuple(self.attributes))) def with_no_outer_index(self) -> GraphDef[Node]: return GraphDef( nodes=[ node.with_no_outer_index() if not isinstance(node, NodeRef) else node for node in self.nodes ], attributes=self.attributes, num_leaves=self.num_leaves, ) def with_matching_outer_index(self, other) -> GraphDef[Node]: return GraphDef( nodes=[ node.with_matching_outer_index(other_node) if not isinstance(node, NodeRef) else node for node, other_node in zip(self.nodes, other.nodes) ], attributes=self.attributes, num_leaves=self.num_leaves, ) def with_same_outer_index(self) -> GraphDef[Node]: return GraphDef( nodes=[ node.with_same_outer_index() if not isinstance(node, NodeRef) else node for node in self.nodes ], attributes=self.attributes, num_leaves=self.num_leaves, ) # TODO(cgarciae): remove this method def apply( self, state: GraphState, *states: GraphState, graph: bool | None = None, ) -> ApplyCaller[tuple[GraphDef[Node], GraphState]]: accessor = DelayedAccessor() def _apply( accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[GraphDef[Node], GraphState]]: module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) if graph is None: _graph = config.nnx_graph_mode else: _graph = graph graphdef, flat_state = flatten(module, graph=_graph) state_ = statelib.from_flat_state(flat_state) return out, (graphdef, state_) return CallableProxy(_apply, accessor) # type: ignore PureState = tuple[GraphDef[Node], GraphState] def _tree_flatten( node: tp.Any, nodes: list[NodeDefType[tp.Any]], leaves: list[tp.Any], paths: list[PathParts] | None, ) -> None: seen_variables: dict[int, str] = {} seen_refs: dict[int, str] = {} def _is_leaf(path, x): if isinstance(x, Variable): var_id = id(x) str_path = jax.tree_util.keystr(path) if var_id in seen_variables: raise ValueError( f'Duplicate {x}\nfound at paths:\n\n' f' - {seen_variables[var_id]}\n' f' - {str_path}\n\n' 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion_api('split') ) seen_variables[var_id] = str_path return True if variablelib.is_array_ref(x): ref_id = id(x) str_path = jax.tree_util.keystr(path) if ref_id in seen_refs: raise ValueError( f'Duplicate {x}\nfound at paths:\n\n' f' - {seen_refs[ref_id]}\n' f' - {str_path}\n\n' 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion_api('split') ) seen_refs[ref_id] = str_path _check_valid_pytree(x, 'flatten', jax.tree_util.keystr(path)) return False jax_leaves, treedef = jax.tree_util.tree_flatten_with_path( node, is_leaf=_is_leaf, is_leaf_takes_path=True ) nnx_paths_and_leaves: list[tuple[PathParts, tp.Any]] = [ (jax_to_nnx_path(jax_path), value) for jax_path, value in jax_leaves ] original_indices = {p: i for i, (p, _) in enumerate(nnx_paths_and_leaves)} nnx_paths_and_leaves.sort() path_index = tuple( (p, original_indices[p]) for p, _ in nnx_paths_and_leaves ) tree_nodedef: TreeNodeDef[tp.Any] = TreeNodeDef( type=type(node), treedef=treedef, path_index=path_index, ) nodes.append(tree_nodedef) sorted_leaf_index = 0 for nnx_path, value in nnx_paths_and_leaves: if isinstance(value, Variable): nodes.append(VariableDef( type=value.var_type, index=sorted_leaf_index, outer_index=None, metadata=HashableMapping(value.get_metadata()), array_refdef=None, )) leaves.append(value) if paths is not None: paths.append(nnx_path) sorted_leaf_index += 1 @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, ref_index: RefMap | None = ..., ref_outer_index: RefMap | None = ..., graph: bool = ..., ) -> tuple[GraphDef[Node], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: tp.Literal[True], ref_index: RefMap | None = ..., ref_outer_index: RefMap | None = ..., graph: bool = ..., ) -> tuple[ GraphDef[Node], FlatState[tp.Any], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: tp.Literal[False], ref_index: RefMap | None = ..., ref_outer_index: RefMap | None = ..., graph: bool = ..., ) -> tuple[ GraphDef[Node], list[tp.Any], ]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: bool, ref_index: RefMap | None = ..., ref_outer_index: RefMap | None = ..., graph: bool = ..., ) -> tuple[ GraphDef[Node], FlatState[tp.Any] | list[tp.Any], ]: ... def flatten( # type: ignore[invalid-annotation] node: Node, /, *, with_paths: bool = True, ref_index: RefMap | None = None, ref_outer_index: RefMap | None = None, graph: bool | None = None, ) -> tuple[ GraphDef[Node], FlatState[tp.Any] | list[tp.Any], ]: """Flattens a graph node into a (graphdef, state) pair. Args: x: A graph node. ref_index: A mapping from nodes to indexes, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to flatten a sequence of graph nodes that share references. with_paths: A boolean that indicates whether to return a FlatState object that includes the paths, or just a list of the Variable's inner values. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ if graph is None: graph = set_graph_mode.current_value() if ref_index is None: ref_index = RefMap() leaves: list[tp.Any] = [] path: list[Key] | None = [] if with_paths else None paths: list[PathParts] | None = [] if with_paths else None nodes: list[NodeDefType[tp.Any]] = [] attributes: list[tuple[Key, AttrType]] = [] if graph: node_impl = get_node_impl(node) _graph_flatten( node, node_impl, path, ref_index, ref_outer_index, nodes, attributes, leaves, paths, ) else: _tree_flatten( node, nodes, leaves, paths, ) graphdef: GraphDef = GraphDef( nodes=nodes, attributes=attributes, num_leaves=len(leaves) ) if paths is not None: return graphdef, FlatState.from_sorted_keys_values(tuple(paths), leaves) # type: ignore[return-value] else: return graphdef, leaves @dataclasses.dataclass(frozen=True, slots=True) class DataElem: value: tp.Any @dataclasses.dataclass(frozen=True, slots=True) class StaticElem: value: tp.Any def _graph_flatten( node: Node, node_impl: NodeImpl[Node, Leaf, AuxData] | None, path: list[Key] | None, ref_index: RefMap, ref_outer_index: RefMap | None, nodes: list[NodeDefType[tp.Any]], attributes: list[tuple[Key, AttrType]], leaves: list[tp.Any], paths: list[PathParts] | None, ) -> None: is_pytree_node_ = type(node_impl) is PytreeNodeImpl index: int | None if not is_pytree_node_ and node in ref_index: nodes.append(NodeRef(index := ref_index[node])) return is_graph_node_ = type(node_impl) is GraphNodeImpl is_variable = isinstance(node, Variable) is_array_ref = variablelib.is_array_ref(node) # only cache graph nodes, we don't add array refs here # as they are added in the make_mutable_arraydef function if is_graph_node_ or is_variable: index = len(ref_index) ref_index[node] = index else: index = None def make_mutable_arraydef(value: variablelib.Ref): if value in ref_index: index = ref_index[value] return NodeRef(index), REPEATED else: index = len(ref_index) ref_index[value] = index output_value: NoUpdate | ArrayRefOutput | variablelib.Ref if ref_outer_index is not None: if value in ref_outer_index: outer_index = ref_outer_index[value] output_value = NO_UPDATE array_refdef = ArrayRefDef(index=index, outer_index=outer_index) else: output_value = ArrayRefOutput(value[...]) array_refdef = ArrayRefDef(index=index, outer_index=None) else: output_value = value array_refdef = ArrayRefDef(index=index, outer_index=None) return array_refdef, output_value if is_variable: assert isinstance(node, Variable) assert index is not None prev_inner_value = node.get_raw_value() if variablelib.is_array_ref(prev_inner_value): array_refdef, inner_value = make_mutable_arraydef(prev_inner_value) else: array_refdef = None inner_value = prev_inner_value if path is None: leaf = inner_value else: leaf = node # type: ignore[assignment] if inner_value is not prev_inner_value: leaf.set_raw_value(inner_value) variabledef = VariableDef( type=node.var_type, # type: ignore index=index, outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, metadata=HashableMapping(node.get_metadata()), array_refdef=array_refdef, ) if type(inner_value) is not Repeated: assert not isinstance(leaf, Repeated) leaves.append(leaf) if path is not None: assert paths is not None paths.append(tuple(path)) nodes.append(variabledef) return elif is_array_ref: array_refdef, leaf = make_mutable_arraydef(node) # type: ignore[arg-type] if not isinstance(leaf, Repeated): leaves.append(leaf) if path is not None: assert paths is not None paths.append(tuple(path)) nodes.append(array_refdef) return elif not is_pytree_node_ and not is_graph_node_: # unkown leaf leaves.append(node) if path is not None: assert paths is not None paths.append(tuple(path)) return if node_impl is None: raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') values, metadata = node_impl.flatten(node) num_attributes = len(values) nodedef = NodeDef( node_impl.type, index, ref_outer_index[node] if is_graph_node_ and ref_outer_index and node in ref_outer_index else None, num_attributes, metadata, ) nodes.append(nodedef) for key, value in values: is_data = None if isinstance(value, DataElem): value = value.value is_data = True elif isinstance(value, StaticElem): value = value.value is_data = False if is_data is False: attributes.append((key, Static(value))) continue value_node_impl = get_node_impl(value) if path is not None: path.append(key) if value_node_impl is not None or isinstance(value, Variable): attributes.append((key, NODE_ATTR)) _graph_flatten( value, value_node_impl, path, ref_index, ref_outer_index, nodes, attributes, leaves, paths, ) elif variablelib.is_array_ref(value): attributes.append((key, NODE_ATTR)) array_refdef, leaf = make_mutable_arraydef(value) if not isinstance(leaf, Repeated): leaves.append(leaf) if paths is not None: paths.append(tuple(path)) # type: ignore nodes.append(array_refdef) elif isinstance(value, (jax.Array, np.ndarray)) or is_data: attributes.append((key, LEAF_ATTR)) if paths is not None: paths.append(tuple(path)) # type: ignore leaves.append(value) else: attributes.append((key, Static(value))) if path is not None: path.pop() return def _get_sorted_leaves( xs: tp.Mapping[tp.Any, tp.Any], ) -> list[tp.Any]: if not isinstance(xs, tp.Mapping): # type: ignore raise TypeError(f'expected Mapping; got {type(xs).__qualname__}') leaves: list[tp.Any] = [] def _flatten(xs): if not isinstance(xs, tp.Mapping): leaves.append(xs) else: for _, value in sorted(xs.items()): _flatten(value) _flatten(xs) return leaves def _tree_unflatten( graphdef: GraphDef[tp.Any], leaves: list[tp.Any], copy_variables: bool, ) -> tp.Any: tree_nodedef = graphdef.nodes[0] assert isinstance(tree_nodedef, TreeNodeDef) variable_defs_iter = iter( node for node in graphdef.nodes[1:] if isinstance(node, VariableDef) ) variabledef = next(variable_defs_iter, None) original_leaves: list[tp.Any] = [None] * len(leaves) for i, (path, original_index) in enumerate(tree_nodedef.path_index): leaf = leaves[i] if variabledef is not None and variabledef.index == i: if isinstance(leaf, Variable): if copy_variables: leaf = leaf.copy() else: leaf = variabledef.type.from_metadata( leaf, dict(variabledef.metadata) ) variabledef = next(variable_defs_iter, None) original_leaves[original_index] = leaf return tree_nodedef.treedef.unflatten(original_leaves) def unflatten( # type: ignore[invalid-annotation] graphdef: GraphDef[Node], state: State[Key, tp.Any] | FlatState[tp.Any] | list[tp.Any], /, *, index_ref: IndexMap | None = None, outer_index_outer_ref: IndexMap | None = None, copy_variables: bool = False, ) -> Node: """Unflattens a graphdef into a node with the given state. Args: graphdef: A GraphDef instance. state: A State instance. index_ref: A mapping from indexes to nodes references found during the graph traversal, defaults to None. If not provided, a new empty dictionary is created. This argument can be used to unflatten a sequence of (graphdef, state) pairs that share the same index space. index_ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the graphdef. copy_variables: If True variables in the state will be copied onto the new new structure, else variables will be shared. Default is False. """ if isinstance(state, (State, dict)): leaves = _get_sorted_leaves(state) elif isinstance(state, FlatState): leaves = state.leaves elif isinstance(state, list): # type: ignore leaves = state else: raise ValueError(f'Unsupported state type: {type(state)}') if len(leaves) != graphdef.num_leaves: raise ValueError( f'Incorrect number of leaves, expected {graphdef.num_leaves} leaves, but got {len(leaves)}.' ) if graphdef.nodes and isinstance(graphdef.nodes[0], TreeNodeDef): return _tree_unflatten(graphdef, leaves, copy_variables) if index_ref is None: index_ref = IndexMap() if len(graphdef.nodes) == 0: return leaves[0] elif isinstance(nodedef := graphdef.nodes[0], NodeRef): node = index_ref[nodedef.index] else: node_iter = iter(graphdef.nodes) attribute_iter = iter(graphdef.attributes) leaves_iter = iter(leaves) nodedef = next(node_iter) assert not isinstance(nodedef, NodeRef) if isinstance(nodedef, ArrayRefDef): node_impl = None else: node_impl = get_node_impl_for_type(nodedef.type) node = _graph_unflatten( nodedef, node_impl, node_iter, attribute_iter, leaves_iter, index_ref, outer_index_outer_ref, copy_variables, ) try: next(leaves_iter) except StopIteration: pass else: raise ValueError('Incorrect number of leaves in state.') return node def _graph_unflatten( nodedef: NodeDefType[Node], node_impl: NodeImpl[Node, Leaf, AuxData] | None, node_iter: tp.Iterator[NodeDefType[Node]], attribute_iter: tp.Iterator[tuple[Key, AttrType]], leaves_iter: tp.Iterator[tp.Any], index_ref: IndexMap, outer_index_outer_ref: IndexMap | None, copy_variables: bool, ) -> Node: """Recursive helper for graph_unflatten. Args: nodedef: A GraphDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. outer_index_outer_ref: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology specified by the nodedef. """ def get_mutable_array(array_refdef: ArrayRefDef, leaf): assert type(array_refdef) is ArrayRefDef if ( outer_index_outer_ref is not None and array_refdef.outer_index is not None and array_refdef.outer_index in outer_index_outer_ref ): # if array ref exists, update it array_ref = outer_index_outer_ref[array_refdef.outer_index] if not variablelib.is_array_ref(array_ref): raise RuntimeError(f'Expected a ArrayRef type but got {array_ref}.') if type(leaf) is not NoUpdate: raise RuntimeError(f'Expected a no update for ArrayRef but got {leaf}.') elif type(leaf) in (NoUpdate, Repeated): raise ValueError( f"Expected a ArrayRefOutput type but got '{leaf}.'" ) elif type(leaf) is ArrayRefOutput: array_ref = jax.new_ref(leaf.value) elif variablelib.is_array_ref(leaf): array_ref = leaf else: # here we allow merging frozen arrays and will not create a new array ref array_ref = leaf index_ref[array_refdef.index] = array_ref return array_ref if type(nodedef) is NodeRef: return index_ref[nodedef.index] if type(nodedef) is VariableDef: variabledef = tp.cast(VariableDef[Variable], nodedef) # its a unseen variable, create a new one if variabledef.array_refdef is not None: if type(variabledef.array_refdef) is NodeRef: value = index_ref[variabledef.array_refdef.index] else: value = next(leaves_iter) assert type(variabledef.array_refdef) is ArrayRefDef if isinstance(value, Variable): copy_ref = not isinstance( value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) ) value = value.copy(_copy_ref=copy_ref) if copy_variables else value inner_value = value.get_raw_value() array_ref = get_mutable_array(variabledef.array_refdef, inner_value) if array_ref is not inner_value: value.set_raw_value(array_ref) else: # if value is an array or array ref, we need call get_mutable_array # to register it in the index_ref value = get_mutable_array(variabledef.array_refdef, value) else: value = next(leaves_iter) if isinstance(value, Variable) and copy_variables: copy_ref = not isinstance( value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) ) value = value.copy(_copy_ref=copy_ref) # when idxmap is present, check if the Varable exists there # and update existing variables if it does if ( outer_index_outer_ref is not None and variabledef.outer_index is not None and variabledef.outer_index in outer_index_outer_ref ): # if variable exists, update it variable = outer_index_outer_ref[variabledef.outer_index] if not isinstance(variable, Variable): raise ValueError(f'Expected a Variable type but got {type(variable)}.') elif isinstance(value, Variable): variable.update_from_state(value) else: variable.set_raw_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one if isinstance(value, Variable): variable = value else: variable = variabledef.type.from_metadata( value, dict(variabledef.metadata) ) index_ref[variabledef.index] = variable return variable # type: ignore[return-value] if type(nodedef) is ArrayRefDef: leaf = next(leaves_iter) array_ref = get_mutable_array(nodedef, leaf) return array_ref # type: ignore[return-value] assert type(nodedef) is NodeDef if node_impl is None: raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') if nodedef.index is not None and nodedef.index in index_ref: raise RuntimeError(f'GraphDef index {nodedef.index} already used.') def _get_children() -> list[tuple[Key, tp.Any]]: children: list[tuple[Key, LeafType | Node]] = [] # type: ignore[invalid-annotation] assert type(nodedef) is NodeDef for _ in range(nodedef.num_attributes): key, value = next(attribute_iter) if type(value) is Static: children.append((key, value.value)) # type: ignore[attribute-error] elif type(value) is LeafAttr: leaf = next(leaves_iter) children.append((key, leaf)) elif type(value) is NodeAttr: node_def = next(node_iter) if isinstance(node_def, NodeRef): node = index_ref[node_def.index] elif isinstance(node_def, ArrayRefDef): leaf = next(leaves_iter) node = get_mutable_array(node_def, leaf) elif isinstance(node_def, NodeDef | VariableDef): value_node_impl = get_node_impl_for_type(node_def.type) node = _graph_unflatten( node_def, value_node_impl, node_iter, attribute_iter, leaves_iter, index_ref, outer_index_outer_ref, copy_variables, ) else: raise RuntimeError(f'Unknown node definition: {node_def!r}') children.append((key, node)) elif type(value) is NodeRef: children.append((key, index_ref[value.index])) # type: ignore[attribute-error] else: raise RuntimeError(f'Unknown static field: {key!r}') return children if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle assert type(nodedef) is NodeDef if ( outer_index_outer_ref is not None and nodedef.outer_index is not None and nodedef.outer_index in outer_index_outer_ref ): node = outer_index_outer_ref[nodedef.outer_index] if type(node) != nodedef.type: raise ValueError( f'Expected a node of type {nodedef.type} for index ' f'{nodedef.index}, but got a node of type {type(node)}.' ) node_impl.clear(node) else: node = node_impl.create_empty(nodedef.metadata) assert nodedef.index is not None index_ref[nodedef.index] = node node_impl.init(node, _get_children()) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first node = node_impl.unflatten(_get_children(), nodedef.metadata) return node def graph_pop( node: tp.Any, filters: tuple[filterlib.Filter, ...], ) -> tuple[GraphState, ...]: id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, LeafType], ...] = tuple( {} for _ in predicates ) _graph_pop(node, id_to_index, path_parts, flat_states, predicates) return tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) def _graph_pop( node: tp.Any, id_to_index: dict[int, Index], path_parts: PathParts, flat_states: tuple[dict[PathParts, LeafType], ...], predicates: tuple[filterlib.Predicate, ...], ) -> None: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') if id(node) in id_to_index: return id_to_index[id(node)] = len(id_to_index) node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') node_dict = node_impl.node_dict(node) for name, value in node_dict.items(): if is_node(value): _graph_pop( node=value, id_to_index=id_to_index, path_parts=(*path_parts, name), flat_states=flat_states, predicates=predicates, ) continue elif not is_node_leaf(value): continue elif id(value) in id_to_index: continue node_path = (*path_parts, name) node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') for state, predicate in zip(flat_states, predicates): if predicate(node_path, value): if node_impl.pop_key is None: raise ValueError( f'Cannot pop key {name!r} from node of type {type(node).__name__}' ) id_to_index[id(value)] = len(id_to_index) node_impl.pop_key(node, name) if isinstance(value, Variable): value = value state[node_path] = value # type: ignore[index] # mypy is wrong here? break else: # NOTE: should we raise an error here? pass def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]): def _update_variable(node: Variable, value): if isinstance(value, Variable): # updated from Variable node.update_from_state(value) else: # updated from raw value if isinstance(value, State) and not value: # NOTE: this is a special case when trying to update a Variable from state # created when flattening into a NodeRef, which creates an empty State. This # can happen when using standalone Variables with `grad` pass else: if is_array_ref(node.get_raw_value()) and ( isinstance(value, jax.Array) or is_array_ref(value) ): node[...] = value[...] else: node.set_raw_value(value, _unsafe_bypass_check=True) if isinstance(node, Variable): _update_variable(node, state) return if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') node_impl = get_node_impl(node) if node_impl is None: raise TypeError(f'Unknown node type: {type(node)}') node_dict = node_impl.node_dict(node) for key, value in state.items(): # case 1: new state is being added if key not in node_dict: if node_impl.set_key is None: raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' ) if isinstance(value, Variable): copy_ref = not isinstance( value.get_raw_value(), (NoUpdate, Repeated, ArrayRefOutput) ) value = value.copy(_copy_ref=copy_ref) node_impl.set_key(node, key, value) continue current_value = node_dict[key] # case 2: subgraph is being updated if is_array_ref(current_value): current_value[...] = value elif is_node(current_value): if is_node_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) elif isinstance(current_value, Variable): _update_variable(current_value, value) elif node_impl.set_key is not None: node_impl.set_key(node, key, value) else: raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' ) # -------------------------------------------------------- # UpdateContext # -------------------------------------------------------- class StaticCache(tp.NamedTuple): graphdef: GraphDef[tp.Any] final_graphdef: GraphDef[tp.Any] paths: tuple[PathParts, ...] variables: list[Variable[tp.Any]] new_ref_index: RefMap new_index_ref: IndexMap @staticmethod def create( graphdef: GraphDef[tp.Any], paths: tuple[PathParts, ...], variables: list[Variable[tp.Any]], new_ref_index: RefMap, ): new_index_ref = IndexMap.from_refmap(new_ref_index) final_graphdef: GraphDef[tp.Any] final_graphdef = graphdef.with_same_outer_index() return StaticCache( graphdef=graphdef, final_graphdef=final_graphdef, paths=paths, variables=variables, new_ref_index=new_ref_index, new_index_ref=new_index_ref, ) @dataclasses.dataclass class GraphContext(threading.local): update_context_stacks: dict[tp.Hashable, list[UpdateContext]] = ( dataclasses.field(default_factory=dict) ) ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list) index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list) tmp_static_cache: tp.MutableMapping[tp.Any, StaticCache] | None = None caching: bool = False graph_mode_stack: list[bool] = dataclasses.field(default_factory=list) graph_updates_stack: list[bool] = dataclasses.field(default_factory=list) GRAPH_CONTEXT = GraphContext() class set_graph_mode(BaseConfigContext): get_default = classmethod(lambda cls: config.nnx_graph_mode) get_stack = classmethod(lambda cls: GRAPH_CONTEXT.graph_mode_stack) class set_graph_updates(BaseConfigContext): get_default = classmethod(lambda cls: config.nnx_graph_updates) get_stack = classmethod(lambda cls: GRAPH_CONTEXT.graph_updates_stack) @contextlib.contextmanager def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]): if GRAPH_CONTEXT.caching: yield return GRAPH_CONTEXT.tmp_static_cache = static_cache try: yield finally: if GRAPH_CONTEXT.tmp_static_cache is not None: raise ValueError( 'GRAPH_CONTEXT.tmp_static_cache should be None, no context consumed it.' ) def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bool | None = None): """Create a partial from a NNX transformed function alog with some cached input arguments and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful for speed up function that are called repeatedly with the same subset of inputs e.g. a ``train_step`` with a ``model`` and ``optimizer``:: >>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ... >>> @nnx.jit ... def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(model, grads) ... return loss ... >>> cached_train_step = nnx.cached_partial(train_step, model, optimizer) ... >>> for step in range(total_steps:=2): ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) ... # loss = train_step(model, optimizer, x, y) ... loss = cached_train_step(x, y) ... print(f'Step {step}: loss={loss:.3f}') Step 0: loss=2.669 Step 1: loss=2.660 Note that ``cached_partial`` will clone all cached graph nodes to gurantee the validity of the cache, and these clones will contain references to the same Variable objects which guarantees that state is propagated correctly back to the original graph nodes. Because of the previous, the final structure of all graph nodes must be the same after each call to the cached function, otherwise an error will be raised. Temporary mutations are allowed (e.g. the use of ``Module.sow``) as long as they are cleaned up before the function returns (e.g. via ``nnx.pop``). Args: f: A function to cache. *cached_args: A subset of the input arguments containing the graph nodes to cache. Returns: A partial function expecting the remaining arguments to the original function. """ if graph is None: graph = set_graph_mode.current_value() if not graph: raise ValueError( 'cached_partial is a graph-mode-only API and does not support ' 'tree-mode (graph=False).' ) cache: tp.MutableMapping[tp.Any, StaticCache] = PythonRefMap() # type: ignore original_ref_index: RefMap = RefMap() index_ref: IndexMap = IndexMap() cached_ref_index: RefMap = RefMap() def create_static_cache(x): # TODO(cgarciae): support Array attribute updates for graph nodes if is_graph_node(x) or isinstance(x, Variable): graphdef, flat_state = flatten( x, with_paths=True, ref_index=original_ref_index, graph=True ) paths = flat_state.paths variables = flat_state.leaves # clone but keep the same variable references node_cache = unflatten( graphdef, flat_state, index_ref=index_ref, copy_variables=False, ) start_index = len(cached_ref_index) flatten( node_cache, ref_index=cached_ref_index, with_paths=False, graph=True, ) cached_new_ref_index = RefMap( (key, value) for key, value in cached_ref_index.items() if value >= start_index ) cache[node_cache] = StaticCache.create( graphdef, paths, variables, cached_new_ref_index ) return node_cache return x cached_args = jax.tree.map( create_static_cache, cached_args, is_leaf=lambda x: is_graph_node(x) or isinstance(x, Variable), ) @functools.wraps(f) def cache_args_wrapper(*args, **kwargs): with static_cache(cache): return f(*cached_args, *args, **kwargs) return cache_args_wrapper if tp.TYPE_CHECKING: cached_partial = functools.partial else: cached_partial = _cached_partial @dataclasses.dataclass class SplitContext: ctxtag: tp.Hashable | None ref_index: RefMap is_inner: bool | None @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... # type: ignore[invalid-annotation] @tp.overload def split( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, / ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]]]: ... # type: ignore[not-supported-yet] def split( self, node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], tpe.Unpack[tuple[GraphState, ...]]]: # type: ignore[not-supported-yet] ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) inner_ref_outer_index = ( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=inner_ref_outer_index, graph=True ) flat_states = _split_state(flat_state, filters) states = _to_nested_state(graphdef, flat_states) return graphdef, *states @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, /, *, with_paths: tp.Literal[False], ) -> tuple[GraphDef[A], list[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, /, ) -> tuple[GraphDef[A], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, /, ) -> tuple[GraphDef[A], FlatState[tp.Any]]: ... @tp.overload def flatten( # type: ignore[invalid-annotation] self, graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[ GraphDef[A], FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]], ]: ... def flatten( # type: ignore[invalid-annotation] self, node: A, *filters: filterlib.Filter, with_paths: bool = True, ) -> tuple[ GraphDef[A], FlatState[tp.Any] | list[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]], ]: if not with_paths and filters: raise ValueError('Cannot use filters with with_paths=False') ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) static_cache = ( ctx.static_cache if ctx is not None and self.is_inner is False else None ) ref_outer_index = ( ctx.inner_ref_outer_index if ctx and ctx.inner_ref_outer_index else None ) flat_state: FlatState[tp.Any] | list[tp.Any] leaves: list[tp.Any] if node in self.ref_index: # node is already in the ref_index, call flatten which will return a NodeRef graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, graph=True, ) if with_paths: assert isinstance(flat_state, FlatState) paths = flat_state.paths leaves = flat_state.leaves else: assert isinstance(flat_state, list) paths = None leaves = flat_state elif static_cache is not None and node in static_cache: node_static_cache = static_cache[node] graphdef = node_static_cache.graphdef # add the new references to the ref_index self.ref_index.update(node_static_cache.new_ref_index) if with_paths: paths = node_static_cache.paths leaves = node_static_cache.variables else: paths = None leaves = [ variable.get_raw_value() for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( node, ref_index=self.ref_index, ref_outer_index=ref_outer_index, with_paths=with_paths, graph=True, ) if with_paths: assert isinstance(flat_state, FlatState) paths = flat_state.paths leaves = flat_state.leaves else: assert isinstance(flat_state, list) paths = None leaves = flat_state if with_paths: assert paths is not None flat_state = FlatState.from_sorted_keys_values(paths, leaves) flat_states = _split_state(flat_state, filters) return graphdef, *flat_states # type: ignore[bad-return-type] else: return graphdef, leaves @contextlib.contextmanager def split_context(ctxtag: tp.Hashable | None = None): ctx = current_update_context(ctxtag) if ctxtag is not None else None is_inner = ctx.outer_ref_outer_index is not None if ctx is not None else None GRAPH_CONTEXT.ref_index_stack.append(SplitContext(ctxtag, RefMap(), is_inner)) try: yield GRAPH_CONTEXT.ref_index_stack[-1] finally: flatten_ctx = GRAPH_CONTEXT.ref_index_stack.pop() if ctxtag is not None: ctx = current_update_context(ctxtag) ctx.flatten_end(flatten_ctx.ref_index) del flatten_ctx.ref_index del flatten_ctx.ctxtag @dataclasses.dataclass class MergeContext: ctxtag: tp.Hashable | None index_ref: IndexMap is_inner: bool | None def merge( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) outer_index_outer_ref = ( ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None ) _state = _merge_to_flat_state((state, *states)) node = unflatten( graphdef, _state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, copy_variables=True, ) return node def unflatten( # type: ignore[invalid-annotation] self, graphdef: GraphDef[A], flat_state: GraphFlatState | list[tp.Any], /, *flat_states: GraphFlatState, ) -> A: ctx = ( current_update_context(self.ctxtag) if self.ctxtag is not None else None ) static_cache = ( ctx.static_cache if ctx is not None and self.is_inner is False else None ) state: FlatState[tp.Any] | list[tp.Any] if type(flat_state) is list: if flat_states: raise ValueError( 'Cannot use multiple flat_states when flat_state is a list, ' f'got flat_state: {flat_state!r}, flat_states: {flat_states!r}' ) state = flat_state else: state = FlatState.merge(flat_state, *flat_states) if type(graphdef.nodes[0]) is NodeRef: node = unflatten( graphdef, state, index_ref=self.index_ref, ) elif static_cache is not None: assert isinstance(graphdef.nodes[0], NodeDef) or isinstance(graphdef.nodes[0], VariableDef) assert ctx is not None if (outer_index := graphdef.nodes[0].outer_index) is not None: outer_index_outer_ref = ctx.outer_index_outer_ref assert outer_index_outer_ref is not None node = outer_index_outer_ref[outer_index] if node in static_cache: static_cache_node = static_cache[node] if static_cache_node.final_graphdef != graphdef: raise ValueError( 'The graph structure of a node added to cached_partial was mutated inside the transformation, ' f'this is not allowed.\nNode: {node}\nOuput graphdef: {graphdef}\nExpected graphdef: {static_cache_node.final_graphdef}' ) if type(state) is list: leaves = state elif type(state) is FlatState: leaves = state.leaves else: raise ValueError(f'Unsupported state type: {type(state)}') if len(leaves) != len(static_cache_node.variables): raise ValueError( f'Incorrect number of leaves: expected {len(static_cache_node.variables)} ' f'leaves in the state, got {len(leaves)}' ) for variable, leaf in zip(static_cache_node.variables, leaves): if isinstance(leaf, Variable): variable.update_from_state(leaf) else: variable.set_raw_value(leaf) self.index_ref.update(static_cache_node.new_index_ref) else: # uncached node, create it node = unflatten( graphdef, state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, ) else: # graphdef.outer_index is None # its a new node, create it node = unflatten( graphdef, state, index_ref=self.index_ref, ) else: outer_index_outer_ref = ( ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None ) node = unflatten( graphdef, state, index_ref=self.index_ref, outer_index_outer_ref=outer_index_outer_ref, ) return node @tp.overload @contextlib.contextmanager def merge_context() -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] @tp.overload @contextlib.contextmanager def merge_context( ctxtag: tp.Hashable | None, inner: bool | None ) -> tp.Generator[MergeContext, None, None]: ... # type: ignore[bad-return-type] @contextlib.contextmanager def merge_context(ctxtag: tp.Hashable | None = None, inner: bool | None = None): GRAPH_CONTEXT.index_ref_stack.append(MergeContext(ctxtag, IndexMap(), inner)) try: yield GRAPH_CONTEXT.index_ref_stack[-1] finally: unflatten_ctx = GRAPH_CONTEXT.index_ref_stack.pop() index_ref = unflatten_ctx.index_ref if ctxtag is not None: if inner is None: raise ValueError('inner_merge must be specified when using ctxtag') ctx = current_update_context(ctxtag) ctx.unflatten_end(index_ref, inner) del unflatten_ctx.index_ref del unflatten_ctx.ctxtag @jax.tree_util.register_static @dataclasses.dataclass class UpdateContext: """A context manager for handling complex state updates.""" tag: tp.Hashable outer_ref_outer_index: RefMap | None outer_index_inner_ref: IndexMap | None # reverse caches outer_index_outer_ref: IndexMap | None inner_ref_outer_index: RefMap | None static_cache: tp.MutableMapping[tp.Any, StaticCache] | None # define hash and eq to make this an opaque object def __hash__(self): return 0 def __eq__(self, other): return isinstance(other, UpdateContext) def flatten_end(self, ref_index: RefMap): if self.outer_ref_outer_index is None: # outer split (1), store the references self.outer_ref_outer_index = ref_index self.outer_index_outer_ref = IndexMap.from_refmap( self.outer_ref_outer_index ) else: # inner split (3), clear index_ref self.outer_index_inner_ref = None self.inner_ref_outer_index = None def unflatten_end(self, index_ref: IndexMap, inner_merge: bool): if inner_merge: # inner merge (2) self.outer_index_inner_ref = index_ref self.inner_ref_outer_index = RefMap.from_indexmap(index_ref) @dataclasses.dataclass class UpdateContextManager: tag: tp.Hashable def __enter__(self): if GRAPH_CONTEXT.tmp_static_cache is not None: # take current static cache static_cache = GRAPH_CONTEXT.tmp_static_cache GRAPH_CONTEXT.tmp_static_cache = None else: static_cache = None ctx = UpdateContext( tag=self.tag, outer_ref_outer_index=None, outer_index_inner_ref=None, outer_index_outer_ref=None, inner_ref_outer_index=None, static_cache=static_cache, ) if self.tag not in GRAPH_CONTEXT.update_context_stacks: GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx] else: GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) return ctx def __exit__(self, *args): if self.tag not in GRAPH_CONTEXT.update_context_stacks: raise RuntimeError( f'No update context found for tag {self.tag!r}, this is a bug.' ) stack = GRAPH_CONTEXT.update_context_stacks[self.tag] ctx = stack.pop() # clear references del ctx.outer_ref_outer_index del ctx.outer_index_inner_ref del ctx.outer_index_outer_ref del ctx.inner_ref_outer_index if not stack: del GRAPH_CONTEXT.update_context_stacks[self.tag] def __call__(self, f: F) -> F: @functools.wraps(f) def update_context_manager_wrapper(*args, **kwargs): with self: return f(*args, **kwargs) return update_context_manager_wrapper # type: ignore def update_context(tag: tp.Hashable): """Creates an :class:`UpdateContext` context manager which can be used to handle more complex state updates beyond what ``nnx.update`` can handle, including updates to static properties and graph structure. UpdateContext exposes a ``split`` and ``merge`` API with the same signature as ``nnx.split`` / ``nnx.merge`` but performs some bookkeeping to have the necessary information in order to perfectly update the input objects based on the changes made inside the transform. The UpdateContext must call split and merge a total of 4 times, the first and last calls happen outside the transform and the second and third calls happen inside the transform as shown in the diagram below:: idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap The first call to split ``(1)`` creates a ``refmap`` which keeps track of the outer references, and the first call to merge ``(2)`` creates an ``idxmap`` which keeps track of the inner references. The second call to split ``(3)`` combines the refmap and idxmap to produce the ``index_mapping`` which indicates how the outer references map to the inner references. Finally, the last call to merge ``(4)`` uses the index_mapping and the refmap to reconstruct the output of the transform while reusing/updating the inner references. To avoid memory leaks, the idxmap is cleared after ``(3)`` and the refmap is cleared after ``(4)``, and both are cleared after the context manager exits. Here is a simple example showing the use of ``update_context``:: >>> from flax import nnx ... >>> class Foo(nnx.Module): pass ... >>> m1 = Foo() >>> with nnx.update_context('example'): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 Note that ``update_context`` takes in a ``tag`` argument which is used primarily as a safety mechanism reduce the risk of accidentally using the wrong UpdateContext when using :func:`current_update_context` to access the current active context. ``update_context`` can also be used as a decorator that creates/activates an UpdateContext context for the duration of the function:: >>> from flax import nnx ... >>> class Foo(nnx.Module): pass ... >>> m1 = Foo() >>> @jax.jit ... def f(graphdef, state): ... with nnx.merge_context('example', inner=True) as ctx: ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... with nnx.split_context('example') as ctx: ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... with nnx.split_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... with nnx.merge_context('example', inner=False) as ctx: ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1 The context can be accessed using :func:`current_update_context`. Args: tag: A string tag to identify the context. """ return UpdateContextManager(tag=tag) def current_update_context(tag: tp.Hashable) -> UpdateContext: """Returns the current active :class:`UpdateContext` for the given tag.""" if tag not in GRAPH_CONTEXT.update_context_stacks: raise ValueError(f'No update context found for tag {tag!r}.') return GRAPH_CONTEXT.update_context_stacks[tag][-1] # -------------------------------------------------------- # Functional API # -------------------------------------------------------- def _split_state( state: FlatState[tp.Any], filters: tuple[filterlib.Filter, ...], ) -> tuple[FlatState[tp.Any], tpe.Unpack[tuple[FlatState[tp.Any], ...]]]: if not filters: return (state,) # type: ignore[bad-return-type] states = state.split(*filters) if not isinstance(states, tuple): return (states,) # type: ignore[bad-return-type] assert len(states) > 0 return states # type: ignore[return-value] @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, /, *, graph: bool | None = None, ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, /, *, graph: bool | None = None, ) -> tuple[GraphDef[A], GraphState]: ... @tp.overload def split( # type: ignore[invalid-annotation] graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, graph: bool | None = None, ) -> tuple[ GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]], ]: ... def split( # type: ignore[invalid-annotation] node: A, *filters: filterlib.Filter, graph: bool | None = None, ) -> tuple[ GraphDef[A], GraphState, tpe.Unpack[tuple[GraphState, ...]], ]: """Split a graph node into a :class:`GraphDef` and one or more :class:`State`s. State is a ``Mapping`` from strings or integers to ``Variables``, Arrays or nested States. GraphDef contains all the static information needed to reconstruct a ``Module`` graph, it is analogous to JAX's ``PyTreeDef``. :func:`split` is used in conjunction with :func:`merge` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': Param( value=(2,) ), 'scale': Param( value=(2,) ) }, 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': BatchStat( value=(2,) ), 'var': BatchStat( value=(2,) ) } }) :func:`split` and :func:`merge` are primarily used to interact directly with JAX transformations, see `Functional API `__ for more information. Arguments: node: graph node to split. *filters: some optional filters to group the state into mutually exclusive substates. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: ``GraphDef`` and one or more ``States`` equal to the number of filters passed. If no filters are passed, a single ``State`` is returned. """ if graph is None: graph = set_graph_mode.current_value() graphdef, flat_state = flatten(node, graph=graph) flat_states = _split_state(flat_state, filters) states = _to_nested_state(graphdef, flat_states) return graphdef, *states # type: ignore[return-value] def _to_nested_state( graphdef: GraphDef[A], flat_states: tp.Iterable[tp.Any] ) -> tuple[tp.Any, ...]: def _nested_or_leaf(flat_state): if not flat_state: return State({}) if len(flat_state) == 1 and flat_state[0][0] == (): return flat_state[0][1] return statelib.from_flat_state(flat_state) states = tuple( _nested_or_leaf(flat_state) for flat_state in flat_states ) return states def _merge_to_flat_state(states: tp.Iterable[tp.Any]): flat_state: list[tuple[PathParts, tp.Any]] = [] for state in states: if isinstance(state, dict | State): flat_state.extend(traversals.flatten_to_sequence(state)) elif isinstance(state, FlatState): flat_state.extend(state) else: flat_state.append(((), state)) flat_state.sort() return [value for _, value in flat_state] def merge( # type: ignore[invalid-annotation] graphdef: GraphDef[A], state: tp.Any, /, *states: tp.Any, copy: bool = False, ) -> A: """The inverse of :func:`flax.nnx.split`. ``nnx.merge`` takes a :class:`flax.nnx.GraphDef` and one or more :class:`flax.nnx.State`'s and creates a new node with the same structure as the original node. Recall: :func:`flax.nnx.split` is used to represent a :class:`flax.nnx.Module` by: 1) a static ``nnx.GraphDef`` that captures its Pythonic static information; and 2) one or more :class:`flax.nnx.Variable` ``nnx.State``'(s) that capture its ``jax.Array``'s in the form of JAX pytrees. ``nnx.merge`` is used in conjunction with ``nnx.split`` to switch seamlessly between stateful and stateless representations of the graph. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear) ``nnx.split`` and ``nnx.merge`` are primarily used to interact directly with JAX transformations (refer to `Functional API `__ for more information. Args: graphdef: A :class:`flax.nnx.GraphDef` object. state: A :class:`flax.nnx.State` object. *states: Additional :class:`flax.nnx.State` objects. copy: Whether to create new copies of the Variables in the states, defaults to ``False``. Returns: The merged :class:`flax.nnx.Module`. """ if isinstance(state, list): if len(states) != 0: raise ValueError(f'Only one state can be passed as a list.') _state = state else: _state = _merge_to_flat_state((state, *states)) node = unflatten(graphdef, _state, copy_variables=copy) return node def update(node, state: tp.Any, /, *states: tp.Any) -> None: """Update the given graph node with a new state(s) in-place. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss Args: node: A graph node to update. state: A :class:`State` object. *states: Additional :class:`State` objects. """ if states: if isinstance(node, Variable): non_empty_states = [ _state for _state in (state, *states) if not isinstance(_state, tp.Mapping) or _state ] if len(non_empty_states) != 1: all_states = (state, *states) raise ValueError( f'Expected exactly one non-empty state, got: {all_states!r}' ) state = non_empty_states[0] else: state = statelib.merge_state(state, *states) _graph_update_dynamic(node, state) @tp.overload def state(node, /, *, graph: bool | None = None) -> GraphState: ... @tp.overload def state(node, first: filterlib.Filter, /, *, graph: bool | None = None) -> GraphState: ... @tp.overload def state( node, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, graph: bool | None = None, ) -> tuple[GraphState, ...]: ... def state( node, *filters: filterlib.Filter, graph: bool | None = None, ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Similar to :func:`split` but only returns the :class:`State`'s indicated by the filters. Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model) Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: One or more :class:`State` mappings. """ if graph is None: graph = set_graph_mode.current_value() _, flat_state = flatten(node, graph=graph) state = flat_state.to_nested_state() states: GraphState | tuple[GraphState, ...] if len(filters) == 0: states = state # type: ignore[assignment] elif len(filters) == 1: states = statelib.filter_state(state, filters[0]) else: states = statelib.filter_state(state, filters[0], filters[1], *filters[2:]) return states variables = state def map( f: tp.Callable[[tuple, tp.Any], tp.Any], node: A, /, *, graph: bool | None = None, ) -> A: """Map a function over the state of a graph node. ``map`` extracts the state from ``node`` using :func:`split`, applies ``f`` to every ``(path, value)`` pair using :func:`map_state`, and returns a new node with the mapped values merged back into the original structure. Note that the leaves in the state are :class:`Variable` objects, so ``f`` should handle them accordingly. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> new_model = nnx.map(lambda path, v: v.replace(jnp.zeros_like(v)), model) >>> assert jnp.all(new_model.kernel[...] == 0) >>> assert jnp.all(new_model.bias[...] == 0) Args: f: A callable ``(path, value) -> new_value`` applied to each leaf in the state. ``path`` is a tuple of path parts and ``value`` is the corresponding leaf (typically a :class:`Variable`). node: A graph node object. graph: If ``True``, uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A :class:`State` with the mapped values. """ graphdef, state = split(node, graph=graph) state = statelib.map_state(f, state) return merge(graphdef, state) def graphdef( node: tp.Any, /, *, graph: bool | None = None, ) -> GraphDef[tp.Any]: """Get the :class:`GraphDef` of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model) Args: node: A graph node object. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: The :class:`GraphDef` of the :class:`Module` object. """ if graph is None: graph = set_graph_mode.current_value() graphdef, _ = flatten(node, graph=graph) return graphdef @tp.overload def pop( node, filter: filterlib.Filter, /, ) -> GraphState: ... @tp.overload def pop( node, filter: filterlib.Filter, filter2: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[GraphState, ...]: ... def pop( node, *filters: filterlib.Filter ) -> tp.Union[GraphState, tuple[GraphState, ...]]: """Pop one or more :class:`Variable` types from the graph node. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.i = nnx.Intermediate(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].shape == (1, 3) >>> assert not hasattr(model, 'i') Args: node: A graph node object. *filters: One or more :class:`Variable` objects to filter by. Returns: The popped :class:`State` containing the :class:`Variable` objects that were filtered for. """ if len(filters) == 0: raise ValueError('Expected at least one filter') id_to_index: dict[int, Index] = {} path_parts: PathParts = () predicates = tuple(filterlib.to_predicate(filter) for filter in filters) flat_states: tuple[dict[PathParts, LeafType], ...] = tuple( {} for _ in predicates ) _graph_pop( node=node, id_to_index=id_to_index, path_parts=path_parts, flat_states=flat_states, predicates=predicates, ) states = tuple( statelib.from_flat_state(flat_state) for flat_state in flat_states ) if len(states) == 1: return states[0] else: return states def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node: """Create a deep copy of the given graph node. Example usage:: >>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias[...] += 1 >>> assert (model.bias[...] != cloned_model.bias[...]).all() Args: node: A graph node object. variables: If ``True`` (default) copies of the :class:`Variable` objects are created, otherwise the Variables are shared between the original and cloned node. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A deep copy of the :class:`Module` object. """ graphdef, state = split(node, graph=graph) return merge(graphdef, state, copy=variables) def vars_as( node: A, /, *, hijax: bool | None = None, ref: bool | None = None, mutable: bool | None = None, only: filterlib.Filter = ..., allow_duplicates: bool = False, ) -> A: """ """ new_attrs: dict[str, bool] = {} if hijax is not None: new_attrs['hijax'] = hijax if ref is not None: new_attrs['ref'] = ref if mutable is not None: new_attrs['mutable'] = mutable def _different_vars(path, x): return isinstance(x, Variable) and any( getattr(x, attr) != value for attr, value in new_attrs.items() ) only = filterlib.All(_different_vars, only) predicate = filterlib.to_predicate(only) if not allow_duplicates and ( all_duplicates := find_duplicates(node, only=only) ): duplicates_strs = '\n ---' for node_duplicates in all_duplicates: for path in node_duplicates: path_str = '/'.join(builtins.map(str, path)) duplicates_strs += f'\n {path_str}' duplicates_strs += '\n ---' raise ValueError(f'Found duplicate at paths:{duplicates_strs}') def _to_refs(jax_path, x): if predicate(jax_to_nnx_path(jax_path), x): assert isinstance(x, Variable) variable = x.copy(**new_attrs) return variable return x node = jax.tree.map_with_path( _to_refs, node, is_leaf=lambda x: isinstance(x, Variable) ) return node def pure(tree: A) -> A: """Returns a new tree with all ``Variable`` objects replaced with inner values. This can be used to remove Variable metadata when its is not needed for tasks like serialization or exporting. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, state = nnx.split(model) >>> jax.tree.map(jnp.shape, state) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) }) >>> pure_state = nnx.pure(state) >>> jax.tree.map(jnp.shape, pure_state) State({ 'bias': (3,), 'kernel': (2, 3) }) Args: tree: A pytree potentially containing ``Variable`` objects. Returns: A new pytree with all ``Variable`` objects replaced with their inner values. """ def _pure_fn(x): if isinstance(x, Variable): return pure(x.get_raw_value()) elif variablelib.is_array_ref(x): return x[...] return x return jax.tree.map( _pure_fn, tree, is_leaf=lambda x: isinstance(x, Variable), ) def call( graphdef_state: tuple[GraphDef[A], GraphState], / ) -> ApplyCaller[tuple[GraphDef[A], GraphState]]: """Calls a method underlying graph node defined by a (GraphDef, State) pair. ``call`` takes a ``(GraphDef, State)`` pair and creates a proxy object that can be used to call methods on the underlying graph node. When a method is called, the output is returned along with a new (GraphDef, State) pair that represents the updated state of the graph node. ``call`` is equivalent to :func:`merge` > ``method`` > :func:`split` but is more convenient to use in pure JAX functions. Example:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count[...] += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count[...] Array(2, dtype=uint32) The proxy object returned by ``call`` supports indexing and attribute access to access nested methods. In the example below, the ``increment`` method indexing is used to call the ``increment`` method of the ``StatefulLinear`` module at the ``b`` key of a ``nodes`` dictionary. >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count[...] += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count[...] Array(0, dtype=uint32) >>> nodes['b'].count[...] Array(1, dtype=uint32) """ def pure_caller(accessor: DelayedAccessor, *args, **kwargs): node = merge(*graphdef_state) method = accessor(node) out = method(*args, **kwargs) return out, split(node) return CallableProxy(pure_caller) # type: ignore def set_metadata( node: tp.Any, /, *, only: filterlib.Filter = Variable, **metadata: tp.Any ) -> None: """Sets the metadata of all :class:`Variable` objects in the given graph node in-place. Example:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self): ... self.param = nnx.Param(0.0) ... self.variable = nnx.Variable(0.0) ... >>> node = Foo() ... >>> # set differentiable to False for all nnx.Param objects >>> nnx.set_metadata(node, differentiable=False, only=nnx.Param) ... >>> # check that only the nnx.Param was updated >>> assert node.param.get_metadata('differentiable') is False Args: node: A graph node object. only: A Filter to specify which :class:`Variable` objects to set metadata for. metadata: Key-value pairs to set as metadata on the :class:`Variable` objects. """ def _set_metadata(path: PathParts, variable: V) -> None: del path # unused if isinstance(variable, Variable): variable.set_metadata(**metadata) # inplace update of variable_state metadata map_state(_set_metadata, state(node, only)) def iter_graph( node: tp.Any, /, *, graph: bool | None = None, ) -> tp.Iterator[tuple[PathParts, tp.Any]]: """Iterates over all nested nodes and leaves of the given graph node, including the current node. ``iter_graph`` creates a generator that yields path and value pairs, where the path is a tuple of strings or integers representing the path to the value from the root. Repeated nodes are visited only once. Leaves include static values. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, '_pytree__nodes') HashableMapping (0, '_pytree__state') PytreeState (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list Args: node: A graph node object. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ if graph is None: graph = set_graph_mode.current_value() if graph: return _iter_graph(node) else: return _iter_tree(node) def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: visited: set[int] = set() stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)] while stack: path_parts, node, traversed = stack.pop(-1) if traversed or not (is_node(node) or isinstance(node, Variable)): yield path_parts, node continue if id(node) in visited: continue visited.add(id(node)) if (node_impl := get_node_impl(node)) is None: yield path_parts, node continue stack.append((path_parts, node, True)) for key, child in reversed(node_impl.node_dict(node).items()): stack.append(((*path_parts, key), child, False)) def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: in_progress: dict[int, str] = {} seen_refs: dict[int, str] = {} stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)] while stack: path, current, traversed = stack.pop() if traversed: in_progress.pop(id(current), None) yield path, current continue if not is_pytree_node(current, check_graph_registry=False): _check_valid_pytree(current, 'iter_graph', '/'.join(builtins.map(str, path))) if isinstance(current, Variable) or variablelib.is_array_ref(current): obj_id = id(current) str_path = '/'.join(builtins.map(str, path)) if obj_id in seen_refs: raise ValueError( f'Duplicate {current}\nfound at paths:\n\n' f' - {seen_refs[obj_id]}\n' f' - {str_path}\n\n' 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion_api('iter_graph') ) seen_refs[obj_id] = str_path yield path, current continue obj_id = id(current) str_path = '/'.join(builtins.map(str, path)) if obj_id in in_progress: raise ValueError( f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n' f' - {in_progress[obj_id]}\n' f' - {str_path}\n\n' 'Cycles are not supported with graph=False. ' + _tree_mode_suggestion_api('iter_graph') ) in_progress[obj_id] = str_path stack.append((path, current, True)) children, _ = jax.tree_util.tree_flatten_with_path( current, is_leaf=lambda x: x is not current ) for jax_key_path, child in reversed(children): key = _key_path_to_key(jax_key_path[0]) stack.append(((*path, key), child, False)) def iter_children( node: tp.Any, /, *, graph: bool | None = None, ) -> tp.Iterator[tuple[Key, tp.Any]]: """Iterates over all immediate child nodes of a given node. This function is similar to :func:`iter_graph`, except it only iterates over the immediate children, and does not recurse further down. Specifically, this function creates a generator that yields the key and the child node instance, where the key is a string representing the attribute name to access the corresponding child. Example:: >>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in nnx.iter_children(model): ... print(path, type(module).__name__) ... batch_norm BatchNorm dropout Dropout linear Linear submodule SubModule Args: node: A graph node object. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ if graph is None: graph = set_graph_mode.current_value() if graph: node_impl = get_node_impl(node) if node_impl is None: raise ValueError( f'Expected a graph node, got {type(node).__name__}. ' 'If this is a regular pytree, use graph=False.' ) node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): if is_graph_node(value): yield key, value else: _check_valid_pytree(node, 'iter_children') if not is_pytree_node(node, check_graph_registry=False): raise ValueError( f'Expected a pytree node, got {type(node).__name__}. ' 'If this is a graph node, use graph=True.' ) children, _ = jax.tree_util.tree_flatten_with_path( node, is_leaf=lambda x: x is not node ) for jax_key_path, child in children: if is_graph_node(child): key = _key_path_to_key(jax_key_path[0]) yield key, child def recursive_map( f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /, *, graph: bool | None = None, ): """Recursively applies a function to all nodes and leaves of the given graph node. Example:: >>> from flax import nnx >>> class MyModule(nnx.Module): ... def __init__(self, *, rngs: nnx.Rngs): ... self.lin = nnx.Linear(16, 16, rngs=rngs) ... self.conv = nnx.Conv(16, 3, 1, 1, rngs=rngs) ... >>> def print_modules(path, node): ... if isinstance(node, nnx.Module): ... s = "." + ".".join(path) ... print(f"Path = {s:<10}{node.__class__.__name__}") ... return node ... >>> model = MyModule(rngs=nnx.Rngs(0)) >>> new_model = nnx.recursive_map(print_modules, model) ... Path = .conv Conv Path = .lin Linear Path = . MyModule Args: f: A function that takes a path and a node and returns a new node. node: A graph node object. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ if graph is None: graph = set_graph_mode.current_value() if graph: node = clone(node, variables=False, graph=True) path_parts: PathParts = () visited: set[int] = set() results: dict[int, tp.Any] = {} return _recursive_map_graph(f, node, path_parts, visited, results) else: return _recursive_map_tree(f, node) def _recursive_map_graph( f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, path: PathParts, visited: set[int], results: dict[int, tp.Any], ) -> tp.Any: node_id = id(node) if node_id in visited: if node_id in results: return results[node_id] path_str = '/'.join(builtins.map(str, path)) raise ValueError( f"Found cycle in the graph at path '{path_str}'. Node of type" f' {type(node)} has already been visited but has not been returned yet.' ) node_impl = get_node_impl(node) if ( type(node_impl) is GraphNodeImpl or isinstance(node, Variable) or is_array_ref(node) ): visited.add(node_id) if node_impl is not None: for key, value in node_impl.node_dict(node).items(): new_value = _recursive_map_graph(f, value, (*path, key), visited, results) if new_value is not value: if node_impl.set_key is not None and value is not new_value: node_impl.set_key(node, key, new_value) else: raise ValueError( f"Cannot update key '{key}' for node of type '{type(node)}'" ' because the node does not support mutation.' ) new_node = f(path, node) results[node_id] = new_node return new_node def _recursive_map_tree( f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, ) -> tp.Any: in_progress: dict[int, str] = {} seen_refs: dict[int, str] = {} def _recurse(path: PathParts, current: tp.Any) -> tp.Any: if not is_pytree_node(current, check_graph_registry=False): _check_valid_pytree(current, 'recursive_map', '/'.join(builtins.map(str, path))) if isinstance(current, Variable) or is_array_ref(current): obj_id = id(current) str_path = '/'.join(builtins.map(str, path)) if obj_id in seen_refs: raise ValueError( f'Duplicate {current}\nfound at paths:\n\n' f' - {seen_refs[obj_id]}\n' f' - {str_path}\n\n' 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion_api('recursive_map') ) seen_refs[obj_id] = str_path return f(path, current) obj_id = id(current) str_path = '/'.join(builtins.map(str, path)) if obj_id in in_progress: raise ValueError( f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n' f' - {in_progress[obj_id]}\n' f' - {str_path}\n\n' 'Cycles are not supported with graph=False. ' + _tree_mode_suggestion_api('recursive_map') ) in_progress[obj_id] = str_path children_with_path, treedef = jax.tree_util.tree_flatten_with_path( current, is_leaf=lambda x: x is not current ) new_children = [] for jax_key_path, child in children_with_path: key = _key_path_to_key(jax_key_path[0]) new_child = _recurse((*path, key), child) new_children.append(new_child) new_node = treedef.unflatten(new_children) result = f(path, new_node) in_progress.pop(obj_id, None) return result return _recurse((), node) def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]: """Finds duplicate nodes or node leaves in the given node. This function traverses the graph node and collects paths to nodes and leaves that have the same identity. It returns a list of lists, where each inner list contains paths to nodes or leaves that are duplicates. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class SharedVariables(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.array(1.0)) ... self.b = nnx.Param(jnp.array(2.0)) ... self.c = self.b # shared Variable ... >>> model = SharedVariables() >>> duplicates = nnx.find_duplicates(model) >>> len(duplicates) 1 >>> for path in duplicates[0]: ... print(path) ('b',) ('c',) ``find_duplicates`` will also find duplicates nodes such as Modules that are referenced multiple times in the graph:: >>> class SharedModules(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.a = nnx.Linear(1, 1, rngs=rngs) ... self.b = nnx.Linear(1, 1, rngs=rngs) ... self.c = self.a # shared Module ... >>> model = SharedModules(nnx.Rngs(0)) >>> for duplicate_paths in nnx.find_duplicates(model): ... print(duplicate_paths) [('a',), ('c',)] Args: node: A graph node object. only: A Filter to specify which nodes or leaves to consider for duplicates. Returns: A list of lists, where each inner list contains the different paths for a for a duplicate node or leaf. """ node_paths: dict[int, list[PathParts]] = {} duplicate_candidate = filterlib.to_predicate(only) _node_paths(node, node_paths, (), duplicate_candidate) _duplicates = [paths for paths in node_paths.values() if len(paths) > 1] return _duplicates def _node_paths( node: tp.Any, node_paths: dict[int, list[PathParts]], path: PathParts, duplicate_candidate: filterlib.Predicate, /, ): _is_graph_node = is_graph_node(node) _is_pytree_node = is_pytree_node(node) _is_node_leaf = is_node_leaf(node) if _is_graph_node or _is_pytree_node or _is_node_leaf: node_id = id(node) if node_id in node_paths: if (_is_graph_node or _is_node_leaf) and duplicate_candidate(path, node): node_paths[node_id].append(path) return if _is_graph_node or _is_node_leaf: node_paths[node_id] = [path] node_impl = get_node_impl(node) if node_impl is None: return node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): _node_paths(value, node_paths, (*path, key), duplicate_candidate) @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) class Static(tp.Generic[A]): """An empty pytree node that treats its inner value as static. ``value`` must define ``__eq__`` and ``__hash__``. """ value: A # --------------------------------------------------------- # Pytree # --------------------------------------------------------- class GenericPytree: ... from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY def is_pytree_node( x: tp.Any, *, check_graph_registry: bool = True, ) -> bool: if check_graph_registry and type(x) in GRAPH_REGISTRY: return False elif isinstance(x, Variable): return False elif type(x) in JAX_PYTREE_REGISTRY: return True elif isinstance(x, tuple): return True else: return False def _key_path_to_key(key: tp.Any) -> Key: if isinstance(key, jax.tree_util.SequenceKey): return key.idx elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): if not is_key_like(key.key): # type: ignore[not-supported-yet] raise ValueError( f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' ) return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name else: return str(key) def jax_to_nnx_path(jax_path: tuple, /): return tuple(_key_path_to_key(part) for part in jax_path) class IndexesPytreeDef(tp.NamedTuple): key_index: HashableMapping[Key, int] treedef: jax.tree_util.PyTreeDef def _flatten_pytree(pytree: tp.Any): leaves, treedef = jax.tree_util.tree_flatten_with_path( pytree, is_leaf=lambda x: x is not pytree ) nodes = [(_key_path_to_key(path[0]), value) for path, value in leaves] key_index = HashableMapping( {key: i for i, (key, _) in enumerate(nodes)}, copy=False ) # Sort by key to match the path-sorted order used by _merge_to_flat_state. # key_index records the original jax tree_flatten order so _unflatten_pytree # can restore it before calling treedef.unflatten. nodes.sort() return nodes, IndexesPytreeDef(key_index, treedef) def _unflatten_pytree( nodes: tuple[tuple[Key, tp.Any], ...], metadata: IndexesPytreeDef ): # sort to original order sorted_nodes = sorted(nodes, key=lambda x: metadata.key_index[x[0]]) pytree = metadata.treedef.unflatten(value for _, value in sorted_nodes) return pytree PYTREE_NODE_IMPL = PytreeNodeImpl( type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, # type: ignore set_key=None, pop_key=None, ) def _list_set_key(x: list[tp.Any], key: int, value: tp.Any): x[key] = value # common pytrees # list register_pytree_node_type( list, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore set_key=_list_set_key, # type: ignore ) # tuple register_pytree_node_type( tuple, flatten=lambda x: (list(enumerate(x)), None), unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore ) def _mutable_mapping_set_key( x: tp.MutableMapping[Key, tp.Any], key: Key, value: tp.Any ): x[key] = value def _mutable_mapping_pop_key(x: tp.MutableMapping[Key, tp.Any], key: Key): x.pop(key) # dict register_pytree_node_type( dict, flatten=lambda x: (sorted(x.items()), None), unflatten=lambda nodes, _: dict(nodes), # type: ignore set_key=_mutable_mapping_set_key, pop_key=_mutable_mapping_pop_key, ) # State register_pytree_node_type( State, flatten=lambda x: (sorted(x.raw_mapping.items()), None), unflatten=lambda nodes, _: State(nodes), # type: ignore set_key=_mutable_mapping_set_key, pop_key=_mutable_mapping_pop_key, ) # None register_pytree_node_type( type(None), flatten=lambda x: ([], None), unflatten=lambda _, __: None, # type: ignore ) ================================================ FILE: flax/nnx/helpers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import inspect import typing as tp import jax import jax.numpy as jnp import optax from flax.nnx import graphlib, reprlib from flax.nnx.graphlib import GraphDef from flax.nnx.module import Module from flax.nnx.proxy_caller import ApplyCaller from flax.nnx.rnglib import Rngs from flax.nnx.statelib import State from flax.training.train_state import struct from flax.nnx.variablelib import Variable A = tp.TypeVar('A') M = tp.TypeVar('M', bound=Module) TS = tp.TypeVar('TS', bound='TrainState') class Dict(reprlib.MappingReprMixin, Module, tp.MutableMapping[str, A]): """A Module that implements a mutable mapping. This class provides a way to store and manipulate a mapping of keys to values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types. Example: >>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.Dict({ ... 'kernel1': nnx.Linear(1, 32, rngs=rngs), # data ... 'activation1': nnx.relu, # static ... 'kernel2': nnx.Linear(32, 1, rngs=rngs), # data ... }) """ @tp.overload def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... @tp.overload def __init__( self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A ): ... def __init__(self, *args, **kwargs): for name, value in dict(*args, **kwargs).items(): setattr(self, name, value) def __getitem__(self, key) -> A: try: return getattr(self, key) except AttributeError as e: raise KeyError(key) from e def __setitem__(self, key, value): setattr(self, key, value) def __iter__(self) -> tp.Iterator[str]: return (k for k in vars(self) if not k.startswith('_pytree__')) def __len__(self) -> int: length = 0 for _ in self: length += 1 return length def __hash__(self) -> int: return id(self) def __delitem__(self, key: str) -> None: try: delattr(self, key) except AttributeError as e: raise KeyError(key) from e if tp.TYPE_CHECKING: def __getattr__(self, key: str) -> A: ... def __setattr__(self, key: str, value: A) -> None: ... class List(reprlib.SequenceReprMixin, Module, tp.MutableSequence[A]): """A Module that implements a mutable sequence. This class provides a way to store and manipulate a sequence of values contained a mixed set of data (e.g. Array, Variables, Modules) and static (e.g. functions, strings) types. Example: >>> from flax import nnx ... >>> rngs = nnx.Rngs(0) >>> layers = nnx.List([ ... nnx.Linear(1, 32, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(32, 1, rngs=rngs), # data ... ]) """ def __init__(self, it: tp.Iterable[A] | None = None, /): """ Args: it: An iterable of values to initialize the list. """ self._length: int = 0 if it is not None: for value in it: self.append(value) def _get_elem(self, key: int) -> A: return getattr(self, str(key)) def _set_elem(self, key: int, value: A) -> None: setattr(self, str(key), value) def _del_elem(self, key: int) -> None: delattr(self, str(key)) def __len__(self) -> int: return self._length def append(self, value: A) -> None: self._set_elem(self._length, value) self._length += 1 def insert(self, index: int, value: A) -> None: if index < 0: index += self._length if index < 0: index = 0 if index > self._length: index = self._length # Shift elements to the right for i in range(self._length, index, -1): self._set_elem(i, self._get_elem(i - 1)) # Insert the new value self._set_elem(index, value) self._length += 1 def __iter__(self) -> tp.Iterator[A]: for i in range(self._length): yield self._get_elem(i) @tp.overload def __getitem__(self, index: int) -> A: ... @tp.overload def __getitem__(self, index: slice) -> tp.List[A]: ... def __getitem__(self, index: int | slice) -> A | tp.List[A]: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') return self._get_elem(index) elif isinstance(index, slice): idxs = list(range(self._length))[index] return [self._get_elem(i) for i in idxs] else: raise TypeError('Invalid index type') def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') self._set_elem(index, value) # type: ignore[arg-type] elif isinstance(index, slice): if not isinstance(value, tp.Iterable): raise TypeError('Expected an iterable') values = list(value) idxs = list(range(self._length))[index] if len(idxs) != len(values): raise ValueError('Length mismatch') for i, v in zip(idxs, values): self._set_elem(i, v) else: raise TypeError('Invalid index type') def _graph_node_set_key(self, key: str, value: tp.Any): if not isinstance(key, int): raise KeyError(f'Invalid key: {key}') elif key < len(self): if isinstance(variable := self[key], Variable) and isinstance(value, Variable): variable.update_from_state(value) else: self[key] = value else: self.insert(key, value) def __delitem__(self, index: int | slice) -> None: if isinstance(index, int): if index < 0: index += self._length if index < 0 or index >= self._length: raise IndexError('Index out of bounds') self._del_elem(index) for i in range(index + 1, self._length): self._set_elem(i - 1, self._get_elem(i)) self._length -= 1 elif isinstance(index, slice): idxs = list(range(self._length))[index] for i in reversed(idxs): # implement recursively del self[i] else: raise TypeError('Invalid index type') _pytree__has_int_keys = True class Sequential(Module): """A Module that applies a sequence of callables. This class provides a way to store and manipulate a sequence of callables (e.g. layers, activation functions) and apply them in order. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(0) >>> model = nnx.Sequential( ... nnx.Linear(1, 4, rngs=rngs), # data ... nnx.relu, # static ... nnx.Linear(4, 2, rngs=rngs), # data ... ) >>> x = jnp.ones((1, 1)) >>> y = model(x) >>> y.shape (1, 2) """ def __init__(self, *fns: tp.Callable[..., tp.Any]): """ Args: *fns: A sequence of callables to apply. """ self.layers = List(fns) def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: if len(self.layers) == 0: if len(args) == 1: return args[0] elif len(args) > 0: return args elif len(kwargs) > 0: return kwargs else: return None output: tp.Any = None for i, f in enumerate(self.layers): if not callable(f): raise TypeError(f'Sequence[{i}] is not callable: {f}') if i > 0: if isinstance(output, tuple): args = output kwargs = {} elif isinstance(output, dict): args = () kwargs = output else: args = (output,) kwargs = {} if rngs is not None and has_keyword_arg(f, 'rngs'): kwargs['rngs'] = rngs output = f(*args, **kwargs) return output class ModuleDefApply(tp.Protocol, tp.Generic[M]): def __call__( self, state: State, *states: State ) -> ApplyCaller[tuple[State, GraphDef[M]]]: ... class TrainState(tp.Generic[M], struct.PyTreeNode): graphdef: graphlib.GraphDef[M] params: State opt_state: optax.OptState step: jax.Array tx: optax.GradientTransformation = struct.field(pytree_node=False) @classmethod def create( cls, graphdef: graphlib.GraphDef[M], *, params: State, tx: optax.GradientTransformation, step: int = 0, **kwargs, ): return cls( graphdef=graphdef, params=params, opt_state=tx.init(params), step=jnp.asarray(step), tx=tx, **kwargs, ) if tp.TYPE_CHECKING: def __getattr__(self, key: str) -> tp.Any: ... def apply( self, state: tp.Union[State, str], *states: tp.Union[State, str] ) -> ApplyCaller[tuple[GraphDef[M], State]]: states = (state, *states) _states: list[State] = [] for _state in states: if isinstance(_state, str): _state_key = _state _state = getattr(self, _state_key) if not isinstance(_state, State): raise TypeError( f'Expected {self.__class__.__name__}.{_state_key} to be a State, got {type(_state)}' ) _states.append(_state) return self.graphdef.apply(*_states) def apply_gradients(self: TS, grads: State, **kwargs) -> TS: updates, opt_state = self.tx.update(grads, self.opt_state, self.params) params = optax.apply_updates(self.params, updates) # type: ignore step = self.step + 1 return self.replace( params=params, opt_state=opt_state, step=step, **kwargs, ) def has_keyword_arg(func: tp.Callable[..., tp.Any], name: str) -> bool: """Return True if func has keyword-only arguments with the given name.""" return any( param.name == name and param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD) for param in inspect.signature(func).parameters.values() ) ================================================ FILE: flax/nnx/ids.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """UUIDs for Flax internals.""" import threading class UUIDManager: """Globally unique counter-based id manager. We need globally unique key ids for Module and Variable object instances to preserve and recreate sharing-by-reference relationship when lifting transforms and adopting outside Modules. - Use of id() is unacceptable because these identifiers are literally pointers which can be recycled, so we rely on a globally unique counter id instead. - We need to handle copy/deepcopy uniqueness via a wrapped type. """ def __init__(self): self._lock = threading.Lock() self._id = 0 def __call__(self): with self._lock: self._id += 1 return UUID(self._id) uuid = UUIDManager() class UUID: """Hashable wrapper for ids that handles uniqueness of copies.""" def __init__(self, rawid): self.id = rawid def __eq__(self, other): return isinstance(other, UUID) and other.id == self.id def __hash__(self): return hash(self.id) def __repr__(self): return f'UUID({self.id})' def __deepcopy__(self, memo): del memo return uuid() def __copy__(self): return uuid() ================================================ FILE: flax/nnx/module.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import inspect import typing as tp import jax import jax.numpy as jnp from flax.nnx import ( filterlib, graphlib, pytreelib, ) from flax.nnx import variablelib as variableslib from flax.nnx.pytreelib import Pytree, PytreeMeta from flax.nnx.graphlib import GraphState from flax.nnx.statelib import split_state, State import functools as ft from flax.typing import Key, Path, PathParts from collections.abc import MutableMapping import warnings A = tp.TypeVar('A') B = tp.TypeVar('B') M = tp.TypeVar('M', bound='Module') S = tp.TypeVar('S', bound=tp.Union[GraphState, tuple[GraphState, ...]]) V = tp.TypeVar('V', bound=variableslib.Variable[tp.Any]) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) StateMapping = tp.Mapping[Path, tp.Any] tuple_reduce = lambda xs, x: xs + (x,) tuple_init = lambda: () class ModuleMeta(PytreeMeta): # we keep a trivial derived class just in case we need to # add more functionality in the future pass class Module(Pytree, metaclass=ModuleMeta): """Base class for all neural network modules. Layers and models should subclass this class. ``Module``'s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the ``__init__`` method. You can define arbitrary "forward pass" methods on your ``Module`` subclass. While no methods are special-cased, ``__call__`` is a popular choice since you can call the ``Module`` directly:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) """ def sow( self, variable_type: type[variableslib.Variable[B]] | str, name: str, value: A, reduce_fn: tp.Callable[[B, A], B] = tuple_reduce, init_fn: tp.Callable[[], B] = tuple_init, # type: ignore ) -> bool: """Store intermediate values during module execution for later extraction. Used with :func:`nnx.capture` decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specified ``name`` in a collection associated with ``variable_type``. By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'features', x) ... x = self.linear2(x) ... return x >>> # With the capture decorator, sow returns intermediates >>> model = Model(rngs=nnx.Rngs(0)) >>> @nnx.capture(nnx.Intermediate) ... def forward(model, x): ... return model(x) >>> result, intermediates = forward(model, jnp.ones(2)) >>> assert 'features' in intermediates Custom init/reduce functions can be passed to control accumulation:: >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... return x Args: variable_type: The :class:`Variable` type for the stored value. Typically :class:`Intermediate` or a subclass is used. name: A string key for storing the value in the collection. value: The value to be stored. reduce_fn: Function to combine existing and new values. Default appends to a tuple. init_fn: Function providing initial value for first ``reduce_fn`` call. Default is an empty tuple. """ if isinstance(variable_type, str): variable_type = variableslib.variable_type_from_name( variable_type, allow_register=True ) if hasattr(self, '__captures__'): for var in self.__captures__: if type(var) == variable_type: if name in var: var[name] = reduce_fn(var[name], value) else: var[name] = reduce_fn(init_fn(), value) return True else: return False elif hasattr(self, name): variable = getattr(self, name) if not isinstance(variable, variableslib.Variable): raise ValueError( f"Expected '{name}' to be a Variable, got {type(variable).__name__}" ) elif type(variable) != variable_type: raise ValueError( f"Expected '{name}' to be of type '{variable_type.__name__}', " f"got '{type(variable).__name__}'" ) variable.set_value(reduce_fn(variable.get_value(), value)) else: reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) warnings.warn( """Using 'Module.sow()' outside of 'nnx.capture()' is deprecated; see https://flax.readthedocs.io/en/stable/capturing_intermediates.html for more information. """, DeprecationWarning, stacklevel=2, ) return True def perturb( self, name: str, value: tp.Any, variable_type: ( str | type[variableslib.Variable[tp.Any]] ) = variableslib.Perturbation, ): """Extract gradients of intermediate values during training. Used with :func:`nnx.capture` to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored under ``name`` in the current capture context, allowing gradient computation via ``nnx.grad``. The workflow has four steps: 1. Initialize perturbations with ``nnx.capture(model, nnx.Perturbation)`` 2. Run model with ``nnx.capture(model, nnx.Intermediate, init=perturbations)`` 3. Take gradients with respect to perturbations using ``nnx.grad`` 4. Combine results with ``nnx.merge_state(perturb_grads, intermediates)`` .. note:: This creates extra variables of the same size as ``value``, thus occupies more memory. Use it only to debug gradients in training. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __call__(self, x): ... x2 = self.perturb('grad_of_x', x) ... return 3 * x2 >>> model = Model() >>> x = 1.0 >>> # Step 1: Initialize perturbations >>> forward = nnx.capture(model, nnx.Perturbation) >>> _, perturbations = forward(x) >>> # Steps 2-4: Capture gradients >>> def train_step(model, perturbations, x): ... def loss(model, perturbations, x): ... return nnx.capture(model, nnx.Intermediate, init=perturbations)(x) ... (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) ... return nnx.merge_state(perturb_grads, sowed) >>> metrics = train_step(model, perturbations, x) >>> # metrics contains gradients of intermediate values Args: name: A string key for storing the perturbation value. value: The intermediate value to capture gradients for. You must use the returned value (not the original) for gradient capturing to work. variable_type: The :class:`Variable` type for the stored perturbation. Default is :class:`nnx.Perturbation`. """ if isinstance(variable_type, str): variable_type = variableslib.variable_type_from_name( variable_type, allow_register=True ) if hasattr(self, '__captures__'): for var in self.__captures__: if type(var) == variable_type: if name not in var: zeros = jax.tree.map(jnp.zeros_like, value) var[name] = zeros old_value = var[name] return old_value + value else: return value elif hasattr(self, name): var = getattr(self, name) if not isinstance(var, variable_type): raise ValueError( f"Expected '{name}' to be of type '{variable_type.__name__}', " f"got '{type(var).__name__}'" ) old_value = var.get_value() else: old_value = jax.tree.map(jnp.zeros_like, value) setattr(self, name, variable_type(old_value)) warnings.warn(""" Using 'Module.perturb()' outside of 'nnx.capture()' is deprecated; see https://flax.readthedocs.io/en/stable/capturing_intermediates.html for more information. """, DeprecationWarning, stacklevel=2) return old_value + value def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: """ Warning: this method is method is deprecated; use :func:`iter_modules` instead. Recursively iterates over all nested :class:`Module`'s of the current Module, including the current Module. Alias of :func:`iter_modules`. """ warnings.warn( "The 'm.iter_modules()' method is deprecated; use the 'nnx.iter_modules(m)' function instead.", DeprecationWarning, stacklevel=2, ) yield from iter_modules(self) def iter_children(self) -> tp.Iterator[tuple[Key, Module]]: """ Warning: this method is method is deprecated; use :func:`iter_children` instead. Iterates over all children :class:`Module`'s of the current Module. This method is similar to :func:`iter_modules`, except it only iterates over the immediate children, and does not recurse further down. Alias of :func:`iter_children`. """ warnings.warn( "The 'm.iter_children()' method is deprecated; use the 'nnx.iter_children(m)' function instead.", DeprecationWarning, stacklevel=2, ) yield from iter_children(self) def set_attributes( self, *filters: filterlib.Filter, raise_if_not_found: bool = True, graph: bool | None = None, **attributes: tp.Any, ) -> None: """Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) ``Filter``'s can be used to set the attributes of specific Modules:: >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False) Args: *filters: Filters to select the Modules to set the attributes of. raise_if_not_found: If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules. **attributes: The attributes to set. """ remaining_attributes = set(attributes.keys()) if not filters: filters = (True,) predicates = tuple(map(filterlib.to_predicate, filters)) for path, module in iter_modules(self, graph=graph): for predicate in predicates: if predicate(path, module): for name, value in attributes.items(): if hasattr(module, name): if name in remaining_attributes: remaining_attributes.remove(name) setattr(module, name, value) break if remaining_attributes and raise_if_not_found: raise ValueError( 'Could not find at least one instance of the following' f' attributes: {sorted(remaining_attributes)}' ) def train(self, **attributes): """Sets the Module to training mode. ``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False`` and ``use_running_average=False`` of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` Modules. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) Args: **attributes: additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=False, use_running_average=False, **attributes, raise_if_not_found=False, ) def eval(self, **attributes): """Sets the Module to evaluation mode. ``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True`` and ``use_running_average=True`` of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` Modules. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) Args: **attributes: additional attributes passed to ``set_attributes``. """ return self.set_attributes( deterministic=True, use_running_average=True, **attributes, raise_if_not_found=False, ) def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, graph: bool | None = None, **kwargs) -> A: """Creates a new node with static attributes updated according to ``**kwargs``. The new node contains references to jax arrays in the original node. If a kwarg is not found in any module, this method raises a ValueError. Uses the ``set_view`` class method in nnx.Modules. ``set_view`` class methods should return any unused kwargs. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> new_block = nnx.view(block, deterministic=True, use_running_average=True) >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average (True, True) ``Filter``'s can be used to set the attributes of specific Modules:: >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average (True, False) Args: node: the object to create a copy of. only: Filters to select the Modules to set the attributes of. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. **kwargs: The attributes to set. """ predicate = filterlib.to_predicate(only) remaining = set(kwargs) def _set_mode_fn(path, node): if hasattr(node, 'set_view') and predicate(path, node): sig = inspect.signature(node.set_view) has_var_keyword = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) if has_var_keyword: node.set_view(**kwargs) remaining.clear() else: named_params = { name for name, p in sig.parameters.items() if p.kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) } filtered_kwargs = { k: v for k, v in kwargs.items() if k in named_params } node.set_view(**filtered_kwargs) remaining.difference_update(named_params) return node out = graphlib.recursive_map(_set_mode_fn, node, graph=graph) if raise_if_not_found and remaining: raise ValueError(f"Unused keys found in nnx.view: {sorted(remaining)}") return out def with_attributes( node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, graph: bool | None = None, **attributes: tp.Any, ) -> A: """Creates a new node with attributes updated according to ``**attributes``. The new node contains references to jax arrays in the original node. Unlike ``set_attributes``, this function does not modify the original node. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> new_block = nnx.with_attributes(block, deterministic=True, use_running_average=True) >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average (True, True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) ``Filter``'s can be used to set the attributes of specific Modules:: >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> new_block = nnx.with_attributes(block, only=nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average (True, False) Args: node: the object to create a copy of. only: Filters to select the Modules to set the attributes of. raise_if_not_found: If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. **attributes: The attributes to set. """ predicate = filterlib.to_predicate(only) remaining_attributes = set(attributes.keys()) def _set_attributes_fn(path, node): if isinstance(node, Module) and predicate(path, node): for name, value in attributes.items(): if hasattr(node, name): setattr(node, name, value) remaining_attributes.discard(name) return node out = graphlib.recursive_map(_set_attributes_fn, node, graph=graph) if remaining_attributes and raise_if_not_found: raise ValueError( 'Could not find at least one instance of the ' f'following attributes: {sorted(remaining_attributes)}' ) return out def _parse_docstring_args(doc_str: str) -> dict[str, str]: """Parses parameters from `Args:` section of a function docstring. Assumes Google style docstrings. Returns a dictionary with keys representing argument names and values representing descriptions. Each description has lines starting with 4 spaces. """ lines = doc_str.split("\n") # Get lines with the parameter names inds = [i for i, l in enumerate(lines) if l.startswith(" ") and not l.startswith(" ")] inds.append(len(lines)) out = dict() # Parse each argument for i in range(len(inds)-1): start, end = inds[i], inds[i+1] # Process first line for the description first_colon = lines[start].find(":") name = lines[start][:first_colon].strip() desc = [" "*4 + lines[start][first_colon+1:].strip()] # Append remaining description lines for j in range(start+1, end): desc.append(lines[j]) out[name] = "\n".join(desc) return out def view_info(node: Module, /, *, only: filterlib.Filter = ..., graph: bool | None = None) -> str: """Provides information about the ``view`` arguments for a module and all submodules. If no docstring is provided for a module's `set_view`, this function puts the `set_view` signature below the function. Example:: >>> from flax import nnx ... >>> class CustomModel(nnx.Module): ... def __init__(self, *, rngs): ... self.mha = nnx.MultiHeadAttention(4, 8, 32, rngs=rngs) ... self.drop = nnx.Dropout(0.5, rngs=rngs) ... self.bn = nnx.BatchNorm(32, rngs=rngs) ... >>> model = CustomModel(rngs=nnx.Rngs(0)) >>> print(nnx.view_info(model)) BatchNorm: use_running_average: bool | None = None if True, the stored batch statistics will be used instead of computing the batch statistics on the input. Dropout: deterministic: bool | None = None if True, disables dropout masking. MultiHeadAttention: deterministic: bool | None = None if True, the module is set to deterministic mode. decode: bool | None = None if True, the module is set to decode mode. batch_size: int | Shape | None = None the batch size to use for the cache. max_length: int | None = None the max length to use for the cache. Args: node: the object to display ``view`` information for. only: Filters to select the Modules to display information for. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ predicate = filterlib.to_predicate(only) classes: set[Module] = set() def _set_mode_info_fn(path, node): if hasattr(node, 'set_view') and predicate(path, node): classes.add(node.__class__) return node graphlib.recursive_map(_set_mode_info_fn, node, graph=graph) class_list = sorted(list(classes), key=lambda x: x.__qualname__) out_str = [] for c in class_list: out_str.append(f"{c.__qualname__}:") sig = inspect.signature(c.set_view) doc = inspect.getdoc(c.set_view) # Parse docstring if isinstance(doc, str): start, end = doc.find("Args:\n"), doc.find("Returns:\n") if end == -1: end = len(doc) doc = doc[start+6:end] parsed_docstring = _parse_docstring_args(doc) # Generate output from signature and docstring skip_names = {"self", "args", "kwargs"} for name, param in sig.parameters.items(): if name in skip_names: continue if param.default is inspect.Parameter.empty: out_str.append(f" {name}: {param.annotation}") else: out_str.append(f" {name}: {param.annotation} = {param.default}") out_str.append(parsed_docstring[name]) else: out_str.append(f" set_view{sig}") return "\n".join(out_str) def first_from(*args: tp.Optional[A], error_msg: str) -> A: """Return the first non-None argument. If all arguments are None, raise a ValueError with the given error message. Args: *args: the arguments to check error_msg: the error message to raise if all arguments are None Returns: The first non-None argument. """ for arg in args: if arg is not None: return arg raise ValueError(error_msg) def iter_modules( module: Module, /, *, graph: bool | None = None, ) -> tp.Iterator[tuple[PathParts, Module]]: """Recursively iterates over all nested :class:`Module`'s of the given Module, including the argument. Specifically, this function creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module. Example:: >>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in nnx.iter_modules(model): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block Args: module: A :class:`Module` object. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. """ for path, value in graphlib.iter_graph(module, graph=graph): if isinstance(value, Module): yield path, value iter_children = graphlib.iter_children P = tp.ParamSpec("P") R = tp.TypeVar("R") @tp.overload def capture( fn: tp.Callable[P, R], *var_types: type[variableslib.Variable], init: tp.Optional[State] = None, method_outputs: tp.Optional[type[variableslib.Variable]] = None ) -> tp.Callable[P, tuple[R, State]]: ... @tp.overload def capture( fn: type[variableslib.Variable], *var_types: type[variableslib.Variable], init: tp.Optional[State] = None, method_outputs: tp.Optional[type[variableslib.Variable]] = None ) -> tp.Callable[[tp.Callable[P, R]], tp.Callable[P, tuple[R, State]]]: ... def capture(fn: tp.Callable[P, R] | type[variableslib.Variable], *var_types: type[variableslib.Variable], init : tp.Optional[State] = None, method_outputs : tp.Optional[type[variableslib.Variable]] = None ) -> tp.Callable[P, tuple[R, State]] | tp.Callable[[tp.Callable[P, R]], tp.Callable[P, tuple[R, State]]]: """Wraps a function to capture intermediate values from a module during execution. This function wraps a `Callable`, executing it while collecting intermediate values that were stored using ``Module.sow()`` or ``Module.perturb()``. The `fn` argument can be either a function, a Module instance, or a bound method. If `fn` is a function, its first argument should be the module in which intermediate values are to be recorded. If `fn` is a bound method, the module used for storage is inferred from the instance. If `fn` is a Module, its `__call__` method will be wrapped. Args: fn: The `Callable` to wrap. var_types: Variable types to capture. If None, defaults to []. init: MutableMapping used to initialize perturbation values. This is useful for gradient extraction. method_outputs: If provided, automatically sows the output of each method in the module and its submodules using this variable type. Returns: A wrapped function that returns a tuple of (result, *intermediates) where result is the output of the function and each intermediate is a State containing the captured values with the corresponding type in `var_types`. Example with manual sowing:: class Foo(nnx.Module): def __call__(self, x): self.sow(nnx.Intermediate, 'features', x) return x model = Foo(rngs=nnx.Rngs(0)) forward = nnx.capture(model, nnx.Intermediate) result, intermediates = forward(x) # intermediates['features'] contains the sowed value Example with method outputs:: class Foo(nnx.Module): def features(self, x): return x def classifier(self, x): return x def __call__(self, x): return self.classifier(self.features(x)) model = Foo(rngs=nnx.Rngs(0)) result, intermediates = nnx.capture( model, method_output_type=nnx.Intermediate)(x) # intermediates contains outputs of features(), classifier(), and __call__() Example with gradient extraction:: class Model(nnx.Module): def __call__(self, x): x2 = self.perturb('grad_of_x', x) return 3 * x2 model = Model() forward = nnx.capture(lambda model, x: model(x), nnx.Perturbation) # Initialize perturbations _, perturbations = forward_capture(model, x) # Compute gradients with respect to perturbations loss = nnx.capture(forward, init=perturbations) grads, sowed = nnx.grad(loss, has_aux=True)(model, perturbations, x) """ # Handle partial evaluation when first arg is a Variable type if isinstance(fn, type) and issubclass(fn, variableslib.Variable): # Partial application: return a function that waits for the actual fn all_var_types = (fn,) + var_types def partial_capture(actual_fn: tp.Callable[P, R] | Module) -> tp.Callable[P, tuple[R, State]]: return capture(actual_fn, *all_var_types, init=init, method_outputs=method_outputs) return partial_capture # Handle bound methods and callable Modules module_instance = None if inspect.ismethod(fn) and isinstance(fn.__self__, Module): module_instance = fn.__self__ elif isinstance(fn, Module): module_instance = fn ft.wraps(fn) def wrapper(*fn_args, **kwargs): if module_instance is None: module = fn_args[0] else: module = module_instance # Extract initial values from state state_by_path = _collect_state_by_path(init) if init else {} # Initialize __captures__ as a tuple of Variables (one per type) for path, m in iter_modules(module): # Create initial dicts for each variable type initial_dicts = {} for var_type in var_types: initial_dicts[var_type] = {} # Populate from state if available if path in state_by_path: for name, var in state_by_path[path].items(): var_type = type(var) if var_type not in initial_dicts: initial_dicts[var_type] = {} initial_dicts[var_type][name] = var.get_value() # Create the captures tuple captures_tuple = tuple(k(v) for (k,v) in initial_dicts.items()) m.__captures__ = pytreelib.data(captures_tuple) # Wrap methods with capturing if required if method_outputs: for _, m in iter_modules(module): _add_capturing(type(m), method_outputs) try: result = fn(*fn_args, **kwargs) finally: # Undo method sowing modification for _, m in iter_modules(module): _remove_capturing(type(m)) # Extract intermediates manually from __captures__ interms = State({}) _extract_captures(module, interms, set(var_types)) if len(var_types) == 0: return result split_states = split_state(interms, *var_types) if len(var_types) == 1: return result, split_states else: return (result, *split_states) return wrapper def _collect_state_by_path(state): """Build a mapping from module path to state Variables.""" state_by_path = {} def collect(s, path_parts): if isinstance(s, MutableMapping): for key, value in s.items(): if isinstance(value, variableslib.Variable): path_tuple = tuple(path_parts) if path_tuple not in state_by_path: state_by_path[path_tuple] = {} state_by_path[path_tuple][key] = value elif isinstance(value, MutableMapping): collect(value, path_parts + [key]) collect(state, []) return state_by_path def _navigate_to_path(state, path): """Navigate to a nested path in state, creating dicts as needed.""" current = state for part in path: if part not in current: current[part] = State({}) current = current[part] return current def _extract_captures(module, state, var_types): """Extract intermediates from __captures__ tuple into state dict.""" for path, mod in iter_modules(module): if hasattr(mod, '__captures__'): captures_tuple = mod.__captures__ for var in captures_tuple: if not type(var) in var_types: continue current = _navigate_to_path(state, path) for key, value in var.items(): current[key] = type(var)(value) delattr(mod, '__captures__') def _add_capturing(cls, variable_type): """Adds capturing to methods of a Module. Does not instrument superclass methods.""" for name, method in cls.__dict__.items(): if callable(method) and (not name.startswith('_') or name == '__call__'): if not hasattr(method, '_does_capturing'): def closure(name, method): # Necessary to make 'name' immutable during iteration @ft.wraps(method) def wrapper(self, *args, **kwargs): result = method(self, *args, **kwargs) self.sow(variable_type, name, result) return result wrapper._does_capturing = True setattr(cls, name, wrapper) closure(name, method) return cls def _remove_capturing(cls): """Remove capturing methods from a Module.""" for name, method in cls.__dict__.items(): if hasattr(method, '_does_capturing'): setattr(cls, name, method.__wrapped__) return cls ================================================ FILE: flax/nnx/nn/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: flax/nnx/nn/activations.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp from types import MappingProxyType from jax.nn import ( celu, elu, gelu, glu, hard_sigmoid, hard_silu, hard_swish, hard_tanh, leaky_relu, log_sigmoid, log_softmax, logsumexp, one_hot, relu, identity, relu6, selu, sigmoid, silu, soft_sign, softmax, softplus, standardize, swish, ) import jax.numpy as jnp from jax.numpy import tanh from flax import nnx from flax.nnx.nn import dtypes from flax.typing import Array, Dtype, PromoteDtypeFn __all__ = [ 'celu', 'elu', 'gelu', 'glu', 'hard_sigmoid', 'hard_silu', 'hard_swish', 'hard_tanh', 'leaky_relu', 'log_sigmoid', 'log_softmax', 'logsumexp', 'one_hot', 'relu', 'identity', 'relu6', 'selu', 'sigmoid', 'silu', 'soft_sign', 'softmax', 'softplus', 'standardize', 'swish', 'tanh', 'PReLU', ] class PReLU(nnx.Module): """Parametric Rectified Linear Unit (PReLU) activation function. Note that PReLU is a Flax layer and not a simple activation function, so it needs to be initialized before being called. Example:: >>> import flax.nnx as nnx >>> class MLP(nnx.Module): ... def __init__(self): ... self.linear = nnx.Linear(3, 2) ... self.act = nnx.PReLU(negative_slope_init=0.1) ... ... def __call__(self, x): ... x = self.linear(x) ... x = self.act(x) ... return x Args: negative_slope_init: the value to initialize the negative slope (default 0.01). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, negative_slope)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. negative_slope_metadata: Optional metadata dictionary to set when initializing the negative slope. """ def __init__( self, negative_slope_init: float = 0.01, *, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, negative_slope_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): self.negative_slope = nnx.Param( jnp.asarray(negative_slope_init, dtype=param_dtype), **negative_slope_metadata ) self.dtype = dtype self.param_dtype = param_dtype self.promote_dtype = promote_dtype def __call__(self, inputs: Array) -> Array: negative_slope = self.negative_slope[...] if self.dtype is not None: inputs, negative_slope = self.promote_dtype( (inputs, negative_slope), dtype=self.dtype ) else: # Match Linen behavior: cast parameter to input dtype negative_slope = jnp.asarray(negative_slope, inputs.dtype) return jnp.where( inputs >= 0, inputs, negative_slope * inputs, ) ================================================ FILE: flax/nnx/nn/attention.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Attention core modules for Flax.""" from __future__ import annotations import functools from typing import Any from collections.abc import Mapping from types import MappingProxyType from collections.abc import Callable import math import jax import jax.numpy as jnp from jax import lax, random from flax import nnx from flax.nnx import rnglib from flax.nnx.module import Module, first_from from flax.nnx.nn import initializers from flax.nnx.nn import dtypes from flax.nnx.nn.linear import ( LinearGeneral, default_kernel_init, ) from flax.nnx.nn.normalization import LayerNorm from flax.typing import ( Dtype, PromoteDtypeFn, Shape, Initializer, PrecisionLike, DotGeneralT, ) Array = jax.Array def dot_product_attention_weights( query: Array, key: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Dtype | None = None, precision: PrecisionLike = None, module: Module | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, is_causal: bool = False, ): """Computes dot-product attention weights given query and key. Used by :func:`dot_product_attention`, which is what you'll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself. Args: query: queries for calculating attention with shape of `[batch..., q_length, num_heads, qk_depth_per_head]`. key: keys for calculating attention with shape of `[batch..., kv_length, num_heads, qk_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs and params) precision: numerical precision of the computation see `jax.lax.Precision` for details. module: the Module that will sow the attention weights into the ``nnx.Intermediate`` collection. If ``module`` is None, the attention weights will not be sowed. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. is_causal: If true, causal attention will be applied. Note, some implementations like xla will generate a mask tensor and apply it to the logits to mask out the non-causal parts of the attention matrix, but other implementations like cudnn will avoid computing the non-causal regions, providing speedups. Returns: Output of shape `[batch..., num_heads, q_length, kv_length]`. """ query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # check if we need to broadcast Key heads to match Query heads is_gqa = False if query.shape[-2] != key.shape[-2]: q_heads = query.shape[-2] k_heads = key.shape[-2] if q_heads % k_heads != 0: raise ValueError( f"Query heads ({q_heads}) must be multiple of " f"Key heads ({k_heads}) for Grouped Query Attention." ) n_rep = q_heads // k_heads is_gqa = True # Reshape Query: [..., Q, H_k * n_rep, D] -> [..., Q, H_k, n_rep, D] query = query.reshape(query.shape[:-2] + (k_heads, n_rep, query.shape[-1])) # Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D] key = jnp.expand_dims(key, axis=-2) # Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep) einsum_str = '...qhgd,...kh1d->...hgqk' else: q_heads = query.shape[-2] einsum_str = '...qhd,...khd->...hqk' assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' # calculate attention matrix depth = query.shape[-1] query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) attn_weights = jnp.einsum(einsum_str, query, key, precision=precision) if is_gqa: attn_weights = attn_weights.reshape(attn_weights.shape[:-4] + (q_heads, attn_weights.shape[-2], attn_weights.shape[-1])) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: attn_weights = attn_weights + bias # apply attention mask if mask is not None or is_causal: big_neg = jnp.finfo(dtype).min masks = [m for m in [mask] if m is not None] if is_causal: T, S = attn_weights.shape[-2:] causal_mask = jnp.tril(jnp.ones((T, S), dtype=dtype)) target_shape = mask.shape if mask is not None else attn_weights.shape masks.append(jnp.broadcast_to(causal_mask, target_shape)) combined_mask = combine_masks(*masks, dtype=dtype) assert combined_mask is not None attn_weights = jnp.where(combined_mask, attn_weights, big_neg) # normalize the attention weights attn_weights = jax.nn.softmax(attn_weights).astype(dtype) if module: module.sow(nnx.Intermediate, 'attention_weights', attn_weights) # apply attention dropout if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate # use original key.ndim because we might have expanded key dim ndim_base = key.ndim - 1 if is_gqa else key.ndim if broadcast_dropout: # dropout is broadcast across the batch + head dimensions dropout_shape = tuple([1] * (ndim_base - 2)) + attn_weights.shape[-2:] keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier return attn_weights def dot_product_attention( query: Array, key: Array, value: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Dtype | None = None, precision: PrecisionLike = None, module: Module | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, is_causal: bool = False, ): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. Will use the more optimized `jax.nn.dot_product_attention` if dropout is not activated and `module=None`. .. note:: ``query``, ``key``, ``value`` needn't have any batch dimensions. Args: query: queries for calculating attention with shape of ``[batch..., q_length, num_heads, qk_depth_per_head]``. key: keys for calculating attention with shape of ``[batch..., kv_length, num_heads, qk_depth_per_head]``. value: values to be used in attention with shape of ``[batch..., kv_length, num_heads, v_depth_per_head]``. bias: bias for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see `jax.lax.Precision` for details. module: the Module that will sow the attention weights into the ``nnx.Intermediate`` collection. If ``module`` is None, the attention weights will not be sowed. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(query, key, value)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. is_causal: If true, causal attention will be applied. Note, some implementations like xla will generate a mask tensor and apply it to the logits to mask out the non-causal parts of the attention matrix, but other implementations like cudnn will avoid computing the non-causal regions, providing speedups. Returns: Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. """ query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert ( query.shape[:-3] == key.shape[:-3] == value.shape[:-3] ), 'q, k, v batch dims must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # Criteria that invoke the more optimized dot product attention if dropout_rate == 0.0 and module is None: # make sure qkv batch are compressed to one dim query_shape = query.shape if len(query_shape) > 4: def reshape_4d(x): return jnp.reshape(x, (math.prod(x.shape[:-3]), *x.shape[-3:])) query, key, value, bias, mask = jax.tree.map( reshape_4d, (query, key, value, bias, mask)) if mask is not None: mask = mask.astype(jnp.bool) out = jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal) if len(query_shape) > 4: out = jnp.reshape(out, query_shape) return out # compute attention weights attn_weights = dot_product_attention_weights( query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, deterministic, dtype, precision, module, promote_dtype, is_causal, ) # return weighted sum over values for each query position # check if need to broadcast Value heads to match Query heads (GQA) if attn_weights.shape[-3] != value.shape[-2]: q_heads = attn_weights.shape[-3] v_heads = value.shape[-2] if q_heads % v_heads != 0: raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})") n_rep = q_heads // v_heads # Reshape weights: [..., H_v, n_rep, Q, K] attn_weights = attn_weights.reshape(attn_weights.shape[:-3] + (v_heads, n_rep) + attn_weights.shape[-2:]) # Expand Value: [..., K, H_v, 1, D] value = jnp.expand_dims(value, axis=-2) # Contract: hgqk, kh1d -> qhgd (h=H_v, g=n_rep) out = jnp.einsum('...hgqk,...kh1d->...qhgd', attn_weights, value, precision=precision) # Flatten: [..., Q, H_q, D] out = out.reshape(out.shape[:-3] + (q_heads, out.shape[-1])) else: out = jnp.einsum( '...hqk,...khd->...qhd', attn_weights, value, precision=precision ) return out class MultiHeadAttention(Module): """Multi-head attention. Example usage:: >>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all() Args: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. in_features: int or tuple with number of input features. qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection. in_kv_features: number of input features for computing key and value. num_kv_heads: number of key and value heads. If None, it defaults to ``num_heads``. If set to a value smaller than ``num_heads``, Grouped Query Attention (GQA) is used. ``num_heads`` must be divisible by ``num_kv_heads``. dtype: the dtype of the computation (default: infer from inputs and params) param_dtype: the dtype passed to parameter initializers (default: float32) broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rate: dropout rate deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. out_kernel_init: optional initializer for the kernel of the output Dense layer, if None, the kernel_init is used. bias_init: initializer for the bias of the Dense layers. out_bias_init: optional initializer for the bias of the output Dense layer, if None, the bias_init is used. use_bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` decode: whether to prepare and use an autoregressive cache. normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). qkv_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the query, key, and value LinearGeneral submodules. out_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the output LinearGeneral submodule. ln_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the LayerNorm submodules (query_ln and key_ln) when normalize_qk=True. rngs: rng key. keep_rngs: whether to store the input rngs as attribute (i.e. `self.rngs = rngs`) (default: True). If rngs is stored, we should split the module as `graphdef, params, nondiff = nnx.split(module, nnx.Param, ...)` where `nondiff` contains RNG object associated with stored `self.rngs`. kernel_metadata: Optional metadata dictionary to set when initializing the Dense layers. out_kernel_metadata: Optional metadata dictionary to set when initializing the output Dense layers. If None, the kernel_metadata is used. bias_metadata: Optional metadata dictionary to set when initializing the bias of the Dense layers. out_bias_metadata: Optional metadata dictionary to set when initializing the bias of the output Dense layers. If None, the bias_metadata is used. query_ln_scale_metadata: Optional metadata dictionary to set when initializing the scale of the query layer norm layer. key_ln_scale_metadata: Optional metadata dictionary to set when initializing the scale of the key layer norm layer. """ def __init__( self, num_heads: int, in_features: int, qkv_features: int | None = None, out_features: int | None = None, num_kv_heads: int | None = None, in_kv_features: int | None = None, *, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, broadcast_dropout: bool = True, dropout_rate: float = 0.0, deterministic: bool | None = None, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, out_kernel_init: Initializer | None = None, bias_init: Initializer = initializers.zeros_init(), out_bias_init: Initializer | None = None, use_bias: bool = True, attention_fn: Callable[..., Array] = dot_product_attention, decode: bool | None = None, normalize_qk: bool = False, qkv_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, out_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, ln_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, # Deprecated, will be removed. qkv_dot_general: DotGeneralT | None = None, out_dot_general: DotGeneralT | None = None, qkv_dot_general_cls: Any = None, out_dot_general_cls: Any = None, rngs: rnglib.Rngs, keep_rngs: bool = True, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), out_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), out_bias_metadata: Mapping[str, Any] = MappingProxyType({}), query_ln_scale_metadata: Mapping[str, Any] = MappingProxyType({}), key_ln_scale_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.num_heads = num_heads self.in_features = in_features self.qkv_features = ( qkv_features if qkv_features is not None else in_features ) self.out_features = ( out_features if out_features is not None else in_features ) self.in_kv_features = ( in_kv_features if in_kv_features is not None else in_features ) self.num_kv_heads = ( num_kv_heads if num_kv_heads is not None else num_heads ) if self.num_heads % self.num_kv_heads != 0: raise ValueError( f"num_heads ({self.num_heads}) must be divisible by " f"num_kv_heads ({self.num_kv_heads})." ) self.dtype = dtype self.param_dtype = param_dtype self.broadcast_dropout = broadcast_dropout self.dropout_rate = dropout_rate self.deterministic = deterministic self.precision = precision self.use_bias = use_bias self.attention_fn = attention_fn self.decode = decode self.normalize_qk = normalize_qk self.qkv_promote_dtype = qkv_promote_dtype self.out_promote_dtype = out_promote_dtype self.ln_promote_dtype = ln_promote_dtype self.qkv_dot_general = qkv_dot_general self.out_dot_general = out_dot_general self.qkv_dot_general_cls = qkv_dot_general_cls self.out_dot_general_cls = out_dot_general_cls if self.qkv_features % self.num_heads != 0: raise ValueError( f'Memory dimension ({self.qkv_features}) must be divisible by ' f"'num_heads' heads ({self.num_heads})." ) self.head_dim = self.qkv_features // self.num_heads linear_general = functools.partial( LinearGeneral, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=kernel_init, bias_init=bias_init, use_bias=self.use_bias, precision=self.precision, promote_dtype=self.qkv_promote_dtype, dot_general=self.qkv_dot_general, dot_general_cls=self.qkv_dot_general_cls, kernel_metadata=kernel_metadata, bias_metadata=bias_metadata, ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] self.query = linear_general( self.in_features, out_features=(self.num_heads, self.head_dim), rngs=rngs ) self.key = linear_general( self.in_kv_features, out_features=(self.num_kv_heads, self.head_dim), rngs=rngs ) self.value = linear_general( self.in_kv_features, out_features=(self.num_kv_heads, self.head_dim), rngs=rngs ) self.query_ln: LayerNorm | None self.key_ln: LayerNorm | None if self.normalize_qk: # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. self.query_ln = LayerNorm( self.head_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.ln_promote_dtype, rngs=rngs, scale_metadata=query_ln_scale_metadata, ) self.key_ln = LayerNorm( self.head_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.ln_promote_dtype, rngs=rngs, scale_metadata=key_ln_scale_metadata, ) else: self.query_ln = nnx.data(None) self.key_ln = nnx.data(None) self.out = LinearGeneral( in_features=(self.num_heads, self.head_dim), out_features=self.out_features, axis=(-2, -1), kernel_init=out_kernel_init or kernel_init, bias_init=out_bias_init or bias_init, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, promote_dtype=self.out_promote_dtype, dot_general=self.out_dot_general, dot_general_cls=self.out_dot_general_cls, rngs=rngs, kernel_metadata=out_kernel_metadata or kernel_metadata, bias_metadata=out_bias_metadata or bias_metadata, ) self.rngs = rngs.dropout.fork() if keep_rngs and dropout_rate > 0 else None self.cached_key: nnx.Cache[Array] | None = nnx.data(None) self.cached_value: nnx.Cache[Array] | None = nnx.data(None) self.cache_index: nnx.Cache[Array] | None = nnx.data(None) def __call__( self, inputs_q: Array, inputs_k: Array | None = None, inputs_v: Array | None = None, *, mask: Array | None = None, deterministic: bool | None = None, rngs: rnglib.Rngs | rnglib.RngStream | None = None, sow_weights: bool = False, decode: bool | None = None, ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_k: key of shape `[batch_sizes..., length, features]`. If None, inputs_k will copy the value of inputs_q. inputs_v: values of shape `[batch_sizes..., length, features]`. If None, inputs_v will copy the value of inputs_k. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. The ``deterministic`` flag passed into the call method will take precedence over the ``deterministic`` flag passed into the constructor. rngs: rng key. The rng key passed into the call method will take precedence over the rng key passed into the constructor. sow_weights: if ``True``, the attention weights are sowed into the 'intermediates' collection. decode: whether to prepare and use an autoregressive cache. The ``decode`` flag passed into the call method will take precedence over the ``decode`` flag passed into the constructor. Returns: output of shape `[batch_sizes..., length, features]`. """ if rngs is None: rngs = self.rngs elif isinstance(rngs, rnglib.Rngs): rngs = rngs.dropout if inputs_k is None: if inputs_v is not None: raise ValueError( '`inputs_k` cannot be None if `inputs_v` is not None. ' 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' 'value to `inputs_k` and leave `inputs_v` as None.' ) inputs_k = inputs_q if inputs_v is None: inputs_v = inputs_k if inputs_q.shape[-1] != self.in_features: raise ValueError( f'Incompatible input dimension, got {inputs_q.shape[-1]} ' f'but module expects {self.in_features}.' ) query = self.query(inputs_q) key = self.key(inputs_k) value = self.value(inputs_v) if self.normalize_qk: assert self.query_ln is not None and self.key_ln is not None # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. query = self.query_ln(query) key = self.key_ln(key) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. decode = first_from( decode, self.decode, error_msg="""No `decode` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) if decode: if ( self.cached_key is None or self.cached_value is None or self.cache_index is None ): raise ValueError( 'Autoregressive cache not initialized, call ``init_cache`` first.' ) ( *batch_dims, max_length, num_kv_heads, depth_per_head, ) = self.cached_key.shape # shape check of cached keys against key input expected_shape = tuple(batch_dims) + (1, num_kv_heads, depth_per_head) if expected_shape != key.shape: raise ValueError( 'Autoregressive cache shape error, ' f'expected key shape {expected_shape} instead got {key.shape}.' ) # update key, value caches with our new 1d spatial slices cur_index = self.cache_index[...] zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) indices = (zero,) * len(batch_dims) + (cur_index, zero, zero) key = lax.dynamic_update_slice(self.cached_key[...], key, indices) value = lax.dynamic_update_slice(self.cached_value[...], value, indices) self.cached_key[...] = key self.cached_value[...] = value self.cache_index[...] += 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length), ), ) if ( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. deterministic = first_from( deterministic, self.deterministic, error_msg="""No `deterministic` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) if not deterministic: if rngs is None: raise ValueError( "'rngs' must be provided to __call__ method if " "MultiHeadAttention instance is defined with keep_rngs=False." ) dropout_rng = rngs() else: dropout_rng = None else: deterministic = True dropout_rng = None # apply attention x = self.attention_fn( query, key, value, mask=mask, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision, module=self if sow_weights else None, ) # back to the original inputs dimensions out = self.out(x) return out def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing forward inference. When in decode mode, only one token must be passed at a time. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> batch_size = 5 >>> embed_dim = 3 >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x) """ cache_shape = (*input_shape[:-1], self.num_kv_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) def set_view( self, deterministic: bool | None = None, decode: bool | None = None, batch_size: int | Shape | None = None, max_length: int | None = None, ): """Class method used by ``nnx.view``. Args: train: if True, the module is set to training mode. deterministic: if True, the module is set to deterministic mode. decode: if True, the module is set to decode mode. batch_size: the batch size to use for the cache. max_length: the max length to use for the cache. """ if deterministic is not None: self.deterministic = deterministic if decode is not None: self.decode = decode if ( not hasattr(self, 'cached_key') or not hasattr(self, 'cached_value') or not hasattr(self, 'cache_index') ): if batch_size is None: raise TypeError( "'batch_size' must be provided when initializing cache." ) if max_length is None: raise TypeError( "'max_length' must be provided when initializing cache." ) if isinstance(batch_size, int): batch_size = (batch_size,) # initialize cache cache_shape = (*batch_size, max_length, self.num_kv_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32)) # mask-making utility functions def make_attention_mask( query_input: Array, key_input: Array, pairwise_fn: Callable[..., Any] = jnp.multiply, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32, ): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the attention weights will be `[batch..., heads, len_q, len_kv]` and this function will produce `[batch..., 1, len_q, len_kv]`. Args: query_input: a batched, flat input of query_length size key_input: a batched, flat input of key_length size pairwise_fn: broadcasting elementwise comparison function extra_batch_dims: number of extra batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. """ mask = pairwise_fn( jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) ) mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) return mask.astype(dtype) def make_causal_mask( x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 ) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights will be `[batch..., heads, len, len]` and this function will produce a causal mask of shape `[batch..., 1, len, len]`. Args: x: input array of shape `[batch..., len]` extra_batch_dims: number of batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A `[batch..., 1, len, len]` shaped causal mask for 1d attention. """ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) return make_attention_mask( idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype, ) def combine_masks( *masks: Array | None, dtype: Dtype = jnp.float32 ) -> Array | None: """Combine attention masks. Args: *masks: set of attention mask arguments to combine, some can be None. dtype: dtype for the returned mask. Returns: Combined mask, reduced by logical and, returns None if no masks given. """ masks_list = [m for m in masks if m is not None] if not masks_list: return None assert all( map(lambda x: x.ndim == masks_list[0].ndim, masks_list) ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' mask, *other_masks = masks_list for other_mask in other_masks: mask = jnp.logical_and(mask, other_mask) return mask.astype(dtype) ================================================ FILE: flax/nnx/nn/dtypes.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp from flax.typing import Dtype from jax import numpy as jnp T = tp.TypeVar('T', bound=tuple) def canonicalize_dtype( *args, dtype: Dtype | None = None, inexact: bool = True ) -> Dtype: """Canonicalize an optional dtype to the definitive dtype. If the ``dtype`` is None this function will infer the dtype. If it is not None it will be returned unmodified or an exceptions is raised if the dtype is invalid. from the input arguments using ``jnp.result_type``. Args: *args: JAX array compatible values. None values are ignored. dtype: Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled. inexact: When True, the output dtype must be a subdtype of `jnp.inexact`. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don't work directly on integers like taking a mean for example. Returns: The dtype that *args should be cast to. """ if dtype is None: args_filtered = [jnp.asarray(x) for x in args if x is not None] dtype = jnp.result_type(*args_filtered) if inexact and not jnp.issubdtype(dtype, jnp.inexact): dtype = jnp.promote_types(jnp.float32, dtype) if inexact and not jnp.issubdtype(dtype, jnp.inexact): raise ValueError(f'Dtype must be inexact: {dtype}') return dtype def promote_dtype(args: T, /, *, dtype=None, inexact=True) -> T: """ "Promotes input arguments to a specified or inferred dtype. All args are cast to the same dtype. See ``canonicalize_dtype`` for how this dtype is determined. The behavior of promote_dtype is mostly a convinience wrapper around ``jax.numpy.promote_types``. The differences being that it automatically casts all input to the inferred dtypes, allows inference to be overridden by a forced dtype, and has an optional check to garantuee the resulting dtype is inexact. Args: *args: JAX array compatible values. None values are returned as is. dtype: Optional dtype override. If specified the arguments are cast to the specified dtype instead and dtype inference is disabled. inexact: When True, the output dtype must be a subdtype of `jnp.inexact`. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don't work directly on integers like taking a mean for example. Returns: The arguments cast to arrays of the same dtype. """ dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) arrays = tuple(jnp.asarray(x, dtype) if x is not None else None for x in args) return arrays # type: ignore[return-value] ================================================ FILE: flax/nnx/nn/initializers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp from jax.nn.initializers import constant as constant from jax.nn.initializers import delta_orthogonal as delta_orthogonal from jax.nn.initializers import glorot_normal as glorot_normal from jax.nn.initializers import glorot_uniform as glorot_uniform from jax.nn.initializers import he_normal as he_normal from jax.nn.initializers import he_uniform as he_uniform from jax.nn.initializers import kaiming_normal as kaiming_normal from jax.nn.initializers import kaiming_uniform as kaiming_uniform from jax.nn.initializers import lecun_normal as lecun_normal from jax.nn.initializers import lecun_uniform as lecun_uniform from jax.nn.initializers import normal as normal from jax.nn.initializers import ones as ones from jax.nn.initializers import orthogonal as orthogonal from jax.nn.initializers import truncated_normal as truncated_normal from jax.nn.initializers import uniform as uniform from jax.nn.initializers import variance_scaling as variance_scaling from jax.nn.initializers import xavier_normal as xavier_normal from jax.nn.initializers import xavier_uniform as xavier_uniform from jax.nn.initializers import zeros as zeros from flax.typing import Initializer DtypeLikeInexact = tp.Any def zeros_init() -> Initializer: """Builds an initializer that returns a constant array full of zeros. >>> import jax, jax.numpy as jnp >>> from flax.nnx import initializers >>> zeros_initializer = initializers.zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ return zeros def ones_init() -> Initializer: """Builds an initializer that returns a constant array full of ones. >>> import jax, jax.numpy as jnp >>> from flax.nnx import initializers >>> ones_initializer = initializers.ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) """ return ones ================================================ FILE: flax/nnx/nn/linear.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import typing as tp from types import MappingProxyType import jax import jax.numpy as jnp import numpy as np from jax import lax import opt_einsum from flax.core.frozen_dict import FrozenDict from flax import nnx from flax.nnx import rnglib, variablelib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers from flax.typing import ( Dtype, Shape, Initializer, PrecisionLike, DotGeneralT, ConvGeneralDilatedT, PaddingLike, LaxPadding, PromoteDtypeFn, EinsumT, ) Array = jax.Array Axis = int Size = int default_kernel_init = initializers.lecun_normal() default_bias_init = initializers.zeros_init() def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: """ "Canonicalizes conv padding to a jax.lax supported format.""" if isinstance(padding, str): return padding if isinstance(padding, int): return [(padding, padding)] * rank if isinstance(padding, tp.Sequence) and len(padding) == rank: new_pad = [] for p in padding: if isinstance(p, int): new_pad.append((p, p)) elif isinstance(p, tuple) and len(p) == 2: new_pad.append(p) else: break if len(new_pad) == rank: return new_pad raise ValueError( f'Invalid padding format: {padding}, should be str, int,' f' or a sequence of len {rank} where each element is an' ' int or pair of ints.' ) def _conv_dimension_numbers(input_shape): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes)) def _canonicalize_tuple(x: tp.Sequence[int] | int) -> tuple[int, ...]: if isinstance(x, tp.Iterable): return tuple(x) else: return (x,) class LinearGeneral(Module): """A linear transformation with flexible axes. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 4, 5) >>> layer.bias.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (2, 3, 4, 5) >>> layer.bias.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5) Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_axis: mapping of batch axis indices to axis size. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. dot_general: dot product function (default: None). If neither this nor ``dot_general_cls`` are provided, ``jax.lax.dot_general`` is used. dot_general_cls: dot product function class to instantiate a dot product function as ``dot_general = dot_general_cls()`` (default: None). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. preferred_element_type: Optional parameter controls the data type output by the dot product. This argument is passed to ``dot_general`` function. See ``jax.lax.dot`` for details. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the weight matrix. bias_metadata: Optional metadata dictionary to set when initializing the bias. """ def __init__( self, in_features: Size | tp.Sequence[Size], out_features: Size | tp.Sequence[Size], *, axis: Axis | tp.Sequence[Axis] = -1, batch_axis: tp.Mapping[Axis, Size] = FrozenDict({}), use_bias: bool = True, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, precision: PrecisionLike = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, dot_general: DotGeneralT | None = None, dot_general_cls: tp.Any = None, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, kernel_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): self.in_features = _canonicalize_tuple(in_features) self.out_features = _canonicalize_tuple(out_features) self.axis = _canonicalize_tuple(axis) self.batch_axis = FrozenDict[Axis, Size](batch_axis) self.use_bias = use_bias self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.dot_general = dot_general self.dot_general_cls = dot_general_cls self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type if len(self.in_features) != len(self.axis): raise ValueError( 'in_features and axis must have the same length. ' f'Got {self.in_features} and {self.axis}.' ) if batch_axis: batch_dims = tuple(batch_axis.keys()) max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims) ) n_batch_axis = len(self.batch_axis) n_in_features = len(self.in_features) n_out_features = len(self.out_features) def kernel_init_wrap(rng, shape, dtype): flat_shape = ( np.prod(shape[:n_batch_axis]) * np.prod(shape[n_batch_axis : n_in_features + n_batch_axis]), np.prod(shape[-n_out_features:]), ) flat_shape = jax.tree.map(int, flat_shape) kernel = kernel_init(rng, flat_shape, dtype) if isinstance(kernel, variablelib.VariableMetadata): kernel.raw_value = jnp.reshape(kernel.raw_value, shape) else: kernel = jnp.reshape(kernel, shape) return kernel batch_shape = tuple(self.batch_axis.values()) kernel_shape = ( *batch_shape, *self.in_features, *self.out_features, ) self.kernel = nnx.Param( kernel_init_wrap(rngs.params(), kernel_shape, self.param_dtype), **kernel_metadata ) self.bias: nnx.Param[jax.Array] | None if self.use_bias: def bias_init_wrap(rng, shape, dtype): flat_shape = (int(np.prod(shape)),) bias = bias_init(rng, flat_shape, dtype) if isinstance(bias, variablelib.VariableMetadata): bias.raw_value = jnp.reshape(bias.raw_value, shape) else: bias = jnp.reshape(bias, shape) return bias bias_shape = (*batch_shape, *self.out_features) self.bias = nnx.Param( bias_init_wrap(rngs.params(), bias_shape, self.param_dtype), **bias_metadata, ) else: self.bias = nnx.data(None) def __call__(self, inputs: Array, out_sharding = None) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ ndim = inputs.ndim n_batch_dims = len(self.batch_axis) axis = _normalize_axes(self.axis, ndim) batch_axis = _normalize_axes(tuple(self.batch_axis.keys()), ndim) n_axis = len(axis) # batch and non-contracting dims of input with 1s for batch dims. expanded_batch_shape = tuple( inputs.shape[ax] if ax in batch_axis else 1 for ax in range(inputs.ndim) if ax not in axis ) kernel = self.kernel[...] bias = self.bias[...] if self.bias is not None else None batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) inputs, kernel, bias = self.promote_dtype( (inputs, kernel, bias), dtype=self.dtype ) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general # We use dot_general_kwargs for BC compatibility with # user custom dot_general/dot_general_cls which may not have # preferred_element_type argument to avoid breaking # existing code dot_general_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: dot_general_kwargs["preferred_element_type"] = self.preferred_element_type out = dot_general( inputs, kernel, ((axis, contract_ind), (batch_axis, batch_ind)), precision=self.precision, **dot_general_kwargs, ) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if bias is not None: # expand bias shape to broadcast bias over batch dims. bias = jnp.reshape(bias, (*expanded_batch_shape, *self.out_features)) out += bias return out class Linear(Module): """A linear transformation applied over the last dimension of the input. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(4,) ), 'kernel': Param( value=(3, 4) ) }) Args: in_features: the number of input features. out_features: the number of output features. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. dot_general: dot product function. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. preferred_element_type: Optional parameter controls the data type output by the dot product. This argument is passed to ``dot_general`` function. See ``jax.lax.dot`` for details. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the weight matrix. bias_metadata: Optional metadata dictionary to set when initializing the bias. """ def __init__( self, in_features: int, out_features: int, *, use_bias: bool = True, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, dot_general: DotGeneralT = lax.dot_general, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, kernel_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): kernel_key = rngs.params() self.kernel = nnx.Param( kernel_init(kernel_key, (in_features, out_features), param_dtype), **kernel_metadata, ) self.bias: nnx.Param[jax.Array] | None if use_bias: bias_key = rngs.params() self.bias = nnx.Param( bias_init(bias_key, (out_features,), param_dtype), **bias_metadata, ) else: self.bias = nnx.data(None) self.in_features = in_features self.out_features = out_features self.use_bias = use_bias self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.dot_general = dot_general self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type def __call__(self, inputs: Array, out_sharding = None) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ kernel = self.kernel[...] bias = self.bias[...] if self.bias is not None else None inputs, kernel, bias = self.promote_dtype( (inputs, kernel, bias), dtype=self.dtype ) # We use dot_general_kwargs for BC compatibility with # user custom self.dot_general method which may not have # preferred_element_type argument to avoid breaking # existing code dot_general_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: dot_general_kwargs["preferred_element_type"] = self.preferred_element_type y = self.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, **dot_general_kwargs, ) assert self.use_bias == (bias is not None) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y class Einsum(Module): """An einsum transformation with learnable kernel and bias. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.shape (8, 2, 4) >>> layer.bias.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4) Args: einsum_str: a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of ``einsum_str`` in the constructor argument and call argument must be not None, while the other must be None. kernel_shape: the shape of the kernel. bias_shape: the shape of the bias. If this is None, a bias won't be used. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. einsum_op: An injectable alternative of `jnp.einsum` to do the computation. Should support same signature as `jnp.einsum`. preferred_element_type: Optional parameter controls the data type output by the dot product. This argument is passed to ``dot_general`` function. See ``jax.lax.dot`` for details. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the weight matrix. bias_metadata: Optional metadata dictionary to set when initializing the bias. """ def __init__( self, einsum_str: str, kernel_shape: Shape, bias_shape: tp.Optional[Shape] = None, *, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, einsum_op: EinsumT = jnp.einsum, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, kernel_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): einsum_str = einsum_str.replace(' ', '') self._einsum_str_check(einsum_str) kernel_key = rngs.params() self.kernel = nnx.Param( kernel_init(kernel_key, kernel_shape, param_dtype), **kernel_metadata ) self.bias: nnx.Param | None if bias_shape is not None: bias_key = rngs.params() self.bias = nnx.Param( bias_init(bias_key, bias_shape, param_dtype), **bias_metadata ) else: self.bias = nnx.data(None) self.einsum_str = einsum_str self.kernel_shape = kernel_shape self.bias_shape = bias_shape self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.promote_dtype = promote_dtype self.einsum_op = einsum_op self.preferred_element_type = preferred_element_type def __call__( self, inputs: Array, einsum_str: tp.Optional[str] = None, out_sharding = None ) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. einsum_str: a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of ``einsum_str`` in the constructor argument and call argument must be not None, while the other must be None. Returns: The transformed input. """ einsum_str = first_from( einsum_str, self.einsum_str, error_msg="""No `einsum_str` argument was provided to Einsum as either a __call__ argument, or class attribute.""", ) einsum_str = einsum_str.replace(' ', '') self._einsum_str_check(einsum_str) inputs, kernel, bias = self.promote_dtype( ( inputs, self.kernel[...], self.bias[...] if self.bias is not None else self.bias, ), dtype=self.dtype, ) # We use einsum_op_kwargs for BC compatibility with # user custom self.einsum_op method which may not have # preferred_element_type argument to avoid breaking # existing code einsum_op_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: einsum_op_kwargs["preferred_element_type"] = self.preferred_element_type y = self.einsum_op( einsum_str, inputs, kernel, precision=self.precision, **einsum_op_kwargs ) if bias is not None: broadcasted_bias_shape = self._infer_broadcasted_bias_shape( einsum_str, inputs, kernel ) y += jnp.reshape(bias, broadcasted_bias_shape) return y def _infer_broadcasted_bias_shape( self, einsum_str: str, lhs: Array, rhs: Array ): """Infer the broadcasted bias shape given the ``einsum_str``, ``lhs`` and ``rhs`` arrays. This is needed reshaping the bias and it to the output during forward inference. This function first replaces all ellipses with actual letter characters, then computes the broadcasted bias shape by checking to see which axes in the rhs array remain in the resulting array after einsumming. These axes are the embedding/feature dimensions, and all other axes in rhs are reduction axes. """ # More details on the parsing function: https://github.com/dgasmith/opt_einsum/blob/c826bb7df16f470a69f7bf90598fc27586209d11/opt_einsum/parser.py#L246 # returns the einsum string representation of the operands and result, with # ellipsis replaced by actual letter characters operands_str, result_str, _ = opt_einsum.parser.parse_einsum_input( (einsum_str, lhs, rhs) ) # rhs_dict is a dict{character:index} mapping that maps every character in # the rhs einsum string representation to its corresponding index position in the string rhs_dict = {c: i for i, c in enumerate(operands_str.split(',')[1])} assert len(rhs_dict) == len(self.kernel_shape) broadcasted_bias_shape = [1] * len(result_str) for i, c in enumerate(result_str): if c in rhs_dict: broadcasted_bias_shape[i] = self.kernel_shape[rhs_dict[c]] return broadcasted_bias_shape def _einsum_str_check(self, einsum_str): if '->' not in einsum_str: raise ValueError( '`einsum_str` equation must be explicit and include "->".' ) if einsum_str.count(',') != 1: raise ValueError( '`einsum_str` equation must have exactly two operands and ' 'therefore, exactly one comma character, instead of ' f'{einsum_str.count(",")}' ) class Conv(Module): """Convolution Module wrapping ``lax.conv_general_dilated``. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.shape (3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.shape (3, 3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x) Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. kernel_size: shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers. strides: an integer or a sequence of ``n`` integers, representing the inter-window strides (default: 1). padding: either the string ``'SAME'``, the string ``'VALID'``, the string ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'` (reflection across the padding boundary), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and passing a single int in a sequence causes the same padding to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of ``inputs`` (default: 1). Convolution with input dilation ``d`` is equivalent to transposed convolution with stride ``d``. kernel_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). conv_general_dilated: the convolution function to use (default: ``jax.lax.conv_general_dilated``). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. preferred_element_type: Optional parameter controls the data type output by the convolution. This argument is passed to ``conv_general_dilated`` function. See ``jax.lax.conv_general_dilated`` for details. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the weight matrix. bias_metadata: Optional metadata dictionary to set when initializing the bias. """ def __init__( self, in_features: int, out_features: int, kernel_size: int | tp.Sequence[int], strides: tp.Union[None, int, tp.Sequence[int]] = 1, *, padding: PaddingLike = 'SAME', input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1, feature_group_count: int = 1, use_bias: bool = True, mask: tp.Optional[Array] = None, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, kernel_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): if isinstance(kernel_size, int): kernel_size = (kernel_size,) else: kernel_size = tuple(kernel_size) kernel_shape = kernel_size + ( in_features // feature_group_count, out_features, ) kernel_key = rngs.params() self.kernel_shape = kernel_shape self.kernel = nnx.Param( kernel_init(kernel_key, kernel_shape, param_dtype), **kernel_metadata ) self.bias: nnx.Param[jax.Array] | None if use_bias: bias_shape = (out_features,) bias_key = rngs.params() self.bias = nnx.Param( bias_init(bias_key, bias_shape, param_dtype), **bias_metadata ) else: self.bias = nnx.data(None) self.in_features = in_features self.out_features = out_features self.kernel_size = kernel_size self.strides = strides self.padding = padding self.input_dilation = input_dilation self.kernel_dilation = kernel_dilation self.feature_group_count = feature_group_count self.use_bias = use_bias self.mask = mask self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.conv_general_dilated = conv_general_dilated self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type def __call__(self, inputs: Array, out_sharding=None) -> Array: """Applies a (potentially unshared) convolution to the inputs. Args: inputs: input data with dimensions ``(*batch_dims, spatial_dims..., features)``. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated``, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. out_sharding: Optional sharding specification (e.g., ``jax.sharding.PartitionSpec``) for the output array. When using JAX's explicit sharding mode with a mesh context with ``AxisType.Explicit``. If ``None`` (default), the compiler automatically determines output sharding. Returns: The convolved data. """ assert isinstance(self.kernel_size, tuple) kernel_size = self.kernel_size def maybe_broadcast( x: tp.Optional[tp.Union[int, tp.Sequence[int]]], ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 x = 1 if isinstance(x, int): return (x,) * len(kernel_size) return tuple(x) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] flat_input_shape = (-1,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) # self.strides or (1,) * (inputs.ndim - 2) strides = maybe_broadcast(self.strides) input_dilation = maybe_broadcast(self.input_dilation) kernel_dilation = maybe_broadcast(self.kernel_dilation) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax in ('CIRCULAR', 'REFLECT'): assert isinstance(padding_lax, str) kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] zero_pad: list[tuple[int, int]] = [(0, 0)] pads = ( zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] ) padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax] inputs = jnp.pad(inputs, pads, mode=padding_mode) padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: raise ValueError( 'Causal padding is only implemented for 1D convolutions.' ) left_pad = kernel_dilation[0] * (kernel_size[0] - 1) pads = [(0, 0), (left_pad, 0), (0, 0)] inputs = jnp.pad(inputs, pads) padding_lax = 'VALID' dimension_numbers = _conv_dimension_numbers(inputs.shape) # One shared convolutional kernel for all pixels in the output. assert self.in_features % self.feature_group_count == 0 if self.mask is not None and self.mask.shape != self.kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {self.kernel_shape}' ) kernel = self.kernel[...] if self.mask is not None: kernel *= self.mask bias = self.bias[...] if self.bias is not None else None inputs, kernel, bias = self.promote_dtype( (inputs, kernel, bias), dtype=self.dtype ) # We use conv_kwargs for BC compatibility with # user custom self.conv_general_dilated method which may not have # preferred_element_type argument to avoid breaking # existing code conv_kwargs = {'out_sharding': out_sharding} if self.preferred_element_type is not None: conv_kwargs["preferred_element_type"] = self.preferred_element_type y = self.conv_general_dilated( inputs, kernel, strides, padding_lax, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, **conv_kwargs, ) if self.use_bias: bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) # type: ignore y += bias if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y class ConvTranspose(Module): """Convolution Module wrapping ``lax.conv_transpose``. **Note:** The `padding` argument behaves differently from PyTorch; see the argument description below. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.shape (3, 3, 4) >>> layer.bias.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.shape (6, 6, 4, 3) >>> layer.bias.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x) Args: in_features: int or tuple with number of input features. out_features: int or tuple with number of output features. kernel_size: shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers. strides: an integer or a sequence of ``n`` integers, representing the inter-window strides (default: 1). padding: either a string indicating a specialized padding mode, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and a single int in a sequence causes the same padding to be used on both sides. **Note that this behavior is different from PyTorch**. In PyTorch, the padding argument effectively adds ``dilation * (kernel_size - 1) - padding`` amount of zero padding to the input instead. This is set so that when ``torch.Conv2d`` and ``torch.ConvTranspose2d`` are initialized with the same parameters, they are inverses of each other in regard to the input and output shapes. ``nnx.Conv`` and ``nnx.ConvTranspose`` do *not* have this behavior; if you want a ``nnx.ConvTranspose`` layer to invert the shape change produced by a ``nnx.Conv`` layer with a given padding and dilation, you should explicitly pass ``dilation * (kernel_size - 1) - padding`` as the `padding` argument to the ``nnx.ConvTranspose`` layer. Strings for specifying padding modes can be one of the following: - ``VALID`` adds ``dilation * (kernel_size - 1)`` padding to all dimensions. This is set so that a ``nnx.Conv`` layer with ``VALID`` padding would produce the inverse shape transformation. - ``SAME`` pads the input so that the output shape is the same as the input shape. - ``CIRCULAR`` pads the input with periodic boundary conditions. - ``CAUSAL`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. kernel_dilation: an integer or a sequence of ``n`` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see ``jax.lax.Precision`` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. transpose_kernel: if ``True`` flips spatial axes and swaps the input/output channel axes of the kernel. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(inputs, kernel, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. preferred_element_type: Optional parameter controls the data type output by the transposed convolution. This argument is passed to ``jax.lax.conv_transpose`` function. See ``jax.lax.conv_transpose`` for details. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the weight matrix. bias_metadata: Optional metadata dictionary to set when initializing the bias. """ def __init__( self, in_features: int, out_features: int, kernel_size: int | tp.Sequence[int], strides: int | tp.Sequence[int] | None = None, *, padding: PaddingLike = 'SAME', kernel_dilation: int | tp.Sequence[int] | None = None, use_bias: bool = True, mask: Array | None = None, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike | None = None, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, transpose_kernel: bool = False, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, preferred_element_type: Dtype | None = None, rngs: rnglib.Rngs, kernel_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): if isinstance(kernel_size, int): kernel_size = (kernel_size,) else: kernel_size = tuple(kernel_size) self.kernel_size = kernel_size self.in_features = in_features self.out_features = out_features self.strides = strides self.padding = padding self.kernel_dilation = kernel_dilation self.use_bias = use_bias self.mask = mask self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.transpose_kernel = transpose_kernel self.promote_dtype = promote_dtype self.preferred_element_type = preferred_element_type if self.transpose_kernel: kernel_shape = kernel_size + (self.out_features, in_features) else: kernel_shape = kernel_size + (in_features, self.out_features) self.kernel_shape = kernel_shape self.kernel = nnx.Param( kernel_init(rngs.params(), kernel_shape, self.param_dtype), **kernel_metadata ) self.bias: nnx.Param | None if self.use_bias: self.bias = nnx.Param( bias_init(rngs.params(), (self.out_features,), self.param_dtype), **bias_metadata ) else: self.bias = nnx.data(None) def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of ``jax.lax.conv_transpose``. Args: inputs: input data with dimensions ``(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated``, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. Returns: The convolved data. """ kernel_size = self.kernel_size def maybe_broadcast( x: tp.Optional[tp.Union[int, tp.Sequence[int]]], ) -> tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 x = 1 if isinstance(x, int): return (x,) * len(kernel_size) return tuple(x) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] flat_input_shape = (-1,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) strides = maybe_broadcast(self.strides) kernel_dilation = maybe_broadcast(self.kernel_dilation) kernel_shape = self.kernel_shape if self.mask is not None and self.mask.shape != kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {kernel_shape}' ) kernel = self.kernel[...] if self.mask is not None: kernel *= self.mask padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax == 'CIRCULAR': padding_lax = 'VALID' bias = self.bias[...] if self.bias is not None else None inputs, kernel, bias = self.promote_dtype( (inputs, kernel, bias), dtype=self.dtype ) y = lax.conv_transpose( inputs, kernel, strides, padding_lax, rhs_dilation=kernel_dilation, transpose_kernel=self.transpose_kernel, precision=self.precision, preferred_element_type=self.preferred_element_type ) if self.padding == 'CIRCULAR': # For circular padding, we need to identify the size of the final output # ("period") along each spatial dimension, pad each dimension to an # integer number of periods, and wrap the array periodically around each # dimension. Padding should be done in such a way that the start of the # original input data inside the padded array is located at integer # number of periods - otherwise the result would be circularly shifted. # Compute period along each spatial dimension - it's input size scaled # by the stride. scaled_x_dims = [ x_dim * stride for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) ] # Compute difference between the current size of y and the final output # size, and complement this difference to 2 * period - that gives how # much we need to pad. size_diffs = [ -(y_dim - x_dim) % (2 * x_dim) for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) ] if self.transpose_kernel: # If the kernel is transposed, the "+1" is put on the right to # mirror the regular convolution. If the same kernel parameters are used # as for Conv, this layer then computes the proper transpose convolution. total_pad = [ (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs ] else: # Divide the padding equally between left and right. The choice to put # "+1" on the left (and not on the right) represents a convention for # aligning even-sized kernels. total_pad = [ ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs ] y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)]) # Wrap the result periodically around each spatial dimension, # one by one. for i in range(1, y.ndim - 1): y = y.reshape( y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :] ) y = y.sum(axis=i) if self.use_bias: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) # type: ignore if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 ) class Embed(Module): """Embedding Module. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': Param( # 15 (60 B) value=Array([[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135], [-1.0917838 , -0.7441663 , 0.07713798], [-0.66570747, 0.13815777, 1.007365 ]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[ 0.57966787, -0.523274 , -0.43195742], [-0.676289 , -0.50300646, 0.33996582], [ 0.41796115, -0.59212935, 0.95934135]], [[-0.66570747, 0.13815777, 1.007365 ], [-1.0917838 , -0.7441663 , 0.07713798], [ 0.41796115, -0.59212935, 0.95934135]]], dtype=float32) A parameterized function from integers [0, ``num_embeddings``) to ``features``-dimensional vectors. This ``Module`` will create an ``embedding`` matrix with shape ``(num_embeddings, features)``. When calling this layer, the input values will be used to 0-index into the ``embedding`` matrix. Indexing on a value greater than or equal to ``num_embeddings`` will result in ``nan`` values. When ``num_embeddings`` equals to 1, it will broadcast the ``embedding`` matrix to input shape with ``features`` dimension appended. Args: num_embeddings: number of embeddings / vocab size. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(embedding,)`` during ``__call__`` or ``(query, embedding)`` during ``attend``, and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. embedding_metadata: Optional metadata dictionary to set when initializing the embedding matrix. """ def __init__( self, num_embeddings: int, features: int, *, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, embedding_init: Initializer = default_embed_init, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, embedding_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): self.embedding = nnx.Param( embedding_init(rngs.params(), (num_embeddings, features), param_dtype), **embedding_metadata, ) self.num_embeddings = num_embeddings self.features = features self.dtype = dtype or self.embedding.dtype self.param_dtype = param_dtype self.promote_dtype = promote_dtype def __call__(self, inputs: Array, out_sharding=None) -> Array: """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Values in the input array must be integers. Returns: Output which is embedded input data. The output shape follows the input, with an additional ``features`` dimension appended. """ if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = self.promote_dtype( (self.embedding[...],), dtype=self.dtype, inexact=False ) if self.num_embeddings == 1: return jnp.broadcast_to(embedding, inputs.shape + (self.features,)) if out_sharding is not None: # Use auto_axes to handle out_sharding as jnp.take does not support it. take_fn = lambda embedding, inputs: jnp.take(embedding, inputs, axis=0) sharded_take = jax.sharding.auto_axes(out_sharding=out_sharding)(take_fn) return sharded_take(embedding, inputs) return jnp.take(embedding, inputs, axis=0) def attend(self, query: Array, out_sharding=None) -> Array: """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth ``features`` of the embedding. out_sharding: Optional sharding specification (e.g., ``jax.sharding.PartitionSpec``) for the output array. When using JAX's explicit sharding mode with a mesh context with ``AxisType.Explicit``. If ``None`` (default), the compiler automatically determines output sharding. Returns: An array with final dim ``num_embeddings`` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ query, embedding = self.promote_dtype( (query, self.embedding[...]), dtype=self.dtype ) return jnp.dot(query, embedding.T, out_sharding=out_sharding) ================================================ FILE: flax/nnx/nn/lora.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import typing as tp from types import MappingProxyType from flax.nnx import rnglib, variablelib from flax.nnx.module import Module from flax.nnx.nn import initializers, dtypes from flax.nnx.nn.linear import Linear from flax.typing import Dtype, Initializer, PromoteDtypeFn import jax import jax.numpy as jnp Array = jax.Array Axis = int Size = int A = tp.TypeVar('A') default_a_initializer = initializers.he_uniform() default_b_initializer = initializers.zeros class LoRAParam(variablelib.Param[A]): pass class LoRA(Module): """A standalone LoRA layer. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.shape (3, 2) >>> layer.lora_b.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.shape (3, 2) >>> layer.lora_b.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4) Args: in_features: the number of input features. lora_rank: the rank of the LoRA dimension. out_features: the number of output features. base_module: a base module to call and substitute, if possible. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. a_initializer: initializer function for the fan-in matrices. Default to `he_uniform`. b_initializer: initializer function for the fan-out matrices. Default to `zero initializer`. lora_param_type: the type of the LoRA params. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, lora_a, lora_b)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. a_metadata: Optional metadata dictionary to set when initializing the fan-in matrices. b_metadata: Optional metadata dictionary to set when initializing the fan-out matrices. """ def __init__( self, in_features: int, lora_rank: int, out_features: int, *, base_module: tp.Optional[Module] = None, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, a_initializer: Initializer = default_a_initializer, b_initializer: Initializer = default_b_initializer, lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, a_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), b_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): self.in_features = in_features self.out_features = out_features self.dtype = dtype self.param_dtype = param_dtype self.lora_param_type = lora_param_type self.base_module = base_module self.promote_dtype = promote_dtype self.lora_a = lora_param_type( a_initializer(rngs.params(), (in_features, lora_rank), param_dtype), **a_metadata, ) self.lora_b = lora_param_type( b_initializer(rngs.params(), (lora_rank, out_features), param_dtype), **b_metadata, ) def __call__(self, x: jax.Array): x, lora_a, lora_b = self.promote_dtype( (x, self.lora_a[...], self.lora_b[...]), dtype=self.dtype ) out = x @ lora_a @ lora_b if self.base_module is not None: if not callable(self.base_module): raise ValueError('`self.base_module` must be callable.') out += self.base_module(x) return out class LoRALinear(Linear): """An `nnx.Linear` layer in which the output will be LoRAified. The model state structure will be compatible with that of Linear. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.shape (3, 4) >>> lora_linear.kernel.shape (3, 4) >>> lora_linear.lora.lora_a.shape (3, 2) >>> jnp.allclose(linear.kernel[...], lora_linear.kernel[...]) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4) Args: in_features: the number of input features. out_features: the number of output features. lora_rank: the rank of the LoRA dimension. base_module: a base module to call and substitute, if possible. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. a_initializer: initializer function for the fan-in matrices. Default to `he_uniform`. b_initializer: initializer function for the fan-out matrices. Default to `zero initializer`. lora_param_type: the type of the LoRA params. lora_promote_dtype: function to promote the dtype for the LoRA submodule. a_metadata: Optional metadata dictionary to set when initializing the fan-in matrices. b_metadata: Optional metadata dictionary to set when initializing the fan-out matrices. """ def __init__( self, in_features: int, out_features: int, *, lora_rank: int, lora_dtype: tp.Optional[Dtype] = None, lora_param_dtype: Dtype = jnp.float32, a_initializer: Initializer = default_a_initializer, b_initializer: Initializer = default_b_initializer, lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, lora_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, a_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), b_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), **kwargs, ): super().__init__(in_features, out_features, rngs=rngs, **kwargs) self.lora = LoRA( in_features, lora_rank, out_features, dtype=lora_dtype, param_dtype=lora_param_dtype, a_initializer=a_initializer, b_initializer=b_initializer, lora_param_type=lora_param_type, promote_dtype=lora_promote_dtype, rngs=rngs, a_metadata=a_metadata, b_metadata=b_metadata, ) def __call__(self, x: jax.Array, out_sharding = None): y = super().__call__(x, out_sharding=out_sharding) y += self.lora(x) return y ================================================ FILE: flax/nnx/nn/normalization.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp from types import MappingProxyType import jax import jax.numpy as jnp from jax import lax from flax import nnx from flax.nnx import rnglib from flax.nnx.module import Module, first_from from flax.nnx.nn import dtypes, initializers from flax.typing import ( Array, Dtype, Initializer, Axes, PromoteDtypeFn, ) def _canonicalize_axes(rank: int, axes: Axes) -> tp.Tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, tp.Iterable): axes = (axes,) return tuple({rank + axis if axis < 0 else axis for axis in axes}) def _abs_sq(x): """Computes the elementwise square of the absolute value |x|^2.""" if jnp.iscomplexobj(x): return lax.square(lax.real(x)) + lax.square(lax.imag(x)) else: return lax.square(x) def _compute_stats( x: Array, axes: Axes, dtype: tp.Optional[Dtype], axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_mean: bool = True, use_fast_variance: bool = True, mask: tp.Optional[Array] = None, ) -> tuple[Array, Array]: """Computes mean and variance statistics. This implementation takes care of a few important details: - Computes in float32 precision for stability in half precision training. - If ``use_fast_variance`` is ``True``, mean and variance are computed using Var = E[|x|^2] - |E[x]|^2, instead of Var = E[|x - E[x]|^2]), in a single XLA fusion. - Clips negative variances to zero which can happen due to roundoff errors. This avoids downstream NaNs. - Supports averaging across a parallel axis and subgroups of a parallel axis with a single ``lax.pmean`` call to avoid latency. Arguments: x: Input array. axes: The axes in ``x`` to compute mean and variance statistics for. dtype: Optional dtype specifying the minimal precision. Statistics are always at least float32 for stability (default: dtype of x). axis_name: Optional name for the pmapped axis to compute mean over. Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: Optional groups of indices within that named axis. use_mean: If true, calculate the mean from the input and use it when computing the variance. If false, set the mean to zero and compute the variance without subtracting the mean. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: A pair ``(mean, var)``. """ if dtype is None: dtype = jnp.result_type(x) # promote x to at least float32, this avoids half precision computation # but preserves double or complex floating points dtype = jnp.promote_types(dtype, jnp.float32) x = jnp.asarray(x, dtype) axes = _canonicalize_axes(x.ndim, axes) def maybe_distributed_mean(*xs, mask=None): mus = tuple(x.mean(axes, where=mask) for x in xs) if axis_name is None: return mus if len(xs) > 1 else mus[0] else: # In the distributed case we stack multiple arrays to speed comms. if len(xs) > 1: reduced_mus = lax.pmean( jnp.stack(mus, axis=0), axis_name, axis_index_groups=axis_index_groups, ) return tuple(reduced_mus[i] for i in range(len(xs))) else: return lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups) if use_mean: if use_fast_variance: mu, mu2 = maybe_distributed_mean(x, _abs_sq(x), mask=mask) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: mu = maybe_distributed_mean(x, mask=mask) var = maybe_distributed_mean( _abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask ) else: var = maybe_distributed_mean(_abs_sq(x), mask=mask) mu = jnp.zeros_like(var) return mu, var def _normalize( x: Array, mean: Array, var: Array, scale: tp.Optional[Array], bias: tp.Optional[Array], reduction_axes: Axes, feature_axes: Axes, dtype: tp.Optional[Dtype], epsilon: float, ): """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. Arguments: x: The input. mean: Mean to use for normalization. var: Variance to use for normalization. reduction_axes: The axes in ``x`` to reduce. feature_axes: Axes containing features. A separate bias and scale is learned for each specified feature. dtype: The dtype of the result (default: infer from input and params). epsilon: Normalization epsilon. Returns: The normalized input. """ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) feature_axes = _canonicalize_axes(x.ndim, feature_axes) stats_shape = list(x.shape) for axis in reduction_axes: stats_shape[axis] = 1 mean = mean.reshape(stats_shape) var = var.reshape(stats_shape) feature_shape = [1] * x.ndim for ax in feature_axes: feature_shape[ax] = x.shape[ax] y = x - mean mul = lax.rsqrt(var + epsilon) args = [x] if scale is not None: scale = scale.reshape(feature_shape) mul *= scale args.append(scale) y *= mul if bias is not None: bias = bias.reshape(feature_shape) y += bias args.append(bias) dtype = dtypes.canonicalize_dtype(*args, dtype=dtype) return jnp.asarray(y, dtype) def _l2_normalize(x, axis=None, eps=1e-12): """Normalizes along dimension `axis` using an L2 norm. This specialized function exists for numerical stability reasons. Args: x: An input ndarray. axis: Dimension along which to normalize, e.g. `1` to separately normalize vectors in a batch. Passing `None` views `t` as a flattened vector when calculating the norm (equivalent to Frobenius norm). eps: Epsilon to avoid dividing by zero. Returns: An array of the same shape as 'x' L2-normalized along 'axis'. """ return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) class BatchNorm(Module): """BatchNorm Module. To calculate the batch norm on the input and update the batch statistics, call the :func:`train` method (or pass in ``use_running_average=False`` in the constructor or during call time). To use the stored batch statistics' running average, call the :func:`eval` method (or pass in ``use_running_average=True`` in the constructor or during call time). Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5, ... dtype=jnp.float32, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(6,) ), 'mean': BatchStat( value=(6,) ), 'scale': Param( value=(6,) ), 'var': BatchStat( value=(6,) ) }) >>> # calculate batch norm on input and update batch statistics >>> layer.train() >>> y = layer(x) >>> batch_stats1 = nnx.clone(nnx.state(layer, nnx.BatchStat)) # keep a copy >>> y = layer(x) >>> batch_stats2 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats1['mean'][...] != batch_stats2['mean'][...]).all() >>> assert (batch_stats1['var'][...] != batch_stats2['var'][...]).all() >>> # use stored batch statistics' running average >>> layer.eval() >>> y = layer(x) >>> batch_stats3 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats2['mean'][...] == batch_stats3['mean'][...]).all() >>> assert (batch_stats2['var'][...] == batch_stats3['var'][...]).all() Args: num_features: the number of input features. use_running_average: if True, the stored batch statistics will be used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, mean, var, scale, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. bias_metadata: Optional metadata dictionary to set when initializing the bias. scale_metadata: Optional metadata dictionary to set when initializing the scale. """ def __init__( self, num_features: int, *, use_running_average: bool = False, axis: int = -1, momentum: float = 0.99, epsilon: float = 1e-5, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, use_bias: bool = True, use_scale: bool = True, bias_init: Initializer = initializers.zeros_init(), scale_init: Initializer = initializers.ones_init(), axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_fast_variance: bool = True, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): feature_shape = (num_features,) self.mean = nnx.BatchStat(jnp.zeros(feature_shape, jnp.float32)) self.var = nnx.BatchStat(jnp.ones(feature_shape, jnp.float32)) self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype), **scale_metadata) else: self.scale = nnx.data(None) self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype), **bias_metadata) else: self.bias = nnx.data(None) self.num_features = num_features self.use_running_average = use_running_average self.axis = axis self.momentum = momentum self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.use_bias = use_bias self.use_scale = use_scale self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance self.promote_dtype = promote_dtype def __call__( self, x, use_running_average: tp.Optional[bool] = None, *, mask: tp.Optional[jax.Array] = None, ): """Normalizes the input using batch statistics. Args: x: the input to be normalized. use_running_average: if true, the stored batch statistics will be used instead of computing the batch statistics on the input. The ``use_running_average`` flag passed into the call method will take precedence over the ``use_running_average`` flag passed into the constructor. Returns: Normalized inputs (the same shape as inputs). """ use_running_average = first_from( use_running_average, self.use_running_average, error_msg="""No `use_running_average` argument was provided to BatchNorm as either a __call__ argument, class attribute, or nnx.flag.""", ) feature_axes = _canonicalize_axes(x.ndim, self.axis) reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) # Promote dtypes for input and all Variables scale = self.scale[...] if self.scale is not None else None bias = self.bias[...] if self.bias is not None else None x, mean, var, scale, bias = self.promote_dtype( (x, self.mean[...], self.var[...], scale, bias), dtype=self.dtype ) if not use_running_average: mean, var = _compute_stats( x, reduction_axes, dtype=self.dtype, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, ) # stop_gradient only for flax_array_ref if self.mean._can_update or self.var._can_update: stop_gradient = jax.lax.stop_gradient else: stop_gradient = lambda x: x self.mean[...] = stop_gradient( self.momentum * self.mean[...] + (1 - self.momentum) * mean ) self.var[...] = stop_gradient( self.momentum * self.var[...] + (1 - self.momentum) * var ) return _normalize( x, mean, var, scale, bias, reduction_axes, feature_axes, self.dtype, self.epsilon, ) def set_view( self, use_running_average: bool | None = None, ): """Class method used by ``nnx.view``. Args: use_running_average: if True, the stored batch statistics will be used instead of computing the batch statistics on the input. """ if use_running_average is not None: self.use_running_average = use_running_average class LayerNorm(Module): """Layer normalization (https://arxiv.org/abs/1607.06450). LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. Example usage:: >>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) Args: num_features: the number of input features. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nnx.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. reduction_axes: Axes for computing normalization statistics. feature_axes: Feature axes for learned bias and scaling. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, scale, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. bias_metadata: Optional metadata dictionary to set when initializing the bias. scale_metadata: Optional metadata dictionary to set when initializing the scale. """ def __init__( self, num_features: int, *, epsilon: float = 1e-6, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, use_bias: bool = True, use_scale: bool = True, bias_init: Initializer = initializers.zeros_init(), scale_init: Initializer = initializers.ones_init(), reduction_axes: Axes = -1, feature_axes: Axes = -1, axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_fast_variance: bool = True, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): feature_shape = (num_features,) self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype), **scale_metadata) else: self.scale = nnx.data(None) self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype), **bias_metadata) else: self.bias = nnx.data(None) self.num_features = num_features self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.use_bias = use_bias self.use_scale = use_scale self.reduction_axes = reduction_axes self.feature_axes = feature_axes self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance self.promote_dtype = promote_dtype def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ # Promote dtypes for input and all Variables scale = self.scale[...] if self.scale else None bias = self.bias[...] if self.bias else None x, scale, bias = self.promote_dtype( (x, scale, bias), dtype=self.dtype ) mean, var = _compute_stats( x, self.reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, ) return _normalize( x, mean, var, scale, bias, self.reduction_axes, self.feature_axes, self.dtype, self.epsilon, ) class RMSNorm(Module): """RMS Layer normalization (https://arxiv.org/abs/1910.07467). RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations. Example usage:: >>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) Args: num_features: the number of input features. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. scale_init: Initializer for scale, by default, one. reduction_axes: Axes for computing normalization statistics. feature_axes: Feature axes for learned bias and scaling. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, scale)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. scale_metadata: Optional metadata dictionary to set when initializing the scale. """ def __init__( self, num_features: int, *, epsilon: float = 1e-6, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, use_scale: bool = True, scale_init: Initializer = initializers.ones, reduction_axes: Axes = -1, feature_axes: Axes = -1, axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_fast_variance: bool = True, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): feature_shape = (num_features,) self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype), **scale_metadata) else: self.scale = nnx.data(None) self.num_features = num_features self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.use_scale = use_scale self.reduction_axes = reduction_axes self.feature_axes = feature_axes self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance self.promote_dtype = promote_dtype def __call__(self, x, mask: tp.Optional[jax.Array] = None): """Applies layer normalization on the input. Args: x: the inputs Returns: Normalized inputs (the same shape as inputs). """ # Promote dtypes for input and all Variables scale = self.scale[...] if self.scale else None x, scale = self.promote_dtype( (x, scale), dtype=self.dtype ) mean, var = _compute_stats( x, self.reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_mean=False, use_fast_variance=self.use_fast_variance, mask=mask, ) return _normalize( x, mean, var, scale, None, self.reduction_axes, self.feature_axes, self.dtype, self.epsilon, ) class GroupNorm(Module): """Group normalization (arxiv.org/abs/1803.08494). This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group. .. note:: LayerNorm is a special case of GroupNorm where ``num_groups=1``. Example usage:: >>> from flax import nnx >>> import jax >>> import numpy as np ... >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) ... >>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x) >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2) Args: num_features: the number of input features/channels. num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. reduction_axes: List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, scale, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: rng key. bias_metadata: Optional metadata dictionary to set when initializing the bias. scale_metadata: Optional metadata dictionary to set when initializing the scale. """ def __init__( self, num_features: int, num_groups: tp.Optional[int] = 32, group_size: tp.Optional[int] = None, *, epsilon: float = 1e-6, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, use_bias: bool = True, use_scale: bool = True, bias_init: Initializer = initializers.zeros_init(), scale_init: Initializer = initializers.ones_init(), reduction_axes: tp.Optional[Axes] = None, axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_fast_variance: bool = True, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): self.feature_axis = -1 if (num_groups is None and group_size is None) or ( num_groups is not None and group_size is not None ): raise ValueError( 'Either `num_groups` or `group_size` should be ' 'specified. If `group_size` is to be specified, ' 'pass `num_groups=None` as argument to override ' 'the default `num_groups` value of 32.' ) if group_size is not None: if num_features % group_size != 0: raise ValueError( 'Number of features ({}) is not multiple of the ' 'group size ({}).'.format(num_features, group_size) ) self.num_groups = num_features // group_size self.group_size = group_size else: if not isinstance(num_groups, int) or num_groups <= 0 or ( num_features % num_groups != 0 ): raise ValueError( 'Number of groups ({}) does not divide the number' ' of channels ({}).'.format(num_groups, num_features) ) self.num_groups = num_groups self.group_size = num_features // num_groups feature_shape = (num_features,) self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype), **scale_metadata) else: self.scale = nnx.data(None) self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype), **bias_metadata) else: self.bias = nnx.data(None) self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.use_bias = use_bias self.use_scale = use_scale self.reduction_axes = reduction_axes self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance self.promote_dtype = promote_dtype def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: x: the input of shape ``...self.num_features`` where ``self.num_features`` is a channels dimension and ``...`` represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions ``...`` will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch. mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ if self.reduction_axes is not None: reduction_axes = self.reduction_axes else: reduction_axes = list(range(1, x.ndim - 1)) + [-1] reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) group_shape = x.shape[:-1] + (self.num_groups, self.group_size) if mask is not None: mask = mask.reshape(mask.shape[:-1] + (self.num_groups, self.group_size)) # Promote dtypes for input and all Variables scale = self.scale[...] if self.scale else None bias = self.bias[...] if self.bias else None x, scale, bias = self.promote_dtype( (x, scale, bias), dtype=self.dtype ) mean, var = _compute_stats( x.reshape(group_shape), list(reduction_axes[:-1]) + [-1], self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, ) mean = jnp.repeat(mean, self.group_size, axis=1) var = jnp.repeat(var, self.group_size, axis=1) return _normalize( x, mean, var, scale, bias, reduction_axes[:-1], (self.feature_axis,), self.dtype, self.epsilon, ) class WeightNorm(nnx.Module): """L2 weight normalization (https://arxiv.org/abs/1602.07868). Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its ``__call__`` output. Example usage:: >>> import jax >>> import numpy as np >>> from flax import nnx >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.normed_linear = nnx.WeightNorm( ... nnx.Linear(8, 4, rngs=rngs), ... variable_filter=nnx.PathContains('kernel'), ... rngs=rngs, ... ) ... ... def __call__(self, x: jax.Array) -> jax.Array: ... return self.normed_linear(x) >>> rng = jax.random.key(42) >>> model = Foo(rngs=nnx.Rngs(rng)) >>> x = jax.random.normal(rng, (5, 8)) >>> y = model(x) >>> y.shape (5, 4) >>> w = model.normed_linear.layer_instance.kernel[...] >>> col_norms = np.linalg.norm(np.array(w), axis=0) >>> np.testing.assert_allclose(col_norms, np.ones(4)) Args: layer_instance: The layer instance to wrap. feature_axes: The axes to normalize. use_scale: Whether to use a scale parameter. scale_init: The initializer for the scale parameter, by default ones. epsilon: The epsilon value for the normalization, by default 1e-12. dtype: The dtype of the result, by default infer from input and params. param_dtype: The dtype of the parameters, by default float32. variable_filter: The variable filter, by default ``nnx.PathContains('kernel')``. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. This is used internally by WeightNorm when normalizing weights. rngs: The rng key. """ def __init__( self, layer_instance: nnx.Module, *, feature_axes: Axes | None = -1, use_scale: bool = True, scale_init: Initializer = initializers.ones, epsilon: float = 1e-12, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, variable_filter: nnx.filterlib.Filter = nnx.PathContains('kernel'), promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, ): self.layer_instance = layer_instance self.feature_axes = () if feature_axes is None else feature_axes self.use_scale = use_scale self.scale_init = scale_init self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.variable_filter = nnx.filterlib.to_predicate(variable_filter) self.promote_dtype = promote_dtype self.scales : tp.Optional[dict] = None if use_scale: state = nnx.state(self.layer_instance, nnx.Param) def init_scales(param): feature_axes = _canonicalize_axes(param.ndim, self.feature_axes) scale_shape = tuple(param.shape[ax] for ax in feature_axes) return scale_init(rngs['params'], scale_shape) self.scales = nnx.data({ path: init_scales(param) for path, param in nnx.to_flat_state(state) if self.variable_filter(path, param)}) def _weightnorm_inplace(self, path, param): if not self.variable_filter(path, param): return if self.feature_axes is None: feature_axes = () reduction_axes = tuple(range(param.ndim)) else: feature_axes = _canonicalize_axes(param.ndim, self.feature_axes) reduction_axes = tuple( i for i in range(param.ndim) if i not in feature_axes ) value_bar = _l2_normalize(param, axis=reduction_axes, eps=self.epsilon) if self.use_scale: if path not in self.scales: raise RuntimeError( f'Could not find the scale corresponding to the param {path} ' 'in scales dict. Parameters of the layer_instance should not change!' ) scale_value = self.scales[path] if len(feature_axes) < param.ndim: broadcast_shape = [1] * param.ndim for ax in feature_axes: broadcast_shape[ax] = param.shape[ax] scale_value = scale_value.reshape(broadcast_shape) value_bar = value_bar * scale_value cast_args = [param] if self.use_scale: cast_args.append(scale_value) final_dtype = dtypes.canonicalize_dtype(*cast_args, dtype=self.dtype) param.set_value(jnp.asarray(value_bar, final_dtype)) def __call__(self, x: Array, *args, **kwargs) -> Array: """Compute the l2-norm of the weights in ``self.layer_instance`` and normalize the weights using this value before computing the ``__call__`` output. Args: *args: positional arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. **kwargs: keyword arguments to be passed into the call method of the underlying layer instance in ``self.layer_instance``. Returns: Output of the layer using l2-normalized weights. """ state = nnx.state(self.layer_instance) for path, param in nnx.to_flat_state(state): self._weightnorm_inplace(path, param) return self.layer_instance(x, *args, **kwargs) # type: ignore class InstanceNorm(Module): """Instance normalization (https://arxiv.org/abs/1607.08022v3). InstanceNorm normalizes the activations of the layer for each channel (rather than across all channels like Layer Normalization), and for each given example in a batch independently (rather than across an entire batch like Batch Normalization). i.e. applies a transformation that maintains the mean activation within each channel within each example close to 0 and the activation standard deviation close to 1. .. note:: This normalization operation is identical to LayerNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters). Example usage:: >>> from flax import nnx >>> import jax >>> import numpy as np >>> # dimensions: (batch, height, width, channel) >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) >>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0)) >>> nnx.state(layer, nnx.Param) State({ 'bias': Param( # 5 (20 B) value=Array([0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 5 (20 B) value=Array([1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y3, atol=1e-7) Args: num_features: the number of input features/channels. epsilon: A small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. bias_init: Initializer for bias, by default, zero. scale_init: Initializer for scale, by default, one. feature_axes: Axes for features. The learned bias and scaling parameters will be in the shape defined by the feature axes. All other axes except the batch axes (which is assumed to be the leading axis) will be reduced. axis_name: the axis name used to combine batch statistics from multiple devices. See ``jax.pmap`` for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the examples on the first two and last two devices. See ``jax.lax.psum`` for more details. This argument is currently not supported for SPMD jit. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype. The function should accept a tuple of ``(inputs, scale, bias)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. rngs: The rng key. bias_metadata: Optional metadata dictionary to set when initializing the bias. scale_metadata: Optional metadata dictionary to set when initializing the scale. """ def __init__( self, num_features: int, *, epsilon: float = 1e-6, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, use_bias: bool = True, use_scale: bool = True, bias_init: Initializer = initializers.zeros, scale_init: Initializer = initializers.ones, feature_axes: Axes = -1, axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, use_fast_variance: bool = True, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: rnglib.Rngs, bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), ): feature_shape = (num_features,) self.scale: nnx.Param[jax.Array] | None if use_scale: key = rngs.params() self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype), **scale_metadata) else: self.scale = None self.bias: nnx.Param[jax.Array] | None if use_bias: key = rngs.params() self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype), **bias_metadata) else: self.bias = None self.num_features = num_features self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.use_bias = use_bias self.use_scale = use_scale self.feature_axes = feature_axes self.axis_name = axis_name self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance self.promote_dtype = promote_dtype def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): """Applies instance normalization on the input. Args: x: the inputs mask: Binary array of shape broadcastable to ``inputs`` array, indicating the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). """ feature_axes = _canonicalize_axes(x.ndim, self.feature_axes) if 0 in feature_axes: raise ValueError('The channel axes cannot include the leading dimension ' 'as this is assumed to be the batch axis.') reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes] # Promote dtypes for input and all Variables scale = self.scale[...] if self.scale else None bias = self.bias[...] if self.bias else None x, scale, bias = self.promote_dtype( (x, scale, bias), dtype=self.dtype ) mean, var = _compute_stats( x, reduction_axes, self.dtype, self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, mask=mask, ) return _normalize( x, mean, var, scale, bias, reduction_axes, feature_axes, self.dtype, self.epsilon, ) class SpectralNorm(Module): """Spectral normalization. See: - https://arxiv.org/abs/1802.05957 - https://arxiv.org/abs/1805.08318 - https://arxiv.org/abs/1809.11096 Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its ``__call__`` output. .. note:: The initialized variables dict will contain, in addition to a 'params' collection, a separate 'batch_stats' collection that will contain a ``u`` vector and ``sigma`` value, which are intermediate values used when performing spectral normalization. During training, we pass in ``update_stats=True`` so that ``u`` and ``sigma`` are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass in ``update_stats=False`` to ensure we get deterministic behavior from the model. Example usage:: >>> from flax import nnx >>> import jax >>> rngs = nnx.Rngs(0) >>> x = jax.random.normal(jax.random.key(0), (3, 4)) >>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs), rngs=rngs) >>> jax.tree.map(jax.numpy.shape, nnx.state(layer, nnx.Param)) State({ 'layer_instance': { 'bias': Param( value=(5,) ), 'kernel': Param( value=(4, 5) ) } }) >>> y = layer(x, update_stats=True) Args: layer_instance: Module instance that is wrapped with SpectralNorm n_steps: How many steps of power iteration to perform to approximate the singular value of the weight params. epsilon: A small float added to l2-normalization to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). error_on_non_matrix: Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer. update_stats: if True, the stored batch statistics will be used instead of computing the batch statistics on the input. rngs: rng key. """ def __init__( self, layer_instance: Module, *, n_steps: int = 1, epsilon: float = 1e-12, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, error_on_non_matrix: bool = False, update_stats: bool = True, rngs: rnglib.Rngs, ): self.layer_instance = layer_instance self.n_steps = n_steps self.epsilon = epsilon self.dtype = dtype self.param_dtype = param_dtype self.error_on_non_matrix = error_on_non_matrix # We define here self.use_running_average attribute to make # .train() and .eval() work. These methods internally change # self.use_running_average to False in train mode and # to True in eval mode # update_stats flag has the opposite logic. self.use_running_average = not update_stats # Initialize batch stat variables: state = nnx.state(self.layer_instance, nnx.Param) def init_batch_stats(path, param): if param.ndim <= 1 or self.n_steps < 1: return None elif param.ndim > 2: if self.error_on_non_matrix: raise ValueError( f'Layer instance parameter is {param.ndim}D but error_on_non_matrix is True' ) else: param = jnp.reshape(param, (-1, param.shape[-1])) path_u = path + ("u", ) path_sigma = path + ("sigma", ) key = rngs.params() return [ ( path_u, nnx.BatchStat( initializers.normal()(key, (1, param.shape[-1]), self.param_dtype)) ), ( path_sigma, nnx.BatchStat(initializers.ones(key, (), self.param_dtype)), ) ] batch_stats: dict[tuple, nnx.BatchStat] = {} for path, param in nnx.to_flat_state(state): batch_stats_per_param = init_batch_stats(path, param) if batch_stats_per_param is None: continue for new_path, bstat in batch_stats_per_param: batch_stats[new_path] = bstat self.batch_stats = nnx.data(batch_stats) def __call__( self, x, update_stats: tp.Optional[bool] = None, ): """Compute the largest singular value of the weights in ``self.layer_instance`` using power iteration and normalize the weights using this value before computing the ``__call__`` output. Args: x: the input array of the nested layer update_stats: if True, update the internal ``u`` vector and ``sigma`` value after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. Returns: Output of the layer using spectral normalized weights. """ update_stats = first_from( update_stats, not self.use_running_average, error_msg="""No `update_stats` argument was provided to SpectralNorm as either a __call__ argument or __init__ argument.""", ) state = nnx.state(self.layer_instance, nnx.Param) for path, param in nnx.to_flat_state(state): self._spectral_normalize_inplace(path, param, update_stats=update_stats) return self.layer_instance(x) def _spectral_normalize_inplace(self, path, orig_param, update_stats): param = orig_param param_shape = param.shape if param.ndim <= 1 or self.n_steps < 1: return elif param.ndim > 2: if self.error_on_non_matrix: raise ValueError( f'Layer instance parameter is {param.ndim}D but error_on_non_matrix is True' ) else: param = jnp.reshape(param, (-1, param.shape[-1])) path_u = path + ("u", ) path_sigma = path + ("sigma", ) if path_u not in self.batch_stats: raise RuntimeError( f"Could not find the path for u batch stat corresponding to the param {path} " "in the batch stats dict. Parameters of the layer_instance should not change!" ) if path_sigma not in self.batch_stats: raise RuntimeError( f"Could not find the path for sigma batch stat corresponding to the param {path} " "in the batch stats dict. Parameters of the layer_instance should not change!" ) u = self.batch_stats[path_u][...] for _ in range(self.n_steps): v = _l2_normalize(jnp.matmul(u, param.T), eps=self.epsilon) u = _l2_normalize(jnp.matmul(v, param), eps=self.epsilon) u = lax.stop_gradient(u) v = lax.stop_gradient(v) sigma = jnp.matmul(jnp.matmul(v, param), u.T)[0, 0] param = param / jnp.where(sigma != 0, sigma, 1) param = param.reshape(param_shape) if update_stats: self.batch_stats[path_u][...] = u self.batch_stats[path_sigma][...] = sigma dtype = dtypes.canonicalize_dtype(param, u, v, sigma, dtype=self.dtype) orig_param[...] = jnp.asarray(param, dtype) ================================================ FILE: flax/nnx/nn/recurrent.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """RNN modules for Flax.""" import warnings from typing import Any, TypeVar from collections.abc import Mapping from types import MappingProxyType from collections.abc import Mapping from collections.abc import Callable from functools import partial from typing_extensions import Protocol from absl import logging import jax import jax.numpy as jnp from flax import nnx from flax.nnx import filterlib, rnglib from flax.nnx.module import Module from flax.nnx.nn import initializers, dtypes from flax.nnx.nn.linear import Linear from flax.nnx.nn.activations import sigmoid from flax.nnx.nn.activations import tanh from flax.nnx.transforms import iteration from flax.typing import Dtype, Initializer, PromoteDtypeFn, Shape default_kernel_init = initializers.lecun_normal() default_bias_init = initializers.zeros_init() A = TypeVar("A") Array = jax.Array Output = Any Carry = Any class RNNCellBase(Module): """RNN cell base class.""" def initialize_carry( self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | rnglib.RngStream | None = None, carry_init: Initializer | None = None, ) -> Carry: """Initialize the RNN cell carry. Args: input_shape: a tuple providing the shape of the input to the cell. rngs: random number generator passed to the init_fn. carry_init: optional carry initializer. Returns: An initialized carry for the given RNN cell. """ raise NotImplementedError def __call__( self, carry: Carry, inputs: Array ) -> tuple[Carry, Array]: """Run the RNN cell. Args: carry: the hidden state of the RNN cell. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ raise NotImplementedError @property def num_feature_axes(self) -> int: """Returns the number of feature axes of the RNN cell.""" raise NotImplementedError def modified_orthogonal(key: Array, shape: Shape, dtype: Dtype = jnp.float32) -> Array: """Modified orthogonal initializer for compatibility with half precision.""" initializer = initializers.orthogonal() return initializer(key, shape).astype(dtype) class LSTMCell(RNNCellBase): r"""LSTM cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. """ def __init__( self, in_features: int, hidden_features: int, *, gate_fn: Callable[..., Any] = sigmoid, activation_fn: Callable[..., Any] = tanh, kernel_init: Initializer = default_kernel_init, recurrent_kernel_init: Initializer = modified_orthogonal, bias_init: Initializer = initializers.zeros_init(), dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, carry_init: Initializer | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, keep_rngs: bool = False, rngs: rnglib.Rngs, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), recurrent_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.in_features = in_features self.hidden_features = hidden_features self.gate_fn = gate_fn self.activation_fn = activation_fn self.dtype = dtype self.param_dtype = param_dtype self.promote_dtype = promote_dtype self.rngs: rnglib.RngStream | None if keep_rngs: self.rngs = rngs.carry.fork() else: self.rngs = nnx.data(None) # input and recurrent layers are summed so only one needs a bias. dense_i = partial( Linear, in_features=in_features, out_features=hidden_features, use_bias=False, kernel_init=kernel_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=kernel_metadata, ) dense_h = partial( Linear, in_features=hidden_features, out_features=hidden_features, use_bias=True, kernel_init=recurrent_kernel_init, bias_init=bias_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=recurrent_kernel_metadata, bias_metadata=bias_metadata, ) self.ii = dense_i() self.if_ = dense_i() self.ig = dense_i() self.io = dense_i() self.hi = dense_h() self.hf = dense_h() self.hg = dense_h() self.ho = dense_h() if carry_init: warnings.warn( "carry_init is provided in __init__. " "Please, use carry_init argument in `initialize_carry` method instead to initialize the carry. " "Otherwise, two instances with same configuration but different carry_init " "functions will have different graphdefs." ) self.carry_init = carry_init def __call__( self, carry: tuple[Array, Array], inputs: Array ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""A long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry i = self.gate_fn(self.ii(inputs) + self.hi(h)) f = self.gate_fn(self.if_(inputs) + self.hf(h)) g = self.activation_fn(self.ig(inputs) + self.hg(h)) o = self.gate_fn(self.io(inputs) + self.ho(h)) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h def initialize_carry( self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | rnglib.RngStream | None = None, carry_init: Initializer | None = None, ) -> tuple[Array, Array]: # type: ignore[override] """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry if rngs is None: raise ValueError('RNGs must be provided to initialize the cell carry.') if self.carry_init is None and carry_init is None: carry_init = initializers.zeros_init() elif carry_init is None: carry_init = self.carry_init assert carry_init is not None # just to please mypy mem_shape = batch_dims + (self.hidden_features,) c = carry_init(rngs(), mem_shape, self.param_dtype) h = carry_init(rngs(), mem_shape, self.param_dtype) return (c, h) @property def num_feature_axes(self) -> int: return 1 class OptimizedLSTMCell(RNNCellBase): r"""More efficient LSTM Cell that concatenates state components before matmul. The parameters are compatible with ``LSTMCell``. Note that this cell is often faster than ``LSTMCell`` as long as the hidden size is roughly <= 2048 units. The mathematical definition of the cell is the same as ``LSTMCell`` and as follows: .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. Args: gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()). dtype: the dtype of the computation (default: infer from inputs and params). param_dtype: the dtype passed to parameter initializers (default: float32). keep_rngs: whether to store the input rngs as attribute (i.e. `self.rngs = rngs`) (default: True). If rngs is stored, we should split the module as `graphdef, params, nondiff = nnx.split(module, nnx.Param, ...)` where `nondiff` contains RNG object associated with stored `self.rngs`. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the kernels that transform the input. recurrent_kernel_metadata: Optional metadata dictionary to set when initializing the kernels that transform the hidden state. bias_metadata: Optional metadata dictionary to set when initializing the bias of layers that transform the hidden state. """ def __init__( self, in_features: int, hidden_features: int, *, gate_fn: Callable[..., Any] = sigmoid, activation_fn: Callable[..., Any] = tanh, kernel_init: Initializer = default_kernel_init, recurrent_kernel_init: Initializer = initializers.orthogonal(), bias_init: Initializer = initializers.zeros_init(), dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, carry_init: Initializer | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, keep_rngs: bool = False, rngs: rnglib.Rngs, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), recurrent_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.in_features = in_features self.hidden_features = hidden_features self.gate_fn = gate_fn self.activation_fn = activation_fn self.dtype = dtype self.param_dtype = param_dtype self.promote_dtype = promote_dtype self.rngs: rnglib.RngStream | None if keep_rngs: self.rngs = rngs.carry.fork() else: self.rngs = nnx.data(None) # input and recurrent layers are summed so only one needs a bias. self.dense_i = Linear( in_features=in_features, out_features=4 * hidden_features, use_bias=False, kernel_init=kernel_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=kernel_metadata, ) self.dense_h = Linear( in_features=hidden_features, out_features=4 * hidden_features, use_bias=True, kernel_init=recurrent_kernel_init, bias_init=bias_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=recurrent_kernel_metadata, bias_metadata=bias_metadata, ) if carry_init: warnings.warn( "carry_init is provided in __init__. " "Please, use carry_init argument in `initialize_carry` method instead to initialize the carry. " "Otherwise, two instances with same configuration but different carry_init " "functions will have different graphdefs." ) self.carry_init = carry_init def __call__( self, carry: tuple[Array, Array], inputs: Array ) -> tuple[tuple[Array, Array], Array]: # type: ignore[override] r"""An optimized long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using ``LSTMCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry # Compute combined transformations for inputs and hidden state y = self.dense_i(inputs) + self.dense_h(h) # Split the combined transformations into individual gates i, f, g, o = jnp.split(y, indices_or_sections=4, axis=-1) # Apply gate activations i = self.gate_fn(i) f = self.gate_fn(f) g = self.activation_fn(g) o = self.gate_fn(o) # Update cell state and hidden state new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h def initialize_carry( self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | rnglib.RngStream | None = None, carry_init: Initializer | None = None, ) -> tuple[Array, Array]: # type: ignore[override] """Initialize the RNN cell carry. Args: rngs: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry if rngs is None: raise ValueError('RNGs must be provided to initialize the cell carry.') mem_shape = batch_dims + (self.hidden_features,) if self.carry_init is None and carry_init is None: carry_init = initializers.zeros_init() elif carry_init is None: carry_init = self.carry_init assert carry_init is not None # just to please mypy c = carry_init(rngs(), mem_shape, self.param_dtype) h = carry_init(rngs(), mem_shape, self.param_dtype) return (c, h) @property def num_feature_axes(self) -> int: return 1 class SimpleCell(RNNCellBase): r"""Simple cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array} where x is the input and h is the output of the previous time step. If `residual` is `True`, .. math:: \begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array} """ def __init__( self, in_features: int, hidden_features: int, # not inferred from carry for now *, dtype: Dtype = jnp.float32, param_dtype: Dtype = jnp.float32, carry_init: Initializer | None = None, residual: bool = False, activation_fn: Callable[..., Any] = tanh, kernel_init: Initializer = initializers.lecun_normal(), recurrent_kernel_init: Initializer = initializers.orthogonal(), bias_init: Initializer = initializers.zeros_init(), promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, keep_rngs: bool = False, rngs: rnglib.Rngs, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), recurrent_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.in_features = in_features self.hidden_features = hidden_features self.dtype = dtype self.param_dtype = param_dtype self.residual = residual self.activation_fn = activation_fn self.promote_dtype = promote_dtype self.rngs: rnglib.RngStream | None if keep_rngs: self.rngs = rngs.carry.fork() else: self.rngs = nnx.data(None) # self.hidden_features = carry.shape[-1] # input and recurrent layers are summed so only one needs a bias. self.dense_h = Linear( in_features=self.hidden_features, out_features=self.hidden_features, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=recurrent_kernel_init, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=recurrent_kernel_metadata, ) self.dense_i = Linear( in_features=self.in_features, out_features=self.hidden_features, use_bias=True, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=kernel_init, bias_init=bias_init, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=kernel_metadata, bias_metadata=bias_metadata, ) if carry_init: warnings.warn( "carry_init is provided in __init__. " "Please, use carry_init argument in `initialize_carry` method instead to initialize the carry. " "Otherwise, two instances with same configuration but different carry_init " "functions will have different graphdefs." ) self.carry_init = carry_init def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] new_carry = self.dense_i(inputs) + self.dense_h(carry) if self.residual: new_carry += carry new_carry = self.activation_fn(new_carry) return new_carry, new_carry def initialize_carry( self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | rnglib.RngStream | None = None, carry_init: Initializer | None = None, ) -> Array: # type: ignore[override] """Initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry if rngs is None: raise ValueError('RNGs must be provided to initialize the cell carry.') batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.hidden_features,) if self.carry_init is None and carry_init is None: carry_init = initializers.zeros_init() elif carry_init is None: carry_init = self.carry_init assert carry_init is not None # just to please mypy return carry_init(rngs(), mem_shape, self.param_dtype) @property def num_feature_axes(self) -> int: return 1 class GRUCell(RNNCellBase): r"""GRU cell. The mathematical definition of the cell is as follows .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array} where x is the input and h is the output of the previous time step. Args: in_features: number of input features. hidden_features: number of output features. gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()). bias_init: initializer for the bias parameters (default: initializers.zeros_init()). dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: float32). keep_rngs: whether to store the input rngs as attribute (i.e. `self.rngs = rngs`) (default: True). If rngs is stored, we should split the module as `graphdef, params, nondiff = nnx.split(module, nnx.Param, ...)` where `nondiff` contains RNG object associated with stored `self.rngs`. rngs: rng key. kernel_metadata: Optional metadata dictionary to set when initializing the kernels that transform the input. recurrent_kernel_metadata: Optional metadata dictionary to set when initializing the kernels that transform the hidden state. bias_metadata: Optional metadata dictionary to set when initializing the bias of layers that transform the input. """ def __init__( self, in_features: int, hidden_features: int, *, gate_fn: Callable[..., Any] = sigmoid, activation_fn: Callable[..., Any] = tanh, kernel_init: Initializer = default_kernel_init, recurrent_kernel_init: Initializer = initializers.orthogonal(), bias_init: Initializer = initializers.zeros_init(), dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, carry_init: Initializer | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, keep_rngs: bool = False, rngs: rnglib.Rngs, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), recurrent_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.in_features = in_features self.hidden_features = hidden_features self.gate_fn = gate_fn self.activation_fn = activation_fn self.dtype = dtype self.param_dtype = param_dtype self.promote_dtype = promote_dtype self.rngs: rnglib.RngStream | None if keep_rngs: self.rngs = rngs.carry.fork() else: self.rngs = nnx.data(None) # Combine input transformations into a single linear layer self.dense_i = Linear( in_features=in_features, out_features=3 * hidden_features, # r, z, n use_bias=True, kernel_init=kernel_init, bias_init=bias_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=kernel_metadata, bias_metadata=bias_metadata, ) self.dense_h = Linear( in_features=hidden_features, out_features=3 * hidden_features, # r, z, n use_bias=False, kernel_init=recurrent_kernel_init, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.promote_dtype, rngs=rngs, kernel_metadata=recurrent_kernel_metadata, ) if carry_init: warnings.warn( "carry_init is provided in __init__. " "Please, use carry_init argument in `initialize_carry` method instead to initialize the carry. " "Otherwise, two instances with same configuration but different carry_init " "functions will have different graphdefs." ) self.carry_init = carry_init def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: # type: ignore[override] """Gated recurrent unit (GRU) cell. Args: carry: the hidden state of the GRU cell, initialized using ``GRUCell.initialize_carry``. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ h = carry # Compute combined transformations for inputs and hidden state x_transformed = self.dense_i(inputs) h_transformed = self.dense_h(h) # Split the combined transformations into individual components xi_r, xi_z, xi_n = jnp.split(x_transformed, 3, axis=-1) hh_r, hh_z, hh_n = jnp.split(h_transformed, 3, axis=-1) # Compute gates r = self.gate_fn(xi_r + hh_r) z = self.gate_fn(xi_z + hh_z) # Compute n with an additional linear transformation on h n = self.activation_fn(xi_n + r * hh_n) # Update hidden state new_h = (1.0 - z) * n + z * h return new_h, new_h def initialize_carry( self, input_shape: tuple[int, ...], rngs: rnglib.Rngs | rnglib.RngStream | None = None, carry_init: Initializer | None = None, ) -> Array: # type: ignore[override] """Initialize the RNN cell carry. Args: rngs: random number generator passed to the init_fn. input_shape: a tuple providing the shape of the input to the cell. Returns: An initialized carry for the given RNN cell. """ batch_dims = input_shape[:-1] if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry if rngs is None: raise ValueError('RNGs must be provided to initialize the cell carry.') mem_shape = batch_dims + (self.hidden_features,) if self.carry_init is None and carry_init is None: carry_init = initializers.zeros_init() elif carry_init is None: carry_init = self.carry_init assert carry_init is not None # just to please mypy h = carry_init(rngs(), mem_shape, self.param_dtype) return h @property def num_feature_axes(self) -> int: return 1 class RNN(Module): """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence using :func:`flax.nnx.scan`. """ state_axes: dict[str, int | type[iteration.Carry] | None] def __init__( self, cell: RNNCellBase, *, time_major: bool = False, return_carry: bool = False, reverse: bool = False, keep_order: bool = False, unroll: int = 1, state_axes: Mapping[str, int | type[iteration.Carry] | None] | None = None, broadcast_rngs: filterlib.Filter = None, rngs: rnglib.Rngs | rnglib.RngStream | bool = True, ): self.cell = cell self.time_major = time_major self.return_carry = return_carry self.reverse = reverse self.keep_order = keep_order self.unroll = unroll self.rngs: rnglib.RngStream | None if rngs is True: self.rngs = rnglib.RngStream(0, tag='carry') elif isinstance(rngs, rnglib.Rngs): self.rngs = rngs.carry.fork() elif rngs is False: self.rngs = nnx.data(None) else: raise ValueError( 'Expected rngs to be a jax.Array, int, Rngs, or bool. ' f'Got {type(rngs)}.' ) self.state_axes = state_axes or nnx.StateAxes({...: iteration.Carry}) # type: ignore self.broadcast_rngs = broadcast_rngs def __call__( self, inputs: Array, *, initial_carry: Carry | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, keep_order: bool | None = None, rngs: rnglib.Rngs | rnglib.RngStream | None = None, ): if return_carry is None: return_carry = self.return_carry if time_major is None: time_major = self.time_major if reverse is None: reverse = self.reverse if keep_order is None: keep_order = self.keep_order # Infer the number of batch dimensions from the input shape. # Cells like ConvLSTM have additional spatial dimensions. time_axis = 0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1) # make time_axis positive if time_axis < 0: time_axis += inputs.ndim if time_major: # we add +1 because we moved the time axis to the front batch_dims = inputs.shape[1 : -self.cell.num_feature_axes] else: batch_dims = inputs.shape[:time_axis] # maybe reverse the sequence if reverse: inputs = jax.tree_util.tree_map( lambda x: flip_sequences( x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major, # type: ignore ), inputs, ) if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry.fork() carry: Carry = ( self.cell.initialize_carry( inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :], rngs ) if initial_carry is None else initial_carry ) slice_carry = seq_lengths is not None and return_carry broadcast_rngs = nnx.All(nnx.RngState, self.broadcast_rngs) state_axes = iteration.StateAxes({broadcast_rngs: None, **self.state_axes}) # type: ignore[misc] # we use split_rngs with splits=1 and squeeze=True to get unique rngs # every time RNN is called @nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True) @nnx.scan( in_axes=(state_axes, iteration.Carry, time_axis), out_axes=(iteration.Carry, (0, time_axis)) if slice_carry else (iteration.Carry, time_axis), unroll=self.unroll, ) def scan_fn( cell: RNNCellBase, carry: Carry, x: Array ) -> tuple[Carry, Array] | tuple[Carry, tuple[Carry, Array]]: carry, y = cell(carry, x) if slice_carry: return carry, (carry, y) return carry, y scan_output = scan_fn(self.cell, carry, inputs) # Next we select the final carry. If a segmentation mask was provided and # return_carry is True we slice the carry history and select the last valid # carry for each sequence. Otherwise we just use the last carry. if slice_carry: assert seq_lengths is not None _, (carries, outputs) = scan_output # seq_lengths[None] expands the shape of the mask to match the # number of dimensions of the carry. carry = _select_last_carry(carries, seq_lengths) else: carry, outputs = scan_output if reverse and keep_order: outputs = jax.tree_util.tree_map( lambda x: flip_sequences( x, seq_lengths, num_batch_dims=len(batch_dims), time_major=time_major, # type: ignore ), outputs, ) if return_carry: return carry, outputs else: return outputs def _select_last_carry(sequence: A, seq_lengths: jnp.ndarray) -> A: last_idx = seq_lengths - 1 def _slice_array(x: jnp.ndarray): return x[last_idx, jnp.arange(x.shape[1])] return jax.tree_util.tree_map(_slice_array, sequence) def _expand_dims_like(x, target): """Expands the shape of `x` to match `target`'s shape by adding singleton dimensions.""" return x.reshape(list(x.shape) + [1] * (target.ndim - x.ndim)) def flip_sequences( inputs: Array, seq_lengths: Array | None, num_batch_dims: int, time_major: bool, ) -> Array: """Flips a sequence of inputs along the time axis. This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements. Example:: >>> from flax.nnx.nn.recurrent import flip_sequences >>> from jax import numpy as jnp >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) >>> lengths = jnp.array([1, 2, 3]) >>> flip_sequences(inputs, lengths, 1, False) Array([[1, 0, 0], [3, 2, 0], [6, 5, 4]], dtype=int32) Args: inputs: An array of input IDs [batch_size, seq_length]. lengths: The length of each sequence [batch_size]. Returns: An ndarray with the flipped inputs. """ # Compute the indices to put the inputs in flipped order as per above example. time_axis = 0 if time_major else num_batch_dims max_steps = inputs.shape[time_axis] if seq_lengths is None: # reverse inputs and return inputs = jnp.flip(inputs, axis=time_axis) return inputs seq_lengths = jnp.expand_dims(seq_lengths, axis=time_axis) # create indexes idxs = jnp.arange(max_steps - 1, -1, -1) # [max_steps] if time_major: idxs = jnp.reshape(idxs, [max_steps] + [1] * num_batch_dims) else: idxs = jnp.reshape( idxs, [1] * num_batch_dims + [max_steps] ) # [1, ..., max_steps] idxs = (idxs + seq_lengths) % max_steps # [*batch, max_steps] idxs = _expand_dims_like(idxs, target=inputs) # [*batch, max_steps, *features] # Select the inputs in flipped order. outputs = jnp.take_along_axis(inputs, idxs, axis=time_axis) return outputs def _concatenate(a: Array, b: Array) -> Array: """Concatenates two arrays along the last dimension.""" return jnp.concatenate([a, b], axis=-1) class RNNBase(Protocol): def __call__( self, inputs: Array, *, initial_carry: Carry | None = None, rngs: rnglib.Rngs | rnglib.RngStream | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, keep_order: bool | None = None, ) -> Output | tuple[Carry, Output]: ... class Bidirectional(Module): """Processes the input in both directions and merges the results. Example usage:: >>> from flax import nnx >>> import jax >>> import jax.numpy as jnp >>> # Define forward and backward RNNs >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> # Create Bidirectional layer >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) >>> # Input data >>> x = jnp.ones((2, 3, 3)) >>> # Apply the layer >>> out = layer(x) >>> print(out.shape) (2, 3, 8) """ forward_rnn: RNNBase backward_rnn: RNNBase merge_fn: Callable[[Array, Array], Array] = _concatenate time_major: bool = False return_carry: bool = False def __init__( self, forward_rnn: RNNBase, backward_rnn: RNNBase, *, merge_fn: Callable[[Array, Array], Array] = _concatenate, time_major: bool = False, return_carry: bool = False, rngs: rnglib.Rngs | rnglib.RngStream | bool = True, ): self.forward_rnn = forward_rnn self.backward_rnn = backward_rnn self.merge_fn = merge_fn self.time_major = time_major self.return_carry = return_carry self.rngs: rnglib.RngStream | None if rngs is True: self.rngs = rnglib.RngStream(0, tag='carry') elif rngs is False: self.rngs = None elif isinstance(rngs, rnglib.Rngs): self.rngs = rngs.carry.fork() elif isinstance(rngs, rnglib.RngStream): self.rngs = rngs else: raise TypeError( f'rngs must be a Rngs, jax.Array, int, or bool, but got {type(rngs)}.' ) def __call__( self, inputs: Array, *, initial_carry: tuple[Carry, Carry] | None = None, rngs: rnglib.Rngs | rnglib.RngStream | None = None, seq_lengths: Array | None = None, return_carry: bool | None = None, time_major: bool | None = None, reverse: bool | None = None, # unused keep_order: bool | None = None, # unused ) -> Output | tuple[tuple[Carry, Carry], Output]: if time_major is None: time_major = self.time_major if return_carry is None: return_carry = self.return_carry if rngs is None: rngs = self.rngs if isinstance(rngs, rnglib.Rngs): rngs = rngs.carry if initial_carry is not None: initial_carry_forward, initial_carry_backward = initial_carry else: initial_carry_forward = None initial_carry_backward = None # Throw a warning in case the user accidentally re-uses the forward RNN # for the backward pass and does not intend for them to share parameters. if self.forward_rnn is self.backward_rnn: logging.warning( 'forward_rnn and backward_rnn is the same object, so ' 'they will share parameters.' ) # Encode in the forward direction. carry_forward, outputs_forward = self.forward_rnn( inputs, initial_carry=initial_carry_forward, rngs=rngs, seq_lengths=seq_lengths, return_carry=True, time_major=time_major, reverse=False, ) # Encode in the backward direction. carry_backward, outputs_backward = self.backward_rnn( inputs, initial_carry=initial_carry_backward, rngs=rngs, seq_lengths=seq_lengths, return_carry=True, time_major=time_major, reverse=True, keep_order=True, ) carry = (carry_forward, carry_backward) if return_carry else None outputs = jax.tree_util.tree_map( self.merge_fn, outputs_forward, outputs_backward ) if return_carry: return carry, outputs else: return outputs ================================================ FILE: flax/nnx/nn/stochastic.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Sequence import jax import jax.numpy as jnp from jax import lax, random from flax.nnx import rnglib from flax.nnx.module import Module, first_from from flax import nnx class Dropout(Module): """Create a dropout layer. To use dropout, call the :func:`train` method (or pass in ``deterministic=False`` in the constructor or during call time). To disable dropout, call the :func:`eval` method (or pass in ``deterministic=True`` in the constructor or during call time). Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class MLP(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... x = self.dropout(x) ... return x >>> model = MLP(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 3)) >>> model.train() # use dropout >>> model(x) Array([[ 2.1067007, -2.5359864, -1.592019 , -2.5238838]], dtype=float32) >>> model.eval() # don't use dropout >>> model(x) Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32) Args: rate: the dropout probability. (_not_ the keep rate!) broadcast_dims: dimensions that will share the same dropout mask deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and masked, whereas if true, no mask is applied and the inputs are returned as is. rng_collection: the rng collection name to use when requesting an rng key. rngs: rng key. """ def __init__( self, rate: float, *, broadcast_dims: Sequence[int] = (), deterministic: bool = False, rng_collection: str = 'dropout', rngs: rnglib.Rngs | rnglib.RngStream | None = None, ): self.rate = rate self.broadcast_dims = broadcast_dims self.deterministic = deterministic self.rng_collection = rng_collection if isinstance(rngs, rnglib.Rngs): self.rngs = rngs[self.rng_collection].fork() elif isinstance(rngs, rnglib.RngStream): self.rngs = rngs.fork() elif rngs is None: self.rngs = nnx.data(None) else: raise TypeError( f'rngs must be a Rngs, RngStream or None, but got {type(rngs)}.' ) def __call__( self, inputs, *, deterministic: bool | None = None, rngs: rnglib.Rngs | rnglib.RngStream | jax.Array | None = None, ) -> jax.Array: """Applies a random dropout mask to the input. Args: inputs: the inputs that should be randomly masked. deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and masked, whereas if true, no mask is applied and the inputs are returned as is. The ``deterministic`` flag passed into the call method will take precedence over the ``deterministic`` flag passed into the constructor. rngs: an optional key, RngStream, or Rngs object used to generate the dropout mask. If given it will take precedence over the rngs passed into the constructor. Returns: The masked inputs reweighted to preserve mean. """ deterministic = first_from( deterministic, self.deterministic, error_msg="""No `deterministic` argument was provided to Dropout as either a __call__ argument or class attribute""", ) if (self.rate == 0.0) or deterministic: return inputs # Prevent gradient NaNs in 1.0 edge-case. if self.rate == 1.0: return jnp.zeros_like(inputs) rngs = first_from( # type: ignore[assignment] rngs, self.rngs, error_msg="""`deterministic` is False, but no `rngs` argument was provided to Dropout as either a __call__ argument or class attribute.""", ) if isinstance(rngs, rnglib.Rngs): key = rngs[self.rng_collection]() elif isinstance(rngs, rnglib.RngStream): key = rngs() elif isinstance(rngs, jax.Array): key = rngs else: raise TypeError( f'rngs must be a Rngs, RngStream or jax.Array, but got {type(rngs)}.' ) keep_prob = 1.0 - self.rate broadcast_shape = list(inputs.shape) for dim in self.broadcast_dims: broadcast_shape[dim] = 1 mask = random.bernoulli( key, p=keep_prob, shape=broadcast_shape, out_sharding=jax.typeof(inputs).sharding ) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) def set_view( self, deterministic: bool | None = None, ): """Class method used by ``nnx.view``. Args: deterministic: if True, disables dropout masking. """ if deterministic is not None: self.deterministic = deterministic ================================================ FILE: flax/nnx/proxy_caller.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import dataclasses import typing as tp import jax A = tp.TypeVar('A', covariant=True) # type: ignore[not-supported-yet] def _identity(x): return x @dataclasses.dataclass(frozen=True) class GetItem: key: tp.Any @dataclasses.dataclass(frozen=True) class GetAttr: name: str @dataclasses.dataclass(frozen=True) class DelayedAccessor: actions: tuple[GetItem | GetAttr, ...] = () def __call__(self, x): for action in self.actions: if isinstance(action, GetItem): x = x[action.key] elif isinstance(action, GetAttr): x = getattr(x, action.name) return x def __getattr__(self, name): return DelayedAccessor(self.actions + (GetAttr(name),)) def __getitem__(self, key): return DelayedAccessor(self.actions + (GetItem(key),)) jax.tree_util.register_static(DelayedAccessor) class _AccessorCall(tp.Protocol): def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> tp.Any: ... class CallableProxy: def __init__( self, callable: _AccessorCall, accessor: DelayedAccessor | None = None ): self._callable = callable self._accessor = DelayedAccessor() if accessor is None else accessor def __call__(self, *args, **kwargs): return self._callable(self._accessor, *args, **kwargs) def __getattr__(self, name) -> CallableProxy: return CallableProxy(self._callable, getattr(self._accessor, name)) def __getitem__(self, key) -> CallableProxy: return CallableProxy(self._callable, self._accessor[key]) class ApplyCaller(tp.Protocol, tp.Generic[A]): def __getattr__(self, __name) -> ApplyCaller[A]: ... def __getitem__(self, __name) -> ApplyCaller[A]: ... def __call__(self, *args, **kwargs) -> tuple[tp.Any, A]: ... ================================================ FILE: flax/nnx/pytreelib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import dataclasses import inspect import os import threading import typing as tp from abc import ABCMeta from copy import deepcopy import warnings from flax.nnx import variablelib import jax import numpy as np import treescope # type: ignore[import-untyped] from treescope import rendering_parts from flax import errors, nnx from flax.nnx import ( graphlib, reprlib, tracers, visualization, ) from flax import config from flax.nnx.variablelib import Variable from flax.typing import MISSING, Missing, SizeBytes BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ A = tp.TypeVar('A') P = tp.TypeVar('P', bound='Pytree') T = tp.TypeVar('T', bound=type) DataAnnotation = '__data__' Data = tp.Annotated[A, DataAnnotation] Data.__doc__ = """Data marks attributes of a class as pytree data using type annotations. Data annotations must be used at the class level and will apply to all instances. The usage of Data is recommended when type annotations are used already present or required e.g. for dataclasses. """ DATA_REGISTRY: set[type] = set() @tp.overload def data(value: A, /) -> A: ... @tp.overload def data( *, default: A = dataclasses.MISSING, # type: ignore[assignment] default_factory: tp.Callable[[], A] | None = None, # type: ignore[assignment] init: bool = True, repr: bool = True, hash: bool | None = None, compare: bool = True, metadata: tp.Mapping[str, tp.Any] | None = None, kw_only: bool = False, ) -> tp.Any: ... def data(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: """Annotates a an attribute as pytree data. The return value from `data` must be directly assigned to an Object attribute which will be registered as a pytree data attribute. Example:: from flax import nnx import jax class Foo(nnx.Pytree): def __init__(self): self.data_attr = nnx.data(42) # pytree data self.static_attr = "hello" # static attribute foo = Foo() assert jax.tree.leaves(foo) == [42] Args: value: The value to annotate as data. Returns: A value which will register the attribute as data on assignment. """ if not isinstance(value, Missing) and kwargs: raise TypeError( 'nnx.data() accepts either a single positional argument or keyword' ' arguments, but not both.' ) metadata = {'nnx_value': value} if 'metadata' in kwargs and kwargs['metadata'] is not None: if 'static' in kwargs['metadata']: raise ValueError( "Cannot use 'static' key in metadata argument for nnx.data." ) metadata.update(kwargs.pop('metadata')) metadata['static'] = False return dataclasses.field(**kwargs, metadata=metadata) # type: ignore[return-value] def register_data_type(type_: T, /) -> T: """Registers a type as pytree data type recognized by Object. Custom types registered as data will be automatically recognized as data attributes when assigned to an Object attribute. This means that values of this type do not need to be wrapped in `nnx.data(...)` for Object to mark the attribute its being assigned to as data. Example:: from flax import nnx from dataclasses import dataclass @dataclass(frozen=True) class MyType: value: int nnx.register_data_type(MyType) class Foo(nnx.Pytree): def __init__(self, a): self.a = MyType(a) # Automatically registered as data self.b = "hello" # str not registered as data foo = Foo(42) assert nnx.is_data(foo.a) # True assert jax.tree.leaves(foo) == [MyType(value=42)] Can also be used as a decorator:: @nnx.register_data_type @dataclass(frozen=True) class MyType: value: int """ DATA_REGISTRY.add(type_) return type_ def is_data(value: tp.Any, /) -> bool: """Checks if a value is a registered data type. This function checks a the value is registered data type, which means it is automatically recognized as data when assigned a :class:`flax.nnx.Pytree` attribute. Data types are: - ``jax.Array`` - ``np.ndarray`` - ``ArrayRef`` - Variables (:class:`flax.nnx.Param`, :class:`flax.nnx.BatchStat`, `nnx.RngState`, etc.) - All graph nodes (:class:`flax.nnx.Object`, :class:`flax.nnx.Module`, :class:`flax.nnx.Rngs`, etc.) - Any type registered with :func:`flax.nnx.register_data_type` - Any pytree that contains at least one node or leaf element of the above Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... # ------ DATA ------------ >>> assert nnx.is_data( jnp.array(0) ) # Arrays >>> assert nnx.is_data( nnx.Param(1) ) # Variables >>> assert nnx.is_data( nnx.Rngs(2) ) # nnx.Pytrees >>> assert nnx.is_data( nnx.Linear(1, 1,rngs=nnx.Rngs(0)) ) # Modules ... # ------ STATIC ------------ >>> assert not nnx.is_data( 'hello' ) # strings, arbitrary objects >>> assert not nnx.is_data( 42 ) # int, float, bool, complex, etc. >>> assert not nnx.is_data( [1, 2.0, 3j, jnp.array(1)] ) # list, dict, tuple, pytrees Args: value: The value to check. Returns: A string representing the attribute status. """ return ( graphlib.is_node_leaf(value) or graphlib.is_graph_node(value) or type(value) in DATA_REGISTRY ) def has_data(value: tp.Any, /) -> list[tp.Any]: visited: set[int] = set() def _is_leaf(x): if id(x) in visited: return True visited.add(id(x)) return is_data(x) leaves = jax.tree.leaves(value, is_leaf=_is_leaf) return [leaf for leaf in leaves if is_data(leaf)] StaticAnnotation = '__static__' Static = tp.Annotated[A, StaticAnnotation] Static.__doc__ = """Static marks attributes of a class as static using type annotations. Static annotations must be used at the class level and will apply to all instances. The usage of Static is recommended when type annotations are used already present or required e.g. for dataclasses. """ @tp.overload def static(value: A, /) -> A: ... @tp.overload def static( *, default: A = dataclasses.MISSING, # type: ignore[assignment] default_factory: tp.Callable[[], A] | None = None, init: bool = True, repr: bool = True, hash: bool | None = None, compare: bool = True, metadata: tp.Mapping[str, tp.Any] | None = None, kw_only: bool = False, ) -> tp.Any: ... def static(value: tp.Any = MISSING, /, **kwargs) -> tp.Any: """Annotates a an attribute as static. The return value from `static` must be directly assigned to an Object attribute which will be registered as static attribute. Example:: from flax import nnx class Foo(nnx.Pytree): def __init__(self, a, b): self.a = nnx.static(a) # pytree metadata self.b = nnx.data(b) # pytree data foo = Foo("one", "two") assert jax.tree.leaves(foo) == ["two"] By default ``nnx.Pytree`` will ... """ if not isinstance(value, Missing) and kwargs: raise TypeError( 'nnx.static() accepts either a single positional argument or keyword' ' arguments, but not both.' ) metadata = {'nnx_value': value} if 'metadata' in kwargs and kwargs['metadata'] is not None: if 'static' in kwargs['metadata']: raise ValueError( "Cannot use 'static' key in metadata argument for nnx.static." ) metadata.update(kwargs.pop('metadata')) metadata['static'] = True return dataclasses.field(**kwargs, metadata=metadata) # type: ignore[return-value] @tp.overload def dataclass(cls: type[A], /) -> type[A]: ... @tp.overload def dataclass( *, init: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False, ) -> tp.Callable[[type[A]], type[A]]: ... @tp.dataclass_transform(field_specifiers=(dataclasses.field, data, static)) def dataclass( cls=None, /, *, init: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, match_args: bool = True, kw_only: bool = False, slots: bool = False, weakref_slot: bool = False, ) -> tp.Any: return dataclasses.dataclass( cls, init=init, eq=eq, order=order, unsafe_hash=unsafe_hash, match_args=match_args, kw_only=kw_only, slots=slots, weakref_slot=weakref_slot, ) def _collect_stats( node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]] ): if not graphlib.is_node(node) and not isinstance(node, Variable): raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.') if id(node) in node_stats: return stats: dict[type[Variable], SizeBytes] = {} node_stats[id(node)] = stats if isinstance(node, Variable): var_type = node.var_type if issubclass(var_type, nnx.RngState): var_type = nnx.RngState size_bytes = SizeBytes.from_any(node.get_raw_value()) if size_bytes: stats[var_type] = size_bytes else: node_impl = graphlib.get_node_impl(node) assert node_impl is not None node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): if id(value) in node_stats: continue if graphlib.is_node(value) or isinstance(value, Variable): _collect_stats(value, node_stats) child_stats = node_stats[id(value)] for var_type, size_bytes in child_stats.items(): if var_type in stats: stats[var_type] += size_bytes else: stats[var_type] = size_bytes @dataclasses.dataclass class ObjectContext(threading.local): seen_modules_repr: set[int] | None = None node_stats: dict[int, dict[type[Variable], SizeBytes]] | None = None OBJECT_CONTEXT = ObjectContext() class PytreeState(reprlib.Representable): __slots__ = ('_trace_state', '_initializing', '_is_setup') def __init__(self, initializing: bool = False, is_setup: bool = False): self._trace_state = tracers.TraceState() self._initializing = initializing self._is_setup = is_setup @property def trace_state(self) -> tracers.TraceState: return self._trace_state @property def initializing(self) -> bool: return self._initializing @property def is_setup(self) -> bool: return self._is_setup def __nnx_repr__(self): yield reprlib.Object(type(self)) yield reprlib.Attr('trace_state', self._trace_state) def __treescope_repr__(self, path, subtree_renderer): return visualization.render_object_constructor( object_type=type(self), attributes={'trace_state': self._trace_state}, path=path, subtree_renderer=subtree_renderer, ) def _flatten_pytree_state(state: PytreeState): return (), (state.initializing, state.is_setup) def _unflatten_pytree_state(static: tuple[bool, bool], _): initializing, setup = static return PytreeState(initializing, setup) jax.tree_util.register_pytree_node( PytreeState, _flatten_pytree_state, _unflatten_pytree_state, ) def check_pytree(pytree): """Checks if a pytree is valid.""" if not isinstance(pytree, Pytree): raise TypeError(f'Expected a Pytree, got {type(pytree)}.') for name, value in vars(pytree).items(): pytree._check_value(name, value, new_status=None) class PytreeMeta(ABCMeta): if not tp.TYPE_CHECKING: def __call__(cls, *args: Any, **kwargs: Any) -> Any: return _graph_node_meta_call(cls, *args, **kwargs) def _pytree_meta_construct(cls, self, *args, **kwargs): self.__init__(*args, **kwargs) ObjectMeta = PytreeMeta def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: node = cls.__new__(cls, *args, **kwargs) object.__setattr__(node, '_pytree__state', PytreeState()) object.__setattr__(node, '_pytree__nodes', cls._pytree__nodes) cls._pytree_meta_construct(node, *args, **kwargs) if cls._pytree__is_pytree: missing: dict[str, bool] = {} for name, value in vars(node).items(): if name not in node._pytree__nodes: missing[name] = is_data(value) if missing: object.__setattr__( node, '_pytree__nodes', node._pytree__nodes.update(missing) ) check_pytree(node) return node @dataclasses.dataclass(frozen=True, repr=False) class ArrayRepr(reprlib.Representable): shape: tp.Tuple[int, ...] dtype: tp.Any @staticmethod def from_array(array: jax.Array | np.ndarray) -> ArrayRepr: return ArrayRepr(array.shape, array.dtype) def __nnx_repr__(self): yield reprlib.Object(type='Array', same_line=True) yield reprlib.Attr('shape', self.shape) yield reprlib.Attr('dtype', self.dtype) @dataclasses.dataclass(frozen=True, repr=False) class VariableRepr(reprlib.Representable): var_type: type[Variable] value: tp.Any metadata: dict[str, tp.Any] def __nnx_repr__(self): variable = self.var_type._new(self.value, self.metadata) yield from variable.__nnx_repr__() @dataclasses.dataclass(frozen=True, repr=False) class MutableArrayRepr(reprlib.Representable): shape: tp.Tuple[int, ...] dtype: tp.Any @staticmethod def from_array(array: jax.Array | np.ndarray) -> MutableArrayRepr: return MutableArrayRepr(array.shape, array.dtype) def __nnx_repr__(self): yield reprlib.Object(type='ArrayRef', same_line=True) yield reprlib.Attr('shape', self.shape) yield reprlib.Attr('dtype', self.dtype) def _to_shape_dtype(x): if isinstance(x, Variable): value = x.get_raw_value() metadata = x.get_metadata() value = jax.tree.map(_to_shape_dtype, value) return VariableRepr(x.var_type, value, metadata) elif variablelib.is_array_ref(x) and np.prod(x.shape) > 1: return MutableArrayRepr(x.shape, x.dtype) elif ( isinstance(x, (np.ndarray, jax.Array)) and np.prod(x.shape) > 1 ): return ArrayRepr(x.shape, x.dtype) return x class AttributeStatus(tp.NamedTuple): is_data: bool explicit: bool class Pytree(reprlib.Representable, metaclass=PytreeMeta): """Base class for all NNX objects.""" if tp.TYPE_CHECKING: _pytree__nodes: graphlib.HashableMapping[tp.Any, bool] _pytree__state: PytreeState _pytree__is_pytree: bool def __init_subclass__( cls, *, pytree: bool = config.flax_pytree_module, **kwargs, ) -> None: super().__init_subclass__(**kwargs) if slots := getattr(cls, '__slots__', ()): raise TypeError( 'Pytree currently does not support __slots__, ' f"found __slots__={slots} in '{cls.__name__}'." ) cls._pytree__is_pytree = pytree graphlib.register_graph_node_type( type=cls, flatten=cls._graph_node_flatten, set_key=cls._graph_node_set_key, # type: ignore pop_key=cls._graph_node_pop_key, # type: ignore create_empty=cls._graph_node_create_empty, clear=cls._graph_node_clear, init=cls._graph_node_init, # type: ignore ) nodes: dict[str, bool] = dict(getattr(cls, '_pytree__nodes', ())) nodes['_pytree__state'] = True try: type_hints = tp.get_type_hints( cls, globals(), {cls.__name__: cls}, include_extras=True ) except NameError: type_hints = cls.__annotations__ # add annotation attributes for name, type_ in type_hints.items(): if isinstance(type_, str): if type_.startswith('nnx.Data'): warnings.warn( f"'Data' is deprecated, please replace:\n\n" ' some_field: nnx.Data[SomeType]\n\n' f'with:\n\n' ' some_field: SomeType = nnx.data()\n\n', DeprecationWarning, stacklevel=2, ) nodes[name] = True elif type_.startswith('nnx.Static'): warnings.warn( f"'Static' is deprecated, please replace:\n\n" ' some_field: nnx.Static[SomeType]\n\n' f'with:\n\n' ' some_field: SomeType = nnx.static()\n\n', DeprecationWarning, stacklevel=2, ) nodes[name] = False else: type_metadata = getattr(type_, '__metadata__', ()) if DataAnnotation in type_metadata: warnings.warn( f"'Data' is deprecated, please replace:\n\n" ' some_field: nnx.Data[SomeType]\n\n' f'with:\n\n' ' some_field: SomeType = nnx.data()\n\n', DeprecationWarning, stacklevel=2, ) nodes[name] = True elif StaticAnnotation in type_metadata: warnings.warn( f"'Static' is deprecated, please replace:\n\n" ' some_field: nnx.Static[SomeType]\n\n' f'with:\n\n' ' some_field: SomeType = nnx.static()\n\n', DeprecationWarning, stacklevel=2, ) nodes[name] = False for name, value in vars(cls).items(): if isinstance(value, dataclasses.Field) and 'static' in value.metadata: if not isinstance(value.metadata['static'], bool): raise ValueError( f"Invalid 'static' metadata for attribute" f" '{cls.__name__}.{name}': expected bool, got" f' {type(value.metadata["static"]).__name__}.' ) is_node = not value.metadata['static'] if name in nodes and nodes[name] != is_node: raise ValueError( f'Conflicting pytree annotation for attribute' f" '{cls.__name__}.{name}': previously registered as" f' {"data" if nodes[name] else "static"}, but found' f' nnx.{"data" if is_node else "static"}(...) annotation.' ) nodes[name] = is_node cls._pytree__nodes = graphlib.HashableMapping(nodes, copy=False) if pytree: jax.tree_util.register_pytree_with_keys( cls, flatten_with_keys=cls._pytree__flatten_with_paths, unflatten_func=cls._pytree__unflatten, flatten_func=cls._pytree__flatten, ) if BUILDING_DOCS: # set correct signature for sphinx cls.__signature__ = inspect.signature(cls.__init__) # Backward compatibility with PR #4863 @property def _object__nodes(self): warnings.warn( "'_object__nodes' is deprecated, use '_pytree__nodes' instead.", DeprecationWarning, stacklevel=2, ) return self._pytree__nodes @property def _object__state(self): warnings.warn( "'_object__state' is deprecated, use '_pytree__state' instead.", DeprecationWarning, stacklevel=2, ) return self._pytree__state if not tp.TYPE_CHECKING: def __setattr__(self, name: str, value: Any) -> None: self._setattr(name, value) def _setattr(self, name, value: tp.Any) -> None: self._check_valid_context( lambda: f"Cannot mutate '{type(self).__name__}' from different trace level" ) data: bool = False explicit: bool = False if isinstance(value, dataclasses.Field) and 'nnx_value' in value.metadata: is_static = value.metadata['static'] value = value.metadata['nnx_value'] if self._pytree__is_pytree: data = not is_static explicit = True elif self._pytree__is_pytree: data = is_data(value) if self._pytree__is_pytree: self._check_value(name, value, AttributeStatus(data, explicit)) if name not in self._pytree__nodes or ( explicit and self._pytree__nodes[name] != data ): object.__setattr__( self, '_pytree__nodes', self._pytree__nodes.update({name: data}) ) object.__setattr__(self, name, value) def _check_value(self, key, value, new_status: AttributeStatus | None): def _get_annotations(leaves): return { 'static' if leaf.metadata['static'] else 'data' for leaf in leaves if isinstance(leaf, dataclasses.Field) and 'nnx_value' in leaf.metadata } def _has_visited(x): if id(x) in visited: return True visited.add(id(x)) return False current_is_data = ( self._pytree__nodes[key] if key in self._pytree__nodes else False ) existing_attr = key in vars(self) if ( new_status is not None and not new_status.explicit and new_status.is_data and existing_attr and not current_is_data ): raise ValueError( f"Cannot assign data value of type '{type(value)}' to static" f" attribute '{key}' of Pytree type '{type(self)}'. To override the" ' status explicitly wrap the value with nnx.data on assignment:\n\n ' f' _.{key} = nnx.data(...)\n\n' ) visited: set[int] = set() leaves = jax.tree.leaves( value, is_leaf=lambda x: _has_visited(x) or is_data(x) ) data_leaves = [leaf for leaf in leaves if is_data(leaf)] if data_leaves: # check no data in nnx.static assignments if new_status is not None: if not new_status.is_data and new_status.explicit: raise ValueError( f"Found data in value of type '{type(value)}' annotated with " f"nnx.static(...) when setting attribute '{key}' of Pytree type " f"'{type(self)}'." ) if not new_status.is_data and not current_is_data: base_pytree_type = Pytree for t in type(self).mro()[1:]: if issubclass(t, nnx.Pytree): base_pytree_type = t break data_leaves_type_names = {type(leaf).__name__ for leaf in data_leaves} raise ValueError( f"Found data on value of type '{type(value)}' assigned to" f" static attribute '{key}' of Pytree type '{type(self)}'. Static" ' attributes should not contain data values. Consider one of' ' the following options:\n\n1. If the attribute is meant to be' f' static, remove the data values of type(s):\n\n {", ".join(data_leaves_type_names)}' f'\n\n2. If the attribute is meant to be data, wrap the value with nnx.data ' f' on assignment:\n\n _.{key} = nnx.data(...)\n\n3. Alternatively,' ' annotate the class attribute with nnx.data:\n\n class' f' {type(self).__name__}({base_pytree_type.__name__}):\n ' f' {key}: {type(value).__name__} = nnx.data()\n\n4. If the' ' container is a list or dict, try using nnx.List(...) or' ' nnx.Dict(...) instead.\n\n5. Disable pytree' ' for this class:\n\n class' f' {type(self).__name__}({base_pytree_type.__name__},' ' pytree=False):\n\n' ) # check no data in static attributes after __init__ elif not current_is_data: base_pytree_type = Pytree for t in type(self).mro()[1:]: if issubclass(t, nnx.Pytree): base_pytree_type = t break raise ValueError( f'Found unexpected data on value of type {type(value)} in static' f" attribute '{key}' of Pytree type '{type(self)}'. This is an" ' error starting from Flax version 0.12.0.\nConsider one of the' ' following options:\n\n1. If the attribute is meant to be static,' ' either remove the data value or wrap it in a static' ' container.\n2. Wrap the value with nnx.data on' f' assignment:\n\n _.{key} = nnx.data(...)\n\n3. Annotate the' ' class attribute with nnx.data:\n\n class' f' {type(self).__name__}({base_pytree_type.__name__}):\n {key}:' f' {type(value).__name__} = nnx.data()\n\n4. If the container is a' ' list or dict, try using nnx.List(...) or nnx.Dict(...)' ' instead.\n5. Disable pytree for this class:\n\n class' f' {type(self).__name__}({base_pytree_type.__name__},' f' pytree=False):\n\n' ) if tags := _get_annotations(leaves): raise ValueError( f'Found unexpected tags {tags} on attribute' f" '{type(self).__name__}.{key}'. Values from nnx.data(...)" ' and\nnnx.static(...) should be assigned to nnx.Pytree attributes' ' directly, they should not be inside other structures. Got value of' f" type '{type(value)}' on Pytree of type '{type(self)}'." ) def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None: if not self._pytree__state.trace_state.is_valid(): raise errors.TraceContextError(error_msg()) def __deepcopy__(self: P, memo=None) -> P: graphdef, state = graphlib.split(self, graph=True) graphdef = deepcopy(graphdef) state = deepcopy(state) return graphlib.merge(graphdef, state) def __nnx_repr__(self): if OBJECT_CONTEXT.node_stats is None or id(self) not in OBJECT_CONTEXT.node_stats: node_stats: dict[int, dict[type[Variable], SizeBytes]] = {} _collect_stats(self, node_stats) OBJECT_CONTEXT.node_stats = node_stats stats = node_stats[id(self)] clear_node_stats = True else: stats = OBJECT_CONTEXT.node_stats[id(self)] clear_node_stats = False if OBJECT_CONTEXT.seen_modules_repr is None: OBJECT_CONTEXT.seen_modules_repr = set() clear_seen = True else: clear_seen = False if id(self) in OBJECT_CONTEXT.seen_modules_repr: yield reprlib.Object(type=type(self), empty_repr='...') return try: if stats: stats_repr = ' # ' + ', '.join( f'{var_type.__name__}: {size_bytes}' for var_type, size_bytes in stats.items() ) if len(stats) > 1: total_bytes = sum(stats.values(), SizeBytes(0, 0)) stats_repr += f', Total: {total_bytes}' else: stats_repr = '' yield reprlib.Object(type=type(self), comment=stats_repr) OBJECT_CONTEXT.seen_modules_repr.add(id(self)) for name, value in vars(self).items(): if str(name).startswith('_pytree__'): continue if str(name).startswith('_') and not self._pytree__nodes.get(str(name), False): continue value = jax.tree.map(_to_shape_dtype, value, is_leaf=graphlib.is_graph_node) yield reprlib.Attr(name, value) finally: if clear_seen: OBJECT_CONTEXT.seen_modules_repr = None if clear_node_stats: OBJECT_CONTEXT.node_stats = None def __treescope_repr__(self, path, subtree_renderer): from flax import nnx if OBJECT_CONTEXT.node_stats is None or id(self) not in OBJECT_CONTEXT.node_stats: node_stats: dict[int, dict[type[Variable], SizeBytes]] = {} _collect_stats(self, node_stats) OBJECT_CONTEXT.node_stats = node_stats stats = node_stats[id(self)] clear_node_stats = True else: stats = OBJECT_CONTEXT.node_stats[id(self)] clear_node_stats = False try: if stats: stats_repr = ' # ' + ', '.join( f'{var_type.__name__}: {size_bytes}' for var_type, size_bytes in stats.items() ) if len(stats) > 1: total_bytes = sum(stats.values(), SizeBytes(0, 0)) stats_repr += f', Total: {total_bytes}' first_line_annotation = rendering_parts.comment_color( rendering_parts.text(f'{stats_repr}') ) else: first_line_annotation = None children = {} for name, value in vars(self).items(): if str(name).startswith('_pytree__'): continue if str(name).startswith('_') and not self._pytree__nodes.get(str(name), False): continue children[name] = value if isinstance(self, nnx.Module): color = treescope.formatting_util.color_from_string( type(self).__qualname__ ) else: color = None return visualization.render_object_constructor( object_type=type(self), attributes=children, path=path, subtree_renderer=subtree_renderer, first_line_annotation=first_line_annotation, color=color, ) finally: if clear_node_stats: OBJECT_CONTEXT.node_stats = None # pickle support def __getstate__(self): return vars(self).copy() def __setstate__(self, state): vars(self).update(state) # ------------------------- # Pytree Definition # ------------------------- _pytree__has_int_keys: bool = False def _pytree__flatten_with_paths(self): obj_items = vars(self).items() if self._pytree__has_int_keys: obj_items = ((_maybe_int(name), value) for name, value in obj_items) key_fn = graphlib._type_aware_sort else: key_fn = None node_attributes = self._pytree__nodes node_keys: list[str | int] = [] node_attrs: list[tuple[tp.Any, tp.Any]] = [] static_keys: list[str | int] = [] static_attrs: list[tp.Any] = [] for key, value in sorted(obj_items, key=key_fn): # get string representation of the key because # node_attributes keys are strings key_str = _get_str(key) if key_str in node_attributes and node_attributes[key_str]: node_keys.append(key) node_attrs.append(( jax.tree_util.GetAttrKey(key) if isinstance(key, str) else jax.tree_util.SequenceKey(key), value, )) else: static_keys.append(key) static_attrs.append(value) return ( node_attrs, (tuple(node_keys), tuple(static_keys), tuple(static_attrs)), ) def _pytree__flatten(self): obj_items = vars(self).items() if self._pytree__has_int_keys: obj_items = ((_maybe_int(name), value) for name, value in obj_items) key_fn = graphlib._type_aware_sort else: key_fn = None node_attributes = self._pytree__nodes node_keys: list[str | int] = [] node_attrs: list[tp.Any] = [] static_keys: list[str | int] = [] static_attrs: list[tp.Any] = [] for key, value in sorted(obj_items, key=key_fn): # get string representation of the key because # node_attributes keys are strings key_str = _get_str(key) if key_str in node_attributes and node_attributes[key_str]: node_keys.append(key) node_attrs.append(value) else: static_keys.append(key) static_attrs.append(value) return ( node_attrs, (tuple(node_keys), tuple(static_keys), tuple(static_attrs)), ) @classmethod def _pytree__unflatten( cls, static: tuple[tp.Iterable[str | int], tp.Iterable[str | int], tp.Iterable[tp.Any]], node_attrs: tp.Iterable[tp.Any], ): node_keys, static_keys, static_attrs = static obj = object.__new__(cls) for name, value in zip(node_keys, node_attrs, strict=True): object.__setattr__(obj, _get_str(name), value) for name, value in zip(static_keys, static_attrs, strict=True): object.__setattr__(obj, _get_str(name), value) return obj # ------------------------- # Graph Definition # ------------------------- def _graph_node_flatten(self): obj_items = vars(self).items() if self._pytree__is_pytree: pytree_nodes = self._pytree__nodes obj_items = ( ( name, nnx.graphlib.DataElem(value) if name in pytree_nodes and pytree_nodes[name] else nnx.graphlib.StaticElem(value), ) for name, value in obj_items ) if self._pytree__has_int_keys: obj_items = ((_maybe_int(name), value) for name, value in obj_items) key_fn = graphlib._type_aware_sort else: key_fn = None nodes = sorted(obj_items, key=key_fn) return nodes, type(self) def _graph_node_set_key(self, key, value: tp.Any): if self._pytree__has_int_keys and isinstance(key, int): key = str(key) if not isinstance(key, str): raise KeyError(f'Invalid key: {key!r}') elif ( hasattr(self, key) and isinstance(variable := getattr(self, key), Variable) and isinstance(value, Variable) ): variable.update_from_state(value) else: setattr(self, key, value) def _graph_node_pop_key(self, key): if self._pytree__has_int_keys and isinstance(key, int): key = str(key) value = getattr(self, key) delattr(self, key) return value def __delattr__(self, name: str) -> None: if name in self._pytree__nodes: mapping = {k: v for k, v in self._pytree__nodes.items() if k != name} object.__setattr__( self, '_pytree__nodes', graphlib.HashableMapping(mapping, copy=False) ) super().__delattr__(name) @staticmethod def _graph_node_create_empty(node_type: tp.Type[P]) -> P: node = object.__new__(node_type) return node def _graph_node_clear(self): vars(self).clear() def _graph_node_init(self, attributes: tp.Iterable[tuple[str | int, tp.Any]]): for name, value in attributes: object.__setattr__(self, _get_str(name), value) if tp.TYPE_CHECKING: def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: ... class Object(Pytree, pytree=False): """Base class for NNX objects that are not pytrees.""" def __init_subclass__(cls, **kwargs): pytree = kwargs.pop('pytree', False) if pytree is not False: raise ValueError( "Object is not a pytree, but 'pytree' was explicitly set to " f'{pytree!r} for type {cls}.' ) super().__init_subclass__(pytree=pytree, **kwargs) def _maybe_int(x): try: return int(x) except (ValueError, TypeError): return x def _get_str(x): return x if isinstance(x, str) else str(x) ================================================ FILE: flax/nnx/reprlib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import os import sys import threading import typing as tp from flax import config as flax_config A = tp.TypeVar('A') B = tp.TypeVar('B') def supports_color() -> bool: """ Returns True if the running system's terminal supports color, and False otherwise. """ try: from IPython import get_ipython ipython_available = get_ipython() is not None except ImportError: ipython_available = False supported_platform = sys.platform != 'win32' or 'ANSICON' in os.environ is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty() return (supported_platform and is_a_tty) or ipython_available class Color(tp.NamedTuple): TYPE: str ATTRIBUTE: str SEP: str PAREN: str COMMENT: str INT: str STRING: str FLOAT: str BOOL: str NONE: str END: str NO_COLOR = Color( TYPE='', ATTRIBUTE='', SEP='', PAREN='', COMMENT='', INT='', STRING='', FLOAT='', BOOL='', NONE='', END='', ) # Use python vscode theme colors if supports_color(): COLOR = Color( TYPE='\x1b[38;2;79;201;177m', ATTRIBUTE='\033[38;2;156;220;254m', SEP='\x1b[38;2;212;212;212m', PAREN='\x1b[38;2;255;213;3m', # COMMENT='\033[38;2;87;166;74m', COMMENT='\033[38;2;105;105;105m', # Dark gray INT='\x1b[38;2;182;207;169m', STRING='\x1b[38;2;207;144;120m', FLOAT='\x1b[38;2;182;207;169m', BOOL='\x1b[38;2;86;156;214m', NONE='\x1b[38;2;86;156;214m', END='\x1b[0m', ) else: COLOR = NO_COLOR @dataclasses.dataclass class ReprContext(threading.local): current_color: Color = COLOR depth: int = 0 REPR_CONTEXT = ReprContext() def colorized(x, /): c = REPR_CONTEXT.current_color if isinstance(x, list): return f'{c.PAREN}[{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}]{c.END}' elif isinstance(x, tuple): if len(x) == 1: return f'{c.PAREN}({c.END}{colorized(x[0])},{c.PAREN}){c.END}' return f'{c.PAREN}({c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}){c.END}' elif isinstance(x, dict): open, close = '{', '}' return f'{c.PAREN}{open}{c.END}{", ".join(f"{c.STRING}{k!r}{c.END}: {colorized(v)}" for k, v in x.items())}{c.PAREN}{close}{c.END}' elif isinstance(x, set): open, close = '{', '}' return f'{c.PAREN}{open}{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}{close}{c.END}' elif isinstance(x, type): return f'{c.TYPE}{x.__name__}{c.END}' elif isinstance(x, bool): return f'{c.BOOL}{x}{c.END}' elif isinstance(x, int): return f'{c.INT}{x}{c.END}' elif isinstance(x, str): return f'{c.STRING}{x!r}{c.END}' elif isinstance(x, float): return f'{c.FLOAT}{x}{c.END}' elif x is None: return f'{c.NONE}{x}{c.END}' elif isinstance(x, Representable): return get_repr(x) else: return repr(x) @dataclasses.dataclass class Object: type: tp.Union[str, type] start: str = '(' end: str = ')' kv_sep: str = '=' indent: str = ' ' empty_repr: str = '' comment: str = '' same_line: bool = False @property def elem_sep(self): return ', ' if self.same_line else ',\n' @dataclasses.dataclass class Attr: key: str value: tp.Union[str, tp.Any] start: str = '' end: str = '' use_raw_value: bool = False use_raw_key: bool = False class Representable: __slots__ = () def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]: raise NotImplementedError def __repr__(self) -> str: current_color = REPR_CONTEXT.current_color REPR_CONTEXT.current_color = NO_COLOR try: return get_repr(self) finally: REPR_CONTEXT.current_color = current_color def __str__(self) -> str: return get_repr(self) def get_repr(obj: Representable) -> str: REPR_CONTEXT.depth += 1 try: if not isinstance(obj, Representable): raise TypeError(f'Object {obj!r} is not representable') c = REPR_CONTEXT.current_color iterator = obj.__nnx_repr__() config = next(iterator) if not isinstance(config, Object): raise TypeError(f'First item must be Config, got {type(config).__name__}') kv_sep = f'{c.SEP}{config.kv_sep}{c.END}' def _repr_elem(elem: tp.Any) -> str: if not isinstance(elem, Attr): raise TypeError(f'Item must be Elem, got {type(elem).__name__}') value_repr = elem.value if elem.use_raw_value else colorized(elem.value) value_repr = value_repr.replace('\n', '\n' + config.indent) key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}' indent = '' if config.same_line else config.indent return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}' max_depth_reached = ( flax_config.flax_max_repr_depth is not None and REPR_CONTEXT.depth > flax_config.flax_max_repr_depth ) if max_depth_reached: elems = '...' else: elems = config.elem_sep.join(map(_repr_elem, iterator)) if elems: if config.same_line or max_depth_reached: elems_repr = elems comment = '' else: elems_repr = '\n' + elems + '\n' comment = f'{c.COMMENT}{config.comment}{c.END}' else: elems_repr = config.empty_repr comment = '' type_repr = ( config.type if isinstance(config.type, str) else config.type.__name__ ) type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else '' start = f'{c.PAREN}{config.start}{c.END}' if config.start else '' end = f'{c.PAREN}{config.end}{c.END}' if config.end else '' out = f'{type_repr}{start}{comment}{elems_repr}{end}' return out finally: REPR_CONTEXT.depth -= 1 class MappingReprMixin(Representable): def __nnx_repr__(self): yield Object(type=type(self), kv_sep=': ', start='({', end='})') for key, value in self.items(): # type: ignore yield Attr(colorized(key), value, use_raw_key=True) @dataclasses.dataclass(repr=False) class PrettyMapping(Representable): mapping: tp.Mapping def __nnx_repr__(self): yield Object(type=type(self), kv_sep=': ', start='({', end='})') for key, value in self.mapping.items(): yield Attr(colorized(key), value, use_raw_key=True) @dataclasses.dataclass(repr=False) class SequenceReprMixin(Representable): def __nnx_repr__(self): yield Object(type=type(self), kv_sep='', start='([', end='])') for value in self: # type: ignore yield Attr('', value, use_raw_key=True) @dataclasses.dataclass(repr=False) class PrettySequence(Representable): sequence: tp.Sequence def __nnx_repr__(self): yield Object(type=type(self), kv_sep='', start='([', end='])') for value in self.sequence: yield Attr('', value, use_raw_key=True) ================================================ FILE: flax/nnx/rnglib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import functools import typing as tp import jax from jax import random import jax.numpy as jnp from flax import struct from flax import typing from flax.nnx import graphlib from flax.nnx.nn import initializers from flax.nnx.variablelib import Variable from flax.nnx import filterlib from flax.nnx.pytreelib import Pytree from flax.typing import MISSING, Key, Missing import warnings F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) A = tp.TypeVar('A') Counts = list[int] AxesValue = tp.Union[int, None] SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]] OutShardingType: tp.TypeAlias = ( jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None ) Fargs = tp.ParamSpec('Fargs') @tp.runtime_checkable class KeylessInitializer(tp.Protocol): def __call__( self, shape: typing.Shape, dtype: tp.Any | None = None, out_sharding: OutShardingType = None, ) -> jax.Array: raise NotImplementedError def _to_keyless( initializer_constructor: tp.Callable[Fargs, jax.nn.initializers.Initializer], ) -> tp.Callable[Fargs, KeylessInitializer]: raise NotImplementedError def _function_to_method(random_f): @functools.wraps(random_f) def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array: return random_f(self(), *args, **kwargs) return rngs_random_method def _initializer_to_method( initializer_constructor: tp.Callable[Fargs, jax.nn.initializers.Initializer], ): def rngs_initializer_method( self: Rngs | RngStream, *args: Fargs.args, **kwargs: Fargs.kwargs ) -> KeylessInitializer: init_fn = initializer_constructor(*args, **kwargs) def rngs_keyless_initializer(*init_args, **init_kwargs): return init_fn(self(), *init_args, **init_kwargs) return rngs_keyless_initializer return rngs_initializer_method class RngState(Variable[jax.Array]): tag: str class RngCount(RngState): ... class RngKey(RngState): ... NotKey = filterlib.All(RngState, filterlib.Not(RngKey)) class RngStream(Pytree): def __init__( self, key: jax.Array | int, *, tag: str, ): if isinstance(key, int): key = random.key(key) elif isinstance(key, jax.Array) and key.dtype == jnp.uint32: key = random.wrap_key_data(key) if not isinstance(key, jax.Array) or not jnp.issubdtype(key.dtype, jax.dtypes.prng_key): raise ValueError(f'Invalid rng value: {key}, expected a ' f'jax.Array of jax.dtypes.prng_key sub-dtype') count = jnp.zeros(key.shape, dtype=jnp.uint32) self.tag = tag self.key = RngKey(key, tag=tag) self.count = RngCount(count, tag=tag) def __call__(self) -> jax.Array: self.count._check_can_update() key = random.fold_in(self.key[...], self.count[...]) self.count[...] += 1 return key def split(self, k: int | tuple[int, ...]): key = random.split(self(), k) return type(self)(key, tag=self.tag) def fork(self, *, split: int | tuple[int, ...] | None = None): if split is not None: warnings.warn( "The 'split' argument of 'fork' is deprecated; use the 'split' method instead.", DeprecationWarning, stacklevel=2, ) key = self() if split is not None: key = random.split(key, split) return type(self)(key, tag=self.tag) # ---------------------------------------------------------- # random functions # ---------------------------------------------------------- if tp.TYPE_CHECKING: bits = staticmethod(functools.partial(random.bits, random.key(0))) uniform = staticmethod( functools.partial(random.uniform, random.key(0)) ) randint = staticmethod( functools.partial(random.randint, random.key(0)) ) permutation = staticmethod( functools.partial(random.permutation, random.key(0)) ) choice = staticmethod(functools.partial(random.choice, random.key(0))) normal = staticmethod(functools.partial(random.normal, random.key(0))) multivariate_normal = staticmethod( functools.partial(random.multivariate_normal, random.key(0)) ) truncated_normal = staticmethod( functools.partial(random.truncated_normal, random.key(0)) ) bernoulli = staticmethod( functools.partial(random.bernoulli, random.key(0)) ) beta = staticmethod(functools.partial(random.beta, random.key(0))) cauchy = staticmethod(functools.partial(random.cauchy, random.key(0))) dirichlet = staticmethod( functools.partial(random.dirichlet, random.key(0)) ) exponential = staticmethod( functools.partial(random.exponential, random.key(0)) ) gamma = staticmethod(functools.partial(random.gamma, random.key(0))) loggamma = staticmethod( functools.partial(random.loggamma, random.key(0)) ) poisson = staticmethod( functools.partial(random.poisson, random.key(0)) ) gumbel = staticmethod(functools.partial(random.gumbel, random.key(0))) categorical = staticmethod( functools.partial(random.categorical, random.key(0)) ) laplace = staticmethod( functools.partial(random.laplace, random.key(0)) ) logistic = staticmethod( functools.partial(random.logistic, random.key(0)) ) pareto = staticmethod(functools.partial(random.pareto, random.key(0))) t = staticmethod(functools.partial(random.t, random.key(0))) chisquare = staticmethod( functools.partial(random.chisquare, random.key(0)) ) f = staticmethod(functools.partial(random.f, random.key(0))) rademacher = staticmethod( functools.partial(random.rademacher, random.key(0)) ) maxwell = staticmethod( functools.partial(random.maxwell, random.key(0)) ) double_sided_maxwell = staticmethod( functools.partial(random.double_sided_maxwell, random.key(0)) ) weibull_min = staticmethod( functools.partial(random.weibull_min, random.key(0)) ) orthogonal = staticmethod( functools.partial(random.orthogonal, random.key(0)) ) generalized_normal = staticmethod( functools.partial(random.generalized_normal, random.key(0)) ) ball = staticmethod(functools.partial(random.ball, random.key(0))) rayleigh = staticmethod( functools.partial(random.rayleigh, random.key(0)) ) wald = staticmethod(functools.partial(random.wald, random.key(0))) geometric = staticmethod( functools.partial(random.geometric, random.key(0)) ) triangular = staticmethod( functools.partial(random.triangular, random.key(0)) ) lognormal = staticmethod( functools.partial(random.lognormal, random.key(0)) ) binomial = staticmethod( functools.partial(random.binomial, random.key(0)) ) multinomial = staticmethod( functools.partial(random.multinomial, random.key(0)) ) else: bits = _function_to_method(random.bits) uniform = _function_to_method(random.uniform) randint = _function_to_method(random.randint) permutation = _function_to_method(random.permutation) choice = _function_to_method(random.choice) normal = _function_to_method(random.normal) multivariate_normal = _function_to_method(random.multivariate_normal) truncated_normal = _function_to_method(random.truncated_normal) bernoulli = _function_to_method(random.bernoulli) beta = _function_to_method(random.beta) cauchy = _function_to_method(random.cauchy) dirichlet = _function_to_method(random.dirichlet) exponential = _function_to_method(random.exponential) gamma = _function_to_method(random.gamma) loggamma = _function_to_method(random.loggamma) poisson = _function_to_method(random.poisson) gumbel = _function_to_method(random.gumbel) categorical = _function_to_method(random.categorical) laplace = _function_to_method(random.laplace) logistic = _function_to_method(random.logistic) pareto = _function_to_method(random.pareto) t = _function_to_method(random.t) chisquare = _function_to_method(random.chisquare) f = _function_to_method(random.f) rademacher = _function_to_method(random.rademacher) maxwell = _function_to_method(random.maxwell) double_sided_maxwell = _function_to_method(random.double_sided_maxwell) weibull_min = _function_to_method(random.weibull_min) orthogonal = _function_to_method(random.orthogonal) generalized_normal = _function_to_method(random.generalized_normal) ball = _function_to_method(random.ball) rayleigh = _function_to_method(random.rayleigh) wald = _function_to_method(random.wald) geometric = _function_to_method(random.geometric) triangular = _function_to_method(random.triangular) lognormal = _function_to_method(random.lognormal) binomial = _function_to_method(random.binomial) multinomial = _function_to_method(random.multinomial) # ---------------------------------------------------------- # initializers # ---------------------------------------------------------- if tp.TYPE_CHECKING: # skip constant delta_orthogonal = staticmethod(_to_keyless(initializers.delta_orthogonal)) glorot_normal = staticmethod(_to_keyless(initializers.glorot_normal)) glorot_uniform = staticmethod(_to_keyless(initializers.glorot_uniform)) he_normal = staticmethod(_to_keyless(initializers.he_normal)) he_uniform = staticmethod(_to_keyless(initializers.he_uniform)) kaiming_normal = staticmethod(_to_keyless(initializers.kaiming_normal)) kaiming_uniform = staticmethod(_to_keyless(initializers.kaiming_uniform)) lecun_normal = staticmethod(_to_keyless(initializers.lecun_normal)) lecun_uniform = staticmethod(_to_keyless(initializers.lecun_uniform)) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = staticmethod(_to_keyless(initializers.variance_scaling)) xavier_normal = staticmethod(_to_keyless(initializers.xavier_normal)) xavier_uniform = staticmethod(_to_keyless(initializers.xavier_uniform)) # skip zeros else: # skip constant delta_orthogonal = _initializer_to_method(initializers.delta_orthogonal) glorot_normal = _initializer_to_method(initializers.glorot_normal) glorot_uniform = _initializer_to_method(initializers.glorot_uniform) he_normal = _initializer_to_method(initializers.he_normal) he_uniform = _initializer_to_method(initializers.he_uniform) kaiming_normal = _initializer_to_method(initializers.kaiming_normal) kaiming_uniform = _initializer_to_method(initializers.kaiming_uniform) lecun_normal = _initializer_to_method(initializers.lecun_normal) lecun_uniform = _initializer_to_method(initializers.lecun_uniform) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = _initializer_to_method(initializers.variance_scaling) xavier_normal = _initializer_to_method(initializers.xavier_normal) xavier_uniform = _initializer_to_method(initializers.xavier_uniform) # skip zeros RngValue = tp.Union[int, jax.Array] class Rngs(Pytree): """A small abstraction to manage RNG state. ``Rngs`` allows the creation of ``RngStream`` which are used to easily generate new unique random keys on demand. An ``RngStream`` is a wrapper around a JAX random ``key``, and a ``counter``. Every time a key is requested, the counter is incremented and the key is generated from the seed key and the counter by using ``jax.random.fold_in``. To create an ``Rngs`` pass in an integer or ``jax.random.key`` to the constructor as a keyword argument with the name of the stream. The key will be used as the starting seed for the stream, and the counter will be initialized to zero. Then call the stream to get a key:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rngs = nnx.Rngs(params=0, dropout=1) >>> param_key1 = rngs.params() >>> param_key2 = rngs.params() >>> dropout_key1 = rngs.dropout() >>> dropout_key2 = rngs.dropout() ... >>> assert param_key1 != dropout_key1 Trying to generate a key for a stream that was not specified during construction will result in an error being raised:: >>> rngs = nnx.Rngs(params=0, dropout=1) >>> try: ... key = rngs.unkown_stream() ... except AttributeError as e: ... print(e) No RngStream named 'unkown_stream' found in Rngs. The ``default`` stream can be created by passing in a key to the constructor without specifying a stream name. When the ``default`` stream is set the ``rngs`` object can be called directly to get a key, and calling streams that were not specified during construction will fallback to ``default``:: >>> rngs = nnx.Rngs(0, params=1) ... >>> key1 = rngs.default() # uses 'default' >>> key2 = rngs() # uses 'default' >>> key3 = rngs.params() # uses 'params' >>> key4 = rngs.dropout() # uses 'default' >>> key5 = rngs.unkown_stream() # uses 'default' """ def __init__( self, default: RngValue | RngStream | tp.Mapping[str, RngValue | RngStream] | None = None, **rngs: RngValue | RngStream, ): """ Args: default: the starting seed for the ``default`` stream, defaults to None. **rngs: keyword arguments specifying the starting seed for each stream. The key can be an integer or a ``jax.random.key``. """ if default is not None: if isinstance(default, tp.Mapping): rngs = {**default, **rngs} else: rngs['default'] = default for tag, key in rngs.items(): if isinstance(key, RngStream): key = key.key.get_value() stream = RngStream( key=key, tag=tag, ) setattr(self, tag, stream) def _get_stream(self, name: str, error_type: type[Exception]) -> RngStream: stream_vars = vars(self) if name not in stream_vars: if 'default' not in stream_vars: raise error_type(f"No RngStream named '{name}' found in Rngs.") stream = stream_vars['default'] else: stream = stream_vars[name] return stream def __getitem__(self, name: str): return self._get_stream(name, KeyError) def __getattr__(self, name: str): return self._get_stream(name, AttributeError) def __call__(self): return self.default() def __iter__(self) -> tp.Iterator[str]: for name, stream in vars(self).items(): if isinstance(stream, RngStream): yield name def __len__(self) -> int: return sum( 1 for stream in vars(self).values() if isinstance(stream, RngStream) ) def __contains__(self, name: tp.Any) -> bool: return name in vars(self) def items(self): for name, stream in vars(self).items(): if isinstance(stream, RngStream): yield name, stream def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | tuple[int, ...]): """ Splits the keys of the newly created ``Rngs`` object. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=1, dropout=2) >>> new_rngs = rngs.split(5) ... >>> assert new_rngs.params.key.shape == (5,) >>> assert new_rngs.dropout.key.shape == (5,) ``split`` also accepts a mapping of `Filters `__ to split sizes or None to control which streams are split and how they are split:: >>> rngs = nnx.Rngs(params=1, dropout=2, noise=3) >>> new_rngs = rngs.split({ ... 'params': 5, # split params into 5 keys ... 'dropout': None, # don't split dropout ... ...: (2, 5), # split anything else into 2x5 keys ... }) ... >>> assert new_rngs.params.key.shape == (5,) >>> assert new_rngs.dropout.key.shape == () >>> assert new_rngs.noise.key.shape == (2, 5) """ if isinstance(k, int): k = {...: k} elif isinstance(k, tuple): k = {...: k} split_predicates = {filterlib.to_predicate(k): v for k, v in k.items()} keys: dict[str, RngStream] = {} for name, stream in self.items(): for predicate, num_splits in split_predicates.items(): if predicate((), stream): if num_splits is None: keys[name] = stream else: keys[name] = stream.split(num_splits) break else: keys[name] = stream return Rngs(**keys) def fork( self, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | tuple[int, ...] | None = None, ): """Returns a new Rngs object with new unique RNG keys. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=1, dropout=2) >>> new_rngs = rngs.fork() ... >>> assert rngs.params() != new_rngs.params() """ if split is not None: warnings.warn( "The 'split' argument of 'fork' is deprecated; use the 'split' method instead.", DeprecationWarning, stacklevel=2, ) if split is None: split = {} elif isinstance(split, int): split = {...: split} elif isinstance(split, tuple): split = {...: split} split_predicates = {filterlib.to_predicate(k): v for k, v in split.items()} keys: dict[str, RngStream] = {} for name, stream in self.items(): for predicate, num_splits in split_predicates.items(): if predicate((), stream): keys[name] = stream.fork(split=num_splits) break else: keys[name] = stream.fork() return Rngs(**keys) # ---------------------------------------------------------- # random functions # ---------------------------------------------------------- if tp.TYPE_CHECKING: bits = staticmethod(functools.partial(random.bits, random.key(0))) uniform = staticmethod( functools.partial(random.uniform, random.key(0)) ) randint = staticmethod( functools.partial(random.randint, random.key(0)) ) permutation = staticmethod( functools.partial(random.permutation, random.key(0)) ) choice = staticmethod(functools.partial(random.choice, random.key(0))) normal = staticmethod(functools.partial(random.normal, random.key(0))) multivariate_normal = staticmethod( functools.partial(random.multivariate_normal, random.key(0)) ) truncated_normal = staticmethod( functools.partial(random.truncated_normal, random.key(0)) ) bernoulli = staticmethod( functools.partial(random.bernoulli, random.key(0)) ) beta = staticmethod(functools.partial(random.beta, random.key(0))) cauchy = staticmethod(functools.partial(random.cauchy, random.key(0))) dirichlet = staticmethod( functools.partial(random.dirichlet, random.key(0)) ) exponential = staticmethod( functools.partial(random.exponential, random.key(0)) ) gamma = staticmethod(functools.partial(random.gamma, random.key(0))) loggamma = staticmethod( functools.partial(random.loggamma, random.key(0)) ) poisson = staticmethod( functools.partial(random.poisson, random.key(0)) ) gumbel = staticmethod(functools.partial(random.gumbel, random.key(0))) categorical = staticmethod( functools.partial(random.categorical, random.key(0)) ) laplace = staticmethod( functools.partial(random.laplace, random.key(0)) ) logistic = staticmethod( functools.partial(random.logistic, random.key(0)) ) pareto = staticmethod(functools.partial(random.pareto, random.key(0))) t = staticmethod(functools.partial(random.t, random.key(0))) chisquare = staticmethod( functools.partial(random.chisquare, random.key(0)) ) f = staticmethod(functools.partial(random.f, random.key(0))) rademacher = staticmethod( functools.partial(random.rademacher, random.key(0)) ) maxwell = staticmethod( functools.partial(random.maxwell, random.key(0)) ) double_sided_maxwell = staticmethod( functools.partial(random.double_sided_maxwell, random.key(0)) ) weibull_min = staticmethod( functools.partial(random.weibull_min, random.key(0)) ) orthogonal = staticmethod( functools.partial(random.orthogonal, random.key(0)) ) generalized_normal = staticmethod( functools.partial(random.generalized_normal, random.key(0)) ) ball = staticmethod(functools.partial(random.ball, random.key(0))) rayleigh = staticmethod( functools.partial(random.rayleigh, random.key(0)) ) wald = staticmethod(functools.partial(random.wald, random.key(0))) geometric = staticmethod( functools.partial(random.geometric, random.key(0)) ) triangular = staticmethod( functools.partial(random.triangular, random.key(0)) ) lognormal = staticmethod( functools.partial(random.lognormal, random.key(0)) ) binomial = staticmethod( functools.partial(random.binomial, random.key(0)) ) multinomial = staticmethod( functools.partial(random.multinomial, random.key(0)) ) else: bits = _function_to_method(random.bits) uniform = _function_to_method(random.uniform) randint = _function_to_method(random.randint) permutation = _function_to_method(random.permutation) choice = _function_to_method(random.choice) normal = _function_to_method(random.normal) multivariate_normal = _function_to_method(random.multivariate_normal) truncated_normal = _function_to_method(random.truncated_normal) bernoulli = _function_to_method(random.bernoulli) beta = _function_to_method(random.beta) cauchy = _function_to_method(random.cauchy) dirichlet = _function_to_method(random.dirichlet) exponential = _function_to_method(random.exponential) gamma = _function_to_method(random.gamma) loggamma = _function_to_method(random.loggamma) poisson = _function_to_method(random.poisson) gumbel = _function_to_method(random.gumbel) categorical = _function_to_method(random.categorical) laplace = _function_to_method(random.laplace) logistic = _function_to_method(random.logistic) pareto = _function_to_method(random.pareto) t = _function_to_method(random.t) chisquare = _function_to_method(random.chisquare) f = _function_to_method(random.f) rademacher = _function_to_method(random.rademacher) maxwell = _function_to_method(random.maxwell) double_sided_maxwell = _function_to_method(random.double_sided_maxwell) weibull_min = _function_to_method(random.weibull_min) orthogonal = _function_to_method(random.orthogonal) generalized_normal = _function_to_method(random.generalized_normal) ball = _function_to_method(random.ball) rayleigh = _function_to_method(random.rayleigh) wald = _function_to_method(random.wald) geometric = _function_to_method(random.geometric) triangular = _function_to_method(random.triangular) lognormal = _function_to_method(random.lognormal) binomial = _function_to_method(random.binomial) multinomial = _function_to_method(random.multinomial) # ---------------------------------------------------------- # initializers # ---------------------------------------------------------- if tp.TYPE_CHECKING: # skip constant delta_orthogonal = staticmethod(_to_keyless(initializers.delta_orthogonal)) glorot_normal = staticmethod(_to_keyless(initializers.glorot_normal)) glorot_uniform = staticmethod(_to_keyless(initializers.glorot_uniform)) he_normal = staticmethod(_to_keyless(initializers.he_normal)) he_uniform = staticmethod(_to_keyless(initializers.he_uniform)) kaiming_normal = staticmethod(_to_keyless(initializers.kaiming_normal)) kaiming_uniform = staticmethod(_to_keyless(initializers.kaiming_uniform)) lecun_normal = staticmethod(_to_keyless(initializers.lecun_normal)) lecun_uniform = staticmethod(_to_keyless(initializers.lecun_uniform)) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = staticmethod(_to_keyless(initializers.variance_scaling)) xavier_normal = staticmethod(_to_keyless(initializers.xavier_normal)) xavier_uniform = staticmethod(_to_keyless(initializers.xavier_uniform)) # skip zeros else: # skip constant delta_orthogonal = _initializer_to_method(initializers.delta_orthogonal) glorot_normal = _initializer_to_method(initializers.glorot_normal) glorot_uniform = _initializer_to_method(initializers.glorot_uniform) he_normal = _initializer_to_method(initializers.he_normal) he_uniform = _initializer_to_method(initializers.he_uniform) kaiming_normal = _initializer_to_method(initializers.kaiming_normal) kaiming_uniform = _initializer_to_method(initializers.kaiming_uniform) lecun_normal = _initializer_to_method(initializers.lecun_normal) lecun_uniform = _initializer_to_method(initializers.lecun_uniform) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = _initializer_to_method(initializers.variance_scaling) xavier_normal = _initializer_to_method(initializers.xavier_normal) xavier_uniform = _initializer_to_method(initializers.xavier_uniform) # skip zeros StreamBackup = ( tuple[RngStream, jax.Array, jax.Array] | tuple[RngStream, jax.Array] ) class SplitBackups(struct.PyTreeNode, tp.Iterable[StreamBackup]): backups: list[StreamBackup] def __iter__(self) -> tp.Iterator[StreamBackup]: return iter(self.backups) def __enter__(self): return self def __exit__(self, *args): restore_rngs(self) @tp.overload def split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: tp.Literal[True] | None = None, ) -> SplitBackups: ... @tp.overload def split_rngs( node: A, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: tp.Literal[False], ) -> A: ... @tp.overload def split_rngs( *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: bool | None = None, ) -> tp.Callable[[F], F]: ... def split_rngs( node: tp.Any = MISSING, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: bool | None = None, ) -> SplitBackups | tp.Any | tp.Callable[[F], F]: """Splits the (nested) Rng states of the given node. Args: node: the base node containing the rng states to split. splits: an integer or tuple of integers specifying the shape of the split rng keys. only: a Filter selecting which rng states to split. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A SplitBackups iterable if ``node`` is provided, otherwise a decorator that splits the rng states of the inputs to the decorated function. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5) >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), (5,)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=(2, 5)) >>> rngs.params.key.shape, rngs.dropout.key.shape ((2, 5), (2, 5)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), ()) Once split, random state can be used with transforms like :func:`nnx.vmap`:: >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') ... >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape () ``split_rngs`` returns a SplitBackups object that can be used to restore the original unsplit rng states using :func:`nnx.restore_rngs`, this is useful when you only want to split the rng states temporarily:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.split_rngs(rngs, splits=5, only='params') >>> model = create_model(rngs) >>> nnx.restore_rngs(backups) ... >>> model.dropout.rngs.key.shape () SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.split_rngs(rngs, splits=5, only='params'): ... model = create_model(rngs) ... >>> model.dropout.rngs.key.shape () >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.split_rngs(splits=5, only='params') ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape () """ if graph is None: graph = graphlib.set_graph_mode.current_value() if isinstance(node, Missing): def split_rngs_decorator(f: F) -> F: @functools.wraps(f) def split_rngs_wrapper(*args, **kwargs): if graph: with split_rngs( (args, kwargs), splits=splits, only=only, squeeze=squeeze, graph=True, ): return f(*args, **kwargs) else: args, kwargs = split_rngs( (args, kwargs), splits=splits, only=only, squeeze=squeeze, graph=False, ) return f(*args, **kwargs) return tp.cast(F, split_rngs_wrapper) return split_rngs_decorator # type: ignore[bad-return-type] if squeeze and splits != 1: raise ValueError('squeeze=True is only supported for splits=1') if graph: return _graph_split_rngs( node, splits=splits, only=only, squeeze=squeeze, ) else: return _tree_split_rngs( node, splits=splits, only=only, squeeze=squeeze, ) def _graph_split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, ) -> SplitBackups: predicate = filterlib.to_predicate(only) backups: list[StreamBackup] = [] for path, stream in graphlib.iter_graph(node, graph=True): if ( isinstance(stream, RngStream) and predicate((*path, 'key'), stream.key) and predicate((*path, 'count'), stream.count) ): key = stream() backups.append((stream, stream.key[...], stream.count[...])) key = random.split(key, splits) if squeeze: key = key[0] stream.key.set_value(key) if squeeze: counts_shape = stream.count.shape elif isinstance(splits, int): counts_shape = (splits, *stream.count.shape) else: counts_shape = (*splits, *stream.count.shape) stream.count.set_value(jnp.zeros(counts_shape, dtype=jnp.uint32)) return SplitBackups(backups) def _tree_split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, ) -> tp.Any: predicate = filterlib.to_predicate(only) def _split_stream(path, node): if ( isinstance(node, RngStream) and predicate((*path, 'key'), node.key) and predicate((*path, 'count'), node.count) ): key = random.split(node(), splits) if squeeze: key = key[0] if squeeze: counts_shape = node.count.shape elif isinstance(splits, int): counts_shape = (splits, *node.count.shape) else: counts_shape = (*splits, *node.count.shape) node.key = RngKey(key, tag=node.tag) node.count = RngCount( jnp.zeros(counts_shape, dtype=jnp.uint32), tag=node.tag ) return node return graphlib.recursive_map(_split_stream, node, graph=False) @tp.overload def fork_rngs( node: tp.Any, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> SplitBackups: ... @tp.overload def fork_rngs( *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> tp.Callable[[F], F]: ... def fork_rngs( node: tp.Any = MISSING, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> SplitBackups | tp.Callable[[F], F]: """Forks the (nested) Rng states of the given node. Args: node: the base node containing the rng states to fork. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A SplitBackups iterable if ``node`` is provided, otherwise a decorator that forks the rng states of the inputs to the decorated function. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs) ``fork_rngs`` returns a SplitBackups object that can be used to restore the original unforked rng states using :func:`nnx.restore_rngs`, this is useful when you only want to fork the rng states temporarily:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.fork_rngs(rngs) >>> model = nnx.Linear(2, 3, rngs=rngs) >>> nnx.restore_rngs(backups) ... SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.fork_rngs(rngs): ... model = nnx.Linear(2, 3, rngs=rngs) """ if isinstance(node, Missing): def fork_rngs_decorator(f: F) -> F: @functools.wraps(f) def fork_rngs_wrapper(*args, **kwargs): with fork_rngs((args, kwargs), split=split): return f(*args, **kwargs) return tp.cast(F, fork_rngs_wrapper) return fork_rngs_decorator # type: ignore[bad-return-type] if split is None: split = {...: None} elif isinstance(split, int | tuple): split = {...: split} predicate_splits = { filterlib.to_predicate(k): v for k, v in split.items() } backups: list[StreamBackup] = [] for path, stream in graphlib.iter_graph(node, graph=graph): for predicate, splits in predicate_splits.items(): if ( isinstance(stream, RngStream) and predicate((*path, 'key'), stream.key) and predicate((*path, 'count'), stream.count) ): forked_stream = stream.fork(split=splits) # backup the original stream state backups.append((stream, stream.key[...], stream.count[...])) # apply the forked key and count to the original stream stream.key.set_value(forked_stream.key.get_value()) stream.count.set_value(forked_stream.count.get_value()) return SplitBackups(backups) def backup_keys(node: tp.Any, /, *, graph: bool | None = None): backups: list[StreamBackup] = [] for _, stream in graphlib.iter_graph(node, graph=graph): if isinstance(stream, RngStream): backups.append((stream, stream.key[...])) return backups def _scalars_only( path: tuple[Key, ...], scalar_key: jax.Array, target_shape: tuple[int, ...] ) -> jax.Array: if target_shape != (): raise ValueError( f'Cannot reseed stream at path {path!r} becuase it has a non-scalar key, ' f'found key with shape {target_shape}. If all your multi-dimensional ' 'keys have unique values on all dimensions, set policy="match_shape", ' 'else provide a custom reseed policy.' ) return scalar_key def _match_shape( path: tuple[Key, ...], scalar_key: jax.Array, target_shape: tuple[int, ...] ) -> jax.Array: if target_shape == (): return scalar_key return random.split(scalar_key, target_shape) def reseed( node, /, *, graph: bool | None = None, policy: tp.Literal['scalars_only', 'match_shape'] | tp.Callable[ [tuple, jax.Array, tuple[int, ...]], jax.Array ] = 'scalars_only', **stream_keys: RngValue, ): """Update the keys of the specified RNG streams with new keys. Args: node: the node to reseed the RNG streams in. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. policy: defines how the the new scalar key is for each RngStream is used to reseed the stream. If ``'scalars_only'`` is given (the default), an error is raised if the target stream key is not a scalar. If ``'match_shape'`` is given, the new scalar key is split to match the shape of the target stream key. A callable of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to define a custom reseeding policy. **stream_keys: a mapping of stream names to new keys. The keys can be either integers or ``jax.random.key``. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool) """ if policy == 'scalars_only': policy = _scalars_only elif policy == 'match_shape': policy = _match_shape elif not callable(policy): raise ValueError( f'policy must be "scalars_only", "match_shape" or a callable, ' f'got {policy!r}' ) rngs = Rngs(**stream_keys) for path, stream in graphlib.iter_graph(node, graph=graph): if isinstance(stream, RngStream): if stream.key.tag in stream_keys: key = rngs[stream.key.tag]() key = policy(path, key, stream.key.shape) stream.key.set_value(key) stream.count.set_value(jnp.zeros(key.shape, dtype=jnp.uint32)) def restore_rngs(backups: tp.Iterable[StreamBackup], /): for backup in backups: stream = backup[0] stream.key.set_value(backup[1]) if len(backup) == 3: stream.count.set_value(backup[2]) # count ================================================ FILE: flax/nnx/scripts/requirements.txt ================================================ datasets>=2.12.0 ================================================ FILE: flax/nnx/scripts/run-all-examples.bash ================================================ set -e source .venv/bin/activate for f in $(find examples/nnx_toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" echo "$f" echo "---------------------------------" MPLBACKEND=Agg python "$f" done ================================================ FILE: flax/nnx/spmd.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp import flax.core.spmd as core_spmd from flax.nnx import variablelib, graphlib from flax.nnx.transforms.transforms import eval_shape from flax.typing import ( Sharding, ) import jax from jax.sharding import PartitionSpec A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) PARTITION_NAME = 'partition_name' # Transform axis change helpers # ------------------------------------------------------------------------------ def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A: axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata) def insert_field(fields, index, value): iterable = list(fields) while len(iterable) < index: iterable.append(None) iterable.insert(index, value) return tuple(iterable) def _add_axis(x: tp.Any): if isinstance(x, variablelib.Variable): metadata = x.get_metadata() if 'out_sharding' in metadata and metadata['out_sharding']: sharding = metadata['out_sharding'] x.set_metadata(out_sharding=insert_field(sharding, index, axis_name)) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): x.set_metadata(k, insert_field(t, index, v)) assert isinstance(x, variablelib.Variable) x.add_axis(index, axis_name) return x return jax.tree.map( _add_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) ) def remove_axis( tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any] ) -> A: axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata) def remove_field(fields, index, value): iterable = list(fields) removed = iterable.pop(index) if removed != value: raise ValueError( f'Expected to remove {value!r} at index {index} from ' f'{fields!r}, but found {removed!r}.' ) return tuple(iterable) def _remove_axis(x: tp.Any): if isinstance(x, variablelib.Variable): if hasattr(x, 'out_sharding') and x.out_sharding is not None: x.set_metadata( out_sharding=remove_field(x.out_sharding, index, axis_name) ) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): x.set_metadata(k, remove_field(t, index, v)) x.remove_axis(index, axis_name) return x return jax.tree.map( _remove_axis, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable), ) def _get_partition_name_and_metadata( transform_metadata: tp.Mapping[tp.Any, tp.Any], ) -> tuple[str, tp.Mapping[tp.Any, tp.Any]]: if PARTITION_NAME not in transform_metadata: raise ValueError( 'Trying to transform a Partitioned variable but "partition_name" ' f'is not specified in transform_metadata: {transform_metadata}' ) other_meta = dict(transform_metadata) # shallow copy other_meta.pop(PARTITION_NAME) return transform_metadata[PARTITION_NAME], other_meta # Annotation handling # ------------------------------------------------------------------------------ def with_partitioning( initializer: F, sharding: Sharding, mesh: tp.Optional[jax.sharding.Mesh] = None, **metadata: tp.Any, ) -> F: """A wrapper over any initializer to add sharding annotation data to a `Variable`.""" return variablelib.with_metadata( initializer, out_sharding=sharding, mesh=mesh, **metadata, ) def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: """Given an `nnx.Variable`, return its `PartitionSpec`.""" metadata = v.get_metadata() if 'out_sharding' in metadata and metadata['out_sharding']: sharding = metadata['out_sharding'] if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: context_rules = core_spmd.get_logical_axis_rules() local_rules = metadata.get('sharding_rules', ()) rules = core_spmd.composite_rules(context_rules, local_rules) return PartitionSpec(*core_spmd.from_sharding_rules(sharding, rules)) return PartitionSpec(*sharding) elif hasattr(v, 'shape'): return PartitionSpec() return None def get_partition_spec(tree: A) -> A: """Extracts a PartitionSpec tree from a PyTree containing ``Variable`` values.""" def f(x): if isinstance(x, variablelib.Variable): return x.replace(get_var_pspec(x)) elif hasattr(x, 'shape'): return PartitionSpec() return None return jax.tree.map( f, tree, is_leaf=lambda x: isinstance(x, variablelib.Variable) ) def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) sharding = jax.tree.map(lambda p: jax.sharding.NamedSharding(mesh, p), spec) return sharding # Other utilities # ------------------------------------------------------------------------------ def get_abstract_model(init_fn, mesh, *, graph: bool | None = None): with jax.set_mesh(mesh): abs_model = eval_shape(init_fn, graph=graph) gdef, abs_state = graphlib.split(abs_model, graph=graph) abs_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_state, get_named_sharding(abs_state, mesh) ) return gdef, abs_state def abstract_with_sharding( tree: A, graph: bool | None = None ) -> A: """Add sharding information to abstract Variables. When creating models with :func:`eval_shape`, Variables are abstract (backed by ``jax.ShapeDtypeStruct``) and may not carry sharding information, especially when using meshes with :attr:`jax.sharding.AxisType.Auto` axes. ``abstract_with_sharding`` inspects each Variable in ``tree`` and, if it has ``out_sharding`` metadata but no sharding already set, attaches a :class:`jax.sharding.NamedSharding` derived from the Variable's ``out_sharding`` and either its ``mesh`` metadata or the current abstract mesh (``jax.sharding.get_abstract_mesh``). Example usage:: from flax import nnx import jax mesh = jax.make_mesh((2, 2), ('a', 'b'), axis_types=(jax.sharding.AxisType.Auto,) * 2) with jax.set_mesh(mesh): abs_model = nnx.eval_shape( lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0), kernel_metadata={'out_sharding': ('a', 'b')})) abs_model = nnx.abstract_with_sharding(abs_model) assert abs_model.kernel.sharding.spec == jax.P('a', 'b') Args: tree: A graph node (e.g. an :class:`nnx.Module`) whose Variables should be annotated with sharding (via ``out_sharding`` metadata). graph: Forwarded to :func:`nnx.map`. If ``True``, uses graph-mode; if ``False``, uses tree-mode. Returns: A tree with sharding-annotated ShapeDtypeStruct values inside Variables. """ def add_sharding(_path, x): if ( isinstance(x, variablelib.Variable) and hasattr(value := x.get_value(), 'shape') and hasattr(value, 'dtype') and getattr(value, 'sharding', None) is None and x.has_metadata('out_sharding') ): if x.has_metadata('mesh'): mesh = x.get_metadata('mesh') else: mesh = jax.sharding.get_abstract_mesh() specs = get_var_pspec(x) sharding = jax.sharding.NamedSharding(mesh, specs) abs_var = x.replace( jax.ShapeDtypeStruct(value.shape, value.dtype, sharding=sharding) ) return abs_var return x return graphlib.map(add_sharding, tree, graph=graph) ================================================ FILE: flax/nnx/statelib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pytype: skip-file from __future__ import annotations import typing as tp from collections.abc import MutableMapping from functools import partial import warnings import jax import jax.tree_util as jtu import treescope # type: ignore[import-not-found,import-untyped] from flax.nnx import filterlib, reprlib, traversals, variablelib from flax.typing import Key, PathParts A = tp.TypeVar('A') K = tp.TypeVar('K', bound=tp.Hashable) S = tp.TypeVar('S', bound='State') V = tp.TypeVar('V') ExtractValueFn = tp.Callable[[tp.Any], tp.Any] SetValueFn = tp.Callable[[V, tp.Any], V] class NestedStateRepr(reprlib.Representable): def __init__(self, state: State): self.state = state def __nnx_repr__(self): yield reprlib.Object('', kv_sep=': ', start='{', end='}') for r in self.state.__nnx_repr__(): if isinstance(r, reprlib.Object): continue yield r def __treescope_repr__(self, path, subtree_renderer): children = {} for k, v in self.state.items(): if isinstance(v, State): v = NestedStateRepr(v) children[k] = v # Render as the dictionary itself at the same path. return subtree_renderer(children, path=path) class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.Representable): __slots__ = ('_keys', '_values') _keys: tuple[PathParts, ...] _values: list[V] def __init__(self, items: tp.Iterable[tuple[PathParts, V]], /, *, sort: bool): keys, values = [], [] if sort: items = sorted(items) for key, value in items: keys.append(key) values.append(value) self._keys = tuple(keys) self._values = values @staticmethod def from_sorted_keys_values( keys: tuple[PathParts, ...], values: list[V], / ) -> FlatState[V]: flat_state = object.__new__(FlatState) flat_state._keys = keys flat_state._values = values return flat_state @property def paths(self) -> tp.Tuple[PathParts, ...]: return self._keys @property def leaves(self) -> list[V]: return self._values def __nnx_repr__(self): yield reprlib.Object(type='FlatState', kv_sep='', start='([', end='])') for value in self: yield reprlib.Attr('', value) @tp.overload def __getitem__(self, index: int) -> tuple[PathParts, V]: ... @tp.overload def __getitem__(self, index: slice) -> FlatState[V]: ... def __getitem__( self, index: int | slice ) -> tuple[PathParts, V] | FlatState[V]: if isinstance(index, int): return self._keys[index], self._values[index] return FlatState(zip(self._keys[index], self._values[index]), sort=False) def __len__(self) -> int: return len(self._keys) def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]: return iter(zip(self._keys, self._values)) def to_nested_state(self) -> State[Key, V]: return from_flat_state(self) @tp.overload def split(self, first: filterlib.Filter, /) -> FlatState[V]: ... @tp.overload def split( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[FlatState[V], ...]: ... @tp.overload def split( self, /, *filters: filterlib.Filter ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: ... def split( # type: ignore[misc] self, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: filters = (first, *filters) *flat_states_, rest = _split_state(self, *filters) if rest: raise ValueError( 'Non-exhaustive filters, got a non-empty remainder: ' f'{rest}.\nUse `...` to match all remaining elements.' ) flat_states: FlatState[V] | tuple[FlatState[V], ...] if len(flat_states_) == 1: flat_states = flat_states_[0] else: flat_states = tuple(flat_states_) return flat_states # type: ignore @tp.overload def filter(self, first: filterlib.Filter, /) -> FlatState[V]: ... @tp.overload def filter( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[FlatState[V], ...]: ... def filter( self, first: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tp.Union[FlatState[V], tuple[FlatState[V], ...]]: *flat_states_, _rest = _split_state(self, first, *filters) assert len(flat_states_) == len(filters) + 1 flat_states: FlatState[V] | tuple[FlatState[V], ...] if len(flat_states_) == 1: flat_states = flat_states_[0] else: flat_states = tuple(flat_states_) return flat_states # type: ignore @staticmethod def merge( flat_state: tp.Iterable[tuple[PathParts, V]], /, *flat_states: tp.Iterable[tuple[PathParts, V]], ) -> FlatState[V]: if not flat_states: if isinstance(flat_state, FlatState): return flat_state return FlatState(flat_state, sort=True) flat_states = (flat_state, *flat_states) return FlatState( (elem for flat_state in flat_states for elem in flat_state), sort=True ) def _flat_state_pytree_flatten(x: FlatState[V]): return x._values, x._keys def _flat_state_pytree_unflatten( keys: tuple[PathParts, ...], values: list[V] ) -> FlatState[V]: flat_state = object.__new__(FlatState) flat_state._keys = keys flat_state._values = values return flat_state jax.tree_util.register_pytree_node( FlatState, _flat_state_pytree_flatten, _flat_state_pytree_unflatten, ) class State(MutableMapping[K, V], reprlib.Representable): """A pytree-like ``Mapping`` with hashable and comparable keys. """ def __init__( self, mapping: tp.Union[ tp.Mapping[K, tp.Mapping | V], tp.Iterator[tuple[K, tp.Mapping | V]], ], /, *, _copy: bool = True, ): if _copy: _mapping = dict(mapping) else: if not isinstance(mapping, dict): raise ValueError( 'Expected a dictionary when `_copy=False`, ' f'got {type(mapping)} instead.' ) _mapping = mapping if tp.TYPE_CHECKING: self._mapping = _mapping else: super().__setattr__('_mapping', _mapping) @property def raw_mapping(self) -> dict[K, tp.Mapping[K, tp.Any] | V]: return self._mapping # type: ignore def __contains__(self, key) -> bool: return key in self._mapping def __getitem__(self, key: K) -> State | V: # type: ignore value = self._mapping[key] if isinstance(value, dict): return type(self)(value, _copy=False) return value # type: ignore[return-value] def __getattr__(self, key: K) -> State | V: # type: ignore[misc] if '_mapping' not in vars(self) or key not in self._mapping: raise AttributeError(f"No attribute '{key}' in State") return self[key] def __setitem__(self, key: K, value: State | V) -> None: if key == '__orig_class__': object.__setattr__(self, key, value) # type: ignore elif isinstance(value, State): self._mapping[key] = value._mapping else: self._mapping[key] = value __setattr__ = __setitem__ # type: ignore def __delitem__(self, key: K) -> None: del self._mapping[key] def __iter__(self) -> tp.Iterator[K]: return iter(self._mapping) def __len__(self) -> int: return len(self._mapping) def __nnx_repr__(self): yield reprlib.Object(type(self), kv_sep=': ', start='({', end='})') for k, v in self.items(): if isinstance(v, State): v = NestedStateRepr(v) yield reprlib.Attr(repr(k), v) def __treescope_repr__(self, path, subtree_renderer): children = {} for k, v in self.items(): if isinstance(v, State): v = NestedStateRepr(v) children[k] = v return treescope.repr_lib.render_dictionary_wrapper( object_type=type(self), wrapped_dict=children, path=path, subtree_renderer=subtree_renderer, ) def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.map_state` instead.', DeprecationWarning, ) return map_state(f, self) def flat_state(self) -> FlatState[V]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.to_flat_state` instead.', DeprecationWarning, ) return to_flat_state(self) @classmethod def from_flat_path( cls, flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]], /, ): warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.from_flat_state` instead.', DeprecationWarning, ) return from_flat_state(flat_state, cls=cls) def to_pure_dict(self, extract_fn: ExtractValueFn | None = None ) -> dict[str, tp.Any]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.to_pure_dict` instead.', DeprecationWarning, ) return to_pure_dict(self, extract_fn) def replace_by_pure_dict(self, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None): warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.replace_by_pure_dict` ' 'instead.', DeprecationWarning, ) return replace_by_pure_dict(self, pure_dict, replace_fn) @tp.overload def split(self, first: filterlib.Filter, /) -> State[K, V]: ... @tp.overload def split( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[K, V], ...]: ... @tp.overload def split( self, /, *filters: filterlib.Filter ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: ... def split( # type: ignore[misc] self, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.split_state` instead.', DeprecationWarning, ) return split_state(self, first, *filters) @tp.overload def filter( self, first: filterlib.Filter, /, ) -> State[K, V]: ... @tp.overload def filter( self, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State[K, V], ...]: ... def filter( self, first: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tp.Union[State[K, V], tuple[State[K, V], ...]]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.filter_state` instead.', DeprecationWarning, ) return filter_state(self, first, *filters) @classmethod def merge(cls, state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]): warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.merge_state` instead.', DeprecationWarning, ) return merge_state(state, *states) def __or__(self, other: State[K, V]) -> State[K, V]: if not other: return self return merge_state(self, other) def __sub__(self, other: State[K, V]) -> State[K, V]: warnings.warn( '`flax.nnx.State` will be deprecated and be replaced by the built-in ' 'Python dict. Please use the equivalent `nnx.diff` instead.', DeprecationWarning, ) return diff(self, other) def __init_subclass__(cls) -> None: super().__init_subclass__() jax.tree_util.register_pytree_with_keys( cls, _state_flatten_with_keys, partial(_state_unflatten, cls), # type: ignore[arg-type] ) def _state_flatten_with_keys(x: State): items = sorted(x._mapping.items()) children = tuple((jtu.DictKey(key), value) for key, value in items) return children, tuple(key for key, _ in items) def _state_unflatten( cls: type[S], static: tuple[K, ...], leaves: tuple[V, ...] | tuple[dict[K, V]], ): return cls(zip(static, leaves)) jax.tree_util.register_pytree_with_keys( State, _state_flatten_with_keys, partial(_state_unflatten, State), # type: ignore[arg-type] ) def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State: """Map ``f`` over :class:`State` object. Arguments: f: A function to be mapped state: A :class:`State` object. Returns: New state :class:`State`. """ flat_state = to_flat_state(state) result = [ (path, f(path, variable_state)) for path, variable_state in flat_state ] return from_flat_state(result) def to_flat_state(state: State) -> FlatState: """Convert state into flat state Arguments: state: A :class:`State` object. Returns: Flat state :class:`FlatState` """ return FlatState(traversals.flatten_to_sequence(state._mapping), sort=True) def from_flat_state( flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]], *, cls = State, # for compatibility with State subclasses ) -> State: """Convert flat state object into :class:`State` object. Arguments: flat_state: A :class:`FlatState` object. Returns: State :class:`State` object. """ if not isinstance(flat_state, tp.Mapping): flat_state = dict(flat_state) nested_state = traversals.unflatten_mapping(flat_state) return cls(nested_state) def to_pure_dict( state: State, extract_fn: ExtractValueFn | None = None ) -> dict[str, tp.Any]: """Convert :class:`State` object into pure dictionary state. Arguments: state: A :class:`State` object. extract_fn: optional extraction function. Returns: Pure dictionary. """ # Works for nnx.Variable if extract_fn is None: extract_fn = lambda x: x.get_value() if isinstance(x, variablelib.Variable) else x flat_values = {k: extract_fn(x) for k, x in to_flat_state(state)} return traversals.unflatten_mapping(flat_values) def restore_int_paths(pure_dict: dict[str, tp.Any]): """Restore integer paths from string value in the dict. This method can be helpful when restoring the state from a checkpoint as pure dictionary: Example:: >>> from flax import nnx >>> import orbax.checkpoint as ocp >>> import tempfile ... >>> model = nnx.List([nnx.Linear(10, 10, rngs=nnx.Rngs(0)) for _ in range(2)]) >>> pure_dict_state = nnx.to_pure_dict(nnx.state(model)) >>> list(pure_dict_state.keys()) [0, 1] >>> checkpointer = ocp.StandardCheckpointer() >>> with tempfile.TemporaryDirectory() as tmpdir: ... checkpointer.save(f'{tmpdir}/ckpt', pure_dict_state) ... restored_pure_dict = checkpointer.restore(f'{tmpdir}/ckpt') ... list(restored_pure_dict.keys()) ['0', '1'] >>> restored_pure_dict = nnx.restore_int_paths(restored_pure_dict) >>> list(restored_pure_dict.keys()) [0, 1] Arguments: pure_dict: state as pure dictionary Returns: state as pure dictionary with restored integers paths """ def try_convert_int(x): try: return int(x) except ValueError: return x fixed = { tuple(map(try_convert_int, path)): value for path, value in traversals.flatten_mapping(pure_dict).items() } return traversals.unflatten_mapping(fixed) def replace_by_pure_dict( state: State, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None ): """Replace input ``state`` values with ``pure_dict`` values. Arguments: state: A :class:`State` object. pure_dict: pure dictionary with values to be used for replacement. replace_fn: optional replace function. """ def try_convert_int(x): try: return int(x) except ValueError: return x # Works for nnx.Variable if replace_fn is None: replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v current_flat = dict(to_flat_state(state)) for kp, v in traversals.flatten_mapping(pure_dict).items(): # Try exact match first, then integer conversion if kp in current_flat: matched_key = kp else: int_kp = tuple(map(try_convert_int, kp)) if int_kp in current_flat: matched_key = int_kp else: raise ValueError(f'key in pure_dict not available in state: {kp}') current_flat[matched_key] = replace_fn(current_flat[matched_key], v) state.update(traversals.unflatten_mapping(current_flat)) @tp.overload def split_state(state: State, first: filterlib.Filter, /) -> State: ... @tp.overload def split_state( state: State, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State, ...]: ... @tp.overload def split_state( state: State, /, *filters: filterlib.Filter ) -> tp.Union[State, tuple[State, ...]]: ... def split_state( # type: ignore[misc] state: State, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union[State, tuple[State, ...]]: """Split a :class:`State` into one or more :class:`State`'s. The user must pass at least one ``Filter`` (i.e. :class:`Variable`), and the filters must be exhaustive (i.e. they must cover all :class:`Variable` types in the ``State``). Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = nnx.split_state(state, nnx.Param, nnx.BatchStat) Arguments: first: The first filter *filters: The optional, additional filters to group the state into mutually exclusive substates. Returns: One or more ``States`` equal to the number of filters passed. """ filters = (first, *filters) flat_states = _split_state(to_flat_state(state), *filters) *states_, rest = (state.to_nested_state() for state in flat_states) if rest: raise ValueError( 'Non-exhaustive filters, got a non-empty remainder: ' f'{rest}.\nUse `...` to match all remaining elements.' ) states: State | tuple[State, ...] if len(states_) == 1: states = states_[0] else: states = tuple(states_) return states # type: ignore @tp.overload def filter_state( state: State, first: filterlib.Filter, /, ) -> State: ... @tp.overload def filter_state( state: State, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tuple[State, ...]: ... def filter_state( state: State, first: filterlib.Filter, /, *filters: filterlib.Filter, ) -> tp.Union[State, tuple[State, ...]]: """Filter a ``State`` into one or more ``State``'s. The user must pass at least one ``Filter`` (i.e. :class:`Variable`). This method is similar to :meth:`split() `, except the filters can be non-exhaustive. Example usage:: >>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = nnx.filter_state(state, nnx.Param) >>> batch_stats = nnx.filter_state(state, nnx.BatchStat) >>> param, batch_stats = nnx.filter_state(state, nnx.Param, nnx.BatchStat) Arguments: first: The first filter *filters: The optional, additional filters to group the state into mutually exclusive substates. Returns: One or more ``States`` equal to the number of filters passed. """ flat_states = _split_state(to_flat_state(state), first, *filters) *states_, _rest = (state.to_nested_state() for state in flat_states) assert len(states_) == len(filters) + 1 states: State | tuple[State, ...] if len(states_) == 1: states = states_[0] else: states = tuple(states_) return states # type: ignore def merge_state(state: tp.Mapping, /, *states: tp.Mapping, cls = State # for compatibility with State subclasses ) -> State: """The inverse of :meth:`split() `. ``merge`` takes one or more :class:`State`'s and creates a new :class:`State`. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params['linear']['bias'][...] += 1 >>> state = nnx.merge_state(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias[...] == jnp.array([1, 1, 1])).all() Args: state: A :class:`State` object. *states: Additional :class:`State` objects. Returns: The merged :class:`State`. """ if not states: if isinstance(state, cls): return state return cls(state) states = (state, *states) new_state: dict[PathParts, tp.Any] = {} for state in states: new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here return from_flat_state(new_state, cls=cls) def diff(state: State, other: State) -> State: if not other: return state self_flat = to_flat_state(state) other_flat = to_flat_state(other) diff = {k: v for k, v in self_flat if k not in other_flat.paths} return from_flat_state(diff) def _split_state( flat_state: FlatState[V], *filters: filterlib.Filter, ) -> tuple[FlatState[V], ...]: for i, filter_ in enumerate(filters): if filter_ in (..., True) and i != len(filters) - 1: remaining_filters = filters[i + 1 :] if not all(f in (..., True) for f in remaining_filters): raise ValueError( '`...` or `True` can only be used as the last filters, ' f'got {filter_} it at index {i}.' ) predicates = tuple(map(filterlib.to_predicate, filters)) # we have n + 1 states, where n is the number of predicates # the last state is for values that don't match any predicate flat_states: tuple[list[tuple[PathParts, V]], ...] = tuple( [] for _ in range(len(predicates) + 1) ) for path, value in flat_state: for i, predicate in enumerate(predicates): if predicate(path, value): flat_states[i].append((path, value)) # type: ignore[index] # mypy is wrong here? break else: # if we didn't break, set leaf to last state flat_states[-1].append((path, value)) # type: ignore[index] # mypy is wrong here? return tuple(FlatState(flat_state, sort=False) for flat_state in flat_states) def create_path_filters(state: State): flat_state = to_flat_state(state) value_paths: dict[tp.Any, set[PathParts]] = {} for path, value in flat_state: if isinstance(value, variablelib.Variable): value = value.get_value() value_paths.setdefault(value, set()).add(path) return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} ================================================ FILE: flax/nnx/summary.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import dataclasses import inspect import io import typing as tp from types import MappingProxyType import functools import itertools import jax import numpy as np import rich.console import rich.table import rich.text import yaml import jax.numpy as jnp from flax import nnx from flax import typing from flax.nnx import graphlib, statelib, variablelib from functools import wraps try: from IPython import get_ipython in_ipython = get_ipython() is not None except ImportError: in_ipython = False # Custom YAML dumper to represent None as 'None' string (not YAML 'null') for clarity class NoneDumper(yaml.SafeDumper): pass NoneDumper.add_representer( type(None), lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'), ) class SizeBytes(typing.SizeBytes): def __repr__(self) -> str: bytes_repr = _bytes_repr(self.bytes) return f'{self.size:,} [dim]({bytes_repr})[/dim]' class ObjectInfo(tp.NamedTuple): path: statelib.PathParts stats: dict[type[variablelib.Variable], SizeBytes] variable_groups: defaultdict[ type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable] ] NodeStats = dict[int, ObjectInfo | None] def _collect_stats( path: statelib.PathParts, node: tp.Any, node_stats: NodeStats, object_types: set[type], ): if not graphlib.is_node(node) and not isinstance(node, variablelib.Variable): raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.') if id(node) in node_stats: return stats: dict[type[variablelib.Variable], SizeBytes] = {} variable_groups: defaultdict[ type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable] ] = defaultdict(lambda: defaultdict()) node_stats[id(node)] = ObjectInfo(path, stats, variable_groups) if isinstance(node, nnx.Pytree): node._nnx_tabulate_id = id(node) # type: ignore object_types.add(type(node)) node_impl = graphlib.get_node_impl(node) assert node_impl is not None node_dict = node_impl.node_dict(node) for key, value in node_dict.items(): if id(value) in node_stats: continue elif isinstance(value, variablelib.Variable): var_type = type(value) if issubclass(var_type, nnx.RngState): var_type = nnx.RngState size_bytes = SizeBytes.from_any(value.get_value()) if var_type in stats: stats[var_type] += size_bytes else: stats[var_type] = size_bytes variable_groups[var_type][key] = value node_stats[id(value)] = None elif graphlib.is_node(value): _collect_stats((*path, key), value, node_stats, object_types) # accumulate stats from children child_info = node_stats[id(value)] assert child_info is not None for var_type, size_bytes in child_info.stats.items(): if var_type in stats: stats[var_type] += size_bytes else: stats[var_type] = size_bytes @dataclasses.dataclass(frozen=True, repr=False) class ArrayRepr: shape: tuple[int, ...] dtype: tp.Any @classmethod def from_array(cls, x: jax.Array | np.ndarray): return cls(jnp.shape(x), jnp.result_type(x)) def __str__(self): shape_repr = ','.join(str(x) for x in self.shape) return f'[dim]{self.dtype}[/dim][{shape_repr}]' @dataclasses.dataclass class CallInfo: call_order: int object_id: int type: type path: statelib.PathParts inputs_repr: str outputs: tp.Any flops: int | None vjp_flops: int | None class SimpleObjectRepr: def __init__(self, obj: tp.Any): self.type = type(obj) def __str__(self): return f'{self.type.__name__}(...)' def __repr__(self): return f'{self.type.__name__}(...)' def _to_dummy_array(x): if isinstance(x,jax.ShapeDtypeStruct): return ArrayRepr(x.shape, x.dtype) elif isinstance(x, jax.Array | np.ndarray): return ArrayRepr.from_array(x) elif graphlib.is_graph_node(x): return SimpleObjectRepr(x) else: return x def _pure_nnx_vjp(f, model, *args, **kwargs): "Wrap nnx functional api around jax.vjp. Only handles pure method calls." graphdef, state = nnx.split(model) def inner(state, *args, **kwargs): model = nnx.merge(graphdef, state) return f(model, *args, **kwargs) return jax.vjp(inner, state, *args, **kwargs) def filter_rng_streams(row: CallInfo): return not issubclass(row.type, nnx.RngStream) def _create_obj_env(object_types): "Turn a set of object types into a dictionary mapping (type, method name) pairs to methods" result = {} for obj_type in object_types: for name, top_method in inspect.getmembers(obj_type, inspect.isfunction): if not name.startswith('_') or name == '__call__': result[(obj_type, name)] = top_method return result def _get_inputs_repr(args, kwargs): input_args, input_kwargs = jax.tree.map( _to_dummy_array, (args, kwargs) ) inputs_repr = '' if input_args: if len(input_args) == 1 and not input_kwargs: inputs_repr += _as_yaml_str(input_args[0]) else: inputs_repr += _as_yaml_str(input_args) if input_kwargs: inputs_repr += '\n' if input_kwargs: inputs_repr += _as_yaml_str(input_kwargs) return inputs_repr def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen): "Wrap a function to save its arguments" # Used when computing vjp flops def do_vjp(*args, **kwargs): primals, f_vjp = jax.vjp(f, *args, **kwargs) return f_vjp(primals) method_name = f.__name__ @functools.partial(jax.jit) def jit_f(graphdef, state): args, kwargs = nnx.merge(graphdef, state) return f(*args, **kwargs) @wraps(f) def wrapper(obj, *args, **kwargs): inputs_repr = _get_inputs_repr(args, kwargs) object_id = getattr(obj, '_nnx_tabulate_id') node_info = node_stats[object_id] path = node_info.path if method_name != '__call__': path = (*path, method_name) identifier = (inputs_repr, object_id) counter_val = next(counter) graphdef, state = nnx.split(((obj, *args), kwargs)) if compute_flops: lowered = jit_f.lower(graphdef, state) flops = _get_flops(lowered) outputs = lowered.out_info else: flops = None outputs = jit_f(graphdef, state) if identifier not in seen: seen.add(identifier) output_repr = jax.tree.map(_to_dummy_array, outputs) vjp_flops = _get_flops(jax.jit(do_vjp).lower( obj, *args, **kwargs)) if compute_vjp_flops else None tracer_args.append( CallInfo(counter_val, object_id, type(obj), path, inputs_repr, output_repr, flops, vjp_flops)) return jit_f(graphdef, state) return wrapper def _overwrite_methods(env): "Overwrite methods with functions from an environment" for (obj_type, name), f in env.items(): setattr(obj_type, name, f) def _get_flops(e) -> int: cost = e.cost_analysis() or e.compile().cost_analysis() return 0 if cost is None or 'flops' not in cost else int(cost['flops']) def tabulate( obj, *input_args, depth: int | None = None, method: str = '__call__', row_filter: tp.Callable[[CallInfo], bool] = filter_rng_streams, table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), compute_flops: bool = False, compute_vjp_flops: bool = False, **input_kwargs, ) -> str: """Creates a summary of the graph object represented as a table. The table summarizes the object's state and metadata. The table is structured as follows: - The first column represents the path of the object in the graph. - The second column represents the type of the object. - The third column represents the input arguments passed to the object's method. - The fourth column represents the output of the object's method. - The following columns provide information about the object's state, grouped by Variable types. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.bn = nnx.BatchNorm(dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.2, rngs=rngs) ... ... def __call__(self, x): ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) ... >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.block1 = Block(32, 128, rngs=rngs) ... self.block2 = Block(128, 10, rngs=rngs) ... ... def __call__(self, x): ... return self.block2(self.block1(x)) ... >>> foo = Foo(nnx.Rngs(0)) >>> # print(nnx.tabulate(foo, jnp.ones((1, 32)))) Foo Summary ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ path ┃ type ┃ inputs ┃ outputs ┃ BatchStat ┃ Param ┃ RngState ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │ │ Foo │ float32[1,32] │ float32[1,10] │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1 │ Block │ float32[1,32] │ float32[1,128] │ 256 (1.0 KB) │ 4,480 (17.9 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/linear │ Linear │ float32[1,32] │ float32[1,128] │ │ bias: float32[128] │ │ │ │ │ │ │ │ kernel: float32[32,128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 4,224 (16.9 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/bn │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128] │ │ │ │ │ │ │ var: float32[128] │ scale: float32[128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/dropout │ Dropout │ float32[1,128] │ float32[1,128] │ │ │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2 │ Block │ float32[1,128] │ float32[1,10] │ 20 (80 B) │ 1,310 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/linear │ Linear │ float32[1,128] │ float32[1,10] │ │ bias: float32[10] │ │ │ │ │ │ │ │ kernel: float32[128,10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 1,290 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/bn │ BatchNorm │ float32[1,10] │ float32[1,10] │ mean: float32[10] │ bias: float32[10] │ │ │ │ │ │ │ var: float32[10] │ scale: float32[10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 20 (80 B) │ 20 (80 B) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/dropout │ Dropout │ float32[1,10] │ float32[1,10] │ │ │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ │ │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ └────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘ Total Parameters: 6,068 (24.3 KB) Note that ``block2/dropout`` is not shown in the table because it shares the same ``RngState`` with ``block1/dropout``. Args: obj: A object to summarize. It can a pytree or a graph objects such as nnx.Module or nnx.Optimizer. *input_args: Positional arguments passed to the object's method. **input_kwargs: Keyword arguments passed to the object's method. depth: The depth of the table. method: The method to call on the object. Default is ``'__call__'``. row_filter: A callable that filters the rows to be displayed in the table. By default, it filters out rows with type ``nnx.RngStream``. table_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table`` constructor. column_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table.add_column`` when adding columns to the table. console_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.console.Console` when rendering the table. Default arguments are ``'force_terminal': True``, and ``'force_jupyter'`` is set to ``True`` if the code is running in a Jupyter notebook, otherwise it is set to ``False``. compute_flops: whether to include a `flops` column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion's UNet, whereas otherwise tabulation would finish in 5 seconds). compute_vjp_flops: whether to include a `vjp_flops` column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of `compute_flops`. Returns: A string summarizing the object. """ _console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython} _console_kwargs.update(console_kwargs) obj = graphlib.clone(obj) # create copy to avoid side effects node_stats: NodeStats = {} object_types: set[type] = set() _collect_stats((), obj, node_stats, object_types) _variable_types: set[type] = { nnx.RngState # type: ignore[misc] if isinstance(leaf, nnx.RngState) else type(leaf) for _, leaf in nnx.to_flat_state(nnx.state(obj)) } variable_types: list[type] = sorted(_variable_types, key=lambda t: t.__name__) # Create a dictionary-version of the object's class. This makes # iteration over methods easier. env = _create_obj_env(object_types) # Information is recorded in post-order, but should be presented as a pre-order traversal. # This keeps track of the order of calls. counter = itertools.count(0) # Modify all the object's methods to save their lowered JIT representations. rows : list[CallInfo] = [] seen : set = set() jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen) for k, v in env.items()} _overwrite_methods(jits) # Trace the top function (which indirectly traces all the others) jits[(type(obj), method)](obj, *input_args, **input_kwargs) # Sort call info in pre-order traversal order rows.sort(key=lambda x: x.call_order) # Restore the object's original methods _overwrite_methods(env) if depth is not None: rows = [row for row in rows if len(row.path) <= depth and row_filter(row)] else: rows = [row for row in rows if row_filter(row)] rich_table = rich.table.Table( show_header=True, show_lines=True, show_footer=True, title=f'{type(obj).__name__} Summary', **table_kwargs, ) rich_table.add_column('path', **column_kwargs) rich_table.add_column('type', **column_kwargs) rich_table.add_column('inputs', **column_kwargs) rich_table.add_column('outputs', **column_kwargs) if compute_flops: rich_table.add_column('flops', **column_kwargs) if compute_vjp_flops: rich_table.add_column('vjp_flops', **column_kwargs) for var_type in variable_types: rich_table.add_column(var_type.__name__, **column_kwargs) for row in rows: node_info = node_stats[row.object_id] assert node_info is not None col_reprs: list[str] = [] path_str = '/'.join(map(str, row.path)) col_reprs.append(path_str) col_reprs.append(row.type.__name__) col_reprs.append(row.inputs_repr) col_reprs.append(_as_yaml_str(row.outputs)) if compute_flops: col_reprs.append(str(row.flops)) if compute_vjp_flops: col_reprs.append(str(row.vjp_flops)) for var_type in variable_types: attributes = {} variable: variablelib.Variable for name, variable in node_info.variable_groups[var_type].items(): value = variable.get_value() value_repr = _render_array(value) if _has_shape_dtype(value) else '' metadata = variable.get_metadata() for required_key in var_type.required_metadata: metadata.pop(required_key, None) if metadata: attributes[name] = { 'value': value_repr, **metadata, } elif value_repr: attributes[name] = value_repr # type: ignore[assignment] if attributes: col_repr = _as_yaml_str(attributes) + '\n\n' else: col_repr = '' size_bytes = node_info.stats.get(var_type) # type: ignore[call-overload] if size_bytes: col_repr += f'[bold]{size_bytes}[/bold]' col_reprs.append(col_repr) rich_table.add_row(*col_reprs) total_offset = 3 + int(compute_flops) + int(compute_vjp_flops) rich_table.columns[total_offset].footer = rich.text.Text.from_markup( 'Total', justify='right' ) node_info = node_stats[id(obj)] assert node_info is not None for i, var_type in enumerate(variable_types): size_bytes = node_info.stats[var_type] rich_table.columns[i + total_offset + 1].footer = str(size_bytes) rich_table.caption_style = 'bold' total_size = sum(node_info.stats.values(), SizeBytes(0, 0)) rich_table.caption = f'\nTotal Parameters: {total_size}' return _get_rich_repr(rich_table, _console_kwargs) def _get_rich_repr(obj, console_kwargs): f = io.StringIO() console = rich.console.Console(file=f, **console_kwargs) console.print(obj) return f.getvalue() def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: leaves = jax.tree.leaves(pytree) size = sum(x.size for x in leaves if hasattr(x, 'size')) num_bytes = sum( x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') ) return size, num_bytes def _size_and_bytes_repr(size: int, num_bytes: int) -> str: if not size: return '' bytes_repr = _bytes_repr(num_bytes) return f'{size:,} [dim]({bytes_repr})[/dim]' def _bytes_repr(num_bytes): count, units = ( (f'{num_bytes / 1e9:,.1f}', 'GB') if num_bytes > 1e9 else (f'{num_bytes / 1e6:,.1f}', 'MB') if num_bytes > 1e6 else (f'{num_bytes / 1e3:,.1f}', 'KB') if num_bytes > 1e3 else (f'{num_bytes:,}', 'B') ) return f'{count} {units}' def _has_shape_dtype(value): return hasattr(value, 'shape') and hasattr(value, 'dtype') def _normalize_values(x): if isinstance(x, type): return f'type[{x.__name__}]' elif isinstance(x, ArrayRepr | SimpleObjectRepr): return str(x) else: return repr(x) def _maybe_pytree_to_dict(pytree: tp.Any): path_leaves = jax.tree_util.tree_flatten_with_path(pytree)[0] path_leaves = [ (tuple(map(graphlib._key_path_to_key, path)), value) for path, value in path_leaves ] if len(path_leaves) < 1: return pytree elif len(path_leaves) == 1 and path_leaves[0][0] == (): return pytree else: return _unflatten_to_simple_structure(path_leaves, original=pytree) def _unflatten_to_simple_structure( xs: list[tuple[tuple[tp.Any, ...], tp.Any]], *, original: tp.Any ): """Rebuild a simple Python structure from path/value leaves. This variant is aware of the original object so it can: - Preserve empty containers that were elided by JAX flattening. - Pad trailing missing list/tuple items using the original length. - Distinguish placeholders for empty dict/list vs None. """ def _get_by_path(x, path: tuple[tp.Any, ...]): cur = x for k in path: cur = cur[k] return cur def _to_simple(x): # Convert to display-friendly simple structures if isinstance(x, (list, tuple)): return [_to_simple(e) for e in x] if isinstance(x, dict): return {k: _to_simple(v) for k, v in x.items()} return x result: list | dict = ( [] if len(xs) > 0 and isinstance(xs[0][0][0], int) else {} ) for path, value in xs: cursor = result for i, key in enumerate(path[:-1]): if isinstance(cursor, list): # Ensure list has slot for current key; infer placeholder from original while len(cursor) <= key: # path to the slot we are about to create slot_path = path[:i] + (len(cursor),) try: orig_slot = _get_by_path(original, slot_path) except Exception: orig_slot = None if isinstance(orig_slot, (list, tuple)): cursor.append([]) elif isinstance(orig_slot, dict): cursor.append({}) else: cursor.append(None) else: if key not in cursor: next_key = path[i + 1] if isinstance(next_key, int): cursor[key] = [] else: cursor[key] = {} cursor = cursor[key] if isinstance(cursor, list): # Handle gaps in indices caused by JAX flattening eliding empty containers while len(cursor) <= path[-1]: slot_path = path[:-1] + (len(cursor),) try: orig_slot = _get_by_path(original, slot_path) except Exception: orig_slot = None if isinstance(orig_slot, (list, tuple)): cursor.append([]) elif isinstance(orig_slot, dict): cursor.append({}) else: cursor.append(None) cursor[path[-1]] = value else: assert isinstance(cursor, dict) cursor[path[-1]] = value # If original is a sequence and result is a list, pad trailing items if isinstance(original, (list, tuple)) and isinstance(result, list): for i in range(len(result), len(original)): slot = original[i] result.append(_to_simple(slot)) return result def _as_yaml_str(value) -> str: if (hasattr(value, '__len__') and len(value) == 0) or value is None: return '' value = jax.tree.map(_normalize_values, value) value = _maybe_pytree_to_dict(value) file = io.StringIO() yaml.dump( value, file, Dumper=NoneDumper, default_flow_style=False, indent=2, sort_keys=False, explicit_end=False, ) return file.getvalue().replace('\n...', '').replace("'", '').strip() def _render_array(x): shape, dtype = jnp.shape(x), jnp.result_type(x) shape_repr = ','.join(str(x) for x in shape) return f'[dim]{dtype}[/dim][{shape_repr}]' def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: def _variable_parents_count(t: type): return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} return sorted(types, key=lambda t: type_sort_key[t]) ================================================ FILE: flax/nnx/tracers.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Taken from flax/core/tracer.py 🏴‍☠️ import jax import jax.core import treescope # type: ignore[import-not-found,import-untyped] from flax.nnx import reprlib def current_jax_trace(): """Returns the Jax tracing state.""" if jax.__version_info__ <= (0, 4, 33): return jax.core.thread_local_state.trace_state.trace_stack.dynamic return jax.core.get_opaque_trace_state(convention="nnx") class TraceState(reprlib.Representable): __slots__ = ['_jax_trace'] def __init__(self): self._jax_trace = current_jax_trace() @property def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: return self._jax_trace == current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') yield reprlib.Attr('jax_trace', self._jax_trace) def __treescope_repr__(self, path, subtree_renderer): return treescope.repr_lib.render_object_constructor( object_type=type(self), attributes={'jax_trace': self._jax_trace}, path=path, subtree_renderer=subtree_renderer, ) def __eq__(self, other): if jax.__version_info__ <= (0, 4, 33): return isinstance(other, TraceState) and self._jax_trace is other._jax_trace return isinstance(other, TraceState) and self._jax_trace == other._jax_trace # pickle support def __getstate__(self): return {} def __setstate__(self, state): self._jax_trace = current_jax_trace() def _flatten_trace_state(trace_state: TraceState): return (), None def _unflatten_trace_state(_1, _2): return TraceState() jax.tree_util.register_pytree_node( TraceState, _flatten_trace_state, _unflatten_trace_state, ) ================================================ FILE: flax/nnx/training/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: flax/nnx/training/metrics.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import typing as tp import numpy as np from flax import struct from flax.nnx import filterlib, graphlib from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Variable import jax, jax.numpy as jnp # TODO: add tests and docstrings class MetricState(Variable): """Wrapper class for Metric Variables.""" pass class Metric(Pytree): """Base class for metrics. Any class that subclasses ``Metric`` should implement a ``compute``, ``reset`` and ``update`` method.""" def __init__(self): raise NotImplementedError('Must override `__init__()` method.') def reset(self) -> None: """In-place reset the ``Metric``.""" raise NotImplementedError('Must override `reset()` method.') def update(self, **kwargs) -> None: """In-place update the ``Metric``.""" raise NotImplementedError('Must override `update()` method.') def compute(self): """Compute and return the value of the ``Metric``.""" raise NotImplementedError('Must override `compute()` method.') def split(self, *filters: filterlib.Filter): return graphlib.split(self, *filters) class Average(Metric): """Average metric. Example usage:: >>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Average() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(values=batch_loss) >>> metrics.compute() Array(2.5, dtype=float32) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Array(2., dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32) """ def __init__(self, argname: str = 'values'): """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value. For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with ``avg.update(test=new_value)``. Args: argname: an optional string denoting the key-word argument that :func:`update` will use to derive the new value. Defaults to ``'values'``. """ self.argname = argname self.total = MetricState(jnp.array(0, dtype=jnp.float32)) self.count = MetricState(jnp.array(0, dtype=jnp.int32)) def reset(self) -> None: """Reset this ``Metric``.""" self.total[...] = jnp.array(0, dtype=jnp.float32) self.count[...] = jnp.array(0, dtype=jnp.int32) def update(self, mask: jax.Array | None = None, **kwargs) -> None: """In-place update this ``Metric``. This method will use the value from ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is defined on construction. Args: mask: optional mask to ignore the values from the computation. We use `values * mask` rule and mask shape should be broadcastable to the shape of values array. **kwargs: the key-word arguments that contains a ``self.argname`` entry that maps to the value we want to use to update this metric. """ if self.argname not in kwargs: raise TypeError(f"Expected keyword argument '{self.argname}'") values: tp.Union[int, float, jax.Array] = kwargs[self.argname] if mask is not None and isinstance(values, (int, float)): raise ValueError(f"If mask is provided, {self.argname} should be a jax array") if isinstance(values, (int, float)): self.total[...] += values self.count[...] += 1 elif mask is None: self.total[...] += values.sum() self.count[...] += values.size else: self.total[...] += (values * mask.astype(values.dtype)).sum() self.count[...] += mask.sum().astype(self.count.dtype) def compute(self) -> jax.Array: """Compute and return the average.""" return self.total / self.count @struct.dataclass class Statistics: mean: jnp.float32 standard_error_of_mean: jnp.float32 standard_deviation: jnp.float32 class Welford(Metric): """Uses Welford's algorithm to compute the mean and variance of a stream of data. Example usage:: >>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Welford() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) >>> metrics.update(values=batch_loss) >>> metrics.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32)) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32)) >>> metrics.reset() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) """ def __init__(self, argname: str = 'values'): """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value. For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with ``wf.update(test=new_value)``. Args: argname: an optional string denoting the key-word argument that :func:`update` will use to derive the new value. Defaults to ``'values'``. """ self.argname = argname self.count = MetricState(jnp.array(0, dtype=jnp.int32)) self.mean = MetricState(jnp.array(0, dtype=jnp.float32)) self.m2 = MetricState(jnp.array(0, dtype=jnp.float32)) def reset(self) -> None: """Reset this ``Metric``.""" self.count[...] = jnp.array(0, dtype=jnp.uint32) self.mean[...] = jnp.array(0, dtype=jnp.float32) self.m2[...] = jnp.array(0, dtype=jnp.float32) def update(self, **kwargs) -> None: """In-place update this ``Metric``. This method will use the value from ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is defined on construction. Args: **kwargs: the key-word arguments that contains a ``self.argname`` entry that maps to the value we want to use to update this metric. """ if self.argname not in kwargs: raise TypeError(f"Expected keyword argument '{self.argname}'") values: tp.Union[int, float, jax.Array] = kwargs[self.argname] count = 1 if isinstance(values, (int, float)) else values.size original_count = self.count[...] self.count[...] += count delta = ( values if isinstance(values, (int, float)) else values.mean() ) - self.mean self.mean[...] += delta * count / self.count m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count self.m2[...] += m2 + delta * delta * count * original_count / self.count def compute(self) -> Statistics: """Compute and return the mean and variance statistics in a ``Statistics`` dataclass object. """ variance = self.m2 / self.count standard_deviation = variance**0.5 sem = standard_deviation / (self.count**0.5) return Statistics( mean=self.mean[...], standard_error_of_mean=sem, standard_deviation=standard_deviation, ) class Accuracy(Average): """Accuracy metric. This metric subclasses :class:`Average`, and so they share the same ``reset`` and ``compute`` method implementations. Unlike :class:`Average`, no string needs to be passed to ``Accuracy`` during construction. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([0, 1, 1, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> metrics = nnx.metrics.Accuracy() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(logits=logits, labels=labels) >>> metrics.compute() Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() Array(0.4, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32) >>> logits3 = jax.random.normal(jax.random.key(2), (5,)) >>> labels3 = jnp.array([0, 1, 0, 1, 1]) >>> accuracy = nnx.metrics.Accuracy(threshold=0.5) >>> accuracy.update(logits=logits3, labels=labels3) >>> accuracy.compute() Array(0.8, dtype=float32) """ def __init__(self, threshold: float | None = None, *args, **kwargs): """For binary classification, pass in a float denoting a threshold to determine if a prediction is positive. For example, constructing the metric as ``acc = Accuracy(threshold=0.5)`` would cause any logit greater than or equal to 0.5 to be interpreted as a positive classification. For multi-class classification, do not pass in a threshold. Args: threshold: for binary classification, determines if a prediction is positive. Defaults to None. """ if (threshold is not None) and (not isinstance(threshold, float)): raise TypeError(f'Expected threshold to be a float, got {type(threshold)}') self.threshold = threshold super().__init__(*args, **kwargs) def update( # type: ignore[override] self, *, logits: jax.Array, labels: jax.Array, mask: jax.Array | None = None, **_ ) -> None: """In-place update this ``Metric``. Args: logits: the outputted predicted activations. For multi-class classification, these values are argmax-ed (on the trailing dimension), before comparing them to the labels. For binary classification, these values are compared to the labels directly. labels: the ground truth integer labels. mask: optional mask to ignore the logits and labels values from the computation. We use `array * mask` rule and mask shape should be broadcastable to the shape of labels array. """ if self.threshold is not None: # Binary classification case if logits.ndim != labels.ndim: raise ValueError( 'For binary classification, expected logits.ndim==labels.ndim, got ' f'{logits.ndim} and {labels.ndim}' ) elif logits.ndim != labels.ndim + 1: # Multi-class classification case raise ValueError( 'For multi-class classification, expected logits.ndim==labels.ndim+1, ' f'got {logits.ndim} and {labels.ndim}' ) if labels.dtype in (jnp.int64, np.int32, np.int64): labels = jnp.astype(labels, jnp.int32) elif labels.dtype != jnp.int32: raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}') if self.threshold is not None: # Binary classification case super().update(values=((logits >= self.threshold) == (labels > 0)), mask=mask) return # Multi-class classification case super().update(values=(logits.argmax(axis=-1) == labels), mask=mask) class MultiMetric(Metric): """MultiMetric class to store multiple metrics and update them in a single call. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ... ) >>> metrics MultiMetric( # MetricState: 4 (16 B) accuracy=Accuracy( # MetricState: 2 (8 B) threshold=None, argname='values', total=MetricState( # 1 (4 B) value=Array(0., dtype=float32) ), count=MetricState( # 1 (4 B) value=Array(0, dtype=int32) ) ), loss=Average( # MetricState: 2 (8 B) argname='values', total=MetricState( # 1 (4 B) value=Array(0., dtype=float32) ), count=MetricState( # 1 (4 B) value=Array(0, dtype=int32) ) ) ) >>> metrics.accuracy Accuracy( # MetricState: 2 (8 B) threshold=None, argname='values', total=MetricState( # 1 (4 B) value=Array(0., dtype=float32) ), count=MetricState( # 1 (4 B) value=Array(0, dtype=int32) ) ) >>> metrics.loss Average( # MetricState: 2 (8 B) argname='values', total=MetricState( # 1 (4 B) value=Array(0., dtype=float32) ), count=MetricState( # 1 (4 B) value=Array(0, dtype=int32) ) ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([0, 1, 1, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() {'accuracy': Array(0.4, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} """ def __init__(self, **metrics): """Pass in key-word arguments to the constructor, e.g. ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``. Args: **metrics: the key-word arguments that will be used to access the corresponding ``Metric``. """ # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods self._metric_names = [] for metric_name, metric in metrics.items(): self._metric_names.append(metric_name) setattr(self, metric_name, metric) def reset(self) -> None: """Reset all underlying ``Metric``'s.""" for metric_name in self._metric_names: getattr(self, metric_name).reset() def update(self, **updates) -> None: """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All ``**updates`` will be passed to the ``update`` method of all underlying ``Metric``'s. Args: **updates: the key-word arguments that will be passed to the underlying ``Metric``'s ``update`` method. It can contain ``mask`` argument as mask to be passed to the underlying metrics. ``mask`` can be a ``jax.Array` and it will be passed to all the metrics. ``mask`` can be a dictionary with metric name as keys and ``jax.Array`` as values, e.g ``{metric_name1: metric_mask1, ...}`` """ # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo mask = updates.pop("mask", None) for metric_name in self._metric_names: metric_mask_kwarg = {} metric_mask = mask.get(metric_name, None) if isinstance(mask, dict) else mask if metric_mask is not None: metric_mask_kwarg = {"mask": metric_mask} getattr(self, metric_name).update(**(updates | metric_mask_kwarg)) def compute(self) -> dict[str, tp.Any]: """Compute and return the value of all underlying ``Metric``'s. This method will return a dictionary, mapping strings (defined by the key-word arguments ``**metrics`` passed to the constructor) to the corresponding metric value. """ return { f'{metric_name}': getattr(self, metric_name).compute() for metric_name in self._metric_names } ================================================ FILE: flax/nnx/training/optimizer.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import functools import typing as tp import jax import jax.numpy as jnp import optax from flax import nnx from flax.nnx import filterlib from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Variable M = tp.TypeVar('M', bound=nnx.Module) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) class OptState(Variable): """Any optimizer state""" pass class OptArray(OptState): """Optimizer state for an array.""" pass class OptVariable(OptState): """Optimizer state for a Variable.""" pass def to_opt_state(tree): def _to_opt_state(x): if isinstance(x, Variable): opt_state = OptVariable(x.get_value(), **x.get_metadata()) # type: ignore else: opt_state = OptArray(x) return opt_state tree = jax.tree.map( _to_opt_state, tree, is_leaf=lambda x: isinstance(x, Variable), ) return tree class _Missing: pass MISSING = _Missing() def _check_grads_arg_passed(f: F) -> F: @functools.wraps(f) def _check_grads_wrapper(self, model, grads=MISSING, **kwargs): if isinstance(grads, _Missing): raise TypeError( 'Missing required argument `grads`. As of Flax 0.11.0 update requires both (model, grads) arguments ' 'to be passed. If you want to keep the previous use nnx.ModelAndOptimizer instead of nnx.Optimizer.' ) return f(self, model, grads, **kwargs) return _check_grads_wrapper # type: ignore def _check_wrt_arg_passed(f: F) -> F: @functools.wraps(f) def _check_wrt_wrapper(*args, wrt=MISSING, **kwargs): if isinstance(wrt, _Missing): raise TypeError( 'Missing required argument `wrt`. As of Flax 0.11.0 the `wrt` argument is required, ' 'if you want to keep the previous use nnx.ModelAndOptimizer instead of nnx.Optimizer.' ) return f(*args, wrt=wrt, **kwargs) return _check_wrt_wrapper # type: ignore class Optimizer(Pytree, tp.Generic[M]): """Simple train state for the common case with a single Optax optimizer. Example usage:: >>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(2.3359997, dtype=float32) >>> grads = nnx.grad(loss_fn)(model) >>> optimizer.update(model, grads) >>> loss_fn(model) Array(2.310461, dtype=float32) Attributes: step: An ``OptState`` :class:`Variable` that tracks the step count. tx: An Optax gradient transformation. opt_state: The Optax optimizer state. """ @_check_wrt_arg_passed def __init__( self, model: M, tx: optax.GradientTransformation, *, wrt: filterlib.Filter, # type: ignore ): """ Instantiate the class and wrap the :class:`Module` and Optax gradient transformation. Instantiate the optimizer state to keep track of :class:`Variable` types specified in ``wrt``. Set the step count to 0. Args: model: An NNX Module. tx: An Optax gradient transformation. wrt: filter to specify for which :class:`Variable`'s to keep track of in the optimizer state. These should be the :class:`Variable`'s that you plan on updating; i.e. this argument value should match the ``wrt`` argument passed to the ``nnx.grad`` call that will generate the gradients that will be passed into the ``grads`` argument of the :func:`update` method. The filter should match the filter used in nnx.grad. """ self.step = OptState(jnp.array(0, dtype=jnp.uint32)) self.tx = tx self.opt_state = nnx.data( to_opt_state(tx.init(nnx.state(model, wrt))) ) self.wrt = wrt if not tp.TYPE_CHECKING: def __getattribute__(self, name: str) -> tp.Any: if name == 'model' and name not in vars(self): raise AttributeError( f"{type(self).__name__} does not have attribute 'model' since Flax 0.11.0. " "To keep the previous behavior, use nnx.ModelAndOptimizer instead of nnx.Optimizer." ) return super().__getattribute__(name) @_check_grads_arg_passed def update(self, model: M, grads, /, **kwargs): """Updates the optimizer state and model parameters given the gradients. Example:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.count = nnx.Variable(jnp.array(0)) ... ... def __call__(self, x): ... self.count[...] += 1 ... return self.linear(x) ... >>> model = Model(rngs=nnx.Rngs(0)) ... >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) >>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))( ... model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) >>> optimizer.update(model, grads) Note that internally this function calls ``.tx.update()`` followed by a call to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. Args: grads: the gradients derived from ``nnx.grad``. **kwargs: additional keyword arguments passed to the tx.update, to support ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ param_arrays = nnx.pure(nnx.state(model, self.wrt)) grad_arrays = nnx.pure(nnx.state(grads, self.wrt)) opt_state_arrays = nnx.pure(self.opt_state) kwargs_arrays = nnx.pure(kwargs) updates, new_opt_state = self.tx.update( grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays ) new_params = optax.apply_updates(param_arrays, updates) nnx.update(model, new_params) nnx.update(self.opt_state, nnx.state(new_opt_state)) self.step[...] += 1 class ModelAndOptimizer(Optimizer[M]): """A convenience class that combines a model and an optimizer. This class is deprecated and will be removed in a future release. Use :class:`Optimizer` instead. """ def __init__(self, model: M, tx: optax.GradientTransformation, *, wrt: filterlib.Filter = nnx.Param): super().__init__(model, tx, wrt=wrt) self.model = model def update(self, grads, /, **kwargs): # type: ignore return super().update(self.model, grads, **kwargs) ================================================ FILE: flax/nnx/transforms/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: flax/nnx/transforms/autodiff.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import deque import dataclasses import functools import typing as tp from flax import struct from flax.nnx import ( extract, filterlib, graphlib, variablelib, ) from flax.nnx.statelib import State import jax from flax.nnx.transforms import general from flax.nnx.transforms.transforms import ( resolve_kwargs, _resolve_bound_callable, _raise_bound_method_error, ) from flax.typing import MISSING, Missing A = tp.TypeVar('A') # C = tp.TypeVar('C') # B = tp.TypeVar('B') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) # G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) # M = tp.TypeVar('M', bound=Module) # MA = tp.TypeVar('MA', bound=Module) # N = tp.TypeVar('N', bound=Module) # StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable # Leaves = tp.List[Leaf] # Index = int # ------------------------------- # grad # ------------------------------- @dataclasses.dataclass(frozen=True) class DiffState: argnum: int filter: filterlib.Filter @dataclasses.dataclass(eq=False) class SimpleGradFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('grad', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: loss, aux = out return loss, (updates, aux) else: return out, updates @dataclasses.dataclass(eq=False) class GradFn: f: tp.Callable[..., tp.Any] has_aux: bool nondiff_states: deque[State | None] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): # rebuild diff_state from substates in args def _grad_merge_fn( ctx: graphlib.MergeContext, path, prefix, value: extract.NodeStates ): nondiff = self.nondiff_states.popleft() if nondiff is None: return ctx.merge(value.graphdef, value.state) else: return ctx.merge(value.graphdef, value.state, nondiff) args = extract.from_tree( pure_args, merge_fn=_grad_merge_fn, ctxtag='grad', is_inner=True ) out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag='grad') if self.has_aux: loss, pure_aux = pure_out fn_out = (loss, (pure_args_out, pure_aux)) else: loss = pure_out fn_out = (loss, pure_args_out) return fn_out def _grad_general( f: tp.Callable[..., tp.Any], argnums: int | DiffState | tp.Sequence[int | DiffState], has_aux: bool, holomorphic: bool, allow_int: bool, return_value: bool, graph: bool, graph_updates: bool, ) -> tp.Callable[..., tp.Any]: transform = jax.value_and_grad if return_value else jax.grad if not graph or not graph_updates: if any(isinstance(x, DiffState) for x in jax.tree.leaves(argnums)): raise ValueError( '`argnums` cannot contain `DiffState` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('grad') ) gradded_fn = transform( SimpleGradFn(f, has_aux, graph=graph), argnums=argnums, # type: ignore[arg-type] has_aux=True, holomorphic=holomorphic, allow_int=allow_int, ) def tree_grad_wrapper(*args, **kwargs): if graph: diff_argnums = (argnums,) if isinstance(argnums, int) else argnums args_prefix = tuple( i in diff_argnums for i in range(len(args)) ) args, kwargs = extract.to_tree2( (args, kwargs), prefix=(args_prefix, False), ) extract.check_no_aliases('grad', args=args, kwargs=kwargs) fn_out = gradded_fn(*args, **kwargs) if return_value: if has_aux: (loss, (updates, aux)), grads = fn_out if graph: grads, aux = extract.from_tree2((grads, aux)) result = (loss, aux), grads else: (loss, updates), grads = fn_out if graph: grads = extract.from_tree2(grads) result = loss, grads else: if has_aux: grads, (updates, aux) = fn_out if graph: grads, aux = extract.from_tree2((grads, aux)) result = grads, aux else: grads, updates = fn_out if graph: grads = extract.from_tree2(grads) result = grads extract.apply_variable_updates((args, kwargs), updates) return result return tree_grad_wrapper jax_argnums: int | tuple[int, ...] if isinstance(argnums, (int, DiffState)): jax_argnums = argnums.argnum if isinstance(argnums, DiffState) else argnums else: jax_argnums = tuple( x.argnum if isinstance(x, DiffState) else x for x in argnums ) _argnums = (argnums,) if isinstance(argnums, (int, DiffState)) else argnums index_filter: dict[int, DiffState] = {} for argnum in _argnums: index = argnum.argnum if isinstance(argnum, DiffState) else argnum if index in index_filter: raise ValueError(f'argnum {index} is repeated in argnums') index_filter[index] = ( dataclasses.replace(argnum, argnum=-1) if isinstance(argnum, DiffState) else DiffState(-1, variablelib.Param) ) @graphlib.update_context('grad') def grad_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) del kwargs nondiff_states: deque[State | variablelib.Variable | None] = deque() def _grad_split_fn( ctx: graphlib.SplitContext, path, prefix: DiffState | None, value ): if prefix is None or (prefix.argnum == -1 and isinstance(value, variablelib.Variable)): nondiff_states.append(None) return extract.NodeStates.from_split(*ctx.split(value)) else: graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) # type: ignore[misc] nondiff_states.append(nondiff) # type: ignore[container-type-mismatch] return extract.NodeStates.from_split(graphdef, diff) arg_filters = tuple(index_filter.get(i) for i in range(len(args))) pure_args = extract.to_tree( args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad' ) gradded_fn = transform( GradFn(f, has_aux, nondiff_states), argnums=jax_argnums, has_aux=True, holomorphic=holomorphic, allow_int=allow_int, ) fn_out = gradded_fn(*pure_args) def process_grads(grads): return jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, grads, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) def process_out(pure_out: A, /) -> A: return extract.from_tree(pure_out, ctxtag='grad', is_inner=False) if return_value: # unpack value_and_grad output if has_aux: (loss, (pure_args_out, pure_aux)), grads = fn_out grads = process_grads(grads) _args_out, aux = process_out((pure_args_out, pure_aux)) return (loss, aux), grads else: (loss, pure_args_out), grads = fn_out grads = process_grads(grads) _args_out = process_out(pure_args_out) return loss, grads else: # unpack grad output if has_aux: grads, (pure_args_out, pure_aux) = fn_out grads = process_grads(grads) _args_out, aux = process_out((pure_args_out, pure_aux)) return grads, aux else: grads, pure_args_out = fn_out grads = process_grads(grads) _args_out = process_out(pure_args_out) return grads return grad_wrapper @tp.overload def grad( f: tp.Callable[..., tp.Any], *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... @tp.overload def grad( *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... def grad( f: tp.Callable[..., tp.Any] | Missing = MISSING, *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Object-aware version of ``jax.grad`` that can handle Modules / graph nodes as arguments. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) }) By default, NNX objects are differentiated with respect to all their ``nnx.Param`` Variables. You can specify which substates are differentiable by passing a ``DiffState`` object to the ``argnums`` argument. For example, if you want to differentiate only the ``kernel`` attribute of the ``Linear`` class, you can use the ``PathContains`` filter:: >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) ... >>> kernel_attribute = nnx.PathContains('kernel') >>> diff_state = nnx.DiffState(0, kernel_attribute) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn, argnums=diff_state) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'kernel': Param( value=(2, 3) ) }) For more information on how to create custom filters, see `Using Filters `__ guide. Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if reduce_axes: raise NotImplementedError('reduce_axes argument to grad is deprecated') del reduce_axes if isinstance(f, Missing): return functools.partial( grad, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('grad') return _grad_general( f_unbound, argnums, has_aux, holomorphic, allow_int, return_value=False, graph=graph, graph_updates=graph_updates, ) @tp.overload def value_and_grad( f: tp.Callable[..., tp.Any], *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... @tp.overload def value_and_grad( *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... def value_and_grad( f: tp.Callable[..., tp.Any] | type[Missing] = Missing, *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Object-aware version of ``jax.value_and_grad``. Like :func:`grad`, but returns both the value and the gradient of ``f``. Args: f: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``f`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. Returns: A function with the same arguments as ``f`` that evaluates both ``f`` and the gradient of ``f`` and returns them as a pair (a two-element tuple). If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If ``has_aux`` is True then a tuple of ((value, auxiliary_data), gradient) is returned. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if reduce_axes: raise NotImplementedError( 'reduce_axes argument to value_and_grad is deprecated') del reduce_axes if f is Missing: return functools.partial( value_and_grad, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('value_and_grad') return _grad_general( f_unbound, argnums, has_aux, holomorphic, allow_int, return_value=True, graph=graph, graph_updates=graph_updates, ) # ----------------------------------------------- # vjp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleVjpFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out return primals_out, (updates, aux) else: return out, updates @tp.overload def vjp( f: tp.Callable[..., tp.Any], *primals: tp.Any, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any]: ... @tp.overload def vjp( *, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... def vjp( f: tp.Callable[..., tp.Any] | Missing = MISSING, *primals: tp.Any, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Stateful version of ``jax.vjp`` that propagates NNX Variable updates. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def loss_fn(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = nnx.vjp(loss_fn, m, x, graph=False) >>> m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out)) Can also be used as a decorator:: >>> @nnx.vjp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = f(m, x) Args: f: Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables. *primals: A sequence of primal values at which the Jacobian of ``f`` should be evaluated. has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. reduce_axes: Deprecated, do not use. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: If ``has_aux`` is False, returns a ``(primals_out, vjp_fn)`` pair. ``vjp_fn`` takes a cotangent with the same structure as ``primals_out`` and returns gradients for each primal argument. If ``has_aux`` is True, returns ``(primals_out, vjp_fn, aux)``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph and graph_updates: raise NotImplementedError( 'graph-mode with graph_updates is not supported for nnx.vjp. ' 'Set graph=False or graph_updates=False.' ) if reduce_axes: raise NotImplementedError('reduce_axes argument to vjp is deprecated') del reduce_axes if isinstance(f, Missing): return functools.partial( # type: ignore[return-value] vjp, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('vjp') if not primals: return functools.partial( # type: ignore[return-value] vjp, f, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) if graph: primals = extract.to_tree2(primals) extract.check_no_aliases('vjp', primals=primals) primals_out, vjp_fn, aux = jax.vjp( SimpleVjpFn(f_unbound, has_aux=has_aux, graph=graph), *primals, has_aux=True, ) if has_aux: updates, user_aux = aux else: updates = aux user_aux = None if graph: primals_out = extract.from_tree2(primals_out) raw_vjp_fn = vjp_fn def vjp_fn(g): return extract.from_tree2(raw_vjp_fn(g)) extract.apply_variable_updates(primals, updates) if has_aux: return primals_out, vjp_fn, user_aux else: return primals_out, vjp_fn # ----------------------------------------------- # jvp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleJvpFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('jvp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out return (primals_out, updates), aux else: return out, updates @tp.overload def jvp( f: tp.Callable[..., tp.Any], primals: tuple[tp.Any, ...], tangents: tuple[tp.Any, ...], *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tuple[tp.Any, ...]: ... @tp.overload def jvp( *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... @tp.overload def jvp( f: tp.Callable[..., tp.Any], *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... def jvp( f: tp.Callable[..., tp.Any] | Missing = MISSING, primals: tuple[tp.Any, ...] | Missing = MISSING, tangents: tuple[tp.Any, ...] | Missing = MISSING, *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tuple[tp.Any, ...] | tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Stateful version of ``jax.jvp`` that propagates NNX Variable updates. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def f(m, x): ... return jnp.sum(m(x)) ... >>> m_tangent = jax.tree.map(jnp.zeros_like, m) >>> x_tangent = jnp.ones_like(x) >>> primals_out, tangent_out = nnx.jvp( ... f, (m, x), (m_tangent, x_tangent), graph=False ... ) Can also be used as a decorator:: >>> @nnx.jvp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, tangent_out = f((m, x), (m_tangent, x_tangent)) Args: f: Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables. primals: A tuple of primal values at which the Jacobian of ``f`` should be evaluated. tangents: A tuple of tangent vectors, with the same structure as ``primals``. has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: If ``has_aux`` is False, returns ``(primals_out, tangent_out)``. If ``has_aux`` is True, returns ``(primals_out, tangent_out, aux)``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph and graph_updates: raise NotImplementedError( 'graph-mode with graph_updates is not supported for nnx.jvp. ' 'Set graph=False or graph_updates=False.' ) if isinstance(f, Missing): return functools.partial( jvp, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('jvp') if isinstance(primals, Missing) or isinstance(tangents, Missing): return functools.partial( jvp, f, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) if graph: primals = extract.to_tree2(primals) tangents = extract.to_tree2(tangents) extract.check_no_aliases('jvp', primals=primals) extract.check_no_aliases('jvp', tangents=tangents) if has_aux: (primals_out, updates), (tangent_out, _updates_tangent), aux = jax.jvp( SimpleJvpFn(f_unbound, has_aux=True, graph=graph), primals, tangents, has_aux=True, ) else: (primals_out, updates), (tangent_out, _updates_tangent) = jax.jvp( SimpleJvpFn(f_unbound, has_aux=False, graph=graph), primals, tangents, ) if graph: primals_out = extract.from_tree2(primals_out) tangent_out = extract.from_tree2(tangent_out) extract.apply_variable_updates(primals, updates) if has_aux: return primals_out, tangent_out, aux else: return primals_out, tangent_out # ----------------------------------------------- # custom_vjp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleCustomVjpFn: f: tp.Callable[..., tp.Any] graph: bool nondiff_argnums: tuple[int, ...] def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('custom_vjp', args=updates, out=out) diff_prefix = tuple( i not in self.nondiff_argnums for i in range(len(args)) ) def keep_fn(path, diff_arg, cur, snap): assert isinstance(diff_arg, bool) changed = extract.variable_changed(cur, snap) if diff_arg and changed: raise ValueError( f'Variables in differentiable argument were mutated inside ' f'custom_vjp at {jax.tree_util.keystr(path)}.\n' f'This is not supported when ' f'graph_updates=False because the gradient for the Variable ' f'updates would be silently dropped. Move the Variable mutation ' f'to a non-differentiable argument, or use graph_updates=True.' ) return changed updates = extract.mask_variable_updates( updates, snapshot, prefix=diff_prefix, keep_fn=keep_fn, ) return out, updates @dataclasses.dataclass(eq=False) class SimpleFwdFn: fwd: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.fwd, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out, residual = self.fwd(*args) if self.graph: out = extract.to_tree2(out) residual = extract.to_tree2(residual) extract.check_no_aliases('custom_vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return (out, updates), residual @dataclasses.dataclass(eq=False) class SimpleBwdFn: bwd: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.bwd, updated=()) @extract.treemap_copy_args def __call__(self, *args): *nondiff, residual, (out_g, _updates_g) = args if self.graph: nondiff = extract.from_tree2(nondiff) residual = extract.from_tree2(residual) result = self.bwd(*nondiff, residual, out_g) if self.graph: result = extract.to_tree2(result) return result class SimpleCustomVjp(tp.Generic[A]): def __init__( self, fun: tp.Callable[..., A], nondiff_argnums: tuple[int, ...], graph: bool, ): functools.update_wrapper(self, fun) self.fun = fun self.nondiff_argnums = nondiff_argnums self.graph = graph self.custom_vjp_fn = jax.custom_vjp( fun=SimpleCustomVjpFn(fun, graph=graph, nondiff_argnums=nondiff_argnums), nondiff_argnums=nondiff_argnums, ) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: args = resolve_kwargs(self.fun, args, kwargs) del kwargs if self.graph: prefix = tuple( i not in self.nondiff_argnums for i in range(len(args)) ) args = extract.to_tree2(args, prefix=prefix) extract.check_no_aliases('custom_vjp', args=args) (out, updates) = self.custom_vjp_fn(*args) if self.graph: out = extract.from_tree2(out) extract.apply_variable_updates(args, updates) return out def defvjp( self, fwd: tp.Callable[..., tuple[A, tp.Any]], bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: self.fwd = fwd self.bwd = bwd self.symbolic_zeros = symbolic_zeros self.custom_vjp_fn.defvjp( fwd=SimpleFwdFn(fwd, graph=self.graph), bwd=SimpleBwdFn(bwd, graph=self.graph), symbolic_zeros=symbolic_zeros, ) # custom_vjp is one of the most complicated transforms as it requires # to handle 4 different functions: # 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes # to pytrees and output pytrees to graph nodes. # 2. CustomVjpFnWrapper: function that wraps the user's function, it converts # its input pytrees to graph nodes and output graph nodes to pytrees. # 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes # and output graph nodes to pytrees. Since it might run by itself in a separate context, # it needs to be aware if the update_context is active or not in order to update the outer # referenes. # 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes # and output graph nodes to pytrees. It doesn't need to be aware of the outer context # since it will never update the outer references as it runs during the backward pass. def _custom_vjp_merge_fn( ctx: graphlib.MergeContext, path, prefix: bool | DiffState, value: extract.NodeStates, *, nondiff_states: deque[extract.GraphDefState], ): nondiff = nondiff_states.popleft() return ctx.merge(nondiff.graphdef, value.state, nondiff.state) def _custom_vjp_split_fn( ctx: graphlib.SplitContext, path, prefix: bool | DiffState, value, *, nondiff_states: list[extract.GraphDefState], ): broadcast: graphlib.GraphState if prefix is False: # pure non-differentiable arg, not supported raise TypeError( 'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. ' f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}' ) elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef # in order to keep the gradients clean from any metadata graphdef, passed = ctx.split(value) broadcast = State({}) nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) return extract.NodeStates.from_states(passed) else: # differentiable arg with DiffState filter, we use the filter to split the state # as before we return a TreeNode.from_states to keep the gradients clean # from any metadata, the non-differentiable state is stored in a deque # which is broadcasted during the forward pass graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc] nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) return extract.NodeStates.from_states(passed) nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) def _extract_nodedefs(x, *, nodedefs: deque[graphlib.GraphDef]): if isinstance(x, graphlib.GraphDef): nodedefs.append(x) return x.with_no_outer_index() return x @dataclasses.dataclass(eq=False) class CustomVjpFnWrapper: f: tp.Callable[..., tp.Any] jax_nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] nodedefs: deque[graphlib.GraphDef] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag, is_inner=True, ) out = self.f(*args) # remove nondiff from pure_args_out_g args_out = tuple( x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums ) args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) # remove outer_index from GraphDef's but store them in global context pure_args_out, pure_out = jax.tree.map( functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) return pure_args_out, pure_out @dataclasses.dataclass(eq=False) class FwdFn: fwd: tp.Callable[..., tp.Any] nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] nodedefs: deque[graphlib.GraphDef] def __post_init__(self): functools.update_wrapper(self, self.fwd) def __call__(self, *pure_args): # here we need to be aware if the update_context is active or not # when its not active, index_mappings will be None # when its active, we will remove the index_mappings from the GraphDef's and store them # in the index_mappings deque created by CustomVjp update_context_active = ( self.ctxtag in graphlib.GRAPH_CONTEXT.update_context_stacks ) nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag if update_context_active else None, is_inner=True, ) out, residual = self.fwd(*args) # remove nondiff from pure_args_out_g args_out = tuple( x for i, x in enumerate(args) if i not in self.nondiff_argnums ) args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag if update_context_active else None, ) pure_residual = extract.to_tree(residual) if update_context_active: # remove outer_index from GraphDef's but store them in global context pure_args_out, pure_out = jax.tree.map( functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) return (pure_args_out, pure_out), pure_residual @dataclasses.dataclass(eq=False) class BwdFn: bwd: tp.Callable[..., tp.Any] tree_node_args: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.bwd) def __call__(self, *args): *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args residual = extract.from_tree(pure_residual, is_inner=True) (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, (pure_args_out_g, pure_out_g), is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g)) def state_to_node_states(is_differentiable: bool, x): if is_differentiable: if isinstance(x, jax.Array): return x elif not isinstance(x, State | variablelib.Variable): raise ValueError(f'Expected State or Variable, got {type(x)}') return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( state_to_node_states, self.tree_node_args, tangent, is_leaf=lambda x: isinstance(x, State | variablelib.Variable), ) return pure_tangent class CustomVjp(tp.Generic[A]): def __init__( self, fun: tp.Callable[..., A], nondiff_argnums: tuple[int | DiffState, ...], ): functools.update_wrapper(self, fun) # first argument is metadata self.jax_nondiff_argnums = tuple( x for x in nondiff_argnums if isinstance(x, int) ) self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' self.fun = fun self.fwd: tp.Callable | None = None self.bwd: tp.Callable | None = None self.symbolic_zeros: bool | None = None self.nondiff_argnums = nondiff_argnums self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} for argnum in self.nondiff_argnums: index = argnum.argnum if isinstance(argnum, DiffState) else argnum if index in self.diff_filter: raise ValueError(f'argnum {index} is repeated in nondiff_argnums') self.diff_filter[index] = ( dataclasses.replace(argnum, argnum=-1) if isinstance(argnum, DiffState) else False ) # def __getattr__(self, name: str) -> tp.Any: # if not hasattr(self.custom_vjp_fn, name): # raise AttributeError(f'{type(self).__name__} has no attribute {name}') # return getattr(self.custom_vjp_fn, name) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: # pytype: disable=invalid-annotation with graphlib.update_context(self.ctxtag): args = resolve_kwargs(self.fun, args, kwargs) del kwargs nondiff_states: list[extract.GraphDefState] = [] arg_filters = tuple( self.diff_filter.get(i, True) for i in range(len(args)) ) pure_args = extract.to_tree( args, prefix=arg_filters, split_fn=functools.partial( _custom_vjp_split_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag, ) tree_node_args = jax.tree.map( lambda x: isinstance(x, extract.NodeStates), pure_args, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tree_node_args = tuple( x for i, x in enumerate(tree_node_args) if i not in self.jax_nondiff_argnums ) nodedefs: deque[graphlib.GraphDef] = deque() if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: raise ValueError() custom_vjp_fn = jax.custom_vjp( fun=CustomVjpFnWrapper( f=self.fun, jax_nondiff_argnums=self.jax_nondiff_argnums, ctxtag=self.ctxtag, nondiff_states=nondiff_states, nodedefs=nodedefs, ), nondiff_argnums=self.jax_nondiff_argnums, ) custom_vjp_fn.defvjp( fwd=FwdFn( fwd=self.fwd, nondiff_argnums=self.jax_nondiff_argnums, ctxtag=self.ctxtag, nondiff_states=nondiff_states, nodedefs=nodedefs, ), bwd=BwdFn( bwd=self.bwd, tree_node_args=tree_node_args, ), symbolic_zeros=self.symbolic_zeros, ) pure_args_out, pure_out = custom_vjp_fn(*pure_args) # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graphlib.GraphDef): nodedef = nodedefs.popleft() return nodedef return x pure_args_out, pure_out = jax.tree_util.tree_map( _insert_index_mappings, (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=self.ctxtag, is_inner=False ) return out def defvjp( self, fwd: tp.Callable[..., tuple[A, tp.Any]], bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: self.fwd = fwd self.bwd = bwd self.symbolic_zeros = symbolic_zeros @tp.overload def custom_vjp( fun: tp.Callable[..., A], *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> CustomVjp[A] | SimpleCustomVjp[A]: ... @tp.overload def custom_vjp( *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]: ... def custom_vjp( fun: tp.Callable[..., A] | Missing = MISSING, *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> CustomVjp[A] | SimpleCustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]: """Reference aware version of `jax.custom_vjp `__. ``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference with the JAX version is that, because Modules follow reference semantics, they propagate the State updates for the inputs as auxiliary outputs. This means that the incoming gradients in the ``bwd`` function will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tangent will be expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms. Example:: >>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'][...] = cos_x * out_g * m.y ... m_g['y'][...] = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': Param( value=() ), 'y': Param( value=() ) }) Note that the State objects that represent Module terms on ``input_updates_g`` have the same shape as the State objects expected in the output tanget. This means that you can usually just copy them from ``input_updates_g`` and update them with their corresponding gradient values. You can select which substates are differentiable (have a tangent) for Modules and other graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following:: >>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x[...] = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': Param( value=() ) }) Note that ``grad`` cannot calculate gradients for states that don't have a tangent defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute`` filter to keep ``custom_vjp`` and ``grad`` in sync. **graph_updates=False** When ``graph_updates=False`` or ``graph=False``, the behavior is closer to ``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and tangent types are the same as the input types, this means the tangent for a Module is a Module instance with gradient values set on its attributes. This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally, Variables in differentiable arguments cannot be mutated inside ``f``. If mutations are needed, pass the relevant Variables through a non-differentiable argument instead. Example:: >>> @nnx.custom_vjp(graph_updates=False) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... cos_x, sin_x, m = res ... m_g = nnx.clone(m) ... m_g.x[...] = cos_x * g * m.y ... m_g.y[...] = sin_x * g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) Args: fun: Callable base function. nondiff_argnums: Tuple of integers or DiffState objects specifying the argument indices that are not differentiated. By default all arguments are differentiated. Integers cannot be used to mark graph nodes such as Modules as non-differentiable, in this case use a DiffState object. DiffState objects define the set of differentiable substates, contrary to what the name of this argument suggests, this is done for compatibility with ``grad``. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` in ``nondiff_argnums``. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(fun, Missing): return functools.partial( custom_vjp, nondiff_argnums=nondiff_argnums, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. fun_unbound, _, was_bound = _resolve_bound_callable(fun) if was_bound: _raise_bound_method_error('custom_vjp') if not graph or not graph_updates: if any(isinstance(x, DiffState) for x in nondiff_argnums): raise ValueError( '`nondiff_argnums` cannot contain `DiffState` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('custom_vjp') ) return SimpleCustomVjp(fun_unbound, nondiff_argnums, graph=graph) # type: ignore[arg-type] return CustomVjp(fun_unbound, nondiff_argnums) # ------------------------------- # remat # ------------------------------- @dataclasses.dataclass(eq=False) class SimpleRematFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('remat', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @tp.overload def remat( *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def remat( f: F, *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... def remat( f: F | Missing = MISSING, *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """A 'lifted' version of the `jax.checkpoint `__ (a.k.a. ``jax.remat``). ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus how they are recomputed during the backward pass, trading off memory and FLOPs. Learn more in `Flax NNX vs JAX Transformations `_. To learn about ``jax.remat``, go to JAX's `fundamentals of jax.checkpoint `_ and `practical notes `_. Args: f: Function to be rematerialized. prevent_cse: Optional, bool. If True, prevents common subexpression elimination. Default True. static_argnums: Optional, int or tuple of ints. Specifies which positional arguments to treat as static. policy: Optional, callable. A policy for which intermediates to save during the forward pass. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(f, Missing): return functools.partial( remat, prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('remat') if not graph or not graph_updates: checkpointed_fn = jax.checkpoint( SimpleRematFn(f_unbound, graph=graph), prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, ) @functools.wraps(f_unbound) def simple_remat_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2((args, kwargs)) extract.check_no_aliases('remat', args=args, kwargs=kwargs) out, updates = checkpointed_fn(*args, **kwargs) if graph: out = extract.from_tree2(out) extract.apply_variable_updates((args, kwargs), updates) return out return simple_remat_wrapper # type: ignore[return-value] # Unbound function path: preserve the concise composition used in NNX. return resolve_kwargs()( # type: ignore[return-value] graphlib.update_context('remat')( general.split_inputs( jax.checkpoint( general.merge_inputs(f_unbound, ctxtag='remat'), prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, ), ctxtag='remat', ), ) ) ================================================ FILE: flax/nnx/transforms/compilation.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pytype: skip-file from __future__ import annotations import dataclasses import functools import inspect import operator import typing as tp import jax from jax.sharding import AbstractMesh, Mesh, PartitionSpec from flax.nnx import ( extract, filterlib, graphlib, statelib, variablelib, ) from flax.nnx.transforms.transforms import ( _resolve_bound_callable, _raise_bound_method_error, ) from flax.typing import MISSING, Missing, PathParts F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) P = tp.ParamSpec('P') R = tp.TypeVar('R') Specs = tp.Any AxisName = tp.Hashable # ------------------------------- # jit # ------------------------------- class StateSharding(extract.PrefixMapping): def __init__( self, filter_sharding: statelib.State | tp.Mapping[filterlib.Filter, tp.Any] | tp.Iterable[tuple[filterlib.Filter, tp.Any]], /, ): if isinstance(filter_sharding, statelib.State): filter_sharding = statelib.create_path_filters(filter_sharding) # type: ignore iterable = tuple( filter_sharding.items() if isinstance(filter_sharding, tp.Mapping) else filter_sharding ) self._filters = tuple(filter for filter, _ in iterable) self._shardings = tuple(axis for _, axis in iterable) @property def filters(self) -> tuple[filterlib.Filter, ...]: return self._filters @property def shardings(self) -> tuple[tp.Any, ...]: return self._shardings def map_prefix( self, path: PathParts, variable: variablelib.Variable ) -> tp.Any: for filter, sharding in zip(self.filters, self.shardings): predicate = filterlib.to_predicate(filter) if predicate(path, variable): return sharding raise ValueError(f'No axis found for {path=}, {variable=}') def __repr__(self): return f'StateSharding({dict(zip(self.filters, self.shardings))})' def __eq__(self, other): return ( isinstance(other, StateSharding) and self.filters == other.filters and self.shardings == other.shardings ) def __hash__(self): return hash((self.filters, self.shardings)) def _jit_split_fn(ctx: graphlib.SplitContext, path, prefix, x): if isinstance(prefix, StateSharding): graphdef, *states = ctx.flatten(x, *prefix.filters) return extract.NodeStates.from_split(graphdef, *states, metadata=prefix) return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False)) def _jit_merge_fn(ctx: graphlib.MergeContext, path, prefix, leaf) -> tp.Any: if not isinstance(leaf, extract.NodeStates): raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}') return ctx.unflatten(leaf.graphdef, *leaf.states) @dataclasses.dataclass(eq=False) class JitFn: f: tp.Callable[..., tp.Any] in_shardings: tp.Any out_shardings: tp.Any kwarg_shardings: tp.Any ctxtag: tp.Hashable def __post_init__(self): # Prevent overwriting our ctxtag info with the child function's orig_ctxtag = self.ctxtag functools.update_wrapper(self, self.f, updated=()) self.ctxtag = orig_ctxtag def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag=self.ctxtag, is_inner=True, ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args_out, kwargs_out, out), prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings), ctxtag=self.ctxtag, split_fn=_jit_split_fn, ) return pure_args_out, pure_kwargs_out, pure_out @tp.overload def jit( *, in_shardings: tp.Any = None, out_shardings: tp.Any = None, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: ... @tp.overload def jit( fun: tp.Callable[P, R], *, in_shardings: tp.Any = None, out_shardings: tp.Any = None, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> JitWrapped[P, R]: ... def jit( fun: tp.Callable[P, R] | Missing = MISSING, *, in_shardings: tp.Any = None, out_shardings: tp.Any = None, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> JitWrapped[P, R] | tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: """ Lifted version of ``jax.jit`` that can handle Modules / graph nodes as arguments. .. note:: If jitted function has a model and an optimizer as inputs, we can reduce accelerator's memory usage if we specify them in ``donate_argnums`` or ``donate_argnames``: >>> from flax import nnx >>> >>> @nnx.jit(donate_argnames=("model", "optimizer")) ... def func(model: nnx.Module, optimizer: nnx.Optimizer, other_args): ... pass For details please see `this discussion `_. Args: fun: Function to be jitted. ``fun`` should be a pure function, as side-effects may only be executed once. The arguments and return value of ``fun`` should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by ``static_argnums`` can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. JAX keeps a weak reference to ``fun`` for use as a compilation cache key, so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` objects will already satisfy this requirement. .. note:: Bound methods (e.g., ``module.method``) are not supported. Use the decorator form ``@nnx.jit`` on the method definition or call ``nnx.jit(MyClass.method)(instance, ...)`` with the unbound method. in_shardings: Pytree of structure matching that of arguments to ``fun``, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree. The ``in_shardings`` argument is optional. JAX will infer the shardings from the input :py:class:`jax.Array`'s and defaults to replicating the input if the sharding cannot be inferred. The valid resource assignment specifications are: - :py:class:`Sharding`, which will decide how the value will be partitioned. With this, using a mesh context manager is not required. - :py:obj:`None`, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings. The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit's in_shardings. out_shardings: Like ``in_shardings``, but specifies resource assignment for function outputs. This is similar to pjit's out_shardings. The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` will use GSPMD's sharding propagation to figure out what the sharding of the output(s) should be. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static. If neither ``static_argnums`` nor ``static_argnames`` is provided, no arguments are treated as static. If ``static_argnums`` is not provided but ``static_argnames`` is, or vice versa, JAX uses :code:`inspect.signature(fun)` to find any positional arguments that correspond to ``static_argnames`` (or vice versa). If both ``static_argnums`` and ``static_argnames`` are provided, ``inspect.signature`` is not used, and only actual parameters listed in either ``static_argnums`` or ``static_argnames`` will be treated as static. static_argnames: An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. donate_argnums: Specify which positional argument buffers are "donated" to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated. If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no arguments are donated. If ``donate_argnums`` is not provided but ``donate_argnames`` is, or vice versa, JAX uses :code:`inspect.signature(fun)` to find any positional arguments that correspond to ``donate_argnames`` (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are provided, ``inspect.signature`` is not used, and only actual parameters listed in either ``donate_argnums`` or ``donate_argnames`` will be donated. For more details on buffer donation see the `FAQ `_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not provided but ``donate_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. device: This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. inline: Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references, reference semantics, and structural changes to Modules inside the jitted function. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode is faster but does not support shared ``Variable`` references or returning mutable array references from the jitted function. Returns: A wrapped version of ``fun``, set up for just-in-time compilation. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(fun, Missing): return functools.partial( jit, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] fun_unbound, _, was_bound = _resolve_bound_callable(fun) if was_bound: _raise_bound_method_error('jit') if not graph: if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('jit') ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)): raise ValueError( '`out_shardings` cannot contain `StateSharding` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('jit') ) if graph and not graph_updates: if in_shardings is not None: extract.check_prefix(in_shardings, 'in_shardings', 'jit') if out_shardings is not None: extract.check_prefix(out_shardings, 'out_shardings', 'jit') wrapped_cls: tp.Any if graph and graph_updates: wrapped_cls = JitWrapped else: wrapped_cls = functools.partial(SimpleJitWrapped, graph=graph) return wrapped_cls( fun_unbound, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, ) @dataclasses.dataclass(frozen=True, slots=True) class PartialState: """Container for a pre-flattened partial argument. Stores the pytree structure (``treedef``) as static metadata and the flattened leaves as dynamic data. Variables within the original argument are kept as leaves so their values can change between calls without triggering recompilation. """ treedef: jax.tree_util.PyTreeDef leaves: list[tp.Any] jax.tree_util.register_dataclass( PartialState, data_fields=['leaves'], meta_fields=['treedef'], ) def _flatten_to_partial_state( arg: tp.Any, ref_index: graphlib.RefMap | None, ) -> PartialState: if ref_index is not None: graphdef, flat_state = graphlib.flatten(arg, ref_index=ref_index, graph=True) return PartialState(treedef=graphdef, leaves=flat_state.leaves) is_leaf = lambda x: isinstance(x, variablelib.Variable) leaves, treedef = jax.tree.flatten(arg, is_leaf=is_leaf) return PartialState(treedef=treedef, leaves=leaves) @dataclasses.dataclass(eq=False) class SimpleJitFn: f: tp.Callable[..., tp.Any] out_shardings: tp.Any donate_argnums: frozenset[int] donate_argnames: frozenset[str] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) args_updates, kwargs_updates = updates args_snapshot, kwargs_snapshot = snapshot if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', args=args_updates, kwargs=kwargs_updates, out=out) def donated_arg(jax_path, prefix, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnums or extract.variable_changed(c, s) args_updates = extract.mask_variable_updates( args_updates, args_snapshot, keep_fn=donated_arg) def donated_kwarg(jax_path, prefix, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnames or extract.variable_changed(c, s) kwargs_updates = extract.mask_variable_updates( kwargs_updates, kwargs_snapshot, keep_fn=donated_kwarg) return out, (args_updates, kwargs_updates) class SimpleJitWrapped(tp.Generic[P, R]): def __init__( self, fun: tp.Callable[P, R], in_shardings: tp.Any, out_shardings: tp.Any, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, partial_args: tuple[PartialState, ...] = (), graph: bool = True, ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun self.out_shardings = out_shardings self.partial_args = partial_args self.graph = graph if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and ( static_argnums or static_argnames ): resolved = _resolve_argnums(fun, static_argnums, static_argnames) expanded = list(in_shardings) for i in sorted(resolved): expanded.insert(i, None) self.in_shardings = tuple(expanded) else: self.in_shardings = in_shardings jit_out_shardings: tp.Any if in_shardings is not None or out_shardings is not None: if isinstance(in_shardings, (tuple, list)) and ( static_argnums or static_argnames ): resolved = _resolve_argnums(fun, static_argnums, static_argnames) expanded = list(in_shardings) for i in sorted(resolved): expanded.insert(i, None) out_in_shardings = tuple(expanded) else: out_in_shardings = in_shardings jit_out_shardings = (out_shardings, (out_in_shardings, None)) else: jit_out_shardings = None donate_argnums_set = frozenset( (donate_argnums,) if isinstance(donate_argnums, int) else donate_argnums or () ) donate_argnames_set = frozenset( (donate_argnames,) if isinstance(donate_argnames, str) else donate_argnames or () ) self.jitted_fn = jax.jit( SimpleJitFn(fun, out_shardings, donate_argnums_set, donate_argnames_set, graph), in_shardings=in_shardings, out_shardings=jit_out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, ) def _maybe_to_tree(self, args, kwargs): if self.graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(self.in_shardings, None) if self.in_shardings is not None else None, check_aliasing=self.in_shardings is not None, ) return args, kwargs def _maybe_from_tree(self, out): if self.graph: out = extract.from_tree2(out) return out def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args = (*self.partial_args, *args) # type: ignore[assignment] args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: # skip check for graph mode extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.jitted_fn(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) return self._maybe_from_tree(out) def __get__(self, obj, objtype=None): if obj is None: return self return functools.partial(self, obj) def eval_shape(self, *args, **kwargs): args = (*self.partial_args, *args) args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.jitted_fn.eval_shape(*args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): args = (*self.partial_args, *args) args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) traced = self.jitted_fn.trace(*args, **kwargs) return SimpleTraced(traced, self) def lower(self, *args, **kwargs): args = (*self.partial_args, *args) args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) lowered = self.jitted_fn.lower(*args, **kwargs) return SimpleLowered(lowered, self) def jit_partial( fun: tp.Callable[..., R], *partial_args: tp.Any, in_shardings: tp.Any = None, out_shardings: tp.Any = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> SimpleJitWrapped[..., R]: """JIT-compile ``fun`` with pre-flattened partial arguments. Similar to ``nnx.cached_partial`` but designed for tree-mode (``graph=False``). Each ``partial_arg`` is flattened into a ``PartialState`` whose pytree structure is fixed at construction time. Variable values inside partial arguments can still change between calls without triggering recompilation, and any mutations to Variables are propagated back to the originals after each call. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> x, y = jnp.ones((4, 2)), jnp.ones((4, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ... >>> def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(model, grads) ... return loss ... >>> train_step_fn = nnx.jit_partial(train_step, model, optimizer, graph=False) ... >>> loss = train_step_fn(x, y) Args: fun: The function to JIT-compile. *partial_args: Arguments to be pre-flattened and bound. These must appear as the first positional arguments of ``fun``. in_shardings: Sharding specification for inputs. When a tuple/list, the first ``len(partial_args)`` entries correspond to partial arguments and are broadcast against their original pytree structure. A non-tuple value (e.g. a single ``PartitionSpec``) is passed through directly to ``jax.jit`` and broadcast across all arguments uniformly. out_shardings: Like ``in_shardings``, but for function outputs. donate_argnums: Positional argument indices whose buffers may be donated to the computation. donate_argnames: Named arguments whose buffers may be donated. keep_unused: If ``True``, unused arguments are not pruned. device: Optional device to run on. backend: Optional backend to use. inline: If ``True``, inline the function. graph: If ``None``, uses the ``nnx_graph_mode`` config value. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateSharding`` is not supported. Returns: A callable expecting the remaining (runtime) arguments. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph_updates and graph: raise ValueError( '`graph_updates` not supported by `jit_partial`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' 'in `jit_partial`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)): raise ValueError( '`out_shardings` cannot contain `StateSharding` objects ' 'in `jit_partial`' ) is_variable = lambda x: isinstance(x, variablelib.Variable) ref_index = graphlib.RefMap() if graph else None flat_partial_args = tuple( _flatten_to_partial_state(arg, ref_index=ref_index) for arg in partial_args ) jit_in_shardings: tp.Any = None if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and not graph: num_partial = len(partial_args) partial_shardings = in_shardings[:num_partial] runtime_shardings = in_shardings[num_partial:] flat_partial_shardings = [] for flat_arg, orig_arg, sharding in zip( flat_partial_args, partial_args, partial_shardings): broadcasted = extract.broadcast_prefix( sharding, orig_arg, prefix_is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), tree_is_leaf=is_variable, ) flat_partial_shardings.append( PartialState(treedef=flat_arg.treedef, leaves=broadcasted) ) jit_in_shardings = (*flat_partial_shardings, *runtime_shardings) else: jit_in_shardings = in_shardings @functools.wraps(fun) def wrapped_fun(*args, **kwargs): index_ref = graphlib.IndexMap() if graph else None def _unflatten(arg): if not isinstance(arg, PartialState): return arg elif graph: return graphlib.unflatten( arg.treedef, arg.leaves, index_ref=index_ref, copy_variables=False, ) else: return jax.tree.unflatten(arg.treedef, arg.leaves) args = (_unflatten(a) for a in args) return fun(*args, **kwargs) return SimpleJitWrapped( wrapped_fun, in_shardings=jit_in_shardings, out_shardings=out_shardings, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, partial_args=flat_partial_args, graph=graph, ) class JitWrapped(tp.Generic[P, R]): """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as ``jax.jit``. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution. """ def __init__( self, fun: tp.Callable[P, R], in_shardings: tp.Any, out_shardings: tp.Any, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun kwarg_shardings = None self.jax_in_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) self.jax_out_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, out_shardings, ) if isinstance(in_shardings, (tuple, list)) and (static_argnums or static_argnames): # We should reintroduce None values into in_shardings corresponding to static arguments static_argnums = _resolve_argnums(fun, static_argnums, static_argnames) in_shardings = list(in_shardings) for static_arg_index in sorted(static_argnums): in_shardings.insert(static_arg_index, None) in_shardings = tuple(in_shardings) jax_out_in_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) self.jitted_fn = jax.jit( JitFn(fun, in_shardings, out_shardings, kwarg_shardings, self), in_shardings=self.jax_in_shardings, out_shardings=(jax_out_in_shardings, kwarg_shardings, self.jax_out_shardings), static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, ) self.in_shardings = in_shardings self.out_shardings = out_shardings self.kwarg_shardings = kwarg_shardings self.static_argnums = static_argnums # implement descriptor protocol so that we can use this as a method def __get__(self, obj, objtype=None): if obj is None: return self return functools.partial(self, obj) def _get_pure_args_kwargs(self, args, kwargs): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(self.in_shardings, self.kwarg_shardings) if self.in_shardings is not None or self.kwarg_shardings is not None else None, split_fn=_jit_split_fn, check_aliasing=self.in_shardings is not None or self.kwarg_shardings is not None, ctxtag=self, ) return pure_args, pure_kwargs def _get_non_pure_out(self, pure_args_out, pure_kwargs_out, pure_out, /): _args_out, _kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), merge_fn=_jit_merge_fn, is_inner=False, ctxtag=self, ) return out def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: # run dynamic_cache_context before update_context with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn( *pure_args, **pure_kwargs ) out = self._get_non_pure_out(pure_args_out, pure_kwargs_out, pure_out) return out def eval_shape(self, *args, **kwargs): """See ``jax.eval_shape``.""" args, kwargs = graphlib.clone((args, kwargs)) with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn.eval_shape( *pure_args, **pure_kwargs ) out = self._get_non_pure_out(pure_args_out, pure_kwargs_out, pure_out) return out def trace(self, *args, **kwargs) -> Traced: """Trace this function explicitly for the given arguments. A traced function is staged out of Python and translated to a jaxpr. It is ready for lowering but not yet lowered. Returns: A ``Traced`` instance representing the tracing. """ with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) traced = self.jitted_fn.trace(*pure_args, **pure_kwargs) return Traced(traced, self) def lower(self, *args, **kwargs) -> Lowered: """Lower this function explicitly for the given arguments. This is a shortcut for ``self.trace(*args, **kwargs).lower()``. A lowered function is staged out of Python and translated to a compiler's input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled. Returns: A ``Lowered`` instance representing the lowering. """ with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) lowered = self.jitted_fn.lower(*pure_args, **pure_kwargs) return Lowered(lowered, self) class Stage: args_info: tp.Any # PyTree of ArgInfo @property def _inner_obj(self) -> tp.Any: raise NotImplementedError @property def in_tree(self) -> jax.tree_util.PyTreeDef: return self._inner_obj.in_tree @property def in_avals(self): return self._inner_obj.in_avals @property def donate_argnums(self): return self._inner_obj.donate_argnums @dataclasses.dataclass(frozen=True, slots=True) class Compiled(Stage): """Compiled representation of a function specialized to types/values. A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX's various compilation paths and backends. """ compiled: jax.stages.Compiled jit_wrapped: JitWrapped @property def _inner_obj(self): return self.compiled @property def args_info(self) -> tp.Any: # PyTree of ArgInfo raise self.compiled.args_info @staticmethod def call(*args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): with graphlib.update_context(self.jit_wrapped): pure_args, pure_kwargs = self.jit_wrapped._get_pure_args_kwargs( args, kwargs ) pure_args_out, pure_kwargs_out, pure_out = self.compiled( *pure_args, **pure_kwargs ) out = self.jit_wrapped._get_non_pure_out( pure_args_out, pure_kwargs_out, pure_out ) return out @property def out_tree(self) -> jax.tree_util.PyTreeDef: return self.compiled.out_tree def as_text(self) -> str | None: """A human-readable text representation of this executable. Intended for visualization and debugging purposes. This is not a valid nor reliable serialization. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.as_text() def cost_analysis(self) -> tp.Any | None: """A summary of execution cost estimates. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.cost_analysis() def memory_analysis(self) -> tp.Any | None: """A summary of estimated memory requirements. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.memory_analysis() def runtime_executable(self) -> tp.Any | None: """An arbitrary object representation of this executable. Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.runtime_executable() @property def input_shardings(self): # PyTree[sharding.Sharding] return self.compiled.input_shardings @property def output_shardings(self): # PyTree[sharding.Sharding] return self.compiled.output_shardings @property def input_layouts(self): return self.compiled.input_formats @dataclasses.dataclass(frozen=True, slots=True) class Lowered(Stage): """Lowering of a function specialized to argument types and values. A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ lowered: jax.stages.Lowered jit_wrapped: JitWrapped @property def _inner_obj(self): return self.lowered @property def args_info(self) -> tp.Any: # PyTree of ArgInfo return self.lowered.args_info @property def out_tree(self): return self.lowered.out_tree @classmethod def from_flat_info( cls, lowering: tp.Any, # type: ignore[name-defined] in_tree: jax.tree_util.PyTreeDef, in_avals, donate_argnums: tuple[int, ...], out_tree: jax.tree_util.PyTreeDef, no_kwargs: bool = False, ): raise NotImplementedError def compile( self, compiler_options: jax.stages.CompilerOptions | None = None ) -> Compiled: """Compile, returning a corresponding ``Compiled`` instance.""" compiled = self.lowered.compile(compiler_options) return Compiled(compiled, self.jit_wrapped) def as_text( self, dialect: str | None = None, *, debug_info: bool = False ) -> str: """A human-readable text representation of this lowering. Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. Use `jax.export` if you want reliable and portable serialization. Args: dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", or "hlo"). debug_info: Whether to include debugging information, e.g., source location. """ return self.lowered.as_text(dialect=dialect, debug_info=debug_info) def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: """An arbitrary object representation of this lowering. Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations. Use `jax.export` if you want reliable and portable serialization. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. Args: dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", or "hlo"). """ return self.lowered.compiler_ir(dialect=dialect) def cost_analysis(self) -> tp.Any | None: """A summary of execution cost estimates. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.lowered.cost_analysis() @dataclasses.dataclass(frozen=True, slots=True) class Traced(Stage): """Traced form of a function specialized to argument types and values. A traced computation is ready for lowering. This class carries the traced representation with the remaining information needed to later lower, compile, and execute it. """ traced: jax.stages.Traced jit_wrapped: JitWrapped @property def _inner_obj(self): return self.traced @property def out_info(self): return self.traced.out_info def lower( self, *, lowering_platforms: tuple[str, ...] | None = None ) -> Lowered: """Lower to compiler input, returning a ``Lowered`` instance.""" lowered = self.traced.lower(lowering_platforms=lowering_platforms) return Lowered(lowered, self.jit_wrapped) @dataclasses.dataclass(frozen=True, slots=True) class SimpleCompiled(Stage): compiled: jax.stages.Compiled jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.compiled @property def args_info(self) -> tp.Any: raise self.compiled.args_info @staticmethod def call(*args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): args = (*self.jit_wrapped.partial_args, *args) args, kwargs = self.jit_wrapped._maybe_to_tree(args, kwargs) if not self.jit_wrapped.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.compiled(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) return self.jit_wrapped._maybe_from_tree(out) @property def out_tree(self) -> jax.tree_util.PyTreeDef: return self.compiled.out_tree def as_text(self) -> str | None: return self.compiled.as_text() def cost_analysis(self) -> tp.Any | None: return self.compiled.cost_analysis() def memory_analysis(self) -> tp.Any | None: return self.compiled.memory_analysis() def runtime_executable(self) -> tp.Any | None: return self.compiled.runtime_executable() @property def input_shardings(self): return self.compiled.input_shardings @property def output_shardings(self): return self.compiled.output_shardings @property def input_layouts(self): return self.compiled.input_formats @dataclasses.dataclass(frozen=True, slots=True) class SimpleLowered(Stage): lowered: jax.stages.Lowered jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.lowered @property def args_info(self) -> tp.Any: return self.lowered.args_info @property def out_tree(self): return self.lowered.out_tree def compile( self, compiler_options: jax.stages.CompilerOptions | None = None ) -> SimpleCompiled: compiled = self.lowered.compile(compiler_options) return SimpleCompiled(compiled, self.jit_wrapped) def as_text( self, dialect: str | None = None, *, debug_info: bool = False ) -> str: return self.lowered.as_text(dialect=dialect, debug_info=debug_info) def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: return self.lowered.compiler_ir(dialect=dialect) def cost_analysis(self) -> tp.Any | None: return self.lowered.cost_analysis() @dataclasses.dataclass(frozen=True, slots=True) class SimpleTraced(Stage): traced: jax.stages.Traced jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.traced @property def out_info(self): return self.traced.out_info def lower( self, *, lowering_platforms: tuple[str, ...] | None = None ) -> SimpleLowered: lowered = self.traced.lower(lowering_platforms=lowering_platforms) return SimpleLowered(lowered, self.jit_wrapped) # ------------------------------- # shard_map # ------------------------------- # TODO: create StateSpec and consider enabling a mode that does # not use filters during split for performance. Overall there might # be performance limitations for using shard_map at a top-level @dataclasses.dataclass(eq=False) class SimpleShardMapFn: f: tp.Callable[..., tp.Any] graph: bool out_specs: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_specs) extract.check_no_aliases('shard_map', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class ShardMapFn: f: tp.Callable[..., tp.Any] in_specs: tp.Any out_specs: tp.Any kwarg_specs: tp.Any ctxtag: tp.Hashable def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag=self.ctxtag, is_inner=True, ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args_out, kwargs_out, out), prefix=(self.in_specs, self.kwarg_specs, self.out_specs), ctxtag=self.ctxtag, split_fn=_jit_split_fn, ) return pure_args_out, pure_kwargs_out, pure_out @tp.overload def shard_map( f: F, *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... @tp.overload def shard_map( *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... def shard_map( f: F | type[Missing] = Missing, *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """ Lifted version of `jax.shard_map `_ that can handle Modules / graph nodes as arguments. Simple data parallel example:: import jax import jax.numpy as jnp from flax import nnx from jax.sharding import PartitionSpec as P mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) @nnx.shard_map( mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') ) def f(m, x): return m(x) y = f(m, x) jax.debug.visualize_array_sharding(y) Notice that here we simply used some ``PartitionSpec`` to define the spec the the whole model and data. This works for simple cases but if we need to assign different ``PartitionSpec`` to different parts of the model we need to use ``StateSharding`` and create some filters that allow us to target specific parts of the model. Here's an example of how to do tensor parallelism for a simple MLP block using ``StateSharding`` and filters:: mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) def path_ends_with(*path_suffix): # custom filter return lambda path, value: path[-len(path_suffix):] == path_suffix model_spec = nnx.StateSharding({ path_ends_with('linear1', 'kernel'): P(None, 'model'), path_ends_with('linear2', 'kernel'): P('model', None), }) @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None)) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...]) Alternatively, a ``State`` object with the exact PartitionSpec for each state then you can be passed to ``StateSharding``:: mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) model_spec = nnx.State( { 'linear1': {'kernel': P(None, 'model')}, 'linear2': {'kernel': P('model', None)}, } ) @nnx.shard_map( mesh=mesh, in_specs=(nnx.StateSharding(model_spec), P(None)), out_specs=P(None), ) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...]) Here ``model_spec`` was created manually but you can also automate this process by using ``nnx.get_partition_spec`` to automatically create it for you (see `Scale up on multiple devices `_ ). Args: f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, takes as input a shard of the mapped-over arguments and produces a shard of the output. mesh: a ``jax.sharding.Mesh`` representing the array of devices over which to shard the data and on which to execute instances of ``f``. The names of the ``Mesh`` can be used in collective communication operations in ``f``. This is typically created by a utility function like :func:`jax.experimental.mesh_utils.create_device_mesh`. in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding`` (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded. out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding`` (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at a position expresses concatenation of that mesh axis's shards along the corresponding positional axis. Not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced. axis_names: optional set of axis names from ``mesh`` over which the function ``f`` is manual. If empty, ``f``, is manual over all mesh axes. check_vma: optional boolean representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in ``out_specs`` are consistent with how the outputs of ``f`` are replicated. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``StateSharding`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateSharding`` is not supported. Returns: A callable that applies the input function ``f`` across data sharded according to the ``mesh`` and ``in_specs``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( shard_map, mesh=mesh, in_specs=in_specs, out_specs=out_specs, axis_names=axis_names, check_vma=check_vma, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] assert not isinstance(f, type) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('shard_map') if not graph or not graph_updates: if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_specs)): raise ValueError( '`in_specs` cannot contain `StateSharding` objects ' 'when `graph=False`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_specs)): raise ValueError( '`out_specs` cannot contain `StateSharding` objects ' 'when `graph=False`' ) if graph: extract.check_prefix(in_specs, 'in_specs', 'shard_map') extract.check_prefix(out_specs, 'out_specs', 'shard_map') shard_map_fn = jax.shard_map( SimpleShardMapFn(f_unbound, graph=graph, out_specs=out_specs), mesh=mesh, in_specs=in_specs, out_specs=(out_specs, in_specs), axis_names=axis_names, check_vma=check_vma, ) @functools.wraps(f) def shard_map_wrapper(*args, **kwargs): if graph: args = extract.to_tree2( args, prefix=in_specs, check_aliasing=in_specs is not None, ) extract.check_no_aliases('shard_map', args=args) out, updates = shard_map_fn(*args, **kwargs) extract.apply_variable_updates(args, updates) if graph: out = extract.from_tree2(out) return out shard_map_wrapper.inner = shard_map_fn # type: ignore return shard_map_wrapper # type: ignore kwarg_specs = PartitionSpec() jax_in_specs = jax.tree.map( lambda x: extract.NodeStates( _graphdef=PartitionSpec(), # type: ignore[arg-type] states=x.shardings, metadata=x, ) if isinstance(x, StateSharding) else x, in_specs, ) jax_out_specs = jax.tree.map( lambda x: extract.NodeStates( _graphdef=PartitionSpec(), # type: ignore[arg-type] states=x.shardings, metadata=x, ) if isinstance(x, StateSharding) else x, out_specs, ) @functools.wraps(f) # type: ignore[no-redef] def shard_map_wrapper(*args, **kwargs): with graphlib.update_context(shard_map_wrapper): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(in_specs, kwarg_specs) if in_specs is not None or kwarg_specs is not None else None, split_fn=_jit_split_fn, check_aliasing=in_specs is not None or kwarg_specs is not None, ctxtag=shard_map_wrapper, ) pure_args_out, pure_kwargs_out, pure_out = shard_map_fn( *pure_args, **pure_kwargs ) _args_out, _kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), merge_fn=_jit_merge_fn, is_inner=False, ctxtag=shard_map_wrapper, ) return out shard_map_fn = jax.shard_map( ShardMapFn(f_unbound, in_specs, out_specs, kwarg_specs, shard_map_wrapper), mesh=mesh, in_specs=jax_in_specs, out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore axis_names=axis_names, check_vma=check_vma, ) shard_map_wrapper.inner = shard_map_fn # type: ignore return shard_map_wrapper # type: ignore # We can't use private methods from jax._src.api_util # We copy the function: api_util.fun_signature def _fun_signature(fun: tp.Callable) -> inspect.Signature | None: try: return inspect.signature(fun) except (ValueError, TypeError): return None # Adapted copy of private jax function from api_util: fun_signature def _resolve_argnums( fun: tp.Callable, static_argnums: int | tp.Sequence[int] | None, static_argnames: str | tp.Iterable[str] | None, ) -> tuple[int, ...]: def _ensure_index_tuple(x: tp.Any) -> tuple[int, ...]: """Convert x to a tuple of indices.""" try: return (operator.index(x),) except TypeError: return tuple(map(operator.index, x)) def _ensure_str(x: str) -> str: if not isinstance(x, str): raise TypeError(f"argument is not a string: {x}") return x def _ensure_str_tuple(x: str | tp.Iterable[str]) -> tuple[str, ...]: """Convert x to a tuple of strings.""" if isinstance(x, str): return (x,) else: return tuple(map(_ensure_str, x)) signature = _fun_signature(fun) if signature is None: # Some built-in functions don't support signature. # See: https://github.com/python/cpython/issues/73485 # In this case no validation is done static_argnums = () if static_argnums is None else _ensure_index_tuple( static_argnums) else: # Infer argnums and argnames according to docstring # If nums is None and names is not None, then nums are inferred from the # names and vice-versa. _POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD _POSITIONAL_ARGUMENTS = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD ) def infer_argnums_and_argnames( sig: inspect.Signature, argnums: int | tp.Iterable[int] | None, argnames: str | tp.Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...]]: """Infer missing argnums and argnames for a function with inspect.""" if argnums is None and argnames is None: return (), () if argnums is not None and argnames is not None: argnums = _ensure_index_tuple(argnums) argnames = _ensure_str_tuple(argnames) return argnums, argnames parameters = sig.parameters if argnums is None: assert argnames is not None argnames = _ensure_str_tuple(argnames) argnums = tuple( i for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames ) else: argnums = _ensure_index_tuple(argnums) argnames = tuple( k for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums ) return argnums, argnames def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: n_pos_args = 0 for param in sig.parameters.values(): if param.kind in _POSITIONAL_ARGUMENTS: n_pos_args += 1 elif param.kind is inspect.Parameter.VAR_POSITIONAL: # We can have any number of positional arguments return if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args): raise ValueError(f"Jitted function has {argnums_name}={argnums}, " f"but only accepts {n_pos_args} positional arguments.") static_argnums, static_argnames = infer_argnums_and_argnames( signature, static_argnums, static_argnames) # Validation _validate_argnums(signature, static_argnums, "static_argnums") return static_argnums ================================================ FILE: flax/nnx/transforms/general.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import typing as tp from flax.nnx import ( extract, graphlib, ) from flax.typing import MISSING, Missing A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) # ------------------------------- # (split|merge)_inputs # ------------------------------- @tp.overload def split_inputs( *, ctxtag: str = 'split_merge_inputs', ) -> tp.Callable[[F], F]: ... @tp.overload def split_inputs( f: F, *, ctxtag: str = 'split_merge_inputs', ) -> F: ... def split_inputs( f: F | Missing = MISSING, *, ctxtag: str = 'split_merge_inputs', ) -> F | tp.Callable[[F], F]: """Takes in a function that contains graph nodes in the inputs and outputs, and returns a function that replaces the graph nodes with some jax-compatible data structures. Must be used in conjunction with :func:`merge_inputs`. Args: f: The function to be transformed. ctxtag: The context tag to be used for the transformation. Defaults to 'split_merge_inputs'. Returns: The transformed function. ``split_inputs`` and ``merge_inputs`` can be used to lift functions that operate on jax datastructures (pytrees) to functions that operate on graph nodes. ``split_inputs`` will take graph nodes in the inputs and outputs and replace them with jax-compatible data structures, usually before calling into the transformed function, while ``merge_inputs`` will convert the jax-compatible data structures back to graph nodes, usually inside the transformed function. For common transforms like ``jax.jit`` and ``jax.vmap`` NNX will provide a version that works with graph nodes, but for other transforms you can use ``split_inputs`` and ``merge_inputs`` to manually lift the function. The following example demonstrates how to use ``split_inputs`` and ``merge_inputs`` to lift ``jax.jit`` to work over a silly function has a stateful operation that zeros out the kernel of a linear layer:: >>> from flax import nnx >>> import jax.numpy as jnp >>> import jax ... >>> @split_inputs ... @jax.jit ... @merge_inputs ... def forward_and_zero(model: nnx.Linear, x: jax.Array): ... y = model(x) ... model.kernel[...] *= 0 # zero out the kernel ... return y ... >>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0)) >>> y = forward_and_zero(model, jnp.ones((1, 2))) >>> y.shape (1, 2) >>> assert jnp.allclose(model.kernel, 0) As shown above, not only does the lifted function work with graph nodes, but it also propagates the side effects of the original function. **Note**: in practice use ``nnx.jit`` instead. Splitting and merging can also be applied to multiple functions in a pipeline. The following example show how to lift ``jax.lax.cond`` by using ``split_inputs`` over ``cond`` and ``merge_inputs`` over the branches:: >>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> true_fn = lambda m, x: m(x) >>> false_fn = lambda m, x: x + 1 ... >>> y = split_inputs(jax.lax.cond)( ... False, ... merge_inputs(true_fn), ... merge_inputs(false_fn), # <== gets called ... model, ... x, ... ) >>> assert jnp.allclose(y, 2) **Lifting functions with output semantics** ``merge_inputs`` internally returns a ``(inputs, output)`` tuple, where ``inputs`` is the tuple of the input arguments with non-graph node leaves set to ``None``, and ``output`` is the output of the function. This is done to propage all the state changes in the function to the graph nodes outside the function. If the transform function has output semantics like e.g. ``jax.vmap``'s ``out_axes``, you must account for this in the by configuring the arguments accordingly:: >>> from functools import partial ... >>> model = nnx.Linear(2, 2, rngs=nnx.Rngs(0)) ... >>> in_axes = (None, 0) >>> out_axes = (in_axes, 0) # <== internal output arrangement ... >>> @split_inputs ... @partial(jax.vmap, in_axes=in_axes, out_axes=out_axes) ... @merge_inputs ... def forward(model: nnx.Linear, x: jax.Array): ... return model(x) ... >>> x = jnp.ones((10, 2)) >>> y = forward(model, x) >>> y.shape (10, 2) .. note:: If the transform has a rigid output structure like ``jax.grad`` or ``jax.lax.scan`` then ``split_inputs`` and ``merge_inputs`` will not work. In this case, use the `Functional API `__. """ if isinstance(f, Missing): return functools.partial(split_inputs, ctxtag=ctxtag) # type: ignore[return-value] @graphlib.update_context(ctxtag) @functools.wraps(f) def split_inputs_wrapper(*args): pure_args = extract.to_tree(args, ctxtag=ctxtag) pure_args_out, pure_out = f(*pure_args) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=ctxtag, is_inner=False ) return out return split_inputs_wrapper # type: ignore @tp.overload def merge_inputs( *, ctxtag: str = 'split_merge_inputs', ) -> tp.Callable[[F], F]: ... @tp.overload def merge_inputs( f: F, *, ctxtag: str = 'split_merge_inputs', ) -> F: ... def merge_inputs( f: F | Missing = MISSING, *, ctxtag: str = 'split_merge_inputs', ) -> F | tp.Callable[[F], F]: """Takes in a function that contains jax-compatible data structures in the inputs and outputs, and returns a function that replaces the jax-compatible data structures the corresponding graph nodes. Must be used in conjunction with :func:`split_inputs`. Args: f: The function to be transformed. ctxtag: The context tag to be used for the transformation. Defaults to 'split_merge_inputs'. Returns: The transformed function. For more information and examples, see :func:`split_inputs`. """ if isinstance(f, Missing): return functools.partial(merge_inputs, ctxtag=ctxtag) # type: ignore[return-value] @functools.wraps(f) def merge_inputs_wrapper(*pure_args): args = extract.from_tree(pure_args, ctxtag=ctxtag, is_inner=True) out = f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag=ctxtag) return pure_args_out, pure_out return merge_inputs_wrapper # type: ignore ================================================ FILE: flax/nnx/transforms/iteration.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pytype: skip-file from collections import deque import dataclasses import functools import typing as tp from flax import struct from flax import typing from flax.core.frozen_dict import FrozenDict from flax.nnx import extract, filterlib, graphlib, spmd, variablelib from flax.nnx import statelib from flax.nnx.module import Module from flax.nnx.statelib import State from flax.nnx.transforms.transforms import ( resolve_kwargs, _resolve_bound_callable, _raise_bound_method_error, ) from flax.typing import Leaf, Missing, PytreeDeque import jax import jax.core import jax.numpy as jnp import jax.stages import numpy as np A = tp.TypeVar('A') C = tp.TypeVar('C') B = tp.TypeVar('B') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) M = tp.TypeVar('M', bound=Module) MA = tp.TypeVar('MA', bound=Module) N = tp.TypeVar('N', bound=Module) T = tp.TypeVar('T') StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable Leaves = list[Leaf] Index = int class Carry: """Helper class for :func:`flax.nnx.scan` function to mark input and output axis as carry. """ pass # ------------------------------- # transform_metadata # ------------------------------- def _apply_axis_fn( tree: tp.Any, axes: tp.Any, metadata: tp.Mapping[str, tp.Any], axis_fn: tp.Callable[..., tp.Any], ) -> None: is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable) _, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf) leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf) for leaf, axis in zip(leaves, per_leaf_axes): if (axis is None or isinstance(axis, int)) and isinstance( leaf, variablelib.Variable ): axis_fn(leaf, axis, metadata) @tp.overload def transform_metadata( *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, partition: str, graph: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def transform_metadata( f: F, *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, graph: bool | None = None, partition: str, ) -> F: ... def transform_metadata( f: F | type[Missing] = Missing, *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, graph: bool | None = None, partition: str, ) -> F | tp.Callable[[F], F]: if f is Missing: return functools.partial( transform_metadata, in_axes=in_axes, out_axes=out_axes, partition=partition, graph=graph, ) # type: ignore[return-value] if graph is None: graph = graphlib.set_graph_mode.current_value() metadata: tp.Mapping[str, tp.Any] = { spmd.PARTITION_NAME: partition, } if graph: extract.check_prefix(in_axes, 'in_axes', 'transform_metadata') extract.check_prefix(out_axes, 'out_axes', 'transform_metadata') @functools.wraps(f) def wrapper(*in_args, **in_kwargs): in_args = resolve_kwargs(f, in_args, in_kwargs) if graph: in_args = extract.to_tree2(in_args, prefix=in_axes) extract.check_no_aliases('transform_metadata', args=in_args) args = graphlib.clone(in_args, graph=graph) _apply_axis_fn(args, in_axes, metadata, spmd.remove_axis) updates, snapshot = extract.updates_and_snapshot(args) if graph: args = extract.from_tree2(args) out = f(*args) if graph: out = extract.to_tree2(out, prefix=out_axes) extract.check_no_aliases('transform_metadata', args=updates, out=out) _apply_axis_fn(args, in_axes, metadata, spmd.add_axis) _apply_axis_fn(out, out_axes, metadata, spmd.add_axis) updates = extract.mask_variable_updates(updates, snapshot) extract.apply_variable_updates(in_args, updates) if graph: out = extract.from_tree2(out) return out return wrapper # type: ignore[return-value] # ------------------------------- # vmap # ------------------------------- class StateAxes(extract.PrefixMapping, tp.Mapping): def __init__( self, filter_axes: ( statelib.State | tp.Mapping[filterlib.Filter, Index | type[Carry] | None] | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]] ), /, ): if isinstance(filter_axes, statelib.State): filter_axes = statelib.create_path_filters(filter_axes) # type: ignore iterable = tuple( filter_axes.items() if isinstance(filter_axes, tp.Mapping) else filter_axes ) self._filters = tuple(filter for filter, _ in iterable) self._axes = tuple(axis for _, axis in iterable) @property def filters(self) -> tuple[filterlib.Filter, ...]: return self._filters @property def axes(self) -> tuple[Index | type[Carry] | None, ...]: return self._axes def map_prefix( self, path: typing.PathParts, variable: variablelib.Variable ) -> tp.Any: for filter, axis in zip(self.filters, self.axes): predicate = filterlib.to_predicate(filter) if predicate(path, variable): return axis raise ValueError(f'No axis found for {path=}, {variable=}') def __repr__(self): return f'StateAxes({dict(self.items())})' def items(self): return zip(self.filters, self.axes) def __getitem__(self, key): return self.axes[self.filters.index(key)] def __iter__(self): return iter(self.filters) def __len__(self): return len(self.filters) def __eq__(self, other): return ( isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes ) def __hash__(self): return hash((self.filters, self.axes)) AxisFn = tp.Callable[ [graphlib.GraphState | variablelib.Variable, int, tp.Mapping], graphlib.GraphState | variablelib.Variable, ] def _update_variable_sharding_metadata( tree, transform_metadata, axis_fn: AxisFn ): def _update_axes_fn(node_states): if isinstance(node_states, extract.NodeStates) and isinstance( node_states.metadata, (StateAxes, int) ): if isinstance(node_states.metadata, int): state = node_states.state assert isinstance(state, State | variablelib.Variable) state = axis_fn(state, node_states.metadata, transform_metadata) return node_states.replace(states=(state,)) else: states_out: list[graphlib.GraphState | variablelib.Variable] = [] for state, axis in zip(node_states.states, node_states.metadata.axes): assert isinstance(state, graphlib.State | variablelib.Variable) if isinstance(axis, int): state = axis_fn(state, axis, transform_metadata) states_out.append(state) return node_states.replace(states=tuple(states_out)) return node_states return jax.tree.map( _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates) ) def _vmap_split_fn(ctx: graphlib.SplitContext, path, prefix, x): if isinstance(prefix, StateAxes): return extract.NodeStates.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) return extract.NodeStates.from_split(*ctx.split(x), metadata=prefix) @dataclasses.dataclass(eq=False) class SimpleVmapFn: f: tp.Callable[..., tp.Any] graph: bool out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases('vmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class SimplePmapFn: f: tp.Callable[..., tp.Any] graph: bool out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases('pmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class VmapFn: f: tp.Callable[..., tp.Any] transform_metadata: tp.Mapping[str, tp.Any] in_axes: tp.Any out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args: tuple[tp.Any, ...]): if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args = extract.from_tree(pure_args, ctxtag='vmap', is_inner=True) out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=_vmap_split_fn, ctxtag='vmap', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis ) return pure_args_out, pure_out @tp.overload def vmap( *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def vmap( f: F, *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... def vmap( f: F | type[Missing] = Missing, *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.vmap `__. Args: f: Function to be mapped over additional axes. in_axes: An integer, None, or sequence of values specifying which input array axes to map over (see `jax.vmap `__). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter `__. out_axes: An integer, None, or pytree indicating where the mapped axis should appear in the output (see `jax.vmap `__). axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. axis_size: Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``StateAxes`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. Returns: Batched/vectorized version of ``f`` with arguments that correspond to those of ``f``, but with extra array axes at positions indicated by ``in_axes``, and a return value that corresponds to that of ``f``, but with extra array axes at positions indicated by ``out_axes``. Example:: >>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3) >>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return x @ model.w ... >>> y = forward(model, x) >>> y.shape (5, 3) To control control how graph node substates are vectorized, ``StateAxes`` can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:: >>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32) """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('vmap') if not graph or not graph_updates: if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('vmap') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('vmap') ) if graph: extract.check_prefix(in_axes, 'in_axes', 'vmap') extract.check_prefix(out_axes, 'out_axes', 'vmap') vmapped_fn = jax.vmap( SimpleVmapFn(f_unbound, graph=graph, out_axes=out_axes), in_axes=in_axes, out_axes=(out_axes, (in_axes, 0)), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(f_unbound) def simple_vmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(in_axes, None) if in_axes is not None else None, check_aliasing=in_axes is not None, ) extract.check_no_aliases('vmap', args=args, kwargs=kwargs) out, updates = vmapped_fn(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) if graph: out = extract.from_tree2(out) return out return simple_vmap_wrapper # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) vmapped_fn = jax.vmap( # type: ignore[assignment] VmapFn(f_unbound, transform_metadata, in_axes, out_axes), in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(f) @graphlib.update_context('vmap') def vmap_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' ) pure_args_out, pure_out = vmapped_fn(*pure_args) _args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag='vmap', is_inner=False ) return out return vmap_wrapper # type: ignore # ------------------------------- # pmap # ------------------------------- @dataclasses.dataclass(eq=False) class PmapFn: f: tp.Callable[..., tp.Any] transform_metadata: tp.Mapping[str, tp.Any] in_axes: tp.Any out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args: tuple[tp.Any, ...]): if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args = extract.from_tree(pure_args, ctxtag='pmap', is_inner=True) out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=_vmap_split_fn, ctxtag='pmap', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis ) return pure_args_out, pure_out @tp.overload def pmap( *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def pmap( f: F, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... def pmap( f: F | type[Missing] = Missing, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.pmap `__. Args: f: Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, graph nodes, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by ``static_broadcasted_argnums`` can be anything at all, provided they are hashable and have an equality operation defined. axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. in_axes: A non-negative integer, None, or nested Python container thereof that specifies which axes of positional arguments to map over. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter `__. out_axes: A non-negative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None ``out_axes`` specification. static_broadcasted_argnums: An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated by ``static_broadcasted_argnums`` then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to (). Static arguments must be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and should be immutable. devices: This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process in multi-process settings (and will therefore include devices across processes). If specified, the size of the mapped axis must be equal to the number of devices in the sequence local to the given process. Nested ``pmap`` s with ``devices`` specified in either the inner or outer ``pmap`` are not yet supported. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'. axis_size: Optional; the size of the mapped axis. donate_argnums: Specify which positional argument buffers are "donated" to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated. transform_metadata: Optional mapping of metadata for the transform. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. Returns: A parallelized version of ``f`` with arguments that correspond to those of ``f`` but with extra array axes at positions indicated by ``in_axes`` and with output that has an additional leading array axis (with the same size). """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( pmap, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('pmap') if not graph or not graph_updates: if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('pmap') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('pmap') ) if graph: extract.check_prefix(in_axes, 'in_axes', 'pmap') extract.check_prefix(out_axes, 'out_axes', 'pmap') pmapped_fn = jax.pmap( SimplePmapFn(f_unbound, graph=graph, out_axes=out_axes), axis_name=axis_name, in_axes=in_axes, out_axes=(out_axes, (in_axes, 0)), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, ) @functools.wraps(f_unbound) def simple_pmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(in_axes, None) if in_axes is not None else None, check_aliasing=in_axes is not None, ) extract.check_no_aliases('pmap', args=args, kwargs=kwargs) out, updates = pmapped_fn(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) if graph: out = extract.from_tree2(out) return out return simple_pmap_wrapper # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) pmapped_fn = jax.pmap( PmapFn(f_unbound, transform_metadata, in_axes, out_axes), axis_name=axis_name, in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, ) @functools.wraps(f) @graphlib.update_context('pmap') def vmap_wrapper(*args): pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap' ) pure_args_out, pure_out = pmapped_fn(*pure_args) _args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag='pmap', is_inner=False ) return out return vmap_wrapper # type: ignore # ------------------------------- # scan # ------------------------------- class Broadcasted(struct.PyTreeNode): data: tp.Any def _get_carry_argnum(axes, is_in_axes: bool): if axes is Carry: return 'all' elif isinstance(axes, int) or axes is None: return None obj_repr = 'in_axes' if is_in_axes else 'out_axes' carry_argnum: int | None = None prev_key: tp.Any = None for key, x in jax.tree_util.tree_leaves_with_path(axes): if x is not Carry: continue assert isinstance(key[0], jax.tree_util.SequenceKey) i = key[0].idx if len(key) >= 2: raise ValueError( f'Carry must at the top-level, it cannot be nested. Found {axes=}' ) if carry_argnum is not None: raise ValueError( f'Found multiple Carry definitions at ' f'{obj_repr}{jax.tree_util.keystr(prev_key)} and ' f'{obj_repr}{jax.tree_util.keystr(key)}' ) carry_argnum = i prev_key = key return carry_argnum def _check_out_axes(out_axes): for key, x in jax.tree_util.tree_leaves_with_path( out_axes, is_leaf=lambda x: x is None ): if x is None: raise ValueError( f'Cannot broadcast output state. ' f'Got out_axes=None at: out_axes{jax.tree_util.keystr(key)}' ) elif isinstance(x, StateAxes): for filter, value in x.items(): if value is None: raise ValueError( f'Cannot broadcast output state. ' f'Got StateAxes({{{filter}: None}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) elif value is Carry: raise ValueError( f'Cannot carry output state. ' f'Got StateAxes({{{filter}: Carry}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) def _check_carry_same_references(carry_arg, carry_arg_out): def check_carry_same_references(key_path, arg, out): if ( not isinstance(arg, jax.Array) or not isinstance(out, jax.Array) ) and arg is not out: raise ValueError( 'Carry references must be the same between iterations. ' f'Got {arg=} with id={id(arg)} and {out=} with id={id(out)} ' f'at carry{jax.tree_util.keystr(key_path)}' ) jax.tree_util.tree_map_with_path( check_carry_same_references, carry_arg, carry_arg_out, is_leaf=lambda x: graphlib.is_graph_node(x) and not isinstance(x, variablelib.Variable), ) def _scan_split_in( carry_deque: PytreeDeque[list[State | variablelib.Variable]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graphlib.SplitContext, path, prefix, x, ): if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable): vectorized_states: list[State | variablelib.Variable] = [] carry_states: list[State | variablelib.Variable] = [] broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, axis in zip(states, prefix.axes): if axis is None: broadcast_states.append(state) elif isinstance(axis, int): if axis != 0: state = jax.tree.map(lambda x: jnp.moveaxis(x, axis, 0), state) vectorized_states.append(state) else: # axis is Carry carry_states.append(state) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) graphdefs_deque.append(graphdef) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) if prefix != 0: state = jax.tree.map(lambda x: jnp.moveaxis(x, prefix, 0), state) vectorized_states.append(state) elif prefix is None: graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) graphdefs_deque.append(graphdef) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} args{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: broadcast_arrays.append(Broadcasted(x)) return Broadcasted(None) elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): raise ValueError( f'Expected an array, got {type(x).__name__} args' f'{jax.tree_util.keystr(path)}' ) if prefix != 0: x = jnp.moveaxis(x, prefix, 0) return x else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) def _scan_split_out( carry_deque: PytreeDeque[list[State | variablelib.Variable]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], /, ctx: graphlib.SplitContext, path: extract.KeyPath, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable): vectorized_states: list[State | variablelib.Variable] = [] carry_states: list[State | variablelib.Variable] = [] broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, filter, axis in zip(states, prefix.filters, prefix.axes): if axis is None: assert is_input_arg # validated by _check_out_axes broadcast_states.append(state) elif isinstance(axis, int): vectorized_states.append(state) elif axis is Carry: assert is_input_arg # validated by _check_out_axes carry_states.append(state) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {axis} for filter {filter} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) graphdefs_deque.append(graphdef) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) vectorized_states.append(state) elif prefix is None: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) graphdefs_deque.append(graphdef) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: assert not is_input_arg # validated by _check_out_axes return Broadcasted(None) elif isinstance(prefix, int): return x else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) def _scan_merge_in( carry_deque: PytreeDeque[list[State]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graphlib.MergeContext, path, prefix, x, ): if isinstance(x, extract.NodeStates): carry_states = carry_deque.popleft() broadcast_states = broadcast_deque.popleft() graphdef = graphdefs_deque.popleft() return ctx.merge(graphdef, *x.states, *carry_states, *broadcast_states) elif isinstance(x, Broadcasted): assert x.data is None return broadcast_arrays.popleft().data else: return x def _scan_merge_out( carry_deque: PytreeDeque[list[State]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State]], /, ctx: graphlib.MergeContext, path, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if isinstance(x, extract.NodeStates): states: list[State] = [] graphdef = graphdefs_deque.popleft() if is_input_arg: carry_states = deque(carry_deque.popleft()) broadcast_states = deque(broadcast_deque.popleft()) else: carry_states = deque[State]() broadcast_states = deque[State]() if isinstance(prefix, StateAxes): vectorized_states = deque(x.states) for axis in prefix.axes: if isinstance(axis, int): state = vectorized_states.popleft() state = jax.tree.map( lambda x: jnp.moveaxis(x, 0, axis) if axis != 0 else x, state, ) states.append(state) elif axis is None: states.append(broadcast_states.popleft()) else: # axis is Carry states.append(carry_states.popleft()) assert not carry_states and not broadcast_states assert not vectorized_states or ( len(vectorized_states) == 1 and not vectorized_states[0] ) elif isinstance(prefix, int): state = jax.tree.map( lambda x: jnp.moveaxis(x, 0, prefix) if prefix != 0 else x, x.state ) states.extend((state, *carry_states, *broadcast_states)) elif prefix is None: assert is_input_arg states.extend(broadcast_states) elif prefix is Carry: assert is_input_arg states.extend(carry_states) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) return ctx.merge(graphdef, *states) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: return x elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Expected an array, got {type(x).__name__} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) if prefix != 0: x = jnp.moveaxis(x, 0, prefix) return x else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) @dataclasses.dataclass(eq=False) class ScanFn: f: tp.Callable[..., tp.Any] input_carry_argnum: int | None | tp.Literal['all'] output_carry_argnum: int | None | tp.Literal['all'] in_axes: tp.Any out_axes: tp.Any transform_metadata: tp.Mapping[str, tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__( self, carry: tuple[ tp.Any, # carry_arg PytreeDeque[list[State]], # carry_deque PytreeDeque[list[State]], # broadcast_deque PytreeDeque[Broadcasted], # broadcast_arrays ], scan_in: tuple[tp.Any, ...], ): pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays = carry broadcast_deque_out = PytreeDeque(broadcast_deque) broadcast_arrays_out = PytreeDeque(broadcast_arrays) graphdefs_deque, pure_args = scan_in if self.input_carry_argnum == 'all': assert pure_args == () pure_args = (pure_carry_arg,) elif isinstance(self.input_carry_argnum, int): assert pure_args[self.input_carry_argnum] is None pure_args = extract.replace_at(pure_args, self.input_carry_argnum, pure_carry_arg) else: assert self.input_carry_argnum is None assert pure_carry_arg is None if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args: tuple = extract.from_tree( pure_args, prefix=self.in_axes, merge_fn=functools.partial( _scan_merge_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', is_inner=True, ) assert not carry_deque and not broadcast_deque and not broadcast_arrays out = self.f(*args) # extract the carry from the args if self.input_carry_argnum == 'all': carry_arg = args[0] elif isinstance(self.input_carry_argnum, int): carry_arg = args[self.input_carry_argnum] else: assert self.input_carry_argnum is None carry_arg = None # extract the carry from the output if self.output_carry_argnum == 'all': carry_arg_out = out out = None elif isinstance(self.output_carry_argnum, int): assert isinstance(out, tuple) carry_arg_out = out[self.output_carry_argnum] out = extract.replace_at(out, self.output_carry_argnum, None) else: assert self.output_carry_argnum is None carry_arg_out = None # TODO(cgarciae): allowing new references might lead to inconsistencies with # scan's looping semantics and we would also need to propagate the input _check_carry_same_references(carry_arg, carry_arg_out) args_out: tuple = extract.clear_non_graph_nodes(args) # replace the carry from the input args with the carry from the output if self.input_carry_argnum == 'all': args_out = (carry_arg_out,) elif isinstance(self.input_carry_argnum, int): args_out = extract.replace_at(args_out, self.input_carry_argnum, carry_arg_out) else: assert self.input_carry_argnum is None assert carry_arg_out is None carry_deque_out = PytreeDeque[list[State | variablelib.Variable]]() graphdefs_out = PytreeDeque[graphlib.GraphDef]() _broadcast_deque_out_tmp = PytreeDeque[ list[State | variablelib.Variable] ]() # discarded pure_args_out: tuple pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=functools.partial( _scan_split_out, carry_deque_out, graphdefs_out, _broadcast_deque_out_tmp ), map_non_graph_nodes=True, ctxtag='scan', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis, ) # extract the pure carry from the pure args if self.input_carry_argnum == 'all': pure_carry_arg_out = pure_args_out[0] pure_args_out = () elif isinstance(self.input_carry_argnum, int): pure_carry_arg_out = pure_args_out[self.input_carry_argnum] pure_args_out = extract.replace_at(pure_args_out, self.input_carry_argnum, None) else: assert self.input_carry_argnum is None pure_carry_arg_out = None carry_arg_out = ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) scan_out = ( graphdefs_out, pure_args_out, pure_out, ) return carry_arg_out, scan_out @dataclasses.dataclass(eq=False) class SimpleScanFn: f: tp.Callable[..., tp.Any] graph: bool in_axes: tp.Any out_axes: tp.Any out_is_tuple: bool carry_arg_index: int | None carry_out_index: int | None def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) updates = extract.mask_at(updates, self.carry_arg_index) snapshot = extract.mask_at(snapshot, self.carry_arg_index) if self.carry_arg_index is not None: carry_in = args[self.carry_arg_index] else: carry_in = None if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) if self.carry_out_index is not None: carry_out = out[self.carry_out_index] if self.out_is_tuple else out extract.check_same_variables(carry_in, carry_out, 'scan') def keep_fn(path, prefix, cur, snap): changed = extract.variable_changed(cur, snap) if prefix is None and changed: raise ValueError( f'Broadcast (None axis) Variable at {jax.tree_util.keystr(path)} ' 'was mutated during scan. Only Carry and scanned Variables can be ' 'updated.' ) return changed extract.check_no_aliases('scan', args=updates, out=out) updates = extract.mask_variable_updates( updates, snapshot, prefix=self.in_axes, keep_fn=keep_fn, ) if self.out_is_tuple: return (*out, updates) return (out, updates) @tp.overload def scan( *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def scan( f: F, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... def scan( f: F | type[Missing] = Missing, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """A Flax NNX transformation of `jax.lax.scan`_. Example:: import jax from flax import nnx class Block(nnx.Module): def __init__(self, input_dim, features, *, rngs): self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x: jax.Array): x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x class Model(nnx.Module): def __init__(self, num_layers, features, *, rngs): # In this model implementation we create # multiple blocks using vmap # As Block contains dropout op, we prefer # to split RNG into num_layers of RNGs # using @nnx.split_rngs decorator. # Next, nnx.vmap creates a vectorized version of Block. # in_axes and out_axes define vectorization axis # of the input splitted rngs and the output Block instance. # Both axes should be 0. @nnx.split_rngs(splits=num_layers) @nnx.vmap(in_axes=(0,), out_axes=0) def create_block(rngs: nnx.Rngs): return Block(features, features, rngs=rngs) self.blocks = create_block(rngs) self.num_layers = num_layers def __call__(self, x): # Forward pass method implementation # We use nnx.scan to apply sequentially the blocks # on the input, for example with num_layers=3 # output = block[0](x) # output = block[1](output) # output = block[2](output) # # In `forward` function defined below: # - x represents the loop carry value # - model is the data to scan along the leading axis # nnx.scan args: # - in_axes marks the inputs: x is marked as carry # and the model is to scan along the axis 0 # - out_axes marks the output as carry @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, model): x = model(x) return x return forward(x, self.blocks) # Alternatively, we can also decorate `self.__call__` method # @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) # def __call__(self, x): # return self.blocks(x) model = Model(2, 4, rngs=nnx.Rngs(0)) _, params, _ = nnx.split(model, nnx.Param, ...) print(params) # kernel of shape: (2, 4, 4) x = jnp.arange(5 * 4, dtype="float32").reshape((5, 4)) y = model(x) print(y.shape) # shape: (5, 4) Args: f: a Python function to be scanned length: optional integer specifying the number of loop iterations reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse unroll: optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. in_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying the kind of input args. Integer value would specify the axis of corresponding input data to scan along. :class:`flax.nnx.Carry` marks the input data as loop carry value. None marks the input data as auxiliary input. out_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying the kind of output args. See ``in_axes`` for details. Note that If ``in_axes`` contains :class:`flax.nnx.Carry` then ``out_axes`` must also contain :class:`flax.nnx.Carry`. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. .. _jax.lax.scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html> """ if f is Missing: return functools.partial( scan, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, in_axes=in_axes, out_axes=out_axes, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('scan') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: return _simple_scan( f, f_unbound, graph=graph, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) return _graph_updates_scan( f, f_unbound, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, transform_metadata=transform_metadata, ) def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, ): if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('scan') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('scan') ) if graph: extract.check_prefix(in_axes, 'in_axes', 'scan') extract.check_prefix(out_axes, 'out_axes', 'scan') out_is_tuple = isinstance(out_axes, tuple) if in_axes is Carry: in_axes = (Carry,) if isinstance(in_axes, tuple): carry_arg_index = next( (i for i, ax in enumerate(in_axes) if ax is Carry), None ) updates_out_axes = extract.mask_at(in_axes, carry_arg_index) else: carry_arg_index = None updates_out_axes = in_axes if isinstance(out_axes, tuple): carry_out_index = next( (i for i, ax in enumerate(out_axes) if ax is Carry), None ) else: carry_out_index = None simple_scan_fn = SimpleScanFn( f_unbound, graph=graph, in_axes=in_axes, out_axes=out_axes, out_is_tuple=out_is_tuple, carry_arg_index=carry_arg_index, carry_out_index=carry_out_index, ) if out_is_tuple: augmented_out_axes = (*out_axes, updates_out_axes) else: augmented_out_axes = (out_axes, updates_out_axes) @functools.wraps(f) def simple_scan_wrapper(*args): args = resolve_kwargs(f, args, {}) if graph: args = extract.to_tree2(args, prefix=in_axes) extract.check_no_aliases('scan', args=args) result = pure_jax_fancy_scan( simple_scan_fn, *args, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, in_axes=in_axes, out_axes=augmented_out_axes, ) if out_is_tuple: n = len(out_axes) out = result[:n] updates = result[n] else: out, updates = result masked_args = extract.mask_at(args, carry_arg_index) extract.apply_variable_updates(masked_args, updates) if carry_arg_index is not None: carry_in = args[carry_arg_index] carry_out = ( out[carry_out_index] if out_is_tuple else out ) extract.update_carry_variables(carry_in, carry_out) if out_is_tuple: out = extract.replace_at(out, carry_out_index, carry_in) else: out = carry_in if graph: out = extract.from_tree2(out) return out return simple_scan_wrapper def _graph_updates_scan( f, f_unbound, *, in_axes, out_axes, length, reverse, unroll, _split_transpose, transform_metadata, ): _check_out_axes(out_axes) input_carry_argnum = _get_carry_argnum(in_axes, is_in_axes=True) output_carry_argnum = _get_carry_argnum(out_axes, is_in_axes=False) if (input_carry_argnum is None and output_carry_argnum is not None) or ( input_carry_argnum is not None and output_carry_argnum is None ): raise ValueError( 'If one of in_axes or out_axes has Carry, the other must also have Carry. ' f'Got {in_axes=!r} and {out_axes=!r}' ) scan_fn = ScanFn( f_unbound, input_carry_argnum, output_carry_argnum, in_axes, out_axes, transform_metadata, ) @functools.wraps(f) @graphlib.update_context('scan') def scan_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) if in_axes is Carry and len(args) != 1: raise ValueError( f'When in_axes=Carry, the function must take exactly one argument, ' f'got {len(args)} arguments.' ) graphdefs_deque = PytreeDeque() carry_deque = PytreeDeque() broadcast_deque = PytreeDeque() broadcast_arrays = PytreeDeque() pure_args: tuple = extract.to_tree( args, prefix=in_axes, split_fn=functools.partial( _scan_split_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays ), map_non_graph_nodes=True, ctxtag='scan', ) if isinstance(input_carry_argnum, int): pure_carry_arg = pure_args[input_carry_argnum] pure_args = extract.replace_at(pure_args, input_carry_argnum, None) elif input_carry_argnum == 'all': pure_carry_arg = pure_args[0] pure_args = () else: assert input_carry_argnum is None pure_carry_arg = None carry = (pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays) scan_in = (graphdefs_deque, pure_args) carry_out, scan_out = jax.lax.scan( scan_fn, carry, scan_in, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) = carry_out ( graphdefs_out, pure_args_out, pure_out, ) = scan_out if input_carry_argnum == 'all': pure_args_out = (pure_carry_arg_out,) elif isinstance(input_carry_argnum, int): pure_args_out = extract.replace_at(pure_args_out, input_carry_argnum, pure_carry_arg_out) else: assert input_carry_argnum is None assert pure_carry_arg_out is None args_out, out = extract.from_tree( (pure_args_out, pure_out), prefix=(in_axes, out_axes), merge_fn=functools.partial( _scan_merge_out, carry_deque_out, graphdefs_out, broadcast_deque_out ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', is_inner=False, ) if input_carry_argnum == 'all': carry_arg = args_out[0] elif isinstance(input_carry_argnum, int): carry_arg = args_out[input_carry_argnum] else: assert input_carry_argnum is None carry_arg = None if output_carry_argnum == 'all': out = carry_arg elif isinstance(output_carry_argnum, int): out = extract.replace_at(out, output_carry_argnum, carry_arg) else: assert output_carry_argnum is None assert carry_arg is None return out return scan_wrapper def pure_jax_fancy_scan( f, *args, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, in_axes: tp.Any = (Carry, 0), out_axes: tp.Any = (Carry, 0), ): if in_axes is Carry: in_axes = (Carry,) is_axis_leaf = lambda x: x is None or x is Carry if isinstance(in_axes, tuple): for i, ax in enumerate(in_axes): if ax is Carry or ax is None or isinstance(ax, int): continue for leaf in jax.tree.leaves(ax, is_leaf=is_axis_leaf): if leaf is Carry: raise ValueError( 'Carry must be a top-level argument, it cannot be nested. ' f'Found Carry inside in_axes[{i}]={ax}' ) if isinstance(out_axes, tuple): for i, ax in enumerate(out_axes): if ax is Carry or ax is None or isinstance(ax, int): continue for path, leaf in jax.tree_util.tree_leaves_with_path( ax, is_leaf=is_axis_leaf, ): if leaf is Carry: raise ValueError( 'Carry must be a top-level argument, it cannot be nested. ' f'Found Carry at out_axes[{i}]{jax.tree_util.keystr(path)}' ) in_has_carry = in_axes is Carry or ( isinstance(in_axes, tuple) and Carry in in_axes ) out_has_carry = out_axes is Carry or ( isinstance(out_axes, tuple) and Carry in out_axes ) if in_has_carry != out_has_carry: raise ValueError( 'If one of in_axes or out_axes has Carry, the other must also ' f'have Carry. Got {in_axes=}, {out_axes=}' ) args_flat, args_treedef = jax.tree.flatten(args) _, in_axes_flat = extract.broadcast_prefix2( in_axes, args, is_leaf=is_axis_leaf, ) carry_indices: list[int] = [] broadcast_indices: list[int] = [] scan_indices: list[int] = [] scan_in_axes: list[int] = [] carry_leaves: list[tp.Any] = [] broadcast_leaves: list[tp.Any] = [] scan_leaves: list[tp.Any] = [] for i, (leaf, ax) in enumerate(zip(args_flat, in_axes_flat, strict=True)): if ax is Carry: carry_indices.append(i) carry_leaves.append(leaf) elif ax is None: broadcast_indices.append(i) broadcast_leaves.append(leaf) elif isinstance(ax, int): scan_indices.append(i) scan_in_axes.append(ax) if ax != 0: leaf = jnp.moveaxis(leaf, ax, 0) scan_leaves.append(leaf) else: raise ValueError(f'Invalid in_axes leaf value: {ax}') n_in = len(args_flat) out_info: list[tuple[ jax.tree_util.PyTreeDef, list[int], list[int], list[int], ]] = [] in_broadcast = jax.tree.map(lambda x: x, broadcast_leaves) def body_fn(carry_state, scan_x): flat = [None] * n_in for idx, j in enumerate(carry_indices): flat[j] = carry_state[idx] for idx, j in enumerate(broadcast_indices): flat[j] = in_broadcast[idx] if scan_x is not None: for idx, j in enumerate(scan_indices): flat[j] = scan_x[idx] reconstructed = args_treedef.unflatten(flat) out = f(*reconstructed) out_flat, out_treedef = jax.tree.flatten(out) out_axes_paths, out_axes_flat = extract.broadcast_prefix2( out_axes, out, is_leaf=is_axis_leaf, ) if not out_info: out_carry_idx = [] out_scan_idx = [] out_scan_axes = [] out_broadcast_idx = [] for j, oax in enumerate(out_axes_flat): if oax is Carry: out_carry_idx.append(j) elif oax is None: out_broadcast_idx.append(j) elif isinstance(oax, int): out_scan_idx.append(j) out_scan_axes.append(oax) else: raise ValueError(f'Invalid out_axes leaf value: {oax}') if out_broadcast_idx: broadcast_paths = [ jax.tree_util.keystr(out_axes_paths[j]) for j in out_broadcast_idx ] broadcast_str = "\n\n ".join(broadcast_paths) raise ValueError( 'Scan does not support broadcast outputs (None axis). The following ' f'output leaves are broadcast:\n\n {broadcast_str}\n' ) out_info.append( (out_treedef, out_carry_idx, out_scan_idx, out_scan_axes), ) oci = out_info[0][1] osi = out_info[0][2] new_carry = [out_flat[j] for j in oci] new_ys = [out_flat[j] for j in osi] return new_carry, new_ys final_carry, stacked_ys = jax.lax.scan( body_fn, carry_leaves, scan_leaves if scan_leaves else None, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) out_treedef, out_carry_idx, out_scan_idx, out_scan_axes = ( out_info[0] ) n_out = out_treedef.num_leaves out_flat: list[tp.Any] = [None] * n_out for idx, j in enumerate(out_carry_idx): out_flat[j] = final_carry[idx] for idx, j in enumerate(out_scan_idx): y = stacked_ys[idx] ax = out_scan_axes[idx] if ax != 0: y = jnp.moveaxis(y, 0, ax) out_flat[j] = y return out_treedef.unflatten(out_flat) # ------------------------------- # while_loop # ------------------------------- @dataclasses.dataclass(eq=False) class SimpleWhileLoopBodyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, val): val_variables, _ = extract.updates_and_snapshot(val) if self.graph: val = extract.from_tree2(val) out = self.f(val) if self.graph: out = extract.to_tree2(out) extract.check_same_variables(val_variables, out, 'while_loop') return out @dataclasses.dataclass(eq=False) class SimpleWhileLoopCondFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, val): if self.graph: val = extract.from_tree2(val) return self.f(val) @dataclasses.dataclass(eq=False) class WhileLoopCondFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, pure_val): val = extract.from_tree(pure_val) out = self.f(val) return out def _reconsile_index_mapping(tree_to_fix, example_tree): def f(a, b): if not isinstance(a, extract.NodeStates) or not isinstance( a._graphdef, graphlib.GraphDef ): return a return dataclasses.replace( a, _graphdef=a._graphdef.with_matching_outer_index(b._graphdef) ) return jax.tree.map(f, tree_to_fix, example_tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) def _add_fake_index_mapping(tree: tp.Any): def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( node_state._graphdef, graphlib.GraphDef ): return node_state return dataclasses.replace( node_state, _graphdef=node_state._graphdef.with_same_outer_index() ) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) def _remove_index_mapping(tree: tp.Any): """Remove a fake outer_index for the input to match that of the output.""" def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( node_state._graphdef, graphlib.GraphDef ): return node_state assert isinstance(node_state._graphdef, graphlib.GraphDef) node_state = dataclasses.replace( node_state, _graphdef=node_state._graphdef.with_no_outer_index() ) return node_state return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) @dataclasses.dataclass(eq=False) class WhileLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graphlib.update_context('while_loop_body') def __call__(self, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) val = extract.from_tree( pure_val_in, ctxtag='while_loop_body', is_inner=True ) out = self.f(val) pure_out = extract.to_tree(out, ctxtag='while_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: msg = ( "nnx.while_loop requires body function's input and output to " 'have the same reference and pytree structure, but they differ. ' 'If the mismatch comes from `outer_index` field, you might ' 'have modified reference structure within the body function, ' 'which is not allowed.' f'Detail of the mismatch: \n {str(e)}' ) raise ValueError(msg) return pure_out @graphlib.update_context('while_loop') def while_loop(cond_fun: tp.Callable[[T], tp.Any], body_fun: tp.Callable[[T], T], init_val: T, *, graph: bool | None = None, graph_updates: bool | None = None) -> T: """A Flax NNX transformation of `jax.lax.while_loop `_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of ``init_val`` inside ``body_fun``. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) Args: cond_fun: A function for the continue condition of the while loop, taking a single input of type ``T`` and outputting a boolean. body_fun: A function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: simple_body_fn = SimpleWhileLoopBodyFn(body_fun, graph=graph) simple_cond_fn = SimpleWhileLoopCondFn(cond_fun, graph=graph) if graph: init_val = extract.to_tree2(init_val) val_out = jax.lax.while_loop(simple_cond_fn, simple_body_fn, init_val) val_out = extract.update_carry_variables(init_val, val_out) if graph: val_out = extract.from_tree2(val_out) return val_out pure_init_val = extract.to_tree(init_val, ctxtag='while_loop') pure_init_val = _add_fake_index_mapping(pure_init_val) pure_out = jax.lax.while_loop( WhileLoopCondFn(cond_fun), WhileLoopBodyFn(body_fun), pure_init_val, ) out = extract.from_tree(pure_out, ctxtag='while_loop', is_inner=False) return out @dataclasses.dataclass(eq=False) class SimpleForiLoopBodyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, i, val): val_variables, _ = extract.updates_and_snapshot(val) if self.graph: val = extract.from_tree2(val) out = self.f(i, val) if self.graph: out = extract.to_tree2(out) extract.check_same_variables(val_variables, out, 'fori_loop') return out @dataclasses.dataclass(eq=False) class ForiLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graphlib.update_context('fori_loop_body') def __call__(self, i, pure_val_in): val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body', is_inner=True) out = self.f(i, val) pure_out = extract.to_tree(out, ctxtag='fori_loop_body') return pure_out @graphlib.update_context('fori_loop') def fori_loop(lower: int, upper: int, body_fun: tp.Callable[[int, T], T], init_val: T, *, unroll: int | bool | None = None, graph: bool | None = None, graph_updates: bool | None = None) -> T: """A Flax NNX transformation of `jax.lax.fori_loop `_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of `init_val` inside `body_fun`. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel[...] = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3) Args: lower: An integer representing the loop index lower bound (inclusive). upper: An integer representing the loop index upper bound (exclusive). body_fun: a function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: the initial input for body_fun. Must be of type ``T``. unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``). This argument is only applicable if the loop bounds are statically known. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: A loop value from the final iteration, of type ``T``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: simple_body_fn = SimpleForiLoopBodyFn(body_fun, graph=graph) if graph: init_val = extract.to_tree2(init_val) val_out = jax.lax.fori_loop( lower, upper, simple_body_fn, init_val, unroll=unroll, ) val_out = extract.update_carry_variables(init_val, val_out) if graph: val_out = extract.from_tree2(val_out) return val_out pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop') body = ForiLoopBodyFn(body_fun) pure_out = jax.eval_shape(body, lower, pure_init_val) pure_init_val = _reconsile_index_mapping(pure_init_val, pure_out) pure_out = jax.lax.fori_loop(lower, upper, body, pure_init_val, unroll=unroll) out = extract.from_tree(pure_out, ctxtag='fori_loop', is_inner=False) return out ================================================ FILE: flax/nnx/transforms/transforms.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pytype: skip-file from __future__ import annotations from abc import abstractmethod import dataclasses import functools import inspect import typing as tp from jax._src import checkify as checkify_lib from flax.nnx import ( extract, graphlib, variablelib, ) from flax.nnx.module import Module from flax.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, ) from flax.nnx.transforms import general from flax.typing import MISSING, Leaf, Missing import jax import jax.core import jax.stages A = tp.TypeVar('A') C = tp.TypeVar('C') B = tp.TypeVar('B') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) M = tp.TypeVar('M', bound=Module) MA = tp.TypeVar('MA', bound=Module) N = tp.TypeVar('N', bound=Module) StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable Leaves = list[Leaf] Index = int @tp.overload def resolve_kwargs( fun: tp.Callable[..., tp.Any], args: tuple, kwargs: dict[str, tp.Any], ) -> tuple: ... @tp.overload def resolve_kwargs() -> tp.Callable[[F], F]: ... def resolve_kwargs( fun: tp.Callable[..., tp.Any] | Missing = MISSING, args: tuple | Missing = MISSING, kwargs: dict[str, tp.Any] | Missing = MISSING, ) -> tuple | tp.Callable[[F], F]: if isinstance(fun, Missing): def resolve_kwargs_decorator(f): @functools.wraps(f) def resolve_kwargs_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) return f(*args) return resolve_kwargs_wrapper return resolve_kwargs_decorator # type: ignore if isinstance(args, Missing): raise ValueError('args must be provided') if isinstance(kwargs, Missing): raise ValueError('kwargs must be provided') if isinstance(fun, functools.partial): # functools.partial should have an opaque signature. fun = lambda *args, **kwargs: None ba = inspect.signature(fun).bind(*args, **kwargs) ba.apply_defaults() if ba.kwargs: raise TypeError('keyword arguments could not be resolved to positions') else: return ba.args # ------------------------------- # helper utilities for bound methods & indices # ------------------------------- def _resolve_bound_callable( f: tp.Callable[..., tp.Any], ) -> tuple[tp.Callable[..., tp.Any], tp.Any | None, bool]: """Detects and extracts bound methods from NNX Module callables. This function unwraps functools.partial layers to reach the underlying callable before checking if it's a bound method of an NNX Module. Args: f: A callable that may be a bound method of an NNX Module, potentially wrapped in functools.partial. Returns: A tuple of (unbound_fn, bound_self, was_bound) where: - unbound_fn: The unbound function (or original if not bound) - bound_self: The Module instance if f was bound, None otherwise - was_bound: True if f was a bound method, False otherwise Note: Preserves functools.partial wrappers around the callable and follows the same detection pattern as _get_unbound_fn in bridge/module.py. Detection occurs before any argnum shifting or index normalization. """ # Unwrap functools.partial layers to reach the underlying callable. partials: list[tuple[tuple[tp.Any, ...], dict[str, tp.Any] | None]] = [] g = f while isinstance(g, functools.partial): # type: ignore[arg-type] partials.append((g.args or (), g.keywords)) # type: ignore[attr-defined] g = g.func # type: ignore[attr-defined] bound_self = getattr(g, "__self__", None) was_bound = bool(inspect.ismethod(g) and isinstance(bound_self, Module)) if was_bound: g = g.__func__ # type: ignore[attr-defined] # Reapply partials in reverse unwrap order. for args, kwargs in reversed(partials): kwargs = {} if kwargs is None else kwargs g = functools.partial(g, *args, **kwargs) return g, (bound_self if was_bound else None), was_bound def _raise_bound_method_error(transform_name: str): """Raises a standardized error for bound method usage with NNX transforms. Args: transform_name: Name of the transform (e.g., 'grad', 'jit', 'remat'). """ raise ValueError( f"nnx.{transform_name} does not support bound methods. " f"Use the decorator form @nnx.{transform_name} or call " f"nnx.{transform_name}(MyClass.method)(instance, ...) with the unbound method." ) class LiftedModule(tp.Generic[M], Module): # type: ignore[ignored-abstractmethod] @abstractmethod def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: pass @property @abstractmethod def _submodule(self) -> M: pass # type: ignore[bad-return-type] # why pytype? def __call__(self, *args, **kwargs) -> tp.Any: return self.call(*args, **kwargs) # type: ignore @property def call(self) -> tp.Any: module = self def check_and_call(accessor: DelayedAccessor, *args, **kwargs): return self._call(accessor, *args, **kwargs) proxy = CallableProxy(check_and_call) # type: ignore[arg-type] while isinstance(module._submodule, LiftedModule): module = module._submodule proxy = proxy.call return proxy # type: ignore # ------------------------------- # simple transforms # ------------------------------- @dataclasses.dataclass(frozen=True) class ValueMetadata: var_type: type[variablelib.Variable] value: tp.Any metadata: dict[str, tp.Any] def _flatten_value_metadata( value_metadata: tp.Union[tp.Any, ValueMetadata], ): metadata = tuple(sorted(value_metadata.metadata.items())) return (value_metadata.value,), (value_metadata.var_type, metadata) def _unflatten_value_metadata(aux_data, children): var_type, metadata_items = aux_data metadata = dict(metadata_items) return ValueMetadata(var_type=var_type, value=children[0], metadata=metadata) jax.tree_util.register_pytree_node( ValueMetadata, _flatten_value_metadata, _unflatten_value_metadata, ) def _to_value_metadata(node): def to_value_metadata(x): if isinstance(x, variablelib.Variable): value = x.get_raw_value() if variablelib.is_array_ref(value): value = value[...] metadata = x.get_metadata() return ValueMetadata(var_type=x.var_type, value=value, metadata=metadata) return x return jax.tree.map( to_value_metadata, node, is_leaf=lambda x: isinstance(x, variablelib.Variable), ) def _to_variable(node): # import here to avoid circular imports from flax.nnx.spmd import get_var_pspec def to_variable(x): if isinstance(x, ValueMetadata): var = x.var_type._new(x.value, x.metadata) global_mesh = jax.sharding.get_abstract_mesh() if global_mesh.axis_sizes == (): global_mesh = None mesh = var.get_metadata("mesh", None) or global_mesh if mesh is not None and (not hasattr(var, 'sharding') or var.sharding is None): pspec = get_var_pspec(var) sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec) var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding)) return var return x return jax.tree.map( to_variable, node, is_leaf=lambda x: isinstance(x, ValueMetadata) ) @dataclasses.dataclass(eq=False) class SimpleEvalShapeFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs, out=out) return out def eval_shape( f: tp.Callable[..., A], *args: tp.Any, graph: bool | None = None, graph_updates: bool | None = None, **kwargs: tp.Any, ) -> A: """A \"lifted\" version of `jax.eval_shape `_ that can handle `flax.nnx.Module `_ / graph nodes as arguments. Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without performing any floating point operations (FLOPs) which can be expensive. This can be useful for performing shape inference, for example. Unlike `jax.eval_shape`, `nnx.eval_shape` will automatically compute the expected sharding based on Flax sharding metadata for all Variables not using explicit sharding. Args: f: the function to evaluate. *args: positional arguments to ``f``. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. **kwargs: keyword arguments to ``f``. """ f_call, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('eval_shape') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: args, kwargs = extract.to_tree2((args, kwargs)) extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs) out = jax.eval_shape( SimpleEvalShapeFn(f_call, graph=graph), *args, **kwargs ) if graph: out = extract.from_tree2(out) return out args, kwargs = extract.to_tree((args, kwargs)) @functools.wraps(f) def _eval_shape_fn(*args, **kwargs): args, kwargs = extract.from_tree((args, kwargs)) out = f_call(*args, **kwargs) return _to_value_metadata(extract.to_tree(out)) out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(_to_variable(out)) @dataclasses.dataclass(eq=False) class CheckifyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), ctxtag='checkify', is_inner=True ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args, kwargs, out), ctxtag='checkify' ) return pure_args_out, pure_kwargs_out, pure_out @dataclasses.dataclass(eq=False) class SimpleCheckifyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('checkify', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates def checkify( f: tp.Callable[..., checkify_lib.Out], errors: frozenset[type[checkify_lib.JaxException]] = checkify_lib.user_checks, # type: ignore graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]: """Reference-aware version of `jax.experimental.checkify `_. Example:: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> import dataclasses >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, a): ... self.a = nnx.Param(a) ... >>> @nnx.jit ... def f(m): ... y = jnp.sin(m.a) # error ... return m.a + y ... >>> m = Foo(a=jnp.inf) >>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m) >>> # err.throw() >>> print(err) Error(nan generated by primitive: sin.) Args: f: the function to checkify. errors: the set of error checks to enable. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ f_call, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('checkify') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: checkify_fn = checkify_lib.checkify( SimpleCheckifyFn(f_call, graph=graph), errors, ) @functools.wraps(f) def simple_checkify_wrapper(*args): if graph: args = extract.to_tree2(args) extract.check_no_aliases('checkify', args=args) error, (out, updates) = checkify_fn(*args) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(args, updates) return error, out return simple_checkify_wrapper # type: ignore checkify_fn = checkify_lib.checkify(CheckifyFn(f_call), errors) @functools.wraps(f) @graphlib.update_context('checkify') def checkify_wrapper(*args, **kwargs): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), ctxtag='checkify', ) error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn( *pure_args, **pure_kwargs ) args_out, kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), ctxtag='checkify', is_inner=False, ) return error, out return checkify_wrapper # type: ignore @dataclasses.dataclass(eq=False) class SimpleCondFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, _snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('switch', args=updates, out=out) return out, updates def cond( pred, true_fun: tp.Callable[..., A], false_fun: tp.Callable[..., A], *operands, graph: bool | None = None, graph_updates: bool | None = None, ) -> A: """Conditionally apply ``true_fun`` or ``false_fun``. Wraps `jax.lax.cond `__ to support Flax NNX modules and variables. Args: pred: boolean scalar. If True, ``true_fun`` is applied, otherwise ``false_fun``. true_fun: function to apply if ``pred`` is True. false_fun: function to apply if ``pred`` is False. *operands: operands passed to whichever branch is selected. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) extract.check_no_aliases('cond', operands=operands) out, updates = jax.lax.cond( pred, SimpleCondFn(true_fun, graph=graph), SimpleCondFn(false_fun, graph=graph), *operands, ) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(operands, updates) return out @general.split_inputs(ctxtag='cond') def _cond(pred, true_fun, false_fun, *operands): return jax.lax.cond( pred, general.merge_inputs(true_fun, ctxtag='cond'), general.merge_inputs(false_fun, ctxtag='cond'), *operands, ) return _cond(pred, true_fun, false_fun, *operands) def switch( index, branches: tp.Sequence[tp.Callable[..., A]], *operands, graph: bool | None = None, graph_updates: bool | None = None, ) -> A: """Select and apply one of ``branches`` based on ``index``. Wraps `jax.lax.switch `__ to support Flax NNX modules and variables. Args: index: integer scalar indicating which branch to apply. branches: sequence of functions to select from. *operands: operands passed to the selected branch. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) extract.check_no_aliases('switch', operands=operands) out, updates = jax.lax.switch( index, [SimpleCondFn(f, graph=graph) for f in branches], *operands, ) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(operands, updates) return out @general.split_inputs(ctxtag='switch') def _switch(index, branches, *operands): return jax.lax.switch( index, [general.merge_inputs(f, ctxtag='switch') for f in branches], *operands, ) return _switch(index, branches, *operands) ================================================ FILE: flax/nnx/traversals.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for flattening and unflattening mappings. """ from __future__ import annotations from collections.abc import Callable, Mapping from collections.abc import Iterable, Sequence from typing import Any, overload from flax import struct # the empty node is a struct.dataclass to be compatible with JAX. @struct.dataclass class _EmptyNode: pass empty_node = _EmptyNode() # TODO: In Python 3.10, use TypeAlias. IsLeafCallable = Callable[[tuple[Any, ...], Mapping[Any, Any]], bool] @overload def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | IsLeafCallable = None, sep: None = None ) -> dict[tuple[Any, ...], Any]: ... @overload def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | IsLeafCallable = None, sep: str, ) -> dict[str, Any]: ... def flatten_mapping(xs: Mapping[Any, Any], /, *, keep_empty_nodes: bool = False, is_leaf: None | IsLeafCallable = None, sep: None | str = None ) -> dict[Any, Any]: """Flatten a nested mapping. The nested keys are flattened to a tuple. See ``unflatten_mapping`` on how to restore the nested mapping. Example:: >>> from flax import nnx >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = nnx.traversals.flatten_mapping(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2} Note that empty mappings are ignored and will not be restored by ``unflatten_mapping``. Args: xs: a nested mapping keep_empty_nodes: replaces empty mappings with ``traverse_util.empty_node``. is_leaf: an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further). sep: if specified, then the keys of the returned mapping will be ``sep``-joined strings (if ``None``, then keys will be tuples). Returns: The flattened mapping. """ assert isinstance( xs, Mapping ), f'expected Mapping; got {type(xs).__qualname__}' def _key(path: tuple[Any, ...]) -> tuple[Any, ...] | str: if sep is None: return path return sep.join(path) def _flatten(xs: Any, prefix: tuple[Any, ...]) -> dict[Any, Any]: if not isinstance(xs, Mapping) or ( is_leaf and is_leaf(prefix, xs) ): return {_key(prefix): xs} result = {} is_empty = True for key, value in xs.items(): is_empty = False path = prefix + (key,) result.update(_flatten(value, path)) if keep_empty_nodes and is_empty: if prefix == (): # when the whole input is empty return {} return {_key(prefix): empty_node} return result return _flatten(xs, ()) def flatten_to_sequence( xs: Mapping[Any, Any], /, *, is_leaf: IsLeafCallable | None = None, ) -> list[tuple[Any, Any]]: """Flatten a nested mapping. The nested keys are flattened to a tuple. See ``unflatten_mapping`` on how to restore the nested mapping. Example:: >>> from flax import nnx >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = nnx.traversals.flatten_to_sequence(xs) >>> flat_xs [(('foo',), 1), (('bar', 'a'), 2)] Note that empty mappings are ignored and will not be restored by ``unflatten_mapping``. Args: xs: a nested mapping is_leaf: an optional function that takes the next nested mapping and nested keys and returns True if the nested mapping is a leaf (i.e., should not be flattened further). Returns: The flattened mapping. """ assert isinstance( xs, Mapping ), f'expected Mapping; got {type(xs).__qualname__}' result = [] def _flatten(xs: Any, prefix: tuple[Any, ...]): if not isinstance(xs, Mapping) or (is_leaf and is_leaf(prefix, xs)): result.append((prefix, xs)) else: for key, value in xs.items(): _flatten(value, (*prefix, key)) _flatten(xs, ()) return result @overload def unflatten_mapping( xs: Sequence[tuple[tuple[Any, ...], Any]], /, *, sep: None = None ) -> dict[Any, Any]: ... @overload def unflatten_mapping( xs: Mapping[tuple[Any, ...], Any], /, *, sep: None = None ) -> dict[Any, Any]: ... @overload def unflatten_mapping(xs: Mapping[str, Any], /, *, sep: str) -> dict[Any, Any]: ... def unflatten_mapping(xs: Any, /, *, sep: str | None = None) -> dict[Any, Any]: """Unflatten a mapping. See ``flatten_mapping`` Example:: >>> from flax import nnx >>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = nnx.traversals.unflatten_mapping(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}} Args: xs: a flattened mapping. sep: separator (same as used with ``flatten_mapping()``). Returns: The nested mapping. """ if isinstance(xs, Mapping): xs = xs.items() if not isinstance(xs, Iterable): raise TypeError( f'expected Mapping or Iterable; got {type(xs).__qualname__}' ) result: dict[Any, Any] = {} for path, value in xs: if sep is not None: path = path.split(sep) # type: ignore if value is empty_node: value = {} cursor = result for key in path[:-1]: if key not in cursor: cursor[key] = {} cursor = cursor[key] cursor[path[-1]] = value return result ================================================ FILE: flax/nnx/variablelib.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pytype: skip-file from __future__ import annotations import dataclasses import functools from functools import partial import itertools as it import threading import typing as tp from typing import Any import warnings from flax import config from flax import errors from flax.core import spmd as core_spmd from flax.nnx import reprlib, tracers, visualization from flax.typing import BaseConfigContext, MISSING, Missing, SizeBytes import jax from jax._src.state.types import AbstractRef import jax.experimental from jax.experimental import hijax as hjx import jax.tree_util as jtu import treescope # type: ignore[import-untyped] A = tp.TypeVar('A') B = tp.TypeVar('B') C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) P = tp.TypeVar('P', bound=property) V = tp.TypeVar('V', bound='Variable[Any]') GetValueHook = tp.Callable[['Variable[A]', A], A] SetValueHook = tp.Callable[['Variable[A]', A], A] CreateValueHook = tp.Callable[['Variable[A]', A], A] AxisName = str AxisIndex = int AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None] # JAX array refs were renamed a few times between JAX v0.7.0 and v0.8.0. # The following ensures we avoid an ImportError or DeprecationWarning. if hasattr(jax, 'new_ref') and hasattr(jax, 'Ref'): # JAX v0.7.2 or newer from jax import Ref elif hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'): # JAX v0.7.1 from jax import ArrayRef as Ref # type: ignore[import-untyped,no-redef] else: # JAX v0.7.0 or older from jax.experimental import MutableArray as Ref # type: ignore[no-redef] @dataclasses.dataclass class VariableContext(threading.local): variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list) variable_ref_stack: list[bool] = dataclasses.field(default_factory=list) eager_shard_stack: list[bool] = dataclasses.field(default_factory=list) VARIABLE_CONTEXT = VariableContext() class use_eager_sharding(BaseConfigContext): """Sets whether Variables should use eager sharding by default or not. Example usage:: >>> from flax import nnx >>> # Use eager sharding by default >>> nnx.use_eager_sharding(True) <...> >>> # Variable will now use eager sharding >>> nnx.using_eager_sharding() True It can also be used as a context manager to temporarily change the default behavior for a block of code:: >>> nnx.use_eager_sharding(False) <...> >>> with nnx.use_eager_sharding(True): ... nnx.using_eager_sharding() True >>> # it will reset outside >>> v = nnx.Variable(jax.numpy.ones((2, 3))) >>> nnx.using_eager_sharding() False Args: value: A boolean indicating if Variables should use eager sharding by default. Returns: A context manager that resets the context to the previous value. """ get_default = classmethod(lambda cls: config.flax_always_shard_variable) get_stack = classmethod(lambda cls: VARIABLE_CONTEXT.eager_shard_stack) def using_eager_sharding() -> bool: """Returns whether Variables are using eager sharding by default. Example:: >>> from flax import nnx >>> nnx.use_eager_sharding(True) <...> >>> nnx.using_eager_sharding() True >>> nnx.use_eager_sharding(False) <...> >>> nnx.using_eager_sharding() False Returns: A boolean indicating if Variables are using eager sharding by default. """ return use_eager_sharding.current_value() @dataclasses.dataclass(frozen=True) class VarDefaults(tp.Mapping[str, tp.Any]): hijax: bool ref: bool def __getitem__(self, key: str) -> tp.Any: return getattr(self, key) def __iter__(self) -> tp.Iterator[str]: return iter(dataclasses.asdict(self)) def __len__(self) -> int: return len(dataclasses.fields(self)) @tp.overload def var_defaults() -> VarDefaults: ... @tp.overload def var_defaults( *, hijax: bool | None = None, ref: bool | None = None ) -> VarDefaultsContext: ... def var_defaults( *, hijax: bool | None = None, ref: bool | None = None ) -> VarDefaultsContext | VarDefaults: if hijax is None and ref is None: return VarDefaults( hijax=VARIABLE_CONTEXT.variable_hijax_stack[-1] if VARIABLE_CONTEXT.variable_hijax_stack else config.flax_hijax_variable, ref=VARIABLE_CONTEXT.variable_ref_stack[-1] if VARIABLE_CONTEXT.variable_ref_stack else False, ) hijax_prev = None if hijax is not None: if VARIABLE_CONTEXT.variable_hijax_stack: hijax_prev = VARIABLE_CONTEXT.variable_hijax_stack[-1] VARIABLE_CONTEXT.variable_hijax_stack[-1] = hijax else: VARIABLE_CONTEXT.variable_hijax_stack.append(hijax) ref_prev = None if ref is not None: if VARIABLE_CONTEXT.variable_ref_stack: ref_prev = VARIABLE_CONTEXT.variable_ref_stack[-1] VARIABLE_CONTEXT.variable_ref_stack[-1] = ref else: VARIABLE_CONTEXT.variable_ref_stack.append(ref) return VarDefaultsContext( hijax_prev=hijax_prev, hijax_new=hijax, ref_prev=ref_prev, ref_new=ref, ) class VarDefaultsContext: def __init__( self, *, hijax_prev: bool | None, hijax_new: bool | None, ref_prev: bool | None, ref_new: bool | None, ): self.hijax_prev = hijax_prev self.hijax_new = hijax_new self.ref_prev = ref_prev self.ref_new = ref_new def __enter__(self): if self.hijax_new is not None and self.hijax_prev is not None: VARIABLE_CONTEXT.variable_hijax_stack.insert(-1, self.hijax_prev) if self.ref_new is not None and self.ref_prev is not None: VARIABLE_CONTEXT.variable_ref_stack.insert(-1, self.ref_prev) def __exit__(self, exc_type, exc_value, traceback): if self.hijax_new is not None: VARIABLE_CONTEXT.variable_hijax_stack.pop() if self.ref_new is not None: VARIABLE_CONTEXT.variable_ref_stack.pop() def __call__(self, f: F) -> F: # undo stack change for decorator usage if self.hijax_new is not None: VARIABLE_CONTEXT.variable_hijax_stack.pop() if self.hijax_prev is not None: VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_prev) if self.ref_new is not None: VARIABLE_CONTEXT.variable_ref_stack.pop() if self.ref_prev is not None: VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_prev) @functools.wraps(f) def var_defaults_wrapper(*args, **kwargs): if self.hijax_new is not None: VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_new) if self.ref_new is not None: VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_new) try: return f(*args, **kwargs) finally: if self.hijax_new is not None: VARIABLE_CONTEXT.variable_hijax_stack.pop() if self.ref_new is not None: VARIABLE_CONTEXT.variable_ref_stack.pop() return var_defaults_wrapper # type: ignore[return-value] def is_array_ref(x) -> tp.TypeGuard[Ref]: return isinstance(x, jax.Array | AbstractRef | Ref) and isinstance( jax.typeof(x), AbstractRef | Ref ) @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): raw_value: A set_value_hooks: tuple[SetValueHook[A], ...] = () get_value_hooks: tuple[GetValueHook[A], ...] = () create_value_hooks: tuple[CreateValueHook[A], ...] = () add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] = () remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] = () metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) PyTreeDef = tp.Any Leaf = tp.Any # --------------------------------- # hijax # --------------------------------- @dataclasses.dataclass(frozen=True) class VariableQDD: leaf_avals: tuple[hjx.AbstractValue, ...] treedef: PyTreeDef var_type: type[Variable[Any]] def to_tangent_qdd(self): leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) return VariableQDD(leaf_avals, self.treedef, self.var_type) def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) return VariableQDD(leaf_types, self.treedef, self.var_type) class VariableEffect(jax.core.Effect): ... variable_effect = VariableEffect() hjx.control_flow_allowed_effects.add_type(VariableEffect) def _bind_new_variable( *leaves, treedef, var_type, has_qdd, ref ) -> HijaxVariable: """Binds new_variable_p after instantiating any Zero tangents.""" leaves = tuple(hjx.instantiate_zeros(leaf) for leaf in leaves) return new_variable_p.bind( *leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref, ) def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: has_qdd = not variable.ref leaves, treedef = jax.tree.flatten(variable) var_type = type(variable) hijax_var = _bind_new_variable( *leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=variable.ref, ) return hijax_var class NewVariable(hjx.HiPrimitive): def is_high(self, *leaves, treedef, var_type, has_qdd, ref) -> bool: return True # type: ignore def impl(self, *leaves, treedef, var_type, has_qdd, ref): return HijaxVariable._new( leaves, treedef, var_type, has_qdd, ref=ref ) def abstract_eval(self, *leaves, treedef, var_type, has_qdd, ref): aval = AbstractVariable( var_type, treedef, leaves, has_qdd, ref=ref ) if has_qdd: qdd = VariableQDD(tuple(leaves), treedef, var_type) aval_qdd = hjx.AvalQDD(aval, qdd) # type: ignore return aval_qdd, {variable_effect} else: return aval, set() def to_lojax(self, *leaves, treedef, var_type, has_qdd, ref): return HijaxVariable._new(leaves, treedef, var_type, has_qdd, ref=ref) def jvp(_, primals, tangents, *, treedef, var_type, has_qdd, ref): if has_qdd: raise NotImplementedError( "jvp not implemented for 'new_variable' with QDD" ) primal_hijax_var = _bind_new_variable( *primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref ) tangent_hijax_var = _bind_new_variable( *tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref ) return primal_hijax_var, tangent_hijax_var def transpose( _, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd, ref ): if has_qdd: raise NotImplementedError( "transpose not implemented for 'new_variable' with QDD" ) avals = tuple( map( lambda x: x.aval if hjx.is_undefined_primal(x) else jax.typeof(x), input_leaves, ) ) leaves_dot = get_variable_p.bind( out_var, treedef=treedef, avals=avals, var_type=var_type, has_qdd=has_qdd, ) return leaves_dot new_variable_p = NewVariable(f'new_variable') def _set_hijax_state(hijax_var, variable: Variable): leaves, treedef = jax.tree.flatten(variable) set_variable_p.bind( hijax_var, *leaves, treedef=treedef, var_type=type(variable) ) class SetVariable(hjx.HiPrimitive): multiple_results = True def is_high(_, *leaf_avals, treedef, var_type) -> bool: return True # type: ignore # TODO: upstream this to Box def impl(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_treedef', treedef) return [] def abstract_eval( _, aval_mutable_qdd: hjx.AvalMutableQDD, *leaf_avals, treedef, var_type ): hijax_var: AbstractVariable = aval_mutable_qdd.aval # type: ignore assert isinstance(hijax_var, AbstractVariable) if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type aval_mutable_qdd.mutable_qdd.update( VariableQDD(leaf_avals, treedef, var_type) ) effects = {variable_effect} if hijax_var.has_qdd else set() return [], effects # TODO better typechecking... def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_treedef', treedef) return [] def jvp(_, primals, tangents, *, treedef, var_type): variable: Variable variable, *vals = primals variable_dot: Variable variable_dot, *val_dots = tangents if type(variable_dot._raw_value) is hjx.Zero: raise Exception( "can't differentiate Variable._set operation, " 'did you forget jax.lax.stop_gradient?' ) set_variable_p.bind( variable, *vals, treedef=treedef, var_type=type(variable) ) set_variable_p.bind( variable_dot, *val_dots, treedef=treedef, var_type=type(variable_dot) ) return [], [] def transpose(_, *args, treedef, var_type): raise NotImplementedError('transpose not implemented for SetHijaxVariable') set_variable_p = SetVariable(f'set_variable') def _get_hijax_state(hijax_var: HijaxVariable | AbstractVariable) -> Variable: if hijax_var.has_qdd: tys: VariableQDD = jax.experimental.cur_qdd(hijax_var) leaf_vals = get_variable_p.bind( hijax_var, treedef=tys.treedef, avals=tuple(tys.leaf_avals), var_type=hijax_var._var_type, has_qdd=hijax_var.has_qdd, ) variable = jax.tree.unflatten(tys.treedef, leaf_vals) else: assert hijax_var._treedef is not None assert hijax_var._leaves is not None if isinstance(hijax_var, (jax.core.Tracer, AbstractVariable)): leaf_avals = hijax_var._leaves else: leaf_avals = tuple(map(jax.typeof, hijax_var._leaves)) leaf_vals = get_variable_p.bind( hijax_var, treedef=hijax_var._treedef, avals=leaf_avals, var_type=hijax_var._var_type, has_qdd=hijax_var.has_qdd, ) variable = jax.tree.unflatten(hijax_var._treedef, leaf_vals) return variable class GetVariable(hjx.HiPrimitive): multiple_results = True def impl( self, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd ): return hijax_var._leaves def abstract_eval(self, abstract_var, *, treedef, avals, var_type, has_qdd): if has_qdd: return avals, {variable_effect} else: return avals, set() def to_lojax( _, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd ): return hijax_var._leaves def jvp(_, primals, tangents, *, treedef, avals, var_type, has_qdd): if has_qdd: raise NotImplementedError( "jvp not implemented for 'get_variable' with QDD" ) (hijax_var,), (hijax_var_dot,) = primals, tangents return ( get_variable_p.bind( hijax_var, treedef=treedef, avals=avals, var_type=var_type, has_qdd=has_qdd, ), get_variable_p.bind( hijax_var_dot, treedef=treedef, avals=tuple(a.to_tangent_aval() for a in avals), var_type=var_type, has_qdd=has_qdd, ), ) def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd): if has_qdd: raise NotImplementedError( "transpose not implemented for 'get_variable' with QDD" ) abstract_var: AbstractVariable = ( hijax_var.aval if hjx.is_undefined_primal(hijax_var) else jax.typeof(hijax_var) ) hijax_var_dot = _bind_new_variable( *out, treedef=abstract_var._treedef, var_type=var_type, has_qdd=has_qdd, ref=abstract_var.ref, ) return (hijax_var_dot,) get_variable_p = GetVariable(f'get_variable') # --------------------------------- # HijaxVariable # --------------------------------- def _variable_has_changed(old: Variable, new: Variable) -> bool: old_structure = jax.tree.structure(old) new_structure = jax.tree.structure(new) if old_structure != new_structure: # type: ignore[operator] return True old_leaves = jax.tree.leaves(old) new_leaves = jax.tree.leaves(new) return any(o is not n for o, n in zip(old_leaves, new_leaves)) def _as_hijax_property(name: str, *, get: bool, set: bool) -> property: """Creates a property that operates on the hijax type.""" def _getter_wrapper(hijax_var): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out def _setter_wrapper(hijax_var, value): variable = _get_hijax_state(hijax_var) setattr(variable, name, value) _set_hijax_state(hijax_var, variable) _hijax_property = property( fget=_getter_wrapper if get else None, fset=_setter_wrapper if set else None, ) return _hijax_property # type: ignore[return] def _as_aval_property(p: property) -> hjx.aval_property: """Wraps a property `p` operate on the aval type.""" _aval_property = hjx.aval_property(fget=p.fget) return _aval_property # type: ignore[return] def _as_hijax_attribute(name: str) -> property: """Creates a property that operates on the hijax type.""" def _getter_wrapper(hijax_var): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out _getter_wrapper.__name__ = name _hijax_property = property(fget=_getter_wrapper) return _hijax_property # type: ignore[return] def _as_hijax_method(name: str) -> tp.Any: """Creates a method that operates on the hijax type.""" def hijax_method_wrapper(hijax_var, *args, **kwargs): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) method = getattr(variable, name) out = method(*args, **kwargs) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out hijax_method_wrapper.__name__ = name return hijax_method_wrapper def _as_tracer_method(name: str): def op(self, hijax_var, *args, **kwargs): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name)(*args, **kwargs) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out op.__name__ = name return op def _not_an_attribute_property(name: str): def _op(self): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return property(_op) class HijaxVariableMeta(type): def __instancecheck__(self, instance): if super().__instancecheck__(instance): return True if isinstance(instance, jax.core.Tracer): ty = jax.typeof(instance) return isinstance(ty, AbstractVariable) return False class HijaxVariable( tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta ): # type: ignore __slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd', '_ref') _treedef: PyTreeDef _leaves: tuple[Leaf, ...] _var_type: type[Variable[tp.Any]] has_qdd: bool _ref: bool @classmethod def _new( cls, leaves: tuple[Leaf, ...], treedef: PyTreeDef, var_type: type[Variable[A]], has_qdd: bool, *, ref: bool = False, ): hijax_var = object.__new__(cls) object.__setattr__(hijax_var, '_treedef', treedef) object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_var_type', var_type) object.__setattr__(hijax_var, 'has_qdd', has_qdd) object.__setattr__(hijax_var, '_ref', ref) return hijax_var __init__ = _as_hijax_method('__init__') @property def value(self) -> A: raise NotImplementedError( 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' ' variable[...]\n\n' 'For other Variable types use:\n\n' ' variable.get_value()\n' ) @value.setter def value(self, new_value: A): raise NotImplementedError( 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' ' variable[...] = new_value\n\n' 'For other Variable types use:\n\n' ' variable.set_value(new_value)\n' ) @property def var_type(self) -> type[Variable[A]]: return self._var_type _trace_state = _as_hijax_property('_trace_state', get=True, set=False) _can_update = _as_hijax_property('_can_update', get=True, set=False) _check_can_update = _as_hijax_method('_check_can_update') __getattr__ = _as_hijax_method('__getattr__') __setattr__ = _as_hijax_method('__setattr__') __delattr__ = _as_hijax_method('__delattr__') type = _as_hijax_property('type', get=True, set=False) type = _as_hijax_property('type', get=True, set=False) hijax = _as_hijax_property('hijax', get=True, set=False) @property def ref(self) -> bool: return self._ref get_metadata = _as_hijax_method('get_metadata') set_metadata = _as_hijax_method('set_metadata') def copy_from(self, other: Variable[A] | HijaxVariable[A]) -> None: if isinstance(other, HijaxVariable): other = _get_hijax_state(other) variable = _get_hijax_state(self) variable.copy_from(other) # type: ignore[arg-type] _set_hijax_state(self, variable) def update_from_state(self, variable_state: Variable[A] | HijaxVariable[A]): if isinstance(variable_state, HijaxVariable): variable_state = _get_hijax_state(variable_state) variable = _get_hijax_state(self) variable.update_from_state(variable_state) # type: ignore[arg-type] _set_hijax_state(self, variable) get_raw_value = _as_hijax_method('get_raw_value') set_raw_value = _as_hijax_method('set_raw_value') set_value = _as_hijax_method('set_value') get_value = _as_hijax_method('get_value') create_value = _as_hijax_method('create_value') set_raw_value = _as_hijax_method('set_raw_value') add_axis = _as_hijax_method('add_axis') remove_axis = _as_hijax_method('remove_axis') copy = _as_hijax_method('copy') replace = _as_hijax_method('replace') to_state = _as_hijax_method('to_state') @classmethod def from_metadata(cls, value: A, metadata: dict[str, tp.Any]): return cls._var_type.from_metadata(value, metadata) # type: ignore[misc] __nnx_repr__ = _as_hijax_method('__nnx_repr__') __treescope_repr__ = _as_hijax_method('__treescope_repr__') # -------------------------------------------- # proxy methods # -------------------------------------------- __jax_array__ = _as_hijax_method('__jax_array__') __getitem__ = _as_hijax_method('__getitem__') __setitem__ = _as_hijax_method('__setitem__') __delitem__ = _as_hijax_method('__delitem__') __call__ = _as_hijax_method('__call__') __len__ = _as_hijax_method('__len__') __iter__ = _as_hijax_method('__iter__') __contains__ = _as_hijax_method('__contains__') __add__ = _as_hijax_method('__add__') __sub__ = _as_hijax_method('__sub__') __mul__ = _as_hijax_method('__mul__') __matmul__ = _as_hijax_method('__matmul__') __truediv__ = _as_hijax_method('__truediv__') __floordiv__ = _as_hijax_method('__floordiv__') __mod__ = _as_hijax_method('__mod__') __divmod__ = _as_hijax_method('__divmod__') __pow__ = _as_hijax_method('__pow__') __lshift__ = _as_hijax_method('__lshift__') __rshift__ = _as_hijax_method('__rshift__') __and__ = _as_hijax_method('__and__') __xor__ = _as_hijax_method('__xor__') __or__ = _as_hijax_method('__or__') __radd__ = _as_hijax_method('__radd__') __rsub__ = _as_hijax_method('__rsub__') __rmul__ = _as_hijax_method('__rmul__') __rmatmul__ = _as_hijax_method('__rmatmul__') __rtruediv__ = _as_hijax_method('__rtruediv__') __rfloordiv__ = _as_hijax_method('__rfloordiv__') __rmod__ = _as_hijax_method('__rmod__') __rdivmod__ = _as_hijax_method('__rdivmod__') __rpow__ = _as_hijax_method('__rpow__') __rlshift__ = _as_hijax_method('__rlshift__') __rrshift__ = _as_hijax_method('__rrshift__') __rand__ = _as_hijax_method('__rand__') __rxor__ = _as_hijax_method('__rxor__') __ror__ = _as_hijax_method('__ror__') __iadd__ = _as_hijax_method('__iadd__') __isub__ = _as_hijax_method('__isub__') __imul__ = _as_hijax_method('__imul__') __imatmul__ = _as_hijax_method('__imatmul__') __itruediv__ = _as_hijax_method('__itruediv__') __ifloordiv__ = _as_hijax_method('__ifloordiv__') __imod__ = _as_hijax_method('__imod__') __ipow__ = _as_hijax_method('__ipow__') __ilshift__ = _as_hijax_method('__ilshift__') __irshift__ = _as_hijax_method('__irshift__') __iand__ = _as_hijax_method('__iand__') __ixor__ = _as_hijax_method('__ixor__') __ior__ = _as_hijax_method('__ior__') __neg__ = _as_hijax_method('__neg__') __pos__ = _as_hijax_method('__pos__') __abs__ = _as_hijax_method('__abs__') __invert__ = _as_hijax_method('__invert__') __complex__ = _as_hijax_method('__complex__') __int__ = _as_hijax_method('__int__') __float__ = _as_hijax_method('__float__') __index__ = _as_hijax_method('__index__') __round__ = _as_hijax_method('__round__') __trunc__ = _as_hijax_method('__trunc__') __floor__ = _as_hijax_method('__floor__') __ceil__ = _as_hijax_method('__ceil__') # -------------------------------------------- # hijax interface # -------------------------------------------- def cur_qdd(self): return self.type_state() def type_state(self): leaf_avals = tuple(map(jax.typeof, self._leaves)) return VariableQDD(leaf_avals, self._treedef, self._var_type) def _to_abstract_variable(hijax_var: HijaxVariable): if hijax_var.has_qdd: treedef = None leaves = None else: leaves = tuple(map(jax.typeof, hijax_var._leaves)) treedef = hijax_var._treedef return AbstractVariable( hijax_var._var_type, treedef, leaves, hijax_var.has_qdd, ref=hijax_var.ref, ) hjx.register_hitype(HijaxVariable, _to_abstract_variable) # --------------------------------- # AbstractVariable # --------------------------------- class AbstractVariable(tp.Generic[A], hjx.MutableHiType): __slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd', '_ref'] _var_type: type[Variable[A]] _treedef: PyTreeDef | None _leaves: tuple[hjx.AbstractValue, ...] | None has_qdd: bool _ref: bool @property def ref(self) -> bool: return self._ref @property def hijax(self): return True _check_can_update = hjx.aval_method(HijaxVariable._check_can_update) def __init__( self, var_type: type[Variable[A]], treedef: PyTreeDef | None, leaves: tuple[hjx.AbstractValue, ...] | None, has_qdd: bool, *, ref: bool = False, ): if (treedef is None) ^ (leaves is None): raise ValueError('treedef and leaves must be both provided or both None') object.__setattr__(self, '_treedef', treedef) object.__setattr__(self, '_leaves', leaves) object.__setattr__(self, '_var_type', var_type) object.__setattr__(self, 'has_qdd', has_qdd) object.__setattr__(self, '_ref', ref) @property def dtype(self): raise AttributeError @property def ndim(self): raise AttributeError @property def size(self): raise AttributeError @property def shape(self): raise AttributeError def __getattr__(self, name: str): # Forward unknown attributes to the value if hasattr(AbstractVariable, name): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) if name.startswith('_'): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return _as_aval_property(_as_hijax_attribute(name)) # __setattr__ supported via __getattr__ # __delattr__ CURRENTLY NOT SUPPORTED type = _as_aval_property(HijaxVariable.type) get_metadata = hjx.aval_method(HijaxVariable.get_metadata) set_metadata = hjx.aval_method(HijaxVariable.set_metadata) copy_from = hjx.aval_method(HijaxVariable.copy_from) update_from_state = hjx.aval_method(HijaxVariable.update_from_state) get_raw_value = hjx.aval_method(HijaxVariable.get_raw_value) set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value) set_value = hjx.aval_method(HijaxVariable.set_value) get_value = hjx.aval_method(HijaxVariable.get_value) create_value = hjx.aval_method(HijaxVariable.create_value) set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value) add_axis = hjx.aval_method(HijaxVariable.add_axis) remove_axis = hjx.aval_method(HijaxVariable.remove_axis) replace = hjx.aval_method(HijaxVariable.replace) @hjx.aval_method def from_metadata(self, value, metadata: dict[str, tp.Any]): aval: AbstractVariable = self.aval # type: ignore variable = aval._var_type.from_metadata(value, metadata) return variable copy = hjx.aval_method(HijaxVariable.copy) replace = hjx.aval_method(HijaxVariable.replace) to_state = hjx.aval_method(HijaxVariable.to_state) def __str__(self): return f'{self._var_type.__name__}()' def __repr__(self): return f'{self._var_type.__name__}()' @hjx.aval_method def __treescope_repr__(self, path, subtree_renderer): raise NotImplementedError # --------------------------------- # proxy methods # --------------------------------- __jax_array__ = hjx.aval_method(HijaxVariable.__jax_array__) _getitem = _as_tracer_method('__getitem__') _setitem = _as_tracer_method('__setitem__') # __delitem__ CURRENTLY NOT SUPPORTED # __call__ CURRENTLY NOT SUPPORTED _len = _as_tracer_method('__len__') _iter = _as_tracer_method('__iter__') # __contains__ CURRENTLY NOT SUPPORTED _add = _as_tracer_method('__add__') _sub = _as_tracer_method('__sub__') _mul = _as_tracer_method('__mul__') _matmul = _as_tracer_method('__matmul__') _truediv = _as_tracer_method('__truediv__') _floordiv = _as_tracer_method('__floordiv__') _mod = _as_tracer_method('__mod__') _divmod = _as_tracer_method('__divmod__') _pow = _as_tracer_method('__pow__') _lshift = _as_tracer_method('__lshift__') _rshift = _as_tracer_method('__rshift__') _and = _as_tracer_method('__and__') _xor = _as_tracer_method('__xor__') _or = _as_tracer_method('__or__') _radd = _as_tracer_method('__radd__') _rsub = _as_tracer_method('__rsub__') _rmul = _as_tracer_method('__rmul__') _rmatmul = _as_tracer_method('__rmatmul__') _rtruediv = _as_tracer_method('__rtruediv__') _rfloordiv = _as_tracer_method('__rfloordiv__') _rmod = _as_tracer_method('__rmod__') _rdivmod = _as_tracer_method('__rdivmod__') _rpow = _as_tracer_method('__rpow__') _rlshift = _as_tracer_method('__rlshift__') _rrshift = _as_tracer_method('__rrshift__') _rand = _as_tracer_method('__rand__') _rxor = _as_tracer_method('__rxor__') _ror = _as_tracer_method('__ror__') # _iadd CURRENTLY NOT SUPPORTED # _isub CURRENTLY NOT SUPPORTED # _imul CURRENTLY NOT SUPPORTED # _imatmul CURRENTLY NOT SUPPORTED # _itruediv CURRENTLY NOT SUPPORTED # _ifloordiv CURRENTLY NOT SUPPORTED # _imod CURRENTLY NOT SUPPORTED # _ipow CURRENTLY NOT SUPPORTED # _ilshift CURRENTLY NOT SUPPORTED # _irshift CURRENTLY NOT SUPPORTED # _iand CURRENTLY NOT SUPPORTED # _ixor CURRENTLY NOT SUPPORTED # _ior CURRENTLY NOT SUPPORTED _neg = _as_tracer_method('__neg__') _pos = _as_tracer_method('__pos__') _abs = _as_tracer_method('__abs__') _invert = _as_tracer_method('__invert__') _complex = _as_tracer_method('__complex__') _int = _as_tracer_method('__int__') _float = _as_tracer_method('__float__') _index = _as_tracer_method('__index__') _round = _as_tracer_method('__round__') _trunc = _as_tracer_method('__trunc__') _floor = _as_tracer_method('__floor__') _ceil = _as_tracer_method('__ceil__') # -------------------------------- # hijax interface # -------------------------------- cur_qdd = _not_an_attribute_property('cur_qdd') def __hash__(self): if self._leaves is not None and self._treedef is not None: return hash( (AbstractVariable, self._var_type, self._treedef, self._leaves) ) else: assert self._leaves is None and self._treedef is None return hash((AbstractVariable, self._var_type)) def __eq__(self, other): return ( isinstance(other, AbstractVariable) and self._var_type == other._var_type ) def str_short(self, short_dtypes=False, **_) -> str: # type: ignore return f'{self._var_type.__name__}()' # mutable interface def lo_ty_qdd(self, variable_state: VariableQDD) -> list: # type: ignore return [lo_ty for t in variable_state.leaf_avals for lo_ty in t.lo_ty()] def new_from_loval( # type: ignore[override] self, variable_state: VariableQDD, *lo_vals ) -> HijaxVariable: lo_vals_ = iter(lo_vals) hi_vals = [ hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore for hi_ty in variable_state.leaf_avals ] assert next(lo_vals_, None) is None variable: Variable = jax.tree.unflatten(variable_state.treedef, hi_vals) return HijaxVariable._new( hi_vals, variable_state.treedef, self._var_type, has_qdd=self.has_qdd, ref=self.ref, ) # will be mutated def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore leaf_vals, treedef = jax.tree.flatten(_get_hijax_state(variable)) assert treedef == variable_state.treedef return [ lo_val for hi_ty, hi_val in zip(variable_state.leaf_avals, leaf_vals) for lo_val in hi_ty.lower_val(hi_val) ] # type: ignore def update_from_loval( # type: ignore[override] self, box_state: VariableQDD, variable, *lo_vals ) -> None: lo_vals_ = iter(lo_vals) hi_vals = [ hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore for hi_ty in box_state.leaf_avals ] assert next(lo_vals_, None) is None _set_hijax_state(variable, jax.tree.unflatten(box_state.treedef, hi_vals)) def to_tangent_aval(self): return AbstractVariable( self._var_type, self._treedef, self._leaves, self.has_qdd, ref=self.ref, ) # -------------------------------------------- # Variable # -------------------------------------------- def _remap_sharding_metadata(metadata: dict[str, tp.Any]) -> None: if 'sharding' in metadata: warnings.warn( "'sharding' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=3, ) metadata['out_sharding'] = metadata.pop('sharding') if 'sharding_names' in metadata: warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=3, ) metadata['out_sharding'] = metadata.pop('sharding_names') def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]: def variable_operator_method(self, other): value = self.get_value() if isinstance(other, Variable): other = other.get_value() return getattr(value, name)(other) variable_operator_method.__name__ = name return variable_operator_method def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]: def variable_unary_operator_method(self): value = self.get_value() return getattr(value, name)() variable_unary_operator_method.__name__ = name return variable_unary_operator_method class VariableMeta(type): def __new__(cls, cls_name, bases, attrs): if '__slots__' not in attrs: attrs['__slots__'] = () return super().__new__(cls, cls_name, bases, attrs) def __instancecheck__(self, instance): if super().__instancecheck__(instance): return True if isinstance(instance, jax.core.Tracer): ty = jax.typeof(instance) if isinstance(ty, AbstractVariable): return issubclass(ty._var_type, self) if isinstance(instance, HijaxVariable): return issubclass(instance._var_type, self) return False if not tp.TYPE_CHECKING: def __call__(cls, *args, **kwargs): return cls._variable_meta_call(*args, **kwargs) def _variable_meta_call(cls, *args, **kwargs): variable = super().__call__(*args, **kwargs) if variable.hijax: return _new_hijax_from_variable(variable) return variable class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): """The base class for all ``Variable`` types. Create custom ``Variable`` types by subclassing this class. Numerous NNX graph functions can filter for specific ``Variable`` types, for example, :func:`split`, :func:`state`, :func:`pop`, and :func:`State.filter`. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> linear_variables = nnx.state(model, nnx.Param) >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ 'custom_variable': CustomVariable( value=(1, 3) ) }) >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ 'custom_variable': CustomVariable( value=(1, 3) ), 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) """ __slots__ = ('_raw_value', '_trace_state', '_var_metadata') _raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] required_metadata = frozenset( ['hijax', 'ref', 'eager_sharding'] ) @property def var_type(self): return type(self) @property def hijax(self) -> bool: return self._var_metadata['hijax'] @property def ref(self) -> bool: return self._var_metadata['ref'] @property def shape(self: Variable[jax.Array]) -> tuple[int, ...]: return self.get_value().shape @property def sharding_names(self): warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) return self.get_metadata('out_sharding', None) def __init__( self, value: A | VariableMetadata[A], *, hijax: bool | None = None, ref: bool | None = None, eager_sharding: bool | None = None, **metadata: tp.Any, ): var_t = type(self) if isinstance(value, VariableMetadata): aux_metadata = dict(value.metadata) if 'hijax' in aux_metadata: if hijax is not None and hijax != aux_metadata['hijax']: raise ValueError( 'Cannot specify hijax both in VariableMetadata and as an ' 'argument to Variable constructor.' ) hijax = aux_metadata.pop('hijax') if 'ref' in aux_metadata: if ref is not None and ref != aux_metadata['ref']: raise ValueError( 'Cannot specify ref both in VariableMetadata and as an ' 'argument to Variable constructor.' ) ref = aux_metadata.pop('ref') if 'eager_sharding' in aux_metadata: if ( eager_sharding is not None and eager_sharding != aux_metadata['eager_sharding'] ): raise ValueError( 'Cannot specify eager_sharding both in VariableMetadata and as ' 'an argument to Variable constructor.' ) eager_sharding = aux_metadata['eager_sharding'] metadata.update(aux_metadata) value = tp.cast(A, value.raw_value) if hijax is None: hijax = var_defaults().hijax if ref is None: ref = var_defaults().ref if eager_sharding is None: eager_sharding = using_eager_sharding() if any(is_array_ref(v) for v in jax.tree.leaves(value)): raise ValueError('Cannot pass a Ref directly into Variable constructor.') metadata['hijax'] = hijax metadata['ref'] = ref metadata['eager_sharding'] = eager_sharding object.__setattr__(self, '_trace_state', tracers.TraceState()) object.__setattr__(self, '_var_metadata', metadata) object.__setattr__(self, '_raw_value', value) if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata: metadata['on_get_value'] = var_t.on_get_value if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata: metadata['on_set_value'] = var_t.on_set_value if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata: metadata['on_create_value'] = var_t.on_create_value if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata: metadata['on_add_axis'] = var_t.on_add_axis if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata: metadata['on_remove_axis'] = var_t.on_remove_axis _remap_sharding_metadata(metadata) # run create_value hooks if 'on_create_value' in metadata: value = metadata['on_create_value'](self, value) object.__setattr__(self, '_raw_value', value) # run create_value hook value = self.create_value(value) # type: ignore # shard the _value if applicable if eager_sharding and 'out_sharding' in metadata: value = core_spmd.shard_value( value, metadata['out_sharding'], metadata.get('sharding_rules', None), metadata.get('mesh', None), ) if ref: value = jax.new_ref(value) # type: ignore object.__setattr__(self, '_raw_value', value) @property def _can_update(self) -> bool: """Whether the Variable can be updated in-place in the current trace context.""" if self.hijax: return True else: return self._trace_state.is_valid() def _check_can_update(self): if not self.hijax and not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) def __getattr__(self, name: str) -> tp.Any: if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] return getattr(object.__getattribute__(self, '_raw_value'), name) def __setattr__(self, name: str, value: tp.Any): self._check_can_update() try: object.__setattr__(self, name, value) except AttributeError as e: raise AttributeError( f'Cannot set attribute {name}. ' f'To set Variable metadata use either:\n\n' f' variable.set_metadata({name}=value)\n\nor\n\n' f" variable.set_metadata('{name}', value)" ) from e def __delattr__(self, name: str): self._check_can_update() try: object.__delattr__(self, name) except AttributeError as e: raise AttributeError( f'Cannot delete attribute {name}. ' f'To delete Variable metadata use:\n\n' f" variable.del_metadata('{name}')" ) from e # NOTE(cgarciae): adding this for backward compatibility with VariableState @property def type(self): """The type of the variable.""" return type(self) @tp.overload def get_metadata( self, *, exclude_required: bool = False ) -> dict[str, tp.Any]: ... @tp.overload def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ... def get_metadata( self, name: str | None = None, default: tp.Any = MISSING, *, exclude_required: bool | None = None, ) -> tp.Any: """Get metadata for the Variable. Args: name: The key of the metadata element to get. If not provided, returns the full metadata dictionary. default: The default value to return if the metadata key is not found. If not provided and the key is not found, raises a KeyError. """ if name is not None and exclude_required is not None: raise TypeError( "Cannot specify both 'name' and 'exclude_required' arguments." ) metadata = self._var_metadata.copy() if name is None: if not isinstance(default, Missing): raise TypeError( "Cannot provide a default value when 'name' is not provided. " f'Got default={default}' ) if exclude_required: for key in self.required_metadata: metadata.pop(key, None) return metadata if name not in metadata and not isinstance(default, Missing): return default return metadata[name] @tp.overload def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... @tp.overload def set_metadata(self, name: str, value: tp.Any, /) -> None: ... @tp.overload def set_metadata(self, **metadata: tp.Any) -> None: ... def set_metadata(self, *args, **kwargs) -> None: """Set metadata for the Variable. `set_metadata` can be called in 3 ways: 1. By passing a dictionary of metadata as the first argument, this will replace the entire Variable's metadata. 2. By passing a name and value as the first two arguments, this will set the metadata entry for the given name to the given value. 3. By using keyword arguments, this will update the Variable's metadata with the provided key-value pairs. """ self._check_can_update() if args and kwargs: raise TypeError( 'Cannot mix positional and keyword arguments in set_metadata' ) if len(args) == 1: metadata = dict(args[0]) _remap_sharding_metadata(metadata) if 'hijax' not in metadata: metadata['hijax'] = self.hijax if metadata['hijax'] != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, ' f'got {metadata["hijax"]}' ) if 'ref' not in metadata: metadata['ref'] = self.ref if metadata['ref'] != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, ' f'got {metadata["ref"]}' ) if 'eager_sharding' not in metadata: metadata['eager_sharding'] = self.eager_sharding if metadata['eager_sharding'] != self.eager_sharding: raise ValueError( f'Cannot change `eager_sharding` metadata, expected ' f'{self.eager_sharding}, got {metadata["eager_sharding"]}' ) self._var_metadata = metadata elif len(args) == 2: name, value = args if name == 'sharding_names': warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) name = 'out_sharding' elif name == 'sharding': warnings.warn( "'sharding' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) name = 'out_sharding' if name == 'hijax' and value != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, got {value}' ) if name == 'ref' and value != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, got {value}' ) self._var_metadata[name] = value elif kwargs: _remap_sharding_metadata(kwargs) if 'hijax' in kwargs and kwargs['hijax'] != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, ' f'got {kwargs["hijax"]}' ) if 'ref' in kwargs and kwargs['ref'] != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, ' f'got {kwargs["ref"]}' ) self._var_metadata.update(kwargs) else: raise TypeError( f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, ' f'got args={args}, kwargs={kwargs}' ) def has_metadata(self, name: str) -> bool: """Check if the Variable has a metadata entry for the given name. Args: name: The key of the metadata element to check. Returns: True if the metadata entry exists, False otherwise. """ return name in self._var_metadata def del_metadata(self, name: str) -> None: """Delete a metadata entry for the Variable. Args: name: The key of the metadata element to delete. """ self._check_can_update() if name in ('hijax', 'ref'): raise ValueError(f'Cannot delete `{name}` metadata') del self._var_metadata[name] def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( f'Cannot copy from incompatible container, ' f'expected {type(self).__name__}, got {type(other).__name__}' ) if self is other: return self._raw_value = other._raw_value self._var_metadata.clear() self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: Variable[A]): self._raw_value = variable_state._raw_value if self._var_metadata != variable_state._var_metadata: metadata = variable_state.get_metadata() metadata['hijax'] = self.hijax metadata['ref'] = self.ref self._var_metadata = metadata @tp.final def get_raw_value(self) -> A: return self._raw_value # @tp.final def set_raw_value(self, value: A, *, _unsafe_bypass_check: bool = False): if not _unsafe_bypass_check: self._check_can_update() self._raw_value = value @property def raw_value(self) -> A: warnings.warn( "'.raw_value' access is now deprecated. Use:\n\n" ' variable.get_raw_value()\n', DeprecationWarning, stacklevel=2, ) return self.get_raw_value() @raw_value.setter def raw_value(self, value: A): warnings.warn( "'.raw_value' setter is now deprecated. Use:\n\n" ' variable.set_raw_value(value)\n', DeprecationWarning, stacklevel=2, ) self.set_raw_value(value) @property def value(self) -> A: warnings.warn( "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" ' variable[...]\n\n' 'For other Variable types use:\n\n' ' variable.get_value()\n', DeprecationWarning, stacklevel=2, ) return self.get_value() @value.setter def value(self, value: A): warnings.warn( "'.value' setter is now deprecated. For Variable[Array] instances use:\n\n" ' variable[...] = value\n\n' 'For other Variable types use:\n\n' ' variable.set_value(value)\n', DeprecationWarning, stacklevel=2, ) self.set_value(value) def create_value(self, value: A): return value def get_value(self, *, index: tp.Any = MISSING) -> A: value = jax.tree.map(lambda x: x, self._raw_value) # make a copy if not isinstance(index, Missing): if is_array_ref(value): value = value[index] elif isinstance(value, jax.Array) and index is ...: pass # skip trivial access else: value = value[index] elif is_array_ref(value): value = value[...] if 'on_get_value' in self._var_metadata: value = self._var_metadata['on_get_value'](self, value) return value # type: ignore def set_value(self, value: A, *, index: tp.Any = MISSING): value = jax.tree.map(lambda x: x, value) # make a copy if 'on_set_value' in self._var_metadata: value = self._var_metadata['on_set_value'](self, value) # update _raw_value if is_array_ref(self._raw_value): if isinstance(index, Missing): self._raw_value[...] = value else: self._raw_value[index] = value elif isinstance(self._raw_value, jax.Array) and ( not isinstance(index, Missing) ): # check if its a full replace to av if ( index == ... and isinstance(value, jax.Array) and value.shape == self._raw_value[index].shape and value.dtype == self._raw_value.dtype and ( getattr(value, 'sharding', None) == getattr(self._raw_value, 'sharding', None) ) ): self._raw_value = value else: self._raw_value = self._raw_value.at[index].set(value) # type: ignore else: if isinstance(index, Missing): self._raw_value = value else: self._raw_value[index] = value # type: ignore def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: self._var_metadata['on_remove_axis'](self, axis_index, axis_name) @tp.overload def copy(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload def copy(self, **kwargs) -> Variable[A]: ... def copy( self, value: tp.Any = MISSING, *, _copy_ref: bool = True, **updates, ) -> Variable[tp.Any]: assert 'raw_value' not in updates new_metadata = self.get_metadata() | updates if not isinstance(value, Missing): pass elif 'value' in updates: value = updates.pop('value') else: value = self.get_raw_value() if _copy_ref and is_array_ref(value): value = value[...] if _copy_ref and new_metadata['ref']: value = jax.new_ref(value) new_metadata['ref'] = True value = jax.tree.map(lambda x: x, value) # make a copy obj = self.from_metadata(value, new_metadata) return obj @classmethod def _new( cls, value: A, metadata: dict[str, tp.Any], ) -> Variable[A]: obj = object.__new__(cls) # skip __setattr__ for trace_state initialization object.__setattr__(obj, '_trace_state', tracers.TraceState()) object.__setattr__(obj, '_var_metadata', metadata) object.__setattr__(obj, '_raw_value', value) return obj @classmethod def from_metadata( cls, value: A, attributes: dict[str, tp.Any], ) -> Variable[A]: variable = cls._new(value, dict(attributes)) if attributes['hijax']: variable = _new_hijax_from_variable(variable) # type: ignore[assignment] return variable # type: ignore[return-value] replace = copy to_state = copy def __nnx_repr__(self): stats = SizeBytes.from_any(self._raw_value) if stats: comment = f' # {stats}' else: comment = '' yield reprlib.Object(type=type(self).__name__, comment=comment) yield reprlib.Attr('value', self.get_value()) for name, value in self._var_metadata.items(): if name == 'hijax' and value == config.flax_hijax_variable: continue if name == 'ref' and not value: continue if name == 'eager_sharding' and value == config.flax_always_shard_variable: continue yield reprlib.Attr(name, value) def __treescope_repr__(self, path, subtree_renderer): size_bytes = SizeBytes.from_any(self.get_value()) if size_bytes: stats_repr = f' # {size_bytes}' first_line_annotation = treescope.rendering_parts.comment_color( treescope.rendering_parts.text(f'{stats_repr}') ) else: first_line_annotation = None metadata = { name: value for name, value in self._var_metadata.items() if not (name == 'hijax' and value == config.flax_hijax_variable) and not (name == 'ref' and not value) and not (name == 'eager_sharding' and value == config.flax_always_shard_variable) } children = {'value': self.get_value(), **metadata} return visualization.render_object_constructor( object_type=type(self), attributes=children, path=path, subtree_renderer=subtree_renderer, first_line_annotation=first_line_annotation, ) # hooks API if tp.TYPE_CHECKING: def on_get_value(self, value: A) -> A: ... def on_set_value(self, value: A) -> A: ... def on_create_value(self, value: A) -> A: ... def on_add_axis( self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def on_remove_axis( self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def __jax_array__(self): return self.get_value() # pickle support def __getstate__(self): return { '_raw_value': self._raw_value, '_trace_state': self._trace_state, '_var_metadata': self._var_metadata, } def __setstate__(self, state): # skip __setattr__ for trace_state initialization object.__setattr__(self, '_trace_state', state['_trace_state']) object.__setattr__(self, '_var_metadata', state['_var_metadata']) object.__setattr__(self, '_raw_value', state['_raw_value']) # -------------------------------------------- # proxy methods # -------------------------------------------- @tp.overload def __getitem__(self: Variable[jax.Array], key) -> jax.Array: ... @tp.overload def __getitem__(self: Variable[dict[tp.Any, B]], key) -> B: ... @tp.overload def __getitem__(self: Variable[list[B]], key: int) -> B: ... @tp.overload def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B: ... @tp.overload def __getitem__(self, key) -> tp.Any: ... def __getitem__(self, key): return self.get_value(index=key) def __setitem__(self, key, value) -> None: self.set_value(value, index=key) def __delitem__(self, key) -> None: value = self.get_value() del value[key] # type: ignore self.set_value(value) # type: ignore def __call__(self, *args, **kwargs) -> tp.Any: return self.get_value()(*args, **kwargs) # type: ignore def __len__(self) -> int: return len(self.get_value()) # type: ignore def __iter__(self) -> tp.Iterator: return iter(self.get_value()) # type: ignore def __contains__(self, item) -> bool: return item in self.get_value() # type: ignore __add__ = _variable_operator('__add__') __sub__ = _variable_operator('__sub__') __mul__ = _variable_operator('__mul__') __matmul__ = _variable_operator('__matmul__') __truediv__ = _variable_operator('__truediv__') __floordiv__ = _variable_operator('__floordiv__') __mod__ = _variable_operator('__mod__') __pow__ = _variable_operator('__pow__') __lshift__ = _variable_operator('__lshift__') __rshift__ = _variable_operator('__rshift__') __and__ = _variable_operator('__and__') __xor__ = _variable_operator('__xor__') __or__ = _variable_operator('__or__') __radd__ = _variable_operator('__radd__') __rsub__ = _variable_operator('__rsub__') __rmul__ = _variable_operator('__rmul__') __rmatmul__ = _variable_operator('__rmatmul__') __rtruediv__ = _variable_operator('__rtruediv__') __rfloordiv__ = _variable_operator('__rfloordiv__') __rmod__ = _variable_operator('__rmod__') __rpow__ = _variable_operator('__rpow__') __rlshift__ = _variable_operator('__rlshift__') __rrshift__ = _variable_operator('__rrshift__') __rand__ = _variable_operator('__rand__') __rxor__ = _variable_operator('__rxor__') __ror__ = _variable_operator('__ror__') def __eq__(self, other) -> bool: if isinstance(other, Variable): other = other.get_value() return self.get_value().__eq__(other) # type: ignore def __iadd__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] += x` instead.' ) def __isub__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] -= x` instead.' ) def __imul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] *= x` instead.' ) def __imatmul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value @= x` instead.' ) def __itruediv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value /= x` instead.' ) def __ifloordiv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value //= x`` instead.' ) def __imod__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value %= x` instead.' ) def __ipow__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value **= x`` instead.' ) def __ilshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value <<= x`` instead.' ) def __irshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value >>= x`` instead.' ) def __iand__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value &= x` instead.' ) def __ixor__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value ^= x` instead.' ) def __ior__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value |= x` instead.' ) __neg__ = _variable_unary_operator('__neg__') __pos__ = _variable_unary_operator('__pos__') __abs__ = _variable_unary_operator('__abs__') __invert__ = _variable_unary_operator('__invert__') __complex__ = _variable_unary_operator('__complex__') __int__ = _variable_unary_operator('__int__') __float__ = _variable_unary_operator('__float__') __index__ = _variable_unary_operator('__index__') __trunc__ = _variable_unary_operator('__trunc__') __floor__ = _variable_unary_operator('__floor__') __ceil__ = _variable_unary_operator('__ceil__') def __round__(self, ndigits: int = 0) -> A: return self.get_value().__round__(ndigits) # type: ignore # -------------------------------------------- def __init_subclass__(cls) -> None: if '__slots__' not in vars(cls): cls.__slots__ = () # type: ignore[assignment] super().__init_subclass__() jax.tree_util.register_pytree_with_keys( cls, flatten_with_keys=_variable_flatten_with_keys, unflatten_func=partial(_variable_unflatten, cls), # type: ignore flatten_func=_variable_flatten, ) def _variable_flatten_with_keys(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) node = (jtu.GetAttrKey('value'), x._raw_value) return (node,), metadata def _variable_flatten(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) return (x._raw_value,), metadata def _variable_unflatten( cls: type[Variable[tp.Any]], static: tuple[tuple[str, tp.Any], ...], children: tuple[tp.Any], ): return cls._new(children[0], dict(static)) jax.tree_util.register_pytree_with_keys( Variable, flatten_with_keys=_variable_flatten_with_keys, unflatten_func=partial(_variable_unflatten, Variable), # type: ignore flatten_func=_variable_flatten, ) VariableState = Variable class Param(Variable[A]): """The canonical learnable parameter. All learnable parameters in NNX layer modules will have the ``Param`` :class:`Variable` type:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) }) """ pass class BatchStat(Variable[A]): """The mean and variance batch statistics stored in the :class:`BatchNorm` layer. Note, these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'mean': BatchStat( value=(3,) ), 'scale': Param( value=(3,) ), 'var': BatchStat( value=(3,) ) }) """ pass class Cache(Variable[A]): """Autoregressive cache in :class:`MultiHeadAttention`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(0), ... ) >>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ 'cache_index': Cache( value=() ), 'cached_key': Cache( value=(1, 2, 3) ), 'cached_value': Cache( value=(1, 2, 3) ) }) """ pass class Intermediate(Variable[A]): """:class:`Variable` type that is typically used for :func:`Module.sow`. Use :func:`nnx.capture` to retrieve the sowed values:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y, intms = nnx.capture(model, nnx.Intermediate)(x) >>> jax.tree.map(jnp.shape, intms) State({ 'i': Intermediate( value=((1, 3),) ) }) """ pass class Perturbation(Intermediate[A]): """:class:`Variable` type that is typically used for :func:`Module.perturb`. Use :func:`nnx.capture` to retrieve the perturbation values:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = self.perturb('i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y, perturbations = nnx.capture(model, nnx.Perturbation)(x) >>> jax.tree.map(jnp.shape, perturbations) State({ 'i': Perturbation( value=(1, 3) ) }) """ pass def with_metadata( initializer: F, set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), get_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), create_value_hooks: tp.Union[ CreateValueHook[A], tp.Sequence[CreateValueHook[A]] ] = (), add_axis_hooks: tp.Union[ AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] ] = (), remove_axis_hooks: tp.Union[ RemoveAxisHook[Variable[A]], tp.Sequence[RemoveAxisHook[Variable[A]]], ] = (), **metadata: tp.Any, ) -> F: if set_value_hooks: if callable(set_value_hooks): set_value_hooks = (set_value_hooks,) else: set_value_hooks = tuple(set_value_hooks) else: set_value_hooks = () if get_value_hooks: if callable(get_value_hooks): get_value_hooks = (get_value_hooks,) else: get_value_hooks = tuple(get_value_hooks) else: get_value_hooks = () if create_value_hooks: if callable(create_value_hooks): create_value_hooks = (create_value_hooks,) else: create_value_hooks = tuple(create_value_hooks) else: create_value_hooks = () if add_axis_hooks: if callable(add_axis_hooks): add_axis_hooks = (add_axis_hooks,) else: add_axis_hooks = tuple(add_axis_hooks) else: add_axis_hooks = () if remove_axis_hooks: if callable(remove_axis_hooks): remove_axis_hooks = (remove_axis_hooks,) else: remove_axis_hooks = tuple(remove_axis_hooks) else: remove_axis_hooks = () @functools.wraps(initializer) def wrapper(*args): return VariableMetadata( initializer(*args), set_value_hooks=set_value_hooks, get_value_hooks=get_value_hooks, create_value_hooks=create_value_hooks, add_axis_hooks=add_axis_hooks, remove_axis_hooks=remove_axis_hooks, metadata=metadata, ) return wrapper # type: ignore ################################################### ### Variable type/class <-> string name mapping ### ################################################### # Assumption: the mapping is 1-1 and unique. VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {} def variable_type_from_name( name: str, /, *, base: type[Variable[tp.Any]] = Variable, allow_register: bool = False, ) -> tp.Type[Variable[tp.Any]]: """Given a Linen-style collection name, get or create its NNX Variable class.""" if name not in VariableTypeCache: if not allow_register: raise ValueError( f'Name {name} is not registered in the registry. ' 'To register a new name, use register_variable_name() ' 'or set allow_register=True.' ) VariableTypeCache[name] = type(name, (base,), {}) return VariableTypeCache[name] def variable_name_from_type( typ: tp.Type[Variable[tp.Any]], /, *, allow_register: bool = False ) -> str: """Given an NNX Variable type, get its Linen-style collection name. Should output the exact inversed result of `variable_type_from_name()`.""" for name, t in VariableTypeCache.items(): if typ == t: return name if not allow_register: raise ValueError( f'Type {typ} is not registered in the registry. ' 'To register a new type, use register_variable_name() ' 'or set allow_register=True.' ) name = typ.__name__ if name in VariableTypeCache: raise ValueError( 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' 'It cannot be linked with this type {typ}.' ) register_variable_name(name, typ) return name @tp.overload def register_variable_name( name: str, typ: type[Variable[tp.Any]], *, overwrite: bool = False, ) -> type[Variable[tp.Any]]: ... @tp.overload def register_variable_name( name: str, *, overwrite: bool = False, ) -> tp.Callable[[type[Variable[tp.Any]]], type[Variable[tp.Any]]]: ... def register_variable_name( name: str, typ: type[Variable[A]] | Missing = MISSING, *, overwrite=False, ) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]: """Register a pair of Linen collection name and its NNX type.""" if isinstance(typ, Missing): return partial(register_variable_name, name, overwrite=overwrite) typ = tp.cast(type[Variable[A]], typ) if not overwrite and name in VariableTypeCache: raise ValueError( f'Name {name} already mapped to type {VariableTypeCache[name]}. ' 'To overwrite, call register_variable_name() with `overwrite=True`.' ) VariableTypeCache[name] = typ return typ # add known variable type names register_variable_name('params', Param) register_variable_name('batch_stats', BatchStat) register_variable_name('cache', Cache) register_variable_name('intermediates', Intermediate) register_variable_name('perturbations', Perturbation) ================================================ FILE: flax/nnx/visualization.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp import treescope # type: ignore[import-untyped] from treescope import rendering_parts, renderers try: from IPython import get_ipython in_ipython = get_ipython() is not None except ImportError: in_ipython = False def display(*args): """Display the given objects using the Treescope pretty-printer. If treescope is not installed or the code is not running in IPython, ``display`` will print the objects instead. """ if not in_ipython: for x in args: print(x) return for x in args: treescope.display(x, ignore_exceptions=True, autovisualize=True) def render_object_constructor( object_type: type[tp.Any], attributes: tp.Mapping[str, tp.Any], path: str | None, subtree_renderer: renderers.TreescopeSubtreeRenderer, roundtrippable: bool = False, color: str | None = None, first_line_annotation: rendering_parts.RenderableTreePart | None = None, ) -> rendering_parts.Rendering: """Renders an object in "constructor format", similar to a dataclass. This produces a rendering like `Foo(bar=1, baz=2)`, where Foo identifies the type of the object, and bar and baz are the names of the attributes of the object. It is a *requirement* that these are the actual attributes of the object, which can be accessed via `obj.bar` or similar; otherwise, the path renderings will break. This can be used from within a `__treescope_repr__` implementation via :: def __treescope_repr__(self, path, subtree_renderer): return repr_lib.render_object_constructor( object_type=type(self), attributes=, path=path, subtree_renderer=subtree_renderer, ) Args: object_type: The type of the object. attributes: The attributes of the object, which will be rendered as keyword arguments to the constructor. path: The path to the object. When `render_object_constructor` is called from `__treescope_repr__`, this should come from the `path` argument to `__treescope_repr__`. subtree_renderer: The renderer to use to render subtrees. When `render_object_constructor` is called from `__treescope_repr__`, this should come from the `subtree_renderer` argument to `__treescope_repr__`. roundtrippable: Whether evaluating the rendering as Python code will produce an object that is equal to the original object. This implies that the keyword arguments are actually the keyword arguments to the constructor, and not some other attributes of the object. color: The background color to use for the object rendering. If None, does not use a background color. A utility for assigning a random color based on a string key is given in `treescope.formatting_util`. first_line_annotation: An annotation for the first line of the node when it is expanded. Returns: A rendering of the object, suitable for returning from `__treescope_repr__`. """ if roundtrippable: constructor = rendering_parts.siblings( rendering_parts.maybe_qualified_type_name(object_type), '(' ) closing_suffix = rendering_parts.text(')') else: constructor = rendering_parts.siblings( rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('<')), rendering_parts.maybe_qualified_type_name(object_type), '(', ) closing_suffix = rendering_parts.siblings( ')', rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('>')), ) children = [] for i, (name, value) in enumerate(attributes.items()): child_path = None if path is None else f'{path}.{name}' if i < len(attributes) - 1: # Not the last child. Always show a comma, and add a space when # collapsed. comma_after = rendering_parts.siblings( ',', rendering_parts.fold_condition(collapsed=rendering_parts.text(' ')), ) else: # Last child: only show the comma when the node is expanded. comma_after = rendering_parts.fold_condition( expanded=rendering_parts.text(',') ) child_line = rendering_parts.build_full_line_with_annotations( rendering_parts.siblings_with_annotations( f'{name}=', subtree_renderer(value, path=child_path), ), comma_after, ) children.append(child_line) return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor, children=children, suffix=closing_suffix, path=path, background_color=color, first_line_annotation=first_line_annotation, ) ================================================ FILE: flax/oss/ .git-blame-ignore-revs ================================================ # .git-blame-ignore-revs # # These commits will be ignored by the github blame view. # The git blame CLI can ignore them as well by doing: # git blame --ignore-revs-file .git-blame-ignore-revs # or via global config: # git config --global blame.ignoreRevsFile .git-blame-ignore-revs # see the blame.markIgnoredLines and blame.markUnblamableLines options. # # remove all trailing whitespaces 442df07ca1a90f04c685cdae9f8e488bbffc2f83 ================================================ FILE: flax/py.typed ================================================ # Marker file for PEP 561. The package uses inline types. ================================================ FILE: flax/serialization.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Serialization utilities for Jax. All Flax classes that carry state (e.g., Optimizer) can be turned into a state dict of numpy arrays for easy serialization. """ import enum import threading from contextlib import contextmanager from typing import Any import jax import msgpack import numpy as np _STATE_DICT_REGISTRY: dict[Any, Any] = {} class _ErrorContext(threading.local): """Context for deserialization error messages.""" def __init__(self): self.path = [] _error_context = _ErrorContext() @contextmanager def _record_path(name): try: _error_context.path.append(name) yield finally: _error_context.path.pop() def current_path(): """Current state_dict path during deserialization for error messages.""" return '/'.join(_error_context.path) class _NamedTuple: """Fake type marker for namedtuple for registry.""" pass def _is_namedtuple(x): """Duck typing test for namedtuple factory-generated objects.""" return isinstance(x, tuple) and hasattr(x, '_fields') def from_state_dict(target, state: dict[str, Any], name: str = '.'): """Restores the state of the given target using a state dict. This function takes the current target as an argument. This lets us know the exact structure of the target, as well as lets us add assertions that shapes and dtypes don't change. In practice, none of the leaf values in ``target`` are actually used. Only the tree structure, shapes and dtypes. Args: target: the object of which the state should be restored. state: a dictionary generated by ``to_state_dict`` with the desired new state for ``target``. name: name of branch taken, used to improve deserialization error messages. Returns: A copy of the object with the restored state. """ if _is_namedtuple(target): ty = _NamedTuple else: ty = type(target) if ty not in _STATE_DICT_REGISTRY: return state ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1] with _record_path(name): return ty_from_state_dict(target, state) def to_state_dict(target) -> dict[str, Any]: """Returns a dictionary with the state of the given target.""" if _is_namedtuple(target): ty = _NamedTuple else: ty = type(target) if ty not in _STATE_DICT_REGISTRY: return target ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0] state_dict = ty_to_state_dict(target) if isinstance(state_dict, dict): for key in state_dict.keys(): assert isinstance(key, str), 'A state dict must only have string keys.' return state_dict def is_serializable(target): if not isinstance(target, type): target = type(target) return target in _STATE_DICT_REGISTRY def register_serialization_state( ty, ty_to_state_dict, ty_from_state_dict, override=False ): """Register a type for serialization. Args: ty: the type to be registered ty_to_state_dict: a function that takes an instance of ty and returns its state as a dictionary. ty_from_state_dict: a function that takes an instance of ty and a state dict, and returns a copy of the instance with the restored state. override: override a previously registered serialization handler (default: False). """ if ty in _STATE_DICT_REGISTRY and not override: raise ValueError( f'a serialization handler for "{ty.__name__}" is already registered' ) _STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict) def _list_state_dict(xs: list[Any]) -> dict[str, Any]: return {str(i): to_state_dict(x) for i, x in enumerate(xs)} def _restore_list(xs, state_dict: dict[str, Any]) -> list[Any]: if len(state_dict) != len(xs): raise ValueError( 'The size of the list and the state dict do not match,' f' got {len(xs)} and {len(state_dict)} ' f'at path {current_path()}' ) ys = [] for i in range(len(state_dict)): y = from_state_dict(xs[i], state_dict[str(i)], name=str(i)) ys.append(y) return ys def _dict_state_dict(xs: dict[str, Any]) -> dict[str, Any]: str_keys = {str(k) for k in xs.keys()} if len(str_keys) != len(xs): raise ValueError( 'Dict keys do not have a unique string representation: ' f'{str_keys} vs given: {xs}' ) return {str(key): to_state_dict(value) for key, value in xs.items()} def _restore_dict(xs, states: dict[str, Any]) -> dict[str, Any]: diff = set(map(str, xs.keys())).difference(states.keys()) if diff: raise ValueError( 'The target dict keys and state dict keys do not match, target dict' f' contains keys {diff} which are not present in state dict at path' f' {current_path()}' ) return { key: from_state_dict(value, states[str(key)], name=str(key)) for key, value in xs.items() } def _namedtuple_state_dict(nt) -> dict[str, Any]: return {key: to_state_dict(getattr(nt, key)) for key in nt._fields} def _restore_namedtuple(xs, state_dict: dict[str, Any]): """Rebuild namedtuple from serialized dict.""" if set(state_dict.keys()) == {'name', 'fields', 'values'}: # TODO(jheek): remove backward compatible named tuple restoration early 2022 state_dict = { state_dict['fields'][str(i)]: state_dict['values'][str(i)] for i in range(len(state_dict['fields'])) } sd_keys = set(state_dict.keys()) nt_keys = set(xs._fields) if sd_keys != nt_keys: raise ValueError( 'The field names of the state dict and the named tuple do not match,' f' got {sd_keys} and {nt_keys} at path {current_path()}' ) fields = { k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items() } return type(xs)(**fields) register_serialization_state(dict, _dict_state_dict, _restore_dict) register_serialization_state(list, _list_state_dict, _restore_list) register_serialization_state( tuple, _list_state_dict, lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)), ) register_serialization_state( _NamedTuple, _namedtuple_state_dict, _restore_namedtuple ) register_serialization_state( jax.tree_util.Partial, lambda x: ( { 'args': to_state_dict(x.args), 'keywords': to_state_dict(x.keywords), } ), lambda x, sd: jax.tree_util.Partial( x.func, *from_state_dict(x.args, sd['args']), **from_state_dict(x.keywords, sd['keywords']), ), ) # On-the-wire / disk serialization format # We encode state-dicts via msgpack, using its custom type extension. # https://github.com/msgpack/msgpack/blob/master/spec.md # # - ndarrays and DeviceArrays are serialized to nested msgpack-encoded string # of (shape-tuple, dtype-name (e.g. 'float32'), row-major array-bytes). # Note: only simple ndarray types are supported, no objects or fields. # # - native complex scalars are converted to nested msgpack-encoded tuples # (real, imag). def _ndarray_to_bytes(arr) -> bytes: """Save ndarray to simple msgpack encoding.""" if isinstance(arr, jax.Array): arr = np.array(arr) if arr.dtype.hasobject or arr.dtype.isalignedstruct: raise ValueError( 'Object and structured dtypes not supported ' 'for serialization of ndarrays.' ) tpl = (arr.shape, arr.dtype.name, arr.tobytes('C')) return msgpack.packb(tpl, use_bin_type=True) def _dtype_from_name(name: str): """Handle JAX bfloat16 dtype correctly.""" if name == b'bfloat16': return jax.numpy.bfloat16 else: return np.dtype(name) def _ndarray_from_bytes(data: bytes) -> np.ndarray: """Load ndarray from simple msgpack encoding.""" shape, dtype_name, buffer = msgpack.unpackb(data, raw=True) return np.frombuffer( buffer, dtype=_dtype_from_name(dtype_name), count=-1, offset=0 ).reshape(shape, order='C') class _MsgpackExtType(enum.IntEnum): """Messagepack custom type ids.""" ndarray = 1 native_complex = 2 npscalar = 3 def _msgpack_ext_pack(x): """Messagepack encoders for custom types.""" # TODO(flax-dev): Array here only work when they are fully addressable. # If they are not fully addressable, use the GDA path for checkpointing. if isinstance(x, (np.ndarray, jax.Array)): return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x)) if isinstance(x, np.generic): # pack scalar as ndarray return msgpack.ExtType( _MsgpackExtType.npscalar, _ndarray_to_bytes(np.asarray(x)) ) elif isinstance(x, complex): return msgpack.ExtType( _MsgpackExtType.native_complex, msgpack.packb((x.real, x.imag)) ) return x def _msgpack_ext_unpack(code, data): """Messagepack decoders for custom types.""" if code == _MsgpackExtType.ndarray: return _ndarray_from_bytes(data) elif code == _MsgpackExtType.native_complex: complex_tuple = msgpack.unpackb(data) return complex(complex_tuple[0], complex_tuple[1]) elif code == _MsgpackExtType.npscalar: ar = _ndarray_from_bytes(data) return ar[()] # unpack ndarray to scalar return msgpack.ExtType(code, data) # Chunking array leaves # msgpack has a hard limit of 2**31 - 1 bytes per object leaf. To circumvent # this limit for giant arrays (e.g. embedding tables), we traverse the tree # and break up arrays near the limit into flattened array chunks. # True limit is 2**31 - 1, but leave a margin for encoding padding. MAX_CHUNK_SIZE = 2**30 def _np_convert_in_place(d): """Convert any jax devicearray leaves to numpy arrays in place.""" if isinstance(d, dict): for k, v in d.items(): if isinstance(v, jax.Array): d[k] = np.array(v) elif isinstance(v, dict): _np_convert_in_place(v) elif isinstance(d, jax.Array): return np.array(d) return d _tuple_to_dict = lambda tpl: {str(x): y for x, y in enumerate(tpl)} _dict_to_tuple = lambda dct: tuple(dct[str(i)] for i in range(len(dct))) def _chunk(arr) -> dict[str, Any]: """Convert array to a canonical dictionary of chunked arrays.""" chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize)) data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)} flatarr = arr.reshape(-1) chunks = [ flatarr[i : i + chunksize] for i in range(0, flatarr.size, chunksize) ] data['chunks'] = _tuple_to_dict(chunks) return data def _unchunk(data: dict[str, Any]): """Convert canonical dictionary of chunked arrays back into array.""" assert '__msgpack_chunked_array__' in data shape = _dict_to_tuple(data['shape']) flatarr = np.concatenate(_dict_to_tuple(data['chunks'])) return flatarr.reshape(shape) def _chunk_array_leaves_in_place(d): """Convert oversized array leaves to safe chunked form in place.""" if isinstance(d, dict): for k, v in d.items(): if isinstance(v, np.ndarray): if v.size * v.dtype.itemsize > MAX_CHUNK_SIZE: d[k] = _chunk(v) elif isinstance(v, dict): _chunk_array_leaves_in_place(v) elif isinstance(d, np.ndarray): if d.size * d.dtype.itemsize > MAX_CHUNK_SIZE: return _chunk(d) return d def _unchunk_array_leaves_in_place(d): """Convert chunked array leaves back into array leaves, in place.""" if isinstance(d, dict): if '__msgpack_chunked_array__' in d: return _unchunk(d) else: for k, v in d.items(): if isinstance(v, dict) and '__msgpack_chunked_array__' in v: d[k] = _unchunk(v) elif isinstance(v, dict): _unchunk_array_leaves_in_place(v) return d # User-facing API calls: def msgpack_serialize(pytree, in_place: bool = False) -> bytes: """Save data structure to bytes in msgpack format. Low-level function that only supports python trees with array leaves, for custom objects use ``to_bytes``. It splits arrays above MAX_CHUNK_SIZE into multiple chunks. Args: pytree: python tree of dict, list, tuple with python primitives and array leaves. in_place: boolean specifying if pytree should be modified in place. Returns: msgpack-encoded bytes of pytree. """ if not in_place: pytree = jax.tree_util.tree_map(lambda x: x, pytree) pytree = _np_convert_in_place(pytree) pytree = _chunk_array_leaves_in_place(pytree) return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True) def msgpack_restore(encoded_pytree: bytes): """Restore data structure from bytes in msgpack format. Low-level function that only supports python trees with array leaves, for custom objects use ``from_bytes``. Args: encoded_pytree: msgpack-encoded bytes of python tree. Returns: Python tree of dict, list, tuple with python primitive and array leaves. """ state_dict = msgpack.unpackb( encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False ) return _unchunk_array_leaves_in_place(state_dict) def from_bytes(target, encoded_bytes: bytes): """Restore optimizer or other object from msgpack-serialized state-dict. Args: target: template object with state-dict registrations that matches the structure being deserialized from ``encoded_bytes``. encoded_bytes: msgpack serialized object structurally isomorphic to ``target``. Typically a flax model or optimizer. Returns: A new object structurally isomorphic to ``target`` containing the updated leaf data from saved data. """ state_dict = msgpack_restore(encoded_bytes) return from_state_dict(target, state_dict) def to_bytes(target) -> bytes: """Save optimizer or other object as msgpack-serialized state-dict. Args: target: template object with state-dict registrations to be serialized to msgpack format. Typically a flax model or optimizer. Returns: Bytes of msgpack-encoded state-dict of ``target`` object. """ state_dict = to_state_dict(target) return msgpack_serialize(state_dict, in_place=True) ================================================ FILE: flax/struct.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for defining custom classes that can be used with jax transformations.""" from collections.abc import Callable import dataclasses import functools from typing import TypeVar, overload import jax from typing_extensions import ( dataclass_transform, # pytype: disable=not-supported-yet ) from . import serialization _T = TypeVar('_T') def field(pytree_node=True, *, metadata=None, **kwargs): return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node}, **kwargs) @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] @overload def dataclass(clz: _T, **kwargs) -> _T: ... @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] @overload def dataclass(**kwargs) -> Callable[[_T], _T]: ... @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] def dataclass( clz: _T | None = None, **kwargs, ) -> _T | Callable[[_T], _T]: """Create a class which can be passed to functional transformations. .. note:: Inherit from ``PyTreeNode`` instead to avoid type checking issues when using PyType. Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are immutable and can be mapped over using the ``jax.tree_util`` methods. The ``dataclass`` decorator makes it easy to define custom classes that can be passed safely to Jax. Define JAX data as normal attribute fields, and use ``pytree_node=False`` to define static metadata. See example:: >>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> @struct.dataclass ... class Model: ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model) Note that dataclasses have an auto-generated ``__init__`` where the arguments of the constructor and the attributes of the created instance match 1:1. If you desire a "smart constructor", for example to optionally derive some of the attributes from others, make an additional static or class method. Consider the following example:: >>> @struct.dataclass ... class DirectionAndScaleKernel: ... direction: jax.Array ... scale: jax.Array ... @classmethod ... def create(cls, kernel): ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) ... direction = direction / scale ... return cls(direction, scale) Args: clz: the class that will be transformed by the decorator. **kwargs: arguments to pass to the dataclass constructor. Returns: The new class. """ # Support passing arguments to the decorator (e.g. @dataclass(kw_only=True)) if clz is None: return functools.partial(dataclass, **kwargs) # type: ignore[bad-return-type] # check if already a flax dataclass if '_flax_dataclass' in clz.__dict__: return clz if 'frozen' not in kwargs.keys(): kwargs['frozen'] = True data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore meta_fields = [] data_fields = [] for field_info in dataclasses.fields(data_clz): is_pytree_node = field_info.metadata.get('pytree_node', True) if is_pytree_node: data_fields.append(field_info.name) else: meta_fields.append(field_info.name) def replace(self, **updates): """Returns a new object replacing the specified fields with new values.""" return dataclasses.replace(self, **updates) data_clz.replace = replace jax.tree_util.register_dataclass(data_clz, data_fields, meta_fields) def to_state_dict(x): state_dict = { name: serialization.to_state_dict(getattr(x, name)) for name in data_fields } return state_dict def from_state_dict(x, state): """Restore the state of a data class.""" state = state.copy() # copy the state so we can pop the restored fields. updates = {} for name in data_fields: if name not in state: raise ValueError( f'Missing field {name} in state dict while restoring' f' an instance of {clz.__name__},' f' at path {serialization.current_path()}' ) value = getattr(x, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict( value, value_state, name=name ) if state: names = ','.join(state.keys()) raise ValueError( f'Unknown field(s) "{names}" in state dict while' f' restoring an instance of {clz.__name__}' f' at path {serialization.current_path()}' ) return x.replace(**updates) serialization.register_serialization_state( data_clz, to_state_dict, from_state_dict ) # add a _flax_dataclass flag to distinguish from regular dataclasses data_clz._flax_dataclass = True # type: ignore[attr-defined] return data_clz # type: ignore TNode = TypeVar('TNode', bound='PyTreeNode') @dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required] class PyTreeNode: """Base class for dataclasses that should act like a JAX pytree node. See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior. This base class additionally avoids type checking errors when using PyType. Example:: >>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> class Model(struct.PyTreeNode): ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model) """ def __init_subclass__(cls, **kwargs): super().__init_subclass__() dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types def __init__(self, *args, **kwargs): # stub for pytype raise NotImplementedError def replace(self: TNode, **overrides) -> TNode: # stub for pytype raise NotImplementedError ================================================ FILE: flax/testing/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax testing utilities.""" from .benchmark import Benchmark ================================================ FILE: flax/testing/benchmark.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Benchmark class for Flax regression and integration testing. This file defines utility functions for collecting model training results from TensorBoard summaries, and for reporting benchmarks in a JSON format for pickup by continuous integration / monitoring frameworks. When the `benchmark_output_dir` is provided, the benchmark results are saved in this directory in the JSON format with a single file per benchmark. """ import functools import inspect import itertools import json import os import tempfile from absl import flags, logging from absl.testing import absltest from tensorboard.backend.event_processing import ( directory_watcher, event_file_loader, io_wrapper, ) from tensorboard.summary import v1 as summary_lib from tensorboard.util import tensor_util from flax import io _BENCHMARK_OUTPUT_DIR = flags.DEFINE_string( 'benchmark_output_dir', default=None, help='Benchmark output directory.' ) _SCALAR_PLUGIN_NAME = ( summary_lib.scalar_pb('', 0).value[0].metadata.plugin_data.plugin_name ) def _make_events_generator(path): """Makes a generator yielding TensorBoard events from files in `path`.""" return directory_watcher.DirectoryWatcher( path, event_file_loader.EventFileLoader, io_wrapper.IsSummaryEventsFile ).Load() def _is_scalar_value(value): if value.HasField('metadata') and value.metadata.HasField('plugin_data'): plugin_data = value.metadata.plugin_data return plugin_data.plugin_name == _SCALAR_PLUGIN_NAME return False def _process_event(event): """Parse TensorBoard scalars into a (tag, wall_time, step, scalar) tuple.""" for value in event.summary.value: if not _is_scalar_value(value): continue if value.HasField('tensor'): yield ( value.tag, event.wall_time, event.step, tensor_util.make_ndarray(value.tensor).item(), ) def _get_tensorboard_scalars(path): """Read and parse scalar TensorBoard summaries. Args: path: str. Path containing TensorBoard event files. Returns: Dictionary mapping summary tags (str) to lists of (wall_time, step, scalar) tuples. """ gen = _make_events_generator(path) data = filter(lambda x: x.HasField('summary'), gen) data = itertools.chain.from_iterable(map(_process_event, data)) data_by_key = {} for tag, wall_time, step, value in data: if not tag in data_by_key: data_by_key[tag] = [] data_by_key[tag].append((wall_time, step, value)) return data_by_key class Benchmark(absltest.TestCase): """Benchmark class for Flax examples. This class overrides the behaviour of `self.assert*` methods to be deferred instead of failing immediately. This allows for using absltest assert methods for checking benchmark target metrics. This is also necessary for correctly reporting benchmark results and determining its success. """ def __init__(self, *args, **kwargs): """Wrap test methods in a try-except decorator to delay exceptions.""" super().__init__(*args, **kwargs) for func_name in dir(self): if func_name.startswith('assert'): func = getattr(self, func_name) patched_func = functools.partial(self._collect_assert_wrapper, fn=func) setattr(self, func_name, patched_func) self._benchmark_output_dir = _BENCHMARK_OUTPUT_DIR.value # Create target directory if defined. if self._benchmark_output_dir and not io.exists(self._benchmark_output_dir): io.makedirs(self._benchmark_output_dir) # pylint: disable=invalid-name def _collect_assert_wrapper(self, *args, fn=None, **kwargs): """Wrapper around assert methods that caputres and collects failures.""" try: return fn(*args, **kwargs) except self.failureException as ex: self._outstanding_fails.append(ex) def setUp(self): """Setup ran before each test.""" super().setUp() self._reported_name = None self._reported_wall_time = None self._reported_metrics = {} self._reported_extras = {} self._outstanding_fails = [] def tearDown(self): """Tear down after each test.""" super().tearDown() self._report_benchmark_results() for message in self._outstanding_fails: raise self.failureException(message) def get_tmp_model_dir(self): """Returns an unique temporary directory for storing model data. Returns path by appending Classname.testname to `benchmark_output_dir` flag if defined else uses a temporary directory. This helps to export summary files to tensorboard as multiple separate runs for each test method. """ if self._benchmark_output_dir: model_dir = self._benchmark_output_dir else: model_dir = tempfile.mkdtemp() model_dir_path = os.path.join( model_dir, self._reported_name or self._get_test_name() ) # Create directories if they don't exist. if not io.exists(model_dir_path): io.makedirs(model_dir_path) return model_dir_path def has_outstanding_fails(self): """Determine whether the benchmark failed, but the error is deferred.""" return len(self._outstanding_fails) > 0 def read_summaries(self, path): """Read TensorBoard summaries.""" return _get_tensorboard_scalars(path) def report_wall_time(self, wall_time: float): """Report wall time for the benchmark.""" self._update_reported_name() self._reported_wall_time = wall_time def report_metrics(self, metrics: dict[str, float]): """Report metrics for the benchmark.""" self._update_reported_name() self._reported_metrics.update(metrics) def report_metric(self, name: str, value: float): """Report a single metric for the benchmark.""" self.report_metrics({name: value}) def report_extras(self, extras: dict[str, str]): """Report extras for the benchmark.""" self._update_reported_name() self._reported_extras.update(extras) def report_extra(self, name: str, value: str): """Report a single extra for the benchmark.""" self.report_extras({name: value}) def _get_test_name(self, prefix='test_'): """Returns full name of test class and method calling report_benchmark. The name is based on the *outermost* Benchmark class in the class stack. Based on tensorflow/python/platform/benchmark.py Args: prefix: str. Prefix that the caller method must have. Returns: Resolved test name as `ClassName.test_name`. """ # Find the caller method (outermost Benchmark class). stack = inspect.stack() calling_class, name = None, None for frame_info in stack[::-1]: f_locals = frame_info.frame.f_locals f_self = f_locals.get('self', None) if isinstance(f_self, Benchmark): name = frame_info.function if name.startswith(prefix): calling_class = f_self break if calling_class is None: raise ValueError('Unable to determine the calling Benchmark class.') # Prefix the name with the class name. class_name = type(calling_class).__name__ name = f'{class_name}.{name}' return name def _update_reported_name(self): """Record / update test name for the benchmark.""" self._reported_name = self._reported_name or self._get_test_name() def _report_benchmark_results(self): """Produce benchmark results report. Results are reported as a JSON string with the following schema: ``` { "name": "succeeded": true / false "wall_time": float (containing wall-time for the benchmark) "metrics": { "string" -> float map of other performance metrics } "extras": { "string" -> "string" map containing anything else of interest } } ``` """ name = self._reported_name if not name: raise ValueError( 'Unable to determine test name for reporting ' 'benchmark results. Make sure you are using ' '`self.report_*` methods.' ) succeeded = not self.has_outstanding_fails() results = { 'name': name, 'succeeded': succeeded, 'metrics': self._reported_metrics, 'extras': self._reported_extras, } if self._reported_wall_time is not None: results['wall_time'] = self._reported_wall_time if not succeeded: msg = '\n'.join([str(fail) for fail in self._outstanding_fails]) results['extras']['failed_assertions'] = msg results_str = json.dumps(results) logging.info(results_str) # Maybe save results as a file for pickup by CI / monitoring frameworks. if self._benchmark_output_dir: filename = os.path.join(self._benchmark_output_dir, name + '.json') with io.GFile(filename, 'w') as fout: fout.write(results_str) ================================================ FILE: flax/traceback_util.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax specific traceback_util functions.""" from jax._src import traceback_util as jax_traceback_util from jax.extend import source_info_util from flax import config # pylint: disable=protected-access # Globals: # Whether to filter flax frames from traceback. _flax_filter_tracebacks = config.flax_filter_frames # Flax specific set of paths to exclude from tracebacks. _flax_exclusions = set() # re-import JAX symbol for convenience. api_boundary = jax_traceback_util.api_boundary def register_exclusion(path): """Marks a Flax source file for exclusion.""" global _flax_exclusions, _flax_filter_tracebacks # Record flax exclusions so we can dynamically add and remove them. _flax_exclusions.add(path) if _flax_filter_tracebacks: jax_traceback_util.register_exclusion(path) source_info_util.register_exclusion(path) def hide_flax_in_tracebacks(): """Hides Flax internal stack frames in tracebacks.""" global _flax_exclusions, _flax_filter_tracebacks _flax_filter_tracebacks = True for exclusion in _flax_exclusions: if exclusion not in jax_traceback_util._exclude_paths: jax_traceback_util._exclude_paths.append(exclusion) def show_flax_in_tracebacks(): """Shows Flax internal stack frames in tracebacks.""" global _flax_exclusions, _flax_filter_tracebacks _flax_filter_tracebacks = False for exclusion in _flax_exclusions: if exclusion in jax_traceback_util._exclude_paths: jax_traceback_util._exclude_paths.remove(exclusion) ================================================ FILE: flax/training/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax training utilities.""" ================================================ FILE: flax/training/checkpoints.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Checkpointing helper functions. Handles saving and restoring optimizer checkpoints based on step-number or other numerical metric in filename. Cleans up older / worse-performing checkpoint files. """ import functools import os import pathlib import re import time import warnings from concurrent.futures import thread from typing import ( Any, ) from collections.abc import Callable, Iterable import jax import orbax.checkpoint as ocp from absl import logging from jax import monitoring, process_index from jax import tree_util as jtu from jax.experimental.array_serialization.serialization import ( GlobalAsyncCheckpointManager, get_tensorstore_spec, ) from jax.experimental.multihost_utils import sync_global_devices from flax import config, core, errors, io, serialization, traverse_util from flax.training import orbax_utils _READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' _WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' # Single-group reg-exps for int or float numerical substrings. # captures sign: SIGNED_FLOAT_RE = re.compile(r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') # does not capture sign: UNSIGNED_FLOAT_RE = re.compile( r'[-+]?((?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)' ) # Module name followed by number. MODULE_NUM_RE = re.compile(r'(.*)_\d+$') # Alternative schemes handled by `gfile`, e.g. on Google Cloud Storage (GCS). SCHEME_RE = re.compile('^(?P[a-z][a-z0-9.+-]+://)?(?P.*)', re.I) # Multiprocess arrays (GlobalDeviceArray, or JAX array with multiprocess # sharding) is across processes and will be stored in directories with this # postfix, separated from the non-distributed data (e.g. the larger pytree) MP_ARRAY_POSTFIX = '_gda' # Occurrences of multiprocess arrays in the target pytree will be # replaced by this string placeholder. MP_ARRAY_PH = '//GDAPlaceholder:' # Add a copy-success file to a distributed array directory to indicate the # array save is complete. # We need this for GCS because GCS's directory move is not atomic. COMMIT_SUCCESS_FILE = 'commit_success.txt' # Orbax main checkpoint file name. ORBAX_CKPT_FILENAME = 'checkpoint' ORBAX_MANIFEST_OCDBT = 'manifest.ocdbt' ORBAX_METADATA_FILENAME = '_METADATA' PyTree = Any # TODO(flax-dev): Remove this once flax is using the latest jax release # containing jax.Array attribute. MultiprocessArrayType = Any def _is_multiprocess_array(value: Any) -> bool: """Use GlobalAsyncCheckpointManager to save the array if it's only partially available on this host.""" if isinstance(value, jax.Array): return not value.is_fully_addressable return False def _checkpoint_path( ckpt_dir: str, step: int | float | str, prefix: str = 'checkpoint_' ) -> str: return os.path.join(ckpt_dir, f'{prefix}{step}') def _checkpoint_path_step(path: str) -> float | None: """Returns the step number of a checkpoint path.""" for s in SIGNED_FLOAT_RE.split(path)[::-1]: if SIGNED_FLOAT_RE.match(s): return float(s) return None def _allowempty_listdir(path: str): try: return io.listdir(path) except io.NotFoundError: return [] def _safe_remove(path: str): """Identify whether a path is a dir or list and choose the correct remove method.""" if io.isdir(path): io.rmtree(path) else: io.remove(path) def _is_orbax_checkpoint(path: str) -> bool: return ( io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists(os.path.join(path, ORBAX_METADATA_FILENAME)) or io.exists(os.path.join(path, ORBAX_MANIFEST_OCDBT)) ) class AsyncManager: """A simple object to track async checkpointing. How to use: create an instance and pass to save_checkpoint() calls: am = AsyncManager() save_checkpoint(..., async_manager=am) """ def __init__(self, max_workers: int = 1): self.executor = thread.ThreadPoolExecutor(max_workers=max_workers) self.save_future = None def wait_previous_save(self): """Block until the previous save finishes, to keep files' consistency.""" if self.save_future and not self.save_future.done(): logging.warning( 'The previous async save_checkpoint has not finished yet. Waiting ' 'for it to complete before the next save.' ) self.save_future.result() def save_async(self, task: Callable[[], Any]): """Run a task async. The future will be tracked as self.save_future. Args: task: The callable to be executed asynchronously. """ self.wait_previous_save() self.save_future = self.executor.submit(task) # type: ignore def _split_mp_arrays( target: dict[str, Any] ) -> tuple[dict[str, Any], list[tuple[MultiprocessArrayType, str]]]: """Split out the multiprocess arrays from the target pytree to save.""" # When target is a single leaf instead of a pytree dict. if not isinstance(target, (core.FrozenDict, dict)): if _is_multiprocess_array(target): return MP_ARRAY_PH, [(target, '')] return target, [] # Traverse the target and handle distributed arrays. flattened = traverse_util.flatten_dict(target, keep_empty_nodes=True) mpa_targets = [] for key, value in flattened.items(): if _is_multiprocess_array(value): subpath = '/'.join(key) mpa_targets.append((value, subpath)) flattened[key] = MP_ARRAY_PH + subpath target = traverse_util.unflatten_dict(flattened) return target, mpa_targets def _make_mpa_dirs( mpa_targets: list[tuple[MultiprocessArrayType, str]], tmp_path: str ): # Temporary array path is not used in GCS. if tmp_path.startswith('gs://'): return mpa_tmp_path = tmp_path + MP_ARRAY_POSTFIX # Clean up the previous MPA dir, in case some leftover from last preemption # lingers. if io.exists(mpa_tmp_path): logging.info('Removing outdated MPA temporary files at %s', mpa_tmp_path) io.rmtree(mpa_tmp_path) _, mpa_subpaths = zip(*mpa_targets) for subpath in mpa_subpaths: io.makedirs(os.path.join(mpa_tmp_path, subpath)) def _save_mpas( gda_manager, mpa_targets: list[tuple[MultiprocessArrayType, str]], tmp_path: str, final_path: str, base_path: str, keep: int, overwrite: bool, keep_every_n_steps: int | None, ckpt_start_time: float, async_manager: AsyncManager | None = None, ): """Save the multiprocess arrays given the paths.""" mpa_list, mpa_subpaths = zip(*mpa_targets) mpa_tmp_path, mpa_final_path = ( tmp_path + MP_ARRAY_POSTFIX, final_path + MP_ARRAY_POSTFIX, ) write_commit_success = False # If the checkpoint directory is a GCS directory, then keep the final # checkpoint directory as the temporary checkpoint directory. This is because # renames are not atomic on GCS. When restoring check for the existence of a # success file. # TODO: figure out a way to unit-test the behavior. if tmp_path.startswith('gs://'): mpa_tmp_path = mpa_final_path write_commit_success = True mpa_paths = [os.path.join(mpa_tmp_path, x) for x in mpa_subpaths] ts_specs = [get_tensorstore_spec(x) for x in mpa_paths] gda_manager.serialize( list(mpa_list), ts_specs, on_commit_callback=functools.partial( _save_commit, tmp_path, final_path, base_path, keep, overwrite, keep_every_n_steps, ckpt_start_time, has_mpa=True, write_commit_success=write_commit_success, async_manager=async_manager, ), ) def _restore_mpas( state_dict, target: Any | None, ckpt_path: str, step: int | float | None, gda_manager: GlobalAsyncCheckpointManager | None, allow_partial: bool = False, ): """Restore the multiprocess arrays given the target structure and type.""" def _check_mpa_errors(): if not gda_manager: raise errors.MPACheckpointingRequiredError(ckpt_path, step) if not target and not allow_partial: raise errors.MPARestoreTargetRequiredError(ckpt_path, step) def _safe_deserialize( target_mpas: list[tuple[tuple[Any, ...], MultiprocessArrayType, str]], gda_manager: Any, ) -> list[MultiprocessArrayType]: gda_manager.wait_until_finished() # Check if reading from GCS and the array dir is potentially corrupted. if ckpt_path.startswith('gs://') and not io.exists( os.path.join(ckpt_path + MP_ARRAY_POSTFIX, COMMIT_SUCCESS_FILE) ): raise errors.MPARestoreDataCorruptedError(step, ckpt_path) # Check if the given target array types are valid. shardings = [] for _, arr, path in target_mpas: if isinstance(arr, jax.Array): shardings.append(arr.sharding) # Restore the arrays. ts_specs = [get_tensorstore_spec(path) for _, _, path in target_mpas] return gda_manager.deserialize(shardings, ts_specs) # When target is a single leaf instead of a pytree dict. if not isinstance(state_dict, (core.FrozenDict, dict)): if ( _is_multiprocess_array(target) and isinstance(state_dict, str) and state_dict.startswith(MP_ARRAY_PH) ): _check_mpa_errors() return _safe_deserialize( [((), target, ckpt_path + MP_ARRAY_POSTFIX)], gda_manager )[0] return state_dict # Go through the restored checkpoint pytree for all MPAs flattened = traverse_util.flatten_dict(state_dict, keep_empty_nodes=True) target_flattened = {} if target: target_flattened = traverse_util.flatten_dict( serialization.to_state_dict(target), keep_empty_nodes=True ) # A list of (state_dict_key, target_array, array_file_path) for every array # to be restored target_mpas = [] for key, value in flattened.items(): if isinstance(value, str) and value.startswith(MP_ARRAY_PH): _check_mpa_errors() if ( not target or (key not in target_flattened) or (not _is_multiprocess_array(target_flattened[key])) ): if allow_partial: logging.warning( 'Multiprocess array %s could not be restored because a valid' ' array is not found in target at the corresponding location.' ' Proceed to restore other arrays because' ' allow_partial_restoration=True', key, ) else: raise errors.MPARestoreTargetRequiredError(ckpt_path, step, key) else: mpa_path = os.path.join( ckpt_path + MP_ARRAY_POSTFIX, value[len(MP_ARRAY_PH) :] ) target_mpas.append((key, target_flattened[key], mpa_path)) # If any MPA needs to be restored, call deserialize if target_mpas: mpa_list = _safe_deserialize(target_mpas, gda_manager) for mpa, (key, _, _) in zip(mpa_list, target_mpas): flattened[key] = mpa state_dict = traverse_util.unflatten_dict(flattened) return state_dict def natural_sort(file_list: Iterable[str], signed: bool = True) -> list[str]: """Natural sort for filenames with numerical substrings. Args: file_list: list of paths to sort containing numerical substrings. signed: bool: if leading '-' (or '+') signs should be included in numerical substrings as a sign or treated as a separator. Returns: List of filenames sorted 'naturally', not lexicographically: any integer substrings are used to subsort numerically. e.g. file_1, file_10, file_2 --> file_1, file_2, file_10 file_0.1, file_-0.2, file_2.0 --> file_-0.2, file_0.1, file_2.0 """ float_re = SIGNED_FLOAT_RE if signed else UNSIGNED_FLOAT_RE def maybe_num(s): if float_re.match(s): return float(s) else: return s def split_keys(s): return [maybe_num(c) for c in float_re.split(s)] return sorted(file_list, key=split_keys) def safe_normpath(path: str) -> str: """Normalizes path safely to get around `io.glob()` limitations.""" match = SCHEME_RE.match(path) assert match is not None d = match.groupdict() return (d['scheme'] or '') + os.path.normpath(d['path']) def _remove_invalid_ckpts( ckpt_path: str, base_path: str, keep: int, overwrite: bool, keep_every_n_steps: int | None, has_mpa: bool, ) -> None: """Clean up the checkpoint space according to `overwrite`, `keep`, and `keep_every_n_steps` parameters.""" dir_path, prefix = os.path.split(base_path) checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) ] checkpoint_files = [ os.path.join(dir_path, c) for c in checkpoint_files if c.match(f'{prefix}*') and not c.match(f'*{MP_ARRAY_POSTFIX}') ] checkpoint_files = natural_sort(checkpoint_files) # Remove newer checkpoints if overwrite and ckpt_path in checkpoint_files: ind = checkpoint_files.index(ckpt_path) + 1 newer_ckpts = checkpoint_files[ind:] checkpoint_files = checkpoint_files[:ind] for path in newer_ckpts: logging.info('Removing checkpoint at %s', path) if has_mpa: # MPA might be removed already but the main ckpt is still there. This # can happen if the job is previously preempted after deleting the MPA # checkpoint folder and before deleting the main checkpoint. if io.exists(path + MP_ARRAY_POSTFIX): io.rmtree(path + MP_ARRAY_POSTFIX) _safe_remove(path) # Remove old checkpoint files. last_kept = -float('inf') if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] # Note: old_ckpts is sorted from oldest to newest. for path in old_ckpts: if keep_every_n_steps: step_number = _checkpoint_path_step(path) if step_number and (step_number - last_kept) >= keep_every_n_steps: logging.debug( 'Not deleting %s, because last_kept=%f and keeping ' 'every %d steps.', path, last_kept, keep_every_n_steps, ) last_kept = step_number continue logging.info('Removing checkpoint at %s', path) if has_mpa: # MPA might be removed already but the main ckpt is still there. if io.exists(path + MP_ARRAY_POSTFIX): io.rmtree(path + MP_ARRAY_POSTFIX) _safe_remove(path) def _save_commit( ckpt_tmp_path: str, ckpt_path: str, base_path: str, keep: int, overwrite: bool, keep_every_n_steps: int | None, ckpt_start_time: float, has_mpa: bool, write_commit_success: bool, async_manager: AsyncManager | None = None, ) -> None: """Commit changes after saving checkpoints to disk. This function does the following, sequentially: 1. Make sure all ckpt writing finishes, and rename them from temp path to the final path. 2. Remove newer checkpoints (files that ordered larger than this save) if `overwrite=True`. 3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`. 4. Record program duration saved by this checkpoint. """ mpa_ckpt_tmp_path, mpa_ckpt_path = ( ckpt_tmp_path + MP_ARRAY_POSTFIX, ckpt_path + MP_ARRAY_POSTFIX, ) # Rename the multiprocess array path once serialization and writing finished. if has_mpa: if write_commit_success: commit_success_path = os.path.join(mpa_ckpt_path, COMMIT_SUCCESS_FILE) with io.GFile(commit_success_path, 'w') as f: f.write(f'Checkpoint commit was successful to {mpa_ckpt_path}') else: # Commits are a two stage process (renaming the array folder and renaming # the main ckpt file in sequential order). We always try to overwrite # here because the array ckpt might be already renamed in a previously # interrupted commit. NOTE: io.rename does not support overwriting # directories via `rename` so we manually overwrite it. if io.exists(mpa_ckpt_path): logging.info('Removing outdated checkpoint at %s', mpa_ckpt_path) io.rmtree(mpa_ckpt_path) io.rename(mpa_ckpt_tmp_path, mpa_ckpt_path) # Commit the main checkpoint file after arrays (if any) are committed if async_manager: async_manager.wait_previous_save() io.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) # Remove newer and older invalid checkpoints. _remove_invalid_ckpts( ckpt_path, base_path, keep, overwrite, keep_every_n_steps, has_mpa ) # Record checkpoint-related metrics. ocp.utils.record_saved_duration(ckpt_start_time) if async_manager: jax.monitoring.record_event_duration_secs( '/jax/checkpoint/write/async/total_duration_secs', time.time() - ckpt_start_time, ) def _check_overwrite_error( ckpt_tmp_path: str, ckpt_path: str, base_path: str, step: int ): """Throw error if a ckpt file of this step or higher already exists.""" dir_path, prefix = os.path.split(base_path) checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(dir_path) ] checkpoint_files = [ os.path.join(dir_path, c) for c in checkpoint_files if c.match(f'{prefix}*') and not c.match(f'*{MP_ARRAY_POSTFIX}') ] if ckpt_path in checkpoint_files: raise errors.InvalidCheckpointError(ckpt_path, step) checkpoint_files.append(ckpt_path) checkpoint_files = natural_sort(checkpoint_files) # Handle the case if the job was preempted after the temporary checkpoint # was written, but before it was renamed to the final checkpoint name if checkpoint_files[-1] == ckpt_tmp_path: checkpoint_files.pop() if ckpt_path != checkpoint_files[-1]: raise errors.InvalidCheckpointError(ckpt_path, step) def _save_main_ckpt_file( target: bytes, has_mpa: bool, paths: tuple[str, str], base_path: str, step: int, keep: int, overwrite: bool, keep_every_n_steps: int | None, ckpt_start_time: float, ): """Save the main checkpoint file via file system.""" ckpt_tmp_path, ckpt_path = paths io.makedirs(os.path.dirname(ckpt_path)) with io.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(target) # Postpone the commitment of checkpoint to after MPA writes are done. if not has_mpa: _save_commit( ckpt_tmp_path, ckpt_path, base_path, keep, overwrite, keep_every_n_steps, ckpt_start_time, has_mpa=False, write_commit_success=False, ) def _get_checkpoint_paths( ckpt_dir: str | os.PathLike, step: int | float, prefix: str = 'checkpoint_', ) -> tuple[str, str, str]: """Generate the checkpoint paths used in this save operation.""" ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str logging.info('Saving checkpoint at step: %s', step) # normalize path because io.glob() can modify path './', '//' ... ckpt_dir = safe_normpath(ckpt_dir) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) base_path = os.path.join(ckpt_dir, prefix) return ckpt_path, ckpt_tmp_path, base_path def save_checkpoint( ckpt_dir: str | os.PathLike, target: PyTree, step: int | float, prefix: str = 'checkpoint_', keep: int = 1, overwrite: bool = False, keep_every_n_steps: int | None = None, async_manager: AsyncManager | None = None, orbax_checkpointer: ocp.Checkpointer | None = None, ) -> str: """Save a checkpoint of the model. Suitable for single-host. In this method, every JAX process saves the checkpoint on its own. Do not use it if you have multiple processes and you intend for them to save data to a common directory (e.g., a GCloud bucket). To save multi-process checkpoints to a shared storage or to save ``GlobalDeviceArray``s, use ``save_checkpoint_multiprocess()`` instead. Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling ``async_manager.wait_previous_save()``. Example usage:: >>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)} Args: ckpt_dir: str or pathlib-like path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). keep_every_n_steps: if defined, keep every checkpoints every n steps (in addition to keeping the last 'keep' checkpoints). async_manager: if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. orbax_checkpointer: if defined, the save will be done by ocp. In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an ``orbax_checkpointer`` is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers. Returns: Filename of saved checkpoint. """ jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. if async_manager: async_manager.wait_previous_save() ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( ckpt_dir, step, prefix ) if config.flax_use_orbax_checkpointing or orbax_checkpointer: logging.info( 'Using Orbax as backend to save Flax checkpoints. For potential' ' troubleshooting see:' ' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#orbax-as-backend-troubleshooting' ) if jax.process_count() > 1: logging.warning( 'Multiple JAX processes detected when calling single-process' ' `save_checkpoint`. Your devices will HANG if this function is only' ' called on process 0! Troubleshoot at:' ' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-your-devices-hang-when-writing-checkpoints' ) # Make sure any previous work is done before making file changes. if orbax_checkpointer and isinstance( orbax_checkpointer, ocp.AsyncCheckpointer ): orbax_checkpointer.wait_until_finished() # If no checkpointer provided, save synchronously with default setting. if not orbax_checkpointer: orbax_checkpointer = ocp.Checkpointer( ocp.PyTreeCheckpointHandler() ) # Check singular target. if jtu.treedef_is_leaf(jtu.tree_structure(target)) and not isinstance( orbax_checkpointer._handler, ocp.ArrayCheckpointHandler, # pylint: disable=protected-access ): raise ValueError( 'Orbax backend only accept pytree as save target. To save singular' ' objects like numbers or Numpy arrays, checkout' ' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-you-don-t-save-pytrees' ) orbax_checkpointer.save( ckpt_path, target, force=overwrite ) # Do a process check here in case people call this for multihost. if process_index() == 0: _remove_invalid_ckpts( ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True ) end_time = time.time() monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time ) return ckpt_path warnings.warn( ( 'Flax Checkpointing will soon be deprecated in favor of Orbax' ' (https://github.com/google/orbax). Please refer to the Checkpoint' ' Upgrade Guide' ' (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html)' ' to self-migrate your code to ocp.' ), DeprecationWarning, ) if not overwrite: _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore target = serialization.to_bytes(target) # Save the files via I/O sync or async. def save_main_ckpt_task(): jax.monitoring.record_event('/jax/flax/checkpoint/save_main_ckpt_task') return _save_main_ckpt_file( target, False, (ckpt_tmp_path, ckpt_path), base_path, step, keep, overwrite, keep_every_n_steps, start_time, ) if async_manager: async_manager.save_async(save_main_ckpt_task) else: save_main_ckpt_task() end_time = time.time() monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time ) return ckpt_path def save_checkpoint_multiprocess( ckpt_dir: str | os.PathLike, target: PyTree, step: int | float, prefix: str = 'checkpoint_', keep: int = 1, overwrite: bool = False, keep_every_n_steps: int | None = None, async_manager: AsyncManager | None = None, gda_manager: GlobalAsyncCheckpointManager | None = None, orbax_checkpointer: ocp.Checkpointer | None = None, ) -> str: """Save a checkpoint of the model in multi-process environment. Use this method to save ``GlobalDeviceArray``s, or to save data to a common directory. Only process 0 will save the main checkpoint file and remove old checkpoint files. Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager or gda_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling ``async_manager.wait_previous_save()`` or ``gda_manager.wait_until_finished()``. Args: ckpt_dir: str or pathlib-like path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). keep_every_n_steps: if defined, keep every checkpoints every n steps (in addition to keeping the last 'keep' checkpoints). async_manager: if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. gda_manager: required if target contains a JAX GlobalDeviceArray. Will save the GDAs to a separate subdirectory with postfix "_gda" asynchronously. Same as async_manager, this will block subsequent saves. orbax_checkpointer: if defined, the save will be done by Orbax. In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an ``orbax_checkpointer`` is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers. Returns: Filename of saved checkpoint. """ jax.monitoring.record_event('/jax/flax/checkpoint/save') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. sync_global_devices('Flax:Checkpoint:StartSave') if async_manager: async_manager.wait_previous_save() if gda_manager: gda_manager.wait_until_finished() sync_global_devices('Flax:Checkpoint:WaitLastSaveDone') ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( ckpt_dir, step, prefix ) if config.flax_use_orbax_checkpointing or orbax_checkpointer: logging.info( 'Using Orbax as backend to save Flax checkpoints. For potential' ' troubleshooting see:' ' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#orbax-as-backend-troubleshooting' ) # Make sure any previous work is done before making file changes. if orbax_checkpointer and isinstance( orbax_checkpointer, ocp.AsyncCheckpointer ): orbax_checkpointer.wait_until_finished() # If no checkpointer provided, save synchronously with default setting. if not orbax_checkpointer: orbax_checkpointer = ocp.Checkpointer( ocp.PyTreeCheckpointHandler() ) # Check singular target. if jtu.treedef_is_leaf(jtu.tree_structure(target)) and not isinstance( orbax_checkpointer._handler, ocp.ArrayCheckpointHandler, # pylint: disable=protected-access ): raise ValueError( 'Orbax backend only accept pytree as save target. To save singular' ' objects like numbers or Numpy arrays, checkout' ' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-you-don-t-save-pytrees' ) if process_index() == 0: _remove_invalid_ckpts( ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True ) orbax_checkpointer.save( ckpt_path, target, force=overwrite ) end_time = time.time() monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time ) return ckpt_path warnings.warn( ( 'Flax Checkpointing will soon be deprecated in favor of Orbax' ' (https://github.com/google/orbax). Please refer to the Checkpoint' ' Upgrade Guide' ' (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/orbax_upgrade_guide.html)' ' to self-migrate your code to ocp.' ), DeprecationWarning, ) target = serialization.to_state_dict(target) target, mpa_targets = _split_mp_arrays(target) target = serialization.msgpack_serialize(target) has_mpa = bool(mpa_targets) if not overwrite: _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore sync_global_devices('Flax:Checkpoint:CheckOverwriteBeforeSave') # Save the files via I/O sync or async. def save_main_ckpt_task(): jax.monitoring.record_event('/jax/flax/checkpoint/save_main_ckpt_task') return _save_main_ckpt_file( target, has_mpa, (ckpt_tmp_path, ckpt_path), base_path, step, keep, overwrite, keep_every_n_steps, start_time, ) # Write the main checkpoint file only via process 0, to avoid race condition. if process_index() == 0: if async_manager: async_manager.save_async(save_main_ckpt_task) else: save_main_ckpt_task() if has_mpa: if not gda_manager: raise errors.MPACheckpointingRequiredError(ckpt_path, step) # Creating the directory containing GDAs explicitly. This should happen only # on process 0 and before any worker starts to write GDA data. if process_index() == 0: _make_mpa_dirs(mpa_targets, ckpt_tmp_path) sync_global_devices('Flax:Checkpoint:AfterCreateMPADir') _save_mpas( gda_manager, mpa_targets, ckpt_tmp_path, ckpt_path, base_path, keep, overwrite, keep_every_n_steps, start_time, async_manager, ) end_time = time.time() monitoring.record_event_duration_secs( _WRITE_CHECKPOINT_EVENT, end_time - start_time ) return ckpt_path def _all_checkpoints( ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_' ) -> list[str]: """Retrieve all checkpoint paths in directory. Args: ckpt_dir: str: directory of checkpoints to restore from. prefix: str: name prefix of checkpoint files. Returns: Sorted list of checkpoint paths or empty list if no checkpoints were found. """ ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str checkpoint_files: list[Any] = [ pathlib.PurePath(c) for c in _allowempty_listdir(ckpt_dir) ] checkpoint_files = [ os.path.join(ckpt_dir, c) for c in checkpoint_files if c.match(f'{prefix}*') and not c.match(f'{prefix}tmp') and not c.match(f'*{MP_ARRAY_POSTFIX}') and not c.match(f'*{ocp.utils.TMP_DIR_SUFFIX}*') ] checkpoint_files = natural_sort(checkpoint_files) if checkpoint_files: return checkpoint_files else: return [] def latest_checkpoint( ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_' ) -> str | None: """Retrieve the path of the latest checkpoint in a directory. Args: ckpt_dir: str: directory of checkpoints to restore from. prefix: str: name prefix of checkpoint files. Returns: The latest checkpoint path or None if no checkpoints were found. """ checkpoint_files = _all_checkpoints(ckpt_dir, prefix) if checkpoint_files: return checkpoint_files[-1] else: return None def available_steps( ckpt_dir: str | os.PathLike, prefix: str = 'checkpoint_', step_type: type = int, ) -> list[int | float]: """Return step numbers of available checkpoints in a directory. Args: ckpt_dir: str: directory of checkpoints to restore from. prefix: str: name prefix of checkpoint files. step_type: type: type for steps, int (default) or float. Returns: Sorted list of available steps or empty list if no checkpoints were found. """ checkpoint_files = _all_checkpoints(ckpt_dir, prefix) checkpoint_steps = [] for file in checkpoint_files: prefix_idx = file.rfind(prefix) checkpoint_steps += [step_type(file[prefix_idx + len(prefix) :])] return checkpoint_steps def restore_checkpoint( ckpt_dir: str | os.PathLike, target: Any | None, step: int | float | None = None, prefix: str = 'checkpoint_', parallel: bool = True, gda_manager: GlobalAsyncCheckpointManager | None = None, allow_partial_mpa_restoration: bool = False, orbax_checkpointer: ocp.Checkpointer | None = None, orbax_transforms: dict | None = None, ) -> PyTree: """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: * ``ckpt_1, ckpt_2, ckpt_3 --> ckpt_3`` * ``ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1`` * ``ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5`` Example usage:: >>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile ... >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)} Args: ckpt_dir: str: checkpoint file or directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is. step: int or float: step number to load or None to load latest. If specified, ckpt_dir must be a directory. prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. gda_manager: required if checkpoint contains a multiprocess array (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the separate subdirectory with postfix "_gda". allow_partial_mpa_restoration: If true, the given ``target`` doesn't have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide a fully valid ``target`` and don't need all the MPAs in the checkpoint to be restored. orbax_checkpointer: the ``ocp.Checkpointer`` that handles the underlying restore, if the given checkpoint is saved with ocp. orbax_transforms: the Orbax transformations that will be passed into ``orbax_checkpointer.restore()`` call. Returns: Restored ``target`` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in ``target`` unchanged. If a file path is specified and is not found, the passed-in ``target`` will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created. """ jax.monitoring.record_event('/jax/flax/checkpoint/restore') start_time = time.time() # Make sure any previous work is done before checking files. if orbax_checkpointer and isinstance( orbax_checkpointer, ocp.AsyncCheckpointer ): orbax_checkpointer.wait_until_finished() ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str ckpt_dir = safe_normpath(ckpt_dir) if step is not None: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not io.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: if not io.exists(ckpt_dir): logging.info('Found no checkpoint directory at %s', ckpt_dir) return target if io.isdir(ckpt_dir): # This means the given dir is an orbax checkpoint. if _is_orbax_checkpoint(ckpt_dir): ckpt_path = ckpt_dir else: ckpt_path = latest_checkpoint(ckpt_dir, prefix) # type: ignore if not ckpt_path: logging.info( 'Found no checkpoint files in %s with prefix %s', ckpt_dir, prefix ) return target else: ckpt_path = ckpt_dir # Restore the checkpoint with Orbax if needed. is_orbax = _is_orbax_checkpoint(ckpt_path) ckpt_type = 'orbax' if is_orbax else 'legacy Flax' logging.info(f'Restoring {ckpt_type} checkpoint from {ckpt_path}') if is_orbax: if not orbax_checkpointer: orbax_checkpointer = ocp.Checkpointer( ocp.PyTreeCheckpointHandler() ) restore_kwargs = {} if target is not None: restore_kwargs['restore_args'] = orbax_utils.restore_args_from_target( target ) if isinstance(orbax_checkpointer._handler, ocp.PyTreeCheckpointHandler): # pylint: disable=protected-access restore_kwargs[ 'transforms' ] = orbax_utils.maybe_construct_transformations( target, orbax_transforms ) restored = orbax_checkpointer.restore( ckpt_path, item=target, **restore_kwargs ) restored = serialization.to_state_dict(restored) if target is not None: restored = serialization.from_state_dict(target, restored) end_time = time.time() monitoring.record_event_duration_secs( _READ_CHECKPOINT_EVENT, end_time - start_time ) return restored # Legacy Flax checkpoint restoration. ckpt_size = io.getsize(ckpt_path) with io.GFile(ckpt_path, 'rb') as fp: if parallel and fp.seekable(): buf_size = 128 << 20 # 128M buffer. num_bufs = ckpt_size / buf_size logging.debug('num_bufs: %d', num_bufs) checkpoint_contents = bytearray(ckpt_size) def read_chunk(i): # NOTE: We have to re-open the file to read each chunk, otherwise the # parallelism has no effect. But we could reuse the file pointers # within each thread. with io.GFile(ckpt_path, 'rb') as f: f.seek(i * buf_size) buf = f.read(buf_size) if buf: checkpoint_contents[i * buf_size : i * buf_size + len(buf)] = buf return len(buf) / buf_size pool_size = 32 pool = thread.ThreadPoolExecutor(pool_size) results = pool.map(read_chunk, range(int(num_bufs) + 1)) pool.shutdown(wait=False) logging.debug(f'results: {list(results)}') else: checkpoint_contents = fp.read() state_dict = serialization.msgpack_restore(checkpoint_contents) state_dict = _restore_mpas( state_dict, target, ckpt_path, step, gda_manager, allow_partial_mpa_restoration, ) if target is None: restored_checkpoint = state_dict else: restored_checkpoint = serialization.from_state_dict(target, state_dict) end_time = time.time() monitoring.record_event_duration_secs( _READ_CHECKPOINT_EVENT, end_time - start_time ) return restored_checkpoint def convert_pre_linen(params: PyTree) -> PyTree: """Converts a pre-Linen parameter pytree. In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class. Consider the following module:: class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(1, 1)(x) x = nn.Dense(1)(x) return x In pre-Linen the resulting params would have had the structure: ``{'Conv_0': { ... }, 'Dense_1': { ... } }`` With Linen the resulting params would instead have had the structure: ``{'Conv_0': { ... }, 'Dense_0': { ... } }`` To convert from pre-Linen format to Linen simply call:: params = convert_pre_linen(pre_linen_params) Note that you can also use this utility to convert pre-Linen collections because they're following the same module naming. Note though that collections were "flat" in pre-Linen and first need to be unflattened before they can be used with this function:: batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in pre_linen_model_state.as_dict().items() })) Then Linen variables can be defined from these converted collections:: variables = {'params': params, 'batch_stats': batch_stats} Args: params: Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen). Returns: Parameter pytree with Linen submodule naming. """ if not isinstance(params, (dict, core.FrozenDict)): return params params_renamed = {} counts: dict[Any, Any] = {} names = natural_sort(params.keys()) for name in names: value = params[name] match = MODULE_NUM_RE.match(name) if match: module = match.group(1) num = counts.get(module, 0) name = f'{module}_{num}' counts[module] = num + 1 params_renamed[name] = convert_pre_linen(value) if isinstance(params, core.FrozenDict): params_renamed = core.freeze(params_renamed) # type: ignore return params_renamed ================================================ FILE: flax/training/common_utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common utility functions used in data-parallel Flax examples. This module is a historical grab-bag of utility functions primarily concerned with helping write pmap-based data-parallel training loops. """ import jax import jax.numpy as jnp from jax import lax def shard(xs): """Helper for pmap to shard a pytree of arrays by local_device_count. Args: xs: a pytree of arrays. Returns: A matching pytree with arrays' leading dimensions sharded by the local device count. """ local_device_count = jax.local_device_count() return jax.tree_util.tree_map( lambda x: x.reshape((local_device_count, -1) + x.shape[1:]), xs ) def shard_prng_key(prng_key): """Helper to shard (aka split) a PRNGKey for use with pmap'd functions. PRNG keys can be used at train time to drive stochastic modules e.g. Dropout. We would like a different PRNG key for each local device so that we end up with different random numbers on each one, hence we split our PRNG key. Args: prng_key: JAX PRNGKey Returns: A new array of PRNGKeys with leading dimension equal to local device count. """ return jax.random.split(prng_key, num=jax.local_device_count()) def stack_forest(forest): """Helper function to stack the leaves of a sequence of pytrees. Args: forest: a sequence of pytrees (e.g tuple or list) of matching structure whose leaves are arrays with individually matching shapes. Returns: A single pytree of the same structure whose leaves are individually stacked arrays. """ stack_args = lambda *args: jnp.stack(args) return jax.tree_util.tree_map(stack_args, *forest) def get_metrics(device_metrics): """Helper utility for pmap, gathering replicated timeseries metric data. Args: device_metrics: replicated, device-resident pytree of metric data, whose leaves are presumed to be a sequence of arrays recorded over time. Returns: A pytree of unreplicated, host-resident, stacked-over-time arrays useful for computing host-local statistics and logging. """ # We select the first element of x in order to get a single copy of a # device-replicated metric. device_metrics = jax.tree_util.tree_map( lambda x: x.addressable_shards[0].data.squeeze(0), device_metrics ) metrics_np = jax.device_get(device_metrics) return stack_forest(metrics_np) def onehot(labels, num_classes, on_value=1.0, off_value=0.0): """Create a dense one-hot version of an indexed array. NB: consider using the more standard ``jax.nn.one_hot`` instead. Args: labels: an n-dim JAX array whose last dimension contains integer indices. num_classes: the maximum possible index. on_value: the "on" value for the one-hot array, defaults to 1.0. off_value: the "off" value for the one-hot array, defaults to 0.0. Returns: A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes. """ x = labels[..., None] == jnp.arange(num_classes).reshape( (1,) * labels.ndim + (-1,) ) x = lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) return x.astype(jnp.float32) ================================================ FILE: flax/training/dynamic_scale.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dynamic loss scaling for mixed precision gradients.""" import functools from typing import Any, NamedTuple from collections.abc import Callable, Sequence import jax import jax.numpy as jnp from jax import lax from flax import struct from flax.typing import Array class DynamicScaleResult(NamedTuple): dynamic_scale: 'DynamicScale' finite: Array aux: Any grad: Any class DynamicScale(struct.PyTreeNode): """Dynamic loss scaling for mixed precision gradients. For many models gradient computations in float16 will result in numerical issues because small/large gradients being flushed to zero/infinity. Dynamic loss scaling is an algorithm that aims to find the largest scalar multiple for which the gradient does not overflow. This way the risk of underflow is minimized. the `value_and_grad` method mimicks `jax.value_and_grad`. Beside the loss and gradients it also ouputs and updated `DynamicScale` instance with the current loss scale factor. This method also returns a boolean value indicating whether the gradients are finite. Example:: from flax.training.dynamic_scale import DynamicScale def loss_fn(p): return jnp.asarray(p, jnp.float16) ** 2 p = jnp.array(1., jnp.float32) dyn_scale = DynamicScale(growth_interval=10) compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p)) for _ in range(100): dyn_scale, is_fin, loss, grad = compute_grad(dyn_scale, p) p += jnp.where(is_fin, 0.01 * grad, 0.) print(loss) Jax currently cannot execute conditionals efficiently on GPUs therefore we selectively ignore the gradient update using `jax.numpy.where` in case of non-finite gradients. Attributes: growth_factor: how much to grow the scalar after a period of finite gradients (default: 2.). backoff_factor: how much to shrink the scalar after a non-finite gradient (default: 0.5). growth_interval: after how many steps of finite gradients the scale should be increased (default: 2000). fin_steps: indicates how many gradient steps in a row have been finite. scale: the current scale by which the loss is multiplied. minimum_scale: the minimum value that the scale can take (default: the smallest positive number representable in floating point). """ growth_factor: float = struct.field(pytree_node=False, default=2.0) backoff_factor: float = struct.field(pytree_node=False, default=0.5) growth_interval: int = struct.field(pytree_node=False, default=2000) fin_steps: int = 0 scale: float = 65536.0 minimum_scale: float | None = struct.field( pytree_node=False, default=jnp.finfo(jnp.float32).tiny ) def value_and_grad( self, fun: Callable[..., Any], argnums: int | Sequence[int] = 0, has_aux: bool = False, axis_name: str | None = None, ) -> Callable[..., DynamicScaleResult]: """Wrapper around `jax.value_and_grad`. Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. axis_name: If an axis is given the gradients will be averaged across replicas (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives. Returns: A function that takes the same arguments as `fun` and returns a DynamicScaleResult """ @functools.wraps(fun) def loss_wrapper(*args): aux = fun(*args) if has_aux: return (self.scale * aux[0], aux[1]) else: return self.scale * aux grad_fn = jax.value_and_grad(loss_wrapper, argnums, has_aux) def grad_fn_wrapper(*args): aux, grad = grad_fn(*args) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale grad = jax.tree_util.tree_map( lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad ) if axis_name is not None: grad = lax.pmean(grad, axis_name) finite = jnp.array(True) for g in jax.tree_util.tree_leaves(grad): finite &= jnp.all(lax.is_finite(g)) grow = self.fin_steps == self.growth_interval fin_scale = jnp.where( grow & finite, jnp.minimum( self.scale * self.growth_factor, jnp.finfo(jnp.float32).max ), self.scale, ) inf_scale = self.scale * self.backoff_factor if self.minimum_scale is not None: inf_scale = jnp.maximum(inf_scale, self.minimum_scale) new_scale = jnp.where(finite, fin_scale, inf_scale) new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1) new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale) return DynamicScaleResult(new_self, finite, aux, grad) return grad_fn_wrapper ================================================ FILE: flax/training/early_stopping.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Early stopping.""" import math from flax import struct class EarlyStopping(struct.PyTreeNode): """Early stopping to avoid overfitting during training. The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:: >>> from flax.training.early_stopping import EarlyStopping >>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng): ... ... ... loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch] ... return None, {'loss': loss} >>> early_stop = EarlyStopping(min_delta=1e-3, patience=2) >>> optimizer = None >>> for epoch in range(10): ... optimizer, train_metrics = train_epoch( ... optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None) ... early_stop = early_stop.update(train_metrics['loss']) ... if early_stop.should_stop: ... print(f'Met early stopping criteria, breaking at epoch {epoch}') ... break Met early stopping criteria, breaking at epoch 7 Attributes: min_delta: Minimum delta between updates to be considered an improvement. patience: Number of steps of no improvement before stopping. best_metric: Current best metric value. patience_count: Number of steps since last improving update. should_stop: Whether the training loop should stop to avoid overfitting. has_improved: Whether the metric has improved greater or equal to the min_delta in the last ``.update`` call. """ min_delta: float = 0 patience: int = 0 best_metric: float = float('inf') patience_count: int = 0 should_stop: bool = False has_improved: bool = False def reset(self): return self.replace( best_metric=float('inf'), patience_count=0, should_stop=False, has_improved=False, ) def update(self, metric): """Update the state based on metric. Returns: The updated EarlyStopping class. The ``.has_improved`` attribute is True when there was an improvement greater than ``min_delta`` from the previous ``best_metric``. """ if ( math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta ): return self.replace( best_metric=metric, patience_count=0, has_improved=True ) else: should_stop = self.patience_count >= self.patience or self.should_stop return self.replace( patience_count=self.patience_count + 1, should_stop=should_stop, has_improved=False, ) ================================================ FILE: flax/training/lr_schedule.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Learning rate schedules used in FLAX image classification examples. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules """ import jax.numpy as jnp import numpy as np from absl import logging def _piecewise_constant(boundaries, values, t): index = jnp.sum(boundaries < t) return jnp.take(values, index) def create_constant_learning_rate_schedule( base_learning_rate, steps_per_epoch, warmup_length=0.0 ): """Create a constant learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) def learning_rate_fn(step): lr = base_learning_rate if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn def create_stepped_learning_rate_schedule( base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0 ): """Create a stepped learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the ``lr_sched_steps`` parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:: [ [30, 0.1], [60, 0.01], [80, 0.001] ] This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch lr_sched_steps: the schedule as a list of steps, each of which is a ``[epoch, lr_factor]`` pair; the step occurs at epoch ``epoch`` and sets the learning rate to ``base_learning_rage * lr_factor`` warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) boundaries = [step[0] for step in lr_sched_steps] decays = [step[1] for step in lr_sched_steps] boundaries = np.array(boundaries) * steps_per_epoch boundaries = np.round(boundaries).astype(int) values = np.array([1.0] + decays) * base_learning_rate def learning_rate_fn(step): lr = _piecewise_constant(boundaries, values, step) if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn def create_cosine_learning_rate_schedule( base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0 ): """Create a cosine learning rate schedule with optional warmup. Note that with `FLIP #1009`_ learning rate schedules in ``flax.training`` are **effectively deprecated** in favor of Optax_ schedules. Please refer to `Optimizer Schedules`_ for more information. .. _FLIP #1009: https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md .. _Optax: https://github.com/deepmind/optax .. _Optimizer Schedules: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches. Args: base_learning_rate: the base learning rate steps_per_epoch: the number of iterations per epoch halfcos_epochs: the number of epochs to complete half a cosine wave; normally the number of epochs used for training warmup_length: if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first ``warmup_length`` epochs Returns: Function ``f(step) -> lr`` that computes the learning rate for a given step. """ logging.warning( 'Learning rate schedules in ``flax.training`` are effectively deprecated ' 'in favor of Optax schedules. Please refer to ' 'https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules' ' for alternatives.' ) halfwavelength_steps = halfcos_epochs * steps_per_epoch def learning_rate_fn(step): scale_factor = jnp.cos(step * jnp.pi / halfwavelength_steps) * 0.5 + 0.5 lr = base_learning_rate * scale_factor if warmup_length > 0.0: lr = lr * jnp.minimum(1.0, step / float(warmup_length) / steps_per_epoch) return lr return learning_rate_fn ================================================ FILE: flax/training/orbax_utils.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated.""" import warnings from typing import Any import jax import numpy as np import orbax.checkpoint as ocp from jax.sharding import Mesh PyTree = Any def is_multi_device_array(value: Any) -> bool: """Instruct Orbax to save this array with Tensorstore instead of msgpack.""" if isinstance(value, jax.Array): return not value.is_fully_replicated return False def save_args_from_target(target: Any) -> Any: return jax.tree_util.tree_map(lambda _: ocp.SaveArgs(), target) def maybe_construct_transformations( target: Any, transforms: Any | None ) -> Any: if transforms is not None: return transforms flat_transforms = {} flat_target = ocp.utils.to_flat_dict(target, sep='/', keep_empty_nodes=True) for k, v in flat_target.items(): if v is None: flat_transforms[k] = ocp.Transform(use_fallback=True) return flat_transforms def restore_args_from_target(target: Any, mesh: Mesh | None = None) -> Any: """Creates Orbax `restore_args` given a target Pytree. Args: target: The Pytree that has the same structure as the checkpoint. The arrays restored from checkpoint will have the same `sharding` as the target Pytree's corresponding arrays. mesh: DEPRECATED ARG. Please simply use your mesh to create the arrays in your `target`, no need to pass it here. Returns: A Pytree of Orbax `RestoreArgs` or `ArrayRestoreArgs` """ def find_sharding(x): if hasattr(x, 'sharding'): return x.sharding return None # Simpler case: no JAX arrays if not any( jax.tree_util.tree_flatten(jax.tree_util.tree_map(find_sharding, target))[ 0 ] ): return jax.tree_util.tree_map( lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target ) # JAX arrays: find sharding from the given target and create RestoreArgs sharding_tree = jax.tree_util.tree_map(find_sharding, target) if mesh is not None: warnings.warn( ( 'restore_args_from_target(): `mesh` arg is deprecated. Simply' ' calling the function with target pytree should suffice.' ), DeprecationWarning, ) def substitute_embedding(s): return jax.sharding.NamedSharding(mesh, s.spec) sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree) restore_args = ocp.checkpoint_utils.construct_restore_args( target, sharding_tree, set_global_shape=False ) return restore_args ================================================ FILE: flax/training/prefetch_iterator.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility for constructing an iterator which prefetches data asynchronously.""" import threading import warnings class PrefetchIterator: """Wraps an iterator to provide async prefetching. DEPRECATION WARNING: TensorFlow datasets no longer require manual prefetching. Previously this class was used to make data loading using TensorFlow datasets more efficient. Now TF data handles prefetching with NumPy iterators correctly. Example:: tf_iter = dataset.as_numpy_iterator() # only loads data while calling next tf_iter = PrefetchIterator(tf_iter) # prefetches data in the background """ def __init__(self, data_iter, buffer_size=1): """Construct a PrefetchIterator. Args: data_iter: the Iterator that should be prefetched. buffer_size: how many items to prefetch (default: 1). """ warnings.warn( 'PrefetchIterator is deprecated. Use the standard `tf.data`' ' prefetch method instead', DeprecationWarning, ) self._data_iter = data_iter self.buffer_size = buffer_size self._cond = threading.Condition() self._buffer = [] self._active = True self._thread = threading.Thread(target=self._prefetch_loop, daemon=True) self._thread.start() self._error = None def __iter__(self): return self def __next__(self): with self._cond: self._cond.wait_for(lambda: self._buffer or not self._active) if self._buffer: item = self._buffer.pop(0) self._cond.notify_all() return item if self._error: raise self._error # pylint: disable=raising-bad-type assert not self._active raise StopIteration() def close(self): with self._cond: self._active = False self._cond.notify_all() def _prefetch_loop(self): """Prefetch loop that prefetches a tf dataset.""" def _predicate(): return len(self._buffer) < self.buffer_size or not self._active while True: try: item = next(self._data_iter) with self._cond: self._buffer.append(item) self._cond.notify_all() self._cond.wait_for(_predicate) if not self._active: return except Exception as e: # pylint: disable=broad-except with self._cond: self._error = e self._active = False self._cond.notify_all() return ================================================ FILE: flax/training/train_state.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any from collections.abc import Callable import optax import jax from flax import core, struct from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT class TrainState(struct.PyTreeNode): """Simple train state for the common case with a single Optax optimizer. Example usage:: >>> import flax.linen as nn >>> from flax.training.train_state import TrainState >>> import jax, jax.numpy as jnp >>> import optax >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 2)) >>> model = nn.Dense(2) >>> variables = model.init(jax.random.key(0), x) >>> tx = optax.adam(1e-3) >>> state = TrainState.create( ... apply_fn=model.apply, ... params=variables['params'], ... tx=tx) >>> def loss_fn(params, x, y): ... predictions = state.apply_fn({'params': params}, x) ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) Array(1.8136346, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) Array(1.8079796, dtype=float32) Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections). For more exotic usecases (e.g. multiple optimizers) it's probably best to fork the class and modify it. Args: step: Counter starts at 0 and is incremented by every call to ``.apply_gradients()``. apply_fn: Usually set to ``model.apply()``. Kept in this dataclass for convenience to have a shorter params list for the ``train_step()`` function in your training loop. params: The parameters to be updated by ``tx`` and used by ``apply_fn``. tx: An Optax gradient transformation. opt_state: The state for ``tx``. """ step: int | jax.Array apply_fn: Callable = struct.field(pytree_node=False) params: core.FrozenDict[str, Any] = struct.field(pytree_node=True) tx: optax.GradientTransformation = struct.field(pytree_node=False) opt_state: optax.OptState = struct.field(pytree_node=True) def apply_gradients(self, *, grads, **kwargs): """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. Note that internally this function calls ``.tx.update()`` followed by a call to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. Args: grads: Gradients that have the same pytree structure as ``.params``. **kwargs: Additional dataclass attributes that should be ``.replace()``-ed. Returns: An updated instance of ``self`` with ``step`` incremented by one, ``params`` and ``opt_state`` updated by applying ``grads``, and additional attributes replaced as specified by ``kwargs``. """ if OVERWRITE_WITH_GRADIENT in grads: grads_with_opt = grads['params'] params_with_opt = self.params['params'] else: grads_with_opt = grads params_with_opt = self.params updates, new_opt_state = self.tx.update( grads_with_opt, self.opt_state, params_with_opt ) new_params_with_opt = optax.apply_updates(params_with_opt, updates) # As implied by the OWG name, the gradients are used directly to update the # parameters. if OVERWRITE_WITH_GRADIENT in grads: new_params = { 'params': new_params_with_opt, OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT], } else: new_params = new_params_with_opt return self.replace( step=self.step + 1, params=new_params, opt_state=new_opt_state, **kwargs, ) @classmethod def create(cls, *, apply_fn, params, tx, **kwargs): """Creates a new instance with ``step=0`` and initialized ``opt_state``.""" # We exclude OWG params when present because they do not need opt states. params_with_opt = ( params['params'] if OVERWRITE_WITH_GRADIENT in params else params ) opt_state = tx.init(params_with_opt) return cls( step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs, ) ================================================ FILE: flax/traverse_util.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A utility for traversing immutable datastructures. A Traversal can be used to iterate and update complex data structures. Traversals take in an object and return a subset of its contents. For example, a Traversal could select an attribute of an object:: >>> from flax import traverse_util >>> import dataclasses >>> @dataclasses.dataclass ... class Foo: ... foo: int = 0 ... bar: int = 0 ... >>> x = Foo(foo=1) >>> iterator = traverse_util.TraverseAttr('foo').iterate(x) >>> list(iterator) [1] More complex traversals can be constructed using composition. It is often useful to start from the identity traversal and use a method chain to construct the intended Traversal:: >>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}] >>> traversal = traverse_util.t_identity.each()['foo'] >>> iterator = traversal.iterate(data) >>> list(iterator) [1, 3] Traversals can also be used to make changes using the ``update`` method:: >>> data = {'foo': Foo(bar=2)} >>> traversal = traverse_util.t_identity['foo'].bar >>> data = traversal.update(lambda x: x + x, data) >>> data {'foo': Foo(foo=0, bar=4)} Traversals never mutate the original data. Therefore, an update essentially returns a copy of the data including the provided updates. """ import abc import copy import dataclasses import warnings from typing import Any from collections.abc import Callable import jax import flax from flax.core.scope import VariableDict from flax.typing import PathParts from . import struct # the empty node is a struct.dataclass to be compatible with JAX. @struct.dataclass class _EmptyNode: pass empty_node = _EmptyNode() def _flatten(xs, prefix, keep_empty_nodes, is_leaf, sep): def _key(path): if sep is None: return path return sep.join(path) if not isinstance(xs, (flax.core.FrozenDict, dict)) or ( is_leaf and is_leaf(prefix, xs) ): return {_key(prefix): xs} result = {} is_empty = True for key, value in xs.items(): is_empty = False path = prefix + (key,) result.update(_flatten(value, path, keep_empty_nodes, is_leaf, sep)) if keep_empty_nodes and is_empty: if prefix == (): # when the whole input is empty return {} return {_key(prefix): empty_node} return result def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None): """Flatten a nested dictionary. The nested keys are flattened to a tuple. See ``unflatten_dict`` on how to restore the nested dictionary structure. Example:: >>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2} Note that empty dictionaries are ignored and will not be restored by ``unflatten_dict``. Args: xs: a nested dictionary keep_empty_nodes: replaces empty dictionaries with ``traverse_util.empty_node``. is_leaf: an optional function that takes the next nested dictionary and nested keys and returns True if the nested dictionary is a leaf (i.e., should not be flattened further). sep: if specified, then the keys of the returned dictionary will be ``sep``-joined strings (if ``None``, then keys will be tuples). Returns: The flattened dictionary. """ assert isinstance( xs, (flax.core.FrozenDict, dict) ), f'expected (frozen)dict; got {type(xs)}' return _flatten(xs, (), keep_empty_nodes, is_leaf, sep) def unflatten_dict(xs, sep=None): """Unflatten a dictionary. See ``flatten_dict`` Example:: >>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}} Args: xs: a flattened dictionary sep: separator (same as used with ``flatten_dict()``). Returns: The nested dictionary. """ assert isinstance(xs, dict), f'input is not a dict; it is a {type(xs)}' result = {} for path, value in xs.items(): if sep is not None: path = path.split(sep) if value is empty_node: value = {} cursor = result for key in path[:-1]: if key not in cursor: cursor[key] = {} cursor = cursor[key] cursor[path[-1]] = value return result def path_aware_map( f: Callable[[PathParts, Any], Any], nested_dict: VariableDict ) -> VariableDict: """A map function that operates over nested dictionary structures while taking the path to each leaf into account. Example:: >>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}} Args: f: A callable that takes in ``(path, value)`` arguments and maps them to a new value. Here ``path`` is a tuple of strings. nested_dict: A nested dictionary structure. Returns: A new nested dictionary structure with the mapped values. """ flat = flatten_dict(nested_dict, keep_empty_nodes=True) return unflatten_dict( {k: f(k, v) if v is not empty_node else v for k, v in flat.items()} ) class Traversal(abc.ABC): """Base class for all traversals.""" def __new__(cls, *args, **kwargs): # Must override __new__ instead of __init__ since this is an ABC warnings.warn( '`flax.traverse_util.Traversal` will be deprecated. If you are using ' 'it for `flax.optim`, use `optax` instead. Refer to the update guide ' 'https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html ' 'for detailed instructions.', DeprecationWarning, ) return super().__new__(cls) @abc.abstractmethod def update(self, fn, inputs): """Update the focused items. Args: fn: the callback function that maps each traversed item to its updated value. inputs: the object that should be traversed. Returns: A new object with the updated values. """ pass @abc.abstractmethod def iterate(self, inputs): """Iterate over the values selected by this ``Traversal``. Args: inputs: the object that should be traversed. Returns: An iterator over the traversed values. """ pass def set(self, values, inputs): """Overrides the values selected by the ``Traversal``. Args: values: a list containing the new values. inputs: the object that should be traversed. Returns: A new object with the updated values. """ def update_fn(_): if not values: raise ValueError('Not enough values provided') return values.pop(0) y = self.update(update_fn, inputs) if values: raise ValueError('Too many values provided') return y def compose(self, other): """Compose two traversals.""" return TraverseCompose(self, other) def merge(self, *traversals): """Compose an arbitrary number of traversals and merge the results.""" return self.compose(TraverseMerge(*traversals)) def each(self): """Traverse each item in the selected containers.""" return self.compose(TraverseEach()) def tree(self): """Traverse each item in a pytree.""" return self.compose(TraverseTree()) def filter(self, fn): """Filter the selected values.""" return self.compose(TraverseFilter(fn)) def __getattr__(self, attr): return self.compose(TraverseAttr(attr)) def __getitem__(self, key): return self.compose(TraverseItem(key)) class TraverseId(Traversal): """The identity Traversal.""" def update(self, fn, inputs): return fn(inputs) def iterate(self, inputs): yield inputs with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) t_identity = TraverseId() class TraverseMerge(Traversal): """Merges the selection from a set of traversals.""" def __init__(self, *traversals): self._traversals = traversals def update(self, fn, inputs): for traversal in self._traversals: inputs = traversal.update(fn, inputs) return inputs def iterate(self, inputs): for traversal in self._traversals: yield from traversal.iterate(inputs) class TraverseCompose(Traversal): """Compose two traversals.""" def __init__(self, x, y): self._x = x self._y = y def update(self, fn, inputs): def update_fn(x): return self._y.update(fn, x) return self._x.update(update_fn, inputs) def iterate(self, inputs): for x in self._x.iterate(inputs): yield from self._y.iterate(x) class TraverseFilter(Traversal): """Filter selected values based on a predicate.""" def __init__(self, fn): self._fn = fn def update(self, fn, inputs): if self._fn(inputs): return fn(inputs) else: return inputs def iterate(self, inputs): if self._fn(inputs): yield inputs def _is_namedtuple(t): return issubclass(t, tuple) and hasattr(t, '_fields') class TraverseAttr(Traversal): """Traverse the attribute of an object.""" def __init__(self, attr): self._attr = attr def update(self, fn, inputs): value = fn(getattr(inputs, self._attr)) if _is_namedtuple(type(inputs)): return inputs._replace(**{self._attr: value}) elif dataclasses.is_dataclass(inputs): return dataclasses.replace(inputs, **{self._attr: value}) else: inputs = copy.copy(inputs) setattr(inputs, self._attr, value) return inputs def iterate(self, inputs): yield getattr(inputs, self._attr) class TraverseItem(Traversal): """Traverse the item of an object.""" def __init__(self, key): self._key = key def update(self, fn, inputs): if isinstance(inputs, tuple): ty = type(inputs) if isinstance(self._key, slice): sl = self._key else: sl = slice(self._key, self._key + 1) indices = set(range(*sl.indices(len(inputs)))) args = [ fn(inputs[i]) if i in indices else inputs[i] for i in range(len(inputs)) ] if _is_namedtuple(ty): return ty(*args) else: return ty(args) else: xs = copy.copy(inputs) xs[self._key] = fn(xs[self._key]) return xs def iterate(self, inputs): if isinstance(self._key, slice): yield from inputs[self._key] else: yield inputs[self._key] class TraverseEach(Traversal): """Traverse each item of a container.""" def update(self, fn, inputs): ty = type(inputs) if ty is dict: return {key: fn(val) for key, val in inputs.items()} if ty not in {list, tuple}: raise ValueError('Only the entries of a list or tuple can be traversed.') return ty(fn(x) for x in inputs) def iterate(self, inputs): if isinstance(inputs, dict): yield from inputs.values() else: yield from inputs class TraverseTree(Traversal): """Traverse every item in a pytree.""" def update(self, fn, inputs): return jax.tree_util.tree_map(fn, inputs) def iterate(self, inputs): yield from jax.tree_util.tree_leaves(inputs) def _get_params_dict(inputs): if isinstance(inputs, (dict, flax.core.FrozenDict)): return flax.core.unfreeze(inputs) else: raise ValueError( 'Can only traverse a flax Model instance or a nested dict, not ' f'{type(inputs)}' ) def _sorted_items(x): """Returns items of a dict ordered by keys.""" return sorted(x.items(), key=lambda x: x[0]) class ModelParamTraversal(Traversal): """Select model parameters using a name filter. This traversal operates on a nested dictionary of parameters and selects a subset based on the ``filter_fn`` argument. See :class:`flax.optim.MultiOptimizer` for an example of how to use :class:`ModelParamTraversal` to update subsets of the parameter tree with a specific optimizer. """ def __init__(self, filter_fn): """Constructor a new ModelParamTraversal. Args: filter_fn: a function that takes a parameter's full name and its value and returns whether this parameter should be selected or not. The name of a parameter is determined by the module hierarchy and the parameter name (for example: '/module/sub_module/parameter_name'). """ self._filter_fn = filter_fn def iterate(self, inputs): params = _get_params_dict(inputs) flat_dict = flatten_dict(params) for key, value in _sorted_items(flat_dict): path = '/' + '/'.join(key) if self._filter_fn(path, value): yield value def update(self, fn, inputs): params = _get_params_dict(inputs) flat_dict = flatten_dict(params, keep_empty_nodes=True) new_dict = {} for key, value in _sorted_items(flat_dict): # empty_node is not an actual leave. It's just a stub for empty nodes # in the nested dict. if value is not empty_node: path = '/' + '/'.join(key) if self._filter_fn(path, value): value = fn(value) new_dict[key] = value new_params = unflatten_dict(new_dict) if isinstance(inputs, flax.core.FrozenDict): return flax.core.FrozenDict(new_params) else: return new_params ================================================ FILE: flax/typing.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import abc from collections import deque import functools from functools import partial from typing import ( Any, Generic, Optional, Protocol, TypeGuard, TypeVar, Union, ) from collections.abc import Iterator from collections.abc import Callable, Hashable, Mapping, Sequence import jax import jax.numpy as jnp import numpy as np from flax.core import FrozenDict import dataclasses import jax.tree_util as jtu # General Array = Union[jax.Array, Any] PRNGKey = jax.Array RNGSequences = dict[str, PRNGKey] Dtype = Union[jax.typing.DTypeLike, Any] Shape = Sequence[int] K = TypeVar('K') class Key(Hashable, Protocol): def __lt__(self: K, value: K, /) -> bool: ... def is_key_like(x: Any) -> TypeGuard[Key]: return hasattr(x, '__hash__') and hasattr(x, '__lt__') Path = str PathParts = tuple[Key, ...] Leaf = Any # Linear PrecisionLike = Union[ None, str, jax.lax.Precision, tuple[str, str], tuple[jax.lax.Precision, jax.lax.Precision], ] DotGeneralT = Callable[..., Array] ConvGeneralDilatedT = Callable[..., Array] EinsumT = Callable[..., Array] PaddingLike = Union[str, int, Sequence[Union[int, tuple[int, int]]]] LaxPadding = Union[str, Sequence[tuple[int, int]]] # Initializers Initializer = Union[jax.nn.initializers.Initializer, Callable[..., Any]] # Collections Collection = Mapping[str, Any] MutableCollection = dict[str, Any] # Dicts VariableDict = Mapping[str, Collection] FrozenVariableDict = FrozenDict[str, Collection] MutableVariableDict = dict[str, MutableCollection] PRNGFoldable = Union[int, str] # Axes T = TypeVar('T') @dataclasses.dataclass(frozen=True) class In(Generic[T]): """Specifies a variable collection should only be lifted as input.""" axis: T @dataclasses.dataclass(frozen=True) class Out(Generic[T]): """Specifies a variable collection should only be lifted as output.""" axis: T Axis = Optional[int] InOutAxis = Union[Axis, In[Axis], Out[Axis]] ScanAxis = int InOutScanAxis = Union[ScanAxis, In[ScanAxis], Out[ScanAxis]] Axes = Union[int, Sequence[int]] # SPMD LogicalNames = tuple[Union[str, None], ...] AxisName = str | tuple[str, ...] | None # Maps each logical axis to physical mesh, can be either None (replicated), # one physical axis or a tuple of physical axes. LogicalRules = Sequence[tuple[str, AxisName]] ArrayPytree = Any # pylint: disable=invalid-name LogicalPartitionSpec = Any # pylint: disable=invalid-name LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name PartitionSpecPytree = Any # pylint: disable=invalid-name Sharding = tuple[AxisName, ...] A = TypeVar('A') HA = TypeVar('HA', bound=Hashable) HB = TypeVar('HB') class PytreeDeque(deque[A]): pass def _pytree_deque_flatten(xs: PytreeDeque, *, with_path: bool): if with_path: nodes = tuple((jtu.SequenceKey(i), x) for i, x in enumerate(xs)) return nodes, () else: return xs, () def _pytree_deque_unflatten(_, nodes): return PytreeDeque(nodes) jtu.register_pytree_with_keys( PytreeDeque, partial(_pytree_deque_flatten, with_path=True), _pytree_deque_unflatten, flatten_func=partial(_pytree_deque_flatten, with_path=False), ) class Missing: pass MISSING = Missing() def _bytes_repr(num_bytes): count, units = ( (f'{num_bytes / 1e9:,.1f}', 'GB') if num_bytes > 1e9 else (f'{num_bytes / 1e6:,.1f}', 'MB') if num_bytes > 1e6 else (f'{num_bytes / 1e3:,.1f}', 'KB') if num_bytes > 1e3 else (f'{num_bytes:,}', 'B') ) return f'{count} {units}' class ShapeDtype(Protocol): shape: Shape dtype: Dtype def has_shape_dtype(x: Any) -> TypeGuard[ShapeDtype]: return hasattr(x, 'shape') and hasattr(x, 'dtype') @dataclasses.dataclass(frozen=True, slots=True) class SizeBytes: # type: ignore[misc] size: int bytes: int @classmethod def from_array(cls, x: ShapeDtype): size = int(np.prod(x.shape)) dtype: jnp.dtype if isinstance(x.dtype, str): dtype = jnp.dtype(x.dtype) else: dtype = x.dtype # type: ignore bytes = size * dtype.itemsize # type: ignore return cls(size, bytes) def __add__(self, other: SizeBytes): return type(self)(self.size + other.size, self.bytes + other.bytes) def __bool__(self) -> bool: return bool(self.size) def __repr__(self) -> str: bytes_repr = _bytes_repr(self.bytes) return f'{self.size:,} ({bytes_repr})' @classmethod def from_any(cls, x): leaves = jax.tree.leaves(x) size_bytes = cls(0, 0) for leaf in leaves: if has_shape_dtype(leaf): size_bytes += cls.from_array(leaf) return size_bytes TupleArg = TypeVar('TupleArg', bound=tuple) class PromoteDtypeFn(Protocol): def __call__( self, args: TupleArg, /, *, dtype: Any = None, inexact: bool = True ) -> TupleArg: ... class HashableMapping(Mapping[HA, HB], Hashable): _mapping: dict[HA, HB] | Mapping[HA, HB] def __init__(self, mapping: Mapping[HA, HB], copy: bool = True): self._mapping = dict(mapping) if copy else mapping def __contains__(self, key: object) -> bool: return key in self._mapping def __getitem__(self, key: HA) -> HB: return self._mapping[key] def __iter__(self) -> Iterator[HA]: return iter(self._mapping) def __len__(self) -> int: return len(self._mapping) def __hash__(self) -> int: # use type-aware sorting to support int keys def _pytree__key_sort_fn(item: tuple[Any, Any]) -> tuple[int, Any]: key, _ = item if isinstance(key, int): return (0, key) elif isinstance(key, str): return (1, key) else: raise ValueError(f'Unsupported key type: {type(key)!r}') return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn))) def __eq__(self, other: Any) -> bool: return ( isinstance(other, HashableMapping) and self._mapping == other._mapping ) def __repr__(self) -> str: return repr(self._mapping) def update(self, other: Mapping[HA, HB]) -> HashableMapping[HA, HB]: """Updates the mapping with another mapping.""" mapping = dict(self._mapping) mapping.update(other) return HashableMapping(mapping, copy=False) F = TypeVar('F', bound=Callable[..., Any]) class BaseConfigContext(abc.ABC): @classmethod @abc.abstractmethod def get_default(cls): ... @classmethod @abc.abstractmethod def get_stack(cls) -> list: ... def __init__(self, value, /): stack = self.get_stack() if stack: self.prev_value = stack[-1] stack[-1] = value else: self.prev_value = None stack.append(value) self.new_value = value @classmethod def current_value(cls): stack = cls.get_stack() if stack: return stack[-1] return cls.get_default() def __enter__(self): if self.prev_value is not None: self.get_stack().insert(-1, self.prev_value) def __exit__(self, exc_type, exc_value, traceback): self.get_stack().pop() def __call__(self, f: F) -> F: self.get_stack().pop() if self.prev_value is not None: self.get_stack().append(self.prev_value) @functools.wraps(f) def wrapper(*args, **kwargs): self.get_stack().append(self.new_value) try: return f(*args, **kwargs) finally: self.get_stack().pop() return wrapper # type: ignore[return-value] ================================================ FILE: flax/version.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Current Flax version at head on Github.""" __version__ = '0.12.6' ================================================ FILE: flaxlib_src/.gitignore ================================================ /target # Byte-compiled / optimized / DLL files __pycache__/ .pytest_cache/ *.py[cod] # C extensions *.so # Distribution / packaging .Python .venv/ env/ bin/ build/ develop-eggs/ dist/ eggs/ lib/ lib64/ parts/ sdist/ var/ include/ man/ venv/ *.egg-info/ .installed.cfg *.egg # Installer logs pip-log.txt pip-delete-this-directory.txt pip-selfcheck.json # Unit test / coverage reports htmlcov/ .tox/ .coverage .cache nosetests.xml coverage.xml # Translations *.mo # Mr Developer .mr.developer.cfg .project .pydevproject # Rope .ropeproject # Django stuff: *.log *.pot .DS_Store # Sphinx documentation docs/_build/ # PyCharm .idea/ # VSCode .vscode/ # Pyenv .python-version # cibuildwheel /wheelhouse ================================================ FILE: flaxlib_src/CMakeLists.txt ================================================ # Set the minimum CMake version and policies for highest tested version cmake_minimum_required(VERSION 3.15...3.27) # Set up the project and ensure there is a working C++ compiler project(flaxlib LANGUAGES CXX) # Warn if the user invokes CMake directly if (NOT SKBUILD) message(WARNING "\ This CMake file is meant to be executed using 'scikit-build-core'. Running it directly will almost certainly not produce the desired result. If you are a user trying to install this package, use the command below, which will install all necessary build dependencies, compile the package in an isolated environment, and then install it. ===================================================================== $ pip install . ===================================================================== If you are a software developer, and this is your own package, then it is usually much more efficient to install the build dependencies in your environment once and use the following command that avoids a costly creation of a new virtual environment at every compilation: ===================================================================== $ pip install nanobind scikit-build-core[pyproject] $ pip install --no-build-isolation -ve . ===================================================================== You may optionally add -Ceditable.rebuild=true to auto-rebuild when the package is imported. Otherwise, you need to rerun the above after editing C++ files.") endif() # Try to import all Python components potentially needed by nanobind find_package(Python 3.8 REQUIRED COMPONENTS Interpreter Development.Module OPTIONAL_COMPONENTS Development.SABIModule) # Import nanobind through CMake's find_package mechanism find_package(nanobind CONFIG REQUIRED) # We are now ready to compile the actual extension module nanobind_add_module( # Name of the extension flaxlib_cpp # Target the stable ABI for Python 3.12+, which reduces # the number of binary wheels that must be built. This # does nothing on older Python versions STABLE_ABI # Source code goes here src/lib.cc ) # Install directive for scikit-build-core install(TARGETS flaxlib_cpp LIBRARY DESTINATION flaxlib) ================================================ FILE: flaxlib_src/Cargo.toml ================================================ [package] name = "flaxlib" version = "0.0.1-a1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] name = "flaxlib" crate-type = ["cdylib"] [dependencies] pyo3 = "0.21.2" ================================================ FILE: flaxlib_src/LICENSE ================================================ Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: flaxlib_src/README.md ================================================ # flaxlib ## Build flaxlib from source Install necessary dependencies to build the C++ based package. ```shell pip install meson-python ninja build ``` Clone the Flax repository, navigate to the flaxlib source directory. ```shell git clone git@github.com:google/flax.git cd flax/flaxlib_src ``` Configure the build. ```shell mkdir -p subprojects meson wrap install robin-map meson wrap install nanobind meson setup builddir ``` Compile the code. You'll need to run this repeatedly if you modify the source code. Note that the actual wheel name will differ depending on your system. ```shell meson compile -C builddir python -m build . -w pip install dist/flaxlib-0.0.1-cp311-cp311-macosx_14_0_arm64.whl --force-reinstall ``` ================================================ FILE: flaxlib_src/pyproject.toml ================================================ [build-system] requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] build-backend = "scikit_build_core.build" [project] name = "flaxlib" version = "0.0.1" requires-python = ">=3.10" classifiers = [ "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] [project.optional-dependencies] tests = [ "pytest", ] [tool.scikit-build] # Protect the configuration against future changes in scikit-build-core minimum-version = "0.4" # Setuptools-style build caching in a local directory build-dir = "build/{wheel_tag}" # Build stable ABI wheels for CPython 3.12+ wheel.py-api = "cp312" ================================================ FILE: flaxlib_src/src/flaxlib/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .flaxlib_cpp import RefMap as RefMap from .flaxlib_cpp import IndexMap as IndexMap from .flaxlib_cpp import NodeDef as NodeDef from .flaxlib_cpp import VariableDef as VariableDef from .flaxlib_cpp import NodeRef as NodeRef ================================================ FILE: flaxlib_src/src/flaxlib/flaxlib_cpp.pyi ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp RefMap = tp.MutableMapping[tp.Any, int] IndexMap = dict[int, tp.Any] class NodeDef: type: type index: int | None outer_index: int | None num_attributes: int metadata: tp.Any def with_no_outer_index(self) -> NodeDef: ... def with_same_outer_index(self) -> NodeDef: ... def __eq__(self, other: tp.Any) -> bool: ... def __hash__(self) -> int: ... def __getstate__( self, ) -> tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any]: ... @staticmethod def __setstate__( nodedef: NodeDef, state: tuple[tp.Any, tp.Any, tp.Any, tp.Any, tp.Any] ) -> None: ... class VariableDef: type: type index: int outer_index: int | None metadata: tp.Any def with_no_outer_index(self) -> VariableDef: ... def with_same_outer_index(self) -> VariableDef: ... def __eq__(self, other: tp.Any) -> bool: ... def __hash__(self) -> int: ... def __getstate__( self, ) -> tuple[tp.Any, int, tp.Any, tp.Any]: ... @staticmethod def __setstate__( variabledef: 'VariableDef', state: tuple[tp.Any, int, tp.Any, tp.Any] ) -> None: ... class NodeRef: index: int def __eq__(self, other: tp.Any) -> bool: ... def __hash__(self) -> int: ... def __getstate__(self) -> tuple[int]: ... @staticmethod def __setstate__(noderef: NodeRef, state: tuple[int]) -> None: ... ================================================ FILE: flaxlib_src/src/lib.cc ================================================ // Copyright 2024 The Flax Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include #include #include #include #include #include #include namespace nb = nanobind; using namespace nb::literals; // ----------------------------------- // helper functions // ----------------------------------- intptr_t nb_id(const nb::object &obj) { // Get the object ID return reinterpret_cast(obj.ptr()); } nb::tuple vector_to_tuple(const std::vector &vec) { if (vec.empty()) { return nb::tuple(); } else { return nb::tuple(nb::cast(vec)); } } // 1. Hash function for nb::object struct NbObjectHash { std::size_t operator()(const nb::object &obj) const { return nb::hash(obj); } }; // 2. Equality function for nb::object (Important!) struct NbObjectEqual { bool operator()(const nb::object &a, const nb::object &b) const { return a.equal(b); } }; namespace flaxlib { //--------------------------------------------------------------- // NNXContext //--------------------------------------------------------------- struct PythonContext { nb::object nnx; nb::object graph; nb::object jax; nb::object np; nb::object jax_Array; nb::object np_ndarray; nb::type_object GraphNodeImpl; nb::type_object PytreeNodeImpl; nb::type_object Object; nb::type_object Variable; nb::object get_node_impl; PythonContext() { nnx = nb::module_::import_("flax.nnx"); graph = nb::module_::import_("flax.nnx.graph"); jax = nb::module_::import_("jax"); np = nb::module_::import_("numpy"); jax_Array = jax.attr("Array"); np_ndarray = np.attr("ndarray"); GraphNodeImpl = graph.attr("GraphNodeImpl"); PytreeNodeImpl = graph.attr("PytreeNodeImpl"); Object = nnx.attr("Object"); Variable = graph.attr("Variable"); get_node_impl = graph.attr("get_node_impl"); } ~PythonContext() { graph.release(); jax.release(); np.release(); jax_Array.release(); np_ndarray.release(); GraphNodeImpl.release(); PytreeNodeImpl.release(); Variable.release(); get_node_impl.release(); } }; static std::optional _python_context; PythonContext &get_python_context() { if (!_python_context) { _python_context.emplace(); } return *_python_context; } //--------------------------------------------------------------- // IndexMap //--------------------------------------------------------------- struct IndexMap : public std::unordered_map { }; //--------------------------------------------------------------- // RefMap //--------------------------------------------------------------- struct RefMapKeysIterator { std::unordered_map>::iterator it; std::unordered_map>::iterator end; RefMapKeysIterator(std::unordered_map>::iterator it, std::unordered_map>::iterator end) : it(it), end(end) {} nb::object __next__() { if (it == end) { throw nb::stop_iteration(); } auto elem = it->second; ++it; return std::get<0>(elem); } }; struct RefMapItemsIterator { std::unordered_map>::iterator it; std::unordered_map>::iterator end; RefMapItemsIterator(std::unordered_map>::iterator it, std::unordered_map>::iterator end) : it(it), end(end) {} RefMapItemsIterator __iter__() { return *this; } nb::tuple __next__() { if (it == end) { throw nb::stop_iteration(); } auto elem = it->second; ++it; return nb::make_tuple(std::get<0>(elem), std::get<1>(elem)); } }; struct RefMap { std::unordered_map> mapping; RefMap() {} RefMap(const nb::object &iterable) : RefMap() { for (auto item : iterable) { nb::object obj = item[0]; auto value = nb::cast(item[1]); mapping[nb_id(obj)] = {obj, value}; } } void update(const RefMap &other) { for (const auto &[key_id, value_tuple] : other.mapping) { mapping[key_id] = value_tuple; } } int __getitem__(const nb::object &key) { return std::get<1>(mapping[nb_id(key)]); } void __setitem__(const nb::object &key, int value) { mapping[nb_id(key)] = std::make_tuple(key, value); } int __len__() const { return mapping.size(); } bool __contains__(const nb::object &key) const { return mapping.find(nb_id(key)) != mapping.end(); } RefMapKeysIterator __iter__() { return RefMapKeysIterator(mapping.begin(), mapping.end()); }; RefMapItemsIterator items() { return RefMapItemsIterator(mapping.begin(), mapping.end()); } std::optional get(const nb::object &key, std::optional default_value = std::nullopt) { auto it = mapping.find(nb_id(key)); if (it != mapping.end()) { return std::get<1>(it->second); } return default_value; } }; static IndexMap indexmap_from_refmap(const RefMap &refmap) { IndexMap indexmap; for (const auto &[_, value_index] : refmap.mapping) { nb::object value = std::get<0>(value_index); int index = std::get<1>(value_index); indexmap[index] = value; } return indexmap; }; static RefMap refmap_from_indexmap(const IndexMap &indexmap) { RefMap refmap; for (const auto &[index, value] : indexmap) { refmap.mapping[nb_id(value)] = std::make_tuple(value, index); } return refmap; }; //--------------------------------------------------------------- // NodeDef //--------------------------------------------------------------- struct NodeDef { nb::object type; std::optional index; std::optional outer_index; int num_attributes; nb::object metadata; NodeDef(nb::object type, std::optional index, std::optional outer_index, int num_attributes, nb::object metadata) : type(type), index(index), outer_index(outer_index), num_attributes(num_attributes), metadata(metadata) {} NodeDef with_no_outer_index() const { return NodeDef(type, index, std::nullopt, num_attributes, metadata); } NodeDef with_same_outer_index() const { return NodeDef(type, index, index, num_attributes, metadata); } bool __eq__(const nb::object &other_obj) const { if (!nb::isinstance(other_obj)) { return false; } NodeDef other = nb::cast(other_obj); return type.equal(other.type) && index == other.index && outer_index == other.outer_index && num_attributes == other.num_attributes && metadata.equal(other.metadata); } int __hash__() const { // return nb::hash(type) ^ nb::hash(nb::cast(index)) ^ nb::hash(nb::cast(outer_index)) ^ nb::hash(nb::cast(num_attributes)) ^ nb::hash(metadata); return nb::hash(nb::make_tuple(type, index, outer_index, num_attributes, metadata)); } nb::tuple __getstate__() const { return nb::make_tuple(type, index, outer_index, num_attributes, metadata); } static void __setstate__(NodeDef &nodedef, nb::tuple &t) { new (&nodedef) NodeDef(t[0], nb::cast>(t[1]), nb::cast>(t[2]), nb::cast(t[3]), t[4]); } }; //--------------------------------------------------------------- // VariableDef //--------------------------------------------------------------- struct VariableDef { nb::object type; int index; std::optional outer_index; nb::object metadata; VariableDef(nb::object type, int index, std::optional outer_index, nb::object metadata) : type(type), index(index), outer_index(outer_index), metadata(metadata) {} VariableDef with_no_outer_index() const { return VariableDef(type, index, std::nullopt, metadata); } VariableDef with_same_outer_index() const { return VariableDef(type, index, index, metadata); } bool __eq__(const nb::object &other_obj) const { if (!nb::isinstance(other_obj)) { return false; } VariableDef other = nb::cast(other_obj); return type.equal(other.type) && index == other.index && outer_index == other.outer_index && metadata.equal(other.metadata); } int __hash__() const { // return nb::hash(type) ^ nb::hash(nb::cast(index)) ^ nb::hash(nb::cast(outer_index)) ^ nb::hash(metadata); return nb::hash(nb::make_tuple(type, index, outer_index, metadata)); } nb::tuple __getstate__() const { return nb::make_tuple(type, index, outer_index, metadata); } static void __setstate__(VariableDef &variabledef, nb::tuple &t) { new (&variabledef) VariableDef(t[0], nb::cast(t[1]), nb::cast>(t[2]), t[3]); } }; //--------------------------------------------------------------- // NodeRef //--------------------------------------------------------------- struct NodeRef { int index; NodeRef(int index) : index(index) {} bool __eq__(const nb::object &other_obj) const { if (!nb::isinstance(other_obj)) { return false; } NodeRef other = nb::cast(other_obj); return index == other.index; } int __hash__() const { return nb::hash(nb::cast(index)); } nb::tuple __getstate__() const { return nb::make_tuple(index); } static void __setstate__(NodeRef &noderef, nb::tuple &t) { new (&noderef) NodeRef(nb::cast(t[0])); } }; NB_MODULE(flaxlib_cpp, m) { nb::bind_map(m, "IndexMap") .def_static("from_refmap", &indexmap_from_refmap); nb::class_(m, "RefMapKeysIterator") .def("__next__", &flaxlib::RefMapKeysIterator::__next__); nb::class_(m, "RefMapItemsIterator") .def("__iter__", &flaxlib::RefMapItemsIterator::__iter__) .def("__next__", &flaxlib::RefMapItemsIterator::__next__); nb::class_(m, "RefMap") .def(nb::init<>()) .def(nb::init(), nb::arg("iterable")) .def("update", &flaxlib::RefMap::update) .def_static("from_indexmap", &refmap_from_indexmap) .def("__getitem__", &flaxlib::RefMap::__getitem__, nb::arg("key").none()) .def("__setitem__", &flaxlib::RefMap::__setitem__, nb::arg("key").none(), nb::arg("value")) .def("__len__", &flaxlib::RefMap::__len__) .def("__contains__", &flaxlib::RefMap::__contains__, nb::arg("key").none()) .def("__iter__", &flaxlib::RefMap::__iter__) .def("items", &flaxlib::RefMap::items) .def("get", &flaxlib::RefMap::get, nb::arg("key").none(), nb::arg("default_value").none()); nb::class_(m, "NodeDef") .def(nb::init, std::optional, int, nb::object>(), nb::arg("type"), nb::arg("index").none(), nb::arg("outer_index").none(), nb::arg("num_attributes"), nb::arg("metadata").none()) .def_prop_ro("type", [](const flaxlib::NodeDef &n) { return n.type; }) .def_prop_ro("index", [](const flaxlib::NodeDef &n) { return n.index; }) .def_prop_ro("outer_index", [](const flaxlib::NodeDef &n) { return n.outer_index; }) .def_prop_ro("num_attributes", [](const flaxlib::NodeDef &n) { return n.num_attributes; }) .def_prop_ro("metadata", [](const flaxlib::NodeDef &n) { return n.metadata; }) .def("with_no_outer_index", &flaxlib::NodeDef::with_no_outer_index) .def("with_same_outer_index", &flaxlib::NodeDef::with_same_outer_index) .def("__eq__", &flaxlib::NodeDef::__eq__, nb::arg().none()) .def("__hash__", &flaxlib::NodeDef::__hash__) .def("__getstate__", &flaxlib::NodeDef::__getstate__) .def("__setstate__", &flaxlib::NodeDef::__setstate__); nb::class_(m, "VariableDef") .def(nb::init, nb::object>(), nb::arg("type"), nb::arg("index"), nb::arg("outer_index").none(), nb::arg("metadata").none()) .def_prop_ro("type", [](const flaxlib::VariableDef &n) { return n.type; }) .def_prop_ro("index", [](const flaxlib::VariableDef &n) { return n.index; }) .def_prop_ro("outer_index", [](const flaxlib::VariableDef &n) { return n.outer_index; }) .def_prop_ro("metadata", [](const flaxlib::VariableDef &n) { return n.metadata; }) .def("with_no_outer_index", &flaxlib::VariableDef::with_no_outer_index) .def("with_same_outer_index", &flaxlib::VariableDef::with_same_outer_index) .def("__eq__", &flaxlib::VariableDef::__eq__, nb::arg().none()) .def("__hash__", &flaxlib::VariableDef::__hash__) .def("__getstate__", &flaxlib::VariableDef::__getstate__) .def("__setstate__", &flaxlib::VariableDef::__setstate__); nb::class_(m, "NodeRef") .def(nb::init(), nb::arg("index")) .def_prop_ro("index", [](const flaxlib::NodeRef &n) { return n.index; }) .def("__eq__", &flaxlib::NodeRef::__eq__, nb::arg().none()) .def("__hash__", &flaxlib::NodeRef::__hash__) .def("__getstate__", &flaxlib::NodeRef::__getstate__) .def("__setstate__", &flaxlib::NodeRef::__setstate__); } } // namespace flaxlib ================================================ FILE: nnx.py ================================================ # Copyright 2026 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Standalone nnx module shim. This module allows importing nnx directly as a standalone module: import nnx instead of: from flax import nnx Both imports refer to the exact same module, ensuring that `nnx.Module` and `flax.nnx.Module` are the same class in memory. """ from flax.nnx import * from flax.version import __version__ as __version__ from flax import nnx as _nnx # Re-export the module's metadata __all__: list[str] = _nnx.__all__ if hasattr(_nnx, '__all__') else [] ================================================ FILE: pylintrc ================================================ # This Pylint rcfile contains a best-effort configuration to uphold the # best-practices and style described in the Google Python style guide: # https://google.github.io/styleguide/pyguide.html # # Its canonical open-source location is: # https://google.github.io/styleguide/pylintrc [MASTER] # Add files or directories to the blacklist. They should be base names, not # paths. ignore=third_party # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=no # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=4 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist= [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. #enable= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=apply-builtin, backtick, bad-option-value, buffer-builtin, c-extension-no-member, cmp-builtin, cmp-method, coerce-builtin, coerce-method, delslice-method, div-method, duplicate-code, eq-without-hash, execfile-builtin, file-builtin, filter-builtin-not-iterating, fixme, getslice-method, global-statement, hex-method, idiv-method, implicit-str-concat-in-sequence, import-error, import-self, import-star-module-level, input-builtin, intern-builtin, invalid-str-codec, locally-disabled, long-builtin, long-suffix, map-builtin-not-iterating, metaclass-assignment, next-method-called, next-method-defined, no-absolute-import, no-else-break, no-else-continue, no-else-raise, no-else-return, no-member, no-self-use, nonzero-method, oct-method, old-division, old-ne-operator, old-octal-literal, old-raise-syntax, parameter-unpacking, print-statement, raising-string, range-builtin-not-iterating, raw_input-builtin, rdiv-method, reduce-builtin, relative-import, reload-builtin, round-builtin, setslice-method, signature-differs, standarderror-builtin, suppressed-message, sys-max-int, too-few-public-methods, too-many-ancestors, too-many-arguments, too-many-boolean-expressions, too-many-branches, too-many-instance-attributes, too-many-locals, too-many-public-methods, too-many-return-statements, too-many-statements, trailing-newlines, unichr-builtin, unicode-builtin, unpacking-in-except, useless-else-on-loop, useless-suppression, using-cmp-argument, xrange-builtin, zip-builtin-not-iterating, [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=main,_ # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl # Regular expression matching correct function names function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ # Regular expression matching correct variable names variable-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct constant names const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct attribute names attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ # Regular expression matching correct argument names argument-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class attribute names class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ # Regular expression matching correct inline iteration names inlinevar-rgx=^[a-z][a-z0-9_]*$ # Regular expression matching correct class names class-rgx=^_?[A-Z][a-zA-Z0-9]*$ # Regular expression matching correct module names module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ # Regular expression matching correct method names method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= [FORMAT] # Maximum number of characters on a single line. max-line-length=80 # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt # lines made too long by directives to pytype. # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=(?x)( ^\s*(\#\ )??$| ^\s*(from\s+\S+\s+)?import\s+.+$) # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=yes # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check= # Maximum number of lines in a module max-module-lines=99999 # String used as indentation unit. The internal Google style guide mandates 2 # spaces. Google's externaly-published style guide says 4, consistent with # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google # projects (like TensorFlow). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=TODO [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging,absl.logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=regsub, TERMIOS, Bastion, rexec, sets # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant, absl # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls, class_ # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=StandardError, Exception, BaseException ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" [project] name = "flax" requires-python = ">=3.11" description = "Flax: A neural network library for JAX designed for flexibility" keywords = [] authors = [ {name = "Flax team", email = "flax-dev@google.com"}, ] dependencies = [ "numpy>=1.23.2", # keep in sync with jax-version in .github/workflows/build.yml "jax>=0.8.1", "msgpack", "optax", "orbax-checkpoint", "tensorstore", "rich>=11.1", "typing_extensions>=4.2", "PyYAML>=5.4.1", "treescope>=0.1.7", "orbax-export>=0.0.8", ] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dynamic = ["version", "readme"] [project.optional-dependencies] testing = [ "clu", "clu<=0.0.9; python_version<'3.10'", "einops", "gymnasium[atari]; python_version<'3.14'", "jaxlib", "jaxtyping", "jraph>=0.0.6dev0", "ml-collections", "mypy", "opencv-python", # Set protobuf version for python 3.13+ to prevent error in # examples/mnist/train_test.py::TrainTest::test_train_and_evaluate # Failed to construct dataset "mnist", builder_kwargs "{}": Value out of range: 11594722 "protobuf<6; python_version>='3.13'", "pytest", "pytest-cov", "pytest-custom_exit_code", "pytest-xdist", "pytype", # WMT/LM1B examples "sentencepiece==0.2.0", # Segfault bug in 0.2.1 "tensorflow_text>=2.11.0; platform_system!='Darwin' and python_version < '3.13'", "tensorflow_datasets", "tensorflow>=2.12.0; python_version<'3.13'", # to fix Numpy np.bool8 deprecation error "tensorflow>=2.20.0; python_version>='3.13' and python_version<'3.14'", # Temporary fix for https://github.com/google/flax/issues/5143 "keras<3.13", "torch", "treescope>=0.1.1; python_version>='3.10'", "cloudpickle>=3.0.0", "ale-py>=0.10.2; python_version<'3.14'", ] docs = [ "sphinx==6.2.1", "sphinx-book-theme", "Pygments>=2.6.1", "ipykernel", "tqdm==4.67.1", "myst_nb", "nbstripout", "recommonmark", "ipython_genutils", "sphinx-design", "jupytext==1.13.8", "dm-haiku>=0.0.14", "docutils", # The next packages are for notebooks. "matplotlib", "scikit-learn", # The next packages are used in testcode blocks. "ml_collections", # notebooks "einops", "kagglehub>=0.3.3", "ipywidgets>=8.1.5", ] dev = [ "nanobind>=2.5.0", "pre-commit>=3.8.0", "scikit-build-core[pyproject]>=0.11.0", ] [project.urls] homepage = "https://github.com/google/flax" [tool.setuptools.dynamic] readme = {file = ["README.md"], content-type = "text/markdown"} version = {attr = "flax.version.__version__"} [tool.setuptools.packages.find] include = ["flax*"] [tool.setuptools.package-data] flax = ["*py.typed"] [tool.yapf] based_on_style = "yapf" [tool.pytype] # TODO(levskaya): figure out why we get pyi-error from flax's root __init__.py # could be a pytype bug. disable = "pyi-error" [tool.mypy] show_error_codes = true no_implicit_optional = true disable_error_code = "attr-defined" [[tool.mypy.overrides]] module = [ "tensorflow.*", "tensorboard.*", "absl.*", "jax.*", "rich.*", "flax.*", "jaxlib.cuda.*", "jaxlib.cpu.*", "msgpack", "numpy.*", "optax.*", "orbax.*", "opt_einsum.*", "scipy.*", "libtpu.*", "jaxlib.mlir.*", "yaml", ] ignore_missing_imports = true disable_error_code = "annotation-unchecked" # exclude nnx examples [[tool.mypy.overrides]] module = "flax.nnx.examples.*" ignore_errors = true [tool.pytest.ini_options] filterwarnings = [ # By default error out on any warnings. "error", # Jax warning when no gpu/tpu found. "ignore:No GPU/TPU found, falling back to CPU.*:UserWarning", # traverse_util.Traversal will be removed soon. "ignore:`flax.traverse_util.Traversal` will be deprecated.*:DeprecationWarning", # Deprecated legacy checkpoint - just want to keep the tests running for a while "ignore:Flax Checkpointing will soon be deprecated in favor of Orbax.*:DeprecationWarning", # DeprecationWarning: The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead. "ignore:.*The inputs_kv arg will be deprecated soon. Use inputs_k and inputs_v instead.*:DeprecationWarning", # DeprecationWarning: the function signature of MultiHeadDotProductAttention's `__call__` method has changed "ignore:.*the function signature of MultiHeadDotProductAttention's `__call__` method has changed.*:DeprecationWarning", # DeprecationWarning: ml_dtypes.float8_e4m3b11 is deprecated. "ignore:.*ml_dtypes.float8_e4m3b11 is deprecated.*:DeprecationWarning", # pytest-cov uses a deprecated feature of pytest-xdist. (2023-11-06) "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", # DeprecationWarning: jax.random.KeyArray is deprecated. "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", # DeprecationWarning: jax.core.Shape is deprecated. "ignore:.*jax.core.Shape is deprecated.*:DeprecationWarning", # DeprecationWarning: pkg_resources is deprecated as an API. "ignore:.*pkg_resources is deprecated as an API.*:DeprecationWarning", # DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. "ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning", # jax.xla_computation is deprecated but TF still uses it. "ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning", # Orbax warnings inside deprecated `flax.training` package. "ignore:.*Couldn't find sharding info under RestoreArgs.*:UserWarning", # RuntimeWarning: invalid value encountered in cast "ignore:.*invalid value encountered in cast.*:RuntimeWarning", # RuntimeWarning: divide by zero encountered in equal/not_equal "ignore:.*divide by zero encountered in.*:RuntimeWarning", # DeprecationWarning: numpy.core is deprecated "ignore:.*numpy.core is deprecated.*:DeprecationWarning", # DeprecationWarning: shape requires ndarray or scalar arguments "ignore:.*shape requires ndarray or scalar arguments.*:DeprecationWarning", # UserWarning: Sharding info not provided when restoring "ignore:.*Sharding info not provided when restoring.*:UserWarning", # UserWarning: pkg_resources is deprecated as an API. "ignore:.*pkg_resources is deprecated as an API.*:UserWarning", # DeprecationWarning: 'Data' is deprecated, please replace: "ignore:.*[Data|Static]' is deprecated, please replace.*:DeprecationWarning", # DeprecationWarning: Implicit conversion of an array to a dtype is deprecated; rather than dtype=arr use dtype=arr.dtype. "ignore:.*Implicit conversion of an array to a dtype is deprecated; rather than dtype=arr use dtype=arr.dtype.*:DeprecationWarning", # DeprecationWarning: Setting `jax_pmap_shmap_merge` is deprecated in JAX v0.9.0 and will be removed in JAX v0.10.0 "ignore:.*Setting `jax_pmap_shmap_merge` is deprecated in JAX.*:DeprecationWarning", ] [tool.coverage.report] exclude_lines = [ "@abc.abstractmethod", "raise NotImplementedError", ] [tool.pyink] pyink-indentation = 2 pyink-use-majority-quotes = true line-length = 80 preview = true [tool.ruff] # Exclude a variety of commonly ignored directories. exclude = [ "__init__.py", "activation.py", "partitioning.py", "flax/core/variables.py", "examples/", ] line-length = 80 indent-width = 2 [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. select = ["F401"] ignore = [] # Allow fix for all enabled rules (when `--fix`) is provided. # Full list of rules: https://docs.astral.sh/ruff/rules/ fixable = ["ALL"] unfixable = [] [tool.ruff.format] indent-style = "space" quote-style = "single" [tool.uv] # Ignore uv.lock and always upgrade the package to the latest upgrade-package = ["jax", "jaxlib", "orbax-checkpoint"] ================================================ FILE: tests/checkpoints_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.training.checkpoints.""" import copy import os import pathlib from typing import Any import jax import numpy as np import orbax.checkpoint as orbax from absl.testing import absltest, parameterized from jax import numpy as jnp from flax import config, core, errors, io, struct from flax import linen as nn from flax.training import checkpoints # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() PyTree = Any def check_eq(xs, ys): return jax.tree_util.tree_all( jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys) ) def shuffle(l): """Functional shuffle.""" l = copy.copy(l) np.random.shuffle(l) return l class Inner(nn.Module): """Inner class based on nn.""" @nn.compact def __call__(self, x): x = nn.Conv(10, (2, 2))(x) x = nn.normalization.BatchNorm(True)(x) return x class Model(nn.Module): """Simple model based on nn.""" @nn.compact def __call__(self, inputs): x = nn.Conv(10, (2, 2))(inputs) x = Inner()(x) x = x.reshape([x.shape[0], -1]) x = nn.normalization.BatchNorm(True)(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x @struct.dataclass class CustomDC: foo: Any bar: Any class CheckpointsTest(parameterized.TestCase): def setUp(self): super().setUp() config.update('flax_use_orbax_checkpointing', False) # default value def test_naturalsort(self): np.random.seed(0) tests = [ ['file_1', 'file_2', 'file_10', 'file_11', 'file_21'], ['file_0.001', 'file_0.01', 'file_0.1', 'file_1'], ['file_-3.0', 'file_-2', 'file_-1', 'file_0.0'], ['file_1e1', 'file_1.0e2', 'file_1e3', 'file_1.0e4'], ['file_1', 'file_2', 'file_9', 'file_1.0e1', 'file_11'], ] for test in tests: self.assertEqual(test, checkpoints.natural_sort(shuffle(test))) def test_safe_normpath(self): tests = ['./a/b/c', '/a//b/c', '/a/../b/c', 'a/b/./c', 'gs://a//b/c'] expected = ['a/b/c', '/a/b/c', '/b/c', 'a/b/c', 'gs://a/b/c'] for test, expect in zip(tests, expected): self.assertEqual(expect, checkpoints.safe_normpath(test)) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_save_restore_checkpoints(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = pathlib.Path(self.create_tempdir().full_path) test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32), } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32), } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32), } new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, prefix='test_' ) check_eq(new_object, test_object0) checkpoints.save_checkpoint( tmp_dir, test_object1, 0, prefix='test_', keep=1 ) self.assertIn('test_0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, prefix='test_' ) check_eq(new_object, test_object1) checkpoints.save_checkpoint( tmp_dir, test_object1, 1, prefix='test_', keep=1 ) checkpoints.save_checkpoint( tmp_dir, test_object2, 2, prefix='test_', keep=1 ) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, prefix='test_' ) check_eq(new_object, test_object2) checkpoints.save_checkpoint( tmp_dir, test_object2, 3, prefix='test_', keep=2 ) checkpoints.save_checkpoint( tmp_dir, test_object1, 4, prefix='test_', keep=2 ) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, prefix='test_' ) check_eq(new_object, test_object1) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, step=3, prefix='test_' ) check_eq(new_object, test_object2) # Restore with a specific checkpoint path, not the directory path. new_object = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_3'), test_object0 ) check_eq(new_object, test_object2) # If a specific path is specified, but it does not exist, the same behavior # as when a directory is empty should apply: the target is returned # unchanged. new_object = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_not_there'), test_object0 ) check_eq(new_object, test_object0) with self.assertRaises(ValueError): checkpoints.restore_checkpoint( tmp_dir, test_object0, step=5, prefix='test_' ) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_overwrite_checkpoints(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) overwrite_error = ValueError if use_orbax else errors.InvalidCheckpointError tmp_dir = self.create_tempdir().full_path test_object0 = {'a': np.array([0, 0, 0], np.int32)} test_object = {'a': np.array([1, 2, 3], np.int32)} checkpoints.save_checkpoint(tmp_dir, test_object0, 0, keep=1) with self.assertRaises(overwrite_error): checkpoints.save_checkpoint(tmp_dir, test_object, 0, keep=1) checkpoints.save_checkpoint(tmp_dir, test_object, 0, keep=1, overwrite=True) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0) check_eq(new_object, test_object) non_norm_dir_path = tmp_dir + '//' checkpoints.save_checkpoint(non_norm_dir_path, test_object, 4, keep=1) new_object = checkpoints.restore_checkpoint(non_norm_dir_path, test_object0) check_eq(new_object, test_object) @parameterized.parameters( {'use_orbax': True, 'keep_every_n_steps': None}, {'use_orbax': False, 'keep_every_n_steps': 7}, ) def test_keep(self, use_orbax, keep_every_n_steps): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path test_object = {'a': np.array([1, 2, 3], np.int32)} steps_start = 17 steps_end = 37 keep = 3 increment = 5 for step in range(steps_start, steps_end, increment): checkpoints.save_checkpoint( tmp_dir, test_object, step=step, keep=keep, keep_every_n_steps=keep_every_n_steps, ) last_checkpoint = -float('inf') for step in range(steps_start, steps_end, increment): if ((steps_end - step) / increment <= keep) or ( keep_every_n_steps and (step - last_checkpoint) >= keep_every_n_steps ): restored = checkpoints.restore_checkpoint( tmp_dir, target=None, step=step ) check_eq(restored, test_object) last_checkpoint = step else: with self.assertRaises(ValueError): checkpoints.restore_checkpoint(tmp_dir, target=None, step=step) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_save_restore_checkpoints_w_float_steps(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32), } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32), } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32), } checkpoints.save_checkpoint( tmp_dir, test_object1, 0.0, prefix='test_', keep=1 ) self.assertIn('test_0.0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object0, prefix='test_' ) check_eq(new_object, test_object1) checkpoints.save_checkpoint( tmp_dir, test_object1, 2.0, prefix='test_', keep=1 ) checkpoints.save_checkpoint( tmp_dir, test_object2, 3.0, prefix='test_', keep=2 ) self.assertIn('test_3.0', os.listdir(tmp_dir)) self.assertIn('test_2.0', os.listdir(tmp_dir)) check_eq(new_object, test_object1) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_save_restore_checkpoints_target_none(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32), } # Target pytree is a dictionary, so it's equal to a restored state_dict. checkpoints.save_checkpoint(tmp_dir, test_object0, 0) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) check_eq(new_object, test_object0) # Target pytree it's a tuple, check the expected state_dict is recovered. test_object1 = ( np.array([0, 0, 0], np.int32), np.array([1, 1, 1], np.int32), ) checkpoints.save_checkpoint(tmp_dir, test_object1, 1) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) expected_new_object = {str(k): v for k, v in enumerate(test_object1)} check_eq(new_object, expected_new_object) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_save_restore_checkpoints_target_singular(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path test_object0 = np.array([0, 0, 0], np.int32) test_object1 = np.array([1, 1, 1], np.int32) # Orbax backend returns error if target is singular. Orbax user need to use # ArrayCheckpointHandler instead. if use_orbax: with self.assertRaises(ValueError): checkpoints.save_checkpoint(tmp_dir, test_object1, 0) else: checkpoints.save_checkpoint(tmp_dir, test_object1, 0) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object0, 1) new_object = checkpoints.restore_checkpoint(tmp_dir, target=test_object1) check_eq(new_object, test_object0) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_save_restore_checkpoints_target_empty(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path test_object0 = {} test_object1 = [] # Orbax returns ValueError if the target is empty, but legacy Flax doesn't. if use_orbax: with self.assertRaises(ValueError): checkpoints.save_checkpoint(tmp_dir, test_object1, 0) else: checkpoints.save_checkpoint(tmp_dir, test_object1, 0) new_object = checkpoints.restore_checkpoint(tmp_dir, target=None) check_eq(new_object, test_object0) checkpoints.save_checkpoint(tmp_dir, test_object0, 1) new_object = checkpoints.restore_checkpoint(tmp_dir, target=test_object1) check_eq(new_object, test_object1) def test_async_save_checkpoints(self): tmp_dir = pathlib.Path(self.create_tempdir().full_path) test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32), } test_object1 = { 'a': np.random.normal(size=(1000, 1000)), 'b': np.random.normal(size=(1000, 1000)), } test_object2 = { 'a': np.random.normal(size=(1000, 1000)), 'b': np.random.normal(size=(1000, 1000)), } test_object3 = { 'a': np.random.normal(size=(1000, 1000)), 'b': np.random.normal(size=(1000, 1000)), } am = checkpoints.AsyncManager() checkpoints.save_checkpoint( tmp_dir, test_object1, 0, prefix='test_', keep=1, async_manager=am ) # Hard-wait the write to be done, then check its content. am.save_future.result() self.assertIn('test_0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object1, prefix='test_' ) check_eq(new_object, test_object1) # Check two consecutive saves happen in the right order. checkpoints.save_checkpoint( tmp_dir, test_object2, 1, prefix='test_', keep=1, async_manager=am ) checkpoints.save_checkpoint( tmp_dir, test_object3, 2, prefix='test_', keep=1, async_manager=am ) am.save_future.result() self.assertIn('test_2', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint( tmp_dir, test_object1, prefix='test_' ) check_eq(new_object, test_object3) def test_last_checkpoint(self): tmp_dir = pathlib.Path(self.create_tempdir().full_path) with io.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') as f: f.write('test_tmp') io.makedirs(os.path.join(tmp_dir, 'test_tmp_gda')) self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'test_'), None) with io.GFile(os.path.join(tmp_dir, 'test_0'), 'w') as f: f.write('test_0') io.makedirs(os.path.join(tmp_dir, 'test_0_gda')) self.assertEqual( checkpoints.latest_checkpoint(tmp_dir, 'test_'), os.path.join(tmp_dir, 'test_0'), ) with io.GFile(os.path.join(tmp_dir, 'test_10'), 'w') as f: f.write('test_10') self.assertEqual( checkpoints.latest_checkpoint(tmp_dir, 'test_'), os.path.join(tmp_dir, 'test_10'), ) self.assertEqual(checkpoints.latest_checkpoint(tmp_dir, 'ckpt_'), None) path = f'orbaxtest_{orbax.utils.TMP_DIR_SUFFIX}_10' with io.GFile(os.path.join(tmp_dir, path), 'w') as f: f.write('orbaxtest_10') self.assertIsNone(checkpoints.latest_checkpoint(tmp_dir, 'orbaxtest_')) @parameterized.parameters( {'step_type': int, 'steps': [1, 5, 112]}, {'step_type': float, 'steps': [1.0, 4.5, 5.6]}, ) def test_available_steps(self, step_type, steps): tmp_dir = pathlib.Path(self.create_tempdir().full_path) with io.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') as f: f.write('test_tmp') io.makedirs(os.path.join(tmp_dir, 'test_tmp_gda')) for step in steps: with io.GFile(os.path.join(tmp_dir, 'test_' + str(step)), 'w') as f: f.write('test_' + str(step)) io.makedirs(os.path.join(tmp_dir, 'test_' + str(step) + '_gda')) self.assertEqual( checkpoints.available_steps(tmp_dir, 'test_', step_type=step_type), steps, ) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_complex_pytree(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path to_save = [ CustomDC(foo=12, bar=core.freeze({'x': jnp.array((1, 4))})), np.array((2, 3)), ] target = [ CustomDC(foo=0, bar=core.freeze({'x': jnp.array((0, 0))})), np.array((0, 0)), ] checkpoints.save_checkpoint(tmp_dir, to_save, 0) restored = checkpoints.restore_checkpoint(tmp_dir, target=target) check_eq(restored, to_save) # restore_checkpoint can automatically restore either orbax or legacy files. def test_auto_restore(self): tmp_dir = self.create_tempdir().full_path to_save = [CustomDC(foo=12, bar={'x': jnp.array((1, 4))}), np.array((2, 3))] target = [CustomDC(foo=0, bar={'x': jnp.array((0, 0))}), np.array((0, 0))] # Store an orbax ckpt config.update('flax_use_orbax_checkpointing', True) checkpoints.save_checkpoint(tmp_dir, to_save, 0, prefix='test_') # And a legacy ckpt config.update('flax_use_orbax_checkpointing', False) checkpoints.save_checkpoint(tmp_dir, to_save, 1, prefix='test_', keep=2) # Both gets restored with same API. restored = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_0'), target=target ) check_eq(restored, to_save) restored = checkpoints.restore_checkpoint( os.path.join(tmp_dir, 'test_1'), target=target ) check_eq(restored, to_save) @parameterized.parameters({'use_orbax': True}, {'use_orbax': False}) def test_smaller_target(self, use_orbax): config.update('flax_use_orbax_checkpointing', use_orbax) tmp_dir = self.create_tempdir().full_path to_save = {'a': jnp.ones((16, 256, 1024))} target = {'a': jnp.zeros((2, 3))} checkpoints.save_checkpoint(tmp_dir, to_save, 0, keep=1) new_object = checkpoints.restore_checkpoint(tmp_dir, target) check_eq(new_object, to_save) def test_convert_pre_linen(self): params = checkpoints.convert_pre_linen( { 'mod_0': { 'submod1_0': {}, 'submod2_1': {}, 'submod1_2': {}, }, 'mod2_2': {'submod2_2_0': {}}, 'mod2_11': {'submod2_11_0': {}}, 'mod2_1': {'submod2_1_0': {}}, } ) self.assertDictEqual( core.unfreeze(params), { 'mod_0': { 'submod1_0': {}, 'submod1_1': {}, 'submod2_0': {}, }, 'mod2_0': {'submod2_1_0': {}}, 'mod2_1': {'submod2_2_0': {}}, 'mod2_2': {'submod2_11_0': {}}, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/colab_tpu_jax_version.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# JAX/jaxlib should be both 0.3.25\n", "# because newer JAX versions are *not* supported on TPU runtimes\n", "# Flax should be included in a ƒresh kernel.\n", "!pip freeze | egrep 'jax|flax'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# should show 8 TPU devices\n", "import jax, jax.tools.colab_tpu\n", "jax.tools.colab_tpu.setup_tpu()\n", "jax.devices()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# sometimes it's necessary to install additional packages; but we need to keep\n", "# JAX/jaxlib versions pinned to what is supported by the TPU runtime...\n", "!pip install jax==0.3.25 jaxlib==0.3.25 flax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# in case JAX version has changed after the '!pip install`, below command should\n", "# show the offending packages\n", "!pip install -qq pipdeptree\n", "!pipdeptree -w silence -r -p jax" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# it's possible to get dependency tree without installing packages, but this\n", "# usually takes some 2-3 minutes...\n", "!pip install -qq pipgrip\n", "!pipgrip --tree flax==0.6.4" ] } ], "metadata": { "accelerator": "TPU", "gpuClass": "standard", "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: tests/configurations_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock from absl.testing import absltest from flax.configurations import bool_flag, config class MyTestCase(absltest.TestCase): def setUp(self): super().setUp() self.enter_context(mock.patch.object(config, '_values', {})) self._flag = bool_flag('flax_test', default=False, help='Just a test flag.') def test_duplicate_flag(self): with self.assertRaisesRegex(RuntimeError, 'already defined'): bool_flag(self._flag.name, default=False, help='Another test flag.') def test_default(self): self.assertFalse(self._flag.value) self.assertFalse(config.flax_test) def test_typed_update(self): config.update(self._flag, True) self.assertTrue(self._flag.value) self.assertTrue(config.flax_test) def test_untyped_update(self): config.update(self._flag.name, True) self.assertTrue(self._flag.value) self.assertTrue(config.flax_test) def test_update_unknown_flag(self): with self.assertRaisesRegex(LookupError, 'Unrecognized config option'): config.update('unknown', True) def test_temp_flip(self): self.assertFalse(self._flag.value) with config.temp_flip_flag('test', True): self.assertTrue(self._flag.value) self.assertFalse(self._flag.value) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/core_frozen_dict_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from absl.testing import absltest, parameterized from flax.core import FrozenDict, copy, freeze, pop, pretty_repr, unfreeze class FrozenDictTest(parameterized.TestCase): def test_frozen_dict_copies(self): xs = {'a': 1, 'b': {'c': 2}} frozen = freeze(xs) xs['a'] += 1 xs['b']['c'] += 1 self.assertEqual(unfreeze(frozen), {'a': 1, 'b': {'c': 2}}) def test_frozen_dict_maps(self): xs = {'a': 1, 'b': {'c': 2}} frozen = FrozenDict(xs) frozen2 = jax.tree_util.tree_map(lambda x: x + x, frozen) self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}}) def test_frozen_dict_pop(self): xs = {'a': 1, 'b': {'c': 2}} b, a = FrozenDict(xs).pop('a') self.assertEqual(a, 1) self.assertEqual(unfreeze(b), {'b': {'c': 2}}) def test_frozen_dict_partially_maps(self): x = jax.tree_util.tree_map( lambda a, b: (a, b), freeze({'a': 2}), freeze({'a': {'b': 1}}) ) self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})}) def test_frozen_dict_hash(self): xs = {'a': 1, 'b': {'c': 2}} ys = {'a': 1, 'b': {'c': 3}} self.assertNotEqual(hash(freeze(xs)), hash(freeze(ys))) def test_frozen_items(self): xs = {'a': 1, 'b': {'c': 2}} items = list(freeze(xs).items()) self.assertEqual(items, [('a', 1), ('b', freeze(xs['b']))]) def test_frozen_dict_repr(self): expected = """FrozenDict({ a: 1, b: { c: 2, d: {}, }, })""" xs = FrozenDict({'a': 1, 'b': {'c': 2, 'd': {}}}) self.assertEqual(repr(xs), expected) self.assertEqual(repr(FrozenDict()), 'FrozenDict({})') def test_frozen_dict_reduce(self): before = FrozenDict(a=FrozenDict(b=1, c=2)) cl, data = before.__reduce__() after = cl(*data) self.assertIsNot(before, after) self.assertEqual(before, after) self.assertEqual(after, {'a': {'b': 1, 'c': 2}}) def test_frozen_dict_copy_reserved_name(self): result = FrozenDict({'a': 1}).copy({'cls': 2}) self.assertEqual(result, {'a': 1, 'cls': 2}) @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}}, 'key': 'b', 'actual_new_x': {'a': 1}, 'actual_value': {'c': 2}, }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), 'key': 'b', 'actual_new_x': FrozenDict({'a': 1}), 'actual_value': FrozenDict({'c': 2}), }, ) def test_utility_pop(self, x, key, actual_new_x, actual_value): new_x, value = pop(x, key) self.assertTrue( new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) ) self.assertTrue( value == actual_value and isinstance(value, type(actual_value)) ) @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}}, 'add_or_replace': {'b': {'c': -1, 'd': 3}}, 'actual_new_x': {'a': 1, 'b': {'c': -1, 'd': 3}}, }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), 'add_or_replace': FrozenDict({'b': {'c': -1, 'd': 3}}), 'actual_new_x': FrozenDict({'a': 1, 'b': {'c': -1, 'd': 3}}), }, ) def test_utility_copy(self, x, add_or_replace, actual_new_x): new_x = copy(x, add_or_replace=add_or_replace) self.assertTrue( new_x == actual_new_x and isinstance(new_x, type(actual_new_x)) ) @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}}, }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), }, ) def test_utility_copy_singlearg(self, x): new_x = copy(x) self.assertTrue(new_x == x and isinstance(new_x, type(x))) @parameterized.parameters( { 'x': {'a': 1, 'b': {'c': 2}}, 'pretty_str': '{\n a: 1,\n b: {\n c: 2,\n },\n}', }, { 'x': FrozenDict({'a': 1, 'b': {'c': 2}}), 'pretty_str': ( 'FrozenDict({\n a: 1,\n b: {\n c: 2,\n },\n})' ), }, { 'x': 345, 'pretty_str': '345', }, ) def test_utility_pretty_repr(self, x, pretty_str): self.assertEqual(pretty_repr(x), pretty_str) def test_flatten(self): frozen = freeze({'c': 1, 'b': {'a': 2}}) flat_leaves, tdef = jax.tree_util.tree_flatten(frozen) self.assertEqual(flat_leaves, [2, 1]) self.assertEqual( jax.tree_util.tree_unflatten(tdef, flat_leaves), frozen, ) flat_path_leaves, tdef = jax.tree_util.tree_flatten_with_path(frozen) self.assertEqual( flat_path_leaves, [ ((jax.tree_util.DictKey('b'), jax.tree_util.DictKey('a')), 2), ((jax.tree_util.DictKey('c'),), 1), ], ) self.assertEqual( jax.tree_util.tree_unflatten(tdef, [l for _, l in flat_path_leaves]), frozen, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/core_lift_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import numpy as np from absl.testing import absltest from jax import numpy as jnp from jax import random from flax import errors from flax.core import FrozenDict, apply, copy, init, lift, nn class LiftTest(absltest.TestCase): def test_aliasing(self): def f(scope): a = scope.push('a') def g(scopes, _): scope, a = scopes self.assertEqual(a.parent, scope) lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,))) init(f)(random.key(0)) def test_undefined_param(self): def f(scope): dense = lift.vmap( nn.dense, in_axes=(0, None), out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, ) dense(scope.push('dense'), np.ones((3, 2)), 2) msg = r'Could not find parameter named "kernel" in scope "/vmap\(dense\)".' with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg): apply(f)({'params': {'dense': {'abc': np.ones((3, 3))}}}) def test_jit_cache(self): compiles = 0 @lift.jit def f(scope, _module_hash, x): nonlocal compiles compiles += 1 if scope.is_mutable_collection( 'intermediates' ) and not scope.is_mutable_collection('params'): scope.put_variable('intermediates', 'x', x + 1) return nn.dense(scope, x, 1) x = np.ones((3, 2)) module_hash = 1 _, params = init(f)(random.key(0), module_hash, x) init(f)(random.key(0), module_hash, x) self.assertEqual(compiles, 1) apply(f)(params, module_hash, x) self.assertEqual(compiles, 2) # apply should cause a compile apply(f)(params, module_hash, x) self.assertEqual(compiles, 2) # applying again should not # edge case where only the implicit return of the jitted functions changes. # this should not use the previously cached apply. _, state = apply(f, mutable='intermediates')(params, module_hash, x) self.assertEqual(compiles, 3) # applying again should not self.assertEqual(state['intermediates']['x'].sum(), 3 * 2 * 2) def test_vjp(self): def g(scope, x, y): p = scope.param('test', nn.initializers.constant(0.5), ()) scope.variable('state', 'counter', lambda: 0) return p * x * y def f(scope, x, y): z, bwd = lift.vjp(g, scope, x, y) return bwd(jnp.ones(y.shape)) x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) _, params = init(f)(random.key(0), x, y) params_grad, x_grad, y_grad = apply(f)(params, x, y) self.assertEqual( params_grad, { 'params': FrozenDict({'test': 32.0}), }, ) np.testing.assert_allclose(x_grad, [2.0, 2.5, 3.0]) np.testing.assert_allclose(y_grad, [0.5, 1.0, 1.5]) def test_jvp(self): def g(scope, x): p = scope.param('test', nn.initializers.zeros_init(), ()) scope.variable('state', 'counter', lambda: 0) return p * x def f(scope, x): vars_t = jax.tree_util.tree_map( jnp.ones_like, scope.variables().get('params', {}) ) _, out_t = lift.jvp( g, scope, (x,), (jnp.zeros_like(x),), {'params': vars_t} ) return out_t x = jnp.ones((3,)) _, params = init(f)(random.key(0), x) y_t = apply(f)(params, x) np.testing.assert_allclose(y_t, jnp.ones_like(x)) def test_while_loop(self): def f(scope, x): key_zero = random.key(0) key_zero = jnp.broadcast_to(key_zero, (2, *key_zero.shape)) scope.param('inc', lambda _: 1) scope.put_variable('state', 'acc', 0) scope.put_variable('state', 'rng_params', key_zero) scope.put_variable('state', 'rng_loop', jax.random.clone(key_zero)) def cond_fn(scope, c): acc = scope.get_variable('state', 'acc') return acc < x def body_fn(scope, c): i = scope.get_variable('state', 'acc') p_rng = scope.make_rng('params') l_rng = scope.make_rng('loop') scope.put_variable( 'state', 'rng_params', scope.get_variable('state', 'rng_params').at[i].set(p_rng), ) scope.put_variable( 'state', 'rng_loop', scope.get_variable('state', 'rng_loop').at[i].set(l_rng), ) inc = scope.get_variable('params', 'inc') scope.put_variable('state', 'acc', i + inc) return c + 2 return lift.while_loop( cond_fn, body_fn, scope, 0, carry_variables='state', split_rngs={'params': False, 'loop': True}, ) x = 2 c, vars = apply(f, mutable=True)( {}, x, rngs={'params': random.key(1), 'loop': random.key(2)} ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) self.assertEqual( vars['state']['rng_params'][0], vars['state']['rng_params'][1] ) with jax.debug_key_reuse(False): self.assertNotEqual( vars['state']['rng_loop'][0], vars['state']['rng_loop'][1], ) def test_cond(self): def f(scope, x, pred): scope.variable('state', 'true_count', lambda: 0) scope.variable('state', 'false_count', lambda: 0) def true_fn(scope, x): scope.variable('state', 'true_count').value += 1 return scope.child(nn.dense)(x, 2) def false_fn(scope, x): scope.variable('state', 'false_count').value += 1 return -scope.child(nn.dense)(x, 2) return lift.cond(pred, true_fn, false_fn, scope, x) x = jnp.ones((1, 3)) y1, vars = init(f)(random.key(0), x, True) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 0}) y2, vars = apply(f, mutable='state')(vars, x, False) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1}) np.testing.assert_allclose(y1, -y2) def test_switch(self): def f(scope, x, index): scope.variable('state', 'a_count', lambda: 0) scope.variable('state', 'b_count', lambda: 0) scope.variable('state', 'c_count', lambda: 0) def a_fn(scope, x): scope.variable('state', 'a_count').value += 1 return scope.child(nn.dense)(x, 2) def b_fn(scope, x): scope.variable('state', 'b_count').value += 1 return -scope.child(nn.dense)(x, 2) def c_fn(scope, x): scope.variable('state', 'c_count').value += 1 return scope.child(nn.dense)(x, 2) return lift.switch(index, [a_fn, b_fn, c_fn], scope, x) x = jnp.ones((1, 3)) y1, vars = init(f)(random.key(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) y2, updates = apply(f, mutable='state')(vars, x, 1) vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 0}) np.testing.assert_allclose(y1, -y2) y3, updates = apply(f, mutable='state')(vars, x, 2) vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) def test_subscope_var_aliasing(self): def test(scope, x): subscope = scope.push(name='a') subscope.put_variable('state', 'x', 0.0) _ = lift.while_loop( lambda scope, x: False, lambda scope, x: x, scope, jnp.array(0, jnp.int32), carry_variables=['state'], ) subscope.put_variable('state', 'x', 1.0) val0 = scope.variables()['state']['a']['x'] val1 = subscope.variables()['state']['x'] self.assertEqual(val0, val1) return x init(test)(random.key(0), 1.0) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/core_meta_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from absl.testing import absltest from jax import numpy as jnp from jax import random, sharding from jax.experimental import mesh_utils from flax import errors from flax.core import init, lift, meta, nn class MetaTest(absltest.TestCase): def test_boxed_param(self): def f(scope, xs): def g(scope, x): kernel_init = meta.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) kernel = scope.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = scope.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, meta.Partitioned) self.assertEqual(kernel_box.names, ('in', 'out')) return x @ kernel lift.vmap( g, in_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, metadata_params={meta.PARTITION_NAME: 'batch'}, )(scope, xs) _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['kernel'].names, ('batch', 'in', 'out') ) def test_boxed_variable(self): def f(scope, xs): def g(scope, x): kernel_init = meta.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) kernel = scope.variable( 'params', 'kernel', kernel_init, scope.make_rng('params'), (x.shape[-1], 2), ) kernel.value += 1.0 self.assertEqual(kernel.value.sum(), kernel.value.size * 2) kernel_box = scope.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, meta.Partitioned) self.assertEqual(kernel_box.names, ('in', 'out')) return x @ kernel.value lift.vmap( g, in_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, metadata_params={meta.PARTITION_NAME: 'batch'}, )(scope, xs) _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['kernel'].names, ('batch', 'in', 'out') ) def test_partition_axis_unspecified(self): def f(scope, xs): def g(scope, x): kernel_init = meta.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) scope.param('kernel', kernel_init, (3, 2)) return x with self.assertRaises(errors.PartitioningUnspecifiedError): lift.vmap( g, in_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, metadata_params={}, )(scope, xs) init(f)(random.key(0), jnp.zeros((8, 3))) def test_unbox(self): xs = { 'kernel': meta.Partitioned(jnp.zeros((3, 2)), ('in', 'out')), 'complex': meta.Partitioned( {'K': jnp.zeros((3, 2)), 'b': jnp.zeros((3,))}, ('data',) ), } unboxed = meta.unbox(xs) unboxed_shapes = jax.tree_util.tree_map(jnp.shape, unboxed) self.assertEqual( unboxed_shapes, { 'kernel': (3, 2), 'complex': { 'K': (3, 2), 'b': (3,), }, }, ) def test_scan_over_layers(self): def f(scope, x): def body(scope, x): kernel_init = meta.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) y = nn.dense(scope, x, 3, kernel_init=kernel_init) return y, () c, _ = lift.scan( body, variable_axes={'params': 0}, split_rngs={'params': True}, length=8, metadata_params={meta.PARTITION_NAME: 'layers'}, )(scope, x) return c _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) boxed_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual( boxed_shapes, { 'kernel': meta.Partitioned((8, 3, 3), ('layers', 'in', 'out')), 'bias': (8, 3), }, ) def test_get_partition_spec(self): xs = { 'kernel': meta.Partitioned(jnp.zeros((8, 3, 3)), ('layers', 'in', 'out')), 'bias': jnp.zeros((8, 3)), 'step': jnp.array(100), } ps = meta.get_partition_spec(xs) self.assertEqual( ps, { 'kernel': jax.sharding.PartitionSpec('layers', 'in', 'out'), 'bias': jax.sharding.PartitionSpec(), 'step': jax.sharding.PartitionSpec(), }, ) def test_get_sharding(self): devices = mesh_utils.create_device_mesh((jax.local_device_count(), 1)) mesh = sharding.Mesh(devices, ('in', 'out')) xs = { 'kernel': meta.Partitioned(jnp.zeros((8, 3)), ('in', 'out')), 'bias': jnp.zeros((8, 3)), 'step': jnp.array(100), } ps = meta.get_sharding(xs, mesh) self.assertEqual( ps, { 'kernel': jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec('in', 'out') ), 'bias': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), 'step': jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()), }, ) def test_boxed_param_with_mesh(self): devices = mesh_utils.create_device_mesh((jax.local_device_count(), 1)) mesh = sharding.Mesh(devices, ('in', 'out')) def f(scope, x): kernel_init = meta.with_partitioning( nn.initializers.ones_init(), ('in', 'out'), mesh=mesh ) kernel = scope.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = scope.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, meta.Partitioned) self.assertEqual(kernel_box.names, ('in', 'out')) return x @ kernel @jax.jit def create_state(): y, variables = init(f)(random.key(0), jnp.zeros((8, 4))) spec = meta.get_partition_spec(variables) shardings = jax.tree_util.tree_map( lambda s: sharding.NamedSharding(mesh, s), spec ) variables = jax.lax.with_sharding_constraint(variables, shardings) return variables variables = create_state() self.assertEqual(variables['params']['kernel'].names, ('in', 'out')) self.assertIs(variables['params']['kernel'].mesh, mesh) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/core_scope_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import numpy as np from absl.testing import absltest from jax import numpy as jnp from jax import random from flax import errors from flax import config from flax.core import Scope, apply, freeze, init, lazy_init, nn, scope from flax.core.scope import LazyRng class ScopeTest(absltest.TestCase): def test_rng(self): def f(scope): self.assertTrue(scope.has_rng('params')) self.assertFalse(scope.has_rng('dropout')) rng = scope.make_rng('params') self.assertTrue( np.all(rng == LazyRng.create(random.key(0), 1).as_jax_rng()) ) init(f)(random.key(0)) def test_in_filter(self): filter_true = lambda x, y: self.assertTrue(scope.in_filter(x, y)) filter_false = lambda x, y: self.assertFalse(scope.in_filter(x, y)) filter_true(True, 'any_string1') filter_false(False, 'any_string2') filter_true('exact_match', 'exact_match') filter_false('no_match1', 'no_match2') filter_true(['one', 'two'], 'one') filter_false(['one', 'two'], 'three') filter_false([], 'one') filter_false([], None) def test_union_filter(self): def union_check(a, b, ans): self.assertEqual(scope.union_filters(a, b), ans) self.assertEqual(scope.union_filters(b, a), ans) union_check(['a', 'b'], ['b', 'c'], {'a', 'b', 'c'}) union_check(True, False, True) union_check(False, False, set()) union_check(True, True, True) union_check( scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), scope.DenyList({'b'}), ) union_check( scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList({'a'}) ) def test_intersect_filter(self): def intersect_check(a, b, ans): self.assertEqual(scope.intersect_filters(a, b), ans) self.assertEqual(scope.intersect_filters(b, a), ans) intersect_check(['a', 'b'], ['b', 'c'], {'b'}) intersect_check(True, False, False) intersect_check(False, False, set()) intersect_check(True, True, True) intersect_check( scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), scope.DenyList({'a', 'b', 'c'}), ) intersect_check(scope.DenyList(['a', 'b']), ['b', 'c'], {'c'}) def test_subtract_filter(self): def subtract_check(a, b, ans): self.assertEqual(scope.subtract_filters(a, b), ans) subtract_check(['a', 'b'], ['b', 'c'], {'a'}) subtract_check(True, False, scope.DenyList(False)) subtract_check(False, False, set()) subtract_check(True, True, False) subtract_check(True, 'a', scope.DenyList('a')) subtract_check( scope.DenyList(['a', 'b']), scope.DenyList(['b', 'c']), {'c'} ) subtract_check( scope.DenyList(['a', 'b']), ['b', 'c'], scope.DenyList({'a', 'b', 'c'}), ) def test_group_collections(self): params = {'dense1': {'x': [10, 20]}} batch_stats = {'dense1': {'ema': 5}} xs = {'params': params, 'batch_stats': batch_stats} # Retrieve all keys only once. group = scope.group_collections(xs, ['params', 'params']) self.assertEqual(group, ({'params': params}, {})) # Ignore non-existing keys. self.assertEqual(scope.group_collections(xs, ['vars']), ({},)) # False gets nothing and True retrieves all keys once. self.assertEqual( scope.group_collections(xs, [False, True, True]), ({}, xs, {}) ) def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones_init(), (4,)) msg = ( r'For parameter "test" in "/", the given initializer is expected to' r' generate shape \(4,\), but the existing parameter it received has' r' shape \(2,\).' ) with self.assertRaisesRegex(errors.ScopeParamShapeError, msg): apply(f)(freeze({'params': {'test': np.ones((2,))}})) def test_apply_variables_bad_pytree(self): def f(scope): scope.param('kernel', nn.initializers.ones_init(), (4,)) params = freeze( { 'params': { 'kernel': np.ones((4,)), }, } ) apply(f)(params) # Valid. msg = 'but got a dict with an extra params layer' with self.assertRaisesRegex( errors.ApplyScopeInvalidVariablesStructureError, msg ): apply(f)({'params': params}) def test_mutate_undefined_collection(self): def f(scope): scope.put_variable('state', 'test', 123) msg = ( r'Cannot update variable "test" in "/" because collection "state" is' r' immutable.' ) with self.assertRaisesRegex(errors.ModifyScopeVariableError, msg): init(f, mutable='params')(random.key(0)) def test_undefined_param(self): def f(scope): nn.dense(scope.push('dense'), np.ones((1, 2)), 2) msg = r'Could not find parameter named "kernel" in scope "/dense".' with self.assertRaisesRegex(errors.ScopeParamNotFoundError, msg): apply(f)({'params': {'abc': 1}}) def test_variable_is_mutable(self): def f(scope, should_be_mutable): test = scope.variable('state', 'test', lambda: 1) self.assertEqual(test.is_mutable(), should_be_mutable) _, variables = apply(f, mutable='state')({}, True) apply(f, mutable=False)(variables, False) def test_rngs_check_w_frozen_dict(self): def f(scope, x): return x _ = apply(f)({}, np.array([0.0]), rngs=freeze({'a': random.key(0)})) def test_rng_check_w_old_and_new_keys(self): # random.key always returns a new-style typed PRNG key. key = random.key(0) self.assertTrue(scope._is_valid_rng(key)) self.assertFalse(scope._is_valid_rng(random.split(key))) # random.PRNGKey returns an old-style uint32 key by default. old_key = random.PRNGKey(0) self.assertTrue(scope._is_valid_rng(old_key)) self.assertFalse(scope._is_valid_rng(random.split(old_key))) # Also explicitly test raw key data, because the jax_enable_custom_prng # flag can make PRNGKey return new-style keys. raw_key = random.key_data(key) self.assertTrue(scope._is_valid_rng(raw_key)) self.assertFalse(scope._is_valid_rng(random.split(raw_key))) def test_rng_check_w_lazy_rng(self): key = random.key(0) self.assertTrue(scope._is_valid_rng(scope.LazyRng.create(key, 1))) def test_jax_leak_detector(self): with jax.check_tracer_leaks(True): def f(scope): def g(scope): pass scope.child(g)() jax.jit(init(f))(random.key(0)) def test_rng_counter_reuse(self): root = Scope({}, {'dropout': random.key(0)}) def f(scope): return scope.make_rng('dropout') a = root.child(f)() root = root.rewound() b = root.child(f)() self.assertFalse(jnp.allclose(a, b)) def test_empty_col_error(self): root = Scope({}) with self.assertRaises(errors.ScopeCollectionNotFound): root.param('test', nn.initializers.zeros_init(), ()) root = Scope({'params': {}}) with self.assertRaises(errors.ScopeCollectionNotFound): root.param('test', nn.initializers.zeros_init(), ()) root = Scope({'params': {'abc': 1}}) with self.assertRaises(errors.ScopeCollectionNotFound): root.variable('state', 'test', jnp.zeros, ()) root = Scope({'state': {}}) with self.assertRaises(errors.ScopeCollectionNotFound): root.variable('state', 'test', jnp.zeros, ()) def test_variable_no_init(self): root = Scope({}, mutable='state') with self.assertRaises(errors.ScopeCollectionNotFound): root.variable('state', 'test') root = Scope({'state': {'abc': 1}}, mutable='state') abc = root.variable('state', 'abc') self.assertEqual(abc.value, 1) with self.assertRaises(errors.ScopeVariableNotFoundError): root.variable('state', 'test') def test_variable_alias(self): scope = Scope({}, mutable='state') subscope = scope.push(name='a') subscope.put_variable('state', 'x', 0.0) scope.put_variable('state', 'a', {'x': jnp.array(1.0, jnp.float32)}) self.assertEqual( scope.variables()['state']['a']['x'], subscope.variables()['state']['x'] ) def test_lazy_init(self): def f(scope, x): k = scope.param( 'kernel', nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1]) ) return x @ k init_fn = lazy_init(f) # provide a massive input message which would OOM if any compute ops were actually executed variables = init_fn( random.key(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32), ) self.assertEqual(variables['params']['kernel'].shape, (128, 128)) def test_lazy_init_fails_on_data_dependence(self): def f(scope, x): # kernel is initialized with x so params are now dependent on the input k = scope.param('kernel', lambda _: x) return x * k init_fn = lazy_init(f) with self.assertRaises(errors.LazyInitError): init_fn(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) @config.temp_flip_flag('fix_rng_separator', True) def test_fold_in_static_seperator(self): x = LazyRng(random.key(0), ('ab', 'c')) y = LazyRng(random.key(0), ('a', 'bc')) self.assertFalse(np.all(x.as_jax_rng() == y.as_jax_rng())) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_attention_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from collections.abc import Callable, Sequence import jax from absl.testing import absltest from jax import lax, random from jax import numpy as jnp from flax.core import Array, Scope, init, lift, nn, unfreeze def softmax_attn(scope: Scope, weights: Array): del scope norm_dims = tuple(range(weights.ndim // 2, weights.ndim)) log_norms = jax.scipy.special.logsumexp( weights, axis=norm_dims, keepdims=True ) return jnp.exp(weights - log_norms) def with_dropout(fn, rate: float, deterministic: bool = False): def attn_fn(scope: Scope, weights: Array): attn_weights = fn(scope, weights) return nn.dropout( scope, attn_weights, deterministic=deterministic, rate=rate ) return attn_fn def _dot_product_attention( scope: Scope, query: Array, key: Array, value: Array, bias: Array | None = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32, ): assert key.ndim == query.ndim assert key.ndim == value.ndim n = query.ndim attn_weights = lax.dot_general(query, key, (((n - 1,), (n - 1,)), ((), ()))) if bias is not None: attn_weights += bias attn_weights = attn_fn(scope, attn_weights) attn_weights = attn_weights.astype(dtype) contract_dims = ( tuple(range(n - 1, attn_weights.ndim)), tuple(range(0, n - 1)), ) y = lax.dot_general(attn_weights, value, (contract_dims, ((), ()))) return y def dot_product_attention( scope: Scope, inputs_q: Array, inputs_kv: Array, bias: Array | None = None, qkv_features: int | None = None, out_features: int | None = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32, ): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: out_features = inputs_q.shape[-1] dense = partial(nn.dense, features=qkv_features, bias=False, dtype=dtype) query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) y = _dot_product_attention( scope, query, key, value, bias=bias, attn_fn=attn_fn, dtype=dtype ) return scope.child(nn.dense, 'out')(y, features=out_features, dtype=dtype) def multi_head_dot_product_attention( scope: Scope, inputs_q: Array, inputs_kv: Array, bias: Array | None = None, qkv_features: int | None = None, out_features: int | None = None, attn_fn: Callable = softmax_attn, batch_axes: Sequence[int] = (0,), num_heads: int = 1, dtype=jnp.float32, broadcast_dropout=False, ): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: out_features = inputs_q.shape[-1] attn_fn = partial( dot_product_attention, attn_fn=attn_fn, qkv_features=qkv_features // num_heads, out_features=out_features, dtype=dtype, ) attn_fn = lift.vmap( attn_fn, in_axes=(None, None, None), out_axes=-2, axis_size=num_heads, variable_axes={'params': 0}, split_rngs={'params': True, 'dropout': not broadcast_dropout}, ) for axis in reversed(sorted(batch_axes)): attn_fn = lift.vmap( attn_fn, in_axes=(axis, axis, axis), out_axes=axis, variable_axes={'params': None}, split_rngs={'params': False, 'dropout': not broadcast_dropout}, ) y = attn_fn(scope, inputs_q, inputs_kv, bias) return y.mean(axis=-2) class AttentionTest(absltest.TestCase): def test_attention(self): inputs = jnp.ones((2, 7, 16)) model = partial( multi_head_dot_product_attention, num_heads=2, batch_axes=(0,), attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False), ) rngs = {'params': random.key(0), 'dropout': random.key(1)} y, variables = jax.jit(init(model))(rngs, inputs, inputs) variable_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(y.shape, (2, 7, 16)) self.assertEqual( unfreeze(variable_shapes), { 'key': {'kernel': (2, 16, 8)}, 'value': {'kernel': (2, 16, 8)}, 'query': {'kernel': (2, 16, 8)}, 'out': {'bias': (2, 16), 'kernel': (2, 8, 16)}, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_auto_encoder_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from collections.abc import Callable import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, init, nn, unfreeze def mlp(scope: Scope, x: Array, hidden: int, out: int): x = scope.child(nn.dense, 'hidden')(x, hidden) x = nn.relu(x) return scope.child(nn.dense, 'out')(x, out) @dataclass class AutoEncoder: latents: int features: int hidden: int def __call__(self, scope, x): z = self.encode(scope, x) return self.decode(scope, z) def encode(self, scope, x): return scope.child(mlp, 'encoder')(x, self.hidden, self.latents) def decode(self, scope, z): return scope.child(mlp, 'decoder')(z, self.hidden, self.features) def module_method(fn, name=None): if name is None: name = fn.__name__ if hasattr(fn, '__name__') else None def wrapper(self, *args, **kwargs): scope = self.scope.rewound() mod_fn = lambda scope: fn(self, scope, *args, **kwargs) return scope.child(mod_fn, name)() return wrapper @dataclass class AutoEncoder2: scope: Scope latents: int features: int hidden: int def __call__(self, x): z = self.encode(x) return self.decode(z) @module_method def encode(self, scope, x): return mlp(scope, x, self.hidden, self.latents) @module_method def decode(self, scope, z): return mlp(scope, z, self.hidden, self.features) @dataclass class AutoEncoder3: encode: Callable decode: Callable @staticmethod def create(scope, hidden: int, latents: int, features: int): enc = scope.child(mlp, 'encode', hidden=hidden, out=latents) dec = scope.child(mlp, 'decode', hidden=hidden, out=features) return AutoEncoder3(enc, dec) def __call__(self, x): z = self.encode(x) return self.decode(z) class AutoEncoderTest(absltest.TestCase): def test_auto_encoder_hp_struct(self): ae = AutoEncoder(latents=2, features=4, hidden=3) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( variable_shapes, { 'encoder': { 'hidden': {'kernel': (4, 3), 'bias': (3,)}, 'out': {'kernel': (3, 2), 'bias': (2,)}, }, 'decoder': { 'hidden': {'kernel': (2, 3), 'bias': (3,)}, 'out': {'kernel': (3, 4), 'bias': (4,)}, }, }, ) def test_auto_encoder_with_scope(self): ae = lambda scope, x: AutoEncoder2(scope, latents=2, features=4, hidden=3)( x ) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( variable_shapes, { 'encode': { 'hidden': {'kernel': (4, 3), 'bias': (3,)}, 'out': {'kernel': (3, 2), 'bias': (2,)}, }, 'decode': { 'hidden': {'kernel': (2, 3), 'bias': (3,)}, 'out': {'kernel': (3, 4), 'bias': (4,)}, }, }, ) def test_auto_encoder_bind_method(self): ae = lambda scope, x: AutoEncoder3.create( scope, latents=2, features=4, hidden=3 )(x) x = jnp.ones((1, 4)) x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( variable_shapes, { 'encode': { 'hidden': {'kernel': (4, 3), 'bias': (3,)}, 'out': {'kernel': (3, 2), 'bias': (2,)}, }, 'decode': { 'hidden': {'kernel': (2, 3), 'bias': (3,)}, 'out': {'kernel': (3, 4), 'bias': (4,)}, }, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_big_resnets_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import jax import numpy as np from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, init, lift, nn, unfreeze default_norm = partial(nn.batch_norm) def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): residual = x x = scope.child(conv, 'conv_1')(x, features, (3, 3)) x = scope.child(norm, 'bn_1')(x) x = act(x) x = scope.child(conv, 'conv_2')(x, features, (3, 3)) x = scope.child(norm, 'bn_2')(x) return act(residual + x) def big_resnet( scope: Scope, x, blocks=(10, 5), dtype=jnp.float32, norm=default_norm, act=nn.relu, ): conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) # a two stage resnet where inner blocks are rematerialized to make sure # memory consumtion grows as O(sqrt(N)) and compute is O(N) where N is the number of blocks.. # we use a double scan such that the compiled binary is of size O(1). print('total residual blocks:', np.prod(blocks)) def body_fn(scope, x): return residual_block(scope, x, conv, norm, act, features=x.shape[-1]) return lift.remat_scan( body_fn, lengths=blocks, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, policy=None, )(scope, x) class BigResnetTest(absltest.TestCase): def test_big_resnet(self): x = random.normal(random.key(0), (1, 8, 8, 8)) y, variables = init(big_resnet)(random.key(1), x) self.assertEqual(y.shape, (1, 8, 8, 8)) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) batch_stats_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['batch_stats']) ) self.assertEqual( param_shapes, { 'conv_1': {'kernel': (10, 5, 3, 3, 8, 8)}, 'conv_2': {'kernel': (10, 5, 3, 3, 8, 8)}, 'bn_1': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, 'bn_2': {'scale': (10, 5, 8), 'bias': (10, 5, 8)}, }, ) self.assertEqual( batch_stats_shapes, { 'bn_1': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, 'bn_2': {'var': (10, 5, 8), 'mean': (10, 5, 8)}, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_custom_vjp_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from collections.abc import Callable, Sequence import jax import numpy as np from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, apply, init, lift, nn, unfreeze def mlp_custom_grad( scope: Scope, x: Array, sizes: Sequence[int] = (8, 1), act_fn: Callable[[Array], Array] = nn.relu, ): f = nn.dense def fwd(scope, x, features): y, vjp_fn = lift.vjp(partial(f, features=features), scope, x) return y, vjp_fn def bwd(features, res, y_t): del features vjp_fn = res params_t, *input_t = vjp_fn(y_t) params_t = jax.tree_util.tree_map(jnp.sign, params_t) return (params_t, *input_t) dense_custom_grad = lift.custom_vjp( f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,) ) # hidden layers for size in sizes[:-1]: x = scope.child(dense_custom_grad, prefix='hidden_')(x, size) x = act_fn(x) # output layer return scope.child(dense_custom_grad, 'out')(x, sizes[-1]) class CustomVJPTest(absltest.TestCase): def test_custom_vjp(self): x = random.normal(random.key(0), (1, 4)) y, variables = init(mlp_custom_grad)(random.key(1), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) loss_fn = lambda p, x: jnp.mean(apply(mlp_custom_grad)(p, x) ** 2) grad = jax.grad(loss_fn)(variables, x) grad_shapes = unfreeze(jax.tree_util.tree_map(jnp.shape, grad['params'])) self.assertEqual(y.shape, (1, 1)) expected_param_shapes = { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, } self.assertEqual(param_shapes, expected_param_shapes) self.assertEqual(grad_shapes, expected_param_shapes) for g in jax.tree_util.tree_leaves(grad): self.assertTrue(np.all(g == np.sign(g))) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_dense_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax import struct from flax.core import Array, init, nn, unfreeze @dataclass class Dense: features: int bias: bool = True kernel_init: Any = nn.linear.default_kernel_init bias_init: Any = nn.initializers.zeros_init() def __call__(self, scope, x): kernel = scope.param( 'kernel', self.kernel_init, (x.shape[-1], self.features) ) y = x @ kernel if self.bias: bias = scope.param('bias', self.bias_init, (self.features,)) y += bias.reshape((1,) * (y.ndim - 1) + (-1,)) return y @struct.dataclass class ExplicitDense: kernel: Array bias: Array | None # a fully explicit "scope free" version @staticmethod def create( rng, in_size, out_size, bias=True, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros_init(), ): k1, k2 = random.split(rng, 2) kernel = kernel_init(k1, (in_size, out_size)) if bias: bias = bias_init(k2, (out_size,)) else: bias = None return ExplicitDense(kernel, bias) # a semi-explicit version where a scope is used to create explicit params @staticmethod def create_in_scope( scope, in_size, out_size, bias=True, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros_init(), ): kernel = scope.param('kernel', kernel_init, (in_size, out_size)) if bias: bias = scope.param('bias', bias_init, (out_size,)) else: bias = None return ExplicitDense(kernel, bias) def __call__(self, x): y = x @ self.kernel if self.bias is not None: y += self.bias.reshape((1,) * (y.ndim - 1) + (-1,)) return y def explicit_mlp(scope, x, sizes=(3, 1)): for i, size in enumerate(sizes): dense = scope.param(f'dense_{i}', ExplicitDense.create, x.shape[-1], size) x = dense(x) if i + 1 < len(sizes): x = nn.relu(x) return x def semi_explicit_mlp(scope, x, sizes=(3, 1)): for i, size in enumerate(sizes): dense = scope.child(ExplicitDense.create_in_scope, prefix='dense_')( x.shape[-1], size ) x = dense(x) if i + 1 < len(sizes): x = nn.relu(x) return x class DenseTest(absltest.TestCase): def test_dense(self): model = Dense(features=4) x = jnp.ones((1, 3)) y, variables = init(model)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 4)) self.assertEqual( param_shapes, { 'kernel': (3, 4), 'bias': (4,), }, ) def test_explicit_dense(self): x = jnp.ones((1, 3)) y, variables = init(explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 4)) self.assertEqual( param_shapes, { 'kernel': (3, 4), 'bias': (4,), }, ) def test_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': ExplicitDense((4, 3), (3,)), 'dense_1': ExplicitDense((3, 1), (1,)), }, ) def test_semi_explicit_dense(self): x = jnp.ones((1, 4)) y, variables = init(semi_explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 1)) self.assertEqual( param_shapes, { 'dense_0': {'kernel': (4, 3), 'bias': (3,)}, 'dense_1': {'kernel': (3, 1), 'bias': (1,)}, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_flow_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any from collections.abc import Sequence import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from jax.scipy.linalg import expm from flax.core import Array, Scope, apply, init, nn, unfreeze Initializer = Any Flow = Any @dataclass class DenseFlow: kernel_init: Initializer = nn.linear.default_kernel_init bias_init: Initializer = nn.initializers.zeros_init() def params(self, scope: Scope, features: int): kernel = scope.param('kernel', self.kernel_init, (features, features)) bias = scope.param('bias', self.bias_init, (features,)) return kernel, bias def forward(self, scope: Scope, x: Array): kernel, bias = self.params(scope, x.shape[-1]) return jnp.dot(x, expm(kernel)) + bias.reshape((1,) * (x.ndim - 1) + (-1,)) def backward(self, scope: Scope, y: Array): kernel, bias = self.params(scope, y.shape[-1]) return jnp.dot(y - bias.reshape((1,) * (y.ndim - 1) + (-1,)), expm(-kernel)) @dataclass class StackFlow: flows: Sequence[Flow] def forward(self, scope: Scope, x: Array): for i, f in enumerate(self.flows): x = scope.child(f.forward, name=str(i))(x) return x def backward(self, scope: Scope, x: Array): for i, f in reversed(tuple(enumerate(self.flows))): x = scope.child(f.backward, name=str(i))(x) return x class FlowTest(absltest.TestCase): def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) y, variables = init(flow.forward)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 3)) self.assertEqual( param_shapes, { '0': {'kernel': (3, 3), 'bias': (3,)}, '1': {'kernel': (3, 3), 'bias': (3,)}, '2': {'kernel': (3, 3), 'bias': (3,)}, }, ) x_restored = apply(flow.backward)(variables, y) self.assertTrue(jnp.allclose(x, x_restored)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_resnet_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, init, nn, unfreeze default_norm = partial(nn.batch_norm) def residual_block( scope: Scope, x: Array, conv, norm, act, features: int, strides=(1, 1) ): residual = x x = scope.child(conv, 'conv_1')(x, features, (1, 1)) x = scope.child(norm, 'bn_1')(x) x = act(x) x = scope.child(conv, 'conv_2')(x, 4 * features, (3, 3), strides=strides) x = scope.child(norm, 'bn_2')(x) x = act(x) x = scope.child(conv, 'conv_3')(x, 4 * features, (1, 1)) x = scope.child(norm, 'bn_3')(x) if x.shape != residual.shape: residual = scope.child(conv, 'proj_conv')( residual, 4 * features, (1, 1), strides=strides ) residual = scope.child(norm, 'proj_bn')(residual) return act(residual + x) def resnet( scope: Scope, x, block_sizes=(3, 4, 6, 3), features=16, num_classes=1000, dtype=jnp.float32, norm=default_norm, act=nn.relu, ): conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3))) x = scope.child(norm, 'init_bn')(x) x = act(x) x = nn.max_pool(x, (2, 2), (2, 2), 'SAME') for i, size in enumerate(block_sizes): for j in range(size): strides = (1, 1) if i > 0 and j == 0: strides = (2, 2) block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') x = residual_block( block_scope, x, conv, norm, act, block_features, strides ) # we can access parameters of the sub module by operating on the scope # Example: # block_scope.get_kind('params')['conv_1']['kernel'] x = jnp.mean(x, (1, 2)) x = scope.child(nn.dense, 'out')(x, num_classes) return x class ResNetTest(absltest.TestCase): def test_resnet(self): block_sizes = (2, 2) x = random.normal(random.key(0), (1, 64, 64, 3)) y, variables = init(resnet)( random.key(1), x, block_sizes=block_sizes, features=16 ) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(y.shape, (1, 1000)) self.assertEqual( param_shapes, { 'init_conv': {'kernel': (7, 7, 3, 16)}, 'init_bn': {'bias': (16,), 'scale': (16,)}, 'out': {'kernel': (128, 1000), 'bias': (1000,)}, 'block_0_0': { 'conv_1': {'kernel': (1, 1, 16, 16)}, 'conv_2': {'kernel': (3, 3, 16, 64)}, 'conv_3': {'kernel': (1, 1, 64, 64)}, 'bn_1': {'bias': (16,), 'scale': (16,)}, 'bn_2': {'bias': (64,), 'scale': (64,)}, 'bn_3': {'bias': (64,), 'scale': (64,)}, 'proj_conv': {'kernel': (1, 1, 16, 64)}, 'proj_bn': {'bias': (64,), 'scale': (64,)}, }, 'block_0_1': { 'conv_1': {'kernel': (1, 1, 64, 16)}, 'conv_2': {'kernel': (3, 3, 16, 64)}, 'conv_3': {'kernel': (1, 1, 64, 64)}, 'bn_1': {'bias': (16,), 'scale': (16,)}, 'bn_2': {'bias': (64,), 'scale': (64,)}, 'bn_3': {'bias': (64,), 'scale': (64,)}, }, 'block_1_0': { 'conv_1': {'kernel': (1, 1, 64, 32)}, 'conv_2': {'kernel': (3, 3, 32, 128)}, 'conv_3': {'kernel': (1, 1, 128, 128)}, 'bn_1': {'bias': (32,), 'scale': (32,)}, 'bn_2': {'bias': (128,), 'scale': (128,)}, 'bn_3': {'bias': (128,), 'scale': (128,)}, 'proj_conv': {'kernel': (1, 1, 64, 128)}, 'proj_bn': {'bias': (128,), 'scale': (128,)}, }, 'block_1_1': { 'conv_1': {'kernel': (1, 1, 128, 32)}, 'conv_2': {'kernel': (3, 3, 32, 128)}, 'conv_3': {'kernel': (1, 1, 128, 128)}, 'bn_1': {'bias': (32,), 'scale': (32,)}, 'bn_2': {'bias': (128,), 'scale': (128,)}, 'bn_3': {'bias': (128,), 'scale': (128,)}, }, }, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_scan_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, init, lift, nn, unfreeze def mlp_scan(scope: Scope, xs: Array, share_params: bool = False): scope.variable('counter', 'i', jnp.zeros, ()) def body_fn(scope, c, x): counter = scope.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = scope.child(nn.dense)(x, 1) return c, x if share_params: _, ys = lift.scan( body_fn, variable_carry='counter', variable_broadcast='params', split_rngs={'params': False}, )(scope, (), xs) else: _, ys = lift.scan( body_fn, variable_carry='counter', variable_axes={'params': 0}, split_rngs={'params': True}, )(scope, (), xs) # output layer return ys class ScanTest(absltest.TestCase): def test_scan_unshared_params(self): x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_scan)(random.key(1), x, share_params=False) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(variables['counter']['i'], 2) self.assertEqual( param_shapes, { 'dense_0': {'kernel': (2, 4, 1), 'bias': (2, 1)}, }, ) self.assertNotEqual(y[0], y[1]) k1, k2 = variables['params']['dense_0']['kernel'] self.assertFalse(jnp.allclose(k1, k2)) def test_scan_shared_params(self): x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_scan)(random.key(1), x, share_params=True) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual(variables['counter']['i'], 2) self.assertEqual( param_shapes, { 'dense_0': {'kernel': (4, 1), 'bias': (1,)}, }, ) self.assertEqual(y[0], y[1]) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_tied_autoencoder_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import init, lift, nn, unfreeze def transpose(fn): def trans(variables): return jax.tree_util.tree_map(lambda x: x.T, variables) return lift.map_variables( fn, 'params', map_in_fn=trans, map_out_fn=trans, mutable=True ) @dataclass class TiedAutoEncoder: latents: int features: int def __call__(self, scope, x): z = self.encode(scope, x) return self.decode(scope, z) def encode(self, scope, x): return nn.dense(scope, x, self.latents, bias=False) def decode(self, scope, z): return transpose(nn.dense)(scope, z, self.features, bias=False) class TiedAutoEncoderTest(absltest.TestCase): def test_tied_auto_encoder(self): ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) x_r, variables = init(ae)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( param_shapes, { 'kernel': (4, 2), }, ) self.assertEqual(x.shape, x_r.shape) def test_init_from_decoder(self): ae = TiedAutoEncoder(latents=2, features=4) z = jnp.ones((1, ae.latents)) x_r, variables = init(ae.decode)(random.key(0), z) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( param_shapes, { 'kernel': (4, 2), }, ) self.assertEqual(x_r.shape, (1, 4)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_vmap_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable, Sequence import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, init, lift, nn, unfreeze def mlp_vmap( scope: Scope, x: Array, sizes: Sequence[int] = (8, 1), act_fn: Callable[[Array], Array] = nn.relu, share_params: bool = False, ): if share_params: dense_vmap = lift.vmap( nn.dense, in_axes=(0, None), variable_axes={'params': None}, split_rngs={'params': False}, ) else: dense_vmap = lift.vmap( nn.dense, in_axes=(0, None), variable_axes={'params': 0}, split_rngs={'params': True}, ) # hidden layers for size in sizes[:-1]: x = scope.child(dense_vmap, prefix='hidden_')(x, size) x = act_fn(x) # output layer return scope.child(dense_vmap, 'out')(x, sizes[-1]) class VMapTest(absltest.TestCase): def test_vmap_shared(self): x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_vmap)(random.key(1), x, share_params=True) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( param_shapes, { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, }, ) self.assertEqual(y.shape, (2, 1)) self.assertTrue(jnp.allclose(y[0], y[1])) def test_vmap_unshared(self): x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) y, variables = init(mlp_vmap)(random.key(1), x, share_params=False) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( param_shapes, { 'hidden_0': {'kernel': (2, 4, 8), 'bias': (2, 8)}, 'out': {'kernel': (2, 8, 1), 'bias': (2, 1)}, }, ) self.assertEqual(y.shape, (2, 1)) self.assertFalse(jnp.allclose(y[0], y[1])) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/core/design/core_weight_std_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from collections.abc import Sequence import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from flax.core import Array, Scope, apply, init, lift, nn, unfreeze def weight_std(fn, kernel_name='kernel', eps=1e-8): def std(variables): params = variables['params'] assert kernel_name in params kernel = params[kernel_name] redux = tuple(range(kernel.ndim - 1)) norm = jnp.square(kernel).sum(redux, keepdims=True) std_kernel = kernel / jnp.sqrt(norm + eps) params[kernel_name] = std_kernel return variables # map_variables handles a few of nasty edge cases here... # the transformed kind will be immutable inside fn # this way we avoid lost mutations to param # map_variables also avoids accidental reuse of rngs # and it makes sure that other state is updated correctly (not twice during init!) return lift.map_variables(fn, 'params', std, init=True) def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1)): std_dense = weight_std( partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5)) ) for size in sizes[:-1]: x = scope.child(std_dense, prefix='hidden_')(x, size) return scope.child(nn.dense, 'out')(x, sizes[-1]) class WeightStdTest(absltest.TestCase): def test_weight_std(self): x = random.normal( random.key(0), ( 1, 4, ), ) y, variables = init(mlp)(random.key(1), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) self.assertEqual( param_shapes, { 'hidden_0': {'kernel': (4, 8), 'bias': (8,)}, 'out': {'kernel': (8, 1), 'bias': (1,)}, }, ) self.assertEqual(y.shape, (1, 1)) self.assertTrue(y.ravel() < 1.0) y2 = apply(mlp)(variables, x) self.assertTrue(jnp.allclose(y, y2)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/cursor_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.struct.""" import dataclasses from typing import Any, NamedTuple import jax import jax.numpy as jnp import optax from absl.testing import absltest import flax import flax.linen as nn from flax.core import freeze from flax.cursor import AccessType, _traverse_tree, cursor from flax.errors import CursorFindError, TraverseTreeError from flax.training import train_state # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class GenericTuple(NamedTuple): x: Any y: Any = None z: Any = None @dataclasses.dataclass class GenericDataClass: x: Any y: Any = None z: Any = None class CursorTest(absltest.TestCase): def test_repr(self): g = GenericTuple(1, 'a', (2, 'b')) c = cursor( {'a': {1: {(2, 3): 'z', 4: g, '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1} ) self.assertEqual( repr(c), """Cursor( _obj={'a': {1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1}, _changes={} )""", ) # test overwriting c['z'] = -2 c['z'] = -3 c['a']['b'][1] = -2 c['a']['b'] = None # test deep mutation c['a'][1][4].x = (2, 4, 6) c['a'][1][4].z[0] = flax.core.freeze({'a': 1, 'b': {'c': 2, 'd': 3}}) self.assertEqual( repr(c), """Cursor( _obj={'a': {1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, 'z': -1}, _changes={ 'z': Cursor( _obj=-3, _changes={} ), 'a': Cursor( _obj={1: {(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, 'b': [1, 2, 3]}, _changes={ 'b': Cursor( _obj=None, _changes={} ), 1: Cursor( _obj={(2, 3): 'z', 4: GenericTuple(x=1, y='a', z=(2, 'b')), '6': (7, 8)}, _changes={ 4: Cursor( _obj=GenericTuple(x=1, y='a', z=(2, 'b')), _changes={ 'x': Cursor( _obj=(2, 4, 6), _changes={} ), 'z': Cursor( _obj=(2, 'b'), _changes={ 0: Cursor( _obj=FrozenDict({ a: 1, b: { c: 2, d: 3, }, }), _changes={} ) } ) } ) } ) } ) } )""", ) def test_magic_methods(self): def same_value(v1, v2): if isinstance(v1, tuple): return all([ jnp.all(jax.tree_util.tree_map(lambda x, y: x == y, e1, e2)) for e1, e2 in zip(v1, v2) ]) return jnp.all(jax.tree_util.tree_map(lambda x, y: x == y, v1, v2)) list_obj = [(1, 2), (3, 4)] for l, tuple_wrap in ((list_obj, lambda x: x), (tuple(list_obj), tuple)): c = cursor(l) # test __len__ self.assertTrue(same_value(len(c), len(l))) # test __iter__ for i, child_c in enumerate(c): child_c[1] += i + 1 self.assertEqual(c.build(), tuple_wrap([(1, 3), (3, 6)])) # test __reversed__ for i, child_c in enumerate(reversed(c)): child_c[1] += i + 1 self.assertEqual(c.build(), tuple_wrap([(1, 5), (3, 7)])) # test __iter__ error with self.assertRaisesRegex( NotImplementedError, '__iter__ method only implemented for tuples and lists, not type ", ): c = cursor({'a': 1, 'b': 2}) for key in c: c[key] *= -1 # test __iter__ error with self.assertRaisesRegex( NotImplementedError, '__reversed__ method only implemented for tuples and lists, not type' " ", ): c = cursor({'a': 1, 'b': 2}) for key in reversed(c): c[key] *= -1 for obj_value in (2, jnp.array([[1, -2], [3, 4]])): for c in ( cursor(obj_value), cursor([obj_value])[0], cursor((obj_value,))[0], cursor({0: obj_value})[0], cursor(flax.core.freeze({0: obj_value}))[0], cursor(GenericTuple(x=obj_value)).x, cursor(GenericDataClass(x=obj_value)).x, ): # test __neg__ self.assertTrue(same_value(-c, -obj_value)) # test __pos__ self.assertTrue(same_value(+c, +obj_value)) # test __abs__ self.assertTrue(same_value(abs(-c), abs(-obj_value))) # test __invert__ self.assertTrue(same_value(~c, ~obj_value)) # test __round__ self.assertTrue(same_value(round(c + 0.123), round(obj_value + 0.123))) self.assertTrue( same_value(round(c + 0.123, 2), round(obj_value + 0.123, 2)) ) for other_value in (3, jnp.array([[5, 6], [7, 8]])): # test __add__ self.assertTrue(same_value(c + other_value, obj_value + other_value)) # test __radd__ self.assertTrue(same_value(other_value + c, other_value + obj_value)) # test __sub__ self.assertTrue(same_value(c - other_value, obj_value - other_value)) # test __rsub__ self.assertTrue(same_value(other_value - c, other_value - obj_value)) # test __mul__ self.assertTrue(same_value(c * other_value, obj_value * other_value)) # test __rmul__ self.assertTrue(same_value(other_value * c, other_value * obj_value)) # test __truediv__ self.assertTrue(same_value(c / other_value, obj_value / other_value)) # test __rtruediv__ self.assertTrue(same_value(other_value / c, other_value / obj_value)) # test __floordiv__ self.assertTrue( same_value(c // other_value, obj_value // other_value) ) # test __rfloordiv__ self.assertTrue( same_value(other_value // c, other_value // obj_value) ) # test __mod__ self.assertTrue(same_value(c % other_value, obj_value % other_value)) # test __rmod__ self.assertTrue(same_value(other_value % c, other_value % obj_value)) # test __divmod__ self.assertTrue( same_value(divmod(c, other_value), divmod(obj_value, other_value)) ) # test __rdivmod__ self.assertTrue( same_value(divmod(other_value, c), divmod(other_value, obj_value)) ) # test __pow__ self.assertTrue( same_value(pow(c, other_value), pow(obj_value, other_value)) ) # test __rpow__ self.assertTrue( same_value(pow(other_value, c), pow(other_value, obj_value)) ) # test __lshift__ self.assertTrue( same_value(c << other_value, obj_value << other_value) ) # test __rlshift__ self.assertTrue( same_value(other_value << c, other_value << obj_value) ) # test __rshift__ self.assertTrue( same_value(c >> other_value, obj_value >> other_value) ) # test __rrshift__ self.assertTrue( same_value(other_value >> c, other_value >> obj_value) ) # test __and__ self.assertTrue(same_value(c & other_value, obj_value & other_value)) # test __rand__ self.assertTrue(same_value(other_value & c, other_value & obj_value)) # test __xor__ self.assertTrue(same_value(c ^ other_value, obj_value ^ other_value)) # test __rxor__ self.assertTrue(same_value(other_value ^ c, other_value ^ obj_value)) # test __or__ self.assertTrue(same_value(c | other_value, obj_value | other_value)) # test __ror__ self.assertTrue(same_value(other_value | c, other_value | obj_value)) if isinstance(obj_value, jax.Array) and isinstance( other_value, jax.Array ): # test __matmul__ self.assertTrue( same_value(c @ other_value, obj_value @ other_value) ) # test __rmatmul__ self.assertTrue( same_value(other_value @ c, other_value @ obj_value) ) # test __lt__ self.assertTrue(same_value(c < other_value, obj_value < other_value)) self.assertTrue(same_value(other_value < c, other_value < obj_value)) # test __le__ self.assertTrue( same_value(c <= other_value, obj_value <= other_value) ) self.assertTrue( same_value(other_value <= c, other_value <= obj_value) ) # test __eq__ self.assertTrue( same_value(c == other_value, obj_value == other_value) ) self.assertTrue( same_value(other_value == c, other_value == obj_value) ) # test __ne__ self.assertTrue( same_value(c != other_value, obj_value != other_value) ) self.assertTrue( same_value(other_value != c, other_value != obj_value) ) # test __gt__ self.assertTrue(same_value(c > other_value, obj_value > other_value)) self.assertTrue(same_value(other_value > c, other_value > obj_value)) # test __ge__ self.assertTrue( same_value(c >= other_value, obj_value >= other_value) ) self.assertTrue( same_value(other_value >= c, other_value >= obj_value) ) def test_path(self): c = cursor( GenericTuple( x=[ 0, {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, (100, 200), [3, 4, 5], ], y=train_state.TrainState.create( apply_fn=lambda x: x, params=freeze({'a': 1, 'b': (2, 3), 'c': [4, 5]}), tx=optax.adam(1e-3), ), ) ) self.assertEqual(c.x[1][('c', 'd')][0]._path, ".x[1][('c', 'd')][0]") self.assertEqual(c.x[2][1]._path, '.x[2][1]') self.assertEqual(c.y.params['b'][1]._path, ".y.params['b'][1]") # test path when first access type is item access c = cursor([1, GenericTuple('a', 2), (3, 4)]) self.assertEqual(c[1].x._path, '[1].x') self.assertEqual(c[2][0]._path, '[2][0]') def test_traverse_tree(self): c = cursor( GenericTuple( x=[ 0, {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, (100, 200), [3, 4, 5], ], y=3, ) ) def update_fn(path, value): if value == 4: return -4 return value def cond_fn(path, value): return value == 3 with self.assertRaisesRegex( TraverseTreeError, 'Both update_fn and cond_fn are None. Exactly one of them must be' ' None.', ): next(_traverse_tree((), c._obj)) with self.assertRaisesRegex( TraverseTreeError, 'Both update_fn and cond_fn are not None. Exactly one of them must be' ' not None.', ): next(_traverse_tree((), c._obj, update_fn=update_fn, cond_fn=cond_fn)) (p, v), (p2, v2) = _traverse_tree((), c._obj, update_fn=update_fn) self.assertEqual( p, ( ('x', AccessType.ATTR), (1, AccessType.ITEM), (('c', 'd'), AccessType.ITEM), (0, AccessType.ITEM), ), ) self.assertEqual(v, -4) self.assertEqual( p2, (('x', AccessType.ATTR), (3, AccessType.ITEM), (1, AccessType.ITEM)) ) self.assertEqual(v2, -4) p, p2, p3 = _traverse_tree((), c._obj, cond_fn=cond_fn) self.assertEqual( p, ( ('x', AccessType.ATTR), (1, AccessType.ITEM), ('b', AccessType.ITEM), (1, AccessType.ITEM), ), ) self.assertEqual( p2, (('x', AccessType.ATTR), (3, AccessType.ITEM), (0, AccessType.ITEM)) ) self.assertEqual(p3, (('y', AccessType.ATTR),)) def test_set_and_build(self): # test regular dict and FrozenDict dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} for d, freeze_wrap in ((dict_obj, lambda x: x), (freeze(dict_obj), freeze)): # set API self.assertEqual( cursor(d)['b'][0].set(10), freeze_wrap({'a': 1, 'b': (10, 3), 'c': [4, 5]}), ) # build API c = cursor(d) c['b'][0] = 20 c['a'] = (100, 200) d2 = c.build() self.assertEqual( d2, freeze_wrap({'a': (100, 200), 'b': (20, 3), 'c': [4, 5]}) ) self.assertEqual( dict_obj, {'a': 1, 'b': (2, 3), 'c': [4, 5]} ) # make sure original object is unchanged # test list and tuple list_obj = [0, dict_obj, (1, 2), [3, 4, 5]] for l, tuple_wrap in ((list_obj, lambda x: x), (tuple(list_obj), tuple)): # set API self.assertEqual( cursor(l)[1]['b'][0].set(10), tuple_wrap([0, {'a': 1, 'b': (10, 3), 'c': [4, 5]}, (1, 2), [3, 4, 5]]), ) # build API c = cursor(l) c[1]['b'][0] = 20 c[2] = (100, 200) l2 = c.build() self.assertEqual( l2, tuple_wrap( [0, {'a': 1, 'b': (20, 3), 'c': [4, 5]}, (100, 200), [3, 4, 5]] ), ) self.assertEqual( list_obj, [0, {'a': 1, 'b': (2, 3), 'c': [4, 5]}, (1, 2), [3, 4, 5]] ) # make sure original object is unchanged # test TrainState state = train_state.TrainState.create( apply_fn=lambda x: x, params=dict_obj, tx=optax.adam(1e-3), ) # set API self.assertEqual( cursor(state).params['b'][0].set(10).params, {'a': 1, 'b': (10, 3), 'c': [4, 5]}, ) # build API new_fn = lambda x: x + 1 c = cursor(state) c.apply_fn = new_fn c.params['b'][0] = 20 c.params['a'] = (100, 200) state2 = c.build() self.assertEqual(state2.apply_fn, new_fn) self.assertEqual( state2.params, {'a': (100, 200), 'b': (20, 3), 'c': [4, 5]} ) self.assertEqual( dict_obj, {'a': 1, 'b': (2, 3), 'c': [4, 5]} ) # make sure original object is unchanged # test NamedTuple # set API t = GenericTuple(GenericTuple(0)) self.assertEqual(cursor(t).x.x.set(1), GenericTuple(GenericTuple(1))) # build API c = cursor(t) c.x.x = 2 c.x.y = 3 c.y = 4 t2 = c.build() self.assertEqual(t2, GenericTuple(GenericTuple(2, 3), 4)) self.assertEqual( t, GenericTuple(GenericTuple(0)) ) # make sure original object is unchanged def test_apply_update(self): # test list and tuple def update_fn(path, value): """Multiply the first element of all leaf nodes of the pytree by -1.""" if path[-1] == '0' and isinstance(value, int): return value * -1 return value for tuple_wrap in (lambda x: x, tuple): l = tuple_wrap([tuple_wrap([1, 2]), tuple_wrap([3, 4])]) c = cursor(l) l2 = c.apply_update(update_fn).build() self.assertEqual( l2, tuple_wrap([tuple_wrap([-1, 2]), tuple_wrap([-3, 4])]) ) self.assertEqual( l, tuple_wrap([tuple_wrap([1, 2]), tuple_wrap([3, 4])]) ) # make sure the original object is unchanged # test regular dict and FrozenDict def update_fn(path, value): """Multiply all dense kernel params by 2 and add 1. Subtract the Dense_1 bias param by 1.""" if 'kernel' in path: return value * 2 + 1 elif 'Dense_1' in path and 'bias' in path: return value - 1 return value class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) return x for freeze_wrap in (lambda x: x, freeze): params = freeze_wrap( Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] ) c = cursor(params) params2 = c.apply_update(update_fn).build() for layer in ('Dense_0', 'Dense_1', 'Dense_2'): self.assertTrue( (params2[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() ) if layer == 'Dense_1': self.assertTrue( (params2[layer]['bias'] == jnp.array([-1, -1, -1])).all() ) else: self.assertTrue( (params2[layer]['bias'] == params[layer]['bias']).all() ) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: (x == y).all(), params, freeze_wrap( Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] ), ) ) ) # make sure original params are unchanged # test TrainState def update_fn(path, value): """Replace params with empty dictionary.""" if 'params' in path: return {} return value state = train_state.TrainState.create( apply_fn=lambda x: x, params={'a': 1, 'b': 2}, tx=optax.adam(1e-3), ) c = cursor(state) state2 = c.apply_update(update_fn).build() self.assertEqual(state2.params, {}) self.assertEqual( state.params, {'a': 1, 'b': 2} ) # make sure original params are unchanged # test NamedTuple def update_fn(path, value): """Add 5 to all x-attribute values that are ints.""" if path[-1] == 'x' and isinstance(value, int): return value + 5 return value t = GenericTuple(GenericTuple(0, 1), GenericTuple(2, 3)) c = cursor(t) t2 = c.apply_update(update_fn).build() self.assertEqual(t2, GenericTuple(GenericTuple(5, 1), GenericTuple(7, 3))) self.assertEqual( t, GenericTuple(GenericTuple(0, 1), GenericTuple(2, 3)) ) # make sure original object is unchanged def test_apply_update_root_node_unmodified(self): def update_fn(path, value): if isinstance(value, list): value = value.copy() value.append(-1) return value l = [[1, 2], [3, 4], 5] l2 = cursor(l).apply_update(update_fn).build() self.assertEqual(l2, [[1, 2, -1], [3, 4, -1], 5]) def test_multi_modify(self): d = {'a': 1, 'b': (2, 3), 'c': [4, 5]} c = cursor(d) # test multiple changes on same element c['b'][0] = 6 c['b'][0] = 7 d2 = c.build() self.assertEqual(d2, {'a': 1, 'b': (7, 3), 'c': [4, 5]}) # test nested changes c['a'] = (100, 200) c['a'][1] = -1 d3 = c.build() self.assertEqual(d3, {'a': (100, -1), 'b': (7, 3), 'c': [4, 5]}) def test_hidden_change(self): # test list l = [1, 2] c = cursor(l) c[0] = 100 l[1] = -1 l2 = c.build() self.assertEqual( l2, [100, -1] ) # change in l affects l2 (this is expected behavior) self.assertEqual(l, [1, -1]) # test regular dict d = {'a': 1, 'b': 2} c = cursor(d) c['a'] = 100 d['b'] = -1 d2 = c.build() self.assertEqual( d2, {'a': 100, 'b': -1} ) # change in d affects d2 (this is expected behavior) self.assertEqual(d, {'a': 1, 'b': -1}) # test TrainState params = {'a': 1, 'b': 2} state = train_state.TrainState.create( apply_fn=lambda x: x, params=params, tx=optax.adam(1e-3), ) c = cursor(state) c.params['a'] = 100 params['b'] = -1 state2 = c.build() self.assertEqual( state2.params, {'a': 100, 'b': -1} ) # change in state affects state2 (this is expected behavior) self.assertEqual(state.params, {'a': 1, 'b': -1}) def test_named_tuple_multi_access(self): t = GenericTuple(GenericTuple(0, 1), GenericTuple(2, 3)) c = cursor(t) c.x.x = 4 c[0].y = 5 c.y.x = 6 c.y[1] = 7 self.assertEqual( c.build(), GenericTuple(GenericTuple(4, 5), GenericTuple(6, 7)) ) c[0][1] = -5 self.assertEqual( c.build(), GenericTuple(GenericTuple(4, -5), GenericTuple(6, 7)) ) c.x[1] = -6 self.assertEqual( c.build(), GenericTuple(GenericTuple(4, -6), GenericTuple(6, 7)) ) c.x.y = -7 self.assertEqual( c.build(), GenericTuple(GenericTuple(4, -7), GenericTuple(6, 7)) ) c[0].y = -8 self.assertEqual( c.build(), GenericTuple(GenericTuple(4, -8), GenericTuple(6, 7)) ) def test_find(self): c = cursor( GenericTuple( x=[ 0, {'a': 1, 'b': (2, 3), ('c', 'd'): [4, 5]}, (100, 200), [3, 4, 5], ], y=train_state.TrainState.create( apply_fn=lambda x: x, params=freeze({'a': 1, 'b': (2, 3), 'c': [4, 5]}), tx=optax.adam(1e-3), ), ) ) with self.assertRaisesRegex( CursorFindError, 'More than one object found given the conditions of the cond_fn\\. ' 'The first two objects found have the following paths: ' "\\.x\\[1]\\['b'] and \\.y\\.params\\['b'] ", ): c.find(lambda path, value: 'b' in path and isinstance(value, tuple)) with self.assertRaisesRegex( CursorFindError, 'No object found given the conditions of the cond_fn\\.', ): c.find(lambda path, value: 'b' in path and isinstance(value, str)) self.assertEqual( c.find(lambda path, value: path.endswith('params/b'))[1].set(30).y.params, freeze({'a': 1, 'b': (2, 30), 'c': [4, 5]}), ) def test_find_all(self): # test list and tuple def cond_fn(path, value): """Get all lists that are not the first element in its parent.""" return path[-1] != '0' and isinstance(value, (tuple, list)) for tuple_wrap in (lambda x: x, tuple): l = tuple_wrap( [tuple_wrap([1, 2]), tuple_wrap([3, 4]), tuple_wrap([5, 6])] ) c = cursor(l) c2, c3 = c.find_all(cond_fn) c2[0] *= -1 c3[1] *= -2 self.assertEqual( c.build(), tuple_wrap( [tuple_wrap([1, 2]), tuple_wrap([-3, 4]), tuple_wrap([5, -12])] ), ) self.assertEqual( l, tuple_wrap( [tuple_wrap([1, 2]), tuple_wrap([3, 4]), tuple_wrap([5, 6])] ), ) # make sure the original object is unchanged # test regular dict and FrozenDict def cond_fn(path, value): """Get the second and third dense params.""" return 'Dense_1' in path or 'Dense_2' in path class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) x = nn.Dense(3)(x) x = nn.relu(x) return x for freeze_wrap in (lambda x: x, freeze): params = freeze_wrap( Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] ) c = cursor(params) for i, c2 in enumerate(c.find_all(cond_fn)): self.assertEqual( c2['kernel'].set(123)[f'Dense_{i+1}'], freeze_wrap({'kernel': 123, 'bias': params[f'Dense_{i+1}']['bias']}), ) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: (x == y).all(), params, freeze_wrap( Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] ), ) ) ) # make sure original params are unchanged # test TrainState def cond_fn(path, value): """Find TrainState params.""" return 'params' in path state = train_state.TrainState.create( apply_fn=lambda x: x, params={'a': 1, 'b': 2}, tx=optax.adam(1e-3), ) c = cursor(state) c2 = list(c.find_all(cond_fn)) self.assertEqual(len(c2), 1) c2 = c2[0] self.assertEqual(c2['b'].set(-1).params, {'a': 1, 'b': -1}) self.assertEqual( state.params, {'a': 1, 'b': 2} ) # make sure original params are unchanged # test NamedTuple def cond_fn(path, value): """Get all GenericTuples that have int x-attribute values.""" return isinstance(value, GenericTuple) and isinstance(value.x, int) t = GenericTuple( GenericTuple(0, 'a'), GenericTuple(1, 'b'), GenericTuple('c', 2) ) c = cursor(t) c2, c3 = c.find_all(cond_fn) c2.x += 5 c3.x += 6 self.assertEqual( c.build(), GenericTuple( GenericTuple(5, 'a'), GenericTuple(7, 'b'), GenericTuple('c', 2) ), ) self.assertEqual( t, GenericTuple( GenericTuple(0, 'a'), GenericTuple(1, 'b'), GenericTuple('c', 2) ), ) # make sure original object is unchanged if __name__ == '__main__': absltest.main() ================================================ FILE: tests/download_dataset_metadata.sh ================================================ #!/bin/bash # If you get an error like: # Cloning into 'datasets'... # fatal: cannot change to 'https://github.com/tensorflow/datasets/': No such file or directory # error: failed to initialize sparse-checkout # This mean your git version is outdated. Just update it. set -e # Download TFDS metadata to flax/.tfds/metadata directory. # This allows the tests to specify the `data_dir` when using tfds.testing.mock_data(). cd "$( dirname "$0" )" if [ -d "../.tfds/metadata" ]; then echo 'TFDS metadata already exists.'; else echo 'TFDS metadata does not exist. Downloading...'; git clone --branch v4.8.2 --depth 3 --filter=blob:none --sparse https://github.com/tensorflow/datasets/ cd datasets git sparse-checkout set tensorflow_datasets/testing/metadata mkdir ../../.tfds mv tensorflow_datasets/testing/metadata/ ../../.tfds/metadata/ cd .. rm -rf datasets fi ================================================ FILE: tests/early_stopping_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.training.early_stopping.""" import jax from absl.testing import absltest from flax.training import early_stopping # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class EarlyStoppingTests(absltest.TestCase): def test_update(self): es = early_stopping.EarlyStopping(min_delta=0, patience=0) for i in range(2): improve_steps = 0 for step in range(10): metric = 1.0 es = es.update(metric) if not es.has_improved: improve_steps += 1 if es.should_stop: break self.assertEqual(improve_steps, 1) self.assertEqual(step, 1) es = es.reset() # ensure object is reusable if reset. def test_patience(self): es = early_stopping.EarlyStopping(min_delta=0, patience=0) patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6) for step in range(10): metric = 1.0 es = es.update(metric) if es.should_stop: break self.assertEqual(step, 1) for patient_step in range(10): metric = 1.0 patient_es = patient_es.update(metric) if patient_es.should_stop: break self.assertEqual(patient_step, 7) def test_delta(self): es = early_stopping.EarlyStopping(min_delta=0, patience=0) delta_es = early_stopping.EarlyStopping(min_delta=1e-3, patience=0) delta_patient_es = early_stopping.EarlyStopping(min_delta=1e-3, patience=1) metric = 1.0 for step in range(100): metric -= 1e-4 es = es.update(metric) if es.should_stop: break self.assertEqual(step, 99) metric = 1.0 for step in range(100): metric -= 1e-4 delta_es = delta_es.update(metric) if delta_es.should_stop: break self.assertEqual(step, 1) metrics = [ 0.01, 0.005, 0.0033, 0.0025, 0.002, 0.0017, 0.0014, 0.0012, 0.0011, 0.001, ] improvement_steps = 0 for step in range(10): metric = metrics[step] delta_patient_es = delta_patient_es.update(metric) if delta_patient_es.has_improved: improvement_steps += 1 if delta_patient_es.should_stop: break self.assertEqual(improvement_steps, 4) # steps 0, 1, 2, 4 self.assertEqual(step, 6) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/flaxlib_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TODO: Re-enable this test after setting up CI build for flaxlib CC. # from absl.testing import absltest # import flaxlib # class TestFlaxlib(absltest.TestCase): # def test_flaxlib(self): # self.assertEqual(flaxlib.sum_as_string(1, 2), '3') ================================================ FILE: tests/import_test.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Test Import in Colab\n", "\n", "\"Run all\" to test that all the Flax imports work in head.\n", "\n", "Change runtime type as needed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Colab runtimes are pre-built with JAX/Flax:\n", "!pip freeze | egrep 'jax|flax'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "# Install from head\n", "!pip install git+https://github.com/google/flax.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check versions after installing Flax from Github:\n", "!pip freeze | egrep 'jax|flax'" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Verify we can import everything.\n", "import flax\n", "from flax.training import (checkpoints, dynamic_scale, early_stopping, lr_schedule,\n", " orbax_utils, prefetch_iterator, train_state, common_utils)\n", "from flax.metrics import tensorboard" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: tests/io_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.io.""" import os import tempfile import jax import tensorflow as tf from absl.testing import absltest, parameterized from flax import errors, io # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class IOTest(parameterized.TestCase): @parameterized.parameters( {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF}, ) def test_override(self, backend_mode): with io.override_mode(backend_mode): self.assertEqual(io.io_mode, backend_mode) @parameterized.parameters( {'write_mode': io.BackendMode.DEFAULT, 'read_mode': io.BackendMode.TF}, {'write_mode': io.BackendMode.TF, 'read_mode': io.BackendMode.DEFAULT}, ) def test_GFile(self, write_mode, read_mode): test_string = b'testing write and read' with tempfile.TemporaryDirectory() as temp_dir_path: test_path = os.path.join(temp_dir_path, 'test') with io.override_mode(write_mode): with io.GFile(test_path, 'wb') as file: file.write(test_string) with io.override_mode(read_mode): with io.GFile(test_path, 'rb') as file: self.assertEqual(file.read(), test_string) def test_listdir(self): with tempfile.TemporaryDirectory() as temp_dir_path: os.mkdir(os.path.join(temp_dir_path, 'a')) os.mkdir(os.path.join(temp_dir_path, 'as')) os.mkdir(os.path.join(temp_dir_path, 'af')) os.mkdir(os.path.join(temp_dir_path, 'test')) os.mkdir(os.path.join(temp_dir_path, 'at')) with io.override_mode(io.BackendMode.DEFAULT): default_dir_set = set(io.listdir(temp_dir_path)) with io.override_mode(io.BackendMode.TF): tf_dir_set = set(io.listdir(temp_dir_path)) self.assertEqual(default_dir_set, tf_dir_set) @parameterized.parameters( {'create_temp_fn': tempfile.TemporaryDirectory}, {'create_temp_fn': tempfile.NamedTemporaryFile}, ) def test_isdir(self, create_temp_fn): with create_temp_fn() as temp: path = temp.name if hasattr(temp, 'name') else temp with io.override_mode(io.BackendMode.DEFAULT): default_isdir = io.isdir(path) with io.override_mode(io.BackendMode.TF): tf_isdir = io.isdir(path) self.assertEqual(default_isdir, tf_isdir) def test_copy(self): test_string = b'testing copy' with tempfile.TemporaryDirectory() as temp_dir_path: test_path = os.path.join(temp_dir_path, 'test') copy1_path = os.path.join(temp_dir_path, 'copy1') copy2_path = os.path.join(temp_dir_path, 'copy2') with io.GFile(test_path, 'wb') as file: file.write(test_string) with io.override_mode(io.BackendMode.DEFAULT): io.copy(test_path, copy1_path) with io.override_mode(io.BackendMode.TF): io.copy(copy1_path, copy2_path) with io.GFile(copy2_path, 'rb') as file: self.assertEqual(file.read(), test_string) @parameterized.parameters( { 'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError, }, { 'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError, }, ) def test_copy_raises_error(self, backend_mode, error_type): with tempfile.NamedTemporaryFile() as temp_file: with io.override_mode(backend_mode): with self.assertRaises(error_type): io.copy(temp_file.name, temp_file.name) def test_rename(self): with tempfile.TemporaryDirectory() as temp_dir_path: test_path = os.path.join(temp_dir_path, 'test') rename1_path = os.path.join(temp_dir_path, 'rename1') rename2_path = os.path.join(temp_dir_path, 'rename2') with io.GFile(test_path, 'wb') as file: file.write(b'placeholder text') with io.override_mode(io.BackendMode.DEFAULT): io.rename(test_path, rename1_path) with io.override_mode(io.BackendMode.TF): io.rename(rename1_path, rename2_path) with io.GFile(rename2_path, 'rb') as file: self.assertTrue(os.path.exists(rename2_path)) @parameterized.parameters( { 'backend_mode': io.BackendMode.DEFAULT, 'error_type': errors.AlreadyExistsError, }, { 'backend_mode': io.BackendMode.TF, 'error_type': tf.errors.AlreadyExistsError, }, ) def test_rename_raises_error(self, backend_mode, error_type): with tempfile.NamedTemporaryFile() as temp_file: with io.override_mode(backend_mode): with self.assertRaises(error_type): io.rename(temp_file.name, temp_file.name) def test_exists(self): with tempfile.NamedTemporaryFile() as temp_file: with io.override_mode(io.BackendMode.DEFAULT): default_exists = io.exists(temp_file.name) with io.override_mode(io.BackendMode.TF): tf_exists = io.exists(temp_file.name) self.assertEqual(default_exists, tf_exists) @parameterized.parameters( {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF}, ) def test_makedirs(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: test_dir_path = os.path.join(temp_dir_path, 'test_dir') with io.override_mode(backend_mode): io.makedirs(test_dir_path) self.assertTrue( os.path.exists(test_dir_path) and (os.path.isdir(test_dir_path)) ) def test_glob(self): with tempfile.TemporaryDirectory() as temp_dir_path: os.mkdir(os.path.join(temp_dir_path, 'a')) os.mkdir(os.path.join(temp_dir_path, 'as')) os.mkdir(os.path.join(temp_dir_path, 'af')) os.mkdir(os.path.join(temp_dir_path, 'test')) os.mkdir(os.path.join(temp_dir_path, 'at')) with io.override_mode(io.BackendMode.DEFAULT): default_glob_set = set(io.glob('a*/')) with io.override_mode(io.BackendMode.TF): tf_glob_set = set(io.glob('a*/')) self.assertEqual(default_glob_set, tf_glob_set) @parameterized.parameters( {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF}, ) def test_remove(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: test_path = os.path.join(temp_dir_path, 'test') with io.GFile(test_path, 'wb') as file: file.write(b'placeholder text') with io.override_mode(backend_mode): io.remove(test_path) self.assertTrue(not os.path.exists(test_path)) @parameterized.parameters( {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF}, ) def test_rmtree(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: dir0_path = os.path.join(temp_dir_path, 'dir0') os.mkdir(dir0_path) os.mkdir(os.path.join(dir0_path, 'dir1')) os.mkdir(os.path.join(dir0_path, 'dir1', 'dir2')) os.mkdir(os.path.join(dir0_path, 'dir1', 'dir3')) os.mkdir(os.path.join(dir0_path, 'dir4')) os.mkdir(os.path.join(dir0_path, 'dir4', 'dir5')) os.mkdir(os.path.join(dir0_path, 'dir6')) with io.override_mode(backend_mode): io.rmtree(dir0_path) self.assertTrue(not os.path.exists(dir0_path)) @parameterized.parameters( {'backend_mode': io.BackendMode.DEFAULT}, {'backend_mode': io.BackendMode.TF}, ) def test_getsize(self, backend_mode): with tempfile.TemporaryDirectory() as temp_dir_path: test_path = os.path.join(temp_dir_path, 'test') content = b'placeholder text' with io.GFile(test_path, 'wb') as file: file.write(content) with io.override_mode(backend_mode): size = io.getsize(test_path) self.assertEqual(size, len(content)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/jax_utils_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.jax_utils.""" from functools import partial import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' from absl.testing import absltest from absl.testing import parameterized from flax import jax_utils import jax import jax.numpy as jnp import numpy as np NDEV = 4 def assert_max_traces(n): """Decorator to assert that a function is traced at most n times.""" from functools import wraps def decorator(fn): trace_count = {'count': 0} @wraps(fn) def wrapped(*args, **kwargs): trace_count['count'] += 1 if trace_count['count'] > n: raise AssertionError( f"Function was traced {trace_count['count']} times, " f"expected at most {n} traces" ) return fn(*args, **kwargs) wrapped.trace_count = trace_count return wrapped return decorator class PadShardUnpadTest(parameterized.TestCase): BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1] DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32] @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_basics(self, dtype, bs): # Just tests that basic calling works without exploring caveats. @partial(jax_utils.pad_shard_unpad, static_argnums=()) def add(a, b): b = jnp.asarray(b, dtype=dtype) return a + b x = np.arange(bs, dtype=dtype) y = add(x, 10 * x) self.assertEqual(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_trees(self, dtype, bs): # Just tests that basic calling works without exploring caveats. @partial(jax_utils.pad_shard_unpad, static_argnums=()) def add(a, b): return a['a'] + b[0] x = jnp.arange(bs, dtype=dtype) y = add(dict(a=x), (10 * x,)) self.assertEqual(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.parameters(DTYPES) def test_min_device_batch_avoids_recompile(self, dtype): @partial(jax_utils.pad_shard_unpad, static_argnums=()) @jax.jit @assert_max_traces(n=1) def add(a, b): b = jnp.asarray(b, dtype=dtype) return a + b for bs in self.BATCH_SIZES: x = jnp.arange(bs, dtype=dtype) y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg self.assertEqual(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_static_argnum(self, dtype, bs): @partial(jax_utils.pad_shard_unpad, static_argnums=(1,)) def add(a, b): return a + jnp.asarray(b, dtype=dtype) x = jnp.arange(bs, dtype=dtype) y = add(x, 10) self.assertEqual(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) @parameterized.product(dtype=DTYPES, bs=BATCH_SIZES) def test_static_argnames(self, dtype, bs): # In this test, leave static_argnums at the default value too, in order to # test the default/most canonical path where `params` are the first arg. @partial(jax_utils.pad_shard_unpad, static_argnames=('b',)) def add(params, a, *, b): params = jnp.asarray(params, dtype=dtype) b = jnp.asarray(b, dtype=dtype) return params * a + b x = jnp.arange(bs, dtype=dtype) y = add(5, x, b=10) self.assertEqual(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/initializers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.initializers.""" import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest, parameterized from jax import random from flax import linen as nn from flax.linen import initializers # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class InitializersTest(parameterized.TestCase): @parameterized.parameters( { 'builder_fn': initializers.zeros_init, 'params_shape': (2, 3), 'expected_params': jnp.zeros((2, 3)), }, { 'builder_fn': initializers.ones_init, 'params_shape': (3, 2), 'expected_params': jnp.ones((3, 2)), }, ) def test_call_builder(self, builder_fn, params_shape, expected_params): params = builder_fn()(random.key(42), params_shape, jnp.float32) np.testing.assert_allclose(params, expected_params) @parameterized.parameters( { 'builder_fn': initializers.zeros_init, 'expected_params': jnp.zeros((2, 5)), }, { 'builder_fn': initializers.ones_init, 'expected_params': jnp.ones((2, 5)), }, ) def test_kernel_builder(self, builder_fn, expected_params): layer = nn.Dense(5, kernel_init=builder_fn()) params = layer.init(random.key(42), jnp.empty((3, 2)))['params'] np.testing.assert_allclose(params['kernel'], expected_params) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/kw_only_dataclasses_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for kw_only_dataclasses.""" import dataclasses import inspect from absl.testing import absltest from flax.linen import kw_only_dataclasses class KwOnlyDataclassesTest(absltest.TestCase): def test_kwonly_args_moved_to_end(self): @kw_only_dataclasses.dataclass class TestClass: a: int = 1 b: int = kw_only_dataclasses.field(default=2, kw_only=True) c: int = 3 params = inspect.signature(TestClass.__init__).parameters self.assertEqual(list(params), ['self', 'a', 'c', 'b']) self.assertEqual(params['a'].default, 1) self.assertEqual(params['b'].default, 2) self.assertEqual(params['c'].default, 3) v1 = TestClass() self.assertDictEqual(dataclasses.asdict(v1), dict(a=1, b=2, c=3)) v2 = TestClass(b=20) self.assertDictEqual(dataclasses.asdict(v2), dict(a=1, b=20, c=3)) v3 = TestClass(1, 30) self.assertDictEqual(dataclasses.asdict(v3), dict(a=1, b=2, c=30)) def test_base_optional_subclass_required(self): @kw_only_dataclasses.dataclass class Parent: a: int = kw_only_dataclasses.field(default=2, kw_only=True) @kw_only_dataclasses.dataclass class Child(Parent): b: int child_params = inspect.signature(Child.__init__).parameters self.assertEqual(list(child_params), ['self', 'b', 'a']) self.assertEqual(child_params['a'].default, 2) self.assertEqual(child_params['b'].default, inspect.Parameter.empty) v1 = Child(4) self.assertDictEqual(dataclasses.asdict(v1), dict(a=2, b=4)) v2 = Child(4, a=5) # pylint: disable=too-many-function-args self.assertDictEqual(dataclasses.asdict(v2), dict(a=5, b=4)) def test_subclass_overrides_base(self): # Note: if a base class declares a field as keyword-only, then # subclasses don't need to also declare it as keyword-only. @kw_only_dataclasses.dataclass class A: x: int = kw_only_dataclasses.field(default=1, kw_only=True) @kw_only_dataclasses.dataclass class B(A): size: float y: int = kw_only_dataclasses.field(default=3, kw_only=True) x: int = 2 @kw_only_dataclasses.dataclass class C(B): name: str a_params = inspect.signature(A.__init__).parameters b_params = inspect.signature(B.__init__).parameters c_params = inspect.signature(C.__init__).parameters self.assertEqual(list(a_params), ['self', 'x']) self.assertEqual(list(b_params), ['self', 'size', 'x', 'y']) self.assertEqual(list(c_params), ['self', 'size', 'name', 'x', 'y']) self.assertEqual(a_params['x'].default, 1) self.assertEqual(b_params['x'].default, 2) self.assertEqual(b_params['y'].default, 3) self.assertEqual(b_params['size'].default, inspect.Parameter.empty) self.assertEqual(c_params['x'].default, 2) self.assertEqual(c_params['y'].default, 3) self.assertEqual(c_params['name'].default, inspect.Parameter.empty) self.assertEqual(c_params['size'].default, inspect.Parameter.empty) value = C(4, 'foo') # pylint: disable=too-many-function-args self.assertDictEqual( dataclasses.asdict(value), dict(name='foo', size=4, x=2, y=3) ) def test_kwonly_marker(self): @kw_only_dataclasses.dataclass class A: x: float _: kw_only_dataclasses.KW_ONLY a: int = 5 b: int = kw_only_dataclasses.field(default=2) c: int = kw_only_dataclasses.field(default=2, kw_only=True) @kw_only_dataclasses.dataclass class B(A): z: str a_params = inspect.signature(A.__init__).parameters b_params = inspect.signature(B.__init__).parameters self.assertEqual(list(a_params), ['self', 'x', 'a', 'b', 'c']) self.assertEqual(list(b_params), ['self', 'x', 'z', 'a', 'b', 'c']) def test_whatever(self): import abc from collections.abc import Iterator, Iterable import io from typing import Protocol, TypeVar T = TypeVar("T") class CheckpointableIterator(Iterator[T], Protocol[T]): pass isinstance(io.TextIOBase, Iterable) from flax import linen as nn class Steppable(metaclass=abc.ABCMeta): path: str class SequenceLayer(nn.Module, Steppable): pass if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_activation_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.activation.""" from absl.testing import absltest from flax import linen as nn import jax from jax import random import jax.numpy as jnp import numpy as np # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class ActivationTest(absltest.TestCase): def test_prelu(self): rng = random.key(0) key, skey_1, skey_2 = jax.random.split(rng, 3) x = jax.random.uniform(skey_1, (4, 6, 5)) - 0.5 act = nn.PReLU() y, params = act.init_with_output(skey_2, x) expected_y = jnp.where(x < 0, x * act.negative_slope_init, x) init_negative_slope = params['params']['negative_slope'] expected_negative_slope = jnp.array( act.negative_slope_init, dtype=jnp.float32 ) self.assertEqual(y.shape, x.shape) np.testing.assert_array_almost_equal(expected_y, y) np.testing.assert_array_equal(init_negative_slope, expected_negative_slope) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_attention_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.attention.""" from absl.testing import absltest, parameterized from flax import errors, jax_utils from flax import linen as nn from flax.core import pop import jax from jax import lax, random from jax.nn import initializers import jax.numpy as jnp import numpy as np try: # JAX v0.8.0 and newer from jax import enable_x64 except ImportError: from jax.experimental import enable_x64 # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class AttentionTest(parameterized.TestCase): def test_multihead_self_attention(self): rng = random.key(0) x = jnp.ones((4, 6, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, deterministic=False, ) y, _ = sa_module.init_with_output(rng, x) self.assertEqual(y.shape, x.shape) self.assertEqual(y.dtype, jnp.float32) def test_dtype_infer(self): rng = random.key(0) x = jnp.ones((4, 6, 5), jnp.complex64) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, deterministic=False, ) y, _ = sa_module.init_with_output(rng, x) self.assertEqual(y.shape, x.shape) self.assertEqual(y.dtype, jnp.complex64) def test_multihead_encoder_decoder_attention(self): rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, deterministic=False, ) y, _ = sa_module.init_with_output(rng, q) self.assertEqual(y.shape, q.shape) def test_mha_out_initializers(self): rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, out_kernel_init=initializers.zeros, bias_init=initializers.zeros, out_bias_init=initializers.ones, deterministic=False, ) variables = sa_module.init(rng, q) params = variables['params'] # test kernels np.testing.assert_allclose(params['query']['kernel'], 1.0) np.testing.assert_allclose(params['key']['kernel'], 1.0) np.testing.assert_allclose(params['value']['kernel'], 1.0) np.testing.assert_allclose(params['out']['kernel'], 0.0) # test biases np.testing.assert_allclose(params['query']['bias'], 0.0) np.testing.assert_allclose(params['key']['bias'], 0.0) np.testing.assert_allclose(params['value']['bias'], 0.0) np.testing.assert_allclose(params['out']['bias'], 1.0) def test_multihead_self_attention_w_dropout(self): rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.1, deterministic=False, ) rng1, rng2 = random.split(rng) rngs = {'params': rng1, 'dropout': rng2} y, _ = sa_module.init_with_output(rngs, x) self.assertEqual(y.shape, x.shape) def test_multihead_self_attention_explicit_dropout(self): def clone(key): return jax.tree.map(jax.random.clone, key) class Foo(nn.Module): attention_kwargs: dict @nn.compact def __call__(self, x, dropout_rng=None): a = nn.MultiHeadDotProductAttention(**self.attention_kwargs)( x, x, dropout_rng=dropout_rng ) if dropout_rng is not None: dropout_rng = clone(dropout_rng) b = nn.MultiHeadDotProductAttention(**self.attention_kwargs)( x, x, dropout_rng=dropout_rng ) return a, b module = Foo( dict( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.5, deterministic=False, ) ) rng1, rng2, rng3, rng4 = random.split(random.key(0), 4) x = jnp.ones((4, 2, 3, 5)) rngs = {'params': rng1, 'dropout': rng2} v = module.init(rngs, x) a, b = module.apply(v, x, rngs=clone(rngs)) c, d = module.apply(v, x, rngs={'dropout': clone(rng2)}) e, f = module.apply(v, x, rngs={'dropout': rng3}) self.assertFalse((a == b).all()) self.assertTrue((a == c).all()) self.assertTrue((b == d).all()) self.assertFalse((a == e).all()) self.assertFalse((b == f).all()) a, b = module.apply(v, x, rngs=clone(rngs), dropout_rng=rng4) self.assertTrue((a == b).all()) a, b = module.apply(v, x, dropout_rng=clone(rng4)) self.assertTrue((a == b).all()) self.assertTrue(a.shape == b.shape == x.shape) def test_multihead_self_attention_w_dropout_disabled(self): rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) sa_module0 = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.0, deterministic=True, ) rng1, rng2, rng3, rng4 = random.split(rng, 4) rngs1 = {'params': rng1, 'dropout': rng2} rngs2 = {'params': rng3, 'dropout': rng4} y1, vs = sa_module0.init_with_output(rngs1, x) y2, _ = sa_module0.init_with_output(rngs2, x) np.testing.assert_allclose(y1, y2) y3 = sa_module0.apply(vs, x, rngs=rngs1) y4 = sa_module0.apply(vs, x, rngs=rngs2) np.testing.assert_allclose(y3, y4) sa_module1 = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.0, ) y5 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs1) y6 = sa_module1.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y5, y6) sa_module2 = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, dropout_rate=0.5, ) y7 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs1) y8 = sa_module2.apply(vs, x, deterministic=True, rngs=rngs2) np.testing.assert_allclose(y7, y8) def test_causal_mask_1d(self): """Tests autoregressive masking for 1d attention.""" x = jnp.ones((3, 16)) # (bs1, length) mask_1d = nn.attention.make_causal_mask(x) ts = np.arange(16) mask_1d_simple = (ts[:, None] >= ts[None, :])[None, None, :, :] mask_1d_simple = jnp.broadcast_to(mask_1d_simple, (3, 1, 16, 16)) np.testing.assert_allclose( mask_1d, mask_1d_simple, ) @parameterized.parameters([((5,), (1,)), ((6, 5), (2,))]) def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.key(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,) ) module = nn.MultiHeadDotProductAttention( num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, deterministic=False, decode=False, ) decode_module = module.clone(decode=True) initial_vars = decode_module.init(key2, inputs) state, params = pop(initial_vars, 'params') causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape)) y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, mask=y))( inputs, causal_mask ) # feed the inputs sequentially to simulate decoding def body_fn(state, x): y, state = decode_module.apply( {'params': params, **state}, x, mutable=['cache'] ) return state, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim( body_fn, state, inputs, axis=attn_dims, keepdims=True ) np.testing.assert_allclose(y_ref, y, atol=1e-5) def test_autoregressive_receptive_field_1d(self): """Tests the autoregressive self-attention receptive field.""" rng = random.key(0) rng1, rng2 = random.split(rng, num=2) length = 10 dim = 1 num_heads = 1 input_shape = (1, length, dim) inputs = random.normal(rng2, input_shape) module = nn.MultiHeadDotProductAttention( num_heads=num_heads, kernel_init=jax.nn.initializers.ones, deterministic=False, ) initial_vars = module.init(rng1, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) def model_loss(inputs, pos): out = module.apply(initial_vars, inputs, mask=causal_mask) assert out.shape == input_shape assert len(out.shape) == 3 return out[0, pos, :].sum() grad_fn = jax.jit(jax.grad(model_loss)) def get_receptive_field_1d(pos): g = grad_fn(inputs, pos)[0, :, :] return jnp.any((jnp.abs(g) > 1e-5).astype(jnp.uint32), axis=-1) for i in range(length): deps = get_receptive_field_1d(i) assert (deps[:i] == 1).all(), ( 'Receptive Field Error: Some of the ' 'previous postions are not reachable ' 'in autoregressive self-attention.' ) if i != length - 1: k = i + 1 assert (deps[k:] == 0).all(), ( 'Receptive Field Error: Some of the ' 'future postions are reachable in ' 'autoregressive self-attention.' ) def test_multihead_kv_args(self): key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key2, (9, 5)) module = nn.MultiHeadDotProductAttention( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, deterministic=False, ) key = lambda: random.key(43279) y0, v0 = module.init_with_output( key(), query, inputs_k=key_value, inputs_v=key_value ) y1, v1 = module.init_with_output(key(), query, inputs_k=key_value) with self.assertWarnsRegex( DeprecationWarning, 'The inputs_kv arg will be deprecated soon.' ): y2, v2 = module.init_with_output(key(), query, inputs_kv=key_value) self.assertTrue((y0 == y1).all() and (y1 == y2).all()) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y, z: (x == y).all() and (y == z).all(), v0, v1, v2 ) ) ) with self.assertRaisesRegex( ValueError, '`inputs_k` cannot be None if `inputs_v` is not None.' ): y3, v3 = module.init_with_output(key(), query, inputs_v=key_value) with self.assertRaisesRegex( ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.', ): y3, v3 = module.init_with_output( key(), query, inputs_kv=key_value, inputs_v=key_value ) with self.assertRaisesRegex( ValueError, 'If either `inputs_k` or `inputs_v` is not None, `inputs_kv` must be None.', ): y3, v3 = module.init_with_output( key(), query, key_value, key_value, inputs_kv=key_value ) def test_multihead_mask_warning(self): rng = random.key(0) rng1, rng2 = random.split(rng, num=2) length = 10 dim = 1 num_heads = 1 input_shape = (1, length, dim) query = key = random.normal(rng2, input_shape) module = nn.MultiHeadDotProductAttention( num_heads=num_heads, kernel_init=jax.nn.initializers.ones, deterministic=False, ) initial_vars = module.init(rng1, query, key) causal_mask = nn.attention.make_causal_mask(jnp.ones(input_shape[:-1])) module.apply(initial_vars, query, key, mask=causal_mask) with self.assertWarnsRegex( DeprecationWarning, "the function signature of MultiHeadDotProductAttention's `__call__` method has changed", ): with self.assertRaises(errors.ScopeParamShapeError): module.apply(initial_vars, query, key, causal_mask) def test_multihead_sow_attention_weights(self): rng = random.key(0) x = jnp.ones((4, 6, 5)) class Model(nn.Module): attention_kwargs: dict @nn.compact def __call__(self, x, sow_weights=False): x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)( x, sow_weights=sow_weights ) x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x) x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)( x, sow_weights=sow_weights ) return x module = Model( dict( num_heads=8, qkv_features=16, kernel_init=initializers.ones, bias_init=initializers.zeros, deterministic=False, ) ) v = module.init(rng, x) _, intermediates = module.apply( v, x, mutable=['intermediates'], sow_weights=True ) self.assertEqual( intermediates['intermediates']['MultiHeadDotProductAttention_0'][ 'attention_weights' ][0].shape, (4, 8, 6, 6), ) self.assertNotIn( 'MultiHeadDotProductAttention_1', intermediates['intermediates'] ) self.assertEqual( intermediates['intermediates']['MultiHeadDotProductAttention_2'][ 'attention_weights' ][0].shape, (4, 8, 6, 6), ) _, intermediates = module.apply( v, x, mutable=['intermediates'], sow_weights=False ) self.assertNotIn('intermediates', intermediates) def test_autoregressive_decode_with_x64(self): with enable_x64(): x = jnp.ones((1, 4, 4)) module = nn.MultiHeadDotProductAttention( num_heads=2, qkv_features=4, decode=True ) rng = random.PRNGKey(0) variables = module.init(rng, x, x, x) params, cache = variables['params'], variables['cache'] y1, updates = module.apply( { 'params': params, 'cache': cache }, x[:, :1, :], mutable=['cache'] ) cache = updates['cache'] y2, updates = module.apply( { 'params': params, 'cache': cache }, x[:, 1:2, :], mutable=['cache'] ) assert y1.shape == (1, 1, 4) assert y2.shape == (1, 1, 4) def test_attention_alias_equivalence(self): key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key2, (9, 5)) attention_kwargs = dict( num_heads=8, qkv_features=16, kernel_init=initializers.lecun_normal(), bias_init=initializers.uniform(), deterministic=False, ) module1 = nn.MultiHeadDotProductAttention(**attention_kwargs) module2 = nn.MultiHeadAttention(**attention_kwargs) key = lambda: random.key(43279) out1, v1 = module1.init_with_output(key(), query, key_value) out2, v2 = module2.init_with_output(key(), query, key_value, key_value) self.assertTrue((out1 == out2).all()) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map(lambda x, y: (x == y).all(), v1, v2) ) ) def test_attention_alias_submodule(self): key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key2, (9, 5)) attention_kwargs = dict( num_heads=8, qkv_features=16, kernel_init=initializers.lecun_normal(), bias_init=initializers.uniform(), deterministic=False, ) class Foo1(nn.Module): attention_kwargs: dict @nn.compact def __call__(self, query, key): return nn.MultiHeadDotProductAttention(**self.attention_kwargs)( query, key ) class Foo2(nn.Module): attention_kwargs: dict @nn.compact def __call__(self, query, key, value): return nn.MultiHeadAttention(**self.attention_kwargs)(query, key, value) key = lambda: random.key(5478392) module1 = Foo1(attention_kwargs) module2 = Foo2(attention_kwargs) out1, v1 = module1.init_with_output(key(), query, key_value) out2, v2 = module2.init_with_output(key(), query, key_value, key_value) # test different output and variables if layer names are different self.assertTrue((out1 != out2).all()) v2['params']['MultiHeadDotProductAttention_0'] = v2['params'][ 'MultiHeadAttention_0' ] del v2['params']['MultiHeadAttention_0'] self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map(lambda x, y: (x != y).all(), v1, v2) ) ) # test same output if variables are the same v2 = jax.tree_util.tree_map(lambda x: x, v1) v2['params']['MultiHeadAttention_0'] = v2['params'][ 'MultiHeadDotProductAttention_0' ] del v2['params']['MultiHeadDotProductAttention_0'] out2 = module2.apply(v2, query, key_value, key_value) self.assertTrue((out1 == out2).all()) # test same output and variables if names are the same class Foo2(nn.Module): attention_kwargs: dict @nn.compact def __call__(self, query, key, value): return nn.MultiHeadAttention( **self.attention_kwargs, name='MultiHeadDotProductAttention_0' )(query, key, value) module2 = Foo2(attention_kwargs) out2, v2 = module2.init_with_output(key(), query, key_value, key_value) self.assertTrue((out1 == out2).all()) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map(lambda x, y: (x == y).all(), v1, v2) ) ) @parameterized.parameters( {'force_fp32': True, 'attn_weights_dtype': jnp.float32}, {'force_fp32': False, 'attn_weights_dtype': jnp.bfloat16}, ) def test_mixed_precision_multihead_attention( self, force_fp32, attn_weights_dtype ): input_key, params_key, dropout_key = random.split(random.key(0), 3) x = random.uniform(input_key, (2, 4)) attention_kwargs = dict( num_heads=2, qkv_features=4, kernel_init=initializers.lecun_normal(), bias_init=initializers.uniform(), force_fp32_for_softmax=force_fp32, deterministic=False, dtype=jnp.bfloat16, ) mha = nn.MultiHeadDotProductAttention(**attention_kwargs) init_vars = mha.init({'params': params_key, 'dropout': dropout_key}, x) _, updated_vars = mha.apply( init_vars, x, mutable=['intermediates'], sow_weights=True ) self.assertEqual( updated_vars['intermediates']['attention_weights'][0].dtype, attn_weights_dtype, ) @parameterized.parameters( (lax.Precision.DEFAULT, None), (None, jax.lax.dot_general), ) def test_dot_product_attention_precision_and_einsum_override( self, precision, einsum_dot_general ): # Test that we raise a ValueError if the user specifies both # `precision` and/or `einsum_dot_general` and `qk_attn_weights_einsum`. einsum_cls = lambda: jnp.einsum self.assertRaises( ValueError, nn.dot_product_attention, query=jnp.ones((1, 4, 2)), key=jnp.ones((1, 4, 2)), value=jnp.ones((1, 4, 2)), precision=precision, einsum_dot_general=einsum_dot_general, qk_attn_weights_einsum=einsum_cls, attn_weights_value_einsum=einsum_cls, ) @parameterized.parameters( (lambda: jax.lax.dot_general, None), (None, lambda: jax.lax.dot_general), ) def test_dot_product_attention_specify_einsums_together( self, qk_attn_weights_einsum, attn_weights_value_einsum ): # Test that we raise a ValueError if the user specifies only one of # `qk_attn_weights_einsum` and `attn_weights_value_einsum`. self.assertRaises( ValueError, nn.dot_product_attention, query=jnp.ones((1, 4, 2)), key=jnp.ones((1, 4, 2)), value=jnp.ones((1, 4, 2)), qk_attn_weights_einsum=qk_attn_weights_einsum, attn_weights_value_einsum=attn_weights_value_einsum, ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_batch_apply_test.py ================================================ # Copyright 2023 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.batch_apply.""" import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest, parameterized from flax import linen as nn # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class BatchApplyTest(parameterized.TestCase): @parameterized.parameters( {'fn': lambda a, b: a + b.reshape(1, -1)}, {'fn': lambda a, b: jnp.dot(a, b)}, ) def test_batchapply(self, fn): a = jax.random.normal(jax.random.key(0), [2, 3, 4]) b = jax.random.normal(jax.random.key(1), [4]) def raises(a, b): if len(a.shape) != 2: raise ValueError('a must be shape 2') if len(b.shape) != 1: raise ValueError('b must be shape 1') return fn(a, b) out = nn.BatchApply(raises)(a, b) expected_merged_leading = raises(a.reshape(2 * 3, 4), b) expected = expected_merged_leading.reshape( (2, 3) + expected_merged_leading.shape[1:] ) np.testing.assert_array_equal(out, expected) def test_batchapply_accepts_float(self): def raises(a, b): if len(a.shape) != 2: raise ValueError('a must be shape 2') return a + b out = nn.BatchApply(raises)(jnp.ones([2, 3, 4]), 2.0) np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4])) def test_batchapply_accepts_none(self): def raises(a, b): if a is not None: raise ValueError('a must be None.') if len(b.shape) != 2: raise ValueError('b must be shape 2') return 3 * b out = nn.BatchApply(raises)(None, jnp.ones([2, 3, 4])) np.testing.assert_array_equal(out, 3 * jnp.ones([2, 3, 4])) def test_batchapply_raises(self): with self.assertRaisesRegex(ValueError, 'requires at least one input'): nn.BatchApply(lambda: 1)() if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_combinators_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.combinators.""" from typing import Any from collections.abc import Sequence import jax import numpy as np from absl.testing import absltest from jax import numpy as jnp from jax import random from flax import linen as nn # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class MLP(nn.Module): layer_sizes: Sequence[int] activation: Any | None = None activation_final: Any | None = None @nn.compact def __call__(self, inputs): x = inputs for layer_size in self.layer_sizes[:-1]: x = nn.Dense( features=layer_size, kernel_init=nn.initializers.ones_init() )(x) if self.activation is not None: x = self.activation(x) x = nn.Dense( features=self.layer_sizes[-1], kernel_init=nn.initializers.ones_init() )(x) if self.activation_final is None: return x return self.activation_final(x) class AttentionTuple(nn.Module): num_heads: int = 2 qkv_features: int = 16 @nn.compact def __call__(self, query, key_value): output = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, qkv_features=self.qkv_features )(query, key_value) return output, key_value class AttentionDict(nn.Module): num_heads: int = 2 qkv_features: int = 16 @nn.compact def __call__(self, query, key_value): output = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, qkv_features=self.qkv_features )(query, key_value) return dict(query=output, key_value=key_value) class SequentialTest(absltest.TestCase): def test_construction(self): sequential = nn.Sequential([nn.Dense(4), nn.Dense(2)]) key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 1, 5)) params = sequential.init(key2, x) output = sequential.apply(params, x) self.assertEqual(output.shape, (3, 1, 2)) def test_fails_if_layers_empty(self): sequential = nn.Sequential([]) with self.assertRaisesRegex(ValueError, 'Empty Sequential module'): sequential.init(random.key(42), jnp.ones((3, 5))) def test_same_output_as_mlp(self): sequential = nn.Sequential( [ nn.Dense(4, kernel_init=nn.initializers.ones_init()), nn.Dense(8, kernel_init=nn.initializers.ones_init()), nn.Dense(2, kernel_init=nn.initializers.ones_init()), ] ) mlp = MLP(layer_sizes=[4, 8, 2]) key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 5)) params_1 = sequential.init(key2, x) params_2 = mlp.init(key2, x) output_1 = sequential.apply(params_1, x) output_2 = mlp.apply(params_2, x) np.testing.assert_array_equal(output_1, output_2) def test_same_output_as_mlp_with_activation(self): sequential = nn.Sequential( [ nn.Dense(4, kernel_init=nn.initializers.ones_init()), nn.relu, nn.Dense(8, kernel_init=nn.initializers.ones_init()), nn.relu, nn.Dense(2, kernel_init=nn.initializers.ones_init()), nn.log_softmax, ] ) mlp = MLP( layer_sizes=[4, 8, 2], activation=nn.relu, activation_final=nn.log_softmax, ) key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 5)) params_1 = sequential.init(key2, x) params_2 = mlp.init(key2, x) output_1 = sequential.apply(params_1, x) output_2 = mlp.apply(params_2, x) np.testing.assert_array_equal(output_1, output_2) def test_tuple_output(self): sequential = nn.Sequential( [ AttentionTuple(), AttentionTuple(), ] ) key1, key2, key3 = random.split(random.key(0), 3) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key2, (9, 5)) params_1 = sequential.init(key3, query, key_value) outputs = sequential.apply(params_1, query, key_value) np.testing.assert_equal(len(outputs), 2) out_query, out_key_value = outputs np.testing.assert_equal(out_query.shape, (3, 5)) np.testing.assert_equal(out_key_value.shape, (9, 5)) def test_dict_output(self): sequential = nn.Sequential( [ AttentionDict(), AttentionDict(), ] ) key1, key2, key3 = random.split(random.key(0), 3) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key2, (9, 5)) params_1 = sequential.init(key3, query, key_value) outputs = sequential.apply(params_1, query, key_value) np.testing.assert_equal(len(outputs), 2) out_query, out_key_value = outputs['query'], outputs['key_value'] np.testing.assert_equal(out_query.shape, (3, 5)) np.testing.assert_equal(out_key_value.shape, (9, 5)) def test_sequential_compact(self): mlp = nn.Sequential([ lambda x: nn.Dense(x.shape[-1])(x), nn.relu, lambda x: nn.Dense(x.shape[-1])(x), nn.relu, lambda x: nn.Dense(x.shape[-1])(x), ]) params = mlp.init(random.key(0), jnp.ones((3, 5)))['params'] self.assertIn('Dense_0', params) self.assertIn('Dense_1', params) self.assertIn('Dense_2', params) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_dtypes_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.dtypes.""" from absl.testing import absltest from jax import numpy as jnp from flax.linen import dtypes try: # JAX v0.8.0 and newer from jax import enable_x64 except ImportError: from jax.experimental import enable_x64 default_float_dtype = jnp.result_type(1.0) class DtypesTest(absltest.TestCase): def test_no_inexact_dtype(self): i32 = jnp.int32(1.0) self.assertEqual(dtypes.canonicalize_dtype(i32, inexact=False), jnp.int32) def test_inexact_dtype(self): with enable_x64(): i64 = jnp.int64(1) self.assertEqual(dtypes.canonicalize_dtype(i64), jnp.float32) i32 = jnp.int32(1) self.assertEqual(dtypes.canonicalize_dtype(i32), jnp.float32) i16 = jnp.int16(1.0) self.assertEqual(dtypes.canonicalize_dtype(i16), jnp.float32) def test_explicit_downcast(self): f32 = jnp.float32(1.0) (x,) = dtypes.promote_dtype(f32, dtype=jnp.float16) self.assertEqual(x.dtype, jnp.float16) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_linear_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.linear.""" import functools import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest, parameterized from jax import random from jax.nn import initializers from flax import linen as nn # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class LinearTest(parameterized.TestCase): def test_dense(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = dense_module.init_with_output(rng, x) self.assertEqual(y.shape, (1, 4)) self.assertEqual(y.dtype, jnp.float32) np.testing.assert_allclose(y, np.full((1, 4), 4.0)) def test_dense_extra_batch_dims(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 2, 3)) dense_module = nn.Dense( features=4, kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = dense_module.init_with_output(rng, x) np.testing.assert_allclose(y, np.full((1, 2, 4), 4.0)) def test_dense_no_bias(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, use_bias=False, kernel_init=initializers.ones, ) y, _ = dense_module.init_with_output(rng, x) np.testing.assert_allclose(y, np.full((1, 4), 3.0)) def test_dense_is_dense_general(self): x = jax.random.normal(random.key(0), (5, 3)) dense_module = nn.Dense( features=4, use_bias=True, bias_init=initializers.normal(), ) y1, _ = dense_module.init_with_output(dict(params=random.key(1)), x) dg_module = nn.DenseGeneral( features=4, use_bias=True, bias_init=initializers.normal(), ) y2, _ = dg_module.init_with_output(dict(params=random.key(1)), x) np.testing.assert_allclose(y1, y2) def test_dense_general_batch_dim_raises(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 3, 2, 5)) with self.assertRaises(ValueError): dg_module = nn.DenseGeneral( features=4, batch_dims=(0, 2), kernel_init=initializers.ones, bias_init=initializers.ones, ) dg_module.init_with_output(rng, x) def test_dense_general_two_out(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dg_module = nn.DenseGeneral( features=(2, 2), kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = dg_module.init_with_output(rng, x) np.testing.assert_allclose(y, np.full((1, 2, 2), 4.0)) def test_dense_general_two_in(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 2, 2)) dg_module = nn.DenseGeneral( features=3, axis=(-2, 2), kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = dg_module.init_with_output(rng, x) np.testing.assert_allclose(y, np.full((1, 3), 5.0)) def test_dense_general_batch_dim(self): rng = dict(params=random.key(0)) x = jnp.ones((2, 1, 3, 5)) state = {'counter': 0.0} def _counter_init(rng, shape, dtype, state): del rng, dtype state['counter'] += 1.0 return jnp.full(shape, state['counter']) counter_init = functools.partial(_counter_init, state=state) dg_module = nn.DenseGeneral( features=7, axis=(3, -2), batch_dims=0, bias_init=initializers.ones, kernel_init=counter_init, ) y, _ = dg_module.init_with_output(rng, x) target = np.full((2, 1, 7), 16.0) np.testing.assert_allclose(y, target) @parameterized.parameters( [ ((-2, 3), (), 'bijk,jklm->bilm'), ((3, -2), (), 'bijk,jklm->bilm'), ((-2, 3), (0,), 'bijk,bjklm->bilm'), ] ) def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): rng = dict(params=random.key(0)) x = jnp.ones((16, 8, 9, 10)) dg_module = nn.DenseGeneral( features=(11, 12), axis=axis, batch_dims=batch_dims, bias_init=initializers.ones, kernel_init=initializers.normal(), ) y, initial_params = dg_module.init_with_output(rng, x) target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1.0 np.testing.assert_allclose(y, target, atol=1e-6) def test_complex_params_dense(self): dense = nn.Dense(features=2, param_dtype=jnp.complex64) x = jnp.ones((1, 2), jnp.float32) variables = dense.init(random.key(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.complex64) self.assertEqual(variables['params']['bias'].dtype, jnp.complex64) y = dense.apply(variables, x) self.assertEqual(y.dtype, jnp.complex64) def test_complex_input_dense(self): dense = nn.Dense(features=2) x = jnp.ones((1, 2), jnp.complex64) variables = dense.init(random.key(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.float32) self.assertEqual(variables['params']['bias'].dtype, jnp.float32) y = dense.apply(variables, x) self.assertEqual(y.dtype, jnp.complex64) @parameterized.parameters( [ ( 'abc,cde->abde', (3, 4, 5), (5, 6, 7), (3, 4, 6, 7), (6, 7), (1, 1, 6, 7), ), ( 'abcd,abcd->abcd', (3, 4, 5, 6), (3, 4, 5, 6), (3, 4, 5, 6), (3, 4, 5, 6), (3, 4, 5, 6), ), ( 'abcd,abcd->abd', (3, 4, 5, 6), (3, 4, 5, 6), (3, 4, 6), (3, 4, 6), (3, 4, 6), ), ( 'abcd,cdef->abef', (3, 4, 5, 6), (5, 6, 7, 8), (3, 4, 7, 8), (7, 8), (1, 1, 7, 8), ), ( 'abcd,eafc->bfed', (3, 4, 5, 6), (7, 3, 8, 5), (4, 8, 7, 6), (8, 7), (1, 8, 7, 1), ), ( 'abcd,cbedf->abfe', (3, 4, 5, 6), (5, 4, 7, 6, 8), (3, 4, 8, 7), (4, 8, 7), (1, 4, 8, 7), ), ( 'ab...,bc...->ac...', (3, 4, 6), (4, 5, 6), (3, 5, 6), (5, 6), (1, 5, 6), ), ( 'd...ab,bc...->ad...c', (8, 6, 7, 3, 4), (4, 5, 6, 7), (3, 8, 6, 7, 5), (6, 7, 5), (1, 1, 6, 7, 5), ), ( 'd...ab,bc...->adc', (8, 6, 7, 3, 4), (4, 5, 6, 7), (3, 8, 5), (5,), (1, 1, 5), ), ( 'abd...,bc...->ac...', (3, 4, 6), (4, 5, 6), (3, 5, 6), (5, 6), (1, 5, 6), ), ( 'a...d,ej...f->adef', (3, 4, 5, 6), (7, 4, 5, 8), (3, 6, 7, 8), (7, 8), (1, 1, 7, 8), ), ( 'ab...d,ej...f->ad...f', (3, 4, 5, 6), (7, 4, 5, 8), (3, 6, 5, 8), (5, 8), (1, 1, 5, 8), ), ] ) def test_einsum_init_apply( self, einsum_str, lhs_shape, rhs_shape, expected_result_shape, expected_bias_shape, bias_broadcast_shape, ): layer = nn.Einsum(rhs_shape, einsum_str, bias_init=nn.initializers.normal()) x = jax.random.normal(jax.random.key(0), lhs_shape) v = layer.init(jax.random.key(1), x) self.assertEqual(rhs_shape, v['params']['kernel'].shape) self.assertEqual(expected_bias_shape, v['params']['bias'].shape) out = layer.apply(v, x) self.assertEqual(out.shape, expected_result_shape) expected_out = jnp.einsum(einsum_str, x, v['params']['kernel']) + v[ 'params' ]['bias'].reshape(bias_broadcast_shape) np.testing.assert_allclose(out, expected_out) @parameterized.parameters( [ ( ('abd,bce->ace', 'abd,bc...->ac...', 'abd...,bc...->ac...'), (3, 4, 6), (4, 5, 6), ), ( ( 'abcd,ejcf->adef', 'ab...d,ej...f->adef', 'ab...d,e...f->adef', 'a...d,ej...f->adef', ), (3, 4, 5, 6), (7, 4, 5, 8), ), ( ('abcd,ejcf->adcf', 'ab...d,ej...f->ad...f'), (3, 4, 5, 6), (7, 4, 5, 8), ), ] ) def test_einsum_ellipsis_equivalence( self, einsum_str_list, lhs_shape, rhs_shape ): x = jax.random.uniform(jax.random.key(0), lhs_shape) layer = nn.Einsum( rhs_shape, einsum_str_list[0], bias_init=nn.initializers.normal() ) v = layer.init(jax.random.key(1), x) out = layer.apply(v, x) for einsum_str in einsum_str_list[1:]: layer2 = nn.Einsum( rhs_shape, einsum_str, bias_init=nn.initializers.normal() ) v2 = layer2.init(jax.random.key(1), x) np.testing.assert_allclose(v['params']['kernel'], v2['params']['kernel']) np.testing.assert_allclose(v['params']['bias'], v2['params']['bias']) np.testing.assert_allclose(out, layer2.apply(v2, x)) def test_einsum_str_arg(self): einsum_str = 'abc,cde->abde' x = jax.random.normal(jax.random.key(0), (3, 4, 5)) constructed_layer = nn.Einsum( (5, 6, 7), einsum_str, bias_init=nn.initializers.normal() ) constructed_v = constructed_layer.init(jax.random.key(1), x) constructed_out = constructed_layer.apply(constructed_v, x) called_layer = nn.Einsum((5, 6, 7), bias_init=nn.initializers.normal()) called_v = called_layer.init(jax.random.key(1), x, einsum_str) called_out = called_layer.apply(called_v, x, einsum_str) np.testing.assert_allclose( constructed_v['params']['kernel'], called_v['params']['kernel'] ) np.testing.assert_allclose( constructed_v['params']['bias'], called_v['params']['bias'] ) np.testing.assert_allclose(constructed_out, called_out) with self.assertRaisesWithLiteralMatch( ValueError, 'Parameter "einsum_str" was passed to the constructor and at call time. Should be passed just once.', ): constructed_layer.init(jax.random.key(1), x, einsum_str) with self.assertRaisesWithLiteralMatch( ValueError, 'Parameter "einsum_str" must be passed to the constructor or at call time.', ): called_layer.init(jax.random.key(1), x) def test_einsum_space_str(self): x = jax.random.normal(jax.random.key(0), (3, 4, 5)) layer1 = nn.Einsum( (5, 6, 7), 'abc,cde->abde', bias_init=nn.initializers.normal() ) v1 = layer1.init(jax.random.key(1), x) out1 = layer1.apply(v1, x) layer2 = nn.Einsum( (5, 6, 7), ' ab c , c d e - >a b d e ', bias_init=nn.initializers.normal(), ) v2 = layer2.init(jax.random.key(1), x) out2 = layer2.apply(v1, x) np.testing.assert_allclose(v1['params']['kernel'], v2['params']['kernel']) np.testing.assert_allclose(v1['params']['bias'], v2['params']['bias']) np.testing.assert_allclose(out1, out2) @parameterized.parameters( [ ('abc,cde', '`einsum_str` equation must be explicit and include "->".'), ( 'abc->', f'`einsum_str` equation must have exactly two operands and therefore, exactly one comma character, instead of 0', ), ( 'abc,cde,efg->abdfg', f'`einsum_str` equation must have exactly two operands and therefore, exactly one comma character, instead of 2', ), ] ) def test_einsum_error(self, einsum_str, error_msg): x = jax.random.normal(jax.random.key(0), (3, 4, 5)) layer = nn.Einsum((5, 6, 7), einsum_str) with self.assertRaisesRegex(ValueError, error_msg): layer.init(jax.random.key(1), x) @parameterized.product(use_bias=(True, False)) def test_conv(self, use_bias): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 3)) conv_module = nn.Conv( features=4, use_bias=use_bias, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) expected = 10.0 if use_bias else 9.0 np.testing.assert_allclose(y, np.full((1, 6, 4), expected)) @parameterized.product(use_bias=(True, False)) def test_multibatch_input_conv(self, use_bias): rng = dict(params=random.key(0)) x = jnp.ones((2, 5, 8, 3)) conv_module = nn.Conv( features=4, use_bias=use_bias, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) expected = 10.0 if use_bias else 9.0 np.testing.assert_allclose(y, np.full((2, 5, 6, 4), expected)) def test_conv_local(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 2)) conv_module = nn.ConvLocal( features=4, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (6, 3 * 2, 4)) np.testing.assert_allclose(y, np.full((1, 6, 4), 7.0)) def test_single_input_conv(self): rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) conv_module = nn.Conv( features=4, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) np.testing.assert_allclose(y, np.full((6, 4), 10.0)) def test_single_input_masked_conv(self): rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) m = jnp.tril(jnp.ones((3, 3, 4))) conv_module = nn.Conv( features=4, kernel_size=(3,), padding='VALID', mask=m, kernel_init=initializers.ones, bias_init=initializers.ones, ) expected = jnp.array( [ [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], ] ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) np.testing.assert_allclose(y, expected) def test_single_input_conv_local(self): rng = dict(params=random.key(0)) x = jnp.ones((8, 2)) conv_module = nn.ConvLocal( features=4, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (6, 3 * 2, 4)) np.testing.assert_allclose(y, np.full((6, 4), 7.0)) def test_group_conv(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 4)) conv_module = nn.Conv( features=4, kernel_size=(3,), feature_group_count=2, padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 2, 4)) np.testing.assert_allclose(y, np.full((1, 6, 4), 7.0)) @parameterized.product( n_batch=(1, 3), n_features=(1, 2), kernel_size=(1, 2, 3, 9), n_input_features=(1, 3), input_size=(1, 8, 16), module=(nn.Conv, nn.ConvLocal), ) def test_circular_conv_1d_constant( self, n_batch, n_features, kernel_size, n_input_features, input_size, module, ): """ Test 1D convolution with circular padding: filter with all elements equal to 1 applied on an input with all elements equal to 1. Result should have the same shape as input (except for the feature dimension) and have all elements equal to `n_input_features * kernel_lin_size`. """ rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_size, n_input_features)) conv_module = module( features=n_features, kernel_size=(kernel_size,), padding='CIRCULAR', kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) kernel_shape = self._get_kernel_shape( x.shape, (kernel_size,), module, n_features ) self.assertEqual( initial_params['params']['kernel'].shape, kernel_shape, ) correct_ans = np.full( (n_batch, input_size, n_features), kernel_size * n_input_features ) np.testing.assert_allclose(y, correct_ans) def _get_kernel_shape(self, input_shape, kernel_size, module, n_features): if module == nn.Conv: kernel_shape = kernel_size + (input_shape[-1], n_features) elif module == nn.ConvLocal: kernel_shape = input_shape[1:-1] + ( input_shape[-1] * np.prod(kernel_size), n_features, ) else: raise ValueError(module) return kernel_shape @parameterized.product( n_batch=(1, 3), n_features=(1, 2, 10), kernel_lin_size=(1, 2, 3, 9), n_input_features=(1, 5), input_x_size=(14,), input_y_size=(5, 10), module=(nn.Conv, nn.ConvLocal), ) def test_circular_conv_2d_constant( self, n_batch, n_features, kernel_lin_size, n_input_features, input_x_size, input_y_size, module, ): """ Test 2D convolution with circular padding: square filter with all elements equal to 1 applied on an input with all elements equal to 1. Result should have the same shape as input (except for the feature dimension) and have all elements equal to `n_input_features * kernel_lin_size^2`. """ rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_x_size, input_y_size, n_input_features)) kernel_size = (kernel_lin_size, kernel_lin_size) conv_module = module( features=n_features, kernel_size=kernel_size, padding='CIRCULAR', kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) kernel_shape = self._get_kernel_shape( x.shape, kernel_size, module, n_features ) self.assertEqual( initial_params['params']['kernel'].shape, kernel_shape, ) correct_ans = np.full( (n_batch, input_x_size, input_y_size, n_features), kernel_lin_size * kernel_lin_size * n_input_features, ) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_1d_custom(self): """Test 1d convolution with circular padding and a stride.""" rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) kernel = np.expand_dims(kernel, (1, 2)) conv_module = nn.Conv( features=1, kernel_size=(3,), strides=(3,), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array((5 + 2 * 1 + 2, 3 + 2 * 4 + 5)) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_local_1d_custom(self): """ Test 1d local convolution with circular padding and a stride """ rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array(((-1, 2, 3), (4, 5, 6))) kernel = np.expand_dims(kernel, (2,)) conv_module = nn.ConvLocal( features=1, kernel_size=(3,), strides=(3,), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (2, 3, 1)) # Compare with manually computed convolution correct_ans = np.array((-1 * 5 + 2 * 1 + 3 * 2, 4 * 3 + 5 * 4 + 6 * 5)) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_1d_dilation(self): """Test 1d convolution with circular padding and kernel dilation.""" rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) kernel = np.expand_dims(kernel, (1, 2)) conv_module = nn.Conv( features=1, kernel_size=(3,), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, kernel_dilation=(3,), ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( 3 + 2 * 1 + 4, 4 + 2 * 2 + 5, 5 + 2 * 3 + 1, 1 + 2 * 4 + 2, 2 + 2 * 5 + 3, ) ) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_local_1d_dilation(self): """ Test 1d local convolution with circular padding and kernel dilation """ rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array( ((1, 2, 1), (3, 4, 5), (-1, 1, 2), (2, 3, 4), (-1, -2, -3)) ) kernel = np.expand_dims(kernel, (2,)) conv_module = nn.ConvLocal( features=1, kernel_size=(3,), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, kernel_dilation=(3,), ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (5, 3, 1)) # Compare with manually computed convolution correct_ans = np.array( ( 1 * 3 + 2 * 1 + 1 * 4, 3 * 4 + 4 * 2 + 5 * 5, -1 * 5 + 1 * 3 + 2 * 1, 2 * 1 + 3 * 4 + 4 * 2, -1 * 2 + -2 * 5 + -3 * 3, ) ) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_2d_custom(self): """Test 2d convolution with circular padding on a 3x3 example.""" rng = dict(params=random.key(0)) x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0))) kernel = np.expand_dims(kernel, (2, 3)) conv_module = nn.Conv( features=1, kernel_size=(3, 3), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( (2 * 1 + 7 + 2 + 4 + 3, 2 * 2 + 8 + 3 + 5 + 1, 2 * 3 + 9 + 1 + 6 + 2), (2 * 4 + 1 + 5 + 7 + 6, 2 * 5 + 2 + 6 + 8 + 4, 2 * 6 + 3 + 4 + 9 + 5), (2 * 7 + 4 + 8 + 1 + 9, 2 * 8 + 5 + 9 + 2 + 7, 2 * 9 + 6 + 7 + 3 + 8), ) ) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_local_2d_custom(self): """ Test 2d local convolution with circular padding on a 3x3 example """ rng = dict(params=random.key(0)) x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array( ( ( ((0, 1, 0), (1, 2, 1), (0, 1, 0)), ((0, 1, 0), (1, 3, 1), (0, 1, 0)), ((0, 1, 0), (1, 4, 1), (0, 1, 0)), ), ( ((0, 1, 0), (1, 5, 1), (0, 1, 0)), ((0, 1, 0), (1, 6, 1), (0, 1, 0)), ((0, 1, 0), (1, 7, 1), (0, 1, 0)), ), ( ((0, 1, 0), (1, 8, 1), (0, 1, 0)), ((0, 1, 0), (1, 9, 1), (0, 1, 0)), ((0, 1, 0), (1, 10, 1), (0, 1, 0)), ), ) ) kernel = np.expand_dims(kernel, (3,)) kernel = np.reshape(kernel, (3, 3, 9, 1)) conv_module = nn.ConvLocal( features=1, kernel_size=(3, 3), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 9, 1)) # Compare with manually computed convolution correct_ans = np.array( ( (2 * 1 + 7 + 2 + 4 + 3, 3 * 2 + 8 + 3 + 5 + 1, 4 * 3 + 9 + 1 + 6 + 2), (5 * 4 + 1 + 5 + 7 + 6, 6 * 5 + 2 + 6 + 8 + 4, 7 * 6 + 3 + 4 + 9 + 5), (8 * 7 + 4 + 8 + 1 + 9, 9 * 8 + 5 + 9 + 2 + 7, 10 * 9 + 6 + 7 + 3 + 8), ) ) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) def test_reflect_conv_1d_custom(self): """Test 1d convolution with reflection padding and a stride.""" rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) kernel = np.expand_dims(kernel, (1, 2)) conv_module = nn.Conv( features=1, kernel_size=(3,), strides=(2,), padding='REFLECT', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array((2 + 2 * 1 + 2, 2 + 2 * 3 + 4, 4 + 2 * 5 + 4)) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_reflect_conv_2d_custom(self): """Test 2d convolution with reflect padding on a 3x3 example.""" rng = dict(params=random.key(0)) x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0))) kernel = np.expand_dims(kernel, (2, 3)) conv_module = nn.Conv( features=1, kernel_size=(3, 3), padding='REFLECT', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( (2 * 1 + 4 + 2 + 4 + 2, 2 * 2 + 5 + 3 + 5 + 1, 2 * 3 + 6 + 2 + 6 + 2), (2 * 4 + 1 + 5 + 7 + 5, 2 * 5 + 2 + 6 + 8 + 4, 2 * 6 + 3 + 5 + 9 + 5), (2 * 7 + 4 + 8 + 8 + 4, 2 * 8 + 5 + 9 + 5 + 7, 2 * 9 + 6 + 8 + 6 + 8), ) ) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) def test_causal_conv1d(self): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 4)) conv_module = nn.Conv( features=4, kernel_size=(3,), padding='CAUSAL', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = conv_module.init_with_output(rng, x) correct_ans = np.array( [ [ [5.0, 5.0, 5.0, 5.0], [9.0, 9.0, 9.0, 9.0], [13.0, 13.0, 13.0, 13.0], [13.0, 13.0, 13.0, 13.0], [13.0, 13.0, 13.0, 13.0], [13.0, 13.0, 13.0, 13.0], [13.0, 13.0, 13.0, 13.0], [13.0, 13.0, 13.0, 13.0], ] ] ) np.testing.assert_allclose(y, correct_ans) np.testing.assert_array_equal(correct_ans.shape, y.shape) @parameterized.product( use_bias=(True, False), ) def test_conv_transpose(self, use_bias): rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, use_bias=use_bias, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array( [ [ [4.0, 4.0, 4.0, 4.0], [7.0, 7.0, 7.0, 7.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [7.0, 7.0, 7.0, 7.0], [4.0, 4.0, 4.0, 4.0], ] ] ) if not use_bias: correct_ans -= 1.0 np.testing.assert_allclose(y, correct_ans) @parameterized.product( use_bias=(True, False), ) def test_multibatch_input_conv_transpose(self, use_bias): rng = dict(params=random.key(0)) x = jnp.ones((2, 5, 8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, use_bias=use_bias, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array( [ [ [4.0, 4.0, 4.0, 4.0], [7.0, 7.0, 7.0, 7.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [7.0, 7.0, 7.0, 7.0], [4.0, 4.0, 4.0, 4.0], ] ] ) correct_ans = np.repeat(correct_ans[None], repeats=2, axis=0) correct_ans = np.repeat(correct_ans, repeats=5, axis=1) if not use_bias: correct_ans -= 1.0 np.testing.assert_allclose(y, correct_ans) def test_single_input_conv_transpose(self): rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, kernel_size=(3,), padding='VALID', kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array( [ [4.0, 4.0, 4.0, 4.0], [7.0, 7.0, 7.0, 7.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [10.0, 10.0, 10.0, 10.0], [7.0, 7.0, 7.0, 7.0], [4.0, 4.0, 4.0, 4.0], ] ) np.testing.assert_allclose(y, correct_ans) def test_single_input_masked_conv_transpose(self): rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) m = jnp.tril(jnp.ones((3, 3, 4))) conv_transpose_module = nn.ConvTranspose( features=4, kernel_size=(3,), padding='VALID', mask=m, kernel_init=initializers.ones, bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array( [ [4.0, 3.0, 2.0, 1.0], [7.0, 5.0, 3.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [10.0, 7.0, 4.0, 1.0], [7.0, 5.0, 3.0, 1.0], [4.0, 3.0, 2.0, 1.0], ] ) np.testing.assert_allclose(y, correct_ans) @parameterized.product( n_batch=(1, 3), n_features=(1, 2), kernel_size=(1, 2, 3, 9), n_input_features=(1, 3), input_size=(1, 8, 16), ) def test_circular_conv_transpose_1d_constant( self, n_batch, n_features, kernel_size, n_input_features, input_size ): """ Test 1D transposed convolution with circular padding: filter with all elements equal to 1 applied on an input with all elements equal to 1. Result should have the same shape as input (except for the feature dimension) and have all elements equal to `n_input_features * kernel_lin_size`. """ rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_size, n_input_features)) conv_module = nn.ConvTranspose( features=n_features, kernel_size=(kernel_size,), padding='CIRCULAR', kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual( initial_params['params']['kernel'].shape, (kernel_size, n_input_features, n_features), ) correct_ans = np.full( (n_batch, input_size, n_features), kernel_size * n_input_features ) np.testing.assert_allclose(y, correct_ans) @parameterized.product( n_batch=(1, 3), n_features=(1, 2, 10), kernel_lin_size=(1, 2, 3, 9), n_input_features=(1, 5), input_x_size=(14,), input_y_size=(5, 10), ) def test_circular_conv_transpose_2d_constant( self, n_batch, n_features, kernel_lin_size, n_input_features, input_x_size, input_y_size, ): """ Test 2D transposed convolution with circular padding: square filter with all elements equal to 1 applied on an input with all elements equal to 1. Result should have the same shape as input (except for the feature dimension) and have all elements equal to `n_input_features * kernel_lin_size^2`. """ rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_x_size, input_y_size, n_input_features)) conv_module = nn.ConvTranspose( features=n_features, kernel_size=(kernel_lin_size, kernel_lin_size), padding='CIRCULAR', kernel_init=initializers.ones, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual( initial_params['params']['kernel'].shape, (kernel_lin_size, kernel_lin_size, n_input_features, n_features), ) correct_ans = np.full( (n_batch, input_x_size, input_y_size, n_features), kernel_lin_size * kernel_lin_size * n_input_features, ) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_transpose_2d_with_vmap(self): layer = nn.ConvTranspose(features=5, kernel_size=(3,), padding='CIRCULAR') # this is ok sample_input = jnp.ones((1, 32, 2)) out, vars = layer.init_with_output(jax.random.key(0), sample_input) self.assertEqual(out.shape, (1, 32, 5)) batch_input = jnp.ones((8, 32, 2)) batch_apply = jax.vmap(layer.apply, in_axes=(None, 0)) # this breaks with the error provided batch_out = batch_apply(vars, batch_input) self.assertEqual(batch_out.shape, (8, 32, 5)) def test_circular_conv_transpose_1d_custom(self): """Test 1d transposed convolution with circular padding and a stride.""" rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) kernel = np.expand_dims(kernel, (1, 2)) conv_module = nn.ConvTranspose( features=1, kernel_size=(3,), strides=(3,), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( # pyformat: disable 1 * 1, 1 * 2, 1 * 1, 2 * 1, 2 * 2, 2 * 1, 3 * 1, 3 * 2, 3 * 1, 4 * 1, 4 * 2, 4 * 1, 5 * 1, 5 * 2, 5 * 1, ) ) correct_ans = np.expand_dims(correct_ans, (0, 2)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_transpose_2d_custom(self): """Test 2d transposed convolution with circular padding on a 3x3 example.""" rng = dict(params=random.key(0)) x = np.array( ( (1, 2, 3), (4, 5, 6), (7, 8, 9), ) ) x = np.expand_dims(x, (0, 3)) kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0))) kernel = np.expand_dims(kernel, (2, 3)) conv_module = nn.ConvTranspose( features=1, kernel_size=(3, 3), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.zeros, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( (18, 21, 24), (27, 30, 33), (36, 39, 42), ) ) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) def test_circular_conv_transpose_2d_custom_bias(self): """Test 2d transposed convolution with circular padding on a 2x2 example with bias.""" rng = dict(params=random.key(0)) x = np.array(((1, 2), (3, 4))) x = np.expand_dims(x, (0, 3)) kernel = np.array( ( (1, 2), (3, 4), ) ) kernel = np.expand_dims(kernel, (2, 3)) conv_module = nn.ConvTranspose( features=1, kernel_size=(2, 2), padding='CIRCULAR', kernel_init=lambda *_: kernel, bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (2, 2, 1, 1)) # Compare with manually computed convolution correct_ans = np.array( ( (21, 23), (29, 31), ) ) correct_ans = np.expand_dims(correct_ans, (0, 3)) np.testing.assert_allclose(y, correct_ans) @parameterized.product(use_bias=(True, False)) def test_transpose_kernel_conv_transpose(self, use_bias): rng = dict(params=random.key(0)) x = jnp.ones((1, 15, 15, 3)) conv_module = nn.ConvTranspose( features=4, use_bias=use_bias, strides=2, kernel_size=(6, 6), padding='CIRCULAR', transpose_kernel=True, ) y, initial_params = conv_module.init_with_output(rng, x) self.assertEqual(initial_params['params']['kernel'].shape, (6, 6, 4, 3)) self.assertEqual(y.shape, (1, 30, 30, 4)) @parameterized.product(module=(nn.Conv, nn.ConvLocal, nn.ConvTranspose)) def test_int_kernel_equality(self, module): conv_int = module(features=4, kernel_size=3) conv_seq = module(features=4, kernel_size=(3,)) x = jnp.ones((8, 3)) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: (x == y).all(), conv_int.init(random.key(0), x), conv_seq.init(random.key(0), x), ) ) ) def test_embed(self): rng = dict(params=random.key(0)) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to(jnp.arange(4)[..., None], (4, 3)).astype( jnp.float32 ) embed_module = nn.Embed( num_embeddings=4, features=3, embedding_init=lambda rng, shape, dtype: dummy_embedding, ) y, initial_params = embed_module.init_with_output(rng, x) np.testing.assert_allclose(y, dummy_embedding[None]) z = embed_module.apply( initial_params, jnp.ones((3,)), method=embed_module.attend ) np.testing.assert_allclose(z, 3.0 * jnp.arange(4)) def test_embed_numpy(self): rng = dict(params=random.key(0)) x = jnp.arange(4)[None] dummy_embedding = np.broadcast_to(np.arange(4)[..., None], (4, 3)).astype( np.float32 ) embed_module = nn.Embed( num_embeddings=4, features=3, embedding_init=lambda rng, shape, dtype: dummy_embedding, ) y, initial_params = embed_module.init_with_output(rng, x) np.testing.assert_allclose(y, dummy_embedding[None]) z = embed_module.apply( initial_params, jnp.ones((3,)), method=embed_module.attend ) np.testing.assert_allclose(z, 3.0 * jnp.arange(4)) def test_embed_hash(self): self.assertEqual(hash(nn.Embed(2, 3)), hash(nn.Embed(2, 3))) self.assertNotEqual(hash(nn.Embed(3, 4)), hash(nn.Embed(2, 3))) def test_non_final_axis(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.DenseGeneral(features=6, axis=1, name='dense')(x) x = jnp.ones((2, 4, 8)) y, variables = Foo().init_with_output(random.key(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), {'dense': {'kernel': (4, 6), 'bias': (6,)}}, ) self.assertEqual(y.shape, (2, 8, 6)) def test_non_final_axes(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.DenseGeneral(features=6, axis=(0, 1), name='dense')(x) x = jnp.ones((2, 4, 8)) y, variables = Foo().init_with_output(random.key(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), {'dense': {'kernel': (2, 4, 6), 'bias': (6,)}}, ) self.assertEqual(y.shape, (8, 6)) def test_canonicalize_padding(self): def test_pad(pad, rank, expected=None): if expected is None: with self.assertRaises(ValueError): nn.linear.canonicalize_padding(pad, rank) else: self.assertEqual(nn.linear.canonicalize_padding(pad, rank), expected) test_pad('SAME', 2, 'SAME') test_pad(2, 3, [(2, 2), (2, 2), (2, 2)]) test_pad((2, 2), 3) test_pad((2, 2), 1) test_pad([1, (2, 3)], 2, [(1, 1), (2, 3)]) test_pad([None, (1, 2)], 2) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_meta_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for linen_meta.""" import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec from flax import linen as nn class LinenMetaTest(absltest.TestCase): def test_boxed_param(self): class Bar(nn.Module): @nn.compact def __call__(mdl_self, x): # pylint: disable=no-self-argument kernel_init = nn.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) kernel = mdl_self.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = mdl_self.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, nn.Partitioned) self.assertEqual(kernel_box.names, ('in', 'out')) return x @ kernel class Foo(nn.Module): @nn.compact def __call__(self, xs): return nn.vmap( Bar, in_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, metadata_params={nn.PARTITION_NAME: 'batch'}, )(name='bar')(xs) m = Foo() variables = m.init(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['bar']['kernel'].names, ('batch', 'in', 'out') ) def test_boxed_variable(self): class Bar(nn.Module): @nn.compact def __call__(mdl_self, x): # pylint: disable=no-self-argument kernel_init = nn.with_partitioning( nn.initializers.ones_init(), ('in', 'out') ) kernel = mdl_self.variable( 'params', 'kernel', kernel_init, mdl_self.make_rng('params'), (x.shape[-1], 2), ) kernel.value += 1.0 self.assertEqual(kernel.value.sum(), kernel.value.size * 2) kernel_box = mdl_self.get_variable('params', 'kernel') self.assertIsInstance(kernel_box, nn.Partitioned) self.assertEqual(kernel_box.names, ('in', 'out')) return x @ kernel.value class Foo(nn.Module): @nn.compact def __call__(self, xs): return nn.vmap( Bar, in_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, metadata_params={nn.PARTITION_NAME: 'batch'}, )(name='bar')(xs) m = Foo() variables = m.init(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['bar']['kernel'].names, ('batch', 'in', 'out') ) # def test_boxed_variable(self): # def f(scope, xs): # def g(scope, x): # kernel_init = nn.with_partitioning(nn.initializers.ones_init(), # ('in', 'out')) # kernel = scope.variable('params', 'kernel', kernel_init, # scope.make_rng('params'), (x.shape[-1], 2)) # kernel.value += 1. # self.assertEqual(kernel.value.sum(), kernel.value.size * 2) # kernel_box = scope.get_variable('params', 'kernel') # self.assertIsInstance(kernel_box, nn.Partitioned) # self.assertEqual(kernel_box.names, ('in', 'out')) # return x @ kernel.value # nn.vmap( # g, in_axes=0, # variable_axes={'params': 0}, split_rngs={'params': True}, # metadata_params={nn.PARTITION_NAME: 'batch'})(scope, xs) # _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) # self.assertEqual(variables['params']['kernel'].names, # ('batch', 'in', 'out')) def test_pjit_scan_over_layers(self): class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')), )(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')) )(h) class Model(nn.Module): @nn.compact def __call__(self, x): def body(_, c): c = MLP(512)(c) return c, () c, _ = nn.scan( body, variable_axes={'params': 0}, split_rngs={'params': 0}, length=8, metadata_params={nn.PARTITION_NAME: None}, )(self, x) return c devs = mesh_utils.create_device_mesh((jax.device_count(), 1)) mesh = Mesh(devs, ['data', 'model']) model = Model() x = jnp.ones((8, 128)) spec = nn.get_partition_spec(jax.eval_shape(model.init, random.key(0), x)) self.assertEqual( spec, { 'params': { 'MLP_0': { 'Dense_0': { 'bias': PartitionSpec(), 'kernel': PartitionSpec(None, 'data', 'model'), }, 'Dense_1': { 'bias': PartitionSpec(), 'kernel': PartitionSpec(None, 'model', 'data'), }, }, }, }, ) x_spec = PartitionSpec('data', 'model') f = lambda x: jax.sharding.NamedSharding(mesh, x) key_spec = PartitionSpec() init_fn = jax.jit( model.init, in_shardings=jax.tree_util.tree_map(f, (key_spec, x_spec)), out_shardings=jax.tree_util.tree_map(f, spec), ) variables = init_fn(random.key(0), x) apply_fn = jax.jit( model.apply, in_shardings=jax.tree_util.tree_map(f, (spec, x_spec)), out_shardings=jax.tree_util.tree_map(f, x_spec), ) y = apply_fn(variables, x) self.assertEqual(y.shape, (8, 128)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_module_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.""" import contextlib import copy import dataclasses import enum import functools import gc import inspect import operator import sys from tempfile import TemporaryDirectory from typing import ( Any, Generic, NamedTuple, TypeVar, get_type_hints, ) from collections.abc import Callable, Mapping, Sequence from unittest.mock import patch import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest from jax import random from jax.nn import initializers from flax import config, errors, struct from flax import linen as nn from flax.core import FrozenDict, Scope, freeze from flax.linen import compact # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() def tree_equals(x, y): return jax.tree_util.tree_all(jax.tree_util.tree_map(operator.eq, x, y)) @contextlib.contextmanager def set_config(option: str, value: bool): old_value = getattr(config, option) try: config.update(option, value) yield None finally: config.update(option, old_value) class DummyModule(nn.Module): @compact def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias class Dense(nn.Module): features: int @compact def __call__(self, x): kernel = self.param( 'kernel', initializers.lecun_normal(), (x.shape[-1], self.features) ) y = jnp.dot(x, kernel) return y class IdentityModule(nn.Module): def __call__(self, x): return x class RaisesModule(nn.Module): def __call__(self): assert False class ModuleTest(absltest.TestCase): def test_init_module(self): rngkey = jax.random.key(0) x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModule(parent=scope)(x) params = scope.variables()['params'] y2 = DummyModule(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) np.testing.assert_allclose(y, jnp.array([2.0])) self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_lazy_init(self): class Foo(nn.Module): @compact def __call__(self, x): k = self.param( 'kernel', nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1]) ) return x @ k # provide a massive input message which would OOM if any compute ops were actually executed variables = Foo().lazy_init( random.key(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32), ) self.assertEqual(variables['params']['kernel'].shape, (128, 128)) def test_lazy_init_fails_on_data_dependence(self): class Foo(nn.Module): @compact def __call__(self, x): k = self.param('kernel', lambda _: x) return x * k with self.assertRaises(errors.LazyInitError): Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) def test_arg_module(self): rngkey = jax.random.key(0) x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dense(3, parent=scope)(x) params = scope.variables()['params'] y2 = Dense(3, parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) self.assertEqual(params['kernel'].shape, (10, 3)) def test_util_fun(self): rngkey = jax.random.key(0) class MLP(nn.Module): @compact def __call__(self, x): x = self._mydense(x) x = self._mydense(x) return x def _mydense(self, x): return Dense(3)(x) x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = MLP(parent=scope)(x) params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) self.assertEqual( param_shape, {'Dense_0': {'kernel': (10, 3)}, 'Dense_1': {'kernel': (3, 3)}}, ) def test_nested_module_reuse(self): rngkey = jax.random.key(0) class MLP(nn.Module): @compact def __call__(self, x): x = self._mydense(x) x = self._mydense(x) return x def _mydense(self, x): return Dense(3)(x) class Top(nn.Module): @compact def __call__(self, x): mlp = MLP() y = mlp(x) z = mlp(x) return y + z x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Top(parent=scope)(x) params = scope.variables()['params'] y2 = Top(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) self.assertEqual( param_shape, { 'MLP_0': { 'Dense_0': {'kernel': (10, 3)}, 'Dense_1': {'kernel': (3, 3)}, } }, ) def test_setup_dict_assignment(self): rngkey = jax.random.key(0) class MLP(nn.Module): def setup(self): self.lyrs1 = { 'a': Dense(3), 'b': Dense(3), } self.lyrs2 = [Dense(3), Dense(3)] def __call__(self, x): y = self.lyrs1['a'](x) z = self.lyrs1['b'](y) return z x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = MLP(parent=scope)(x) params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_util.tree_map(jnp.shape, params) self.assertEqual( param_shape, {'lyrs1_a': {'kernel': (10, 3)}, 'lyrs1_b': {'kernel': (3, 3)}}, ) def test_setup_dict_nonstring_keys(self): class Foo(nn.Module): def setup(self): self.a = {(1, 2): nn.Dense(2)} # Tuple as key. @nn.compact def __call__(self, x): return self.a[(1, 2)](x) foo = Foo() x = jnp.ones(shape=(1, 3)) params = foo.init(random.key(0), x)['params'] param_shape = jax.tree_util.tree_map(jnp.shape, params) self.assertEqual( param_shape, {'a_(1, 2)': {'kernel': (3, 2), 'bias': (2,)}} ) def test_setup_frozen_dict_nonstring_keys(self): a = {1: 2} class Foo(nn.Module): def setup(self): self.a = FrozenDict(a) # int as key. @nn.compact def __call__(self, x): return self.a[x] foo = Foo() x = 1 params = foo.init(random.key(0), x) assert foo.apply(params, x) == a[x] def test_setup_dict_nonstring_keys_in_state(self): class Foo(nn.Module): a: dict[int, int] # int as key. @nn.compact def __call__(self, x): return self.a[x] a = {1: 2} foo = Foo(a) x = 1 params = foo.init(random.key(0), x) assert foo.apply(params, x) == a[x] def test_setup_cloning(self): class MLP(nn.Module): def setup(self): self.dense = Dense(3) scope = Scope({}) unused_clone = MLP(parent=scope).clone() def test_submodule_attr(self): rngkey = jax.random.key(0) class Inner(nn.Module): @compact def __call__(self): self.param('x', lambda rng: 40) class Outer(nn.Module): inner: nn.Module def __call__(self): return self.inner() class Wrapper(nn.Module): def setup(self): self.inner = Inner() self.outer = Outer(self.inner) def __call__(self): return self.outer() scope = Scope({'params': {}}, rngs={'params': rngkey}, mutable=['params']) # Make sure this doesn't raise "Can't attach to remote parent" wrapper = Wrapper(parent=scope) wrapper() # Make sure that variables are registered at the level of the # Wrapper submodule, not the Outer submodule. self.assertEqual(40, scope.variables()['params']['inner']['x']) def test_param_in_setup(self): rngkey = jax.random.key(0) class DummyModuleWithoutCompact(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) def __call__(self, x): return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModuleWithoutCompact(x.shape, parent=scope)(x) params = scope.variables()['params'] y2 = DummyModuleWithoutCompact(x.shape, parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) np.testing.assert_allclose(y, jnp.array([2.0])) self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_init_outside_setup_without_compact(self): rngkey = jax.random.key(0) class DummyModuleWithoutCompact(nn.Module): def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): unused_y = DummyModuleWithoutCompact(parent=scope)(x) def test_init_outside_call(self): rngkey = jax.random.key(0) class Dummy(nn.Module): @compact def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias def foo(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): unused_y = Dummy(parent=scope).foo(x) def test_setup_call_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @compact def __call__(self, x): unused_bias = self.param('bias', initializers.ones, x.shape) return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): unused_y = Dummy(x.shape, parent=scope)(x) def test_call_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] @compact def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) bias = self.param('bias', initializers.ones, self.xshape) return x + bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): unused_y = Dummy(x.shape, parent=scope)(x) def test_setup_var_collision(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) self.bias = self.param('bias', initializers.ones, self.xshape) def __call__(self, x): return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): unused_y = Dummy(x.shape, parent=scope)(x) def test_setattr_name_var_disagreement_allowed_in_lists(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.biases = [ self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4) ] def __call__(self, x): return x + self.biases[0] x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) self.assertEqual(y, jnp.array([2.0])) def test_setattr_name_var_disagreement_allowed_in_dicts(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.biases = { # NOTE: Keys still must be strings. This is to make a possible # future transition to automatically derived parameter names when # assigned as a dict easier (like we currently have with # submodules). See a bit of discussion here: # https://github.com/google/flax/issues/705#issuecomment-738761853 str(i): self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4) } def __call__(self, x): return x + self.biases['0'] x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) self.assertEqual(y, jnp.array([2.0])) def test_submodule_var_collision_with_scope(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) self.bias = DummyModule() def __call__(self, x): return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaises(errors.NameInUseError): unused_y = Dummy(x.shape, parent=scope)(x) def test_submodule_var_collision_with_submodule(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @compact def __call__(self, x): unused_bias = DummyModule(name='bias') return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create submodule "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): unused_y = Dummy(x.shape, parent=scope)(x) def test_submodule_var_collision_with_params(self): rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: tuple[int, ...] def setup(self): self.bias = DummyModule() @compact def __call__(self, x): unused_bias = self.param('bias', initializers.ones, self.xshape) return x + self.bias x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): unused_y = Dummy(x.shape, parent=scope)(x) def test_attr_empty_container(self): class Foo(nn.Module): bar: Mapping[str, Any] @compact def __call__(self): pass Foo({'a': ()}).apply({}) def test_multiple_compact_methods(self): """Test that multiple methods with the @compact decorator can be used. NOTE: in the near future we might want to have compact methods reset the autoname_cursor such that Dense would be reused in the second method. """ class MultipleCompactMethods(nn.Module): @compact def __call__(self, x): x = nn.Dense(1)(x) return self.method(x) @compact def method(self, x): x = nn.Dense(1)(x) return x m = MultipleCompactMethods() variables = m.init(random.key(0), jnp.ones((1, 1))) params = variables['params'] self.assertIn('Dense_0', params) self.assertIn('Dense_1', params) def test_only_one_compact_method_subclass(self): class Dummy(nn.Module): @nn.compact def __call__(self): pass class SubDummy(Dummy): @nn.compact def __call__(self): super().__call__() scope = Scope(variables={}) subdummy = SubDummy(parent=scope) # Make sure the @compact annotation is valid on both base class and # subclass, as long as its on the same method. subdummy() def test_forgotten_compact_annotation(self): class Bar(nn.Module): # user forgot to add @compact def __call__(self, x): return nn.Dense(1)(x) class Foo(nn.Module): @nn.compact def __call__(self, x): bar = Bar() x = bar(x) x = bar(x) return x msg = ( r'Submodule Dense must be defined in `setup\(\)` or in a method ' 'wrapped in `@compact`' ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): Foo().init(random.key(0), jnp.ones((1, 3))) def test_forgotten_compact_annotation_with_explicit_parent(self): class Bar(nn.Module): def __call__(self, x): return nn.Dense(1, parent=self)(x) class Foo(nn.Module): @nn.compact def __call__(self, x): bar = Bar() x = bar(x) x = bar(x) return x msg = ( r'Submodule Dense must be defined in `setup\(\)` or in a method ' 'wrapped in `@compact`' ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): Foo().init(random.key(0), jnp.ones((1, 3))) def test_numpy_array_shape_class_args(self): class MLP(nn.Module): widths: Sequence[int] @nn.compact def __call__(self, x): for width in self.widths[:-1]: x = nn.relu(nn.Dense(width)(x)) return nn.Dense(self.widths[-1])(x) test = MLP(np.array([3, 3], np.int32)) params = test.init({'params': random.key(42)}, jnp.ones((3, 3))) _ = test.apply(params, jnp.ones((3, 3))) def test_get_local_methods(self): class Base: @staticmethod def bar(x): return x @classmethod def baz(cls, x): return x def bleep(self, x): return x class Derived1(Base): @staticmethod def bar2(x): return x @classmethod def baz2(cls, x): return x def bloop(self, x): return x class Derived2(Derived1): pass self.assertEqual(nn.module._get_local_method_names(Base), ('bleep',)) self.assertEqual(nn.module._get_local_method_names(Derived1), ('bloop',)) self.assertEqual( nn.module._get_local_method_names(Derived1, exclude=('bloop',)), () ) self.assertEqual(nn.module._get_local_method_names(Derived2), ()) def test_inheritance_dataclass_attribs(self): class Test(nn.Module): bar: int def __call__(self, x): return x class Test2(Test): baz: int def __call__(self, x): return x class Test3(Test): baz: int def __call__(self, x): return x class Test4(Test2): def __call__(self, x): return x key = random.key(0) x = jnp.ones((5,)) test1 = Test(bar=4) test2 = Test2(bar=4, baz=2) test3 = Test3(bar=4, baz=2) test4 = Test4(bar=5, baz=3) self.assertEqual(test1.init_with_output(key, x), (x, freeze({}))) self.assertEqual(test2.init_with_output(key, x), (x, freeze({}))) self.assertEqual(test3.init_with_output(key, x), (x, freeze({}))) self.assertEqual(test4.init_with_output(key, x), (x, freeze({}))) self.assertTrue(hasattr(test1, 'bar')) self.assertTrue(hasattr(test1, 'name')) self.assertTrue(hasattr(test1, 'parent')) self.assertTrue(hasattr(test2, 'bar')) self.assertTrue(hasattr(test2, 'baz')) self.assertTrue(hasattr(test2, 'name')) self.assertTrue(hasattr(test2, 'parent')) self.assertTrue(hasattr(test3, 'bar')) self.assertTrue(hasattr(test3, 'baz')) self.assertTrue(hasattr(test3, 'name')) self.assertTrue(hasattr(test3, 'parent')) self.assertTrue(hasattr(test4, 'bar')) self.assertTrue(hasattr(test4, 'baz')) self.assertTrue(hasattr(test4, 'name')) self.assertTrue(hasattr(test4, 'parent')) self.assertEqual( list(Test.__dataclass_fields__.keys()), ['bar', 'parent', 'name'] ) self.assertEqual( list(Test2.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'], ) self.assertEqual( list(Test3.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'], ) self.assertEqual( list(Test4.__dataclass_fields__.keys()), ['bar', 'baz', 'parent', 'name'], ) def test_get_suffix_value_pairs(self): for x in [(), [], {}, None, 0, set()]: self.assertEqual(nn.module._get_suffix_value_pairs(x), [('', x)]) self.assertEqual( nn.module._get_suffix_value_pairs({'a': 1, 'b': 2}), [('_a', 1), ('_b', 2)], ) self.assertEqual( nn.module._get_suffix_value_pairs([1, 2, 3]), [('_0', 1), ('_1', 2), ('_2', 3)], ) x1 = [nn.Dense(10), nn.relu, nn.Dense(10)] y1 = nn.module._get_suffix_value_pairs(x1) self.assertEqual(y1, [('_0', x1[0]), ('_1', x1[1]), ('_2', x1[2])]) x2 = {'a': 1, 'b': {'c': nn.Dense(10), 'd': nn.relu}} y2 = nn.module._get_suffix_value_pairs(x2) self.assertEqual( y2, [('_a', 1), ('_b_c', x2['b']['c']), ('_b_d', x2['b']['d'])] ) def test_mixed_list_assignment_in_setup(self): class Test(nn.Module): def setup(self): self.layers = [nn.Dense(10), nn.relu, nn.Dense(10)] def __call__(self, x): for lyr in self.layers: x = lyr(x) return x x = random.uniform(random.key(0), (5, 5)) variables = Test().init(random.key(0), jnp.ones((5, 5))) y = Test().apply(variables, x) m0 = variables['params']['layers_0']['kernel'] m1 = variables['params']['layers_2']['kernel'] self.assertTrue(jnp.all(y == jnp.dot(nn.relu(jnp.dot(x, m0)), m1))) def test_module_is_hashable(self): module_a = nn.Dense(10) module_a_2 = nn.Dense(10) module_b = nn.Dense(5) self.assertEqual(hash(module_a), hash(module_a_2)) self.assertNotEqual(hash(module_a), hash(module_b)) def test_module_custom_hash(self): class Test(nn.Module): x: int = 3 y: int = 5 def __hash__(self): return 42 + self.x module_a = Test(1, 2) module_a_2 = Test(1, 5) module_b = Test(2, 2) self.assertEqual(hash(module_a), hash(module_a_2)) self.assertNotEqual(hash(module_a), hash(module_b)) def test_module_with_scope_is_not_hashable(self): module_a = nn.Dense(10, parent=Scope({})) msg = "Can't call __hash__ on modules that hold variables." with self.assertRaisesWithLiteralMatch(TypeError, msg): hash(module_a) def test_module_trace(self): class MLP(nn.Module): act: Callable = nn.relu sizes: Sequence[int] = (3, 2) @nn.compact def __call__(self, x): for size in self.sizes: x = nn.Dense(size)(x) x = self.act(x) return repr(self) mlp = MLP() expected_trace = """MLP( # attributes act = relu sizes = (3, 2) # children Dense_0 = Dense( # attributes features = 3 use_bias = True dtype = None param_dtype = float32 precision = None kernel_init = init bias_init = zeros promote_dtype = promote_dtype dot_general = None dot_general_cls = None ) Dense_1 = Dense( # attributes features = 2 use_bias = True dtype = None param_dtype = float32 precision = None kernel_init = init bias_init = zeros promote_dtype = promote_dtype dot_general = None dot_general_cls = None ) )""" x = jnp.ones((1, 2)) trace, variables = mlp.init_with_output(random.key(0), x) self.assertEqual(trace, expected_trace) trace = mlp.apply(variables, x) self.assertEqual(trace, expected_trace) def test_default_params_rng_equivalence(self): class Model(nn.Module): @nn.compact def __call__(self, x, add_dropout=False, add_noise=False): x = nn.Dense(16)(x) x = nn.Dropout(0.5)(x, deterministic=not add_dropout) if add_noise: x += jax.random.normal(self.make_rng('params')) return x model = Model() key0, key1, key2 = jax.random.split(jax.random.key(0), 3) x = jax.random.normal(key0, (10, 8)) with self.assertRaisesRegex( ValueError, 'First argument passed to an init function should be a ``jax.PRNGKey``', ): model.init({'params': 'test'}, x) with self.assertRaisesRegex( errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test', ): model.init('test', x) # should not throw an error, since nn.Dropout will get an RNG key from the 'params' stream model.init(key1, x, add_dropout=True) v = model.init({'params': key1}, x) v2 = model.init(key1, x) jax.tree_util.tree_map(np.testing.assert_allclose, v, v2) for add_dropout, add_noise in [[True, False], [False, True], [True, True]]: out = model.apply( v, x, add_dropout=add_dropout, add_noise=add_noise, rngs={'params': key2}, ) out2 = model.apply( v, x, add_dropout=add_dropout, add_noise=add_noise, rngs=key2 ) np.testing.assert_allclose(out, out2) with self.assertRaisesRegex( ValueError, 'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``', ): model.apply(v, x, rngs={'params': 'test'}) with self.assertRaisesRegex( errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test', ): model.apply(v, x, rngs='test') def test_module_apply_method(self): class Foo(nn.Module): not_callable: int = 1 @nn.compact def __call__(self): pass def test(self): pass # We can use both instance and class methods in apply. Foo().apply({}, method=Foo.test) Foo().apply({}, method=Foo().test) # We also use a function that is not in the provided Module, although it # should have a first argument representing an instance of the Module (Foo # in this case). x = Foo().apply({}, method=lambda foo_instance: foo_instance) self.assertEqual(type(x), type(Foo())) # This is not allowed. msg = 'Cannot call apply()' with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg): Foo().apply({}, method=lambda: True) # string method names are also allowed. Foo().apply({}, method='test') # test same for init. Foo().init({}, method='test') # non-existent attribute names will yield AttributeError. with self.assertRaisesRegex(AttributeError, 'allowed_apply_fn'): Foo().apply({}, method='allowed_apply_fn') # test same for init. Foo().init({}, method='allowed_apply_fn') # attributes which are not callables yield TypeError. with self.assertRaisesRegex( TypeError, "'Foo.not_callable' must be a callable" ): Foo().apply({}, method='not_callable') # test same for init. Foo().init({}, method='not_callable') def test_module_apply_method_submodule(self): class Foo(nn.Module): bar: nn.Module @nn.compact def __call__(self, x): return self.bar(x) foo = Foo(nn.Dense(3)) variables = foo.init(jax.random.PRNGKey(0), jnp.zeros(3)) foo.apply(variables, jnp.zeros(3), method='bar') def test_call_unbound_compact_module_methods(self): dense = Dense(3) msg = r'Can\'t call compact methods on unbound modules' with self.assertRaisesRegex(errors.CallCompactUnboundModuleError, msg): dense(jnp.ones((1,))) def test_call_unbound_has_variable(self): class EmptyModule(nn.Module): def foo(self): self.has_variable('bar', 'baz') empty = EmptyModule() with self.assertRaisesRegex(ValueError, 'variable.*unbound module'): empty.foo() def test_call_unbound_make_rng(self): class EmptyModule(nn.Module): def foo(self): self.make_rng('bar') empty = EmptyModule() with self.assertRaisesRegex(ValueError, 'RNGs.*unbound module'): empty.foo() def test_call_unbound_variables(self): class EmptyModule(nn.Module): def foo(self): self.variables empty = EmptyModule() with self.assertRaisesRegex(ValueError, 'variables.*unbound module'): empty.foo() def test_call_unbound_noncompact_module_methods(self): class EmptyModule(nn.Module): foo: int = 3 def bar(self): return self.foo empty = EmptyModule() # It's fine to call methods of unbound methods that don't depend on # attributes defined during `setup`. self.assertEqual(empty.bar(), 3) def test_call_unbound_noncompact_module_methods_depending_on_setup(self): class EmptyModule(nn.Module): def setup(self): self.foo = 2 def bar(self): return self.foo empty = EmptyModule() msg = r'"EmptyModule" object has no attribute "foo"' with self.assertRaisesRegex(AttributeError, msg): empty.bar() def test_module_with_attrs(self): class Foo(nn.Module): bar: nn.Dense = dataclasses.field(init=False) def setup(self): self.bar = nn.Dense(3) def __call__(self, x): return self.bar(x) foo = Foo() x = jnp.ones((2,)) variables = foo.init(random.key(0), x) self.assertEqual(variables['params']['bar']['kernel'].shape, (2, 3)) def test_noncompact_module_frozen(self): class Foo(nn.Module): def setup(self): self.i = 1 # This is allowed (for assigning submodules). def __call__(self): self.i = 2 # This is not allowed. msg = ( "Can't set i=2 for Module of type Foo: Module instance is frozen " 'outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.key(0)) def test_compact_module_frozen(self): class Foo(nn.Module): @nn.compact def __call__(self): self.i = 2 msg = ( "Can't set i=2 for Module of type Foo: Module instance is frozen " 'outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.key(0)) def test_submodule_frozen(self): class Foo(nn.Module): @nn.compact def __call__(self): dense = nn.Dense(10) dense.features = 20 # <--- This is not allowed msg = ( "Can't set features=20 for Module of type Dense: Module instance " 'is frozen outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): Foo().init(random.key(0)) def test_module_call_not_implemented(self): class Foo(nn.Module): pass msg = '"Foo" object has no attribute "__call__"' with self.assertRaisesRegex(AttributeError, msg): Foo().init(random.key(0)) def test_is_mutable_collection(self): class EmptyModule(nn.Module): def __call__(self): return self.is_mutable_collection('test') empty = EmptyModule() self.assertTrue(empty.apply({}, mutable=['test'])[0]) self.assertFalse(empty.apply({}, mutable=False)) def test_module_lazy_getattr_setup(self): class A(nn.Module): def setup(self): self.d = nn.Dense(2) def __call__(self, x): return self.d(x) class B(nn.Module): def setup(self): self.a = A() def __call__(self, x): y1 = self.a.d(x) y2 = self.a(x) return y1, y2 key = random.key(0) x = jnp.ones((2,)) (y1, y2), unused_vars = B().init_with_output(key, x) np.testing.assert_array_equal(y1, y2) def test_module_lazy_dir_setup(self): class A(nn.Module): def setup(self): self.d = nn.Dense(2) def __call__(self, x): return self.d(x) class B(nn.Module): def setup(self): self.a = A() def __call__(self, x): assert 'd' in dir(self.a) y1 = self.a.d(x) y2 = self.a(x) return y1, y2 key = random.key(0) x = jnp.ones((2,)) _ = B().init_with_output(key, x) def test_module_unbound_getattr(self): class A(nn.Module): def setup(self): b = B() b.c # B is unbound because it is not yet assigned to an attribute. self.b = b def __call__(self): pass class B(nn.Module): def setup(self): self.c = nn.Dense(2) msg = '"B" object has no attribute "c"' with self.assertRaisesRegex(AttributeError, msg): A().init(random.key(0)) def test_unbound_setup_call(self): setup_called = False class A(nn.Module): def setup(self): nonlocal setup_called setup_called = True def test(self): pass A().test() self.assertFalse(setup_called) def test_module_pass_as_attr(self): class A(nn.Module): def setup(self): self.b = B(nn.Dense(2)) def __call__(self, x): return self.b(x) class B(nn.Module): foo: Any def __call__(self, x): return self.foo(x) variables = A().init(random.key(0), jnp.ones((1,))) var_shapes = jax.tree_util.tree_map(jnp.shape, variables) ref_var_shapes = { 'params': { 'b': { 'foo': { 'bias': (2,), 'kernel': (1, 2), } }, }, } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_module_pass_in_closure(self): a = nn.Dense(2) class B(nn.Module): def setup(self): self.foo = a def __call__(self, x): return self.foo(x) variables = B().init(random.key(0), jnp.ones((1,))) var_shapes = jax.tree_util.tree_map(jnp.shape, variables) ref_var_shapes = { 'params': { 'foo': { 'bias': (2,), 'kernel': (1, 2), } }, } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) self.assertIsNone(a.name) def test_toplevel_submodule_adoption(self): class Encoder(nn.Module): n_layers: int ch: int def setup(self): self.layers = [nn.Dense(self.ch) for _ in range(self.n_layers)] def __call__(self, x): for layer in self.layers: x = layer(x) x = nn.relu(x) return x class Model(nn.Module): encoder: nn.Module n_out: int def setup(self): self.dense_out = nn.Dense(self.n_out) def __call__(self, x): x = self.encoder(x) return self.dense_out(x) # Define model. encoder = Encoder(n_layers=1, ch=8) model = Model(encoder=encoder, n_out=5) # Initialize. key = jax.random.key(0) x = random.uniform(key, (4, 4)) variables = model.init(key, x) y = model.apply(variables, x) self.assertEqual(y.shape, (4, 5)) var_shapes = jax.tree_util.tree_map(jnp.shape, variables) ref_var_shapes = { 'params': { 'dense_out': { 'bias': (5,), 'kernel': (8, 5), }, 'encoder': { 'layers_0': { 'bias': (8,), 'kernel': (4, 8), }, }, }, } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_toplevel_submodule_adoption_pytree(self): class A(nn.Module): @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return c, x class B(nn.Module): A: Any @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) unused_a = A() a_pytree = {'foo': A(), 'bar': A()} b = B(a_pytree) key = random.key(0) x = jnp.ones((2, 2)) params = B(a_pytree).init(key, x, x) unused_y, counters = b.apply(params, x, x, mutable='counter') ref_counters = { 'counter': { 'A_bar': { 'i': jnp.array(2.0), }, 'A_foo': { 'i': jnp.array(2.0), }, }, } self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), counters, ref_counters, ) ) ) def test_toplevel_submodule_adoption_sharing(self): dense = functools.partial(nn.Dense, use_bias=False) class A(nn.Module): @nn.compact def __call__(self, x): return dense(2)(x) class B(nn.Module): a: nn.Module @nn.compact def __call__(self, x): return dense(2)(x) + self.a(x) class C(nn.Module): a: nn.Module b: nn.Module @nn.compact def __call__(self, x): return dense(2)(x) + self.b(x) + self.a(x) key = random.key(0) x = jnp.ones((2, 2)) a = A() b = B(a) c = C(a, b) p = c.init(key, x) var_shapes = jax.tree_util.tree_map(jnp.shape, p) ref_var_shapes = { 'params': { 'Dense_0': { 'kernel': (2, 2), }, 'a': { 'Dense_0': { 'kernel': (2, 2), }, }, 'b': { 'Dense_0': { 'kernel': (2, 2), }, }, }, } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_toplevel_named_submodule_adoption(self): dense = functools.partial(nn.Dense, use_bias=False) class A(nn.Module): def setup(self): self.dense = dense(4) def __call__(self, x): return self.dense(x) class B(nn.Module): a: A def setup(self): self.proj = dense(6) def __call__(self, x): return self.proj(self.a(x)) a = A(name='foo') b = B(a=a) k = jax.random.key(0) x = jnp.zeros((5, 5)) init_vars = b.init(k, x) var_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) if config.flax_preserve_adopted_names: ref_var_shapes = { 'params': { 'foo': { 'dense': { 'kernel': (5, 4), }, }, 'proj': { 'kernel': (4, 6), }, }, } else: ref_var_shapes = { 'params': { 'a': { 'dense': { 'kernel': (5, 4), }, }, 'proj': { 'kernel': (4, 6), }, }, } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) def test_toplevel_submodule_pytree_adoption_sharing(self): class A(nn.Module): @nn.compact def __call__(self, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return x class B(nn.Module): A: Any @nn.compact def __call__(self, x): return self.A['foo'](x) + self.A['bar'](x) + self.A['baz'](x) key = random.key(0) x = jnp.ones((2, 2)) a = A() a_pytree = {'foo': a, 'bar': a, 'baz': a} b = B(a_pytree) params = b.init(key, x) _, counters = b.apply(params, x, mutable='counter') ref_counters = { 'counter': { 'A_bar': { 'i': jnp.array(6.0), }, }, } self.assertTrue(tree_equals(counters, ref_counters)) def test_inner_class_def(self): class X(nn.Module): class Hyper(struct.PyTreeNode): a: int hyper: Hyper @nn.compact def __call__(self, x): return x + 1 self.assertIsInstance(X.Hyper(a=1), X.Hyper) def test_sow(self): class Foo(nn.Module): @nn.compact def __call__(self, x, **sow_args): self.sow('intermediates', 'h', x, **sow_args) self.sow('intermediates', 'h', 2 * x, **sow_args) return 3 * x variables = Foo().init(random.key(0), 1) # During init we should not collect intermediates by default... self.assertNotIn('intermediates', variables) # ...unless we override mutable. variables = Foo().init(random.key(0), 1, mutable=True) self.assertEqual(variables, {'intermediates': {'h': (1, 2)}}) _, state = Foo().apply({}, 1, mutable=['intermediates']) self.assertEqual(state, {'intermediates': {'h': (1, 2)}}) _, state = Foo().apply( {}, 1, init_fn=lambda: 0, reduce_fn=lambda a, b: a + b, mutable=['intermediates'], ) self.assertEqual(state, {'intermediates': {'h': 3}}) self.assertEqual(Foo().apply({}, 1), 3) def test_capture_intermediates(self): class Bar(nn.Module): def test(self, x): return x + 1 class Foo(nn.Module): @nn.compact def __call__(self, x): return Bar().test(x) + 1 _, state = Foo().apply({}, 1, capture_intermediates=True) self.assertEqual(state, {'intermediates': {'__call__': (3,)}}) fn = lambda mdl, _: isinstance(mdl, Bar) _, state = Foo().apply({}, 1, capture_intermediates=fn) self.assertEqual(state, {'intermediates': {'Bar_0': {'test': (2,)}}}) def test_perturb(self): class Foo(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(10)(x) x = self.perturb('before_multiply', x) x = 4 * x x = self.perturb('after_multiply', x) return x def loss(params, perturbations, inputs, targets): variables = {'params': params, 'perturbations': perturbations} preds = Foo().apply(variables, inputs) return jnp.square(preds - targets).mean() x = jax.random.uniform(jax.random.key(1), shape=(10,)) y = jax.random.uniform(jax.random.key(2), shape=(10,)) variables = Foo().init(jax.random.key(0), x) intm_grads = jax.grad(loss, argnums=1)( variables['params'], variables['perturbations'], x, y ) # activation * 4 so reverse gradient also * 4 self.assertTrue( all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']) ) def test_perturb_setup(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(10) def __call__(self, x): x = self.a(x) x = self.perturb('before_multiply', x) x = 4 * x x = self.perturb('after_multiply', x) return x def loss(params, perturbations, inputs, targets): variables = {'params': params, 'perturbations': perturbations} preds = Foo().apply(variables, inputs) return jnp.square(preds - targets).mean() x = jax.random.uniform(jax.random.key(1), shape=(10,)) y = jax.random.uniform(jax.random.key(2), shape=(10,)) variables = Foo().init(jax.random.key(0), x) intm_grads = jax.grad(loss, argnums=1)( variables['params'], variables['perturbations'], x, y ) # activation * 4 so reverse gradient also * 4 self.assertTrue( all(intm_grads['after_multiply'] * 4 == intm_grads['before_multiply']) ) def test_perturb_noop(self): class Foo(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(10)(x) x = self.perturb('before_multiply', x) x = 4 * x x = self.perturb('after_multiply', x) return x x = jax.random.uniform(jax.random.key(1), shape=(10,)) module = Foo() variables = module.init(jax.random.key(0), x) params = variables['params'] perturbations = variables['perturbations'] # check no error if perturbations is not passed module.apply({'params': params}, x) # check errors if perturbations is passed but empty with self.assertRaisesRegex(ValueError, 'Perturbation collection'): module.apply({'params': params, 'perturbations': {}}, x) # check no error if perturbations is passed and not empty module.apply({'params': params, 'perturbations': perturbations}, x) def test_functional_apply(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(3) self.b = nn.Dense(1) def f(foo, x): x = foo.a(x) return foo.b(x) foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) f_apply = nn.apply(f, foo) y1, variables = f_init(random.key(0), x) y2 = f_apply(variables, x) self.assertEqual(y1, y2) def test_bind(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(3) self.b = nn.Dense(1) def f(foo, x): x = foo.a(x) return foo.b(x) foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) y1, variables = f_init(random.key(0), x) y2 = f(foo.bind(variables), x) self.assertEqual(y1, y2) def test_bind_stateful(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(3) self.bn = nn.BatchNorm() self.b = nn.Dense(1) def f(foo, x): x = foo.a(x) x = foo.bn(x, use_running_average=False) return foo.b(x) foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) y1, variables = f_init(random.key(0), x) foo_b = foo.bind(variables, mutable='batch_stats') y2 = f(foo_b, x) y3, new_state = nn.apply(f, foo, mutable='batch_stats')(variables, x) self.assertEqual(y1, y2) self.assertEqual(y2, y3) bs_1 = new_state['batch_stats'] bs_2 = foo_b.variables['batch_stats'] for x, y in zip( jax.tree_util.tree_leaves(bs_1), jax.tree_util.tree_leaves(bs_2) ): np.testing.assert_allclose(x, y) def test_unbind(self): class Foo(nn.Module): def setup(self): self.encoder = nn.Dense(4) self.decoder = nn.Dense(2) def __call__(self, x): x = self.encoder(x) return self.decoder(x) foo = Foo() x = jnp.ones((2,)) variables = foo.init(random.key(0), x) encoder, encoder_vars = foo.bind(variables).encoder.unbind() decoder, decoder_vars = foo.bind(variables).decoder.unbind() self.assertIsInstance(encoder, nn.Dense) self.assertEqual(encoder.features, 4) self.assertIsInstance(decoder, nn.Dense) self.assertEqual(decoder.features, 2) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), variables['params']['encoder'], encoder_vars['params'], ) ) ) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), variables['params']['decoder'], decoder_vars['params'], ) ) ) def test_bind_unbind_equality(self): class Foo(nn.Module): sub_module: Any @nn.compact def __call__(self, x): x = nn.Dense(2)(x) return self.sub_module(x) sub_module = Foo(nn.Dense(3)) module = Foo(sub_module) x = jnp.ones((1, 2)) variables = module.init(jax.random.PRNGKey(0), x) bound_module = module.bind(variables) self.assertTrue((module.apply(variables, x) == bound_module(x)).all()) new_module, new_variables = bound_module.unbind() self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), variables, new_variables ) ) ) self.assertEqual(module, new_module) def test_passing_mutable_variables(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(2)(x) x = jnp.ones((3,)) variables = Foo().init(random.key(0), x) y = Foo().apply(variables, x) self.assertEqual(y.shape, (2,)) def test_super_compact(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(4)(x) class Bar(Foo): @nn.compact def __call__(self, x): y = super().__call__(x) return nn.Dense(3)(y) k = random.key(0) x = jnp.ones((4, 7)) variables = Bar().init(k, x) shapes = jax.tree_util.tree_map(np.shape, variables['params']) self.assertEqual( shapes, { 'Dense_0': {'kernel': (7, 4), 'bias': (4,)}, 'Dense_1': {'kernel': (4, 3), 'bias': (3,)}, }, ) y = Bar().apply(variables, x) self.assertEqual(y.shape, (4, 3)) def test_super_setup(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(4) class Bar(Foo): def setup(self): super().setup() self.b = nn.Dense(3) def __call__(self, x): y = self.a(x) return self.b(y) k = random.key(0) x = jnp.ones((4, 7)) variables = Bar().init(k, x) y = Bar().apply(variables, x) self.assertEqual(y.shape, (4, 3)) def test_freeze_attr(self): class Foo(NamedTuple): a: int b: int self.assertEqual(nn.module._freeze_attr([1, 2]), (1, 2)) xs = nn.module._freeze_attr(Foo(1, 2)) self.assertEqual(xs, (1, 2)) self.assertEqual( type(xs), Foo ) # equality test for NamedTuple doesn't check class! def test_generic_multiple_inheritance(self): T = TypeVar('T') class MyComponent(nn.Module, Generic[T]): pass class MyModule(nn.Module): submodule: MyComponent[jnp.ndarray] class MyComponent2(Generic[T], nn.Module): pass class MyModule2(nn.Module): submodule: MyComponent2[jnp.ndarray] def test_jit_rng_equivalance(self): model = nn.fold_rngs(nn.Dense)(1, use_bias=False) jit_model = nn.jit(nn.Dense)(1, use_bias=False) param = model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] param_2 = jit_model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] self.assertEqual(param, param_2) def test_rng_reuse_after_rewind(self): class C(nn.Module): @nn.compact def __call__(self): # Some module that has dropouts in it, in general, # it does more than just dropout! return self.make_rng('dropout') class A(nn.Module): @nn.compact def __call__(self): # Some module that has dropouts in it, in general, # it does more than just dropout! return C()() class B(nn.Module): @nn.compact def __call__(self): a = A() x0 = a() x1 = a() return jnp.all(x0 == x1) k = random.key(0) rng_equals = B().apply({}, rngs={'dropout': k}) self.assertFalse(rng_equals) def test_module_get_put_has_variable(self): class A(nn.Module): @nn.compact def __call__(self, x): self.put_variable('test_col', 'a', x) assert self.has_variable('test_col', 'a') return self.get_variable('test_col', 'a') class B(nn.Module): def __call__(self, x): self.put_variable('test_col', 'a', x) assert self.has_variable('test_col', 'a') return self.get_variable('test_col', 'a') class C(nn.Module): def setup(self): self.put_variable( 'test_col', 'a', jnp.ones( 2, ), ) assert self.has_variable('test_col', 'a') def __call__(self): return self.get_variable('test_col', 'a') x = jnp.ones((2,)) y, vs = A().apply({}, x, mutable=['test_col']) np.testing.assert_array_equal(x, y) np.testing.assert_array_equal(x, vs['test_col']['a']) y, vs = B().apply({}, x, mutable=['test_col']) np.testing.assert_array_equal(x, y) np.testing.assert_array_equal(x, vs['test_col']['a']) y, vs = C().apply({}, mutable=['test_col']) np.testing.assert_array_equal(y, jnp.ones((2,))) np.testing.assert_array_equal(y, vs['test_col']['a']) def test_generic_module(self): # See https://github.com/google/flax/issues/1899 T = TypeVar('T') class C(nn.Module, Generic[T]): def f(self, t: T) -> T: return t class D(nn.Module): def setup(self): unused_c = C[Any]() def __call__(self) -> None: pass rngs = {} D().init(rngs) def test_modifying_attribs_in_post_init(self): class Foo(nn.Module): love: int = 99 def __post_init__(self): self.hate = 100 - self.love super().__post_init__() foo = Foo() self.assertEqual(foo.love, 99) self.assertEqual(foo.hate, 1) class Bar(nn.Module): love: int = 99 def __post_init__(self): self.love = 101 super().__post_init__() bar = Bar() self.assertEqual(bar.love, 101) def test_has_rng(self): class Foo(nn.Module): def __call__(self): return self.has_rng('bar') foo = Foo() with self.assertRaisesRegex(ValueError, 'RNGs.*unbound module'): foo() k = random.key(0) self.assertTrue(foo.apply({}, rngs={'bar': k})) self.assertFalse(foo.apply({}, rngs={'baz': k})) def test_is_initializing(self): class Foo(nn.Module): def __call__(self): return self.is_initializing() foo = Foo() k = random.key(0) self.assertTrue(foo.init_with_output(k)[0]) self.assertFalse(foo.apply({})) def test_throws_invalid_instance_module_error(self): class B(nn.Module): @nn.compact def __call__(self, x): return x k = random.key(0) x = random.uniform(random.key(1), (2,)) with self.assertRaises(errors.InvalidInstanceModuleError): B.init(k, x) # B is module class, not B() a module instance with self.assertRaises(errors.InvalidInstanceModuleError): B.init_with_output(k, x) with self.assertRaises(errors.InvalidInstanceModuleError): B.apply( {}, x ) # similar issue w. apply called on class instead of instance. with self.assertRaises(errors.InvalidInstanceModuleError): B.bind( {}, x ) # similar issue w. apply called on class instead of instance. def test_throws_incorrect_post_init_override_error(self): class A(nn.Module): x: float def __post_init__(self): self.x_square = self.x**2 @nn.compact def __call__(self, input): return input + 3 r = A(x=3) with self.assertRaises(errors.IncorrectPostInitOverrideError): r.init(jax.random.key(2), jnp.ones(3)) def test_deepcopy_unspecified_parent(self): parent_parameter = inspect.signature(DummyModule).parameters['parent'] unspecified_parent = parent_parameter.default self.assertIs(unspecified_parent, copy.copy(unspecified_parent)) self.assertIs(unspecified_parent, copy.deepcopy(unspecified_parent)) def test_type_hints(self): class Network(nn.Module): layers: int type_hints = get_type_hints(Network) self.assertEqual(type_hints['layers'], int) def test_incorrect_property(self): class Foo(nn.Module): @property def prop(self): return self.non_existent def __call__(self): return self.prop foo = Foo() with self.assertRaisesRegex( errors.DescriptorAttributeError, 'Trying to access a property that' ): foo.apply({}) def test_custom_descriptor(self): class Descriptor: def __get__(self, obj, objtype=None): return 10 class Foo(nn.Module): prop = Descriptor() def __call__(self): return self.prop foo = Foo() res = foo.apply({}) self.assertEqual(res, 10) def test_custom_descriptor_error(self): class Descriptor: def __get__(self, obj, objtype=None): return obj.non_existent class Foo(nn.Module): prop = Descriptor() def __call__(self): return self.prop foo = Foo() with self.assertRaisesRegex( errors.DescriptorAttributeError, 'Trying to access a property that' ): foo.apply({}) def test_nested_external_modules(self): class Baz(nn.Module): a: int def setup(self): self.b = self.param('b', lambda k: 2) def __call__(self, x): return x + self.a * self.b class Bar(nn.Module): baz: Baz def __call__(self, x): return self.baz(x) class Foo(nn.Module): def setup(self): self.bar = Bar(baz=Baz(a=1)) def __call__(self, x): return self.bar.baz(x) module = Foo() y, variables = module.init_with_output(jax.random.key(0), 1) self.assertEqual(y, 3) def test_getattribute_triggers_setup(self): class B(nn.Module): def setup(self): self.p1 = self.param('p1', lambda k: jnp.ones((2,))) def fn1(self, x): return self.p1 + x class A(nn.Module): b: nn.Module def __call__(self, x): return self.b.fn1(x) a = A(b=B()) k = random.key(0) x = jnp.zeros((2,)) vs = nn.init(lambda a, x: a(x), a)(k, x) y = nn.apply(lambda a, x: a.b.fn1(x), a)(vs, x) np.testing.assert_array_equal(y, jnp.ones((2,))) def test_nested_sequential_in_call(self): class Foo(nn.Module): def setup(self): self.seq = nn.Sequential([nn.Dense(10) for i in range(10)]) def __call__(self, x): # try calling only the first layer return self.seq.layers[0](x) module = Foo() variables = module.init(jax.random.key(0), jnp.ones((1, 10))) def test_setup_called_bounded_submodules(self): module = nn.Sequential( [ nn.Sequential( [ nn.Dense(2), nn.relu, nn.Dense(2), ] ), nn.relu, nn.Dense(2), ] ) x = jnp.ones((1, 3)) variables = module.init(jax.random.key(0), x) bound_module = module.bind(variables) self.assertIsNotNone(bound_module.layers[0].layers[0].scope) self.assertIsNotNone(bound_module.layers[0].layers[2].scope) self.assertIsNotNone(bound_module.layers[2].scope) def test_call_bounded_toplevel_mutable(self): class Bar(nn.Module): a: int def setup(self): self.b = self.param('b', lambda k: 1) def __call__(self, x): return x + self.a * self.b class Foo(nn.Module): bars: Sequence[Bar] def __call__(self, x): for bar in self.bars: x = bar(x) return x module = Foo(bars=[]) module.bars = [Bar(a=1)] variables = module.init(jax.random.key(0), jnp.ones(())) bound_module = module.bind(variables) bar1 = bound_module.bars[0] self.assertIsNotNone(bar1.scope) def test_nested_init(self): class Baz(nn.Module): a: int def setup(self): self.b = self.param('b', lambda k: jnp.ones(())) def __call__(self, x): return x + self.a * self.b class Bar(nn.Module): baz: Baz def setup(self): a = 1 def __call__(self, x): return self.baz(x) class Foo(nn.Module): def setup(self): self.bar: Bar = Bar(baz=Baz(a=1)) def __call__(self, x): # y = self.bar(x) y, bar_vars = self.bar.init_with_output(jax.random.key(0), x) return y, bar_vars # create foo module = Foo() # run foo (y, bar_vars), variables = module.init_with_output( jax.random.key(0), jnp.ones(()) ) self.assertIn('params', bar_vars) def test_nested_shared(self): class Shared(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(1)(x) class Unshared(nn.Module): shared: nn.Module def __call__(self, x): return self.shared(x) class Super(nn.Module): a: nn.Module b: nn.Module def run_a(self, x): return self.a(x) def run_b(self, x): return self.b(x) def __call__(self, x): return self.a(x) + self.b(x) sh = Shared() a = Unshared(shared=sh) b = Unshared(shared=sh) module = Super(a=a, b=b) rng = jax.random.key(0) params = module.init(rng, jnp.ones(1))['params'] module.apply({'params': params}, jnp.ones(1)) # works as expected module.apply( {'params': params}, jnp.ones(1), method='run_a' ) # works as expected module.apply( {'params': params}, jnp.ones(1), method='run_b' ) # ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/b/shared/Dense_0" def test_repr(self): class Base1(nn.Module): a: int class Base2(nn.Module): b: str class Foo(Base2, Base1): c: float module = Foo(a=1, b='ok', c=3.0) str_rep = repr(module) self.assertIn('a = 1', str_rep) self.assertIn("b = 'ok'", str_rep) self.assertIn('c = 3.0', str_rep) def test_repr_should_not_cause_setup(self): class MLP(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(1)(x) return repr(self) class Foo(nn.Module): a: float b: MLP scope = Scope({}) module = Foo(parent=scope, a=1, b=MLP(parent=scope)) str_rep = repr(module) self.assertIn('a = 1', str_rep) self.assertEqual(module._state.setup_called, nn.module.SetupState.NEW) # repr() on a module should not cause inadvertent setup of submodules # i.e. module.b._state.setup_called should remain nn.module.SetupState.NEW # and not nn.module.SetupState.DONE self.assertEqual(module.b._state.setup_called, nn.module.SetupState.NEW) def test_kw_only(self): def create_kw_layers(): class BaseLayer(nn.Module, kw_only=True): base_multiplier: int | None = -1 class ChildLayer(BaseLayer): child_multiplier: int # Don't want to have to set a default argument! def __call__(self, x): return x * self.child_multiplier * self.base_multiplier return BaseLayer, ChildLayer if tuple(sys.version_info)[:3] < (3, 10, 0): with self.assertRaisesRegex(TypeError, 'not available before Py 3.10'): BaseLayer, ChildLayer = create_kw_layers() else: BaseLayer, ChildLayer = create_kw_layers() with self.assertRaisesRegex(TypeError, 'positional argument'): _ = BaseLayer(2) # Like in Python dataclass, `kw_only` is not inherited, so ChildLayer can # take positional arg. It takes BaseLayer's default kwargs though. np.testing.assert_equal(ChildLayer(8)(np.ones(10)), -8 * np.ones(10)) def test_positional_cannot_be_kw_only(self): class Foo(nn.Module): a: int Foo(1) # ok Foo(a=1) # ok with self.assertRaisesRegex( TypeError, r'takes 2 positional arguments but 3 were' ): Foo(1, None) Foo(a=1, parent=None) # type: ignore[call-arg] def test_module_path_empty(self): rngkey = jax.random.key(0) scope = Scope({}, {'params': rngkey}, mutable=['params']) m1 = DummyModule(parent=scope) self.assertEqual(m1.path, ()) scope = Scope({}, {'params': rngkey}, mutable=['params'], path=['root']) m2 = DummyModule(parent=scope) self.assertEqual(m2.path, ('root',)) m3 = DummyModule(parent=scope.rewound()) self.assertEqual(m3.path, ('root',)) def test_module_path_unbound_module_error(self): m1 = DummyModule() with self.assertRaisesRegex(ValueError, 'unbound module'): _ = m1.path def test_module_path_in_nested_module(self): module_paths = [] debug_paths = [] class A(nn.Module): def setup(self): self.b1 = B() self.b2 = B() self.c1 = C() module_paths.append(self.path) debug_paths.append(self.scope.debug_path) def __call__(self, x): return self.b1(x) + self.b2(x) + self.c1(x) class B(nn.Module): def setup(self): self.c1 = nn.remat(nn.remat(C))() self.c2 = C() module_paths.append(self.path) debug_paths.append(self.scope.debug_path) def __call__(self, x): return self.c1(x) + self.c2(x) class C(nn.Module): def setup(self): super().setup() if self.scope.__class__.__name__ != 'TestScope': module_paths.append(self.path) debug_paths.append(self.scope.debug_path) @nn.compact def __call__(self, x): return x a = A() k = random.key(0) x = random.uniform(random.key(42), (2,)) _ = a.init(k, x) expected_module_paths = [ (), ('b1',), ('b1', 'c1'), ('b1', 'c2'), ('b2',), ('b2', 'c1'), ('b2', 'c2'), ('c1',), ] expected_debug_paths = [ (), ('b1',), ('b1', 'remat(remat(c1))'), ('b1', 'c2'), ('b2',), ('b2', 'remat(remat(c1))'), ('b2', 'c2'), ('c1',), ] self.assertEqual(module_paths, expected_module_paths) self.assertEqual(debug_paths, expected_debug_paths) def test_intercept_methods(self): mod = IdentityModule(parent=None) x = jnp.ones([]) call_count = [] def add_one_interceptor(f, args, kwargs, context): call_count.append(None) self.assertLen(dataclasses.fields(context), 3) self.assertIs(context.module, mod) self.assertEqual(context.method_name, '__call__') self.assertEqual(context.orig_method(3), 3) self.assertEqual(args, (x,)) self.assertEmpty(kwargs) y = f(*args, **kwargs) return y + 1 y1 = mod(x) with nn.intercept_methods(add_one_interceptor): y2 = mod(x) y3 = mod(x) self.assertLen(call_count, 1) self.assertEqual(y1, 1) self.assertEqual(y2, 2) self.assertEqual(y3, 1) def test_intercept_methods_compact(self): class CompactModule(nn.Module): @compact def __call__(self, x): return nn.Dense(2)(x) mod = CompactModule() x = jnp.ones(shape=(1, 3)) variables = mod.init(jax.random.key(0), x) call_modules = [] def log_interceptor(f, args, kwargs, context): call_modules.append(context.module) self.assertLen(dataclasses.fields(context), 3) self.assertEqual(context.method_name, '__call__') self.assertEqual(args, (x,)) self.assertEmpty(kwargs) return f(*args, **kwargs) with nn.intercept_methods(log_interceptor): _ = mod.apply(variables, x) self.assertLen(call_modules, 2) self.assertIsInstance(call_modules[0], CompactModule) self.assertIsInstance(call_modules[1], nn.Dense) def test_intercept_methods_setup(self): class SetupModule(nn.Module): def setup(self): self.layer = nn.Dense(2) def __call__(self, x): return self.layer(x) mod = SetupModule() x = jnp.ones(shape=(1, 3)) variables = mod.init(jax.random.key(0), x) call_modules = [] log = [] def log_interceptor(f, args, kwargs, context): call_modules.append(context.module) log.append((context.method_name, args, kwargs)) return f(*args, **kwargs) with nn.intercept_methods(log_interceptor): _ = mod.apply(variables, x) self.assertLen(call_modules, 3) self.assertIsInstance(call_modules[0], SetupModule) self.assertIsInstance(call_modules[1], SetupModule) self.assertIsInstance(call_modules[2], nn.Dense) self.assertEqual( log, [('setup', (), {}), ('__call__', (x,), {}), ('__call__', (x,), {})] ) def test_intercept_methods_calling_underlying_optional(self): def do_nothing_interceptor(f, args, kwargs, context): del f, context self.assertEmpty(args) self.assertEmpty(kwargs) m = RaisesModule() with nn.intercept_methods(do_nothing_interceptor): m() with self.assertRaises(AssertionError): m() with nn.intercept_methods(do_nothing_interceptor): m() def test_intercept_methods_run_in_lifo_order(self): def op_interceptor(op): def _interceptor(f, args, kwargs, context): del context y = f(*args, **kwargs) return op(y) return _interceptor mod = IdentityModule(parent=None) x = 7 with ( nn.intercept_methods(op_interceptor(lambda a: a + 1)), nn.intercept_methods(op_interceptor(lambda a: a**2)), ): y = mod(x) self.assertEqual(y, (x**2) + 1) with ( nn.intercept_methods(op_interceptor(lambda a: a**2)), nn.intercept_methods(op_interceptor(lambda a: a + 1)), ): y = mod(x) self.assertEqual(y, (x + 1) ** 2) def test_intercept_methods_subclasses(self): class Foo(IdentityModule): def __call__(self, x): # pylint: disable=useless-parent-delegation return super().__call__(x) class Bar(Foo): def __call__(self, x): # pylint: disable=useless-parent-delegation return super().__call__(x) bar = Bar(parent=None) x = jnp.ones([]) called = [] def record_interceptor(f, args, kwargs, context): called.append(None) self.assertIs(context.module, bar) self.assertEqual(context.method_name, '__call__') self.assertEqual(args, (x,)) self.assertEmpty(kwargs) return f(*args, **kwargs) with nn.intercept_methods(record_interceptor): bar(x) # Bar.__call__, Foo.__call__ and IdenityModule.__call__ self.assertLen(called, 3) def test_intercept_methods_nested_module(self): class Foo(nn.Module): def __call__(self, x): return x class Bar(nn.Module): sub: nn.Module def __call__(self, x): return self.sub(x) foo = Foo() bar = Bar(sub=foo) x = jnp.ones([]) called = [] def record_interceptor(f, args, kwargs, context): called.append(context.module) self.assertEqual(context.method_name, '__call__') self.assertEqual(args, (x,)) self.assertEmpty(kwargs) return f(*args, **kwargs) with nn.intercept_methods(record_interceptor): bar(x) # bar.__call__ and foo.__call__ self.assertLen(called, 2) self.assertIs(called[0], bar) self.assertIs(called[1], foo) def test_cloudpickle_class(self): import cloudpickle class MyModule(nn.Module): pass a = MyModule() UnpickledMyModule = cloudpickle.loads(cloudpickle.dumps(MyModule)) b = UnpickledMyModule() def test_cloudpickle_module(self): from cloudpickle import cloudpickle_fast class NNModuleWithProperty(nn.Module): a: int b: str @property def my_property(self): return self.b * self.a m = NNModuleWithProperty(a=2, b='ok') with TemporaryDirectory() as tmpdir: filename = f'{tmpdir}/module.pkl' with open(filename, 'wb') as f: cloudpickle_fast.dump(m, f) with open(filename, 'rb') as f: obj_loaded = cloudpickle_fast.load(f) self.assertEqual(obj_loaded.a, 2) self.assertEqual(obj_loaded.b, 'ok') self.assertEqual(obj_loaded.my_property, 'okok') def test_module_paths(self): class Bar(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(3)(x) x = nn.Dense(4)(x) return x class Foo(nn.Module): @nn.compact def __call__(self, x): x = Bar()(x) x = nn.Dense(5)(x) return x x = jnp.ones((1, 2)) m = Foo() module_paths = m.module_paths(random.key(0), x) # assert all module are unbounded for module in module_paths.values(): self.assertIsNone(module.scope) # test paths self.assertIn('', module_paths) self.assertEqual(type(module_paths['']), Foo) self.assertIn('Dense_0', module_paths) self.assertEqual(type(module_paths['Dense_0']), nn.Dense) self.assertIn('Bar_0', module_paths) self.assertEqual(type(module_paths['Bar_0']), Bar) self.assertIn('Bar_0/Dense_0', module_paths) self.assertEqual(type(module_paths['Bar_0/Dense_0']), nn.Dense) self.assertIn('Bar_0/Dense_1', module_paths) self.assertEqual(type(module_paths['Bar_0/Dense_1']), nn.Dense) def test_init_apply_default_rng(self): class SubModel(nn.Module): @nn.compact def __call__(self, x, apply_dropout): x = nn.Dense(8)(x) x = nn.Dropout(0.8)(x, deterministic=not apply_dropout) p = self.param( 'parameter', lambda key, shape: jax.random.normal(key, shape), x.shape ) noise = jax.random.normal(self.make_rng('noise'), x.shape) return x * p + noise class Model(nn.Module): @nn.compact def __call__(self, x, apply_dropout): x = nn.Dense(16)(x) x = SubModel()(x, apply_dropout) x = nn.Dropout(0.5)(x, deterministic=not apply_dropout) v = self.variable( 'var_collection', 'variable', lambda shape: jax.random.normal(self.make_rng('var_rng'), shape), x.shape, ) noise = jax.random.normal(self.make_rng('noise'), x.shape) return x * v.value + noise key0, key1, key2 = jax.random.split(jax.random.key(0), 3) x = jax.random.normal(key0, (10, 4)) model = Model() # test init equality default_variables = model.init({'params': key1}, x, apply_dropout=False) rngs = {'params': key1, 'var_rng': key1, 'noise': key1} explicit_variables = model.init(rngs, x, apply_dropout=False) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), default_variables, explicit_variables, ) ) ) # test init inequality for rng_name in ('params', 'var_rng'): rngs[rng_name] = key2 explicit_variables = model.init(rngs, x, apply_dropout=False) self.assertFalse( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), default_variables, explicit_variables, ) ) ) rngs[rng_name] = key1 # test apply equality default_out = model.apply( default_variables, x, apply_dropout=True, rngs={'params': key1} ) rngs = {'dropout': key1, 'noise': key1} explicit_out = model.apply( default_variables, x, apply_dropout=True, rngs=rngs ) np.testing.assert_allclose(default_out, explicit_out) # test apply inequality for rng_name in ('dropout', 'noise'): rngs[rng_name] = key2 explicit_out = model.apply( default_variables, x, apply_dropout=True, rngs=rngs ) with self.assertRaises(AssertionError): np.testing.assert_allclose(default_out, explicit_out, atol=1e-1) rngs[rng_name] = key1 def test_default_make_rng(self): class SubModel(nn.Module): @nn.compact def __call__(self, x): noise = jax.random.normal(self.make_rng(), x.shape) return x + noise class Model(nn.Module): @nn.compact def __call__(self, x): x = SubModel()(x) noise = jax.random.normal(self.make_rng(), x.shape) return x + noise key0, key1 = jax.random.split(jax.random.key(0), 2) x = jax.random.normal(key0, (10, 4)) default_out = Model().apply({}, x, rngs={'params': key1}) class SubModel(nn.Module): @nn.compact def __call__(self, x): noise = jax.random.normal(self.make_rng('params'), x.shape) return x + noise class Model(nn.Module): @nn.compact def __call__(self, x): x = SubModel()(x) noise = jax.random.normal(self.make_rng('params'), x.shape) return x + noise explicit_out = Model().apply({}, x, rngs={'params': key1}) np.testing.assert_allclose(default_out, explicit_out) def test_default_rng_error(self): class Model(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(2)(x) model = Model() with self.assertRaisesRegex( errors.InvalidRngError, 'Dense_0 needs PRNG for "params"' ): model.init({'other_rng_stream': jax.random.key(0)}, jnp.ones((1, 3))) class Model(nn.Module): @nn.compact def __call__(self, x): return x + jax.random.normal(self.make_rng(), x.shape) model = Model() with self.assertRaisesRegex( errors.InvalidRngError, 'None needs PRNG for "params"' ): model.init({'other_rng_stream': jax.random.key(0)}, jnp.ones((1, 3))) def test_compact_name_scope(self): class Foo(nn.Module): @nn.compact_name_scope def up(self, x): return nn.Dense(3)(x) @nn.compact_name_scope def down(self, x): return nn.Dense(3)(x) @nn.compact def __call__(self, x): return self.up(x) + self.down(x) + nn.Dense(3)(x) m = Foo() x = jnp.ones((1, 2)) self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'}) variables = m.init(random.key(0), x) params = variables['params'] self.assertIn('Dense_0', params) self.assertIn('down', params) self.assertIn('up', params) self.assertIn('Dense_0', params['down']) self.assertIn('Dense_0', params['up']) y = m.apply(variables, x) y_up = m.apply(variables, x, method='up') y_down = m.apply(variables, x, method='down') assert y.shape == (1, 3) assert y_up.shape == (1, 3) assert y_down.shape == (1, 3) def test_compact_name_scope_outside_compact(self): class Foo(nn.Module): @nn.compact_name_scope def up(self, x): return nn.Dense(3)(x) @nn.compact_name_scope def down(self, x): return nn.Dense(3)(x) def __call__(self, x): return self.up(x) + self.down(x) m = Foo() x = jnp.ones((1, 2)) self.assertEqual(set(m._compact_name_scope_methods), {'up', 'down'}) variables = m.init(random.key(0), x) params = variables['params'] self.assertIn('down', params) self.assertIn('up', params) self.assertIn('Dense_0', params['down']) self.assertIn('Dense_0', params['up']) y = m.apply(variables, x) y_up = m.apply(variables, x, method='up') y_down = m.apply(variables, x, method='down') assert y.shape == (1, 3) assert y_up.shape == (1, 3) assert y_down.shape == (1, 3) class LeakTests(absltest.TestCase): def test_tracer_leaks(self): model = nn.Sequential([nn.Dense(50)]) @jax.jit @functools.partial(jax.vmap, in_axes=(0, None)) def sample_from_prior(rng, inp): params = model.init(rng, np.zeros((10, 50))) out = model.apply(params, inp) del params return out # disable manual gc.collect call in jax leak checker # so that we can test tracer leaks in ref-cycles. This is a # reasonable proxy for transiently leaked memory during # eager execution. with patch.object(gc, 'collect', return_value=0): with jax.checking_leaks(): for i in range(5): rngs = jax.random.split(jax.random.key(23), 100) out = sample_from_prior(rngs, np.ones((4, 50))) out.block_until_ready() del out, rngs class RelaxedNamingTests(absltest.TestCase): def test_relaxed_adoption(self): class Foo(nn.Module): @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) return x + p class Bar(nn.Module): sub: nn.Module def __call__(self, x): return self.sub(x) with set_config('flax_preserve_adopted_names', True): foo = Foo(name='foo') bar = Bar(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('foo' in vs['params'], 'relaxed naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'old policy naming failure') y = bar.apply(vs, x) def test_class_optional_adoption_name_preservation(self): class Foo(nn.Module): @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) return x + p class Bar1(nn.Module): sub: nn.Module preserve_adopted_names = True def __call__(self, x): return self.sub(x) class Bar2(nn.Module): sub: nn.Module preserve_adopted_names = False def __call__(self, x): return self.sub(x) with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar1(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('foo' in vs['params'], 'adoption naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', True): foo = Foo(name='foo') bar = Bar2(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'adoption naming failure') y = bar.apply(vs, x) def test_nested_class_optional_adoption_name_preservation(self): class Foo(nn.Module): @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) return x + p class Bar(nn.Module): sub: nn.Module preserve_adopted_names = True def __call__(self, x): return self.sub(x) class Baz(nn.Module): sub: nn.Module preserve_adopted_names = True def __call__(self, x): return self.sub(x) with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar(sub=foo, name='bar') baz = Baz(sub=bar) k = random.key(0) x = jnp.zeros((1,)) vs = baz.init(k, x) self.assertTrue('bar' in vs['params'], 'adoption naming failure') self.assertTrue('foo' in vs['params']['bar'], 'adoption naming failure') y = baz.apply(vs, x) def test_relaxed_adoption_still_conflict_checks(self): class Foo(nn.Module): @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) return x + p class Bar(nn.Module): sub1: nn.Module sub2: nn.Module def __call__(self, x): return self.sub(x) with set_config('flax_preserve_adopted_names', True): foo1 = Foo(name='foo') foo2 = Foo(name='foo') bar = Bar(sub1=foo1, sub2=foo2) k = random.key(0) x = jnp.zeros((1,)) with self.assertRaises(errors.NameInUseError): vs = bar.init(k, x) def test_relaxed_adoption_unnamed_adoptee(self): class Foo(nn.Module): @nn.compact def __call__(self, x): p = self.param('p', nn.initializers.zeros, x.shape) return x + p class Bar(nn.Module): sub: nn.Module def __call__(self, x): return self.sub(x) with set_config('flax_preserve_adopted_names', True): foo = Foo(name=None) bar = Bar(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'relaxed naming failure') y = bar.apply(vs, x) with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar(sub=foo) k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'old policy naming failure') y = bar.apply(vs, x) def test_relaxed_python_conflict(self): class Foo(nn.Module): dummy = 0 @nn.compact def __call__(self, x): p = self.param('dummy', nn.initializers.zeros, x.shape) return x + p foo = Foo(name='foo') k = random.key(0) x = jnp.zeros((1,)) vs = foo.init(k, x) def test_relaxed_intercollection_conflict(self): class Foo(nn.Module): @nn.compact def __call__(self, x): v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape) return x + v1.value + v2.value foo = Foo(name='foo') k = random.key(0) x = jnp.zeros((1,)) vs = foo.init(k, x) def test_relaxed_intercollection_conflict_set(self): class Foo(nn.Module): @nn.compact def __call__(self, x): v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape) v3 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) return x + v1.value + v2.value + v3.value foo = Foo(name='foo') k = random.key(0) x = jnp.zeros((1,)) with self.assertRaises(errors.NameInUseError): vs = foo.init(k, x) def test_internal_deep_clone(self): class Child(nn.Module): @nn.compact def __call__(self, x): w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) return x @ w class Parent(nn.Module): num_layers: int child_template: Child @nn.compact def __call__(self, x): for i in range(self.num_layers): x = self.child_template.clone( parent=self, _deep_clone=True, name=None )(x) return x model = Parent(num_layers=2, child_template=Child()) x = jnp.ones((32, 5)) variables = model.init(jax.random.key(0), x) output = model.apply(variables, x) self.assertTrue( variables['params']['Child_0']['w'].shape == variables['params']['Child_1']['w'].shape ) def test_copy_method(self): class Parent(nn.Module): @nn.compact def __call__(self, x): child = nn.Dense( 2, ) x = child(x) x = child.copy()(x) return x model = Parent() x = jnp.ones((2, 2)) variables = model.init(jax.random.key(0), x) output = model.apply(variables, x) self.assertTrue( variables['params']['Dense_0']['kernel'].shape == variables['params']['Dense_1']['kernel'].shape ) def test_copy_from_template(self): class Child(nn.Module): @nn.compact def __call__(self, x): w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) return x @ w class Parent(nn.Module): num_layers: int child_template: Child @nn.compact def __call__(self, x): for i in range(self.num_layers): x = self.child_template.copy()(x) for i in range(self.num_layers): x = self.child_template.copy(name=f'next_layer_{i}')(x) return x model = Parent(num_layers=2, child_template=Child()) x = jnp.ones((32, 5)) variables = model.init(jax.random.key(0), x) output = model.apply(variables, x) self.assertTrue( variables['params']['Child_0']['w'].shape == variables['params']['Child_1']['w'].shape ) self.assertIn('Child_0', variables['params']) self.assertIn('Child_1', variables['params']) self.assertIn('next_layer_0', variables['params']) self.assertIn('next_layer_1', variables['params']) self.assertNotIn('child_template', variables['params']) def test_nonstring_keys_in_dict_on_module(self): class MyEnum(str, enum.Enum): a = 'a' b = 'b' class MyModule(nn.Module): config: dict[MyEnum, int] def __call__(self, inputs): return inputs module = MyModule(config={MyEnum.a: 1, MyEnum.b: 2}) variables = module.init(jax.random.key(0), jnp.zeros([0])) class FrozenDictTests(absltest.TestCase): def test_frozendict_flag(self): with set_config('flax_return_frozendict', True): x = jnp.zeros((2, 3)) layer = nn.Dense(5) params = layer.init(random.key(0), x) self.assertTrue(isinstance(params, FrozenDict)) with set_config('flax_return_frozendict', False): x = jnp.zeros((2, 3)) layer = nn.Dense(5) params = layer.init(random.key(0), x) self.assertTrue(isinstance(params, dict)) class ShareScopeTest(absltest.TestCase): def test_basic(self): class DenseLoRA(nn.Module): inner: nn.Dense rank: int def setup(self): nn.share_scope(self, self.inner) @nn.compact def __call__(self, x: jax.Array): din, dout = x.shape[-1], self.inner.features A = self.param('A', nn.zeros_init(), (din, self.rank)) B = self.param('B', nn.zeros_init(), (self.rank, dout)) return self.inner(x) + x @ A @ B dense_lora = DenseLoRA(nn.Dense(10), rank=2) params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params'] self.assertIn('kernel', params) self.assertIn('bias', params) self.assertIn('A', params) self.assertIn('B', params) def test_child_scope(self): class DenseLoRA(nn.Module): rank: int def setup(self): self.child = nn.Dense(10) nn.share_scope(self, self.child) @nn.compact def __call__(self, x: jax.Array): din, dout = x.shape[-1], self.child.features A = self.param('A', nn.zeros_init(), (din, self.rank)) B = self.param('B', nn.zeros_init(), (self.rank, dout)) return self.child(x) + x @ A @ B dense_lora = DenseLoRA(rank=2) params = dense_lora.init(random.key(0), jnp.ones((1, 5)))['params'] self.assertIn('kernel', params) self.assertIn('bias', params) self.assertIn('A', params) self.assertIn('B', params) def test_in_compact(self): class DenseLoRA(nn.Module): rank: int def setup(self): self.child = nn.Dense(10) nn.share_scope(self, self.child) @nn.compact def __call__(self, x: jax.Array): din, dout = x.shape[-1], self.child.features A = self.param('A', nn.zeros_init(), (din, self.rank)) B = self.param('B', nn.zeros_init(), (self.rank, dout)) return self.child(x) + x @ A @ B class Model(nn.Module): @nn.compact def __call__(self, x: jax.Array): return DenseLoRA(rank=2)(x) model = Model() params = model.init(random.key(0), jnp.ones((1, 5)))['params'] self.assertIn('kernel', params['DenseLoRA_0']) self.assertIn('bias', params['DenseLoRA_0']) self.assertIn('A', params['DenseLoRA_0']) self.assertIn('B', params['DenseLoRA_0']) def test_adopt_child_name(self): class DenseLoRA(nn.Module): inner: nn.Dense rank: int def setup(self): nn.share_scope(self, self.inner) @nn.compact def __call__(self, x: jax.Array): din, dout = x.shape[-1], self.inner.features A = self.param('A', nn.zeros_init(), (din, self.rank)) B = self.param('B', nn.zeros_init(), (self.rank, dout)) return self.inner(x) + x @ A @ B class Model(nn.Module): @nn.compact def __call__(self, x: jax.Array): return DenseLoRA(nn.Dense(10), rank=2)(x) model = Model() params = model.init(random.key(0), jnp.ones((1, 5)))['params'] self.assertIn('kernel', params['Dense_0']) self.assertIn('bias', params['Dense_0']) self.assertIn('A', params['Dense_0']) self.assertIn('B', params['Dense_0']) def test_other_scope_is_none(self): class DenseLoRA(nn.Module): inner: nn.Dense rank: int def setup(self): nn.share_scope(self, self.inner) @nn.compact def __call__(self, x: jax.Array): din, dout = x.shape[-1], self.inner.features A = self.param('A', nn.zeros_init(), (din, self.rank)) B = self.param('B', nn.zeros_init(), (self.rank, dout)) return self.inner(x) + x @ A @ B class Model(nn.Module): def setup(self): # here Dense doesn't have a scope yet self.dense_lora = DenseLoRA(nn.Dense(10), rank=2) @nn.compact def __call__(self, x: jax.Array): return self.dense_lora(x) model = Model() params = model.init(random.key(0), jnp.ones((1, 5)))['params'] self.assertIn('kernel', params['dense_lora']) self.assertIn('bias', params['dense_lora']) self.assertIn('A', params['dense_lora']) self.assertIn('B', params['dense_lora']) def test_external_grandchild_scope_correct(self): class GrandChild(nn.Module): @nn.compact def __call__(self): return nn.Dense(50)(jnp.zeros(10)) class Child(nn.Module): child: GrandChild @nn.compact def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.child(*args, **kwargs) class Parent(nn.Module): main_child: Child def setup(self): nn.share_scope(self, self.main_child) @nn.compact def __call__(self, *args: Any, **kwargs: Any) -> Any: nn.Dense(10)(jnp.zeros(10)) r = self.main_child(*args, **kwargs) return r params = Parent(Child(GrandChild())).init(jax.random.key(0)) self.assertNotIn('main_child', params['params']) self.assertIn('child', params['params']) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_recurrent_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Recurrent tests.""" import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest from flax import linen as nn from flax.linen.recurrent import flip_sequences # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class RNNTest(absltest.TestCase): def test_rnn_basic_forward(self): batch_size = 10 seq_len = 40 channels_in = 5 channels_out = 15 rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) self.assertIn( layer_params['kernel'].shape[0], [channels_in, channels_out] ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_multiple_batch_dims(self): batch_dims = (10, 11) seq_len = 40 channels_in = 5 channels_out = 15 rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((*batch_dims, seq_len, channels_in)) variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) self.assertEqual(ys.shape, (*batch_dims, seq_len, channels_out)) for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) self.assertIn( layer_params['kernel'].shape[0], [channels_in, channels_out] ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_unroll(self): batch_size = 10 seq_len = 40 channels_in = 5 channels_out = 15 rnn = nn.RNN(nn.LSTMCell(channels_out), unroll=10, return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) self.assertIn( layer_params['kernel'].shape[0], [channels_in, channels_out] ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_time_major(self): seq_len = 40 batch_size = 10 channels_in = 5 channels_out = 15 rnn = nn.RNN(nn.LSTMCell(channels_out), time_major=True, return_carry=True) xs = jnp.ones((seq_len, batch_size, channels_in)) variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) # carry state should not be zeros after apply for leaf in jax.tree_util.tree_leaves(carry): assert not np.allclose(leaf, jnp.zeros_like(leaf)) self.assertEqual(leaf.shape, (batch_size, channels_out)) self.assertEqual(ys.shape, (seq_len, batch_size, channels_out)) for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out,)) self.assertIn( layer_params['kernel'].shape[0], [channels_in, channels_out] ) self.assertEqual(layer_params['kernel'].shape[1], channels_out) def test_rnn_with_spatial_dimensions(self): batch_size = 10 seq_len = 40 kernel_size = (3, 3) image_size = (32, 32) channels_in = 5 channels_out = 15 rnn = nn.RNN( nn.ConvLSTMCell(channels_out, kernel_size), ) xs = jnp.ones((batch_size, seq_len, *image_size, channels_in)) variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs, return_carry=True) # carry state should not be zeros after apply for leaf in jax.tree_util.tree_leaves(carry): assert not np.allclose(leaf, jnp.zeros_like(leaf)) self.assertEqual(leaf.shape[:-1], (batch_size, *image_size)) self.assertIn(leaf.shape[-1], [channels_in, channels_out]) self.assertEqual(ys.shape, (batch_size, seq_len, *image_size, channels_out)) for layer_params in variables['params']['cell'].values(): if 'bias' in layer_params: self.assertEqual(layer_params['bias'].shape, (channels_out * 4,)) self.assertIn( layer_params['kernel'].shape[2], [channels_in, channels_out, channels_out * 4], ) self.assertEqual(layer_params['kernel'].shape[3], channels_out * 4) def test_numerical_equivalence(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) cell_params = variables['params']['cell'] for i in range(seq_len): cell_carry, y = rnn.cell.apply( {'params': cell_params}, cell_carry, xs[:, i, :] ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) np.testing.assert_allclose(cell_carry, carry, rtol=1e-5) def test_numerical_equivalence_with_mask(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 key = jax.random.key(0) seq_lengths = jax.random.randint( key, (batch_size,), minval=1, maxval=seq_len + 1 ) rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output( jax.random.key(0), xs, seq_lengths=seq_lengths ) cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) cell_params = variables['params']['cell'] carries = [] for i in range(seq_len): cell_carry, y = rnn.cell.apply( {'params': cell_params}, cell_carry, xs[:, i, :] ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) carries.append(cell_carry) for batch_idx, length in enumerate(seq_lengths): t = int(length) - 1 for carries_t_, carry_ in zip(carries[t], carry): np.testing.assert_allclose( carries_t_[batch_idx], carry_[batch_idx], rtol=1e-5 ) def test_numerical_equivalence_single_batch(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( {'params': cell_params}, cell_carry, xs[batch_idx, i, :][None] ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-6) carry_i = jax.tree_util.tree_map( lambda x: x[batch_idx : batch_idx + 1], carry ) np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-6) def test_numerical_equivalence_single_batch_nn_scan(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 cell: nn.LSTMCell = nn.LSTMCell(channels_out) rnn: nn.LSTMCell = nn.scan( nn.LSTMCell, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False}, )(channels_out) xs = jnp.ones((batch_size, seq_len, channels_in)) carry = rnn.initialize_carry(jax.random.key(0), xs[:, 0].shape) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.key(0), carry, xs) cell_params = variables['params'] for batch_idx in range(batch_size): cell_carry = cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = cell.apply( {'params': cell_params}, cell_carry, xs[batch_idx : batch_idx + 1, i, :], ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) carry_i = jax.tree_util.tree_map( lambda x: x[batch_idx : batch_idx + 1], carry ) np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-5) def test_numerical_equivalence_single_batch_jax_scan(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 xs = jax.random.uniform( jax.random.key(0), (batch_size, seq_len, channels_in) ) cell: nn.LSTMCell = nn.LSTMCell(channels_out) carry = cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) variables = cell.init(jax.random.key(0), carry, xs[:, 0]) cell_params = variables['params'] def scan_fn(carry, x): return cell.apply({'params': cell_params}, carry, x) ys: jnp.ndarray carry, ys = jax.lax.scan(scan_fn, carry, xs.swapaxes(0, 1)) ys = ys.swapaxes(0, 1) cell_carry = cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) for i in range(seq_len): cell_carry, y = cell.apply( {'params': cell_params}, cell_carry, xs[:, i, :] ) np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-4) np.testing.assert_allclose(cell_carry, carry, rtol=1e-4) def test_reverse(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True, reverse=True) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( {'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None], ) np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) np.testing.assert_allclose( cell_carry, jax.tree_util.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry), rtol=1e-5, ) def test_reverse_but_keep_order(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 rnn = nn.RNN( nn.LSTMCell(channels_out), return_carry=True, reverse=True, keep_order=True, ) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( {'params': cell_params}, cell_carry, xs[batch_idx, seq_len - i - 1, :][None], ) np.testing.assert_allclose( y[0], ys[batch_idx, seq_len - i - 1, :], rtol=1e-5 ) np.testing.assert_allclose( cell_carry, jax.tree_util.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry), rtol=1e-5, ) def test_flip_sequence(self): x = jnp.arange(2 * 5).reshape((2, 5)) seq_lengths = jnp.array([4, 2]) flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=False) self.assertEqual(flipped.shape, (2, 5)) np.testing.assert_allclose(flipped[0, :4], [3, 2, 1, 0]) np.testing.assert_allclose(flipped[1, :2], [6, 5]) def test_flip_sequence_more_feature_dims(self): x = jnp.arange(2 * 5 * 3).reshape((2, 5, 3)) seq_lengths = jnp.array([4, 2]) flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=False) self.assertEqual(flipped.shape, (2, 5, 3)) np.testing.assert_allclose(flipped[0, :4], x[0, :4][::-1]) np.testing.assert_allclose(flipped[1, :2], x[1, :2][::-1]) def test_flip_sequence_time_major(self): x = jnp.arange(2 * 5).reshape((5, 2)) seq_lengths = jnp.array([4, 2]) flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=True) self.assertEqual(flipped.shape, (5, 2)) np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1]) np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1]) def test_flip_sequence_time_major_more_feature_dims(self): x = jnp.arange(2 * 5 * 3).reshape((5, 2, 3)) seq_lengths = jnp.array([4, 2]) flipped = flip_sequences(x, seq_lengths, num_batch_dims=1, time_major=True) self.assertEqual(flipped.shape, (5, 2, 3)) np.testing.assert_allclose(flipped[:4, 0], x[:4, 0][::-1]) np.testing.assert_allclose(flipped[:2, 1], x[:2, 1][::-1]) def test_basic_seq_lengths(self): x = jnp.ones((2, 10, 6)) lstm = nn.RNN(nn.LSTMCell(265)) variables = lstm.init(jax.random.key(0), x) y = lstm.apply(variables, x, seq_lengths=jnp.array([5, 5])) class BidirectionalTest(absltest.TestCase): def test_bidirectional(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 bdirectional = nn.Bidirectional( nn.RNN(nn.LSTMCell(channels_out)), nn.RNN(nn.LSTMCell(channels_out)) ) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) def test_shared_cell(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 cell = nn.LSTMCell(channels_out) bdirectional = nn.Bidirectional(nn.RNN(cell), nn.RNN(cell)) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) def test_custom_merge_fn(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 bdirectional = nn.Bidirectional( nn.RNN(nn.LSTMCell(channels_out)), nn.RNN(nn.LSTMCell(channels_out)), merge_fn=lambda x, y: x + y, ) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) def test_return_carry(self): batch_size = 3 seq_len = 4 channels_in = 5 channels_out = 6 bdirectional = nn.Bidirectional( nn.RNN(nn.LSTMCell(channels_out)), nn.RNN(nn.LSTMCell(channels_out)), return_carry=True, ) xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = bdirectional.init_with_output( jax.random.key(0), xs ) carry_forward, carry_backward = carry self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) self.assertEqual( jax.tree_util.tree_map(jnp.shape, carry_forward), ((batch_size, channels_out), (batch_size, channels_out)), ) self.assertEqual( jax.tree_util.tree_map(jnp.shape, carry_backward), ((batch_size, channels_out), (batch_size, channels_out)), ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.""" import copy import functools from typing import Any from absl.testing import absltest, parameterized from flax import ids from flax import linen as nn from flax.linen import fp8_ops from flax.training import train_state import jax from jax import random import jax.numpy as jnp import numpy as np import optax # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() def check_eq(xs, ys): return jax.tree_util.tree_all( jax.tree_util.tree_map(np.testing.assert_allclose, xs, ys) ) class PoolTest(parameterized.TestCase): def test_pool_custom_reduce(self): x = jnp.full((1, 3, 3, 1), 2.0) mul_reduce = lambda x, y: x * y y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID') np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4)) @parameterized.parameters( {'count_include_pad': True}, {'count_include_pad': False} ) def test_avg_pool(self, count_include_pad): x = jnp.full((1, 3, 3, 1), 2.0) pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) y = pool(x) np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0)) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array( [ [0.25, 0.5, 0.25], [0.5, 1.0, 0.5], [0.25, 0.5, 0.25], ] ).reshape((1, 3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) @parameterized.parameters( {'count_include_pad': True}, {'count_include_pad': False} ) def test_avg_pool_no_batch(self, count_include_pad): x = jnp.full((3, 3, 1), 2.0) pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) y = pool(x) np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0)) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array( [ [0.25, 0.5, 0.25], [0.5, 1.0, 0.5], [0.25, 0.5, 0.25], ] ).reshape((3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) expected_y = jnp.array( [ [4.0, 5.0], [7.0, 8.0], ] ).reshape((1, 2, 2, 1)) y = pool(x) np.testing.assert_allclose(y, expected_y) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array( [ [0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 1.0], ] ).reshape((1, 3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad) @parameterized.parameters( {'count_include_pad': True}, {'count_include_pad': False} ) def test_avg_pool_padding_same(self, count_include_pad): x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) pool = lambda x: nn.avg_pool( x, (2, 2), padding='SAME', count_include_pad=count_include_pad ) y = pool(x) if count_include_pad: expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( (1, 2, 2, 1) ) else: expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( (1, 2, 2, 1) ) np.testing.assert_allclose(y, expected_y) def test_pooling_variable_batch_dims(self): x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32) y = nn.max_pool(x, (2, 2), (2, 2)) assert y.shape == (1, 8, 16, 16, 3) def test_pooling_no_batch_dims(self): x = jnp.zeros((32, 32, 3), dtype=jnp.float32) y = nn.max_pool(x, (2, 2), (2, 2)) assert y.shape == (16, 16, 3) class NormalizationTest(parameterized.TestCase): def test_layer_norm_mask(self): key = random.key(0) keys = random.split(key) x = random.normal(keys[0], (3, 4, 5)) m = random.choice(keys[1], 2, x.shape).astype(bool) m = m.at[..., :2].set(True) # guarantee at least 2 elements x = jnp.where(m, x, jnp.nan) module = nn.LayerNorm() y, w = module.init_with_output(key, x, mask=m) z = y.mean(-1, where=m) np.testing.assert_allclose(z, 0, atol=1e-4) z = y.var(-1, where=m) np.testing.assert_allclose(z, 1, atol=1e-4) def test_rms_norm_mask(self): key = random.key(0) keys = random.split(key) x = random.normal(keys[0], (3, 4, 5)) m = random.choice(keys[1], 2, x.shape).astype(bool) m = m.at[..., :1].set(True) # guarantee at least 1 element x = jnp.where(m, x, jnp.nan) module = nn.RMSNorm() y, w = module.init_with_output(key, x, mask=m) z = np.square(y).mean(-1, where=m) np.testing.assert_allclose(z, 1, atol=1e-4) def test_group_norm_mask(self): key = random.key(0) keys = random.split(key) x = random.normal(keys[0], (13, 3, 5, 7 * 11)) m = random.choice(keys[1], 2, x.shape).astype(bool) m = m.at[..., :2].set(True) # guarantee at least 2 elements x = jnp.where(m, x, jnp.nan) module = nn.GroupNorm(7, use_bias=False, use_scale=False) y, w = module.init_with_output(key, x, mask=m) yr = y.reshape((13, 3, 5, 7, 11)) mr = m.reshape((13, 3, 5, 7, 11)) axes = list(range(1, x.ndim - 1)) + [-1] z = yr.mean(axes, where=mr) np.testing.assert_allclose(z, 0, atol=1e-4) z = yr.var(axes, where=mr) np.testing.assert_allclose(z, 1, atol=1e-4) @parameterized.parameters({'test_mask': True}, {'test_mask': False}) def test_batch_norm(self, test_mask): rng = random.key(0) key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2)) if test_mask: m = random.randint( key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 ).astype(jnp.bool_) x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) else: m = None model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False) y, initial_params = model_cls.init_with_output(key3, x, mask=m) mean = y.mean((0, 1), where=m) var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) _, vars_out = model_cls.apply( initial_params, x, mutable=['batch_stats'], mask=m ) ema = vars_out['batch_stats'] np.testing.assert_allclose( ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), rtol=1e-4, ) @parameterized.parameters({'test_mask': True}, {'test_mask': False}) def test_batch_norm_complex(self, test_mask): rng = random.key(0) key1, key2, key3 = random.split(rng, 3) x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) if test_mask: m = random.randint( key2, (4, 3, 1), minval=0, maxval=2, dtype=jnp.int32 ).astype(jnp.bool_) x = jnp.where(m, x, jnp.ones_like(x) * jnp.nan) else: m = None model_cls = nn.BatchNorm( momentum=0.9, use_running_average=False, dtype=jnp.complex64 ) y, initial_params = model_cls.init_with_output(key3, x, mask=m) mean = y.mean((0, 1), where=m) var = y.var((0, 1), where=m) np.testing.assert_allclose(mean, np.array([0.0, 0.0]), atol=1e-4) np.testing.assert_allclose(var, np.array([1.0, 1.0]), rtol=1e-4) self.assertEqual(mean.dtype, jnp.complex64) _, vars_out = model_cls.apply( initial_params, x, mutable=['batch_stats'], mask=m ) ema = vars_out['batch_stats'] np.testing.assert_allclose( ema['mean'], 0.1 * x.mean((0, 1), keepdims=False, where=m), atol=1e-4 ) np.testing.assert_allclose( ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False, where=m), rtol=1e-4, ) @parameterized.parameters( {'reduction_axes': -1}, {'reduction_axes': 1}, {'reduction_axes': (1, 2)}, {'reduction_axes': (0, 1, 2)}, {'reduction_axes': -1, 'use_fast_variance': False}, ) def test_layer_norm(self, reduction_axes, use_fast_variance=True): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4)) if not use_fast_variance: x += 1e6 # This blows up fast variance, but should work otherwise. model_cls = nn.LayerNorm( use_bias=False, use_scale=False, epsilon=e, reduction_axes=reduction_axes, use_fast_variance=use_fast_variance, ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) y_one_liner = ( x - x.mean(axis=reduction_axes, keepdims=True) ) * jax.lax.rsqrt(x.var(axis=reduction_axes, keepdims=True) + e) np.testing.assert_allclose(y_one_liner, y, atol=1e-3, rtol=1e-3) @parameterized.parameters( {'reduction_axes': -1}, {'reduction_axes': 1}, {'reduction_axes': (1, 2)} ) def test_rms_norm(self, reduction_axes): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4)) model_cls = nn.RMSNorm( use_scale=False, epsilon=e, reduction_axes=reduction_axes ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) y_one_liner = x * jax.lax.rsqrt( jnp.mean(jax.lax.square(x), axis=reduction_axes, keepdims=True) + e ) np.testing.assert_allclose(y_one_liner, y, atol=1e-4) def test_group_norm(self): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm( num_groups=2, use_bias=False, use_scale=False, epsilon=e ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) x_gr = x.reshape([2, 5, 4, 4, 2, 16]) y_test = ( x_gr - x_gr.mean(axis=[1, 2, 3, 5], keepdims=True) ) * jax.lax.rsqrt(x_gr.var(axis=[1, 2, 3, 5], keepdims=True) + e) y_test = y_test.reshape([2, 5, 4, 4, 32]) np.testing.assert_allclose(y_test, y, atol=1e-4) def test_group_norm_unbatched(self): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm( num_groups=2, use_bias=False, use_scale=False, epsilon=e, reduction_axes=(0, 1, 3, 4), ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) x_gr = x.reshape([2, 5, 4, 4, 2, 16]) y_test = ( x_gr - x_gr.mean(axis=[0, 1, 3, 5], keepdims=True) ) * jax.lax.rsqrt(x_gr.var(axis=[0, 1, 3, 5], keepdims=True) + e) y_test = y_test.reshape([2, 5, 4, 4, 32]) np.testing.assert_allclose(y_test, y, atol=1e-4) def test_group_norm_batched(self): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (3, 4, 32)) model_cls = nn.GroupNorm( num_groups=2, use_bias=False, use_scale=False, epsilon=e, reduction_axes=(-3, -2, -1), ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) x_stacked = jnp.stack([x] * 5) y_stacked = model_cls.apply({}, x_stacked) np.testing.assert_allclose(y, y_stacked[0, ...], atol=1e-4) x_gr = x_stacked.reshape([5, 3, 4, 2, 16]) y_test = (x_gr - x_gr.mean(axis=[1, 2, 4], keepdims=True)) * jax.lax.rsqrt( x_gr.var(axis=[1, 2, 4], keepdims=True) + e ) y_test = y_test.reshape([5, 3, 4, 32]) np.testing.assert_allclose(y_test, y_stacked, atol=1e-4) def test_group_norm_raises(self): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm( num_groups=3, use_bias=False, use_scale=False, epsilon=e ) with self.assertRaises(ValueError): model_cls.init_with_output(key2, x) def test_group_norm_raises_incorrect_reduction_axes(self): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) model_cls = nn.GroupNorm( num_groups=3, use_bias=False, use_scale=False, epsilon=e, reduction_axes=(0, 1, 2, 3), ) with self.assertRaises(ValueError): model_cls.init_with_output(key2, x) def test_batch_norm_multi_init(self): class Foo(nn.Module): @nn.compact def __call__(self, x): norm = nn.BatchNorm( name='norm', use_running_average=False, axis_name='batch', ) x = norm(x) return x, norm(x) key = random.key(0) model = Foo() x = random.normal(random.key(1), (2, 4)) (y1, y2), _ = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=0.005) @parameterized.parameters( {'feature_axes': -1}, {'feature_axes': (1, 2)}, {'feature_axes': (1, 2, 3)}, {'feature_axes': -1, 'use_fast_variance': False}, ) def test_instance_norm(self, feature_axes, use_fast_variance=True): rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4, 5)) if not use_fast_variance: x += 1e4 # This blows up fast variance, but should work otherwise. model_cls = nn.InstanceNorm( use_bias=False, use_scale=False, epsilon=e, feature_axes=feature_axes, use_fast_variance=use_fast_variance, ) y, _ = model_cls.init_with_output(key2, x) self.assertEqual(x.dtype, y.dtype) self.assertEqual(x.shape, y.shape) canonicalized_feature_axes = [ i if i >= 0 else (x.ndim + i) for i in ( feature_axes if isinstance(feature_axes, tuple) else (feature_axes,) ) ] reduction_axes = [ i for i in range(1, x.ndim) if i not in canonicalized_feature_axes ] y_one_liner = ( x - x.mean(axis=reduction_axes, keepdims=True) ) * jax.lax.rsqrt(x.var(axis=reduction_axes, keepdims=True) + e) np.testing.assert_allclose(y_one_liner, y, atol=1e-6) @parameterized.parameters( {'feature_axes': 0}, {'feature_axes': -4}, {'feature_axes': (0, 3)}, {'feature_axes': (2, -4)}, ) def test_instance_norm_raise_error(self, feature_axes): with self.assertRaisesRegex( ValueError, 'The channel axes cannot include the leading dimension ' 'as this is assumed to be the batch axis.', ): x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) layer = nn.InstanceNorm(feature_axes=feature_axes) _ = layer.init(jax.random.key(1), x) @parameterized.parameters( { 'layer1': nn.LayerNorm(feature_axes=(1, 2)), 'layer2': nn.InstanceNorm(feature_axes=(1, 2)), }, { 'layer1': nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1), 'layer2': nn.InstanceNorm(feature_axes=-1), }, { 'layer1': nn.LayerNorm( reduction_axes=-2, feature_axes=(1, 3), bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ), 'layer2': nn.InstanceNorm( feature_axes=(1, -1), bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ), }, { 'layer1': nn.LayerNorm( reduction_axes=(1, 2, 3), bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ), 'layer2': nn.GroupNorm( num_groups=1, bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform() ), }, { 'layer1': nn.InstanceNorm( bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ), 'layer2': nn.GroupNorm( num_groups=None, group_size=1, bias_init=nn.initializers.uniform(), scale_init=nn.initializers.uniform(), ), }, ) def test_normalization_equivalence(self, layer1, layer2): x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) layer1_variables = layer1.init(jax.random.key(1), x) layer2_variables = layer2.init(jax.random.key(1), x) self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda v1, v2: (v1 == v2).all(), layer1_variables, layer2_variables, ) ) ) layer1_y = layer1.apply(layer1_variables, x) layer2_y = layer2.apply(layer2_variables, x) np.testing.assert_allclose(layer1_y, layer2_y, atol=1e-7) @parameterized.parameters( { 'model_index': 0, 'key_paths': {'Dense_1/kernel/u', 'Dense_1/kernel/sigma'}, }, { 'model_index': 1, 'key_paths': {'Conv_0/kernel/u', 'Conv_0/kernel/sigma'}, }, { 'model_index': 2, 'key_paths': { 'MultiHeadDotProductAttention_0/key/bias/u', 'MultiHeadDotProductAttention_0/key/kernel/u', 'MultiHeadDotProductAttention_0/out/kernel/u', 'MultiHeadDotProductAttention_0/query/bias/u', 'MultiHeadDotProductAttention_0/query/kernel/u', 'MultiHeadDotProductAttention_0/value/bias/u', 'MultiHeadDotProductAttention_0/value/kernel/u', 'MultiHeadDotProductAttention_0/key/bias/sigma', 'MultiHeadDotProductAttention_0/key/kernel/sigma', 'MultiHeadDotProductAttention_0/out/kernel/sigma', 'MultiHeadDotProductAttention_0/query/bias/sigma', 'MultiHeadDotProductAttention_0/query/kernel/sigma', 'MultiHeadDotProductAttention_0/value/bias/sigma', 'MultiHeadDotProductAttention_0/value/kernel/sigma', }, }, ) def test_spectral_norm_train(self, model_index, key_paths): class FooDense(nn.Module): @nn.compact def __call__(self, x, train): x = nn.Dense(8)(x) x = nn.SpectralNorm(nn.Dense(6))(x, update_stats=train) x = nn.Dense(4)(x) return x class FooConv(nn.Module): @nn.compact def __call__(self, x, train): x = nn.Dense(9)(x) x = x.reshape((1, 3, 3)) x = nn.SpectralNorm(nn.Conv(2, kernel_size=(2, 2)))( x, update_stats=train ) x = x.reshape(1, -1) x = nn.Dense(4)(x) return x class FooAttention(nn.Module): @nn.compact def __call__(self, x, train): a = nn.Dense(4)(x) b = nn.Dense(4)(x) x = nn.SpectralNorm(nn.attention.MultiHeadDotProductAttention(4))( a, b, update_stats=train ) x = nn.Dense(4)(x) return x key1, key2, key3 = random.split(random.PRNGKey(0), 3) x = random.normal(key1, (1, 4)) y = random.normal(key2, (1, 4)) model_cls = (FooDense, FooConv, FooAttention)[model_index] variables = model_cls().init(key3, x, train=False) params, batch_stats = variables['params'], variables['batch_stats'] self.assertEqual(key_paths, batch_stats['SpectralNorm_0'].keys()) class TrainState(train_state.TrainState): batch_stats: Any state = TrainState.create( apply_fn=model_cls().apply, params=params, batch_stats=batch_stats, tx=optax.adam(1e-3), ) @jax.jit def train_step(state, batch): def loss_fn(params): logits, updates = state.apply_fn( {'params': params, 'batch_stats': state.batch_stats}, x=batch['image'], train=True, mutable=['batch_stats'], ) loss = jnp.mean( optax.l2_loss(predictions=logits, targets=batch['label']) ) return loss, updates grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, updates), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) state = state.replace(batch_stats=updates['batch_stats']) return state, loss prev_loss = float('inf') for _ in range(10): state, loss = train_step(state, {'image': x, 'label': y}) self.assertLess(loss, prev_loss) prev_loss = loss @parameterized.parameters( {'n_steps': 1, 'update_stats': True, 'result': 4.0}, {'n_steps': 3, 'update_stats': True, 'result': 4.0}, {'n_steps': 10, 'update_stats': True, 'result': 4.0}, {'n_steps': 1, 'update_stats': False, 'result': 1.0}, ) def test_spectral_norm_sigma(self, n_steps, update_stats, result): class Foo(nn.Module): @nn.compact def __call__(self, x, train): x = nn.SpectralNorm(nn.Dense(8, use_bias=False), n_steps=n_steps)( x, update_stats=train ) return x x = jnp.ones((1, 8)) model_cls = Foo() variables = model_cls.init(random.PRNGKey(0), x, train=False) params, batch_stats = variables['params'], variables['batch_stats'] params = jax.tree_util.tree_map(lambda x: 4 * jnp.eye(*x.shape), params) _, updates = model_cls.apply( {'params': params, 'batch_stats': batch_stats}, x=x, train=update_stats, mutable=True, ) np.testing.assert_allclose( updates['batch_stats']['SpectralNorm_0']['Dense_0/kernel/sigma'], result, atol=1e-3, ) @parameterized.parameters( {'error_on_non_matrix': True}, {'error_on_non_matrix': False} ) def test_spectral_norm_3d_tensor(self, error_on_non_matrix): class Foo(nn.Module): @nn.compact def __call__(self, x, train): x = nn.SpectralNorm( nn.DenseGeneral((3, 4), use_bias=False), error_on_non_matrix=error_on_non_matrix, )(x, update_stats=train) return x x = jnp.ones((1, 2)) model_cls = Foo() if error_on_non_matrix: with self.assertRaisesRegex( ValueError, 'Input is 3D but error_on_non_matrix is True' ): _ = model_cls.init(random.PRNGKey(0), x, train=False) else: _ = model_cls.init(random.PRNGKey(0), x, train=False) @parameterized.parameters( {'feature_axes': -1, 'reduction_axes': 0, 'variable_filter': {'kernel'}}, {'feature_axes': 0, 'reduction_axes': 1, 'variable_filter': {'kernel'}}, { 'feature_axes': (0, 1), 'reduction_axes': (), 'variable_filter': {'kernel'}, }, { 'feature_axes': (), 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}, }, { 'feature_axes': None, 'reduction_axes': (0, 1), 'variable_filter': {'kernel'}, }, {'feature_axes': 0, 'reduction_axes': (), 'variable_filter': {'bias'}}, {'feature_axes': (), 'reduction_axes': -1, 'variable_filter': {'bias'}}, ) def test_manual_weight_norm( self, feature_axes, reduction_axes, variable_filter ): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.WeightNorm( nn.Dense(2, bias_init=nn.initializers.normal()), feature_axes=feature_axes, variable_filter=variable_filter, )(x) key1, key2 = jax.random.split(jax.random.key(1)) x = jax.random.normal(key1, (1, 3)) module = Foo() v = module.init(key2, x) v = jax.tree_util.tree_map(lambda x: x + 0.5, v) out = module.apply(v, x) kernel = v['params']['Dense_0']['kernel'] if 'kernel' in variable_filter: kernel /= jnp.sqrt(jnp.sum(kernel**2, axis=reduction_axes, keepdims=True)) kernel_scale = jnp.expand_dims( v['params']['WeightNorm_0']['Dense_0/kernel/scale'], axis=reduction_axes, ) else: kernel_scale = 1 bias = v['params']['Dense_0']['bias'] if 'bias' in variable_filter: bias /= jnp.sqrt(jnp.sum(bias**2, axis=reduction_axes, keepdims=True)) bias_scale = jnp.expand_dims( v['params']['WeightNorm_0']['Dense_0/bias/scale'], axis=reduction_axes ) else: bias_scale = 1 manual_out = jnp.dot(x, kernel_scale * kernel) + ( bias_scale * bias ).reshape(1, -1) self.assertTrue(jnp.allclose(out, manual_out)) @parameterized.parameters( { 'variable_filters': ({}, None, {'kernel', 'bias'}, {'Bar'}), 'key_paths': { 'Bar_0/Baz_0/Dense_0/kernel/scale', 'Bar_0/Baz_0/Dense_0/bias/scale', 'Bar_0/Dense_0/kernel/scale', 'Bar_0/Dense_0/bias/scale', 'Bar_0/Baz_1/Dense_0/kernel/scale', 'Bar_0/Baz_1/Dense_0/bias/scale', 'Bar_0/Dense_1/kernel/scale', 'Bar_0/Dense_1/bias/scale', }, }, { 'variable_filters': ({'kernel'},), 'key_paths': { 'Bar_0/Baz_0/Dense_0/kernel/scale', 'Bar_0/Dense_0/kernel/scale', 'Bar_0/Baz_1/Dense_0/kernel/scale', 'Bar_0/Dense_1/kernel/scale', }, }, { 'variable_filters': ({'Baz', 'kernel'},), 'key_paths': { 'Bar_0/Baz_0/Dense_0/kernel/scale', 'Bar_0/Baz_0/Dense_0/bias/scale', 'Bar_0/Dense_0/kernel/scale', 'Bar_0/Baz_1/Dense_0/kernel/scale', 'Bar_0/Baz_1/Dense_0/bias/scale', 'Bar_0/Dense_1/kernel/scale', }, }, ) def test_weight_norm_variable_filter(self, variable_filters, key_paths): class Baz(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(2)(x) class Bar(nn.Module): @nn.compact def __call__(self, x): x = Baz()(x) x = nn.Dense(3)(x) x = Baz()(x) x = nn.Dense(3)(x) return x for variable_filter in variable_filters: class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.WeightNorm(Bar(), variable_filter=variable_filter)(x) v = Foo().init(jax.random.key(0), jnp.ones((1, 4))) self.assertEqual(key_paths, v['params']['WeightNorm_0'].keys()) @parameterized.parameters( {'model_index': 0, 'key_paths': {'Dense_1/kernel/scale'}}, {'model_index': 1, 'key_paths': {'Conv_0/kernel/scale'}}, { 'model_index': 2, 'key_paths': { 'MultiHeadDotProductAttention_0/key/kernel/scale', 'MultiHeadDotProductAttention_0/out/kernel/scale', 'MultiHeadDotProductAttention_0/query/kernel/scale', 'MultiHeadDotProductAttention_0/value/kernel/scale', }, }, ) def test_weight_norm_train(self, model_index, key_paths): class FooDense(nn.Module): @nn.compact def __call__( self, x, ): x = nn.Dense(8)(x) x = nn.WeightNorm(nn.Dense(6))(x) x = nn.Dense(4)(x) return x class FooConv(nn.Module): @nn.compact def __call__( self, x, ): x = nn.Dense(9)(x) x = x.reshape((1, 3, 3)) x = nn.WeightNorm(nn.Conv(2, kernel_size=(2, 2)))(x) x = x.reshape(1, -1) x = nn.Dense(4)(x) return x class FooAttention(nn.Module): @nn.compact def __call__(self, x): a = nn.Dense(4)(x) b = nn.Dense(4)(x) x = nn.WeightNorm(nn.attention.MultiHeadDotProductAttention(4))(a, b) x = nn.Dense(4)(x) return x key1, key2, key3 = random.split(random.PRNGKey(0), 3) x = random.normal(key1, (1, 4)) y = random.normal(key2, (1, 4)) model_cls = (FooDense, FooConv, FooAttention)[model_index] params = model_cls().init(key3, x)['params'] self.assertEqual(key_paths, params['WeightNorm_0'].keys()) state = train_state.TrainState.create( apply_fn=model_cls().apply, params=params, tx=optax.adam(1e-3), ) @jax.jit def train_step(state, batch): def loss_fn(params): logits = state.apply_fn( {'params': params}, x=batch['image'], ) loss = jnp.mean( optax.l2_loss(predictions=logits, targets=batch['label']) ) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state, loss prev_loss = float('inf') for _ in range(10): state, loss = train_step(state, {'image': x, 'label': y}) self.assertLess(loss, prev_loss) prev_loss = loss def test_weight_norm_compatibility_with_partitioning(self): replicated_module = nn.WeightNorm(nn.Dense(10)) @jax.jit def _init_replicated(x): return replicated_module.init(jax.random.key(0), x) expected = _init_replicated(jnp.ones((10, 10))) sharded_module = nn.WeightNorm( nn.Dense( 10, kernel_init=nn.with_partitioning( nn.initializers.lecun_normal(), ('x',) ), ) ) @jax.jit def _init_sharded(x): return sharded_module.init(jax.random.key(0), x) got = _init_sharded(jnp.ones((10, 10))) def _strip_partitioning(x): if isinstance(x, nn.Partitioned): return x.value return x got = jax.tree.map( _strip_partitioning, got, is_leaf=lambda x: isinstance(x, nn.Partitioned), ) expected_scale = expected['params'].pop('layer_instance/kernel/scale') # NOTE: `value` is from `nn.Partitioned.value` got_scale = got['params'].pop('layer_instance/kernel/value/scale') np.testing.assert_array_equal(got_scale, expected_scale) # Compares the rest of PyTree nodes jax.tree.map(lambda x, y: np.testing.assert_allclose(x, y), expected, got) class StochasticTest(parameterized.TestCase): def test_dropout(self): rng = random.key(0) key1, key2 = random.split(rng) module = nn.Dropout(rate=0.5) y1 = module.apply( {}, jnp.ones((20, 20)), deterministic=False, rngs={'dropout': key1} ) y2 = module.apply( {}, jnp.ones((20, 20)), deterministic=False, rngs={'dropout': key2} ) self.assertFalse(np.all(y1 == y2)) y1 = module.apply( {}, jnp.ones((20, 20)), deterministic=True, rngs={'dropout': key1} ) y2 = module.apply( {}, jnp.ones((20, 20)), deterministic=True, rngs={'dropout': key2} ) self.assertTrue(np.all(y1 == y2)) def test_dropout_rate_stats(self): rootkey = random.key(0) for rate in np.arange(0.1, 1.0, 0.1): rootkey, subkey = random.split(rootkey) module = nn.Dropout(rate=rate) n_trials = 10 nonzero_counts = 0 for key in random.split(subkey, n_trials): y = module.apply( {}, jnp.ones((100, 100)), deterministic=False, rngs={'dropout': key} ) nonzero_counts += np.sum(y > 0.0) all_counts = np.prod((100, 100, n_trials)) frac = np.sum(nonzero_counts) / all_counts keep_rate = 1.0 - rate # just check within 4 sigma. delta = 4 * np.sqrt(rate * keep_rate) / np.sqrt(all_counts) self.assertTrue(keep_rate - delta < frac < keep_rate + delta) def test_dropout_rate_limits(self): rng = random.key(0) key1, key2, key3 = random.split(rng, 3) inputs = jnp.ones((20, 20)) d0 = nn.Dropout(rate=0.0) y1 = d0.apply({}, inputs, deterministic=False, rngs={'dropout': key1}) np.testing.assert_array_equal(y1, inputs) d1 = nn.Dropout(rate=1.0) y2 = d1.apply({}, inputs, deterministic=False, rngs={'dropout': key2}) np.testing.assert_array_equal(y2, np.zeros_like(inputs)) # ensure gradient of rate==1.0 case is non-NaN fn = lambda x, k: d1.apply({}, x, rngs={'dropout': k}, deterministic=False) res = jax.grad(lambda x, k: jnp.sum(fn(x, k)))(inputs, key3) self.assertFalse(np.isnan(res).any()) @parameterized.parameters( { 'num_dims': 2, 'broadcast_dims': (1,), 'slice_fn': lambda out, i: out[i, :], 'summed_total': 2 * 10, }, { 'num_dims': 2, 'broadcast_dims': (0,), 'slice_fn': lambda out, i: out[:, i], 'summed_total': 2 * 10, }, { 'num_dims': 3, 'broadcast_dims': (1, 2), 'slice_fn': lambda out, i: out[i, :, :], 'summed_total': 2 * 10 * 10, }, { 'num_dims': 3, 'broadcast_dims': (1,), 'slice_fn': lambda out, i, j: out[i, :, j], 'summed_total': 2 * 10, }, { 'num_dims': 4, 'broadcast_dims': (0, 2, 3), 'slice_fn': lambda out, i: out[:, i, :, :], 'summed_total': 2 * 10 * 10 * 10, }, { 'num_dims': 4, 'broadcast_dims': (0, 1), 'slice_fn': lambda out, i, j: out[:, :, i, j], 'summed_total': 2 * 10 * 10, }, { 'num_dims': 4, 'broadcast_dims': (3,), 'slice_fn': lambda out, i, j, k: out[i, j, k, :], 'summed_total': 2 * 10, }, ) def test_dropout_broadcast( self, num_dims, broadcast_dims, slice_fn, summed_total ): module = nn.Dropout( rate=0.5, broadcast_dims=broadcast_dims, deterministic=False ) x = jnp.ones((10,) * num_dims) out = module.apply({}, x, rngs={'dropout': random.key(0)}) for i in range(10): if num_dims - len(broadcast_dims) >= 2: for j in range(10): if num_dims - len(broadcast_dims) >= 3: for k in range(10): self.assertTrue(slice_fn(out, i, j, k).sum() in (0, summed_total)) else: self.assertTrue(slice_fn(out, i, j).sum() in (0, summed_total)) else: self.assertTrue(slice_fn(out, i).sum() in (0, summed_total)) def test_dropout_manual_rng(self): class Foo(nn.Module): @nn.compact def __call__(self, x): key = self.make_rng('dropout') x1 = nn.Dropout(rate=0.5, deterministic=False)(x, rng=key) x2 = nn.Dropout(rate=0.5, deterministic=False)(x, rng=jax.random.clone(key)) return x1, x2 module = Foo() x1, x2 = module.apply( {}, jnp.ones((20, 20)), rngs={'dropout': random.key(0)} ) np.testing.assert_array_equal(x1, x2) # TODO(flax-dev): add integration tests for RNN cells class RecurrentTest(parameterized.TestCase): def test_lstm(self): lstm = nn.LSTMCell(features=4) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x) self.assertEqual(carry[0].shape, (2, 4)) self.assertEqual(carry[1].shape, (2, 4)) np.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) self.assertEqual( param_shapes, { 'ii': {'kernel': (3, 4)}, 'if': {'kernel': (3, 4)}, 'ig': {'kernel': (3, 4)}, 'io': {'kernel': (3, 4)}, 'hi': {'kernel': (4, 4), 'bias': (4,)}, 'hf': {'kernel': (4, 4), 'bias': (4,)}, 'hg': {'kernel': (4, 4), 'bias': (4,)}, 'ho': {'kernel': (4, 4), 'bias': (4,)}, }, ) @parameterized.parameters( { 'module_cls': nn.SimpleCell, 'expected_param_shapes': { 'i': {'kernel': (3, 4), 'bias': (4,)}, 'h': {'kernel': (4, 4)}, }, }, { 'module_cls': nn.GRUCell, 'expected_param_shapes': { 'ir': {'kernel': (3, 4), 'bias': (4,)}, 'iz': {'kernel': (3, 4), 'bias': (4,)}, 'in': {'kernel': (3, 4), 'bias': (4,)}, 'hr': {'kernel': (4, 4)}, 'hz': {'kernel': (4, 4)}, 'hn': {'kernel': (4, 4), 'bias': (4,)}, }, }, { 'module_cls': nn.MGUCell, 'expected_param_shapes': { 'if': {'kernel': (3, 4), 'bias': (4,)}, 'in': {'kernel': (3, 4), 'bias': (4,)}, 'hf': {'kernel': (4, 4)}, 'hn': {'kernel': (4, 4), 'bias': (4,)}, }, }, ) def test_gated_units(self, module_cls, expected_param_shapes): module = module_cls(features=4) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3)) carry0 = module.initialize_carry(rng, x.shape) self.assertEqual(carry0.shape, (2, 4)) (carry, y), initial_params = module.init_with_output(key2, carry0, x) self.assertEqual(carry.shape, (2, 4)) np.testing.assert_allclose(y, carry) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) self.assertEqual( param_shapes, expected_param_shapes, ) if module_cls == nn.MGUCell: self.assertTrue( (initial_params['params']['if']['bias'] == jnp.ones((4,))).all() ) self.assertTrue( (initial_params['params']['in']['bias'] == jnp.zeros((4,))).all() ) self.assertTrue( (initial_params['params']['hn']['bias'] == jnp.zeros((4,))).all() ) @parameterized.parameters( {'module_cls': nn.SimpleCell}, {'module_cls': nn.GRUCell}, {'module_cls': nn.MGUCell}, ) def test_complex_input_gated_units(self, module_cls): module_instance = module_cls(features=4) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3), dtype=jnp.complex64) carry0 = module_instance.initialize_carry(rng, x.shape) self.assertEqual(carry0.shape, (2, 4)) (carry, y), _ = module_instance.init_with_output(key2, carry0, x) self.assertEqual(carry.dtype, jnp.complex64) self.assertEqual(y.dtype, jnp.complex64) def test_convlstm(self): lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 4, 4, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) self.assertEqual(c0.shape, (2, 4, 4, 6)) self.assertEqual(h0.shape, (2, 4, 4, 6)) (carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x) self.assertEqual(carry[0].shape, (2, 4, 4, 6)) self.assertEqual(carry[1].shape, (2, 4, 4, 6)) np.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_util.tree_map(np.shape, initial_params['params']) self.assertEqual( param_shapes, { 'hh': {'bias': (6 * 4,), 'kernel': (3, 3, 6, 6 * 4)}, 'ih': {'bias': (6 * 4,), 'kernel': (3, 3, 3, 6 * 4)}, }, ) def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. lstm = nn.LSTMCell(features=4) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x) # Create OptimizedLSTMCell. lstm_opt = nn.OptimizedLSTMCell(features=4) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3)) c0, h0 = lstm_opt.initialize_carry(rng, x.shape) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(key2, (c0, h0), x) np.testing.assert_allclose(y, y_opt, rtol=1e-6) check_eq(lstm_params, lstm_opt_params) def test_mgu_reset_gate(self): module = nn.MGUCell(features=4, reset_gate=False) rng = random.key(0) rng, key1, key2 = random.split(rng, 3) x = random.normal(key1, (2, 3)) carry0 = module.initialize_carry(rng, x.shape) (carry, y), v = module.init_with_output(key2, carry0, x) self.assertIn('kernel', v['params']['hn']) self.assertNotIn('bias', v['params']['hn']) f = jax.nn.sigmoid( jnp.dot(x, v['params']['if']['kernel']) + v['params']['if']['bias'].reshape(1, -1) + jnp.dot(carry0, v['params']['hf']['kernel']) ) n = jax.nn.tanh( jnp.dot(x, v['params']['in']['kernel']) + v['params']['in']['bias'].reshape(1, -1) + jnp.dot(carry0, v['params']['hn']['kernel']) ) expected_out = (1 - f) * n + f * carry0 np.testing.assert_allclose(y, expected_out) class IdsTest(absltest.TestCase): def test_hashable(self): id1 = ids.uuid() id2 = ids.uuid() self.assertEqual(id1, id1) self.assertNotEqual(id1, id2) self.assertNotEqual(hash(id1), hash(id2)) id1c = copy.copy(id1) id1dc = copy.deepcopy(id1) self.assertNotEqual(hash(id1), hash(id1c)) self.assertNotEqual(hash(id1), hash(id1dc)) def get_fp8_dtypes(fp8_genre): assert fp8_genre in ('OCP', 'NANOO') if fp8_genre == 'OCP': e4m3_dtype = jnp.float8_e4m3fn e5m2_dtype = jnp.float8_e5m2 else: # fp8_genre == 'NANOO' e4m3_dtype = jnp.float8_e4m3fnuz e5m2_dtype = jnp.float8_e5m2fnuz return e4m3_dtype, e5m2_dtype class Fp8Test(parameterized.TestCase): @parameterized.parameters( {'x_shape': (16, 32), 'y_shape': (32, 64), 'g_shape': (16, 64), 'eqn': 'mk,kn->mn'}, {'x_shape': (2, 3, 32), 'y_shape': (64, 32), 'g_shape': (2, 3, 64), 'eqn': '...k,nk->...n'}, {'x_shape': (2, 3, 64), 'y_shape': (64, 32), 'g_shape': (2, 3, 32), 'eqn': '...k,kn->...n'}, ) def test_fp8_einsum(self, x_shape, y_shape, g_shape, eqn): rng, key1, key2, key3 = random.split(random.key(42), 4) x = random.normal(key1, x_shape) y = random.normal(key2, y_shape) g = random.normal(key3, g_shape) e4m3_dtype = jnp.float8_e4m3fn e5m2_dtype = jnp.float8_e5m2 cast_to_representable = functools.partial( fp8_ops.qdq, scale=jnp.ones((1,)), compute_dtype=jnp.float32, ) x = cast_to_representable(x, e4m3_dtype) y = cast_to_representable(y, e4m3_dtype) g = cast_to_representable(g, e5m2_dtype) p = nn.Fp8Einsum() vars = p.init(rng, eqn, x, y) def loss_fn(vars, x, y): out = p.apply(vars, eqn, x, y) return jnp.sum(out * g.astype(out.dtype)) step_fn = jax.value_and_grad(loss_fn, argnums=[1, 2]) out, grads = jax.jit(step_fn)(vars, x, y) def loss_fn_ref(x, y): out = jnp.einsum(eqn, x, y) return jnp.sum(out * g.astype(out.dtype)) step_fn_ref = jax.value_and_grad(loss_fn_ref, argnums=[0, 1]) out_ref, grads_ref = jax.jit(step_fn_ref)(x, y) np.testing.assert_allclose(out, out_ref, atol=1e-02, rtol=1e-02) np.testing.assert_allclose(grads[0], grads_ref[0], atol=1e-02, rtol=1e-02) np.testing.assert_allclose(grads[1], grads_ref[1], atol=1e-02, rtol=1e-02) @parameterized.parameters( {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} ) def test_fp8_dot_general_injection(self, fp8_genre): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = functools.partial( fp8_ops.qdq, scale=jnp.ones((1,)), compute_dtype=jnp.float32, ) e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre) init_key, random_key = random.split(random.PRNGKey(seed=123), 2) x = cast_to_representable( random.uniform(random_key, (16, 32)), e4m3_dtype ) dy = cast_to_representable( random.uniform(random_key, (16, 64)), e5m2_dtype ) quant_cls = nn.Fp8DotGeneral if fp8_genre == 'OCP' else nn.NANOOFp8DotGeneralOp def run(fp8_injection, expected_shapes): p = nn.DenseGeneral(features=64, name='dense') if fp8_injection: p.dot_general_cls = quant_cls init_fn = jax.jit(p.init_with_output) y, initial_vars = init_fn(init_key, x) var_shapes = jax.tree_util.tree_map(jnp.shape, initial_vars) self.assertEqual(var_shapes, expected_shapes) def _train(variables, x): y = p.apply(variables, x) loss = y * dy return jnp.mean(loss) train_fn = jax.jit(jax.value_and_grad(_train, argnums=[0, 1])) outputs, grads = train_fn(initial_vars, x) return outputs, grads expected_shapes_original = { 'params': {'kernel': (32, 64), 'bias': (64,)}, } expected_shapes_new = { 'params': {'kernel': (32, 64), 'bias': (64,)}, fp8_ops.OVERWRITE_WITH_GRADIENT: { f'{quant_cls.__name__}_0': { 'input_amax_history': (1024,), 'kernel_amax_history': (1024,), 'output_grad_amax_history': (1024,), 'input_scale': (1,), 'kernel_scale': (1,), 'output_grad_scale': (1,), } }, } output1a, output1b = run(False, expected_shapes_original) output2a, output2b = run(True, expected_shapes_new) dw1, dw2 = output1b[0]['params']['kernel'], output2b[0]['params']['kernel'] dx1, dx2 = output1b[1], output2b[1] np.testing.assert_allclose(output1a, output2a, atol=1e-02) np.testing.assert_allclose(dw1, dw2, atol=1e-04) np.testing.assert_allclose(dx1, dx2, atol=1e-04) @parameterized.parameters( {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} ) def test_fp8_train_state(self, fp8_genre): key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) x = random.uniform(random_key, (16, 16), dtype=jnp.float32) quant_cls = nn.Fp8DotGeneral if fp8_genre == 'OCP' else nn.NANOOFp8DotGeneralOp dense = nn.DenseGeneral( features=32, use_bias=True, dot_general_cls=quant_cls ) init_fn = jax.jit(dense.init) variables = init_fn(init_key, x) opt = optax.adam(learning_rate=0.1) state = train_state.TrainState.create( params=variables, tx=opt, apply_fn=dense.apply ) def _roll_and_update(amax_h, update): return jnp.roll(amax_h, shift=-1, axis=0).at[0].set(update) def _train_loss(state, x, dy): def loss_fn(vars): y = state.apply_fn(vars, x) loss = y * dy.astype(y.dtype) return jnp.sum(loss) grad_fn = jax.grad(loss_fn) grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state train_fn = jax.jit(_train_loss) scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,)) scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,)) scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,)) e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre) e4m3_max = jnp.finfo(e4m3_dtype).max.astype(jnp.float32) e5m2_max = jnp.finfo(e5m2_dtype).max.astype(jnp.float32) for _ in range(5): key, random_key = random.split(key, 2) x = random.normal(random_key, (16, 16), dtype=jnp.float32) g = random.normal(random_key, (16, 32), dtype=jnp.float32) k = state.params['params']['kernel'] # Manually compute the expected amax history and scaling factors. amax_from_history_x = jnp.max(amax_history_x, axis=0) amax_from_history_k = jnp.max(amax_history_k, axis=0) amax_from_history_g = jnp.max(amax_history_g, axis=0) scale_x = fp8_ops.compute_scale(amax_from_history_x, scale_x, e4m3_max) scale_k = fp8_ops.compute_scale(amax_from_history_k, scale_k, e4m3_max) scale_g = fp8_ops.compute_scale(amax_from_history_g, scale_g, e5m2_max) amax_history_x = _roll_and_update(amax_history_x, jnp.max(jnp.abs(x))) amax_history_k = _roll_and_update(amax_history_k, jnp.max(jnp.abs(k))) amax_history_g = _roll_and_update(amax_history_g, jnp.max(jnp.abs(g))) state = train_fn(state, x, g) rtol, atol = 0.001, 0.001 fp8_vars = state.params[fp8_ops.OVERWRITE_WITH_GRADIENT][ f'{quant_cls.__name__}_0' ] np.testing.assert_allclose( fp8_vars['input_amax_history'], amax_history_x, rtol=rtol, atol=atol, ) np.testing.assert_allclose( fp8_vars['kernel_amax_history'], amax_history_k, rtol=rtol, atol=atol, ) np.testing.assert_allclose( fp8_vars['output_grad_amax_history'], amax_history_g, rtol=rtol, atol=atol, ) np.testing.assert_allclose(fp8_vars['input_scale'][0], scale_x) np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k) np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g) @parameterized.parameters( {'fp8_genre': 'OCP', 'use_jit': True}, {'fp8_genre': 'OCP', 'use_jit': False}, {'fp8_genre': 'NANOO', 'use_jit': True}, {'fp8_genre': 'NANOO', 'use_jit': False} ) def test_fp8_meta_dtype(self, fp8_genre, use_jit): if not use_jit and not fp8_ops.CAN_USE_EARRAY: self.skipTest("TODO: requires newer jax that has earray") f32 = jnp.dtype('float32') fmax32 = fp8_ops.fp32_max_grad e4m3_dtype, _ = get_fp8_dtypes(fp8_genre) e4m3_max = 448 if fp8_genre == 'OCP' else 240 # Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will # accumulate the grads of them. We expect the max op (rather than add op) # for the accumulation by converting them to fmax32 dtype. def outer(x, ah_f32, sf_f32): ah_fmax32 = jax.lax.convert_element_type(ah_f32, fmax32) sf_fmax32 = jax.lax.convert_element_type(sf_f32, fmax32) array_x = jnp.array([x], f32) def body_fun(carry, _): carry = fp8_ops.in_qdq(f32, e4m3_dtype, carry, sf_fmax32, ah_fmax32) return carry, None array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3) return array_x[0] outer_fn = jax.grad(outer, (0, 1, 2)) if use_jit: outer_fn = jax.jit(outer_fn) ah = jnp.array([0., 0., 0.], f32) sf = jnp.array([1.], f32) # 1st iteration grads, new_ah, new_sf = outer_fn(2.0, ah, sf) np.testing.assert_allclose(new_ah, [2., 0., 0.]) np.testing.assert_allclose(new_sf, [1.]) # 2nd iteration grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf) np.testing.assert_allclose(new_ah, [3., 0., 2.]) np.testing.assert_allclose(new_sf, [2. / e4m3_max]) # 3rd iteration grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf) np.testing.assert_allclose(new_ah, [4., 2., 3.]) np.testing.assert_allclose(new_sf, [3. / e4m3_max]) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/linen_transforms_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transforms tests.""" from functools import partial import operator from typing import Any from collections.abc import Callable, Sequence import unittest from absl.testing import absltest, parameterized from flax import errors from flax import linen as nn from flax import serialization from flax import struct from flax.core import copy, freeze, AxisMetadata from flax.linen.transforms import _HashableProxy import jax from jax import random import jax.numpy as jnp import numpy as np # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() # pylint: disable=attribute-defined-outside-init,unused-variable,g-wrong-blank-lines,g-bare-generic def tree_equals(x, y): return jax.tree_util.tree_all(jax.tree_util.tree_map(operator.eq, x, y)) def tree_allclose(x, y): return jax.tree_util.tree_all( jax.tree_util.tree_map(lambda x, y: np.all(np.isclose(x, y)), x, y) ) id_fn = lambda x: x class TransformedMLP(nn.Module): features: Sequence[int] transform: Callable = id_fn @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): # JIT the Module (it's __call__ fn by default.) x = self.transform(nn.Dense)(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) return x def decorated_MLP(transform: Callable = id_fn): class MLP(nn.Module): features: Sequence[int] @transform @nn.compact def __call__(self, inputs): x = inputs for i, feat in enumerate(self.features): # JIT the Module (it's __call__ fn by default.) x = nn.Dense(feat, name=f'layers_{i}')(x) if i != len(self.features) - 1: x = nn.relu(x) return x return MLP class TransformTest(parameterized.TestCase): def assert_keys_equal(self, key1, key2): self.assertEqual(key1.dtype, key2.dtype) np.testing.assert_array_equal(random.key_data(key1), random.key_data(key2)) def test_jit(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = TransformedMLP(features=[3, 4, 5]) jit_model = TransformedMLP(features=[3, 4, 5], transform=nn.jit) init_variables = normal_model.init(key2, x) y1 = normal_model.apply(init_variables, x) y2 = jit_model.apply(init_variables, x) self.assertTrue(np.all(y1 == y2)) def test_jit_decorated(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = decorated_MLP()(features=[3, 4, 5]) jit_model = decorated_MLP(nn.jit)(features=[3, 4, 5]) init_variables = normal_model.init(key2, x) y1 = normal_model.apply(init_variables, x) y2 = jit_model.apply(init_variables, x) self.assertTrue(np.all(y1 == y2)) def test_jit_init_fn(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(2)(x) @nn.jit def init_with_output(self, rngs, *args, **kwargs): return super().init_with_output(rngs, *args, **kwargs) Foo().init_with_output(random.key(0), jnp.ones((2, 3))) def test_remat(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = TransformedMLP(features=[3, 4, 5]) remat_model = TransformedMLP(features=[3, 4, 5], transform=nn.remat) init_variables = normal_model.init(key2, x) y1 = normal_model.apply(init_variables, x) y2 = remat_model.apply(init_variables, x) self.assertTrue(np.all(y1 == y2)) def test_remat_decorated(self): key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = decorated_MLP()(features=[3, 4, 5]) remat_model = decorated_MLP(nn.remat)(features=[3, 4, 5]) init_variables = normal_model.init(key2, x) y1 = normal_model.apply(init_variables, x) y2 = remat_model.apply(init_variables, x) self.assertTrue(np.all(y1 == y2)) def test_remat_kwargs(self): raise unittest.SkipTest('test breaks with grad') class ConditionalReLU(nn.Module): @nn.compact def __call__(self, input, apply_relu: bool = False): return nn.relu(input) if apply_relu else input key = random.key(0) x = jnp.ones((4, 4)) * -1 remat_model = nn.remat(ConditionalReLU)() p = remat_model.init(key, x) y = remat_model.apply(p, x, apply_relu=True) self.assertTrue(np.all(y == jnp.zeros_like(x))) # This next line crashes with a concretization error _ = jax.grad(lambda x: remat_model.apply(p, x, apply_relu=True))(x) def test_remat_static_argnums(self): test = self class Foo(nn.Module): train_is_static: bool @nn.compact def __call__(self, inputs, train: bool): if self.train_is_static: test.assertTrue(isinstance(train, bool)) else: test.assertTrue(isinstance(train, jnp.ndarray)) return nn.Dense(3, use_bias=False)(inputs) # set train as a static argument FooRemat = nn.remat(Foo, static_argnums=(2,)) foo = FooRemat(train_is_static=True) x = jnp.empty((1, 2)) variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) # set train as a non-static arguments FooRemat = nn.remat(Foo, static_argnums=()) foo = FooRemat(train_is_static=False) variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) def test_remat_decorator_static_argnums(self): test = self class FooTrainStatic(nn.Module): @partial(nn.remat, static_argnums=(2,)) @nn.compact def __call__(self, inputs, train: bool): test.assertTrue(isinstance(train, bool)) return nn.Dense(3, use_bias=False)(inputs) # set train as a static argument foo = FooTrainStatic() x = jnp.empty((1, 2)) variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) class FooTrainDynamic(nn.Module): @partial(nn.remat, static_argnums=()) @nn.compact def __call__(self, inputs, train: bool): test.assertTrue(isinstance(train, jnp.ndarray)) return nn.Dense(3, use_bias=False)(inputs) # set train as a non-static arguments foo = FooTrainDynamic() variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) def test_vmap(self): key1, key2, key3 = random.split(random.key(3), 3) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key2, (5, 4, 4)) def vmap(cls): return nn.vmap( cls, in_axes=(0,), variable_axes={'params': None}, split_rngs={'params': False}, ) normal_model = TransformedMLP(features=[3, 4, 5]) vmap_model = TransformedMLP(features=[3, 4, 5], transform=vmap) init_variables = normal_model.init(key3, x) # simulate vmap in python for comparison: y1 = jnp.vstack([ normal_model.apply(init_variables, x2[i])[None, ...] for i in np.arange(x2.shape[0]) ]) y2 = vmap_model.apply(init_variables, x2) np.testing.assert_allclose(y1, y2, atol=1e-6) def test_vmap_decorated(self): key1, key2, key3 = random.split(random.key(3), 3) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key2, (5, 4, 4)) def vmap(fn): return nn.vmap( fn, in_axes=(0,), variable_axes={'params': None}, split_rngs={'params': False}, ) normal_model = decorated_MLP()(features=[3, 4, 5]) vmap_model = decorated_MLP(vmap)(features=[3, 4, 5]) init_variables = normal_model.init(key3, x) # simulate vmap in python for comparison: y1 = jnp.vstack([ normal_model.apply(init_variables, x2[i])[None, ...] for i in np.arange(x2.shape[0]) ]) y2 = vmap_model.apply(init_variables, x2) np.testing.assert_allclose(y1, y2, atol=1e-6) def test_vmap_batchnorm(self): key1, key2, key3 = random.split(random.key(3), 3) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key2, (5, 4, 4)) def vmap(cls): return nn.vmap( cls, in_axes=(0,), variable_axes={'params': None, 'batch_stats': None}, split_rngs={'params': False}, axis_name='batch', ) class MlpBn(nn.Module): axis_name: Any = None @nn.compact def __call__(self, x): x = nn.Dense(3)(x) x = nn.BatchNorm(axis_name=self.axis_name, use_running_average=False)(x) return x normal_model = MlpBn() vmap_model = vmap(MlpBn)(axis_name='batch') init_variables = normal_model.init(key3, x) y1 = normal_model.apply( init_variables, x2.reshape((-1, 4)), mutable=['batch_stats'] )[0] y1 = y1.reshape((5, 4, 3)) y2 = vmap_model.apply(init_variables, x2, mutable=['batch_stats'])[0] np.testing.assert_allclose(y1, y2, atol=1e-5) def test_scan(self): class SimpleScan(nn.Module): features: int @nn.compact def __call__(self, c, xs): LSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'params': False}, ) return LSTM(self.features, name='lstm_cell')(c, xs) key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (5, 3, 2)) dummy_rng = random.key(0) init_carry = nn.LSTMCell(2).initialize_carry(dummy_rng, xs[0].shape) model = SimpleScan(2) init_variables = model.init(key2, init_carry, xs) # simulate scan in python for comparison: c = init_carry ys = [] lstmcell_variables = freeze( {'params': init_variables['params']['lstm_cell']} ) for i in range(xs.shape[0]): c, y = nn.LSTMCell(2).apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) y1 = jnp.vstack(ys) c2, y2 = model.apply(init_variables, init_carry, xs) np.testing.assert_allclose(y1, y2, atol=1e-7) np.testing.assert_allclose(c[0], c2[0], atol=1e-7) np.testing.assert_allclose(c[1], c2[1], atol=1e-7) def test_scan_decorated(self): class SimpleScan(nn.Module): features: int @partial( nn.scan, variable_broadcast='params', in_axes=(nn.broadcast, 0), split_rngs={'params': False}, ) @nn.compact def __call__(self, c, b, xs): assert b.shape == (4,) return nn.LSTMCell(self.features, name='lstm_cell')(c, xs) key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (4, 3, 2)) b = jnp.ones((4,)) dummy_rng = random.key(0) init_carry = nn.LSTMCell(2).initialize_carry(dummy_rng, xs[0].shape) model = SimpleScan(2) init_variables = model.init(key2, init_carry, b, xs) # simulate scan in python for comparison: c = init_carry ys = [] lstmcell_variables = freeze( {'params': init_variables['params']['lstm_cell']} ) for i in range(xs.shape[0]): c, y = nn.LSTMCell(2).apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) y1 = jnp.vstack(ys) c2, y2 = model.apply(init_variables, init_carry, b, xs) np.testing.assert_allclose(y1, y2, atol=1e-7) np.testing.assert_allclose(c[0], c2[0], atol=1e-7) np.testing.assert_allclose(c[1], c2[1], atol=1e-7) def test_scan_negative_axes(self): class Foo(nn.Module): @nn.compact def __call__(self, _, x): x = nn.Dense(4)(x) return None, x class Bar(nn.Module): @nn.compact def __call__(self, x): _, x = nn.scan( Foo, variable_broadcast='params', split_rngs=dict(params=False), in_axes=1, out_axes=-1, )()(None, x) return x y, variables = Bar().init_with_output( {'params': jax.random.PRNGKey(0)}, jax.random.normal(jax.random.PRNGKey(1), shape=[1, 2, 3]), ) params = variables['params'] self.assertEqual(y.shape, (1, 4, 2)) self.assertEqual(params['ScanFoo_0']['Dense_0']['kernel'].shape, (3, 4)) self.assertEqual(params['ScanFoo_0']['Dense_0']['bias'].shape, (4,)) def test_multiscope_lifting_simple(self): class Counter(nn.Module): @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value class Outer(nn.Module): @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x class Inner(nn.Module): outer_module: nn.Module @nn.compact def __call__(self, x): return self.outer_module(x) class Test(nn.Module): @nn.compact def __call__(self, x): outer_dense = nn.jit(Outer)(name='outer') # we share stateful outer module as arg to two different, transformed modules: inner = nn.jit(Inner)(outer_dense, name='inner1') inner2 = nn.jit(Inner)(outer_dense, name='inner2') res = inner(x) + inner2(x) return res x = jnp.ones((10, 10)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) ) def test_multiscope_lifting_simple_decorator(self): class Counter(nn.Module): @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value class Outer(nn.Module): @nn.jit @nn.compact def __call__(self, x): cntr = Counter(name='cntr')() return x class Inner(nn.Module): outer_module: nn.Module @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) class Test(nn.Module): @nn.compact def __call__(self, x): outer_dense = Outer(name='outer') # we share stateful outer module as arg to two different, transformed modules: inner = Inner(outer_dense, name='inner1') inner2 = Inner(outer_dense, name='inner2') res = inner(x) + inner2(x) return res x = jnp.ones((1, 1)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) ) def test_multiscope_lifting_argtree(self): class Counter(nn.Module): @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value class Outer(nn.Module): @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x class Inner(nn.Module): outer_module: Sequence[nn.Module] @nn.compact def __call__(self, x): return self.outer_module[0](x) + self.outer_module[1](x) class Test(nn.Module): @nn.compact def __call__(self, x): outer_dense1 = nn.jit(Outer)(name='outer1') outer_dense2 = nn.jit(Outer)(name='outer2') # we share stateful outer module as arg to two different, transformed modules: inner1 = nn.jit(Inner)((outer_dense1, outer_dense2), name='inner1') inner2 = nn.jit(Inner)((outer_dense1, outer_dense2), name='inner2') res = inner1(x) + inner2(x) return res x = jnp.ones((1, 1)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer1']['cntr']['foo'], jnp.array([4], jnp.int32) ) self.assertEqual( init_vars['counter']['outer2']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer2']['cntr']['foo'], jnp.array([4], jnp.int32) ) def test_multiscope_lifting_argtree_decorator(self): class Counter(nn.Module): @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value class Outer(nn.Module): @nn.jit @nn.compact def __call__(self, x): cntr = nn.jit(Counter)(name='cntr')() return x class Inner(nn.Module): outer_module: Sequence[nn.Module] @nn.jit @nn.compact def __call__(self, x): return self.outer_module[0](x) + self.outer_module[1](x) class Test(nn.Module): @nn.compact def __call__(self, x): outer_dense1 = Outer(name='outer1') outer_dense2 = Outer(name='outer2') # we share stateful outer module as arg to two different, transformed modules: inner1 = Inner((outer_dense1, outer_dense2), name='inner1') inner2 = Inner((outer_dense1, outer_dense2), name='inner2') res = inner1(x) + inner2(x) return res x = jnp.ones((1, 1)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer1']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer1']['cntr']['foo'], jnp.array([4], jnp.int32) ) self.assertEqual( init_vars['counter']['outer2']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer2']['cntr']['foo'], jnp.array([4], jnp.int32) ) def test_multiscope_lifting_simple_decorator_w_jit(self): # TODO: actually test jaxpr on a simpler module. class Counter(nn.Module): @nn.jit @nn.compact def __call__(self): v = self.variable('counter', 'foo', lambda: jnp.array([0])) v.value += jnp.array([1]) return v.value class Outer(nn.Module): @nn.jit @nn.compact def __call__(self, x): cntr = Counter(name='cntr')() return x class Inner(nn.Module): outer_module: nn.Module @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) class Test(nn.Module): @nn.jit @nn.compact def __call__(self, x): outer_dense = Outer(name='outer') # we share stateful outer module as arg to two different, transformed modules: inner = Inner(outer_dense, name='inner1') inner2 = Inner(outer_dense, name='inner2') res = inner(x) + inner2(x) return res x = jnp.ones((1, 1)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( init_vars['counter']['outer']['cntr']['foo'], jnp.array([2], jnp.int32) ) self.assertEqual( new_vars['counter']['outer']['cntr']['foo'], jnp.array([4], jnp.int32) ) def test_vmapped_outer_module(self): class Outer(nn.Module): @nn.jit @nn.compact def __call__(self, x): return nn.Dense(5)(x) class Inner(nn.Module): outer_module: nn.Module @partial( nn.vmap, in_axes=(0,), variable_axes={'params': 0}, split_rngs={'params': True}, ) @nn.jit @nn.compact def __call__(self, x): return self.outer_module(x) class Test(nn.Module): @nn.compact def __call__(self, x): outer_dense = Outer(name='outer') inner = Inner(outer_dense, name='inner1') inner2 = Inner(outer_dense, name='inner2') res = inner(x) + inner2(x) return res x = jnp.ones((3, 1, 2)) rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) y = Test(parent=None).apply(init_vars, x) self.assertEqual( init_vars['params']['outer']['Dense_0']['kernel'].shape, (3, 2, 5) ) self.assertEqual( init_vars['params']['outer']['Dense_0']['bias'].shape, (3, 5) ) self.assertEqual(y.shape, (3, 1, 5)) def test_module_transform_with_setup(self): class Foo(nn.Module): def setup(self): self.test = self.param('test', nn.initializers.ones_init(), ()) def __call__(self, x): return x * self.test FooVmap = nn.vmap( Foo, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}, ) variables = FooVmap().init(random.key(0), jnp.ones((4,))) self.assertEqual(variables['params']['test'].shape, (4,)) def test_nested_module_args_vmap(self): class A(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(3)(x) class B(nn.Module): A: nn.Module @nn.compact def __call__(self, x): return self.A(x) class C(nn.Module): B: nn.Module @partial( nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True} ) @nn.compact def __call__(self, x): return self.B(x) class D(nn.Module): @nn.compact def __call__(self, x): a = A() b = B(a) c = C(b) return c(x) key = random.key(0) x = jnp.ones((10, 10)) p = D().init(key, x) variable_shapes = jax.tree_util.tree_map(jnp.shape, p) self.assertEqual( variable_shapes['params']['A_0']['Dense_0']['kernel'], (10, 10, 3) ) self.assertEqual( variable_shapes['params']['A_0']['Dense_0']['bias'], (10, 3) ) def test_nested_module_args_vmap_2(self): class A(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(3)(x) class B(nn.Module): A: nn.Module @nn.compact def __call__(self, x): return self.A(x) class C(nn.Module): A: nn.Module B: nn.Module @partial( nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True} ) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class D(nn.Module): @nn.compact def __call__(self, x): a1 = A() a2 = A() b = B(a1) c = C(a2, b) return c(x) key = random.key(0) x = jnp.ones((10, 10)) p = D().init(key, x) variable_shapes = jax.tree_util.tree_map(jnp.shape, p) self.assertEqual( variable_shapes['params']['A_0']['Dense_0']['kernel'], (10, 10, 3) ) self.assertEqual( variable_shapes['params']['A_0']['Dense_0']['bias'], (10, 3) ) self.assertEqual( variable_shapes['params']['A_1']['Dense_0']['kernel'], (10, 10, 3) ) self.assertEqual( variable_shapes['params']['A_1']['Dense_0']['bias'], (10, 3) ) def test_nested_setup_calls_count(self): D = 3 N = 4 setup_cntr = 0 call_cntr = 0 class Repeat(nn.Module): mdl_def: Any def setup(self): self.lyrs = [self.mdl_def() for _ in range(N)] @nn.remat # we just use remat as a convenient test of transform logic def __call__(self, x): for lyr in self.lyrs: lyr(x) return x class Counter(nn.Module): def setup(self): nonlocal setup_cntr setup_cntr += 1 self.dense = nn.Dense(2, use_bias=False) @nn.remat def __call__(self, x): nonlocal call_cntr call_cntr += 1 return self.dense(x) def nested_repeat(mdl): for _ in range(D): mdl = partial(Repeat, mdl) return mdl() _ = nested_repeat(Counter).init(random.key(0), jnp.ones((2,))) # setup_cntr == 128 due to 1 call in Counter.setup by _validate_setup # and 1 further "real" call. self.assertEqual(setup_cntr, 128) self.assertEqual(call_cntr, 64) def test_multimethod_setup_calls(self): cntr = 0 class A(nn.Module): def setup(self): nonlocal cntr cntr += 1 self.d = nn.Dense(2) @nn.remat def foo(self, x): return self.d(x) @nn.remat def bar(self, x): return self.d(x) class B(nn.Module): def setup(self): self.a = A() def __call__(self, x): y1 = self.a.foo(x) y2 = self.a.bar(x) return y1, y2 key = random.key(0) x = jnp.ones((2,)) (y1, y2), _ = B().init_with_output(key, x) np.testing.assert_array_equal(y1, y2) # cntr == 4 due to: # 1 call by _validate_setup # 1 call for the setup() outside transform boundary # and two further "real" calls in transform boundaries self.assertEqual(cntr, 4) def test_toplevel_submodule_adoption_transform(self): class A(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(3)(x) class B(nn.Module): A: nn.Module @nn.compact def __call__(self, x): return self.A(x) class C(nn.Module): A: nn.Module B: nn.Module @partial( nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True} ) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class Csimple(nn.Module): A: nn.Module B: nn.Module @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class D(nn.Module): @nn.compact def __call__(self, x): a1 = A() a2 = A() b = B(a1) c = C(a2, b) return c(x) key = random.key(0) x = jnp.ones((10, 10)) p1 = D().init(key, x) y1 = D().apply(p1, x) a1 = A() a2 = A() b = B(a1) p2 = freeze({ 'params': { 'A': p1['params']['A_0'], 'B': { 'A': p1['params']['A_1'], }, } }) # Test method wrapper transform. y2 = C(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y2, atol=1e-7) # Test class transform. Ctrafo = nn.vmap( Csimple, variable_axes={'params': 0}, split_rngs={'params': True} ) y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7) def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): @nn.compact def __call__(self, c, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return c, x class B(nn.Module): A: Any @nn.compact def __call__(self, c, x): return self.A['foo'](*self.A['bar'](c, x)) a = A() As = {'foo': A(), 'bar': A()} b = nn.scan( B, in_axes=0, variable_carry='counter', variable_broadcast='params', split_rngs={'params': False}, )(As) key = random.key(0) x = jnp.ones((10, 2)) p = B(As).init(key, x, x) y, cntrs = b.apply(p, x, x, mutable='counter') ref_cntrs = { 'counter': { 'A_bar': { 'i': jnp.array(11.0), }, 'A_foo': { 'i': jnp.array(11.0), }, }, } self.assertTrue( jax.tree_util.tree_all( jax.tree_util.tree_map( lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), cntrs, ref_cntrs, ) ) ) def test_partially_applied_module_constructor_transform(self): k = random.key(0) x = jnp.ones((3, 4, 4)) dense = partial(nn.Dense, use_bias=False) vmap_dense = nn.vmap( dense, variable_axes={'params': 0}, split_rngs={'params': True} )(4) init_vars = vmap_dense.init(k, x) init_vars_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) ref_var_shapes = { 'params': { 'kernel': (3, 4, 4), }, } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) def test_partial_module_method(self): k = random.key(0) x = jnp.ones((3, 4, 4)) class Foo(nn.Module): @nn.compact def inner(self, x): return nn.Dense(2, use_bias=False)(x) def __call__(self, x): return nn.vmap( partial(Foo.inner), variable_axes={'params': 0}, split_rngs={'params': True}, )(self, x) init_vars = Foo().init(k, x) init_vars_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) ref_var_shapes = { 'params': {'Dense_0': {'kernel': (3, 4, 2)}}, } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) def test_variable_in_args_transform(self): class Test(nn.Module): @nn.jit @nn.compact def __call__(self, x): baz = self.variable('test', 'baz', jnp.zeros, x.shape) y = self.mutate_variable_in_method(x, baz) return y @nn.jit def mutate_variable_in_method(self, x, baz): baz.value += x return baz.value k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( variables['test']['baz'], jnp.array([ 1.0, ]), atol=1e-7, ) y, variables = Test().apply(variables, x, mutable=['test']) np.testing.assert_allclose( variables['test']['baz'], jnp.array([ 2.0, ]), atol=1e-7, ) def test_module_instance_in_args_transform(self): class Inner(nn.Module): @nn.jit @nn.compact def __call__(self, x): baz = self.variable('test', 'baz', jnp.zeros, x.shape) baz.value += x return baz.value class Test(nn.Module): @nn.jit @nn.compact def __call__(self, x): inner = Inner(name='inner') y = self.call_instance_arg_in_method(x, inner) return y @nn.jit def call_instance_arg_in_method(self, x, inner): return inner(x) k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( variables['test']['inner']['baz'], jnp.array([ 1.0, ]), atol=1e-7, ) y, variables = Test().apply(variables, x, mutable=['test']) np.testing.assert_allclose( variables['test']['inner']['baz'], jnp.array([ 2.0, ]), atol=1e-7, ) def test_module_instance_in_args_transform_nested(self): class Inner(nn.Module): @nn.jit @nn.compact def __call__(self, x): baz = self.variable('test', 'baz', jnp.zeros, x.shape) baz.value += x return baz.value class Outer(nn.Module): @nn.jit @nn.compact def __call__(self, inner, x): y = self.call_instance_arg_in_method(x, inner) return y @nn.jit def call_instance_arg_in_method(self, x, inner): return inner(x) class Test(nn.Module): @nn.jit @nn.compact def __call__(self, x): inner = Inner(name='inner') outer = Outer(name='outer') return outer(inner, x) k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( variables['test']['inner']['baz'], jnp.array([ 1.0, ]), atol=1e-7, ) y, variables = Test().apply(variables, x, mutable=['test']) np.testing.assert_allclose( variables['test']['inner']['baz'], jnp.array([ 2.0, ]), atol=1e-7, ) def test_nested_variable_passing(self): class NestedVarUser(nn.Module): somevar: nn.Variable @nn.jit @nn.compact def __call__(self, x): self.somevar.value += x return x class VarUser(nn.Module): somevar: nn.Variable @nn.jit @nn.compact def __call__(self, x): return NestedVarUser(self.somevar)(x) class VarPasser(nn.Module): @nn.jit @nn.compact def __call__(self, x): baz = self.variable('test', 'baz', jnp.zeros, x.shape) y = VarUser(baz)(x) return y k = random.key(0) x = jnp.ones((1,)) variables = VarPasser().init(k, x) np.testing.assert_allclose( variables['test']['baz'], jnp.array([ 1.0, ]), atol=1e-7, ) y, variables = VarPasser().apply(variables, x, mutable=['test']) np.testing.assert_allclose( variables['test']['baz'], jnp.array([ 2.0, ]), atol=1e-7, ) def test_returned_module_warning(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return x class Bar(nn.Module): @nn.compact def __call__(self, x): f = self._helper() return f(x) @nn.jit def _helper(self): return Foo() b = Bar() with self.assertRaises(errors.TransformedMethodReturnValueError): b.apply({}, jnp.ones(2)) def test_returned_variable_warning(self): class Bar(nn.Module): @nn.compact def __call__(self, x): f = self._helper() return f(x) @nn.jit def _helper(self): return nn.Variable(None, None, None, False) b = Bar() with self.assertRaises(errors.TransformedMethodReturnValueError): b.apply({}, jnp.ones(2)) def test_nowrap(self): class Bar(nn.Module): @nn.compact def __call__(self, x): return self._helper(x) @nn.nowrap def _helper(self, x): if len(nn.module._context.module_stack) > 2: # pylint: disable=protected-access raise ValueError('Module stack too deep.') return x b = Bar() b.apply({}, jnp.ones(2)) def test_map_variables_tied_autoencoder(self): def trans(variables): return jax.tree_util.tree_map(lambda x: x.T, variables) class TiedAutencoder(nn.Module): features: int latents: int @nn.compact def _call(self, x, decode): def f(self): return nn.Dense( self.features if decode else self.latents, use_bias=False )(x) if decode: map_fn = trans else: map_fn = lambda x: x return nn.map_variables(f, 'params', map_fn, map_fn, mutable=True)(self) def encode(self, x): return self._call(x, False) def decode(self, x): return self._call(x, True) def __call__(self, x): return self.decode(self.encode(x)) x = jnp.ones((2, 4)) ae = TiedAutencoder(4, 5) variables = ae.init(random.key(0), x) param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(param_shapes, {'Dense_0': {'kernel': (4, 5)}}) def test_map_variables_bit_weights(self): class BitWeights(nn.Module): @nn.compact def __call__(self, x): def sign(x): return jax.tree_util.tree_map(jnp.sign, x) BitDense = nn.map_variables(nn.Dense, 'params', sign, init=True) return BitDense(4)(x) bw = BitWeights() x = jnp.ones((2, 4)) y, variables = bw.init_with_output(random.key(0), x) y_2 = bw.apply(variables, x) np.testing.assert_allclose(y, y_2) def test_remat_scan(self): class BigModel(nn.Module): @nn.compact def __call__(self, x): DenseStack = nn.remat_scan(nn.Dense, lengths=(100,)) return DenseStack(8, name='dense_stack')(x) x = jnp.ones((2, 8)) model = BigModel() variables = model.init(random.key(0), x) param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(param_shapes['dense_stack']['kernel'], (100, 8, 8)) self.assertEqual(param_shapes['dense_stack']['bias'], (100, 8)) y = model.apply(variables, x) self.assertEqual(y.shape, (2, 8)) def test_vjp(self): class Bar(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('test', nn.initializers.constant(0.5), ()) self.variable('state', 'counter', lambda: 0) return p * x * y class Foo(nn.Module): @nn.compact def __call__(self, x, y): z, bwd = nn.vjp(Bar.__call__, Bar(), x, y) return bwd(jnp.ones(z.shape)) x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) params = Foo().init(random.key(0), x, y) params_grad, x_grad, y_grad = Foo().apply(params, x, y) self.assertEqual( params_grad, { 'params': nn.FrozenDict({'test': 32.0}), }, ) np.testing.assert_allclose(x_grad, [2.0, 2.5, 3.0]) np.testing.assert_allclose(y_grad, [0.5, 1.0, 1.5]) def test_jvp(self): class Bar(nn.Module): @nn.compact def __call__(self, x): p = self.param('test', nn.initializers.zeros, ()) self.variable('state', 'counter', lambda: 0) return p * x class Foo(nn.Module): @nn.compact def __call__(self, x): bar = Bar() vars_t = jax.tree_util.tree_map( jnp.ones_like, bar.variables.get('params', {}) ) _, out_t = nn.jvp( Bar.__call__, bar, (x,), (jnp.zeros_like(x),), {'params': vars_t} ) return out_t x = jnp.ones((3,)) params = Foo().init(random.key(0), x) y_t = Foo().apply(params, x) np.testing.assert_allclose(y_t, jnp.ones_like(x)) def test_complicated_alias_mutation(self): class A(nn.Module): b: nn.Module @nn.jit @nn.compact def __call__(self, x): return self.b(x) class B(nn.Module): c: nn.Module @nn.jit @nn.compact def __call__(self, x): y = C(name='outer_c')(x) z = self.c(x) return z class C(nn.Module): @nn.jit @nn.compact def __call__(self, x): initialized = self.has_variable('muts', 'v') v = self.variable('muts', 'v', lambda: jnp.zeros_like(x)) if initialized: v.value += x return x a = A(b=B(c=C())) k = random.key(0) x = jnp.ones((1,), jnp.float32) vs = a.init(k, x) y, vs_new = a.apply( vs, x, mutable=[ 'muts', ], ) np.testing.assert_array_equal( vs_new['muts']['b']['c']['v'], jnp.array([1.0], jnp.float32) ) np.testing.assert_array_equal( vs_new['muts']['b']['outer_c']['v'], jnp.array([1.0], jnp.float32) ) def test_custom_vjp(self): class Foo(nn.Module): @nn.compact def __call__(self, x): def f(mdl, x): return mdl(x) def fwd(mdl, x): return nn.vjp(f, mdl, x) def bwd(vjp_fn, y_t): params_t, input_t = vjp_fn(y_t) params_t = jax.tree_util.tree_map(jnp.sign, params_t) return params_t, input_t sign_grad = nn.custom_vjp(f, forward_fn=fwd, backward_fn=bwd) return sign_grad(nn.Dense(1), x).reshape(()) x = jnp.ones((2,)) variables = Foo().init(random.key(0), x) grad = jax.grad(Foo().apply)(variables, x) for grad_leaf in jax.tree_util.tree_leaves(grad): self.assertTrue(jnp.all(jnp.abs(grad_leaf) == 1.0)) def test_transform_with_setup_and_methods_on_submodules(self): # This is the archetypal example motivating the introduction of # SetupState as a triple-enum to handle multiple setup() calls # across transform boundaries and scope reuse. class Foo(nn.Module): def setup(self): self.inner = nn.Dense(2) def helper(self, x, m): return m(x) def __call__(self, x): return self.helper(x, self.inner) k = random.key(0) x = jnp.ones((2,)) vs_foo = Foo().init(k, x) class Bar(nn.Module): def setup(self): self.inner = nn.Dense(2) @nn.jit def helper(self, x, m): return m(x) @nn.jit def __call__(self, x): return self.helper(x, self.inner) vs_bar = Bar().init(k, x) self.assertTrue( tree_equals( jax.tree_util.tree_map(jnp.shape, vs_foo), jax.tree_util.tree_map(jnp.shape, vs_bar), ) ) def test_transform_methods_on_submodules_still_reserve_names(self): class Foo(nn.Module): @nn.jit def helper(self, x, m): conflicting_a = nn.Dense(2, name='a') return m(x) @nn.jit @nn.compact def __call__(self, x): a = nn.Dense(2, name='a') return self.helper(x, a) k = random.key(0) x = jnp.ones((2,)) with self.assertRaises(errors.NameInUseError): vs = Foo().init(k, x) def test_transform_setup_still_reserve_names(self): class Identity(nn.Module): @nn.compact def __call__(self, x): return x class Test(nn.Module): def setup(self): self.sub = Identity() self.sub = Identity() @nn.jit def __call__(self, x): return x k = random.key(0) x = jnp.array([1.0]) with self.assertRaises(errors.NameInUseError): y = Test().init(k, x) def test_transform_with_setup_and_methods_on_submodule_pytrees(self): class Foo(nn.Module): def setup(self): self.inners = [nn.Dense(2), nn.Dense(2)] def helper(self, x, ms): return ms[0](x) + ms[1](x) @nn.fold_rngs def __call__(self, x): return self.helper(x, self.inners) class JitFoo(nn.Module): def setup(self): self.inners = [nn.Dense(2), nn.Dense(2)] def helper(self, x, ms): return ms[0](x) + ms[1](x) @nn.jit def __call__(self, x): return self.helper(x, self.inners) k = random.key(0) x = jnp.ones((2,)) vs_0 = Foo().init(k, x) vs_1 = JitFoo().init(k, x) self.assertTrue(tree_allclose(vs_0, vs_1)) def test_transform_setup_still_reserve_names_pytrees(self): class Identity(nn.Module): @nn.compact def __call__(self, x): return x class Test(nn.Module): def setup(self): self.subs = [Identity(), Identity()] self.subs = [Identity(), Identity()] @nn.jit def __call__(self, x): return x k = random.key(0) x = jnp.array([1.0]) msg = r'Could not create submodule "subs_0".*' with self.assertRaisesRegex(errors.NameInUseError, msg): y = Test().init(k, x) def test_scan_of_setup_parameter(self): class Body(nn.Module): def setup(self): self.dense = nn.Dense(1) self.p = self.param('p', lambda k: jnp.ones((1,))) def __call__(self, x): return self.dense(x) + self.p, None scanbody = nn.scan( Body, variable_axes={'params': 0}, split_rngs={'params': True}, length=2 ) k = random.key(0) x = jnp.ones((1,)) vs = scanbody().init(k, x) y = scanbody().apply(vs, x) def test_multi_method_class_transform(self): class Foo(nn.Module): def setup(self): self.dense0 = nn.Dense(2) self.dense1 = nn.Dense(2) def method_0(self, x): return self.dense0(x), x def method_1(self, x, y): return self.dense1(x) + y, None class Bar(nn.Module): @nn.compact def __call__(self, x): ScanFoo = nn.scan( Foo, methods={ 'method_0': dict( variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=nn.broadcast, out_axes=0, length=3, ), 'method_1': dict( variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0, length=3, ), }, ) sf = ScanFoo() y, ys = sf.method_0(x) z, _ = sf.method_1(y, ys) return z k = random.key(0) x = random.uniform(random.key(1), (2, 2)) vs = Bar().init(k, x) y = Bar().apply(vs, x) def test_compact_aliasing_collision(self): class Foo(nn.Module): m1: nn.Module m2: nn.Module @nn.compact def __call__(self, x): x = self.m2(self.m1(x)) return x class Bar(nn.Module): @nn.compact def __call__(self, x): dense = nn.Dense(2) x = nn.jit(Foo)(dense, dense)(x) return x k = random.key(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) def test_compact_aliasing_collision_arg_and_attrib(self): class Foo(nn.Module): m1: nn.Module @nn.compact def __call__(self, x, m2): x = m2(self.m1(x)) return x class Bar(nn.Module): @nn.compact def __call__(self, x): dense = nn.Dense(2) x = nn.jit(Foo)(dense)(x, dense) return x k = random.key(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) def test_jit_with_setup_helpers(self): class Foo(nn.Module): def setup(self): self.a = nn.Dense(2) self.setup_helper() def setup_helper(self): self.b = nn.Dense(2) @nn.fold_rngs def __call__(self, x): return self.b(self.a(x)) class JitFoo(nn.Module): def setup(self): self.a = nn.Dense(2) self.setup_helper() def setup_helper(self): self.b = nn.Dense(2) @nn.jit def __call__(self, x): return self.b(self.a(x)) k = random.key(0) x = jnp.ones((2, 2)) vs = JitFoo().init(k, x) y0 = JitFoo().apply(vs, x) vs = Foo().init(k, x) y1 = Foo().apply(vs, x) np.testing.assert_array_equal(y0, y1) def test_jit_kwargs(self): class Foo(nn.Module): @nn.jit def __call__(self, a: jax.Array, b: jax.Array): return a + b m = Foo() y = m.apply({}, jnp.array(1.0), b=jnp.array(2.0)) np.testing.assert_array_equal(y, jnp.array(3.0)) def test_jit_static_argnames(self): s = None class Foo(nn.Module): @partial(nn.jit, static_argnames=['b']) def __call__(self, a: jax.Array, b: str): nonlocal s s = b return a m = Foo() y = m.apply({}, jnp.array(1.0), b='hi') self.assertEqual(s, 'hi') np.testing.assert_array_equal(y, jnp.array(1.0)) def test_jit_and_sow(self): class Inner(nn.Module): @nn.compact def __call__(self, x): self.sow('intermediates', 'loss', jnp.sum(x)) return x + 1 class Outer(nn.Module): def setup(self): self.inner = Inner() @nn.jit def __call__(self, x): return self.inner(x) m = Outer() x = jnp.ones((2, 2)) vs = m.init(random.key(0), x) y, updates = m.apply(vs, x, mutable=['intermediates']) np.testing.assert_array_equal( updates['intermediates']['inner']['loss'], 4.0 ) np.testing.assert_array_equal(y, 2) def test_fold_rngs(self): class Foo(nn.Module): def __call__(self, use_jit: bool): def f(foo: Foo): return foo.make_rng('params') if use_jit: key = nn.jit(f)(self) else: key = nn.fold_rngs(f)(self) return key foo = Foo() key_jit = foo.apply({}, True, rngs={'params': random.key(0)}) key_fold_rngs = foo.apply({}, False, rngs={'params': random.key(0)}) self.assert_keys_equal(key_jit, key_fold_rngs) def test_same_key(self): class Block(nn.Module): @nn.jit @nn.compact def __call__(self, carry, inputs): # dump_rng_info(self) key = self.make_rng('params') # y = jax.random.uniform(self.make_rng('params'), (2,)) return carry, key class Transformer(nn.Module): @nn.compact def __call__(self): num_blocks = 10 carry, key = nn.scan( Block, variable_axes={'params': 0}, split_rngs={'params': True}, # length=num_blocks, )()(None, jnp.arange(num_blocks)) return key model = Transformer() keys1, _ = model.init_with_output(jax.random.key(1)) keys2, _ = model.init_with_output(jax.random.key(1)) keys3, _ = model.init_with_output(jax.random.key(1)) keys4, _ = model.init_with_output(jax.random.key(1)) self.assert_keys_equal(keys1, keys2) self.assert_keys_equal(keys2, keys3) self.assert_keys_equal(keys2, keys3) def test_jit_repr_hash(self): n = 0 @partial(jax.jit, static_argnums=0) def f(obj): nonlocal n n += 1 return None f(_HashableProxy.from_module(nn.Dense(10))) self.assertEqual(n, 1) f(_HashableProxy.from_module(nn.Dense(10))) self.assertEqual(n, 1) f(_HashableProxy.from_module(nn.Dense(20))) self.assertEqual(n, 2) f(_HashableProxy.from_module(nn.Dense(20))) self.assertEqual(n, 2) def test_jit_reuse(self): n = 0 class Foo(nn.Module): @nn.jit def __call__(self, x): nonlocal n n += 1 return x x = jnp.array(1.0) m = Foo() self.assertEqual(n, 0) y = m.apply({}, x) self.assertEqual(n, 1) y = m.apply({}, x) self.assertEqual(n, 1) def test_jit_recursive(self): n = 0 class Foo(nn.Module): @partial(nn.jit, static_argnames='recurse_once') def __call__(self, x, *, recurse_once: bool = True): nonlocal n n += 1 if recurse_once: x = self(x, recurse_once=False) return x + 1 x = jnp.array(1.0) m = Foo() self.assertEqual(n, 0) y = m.apply({}, x) self.assertEqual(n, 2) y = m.apply({}, x) self.assertEqual(n, 2) @parameterized.named_parameters(('class', True), ('method', False)) def test_jit_reuse_hash(self, jit_class: bool): n = 0 class Foo(nn.Module): key: int def __call__(self, x): nonlocal n n += 1 return x if jit_class: Foo = nn.jit(Foo) else: # jit method Foo.__call__ = nn.jit(Foo.__call__) x = jnp.array(1.0) self.assertEqual(n, 0) y = Foo(1).apply({}, x) self.assertEqual(n, 1) y = Foo(1).apply({}, x) self.assertEqual(n, 1) y = Foo(2).apply({}, x) self.assertEqual(n, 2) y = Foo(2).apply({}, x) self.assertEqual(n, 2) @parameterized.named_parameters(('class', True), ('method', False)) def test_jit_reuse_submodules(self, jit_class: bool): test = self n = 0 key = None name = None class Foo(nn.Module): key: int def __call__(self, x): nonlocal n, key, name n += 1 key = self.key name = self.name return x if jit_class: Foo = nn.jit(Foo) else: # jit method Foo.__call__ = nn.jit(Foo.__call__) class Parent(nn.Module): @nn.compact def __call__(self, x): for i in range(3): m = Foo(i) y = m(x) test.assertEqual(key, i) test.assertEndsWith(name, f'Foo_{i}') test.assertEqual(n, i + 1) x = jnp.array(1.0) self.assertEqual(n, 0) y = Parent().apply({}, x) @parameterized.named_parameters(('class', True), ('method', False)) def test_jit_stateful_submodules(self, jit_class: bool): n = 0 class Foo(nn.Module): key: int @nn.compact def __call__(self, x): nonlocal n n += 1 count = self.variable('counts', 'count', lambda: 0) if not self.is_initializing(): count.value += 1 return x if jit_class: Foo = nn.jit(Foo) else: # jit method Foo.__call__ = nn.jit(Foo.__call__) class Parent(nn.Module): @nn.compact def __call__(self, x): for _ in range(3): m = Foo(0) x = m(x) return x m = Parent() x = jnp.array(1.0) counts = m.init({}, x)['counts'] self.assertEqual(n, 1) y, updates = m.apply({'counts': counts}, x, mutable=['counts']) counts = updates['counts'] self.assertEqual(n, 2) for count in jax.tree.leaves(counts): self.assertEqual(count, 1) y, updates = m.apply({'counts': counts}, x, mutable=['counts']) counts = updates['counts'] self.assertEqual(n, 2) for count in jax.tree.leaves(counts): self.assertEqual(count, 2) @parameterized.named_parameters(('class', True), ('method', False)) def test_jit_reuse_nested_submodules(self, jit_class: bool): test = self n = 0 class Foo(nn.Module): key: int def __call__(self, x): return x class Parent(nn.Module): submodules: list[Foo] @nn.compact def __call__(self, x): nonlocal n n += 1 for i, m in enumerate(self.submodules): x = m(x) return x if jit_class: Parent = nn.jit(Parent) else: # jit method Parent.__call__ = nn.jit(Parent.__call__) x = jnp.array(1.0) self.assertEqual(n, 0) y = Parent([Foo(1), Foo(2)]).apply({}, x) self.assertEqual(n, 1) y = Parent([Foo(1), Foo(2)]).apply({}, x) self.assertEqual(n, 1) y = Parent([Foo(3), Foo(4)]).apply({}, x) self.assertEqual(n, 2) y = Parent([Foo(3), Foo(4)]).apply({}, x) self.assertEqual(n, 2) def test_jit_hashes_serializable_types(self): class Node: def __init__(self, a: int): self.a = a def __hash__(self): # test object is not being passed as static raise Exception('this should not be called') def __eq__(self, __value, /): raise Exception('this should not be called') def to_dict(node: Node): return {'a': node.a} def from_dict(node: Node, d: dict[str, Any]): node.a = d['a'] return node serialization.register_serialization_state(Node, to_dict, from_dict) try: n = 0 class Foo(nn.Module): node: Node @nn.jit @nn.compact def __call__(self, x): nonlocal n n += 1 return self.node.a + nn.Dense(2)(x) m = Foo(Node(1)) m.init_with_output(random.key(0), jnp.ones((2, 2))) self.assertEqual(n, 1) m.init_with_output(random.key(0), jnp.ones((2, 2))) self.assertEqual(n, 1) finally: del serialization._STATE_DICT_REGISTRY[Node] def test_while_loop(self): class Foo(nn.Module): @nn.compact def __call__(self, x): key_zero = random.key(0) key_zero = jnp.broadcast_to(key_zero, (2, *key_zero.shape)) self.param('inc', lambda _: 1) self.put_variable('state', 'acc', 0) self.put_variable('state', 'rng_params', key_zero) self.put_variable('state', 'rng_loop', key_zero) def cond_fn(mdl, c): acc = mdl.get_variable('state', 'acc') return acc < x def body_fn(mdl, c): i = mdl.get_variable('state', 'acc') p_rng = mdl.make_rng('params') l_rng = mdl.make_rng('loop') mdl.put_variable( 'state', 'rng_params', mdl.get_variable('state', 'rng_params').at[i].set(p_rng), ) mdl.put_variable( 'state', 'rng_loop', mdl.get_variable('state', 'rng_loop').at[i].set(l_rng), ) inc = mdl.get_variable('params', 'inc') mdl.put_variable('state', 'acc', i + inc) return c return nn.while_loop( cond_fn, body_fn, self, (), carry_variables='state', split_rngs={'params': False, 'loop': True}, ) x = 2 mdl = Foo() _, vars = mdl.apply( {}, x, mutable=True, rngs={'params': random.key(1), 'loop': random.key(2)}, ) self.assertEqual(vars['state']['acc'], x) self.assertTrue( jnp.equal(vars['state']['rng_params'][0], vars['state']['rng_params'][1]) ) with jax.debug_key_reuse(False): self.assertFalse( jnp.equal( vars['state']['rng_loop'][0], vars['state']['rng_loop'][1], ) ) def test_while_loop_denylist_split_rngs(self): def cond_fn(module, carry): del module return carry < 10 def body_fn(module, carry): rng = module.make_rng('random') return carry + 1 class Foo(nn.Module): def setup(self): pass f = Foo().bind({}, rngs={'random': jax.random.PRNGKey(0)}) init = jnp.zeros(()) result = nn.while_loop( cond_fn, body_fn, f, init, split_rngs={ 'params': False, nn.DenyList(('params',)): True, }, ) np.testing.assert_array_equal(result, jnp.array(10.0)) def test_cond(self): class Foo(nn.Module): @nn.compact def __call__(self, x, pred): self.variable('state', 'true_count', lambda: 0) self.variable('state', 'false_count', lambda: 0) def true_fn(mdl, x): mdl.variable('state', 'true_count').value += 1 return nn.Dense(2, name='dense')(x) def false_fn(mdl, x): mdl.variable('state', 'false_count').value += 1 return -nn.Dense(2, name='dense')(x) return nn.cond(pred, true_fn, false_fn, self, x) def test_switch(self): class Foo(nn.Module): @nn.compact def __call__(self, x, pred): self.variable('state', 'a_count', lambda: 0) self.variable('state', 'b_count', lambda: 0) self.variable('state', 'c_count', lambda: 0) def a_fn(mdl, x): mdl.variable('state', 'a_count').value += 1 return nn.Dense(2, name='dense')(x) def b_fn(mdl, x): mdl.variable('state', 'b_count').value += 1 return -nn.Dense(2, name='dense')(x) def c_fn(mdl, x): mdl.variable('state', 'c_count').value += 1 return nn.Dense(2, name='dense')(x) return nn.switch(pred, [a_fn, b_fn, c_fn], self, x) x = jnp.ones((1, 3)) foo = Foo() y1, vars = foo.init_with_output(random.key(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 0}) np.testing.assert_allclose(y1, -y2) y3, updates = foo.apply(vars, x, 2, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) def test_switch_multihead(self): class Foo(nn.Module): def setup(self) -> None: self.heads = [ nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), nn.Sequential([nn.Dense(11), nn.Dense(5)]), nn.Dense(5), ] @nn.compact def __call__(self, x, index): def head_fn(i): def fn(mdl, x): mdl.variable('state', f'{i}_count', lambda: -1).value += 1 return mdl.heads[i](x) return fn branches = [head_fn(i) for i in range(len(self.heads))] if self.is_mutable_collection('params'): for branch in branches: _ = branch(self, x) return nn.switch(index, branches, self, x) x = jnp.ones((1, 3)) foo = Foo() y1, vars = foo.init_with_output(random.key(0), x, 0) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 0, '2_count': 0}) y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 1, '2_count': 0}) y3, updates = foo.apply(vars, x, 2, mutable='state') vars = copy(vars, updates) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 1, '2_count': 1}) self.assertEqual( vars['params']['heads_0']['layers_0']['kernel'].shape, (3, 10) ) self.assertEqual(vars['params']['heads_0']['layers_0']['bias'].shape, (10,)) self.assertEqual( vars['params']['heads_0']['layers_1']['kernel'].shape, (10, 7) ) self.assertEqual(vars['params']['heads_0']['layers_1']['bias'].shape, (7,)) self.assertEqual( vars['params']['heads_0']['layers_2']['kernel'].shape, (7, 5) ) self.assertEqual(vars['params']['heads_0']['layers_2']['bias'].shape, (5,)) self.assertEqual( vars['params']['heads_1']['layers_0']['kernel'].shape, (3, 11) ) self.assertEqual(vars['params']['heads_1']['layers_0']['bias'].shape, (11,)) self.assertEqual( vars['params']['heads_1']['layers_1']['kernel'].shape, (11, 5) ) self.assertEqual(vars['params']['heads_1']['layers_1']['bias'].shape, (5,)) self.assertEqual(vars['params']['heads_2']['kernel'].shape, (3, 5)) self.assertEqual(vars['params']['heads_2']['bias'].shape, (5,)) def test_lift_instance_error(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return nn.checkpoint(nn.Dense(2))(x) with self.assertRaises(errors.TransformTargetError): Foo().init(random.key(0), jnp.zeros((2, 3))) def test_scan_compact_count(self): class Foo(nn.Module): num_layers: int = 5 @nn.compact def __call__(self, x): def body_fn(mdl, x): return nn.Dense(features=x.shape[-1])(x), () x, _ = nn.scan( body_fn, length=self.num_layers, variable_axes={'params': 0}, split_rngs={'params': True}, )(self, x) return x m = Foo() x = jnp.ones((3,)) v = m.init(jax.random.key(0), x) self.assertEqual(v['params']['Dense_0']['kernel'].shape, (5, 3, 3)) m.apply(v, x) def test_bound_methods_in_direct_transforms(self): class CondModel(nn.Module): def setup(self): self.dense = nn.Dense(3) def f1(self, arr): arr = self.dense(arr) return arr def f2(self, arr): _ = self.dense(arr) return arr def __call__(self, x): return nn.cond(x.sum() > 0, self.f1, self.f2, self, x) cond_model = CondModel() output, init_params = jax.jit(cond_model.init_with_output)( jax.random.key(0), x=jnp.ones(3) ) def test_add_metadata_axis(self): vars_copy = None class Foo(nn.Module): @nn.compact def __call__(self, x): nonlocal vars_copy kernel_init = nn.with_partitioning( nn.initializers.lecun_normal(), ('foo', 'bar') ) vars_copy = self.variables return nn.Dense( 4, kernel_init=kernel_init, use_bias=False, name='dense' )(x) class Test(nn.Module): @partial( nn.add_metadata_axis, variable_axes={'params': 0}, metadata_params={nn.PARTITION_NAME: 'baz'}, ) @nn.compact def __call__(self, x): return Foo(name='foo')(x) k = random.key(0) x = jnp.ones((4, 4), dtype=jnp.float32) vs = Test().init(k, x) y = Test().apply(vs, x) outer_expect = jax.tree_util.tree_map( jnp.shape, freeze({ 'params': { 'foo': { 'dense': { 'kernel': nn.Partitioned( jnp.ones((4, 4)), names=('baz', 'foo', 'bar') ) } } } }), ) inner_expect = jax.tree_util.tree_map( jnp.shape, freeze({ 'params': { 'dense': { 'kernel': nn.Partitioned( jnp.ones((4, 4)), names=('foo', 'bar') ) } } }), ) self.assertEqual(jax.tree_util.tree_map(jnp.shape, vs), outer_expect) self.assertEqual(jax.tree_util.tree_map(jnp.shape, vars_copy), inner_expect) def test_outer_setup_called_with_sharing_across_transforms(self): class A(nn.Module): def setup(self): self.foo = self.param('foo', nn.initializers.zeros, (2, 2), jnp.float32) def __call__(self, x): return self.foo class B(nn.Module): a: Any @nn.compact def __call__(self, x): return self.a(x) class C(nn.Module): def setup(self): self.a = A() self.b = nn.jit(B)(self.a) def __call__(self, x): b = self.b(x) a = self.a(x) return a + b k = random.key(0) x = random.randint(k, (2, 2), minval=0, maxval=10) vs = C().init(k, x) y = C().apply(vs, x) outer_expect = jax.tree_util.tree_map( jnp.shape, freeze({'params': {'a': {'foo': jnp.zeros((2, 2))}}}) ) self.assertEqual(jax.tree_util.tree_map(jnp.shape, vs), outer_expect) def test_grad_simple(self): class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.ones_init(), ()) return jnp.sum(p * x * y) class Foo(nn.Module): @nn.compact def __call__(self, x, y): x_grad, y_grad = nn.grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y ) return x_grad, y_grad x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Foo().init(random.key(0), x, y) x_grad, y_grad = Foo().apply(vs, x, y) self.assertTrue(tree_allclose(x_grad, y)) self.assertTrue(tree_allclose(y_grad, x)) def test_grad_simple_with_aux(self): class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.ones_init(), ()) return jnp.sum(p * x * y), p class Foo(nn.Module): @nn.compact def __call__(self, x, y): (x_grad, y_grad), aux = nn.grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y, has_aux=True ) return aux, x_grad, y_grad x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Foo().init(random.key(0), x, y) aux, x_grad, y_grad = Foo().apply(vs, x, y) self.assertTrue(tree_allclose(x_grad, y)) self.assertTrue(tree_allclose(y_grad, x)) self.assertTrue(tree_allclose(aux, vs['params']['LearnScale_0']['scale'])) def test_value_and_grad_simple(self): class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.ones_init(), ()) return jnp.sum(p * x * y) class Foo(nn.Module): @nn.compact def __call__(self, x, y): z, (x_grad, y_grad) = nn.value_and_grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y ) return z, x_grad, y_grad x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Foo().init(random.key(0), x, y) z, x_grad, y_grad = Foo().apply(vs, x, y) self.assertTrue(tree_allclose(x_grad, y)) self.assertTrue(tree_allclose(y_grad, x)) def test_value_and_grad_simple_with_aux(self): class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.ones_init(), ()) return jnp.sum(p * x * y), p class Foo(nn.Module): @nn.compact def __call__(self, x, y): (z, aux), (x_grad, y_grad) = nn.value_and_grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y, has_aux=True ) return z, aux, x_grad, y_grad x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Foo().init(random.key(0), x, y) z, aux, x_grad, y_grad = Foo().apply(vs, x, y) self.assertTrue(tree_allclose(x_grad, y)) self.assertTrue(tree_allclose(y_grad, x)) self.assertTrue(tree_allclose(aux, vs['params']['LearnScale_0']['scale'])) def test_value_and_grad_multiscope(self): class Foo(nn.Module): bar: nn.Module @nn.compact def __call__(self, x, y): def fn(self, x, y): qup = nn.Dense(y.shape[-1]) delta = y - self.bar(qup(x)) return jnp.sum(delta**2) z, (x_grad, y_grad) = nn.value_and_grad(fn, self, x, y) return z, x_grad, y_grad class Baz(nn.Module): @nn.compact def __call__(self, x, y): bar = nn.Dense(y.shape[-1]) return Foo(bar=bar)(x, y) x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Baz().init(random.key(0), x, y) z, x_grad, y_grad = Baz().apply(vs, x, y) def comparison_fn(x, y): w1 = vs['params']['Foo_0']['Dense_0']['kernel'] w2 = vs['params']['Dense_0']['kernel'] delta = y - jnp.dot(jnp.dot(x, w1), w2) return jnp.sum(delta**2) self.assertTrue(tree_allclose(comparison_fn(x, y), z)) self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad)) self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad)) def test_value_and_grad_multiscope_adopted(self): class Foo(nn.Module): bar: nn.Module qup: nn.Module @nn.compact def __call__(self, x, y): def fn(self, x, y): delta = y - self.bar(self.qup(x)) return jnp.sum(delta**2) z, (x_grad, y_grad) = nn.value_and_grad(fn, self, x, y) return z, x_grad, y_grad x = random.uniform(random.key(1), (4,)) y = random.uniform(random.key(2), (4,)) vs = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).init(random.key(0), x, y) z, x_grad, y_grad = Foo(bar=nn.Dense(4), qup=nn.Dense(4)).apply(vs, x, y) def comparison_fn(x, y): w1 = vs['params']['qup']['kernel'] w2 = vs['params']['bar']['kernel'] delta = y - jnp.dot(jnp.dot(x, w1), w2) return jnp.sum(delta**2) self.assertTrue(tree_allclose(comparison_fn(x, y), z)) self.assertTrue(tree_allclose(jax.grad(comparison_fn, 0)(x, y), x_grad)) self.assertTrue(tree_allclose(jax.grad(comparison_fn, 1)(x, y), y_grad)) def test_vmap_add_remove_axis_transforms(self): class BoxedData(struct.PyTreeNode, AxisMetadata): value: Any def unbox(self): return self.value def replace_boxed(self, val): return self.replace(value=val) def add_axis(self, index: int, params: dict[Any, Any]): value = jnp.mean(self.value, axis=index) return self.replace(value=value) def remove_axis(self, index: int, params: dict[Any, Any]): value_shape = list(self.value.shape) value_shape.insert(index, params['axis_size']) value = jnp.broadcast_to(self.value, value_shape) return self.replace(value=value) class Top(nn.Module): @nn.compact def __call__(self, x): VFoo = nn.vmap( Foo, in_axes=0, out_axes=0, variable_axes={'params':0, 'aux': 0}, metadata_params={'axis_size': x.shape[0]}, ) vfoo = VFoo(name="vfoo") y = vfoo(x) y = vfoo(x) assert vfoo.variables['aux']['v'].value.shape == () return y class Foo(nn.Module): @nn.compact def __call__(self, x): if self.has_variable('aux', 'v'): assert self.variables['aux']['v'].value.shape == () boxed_v = self.variable('aux', 'v', lambda: BoxedData(jnp.ones(()))) assert self.variables['aux']['v'].value.shape == () return x vs = Top().init(random.key(0), jnp.ones((2,5))) y = Top().apply(vs, jnp.ones((2, 5))) assert vs['aux']['vfoo']['v'].value.shape == () def test_vjp_tracer_leak(self): class LearnScale(nn.Module): @nn.compact def __call__(self, x): p = self.param('scale', nn.initializers.zeros, ()) return p * x class Foo(nn.Module): @nn.compact def __call__(self, x): y, bwd = nn.vjp(lambda mdl, x: mdl(x), LearnScale(), x) params_grad, x_grad = bwd(jnp.ones(y.shape)) return y, params_grad, x_grad key = jax.random.PRNGKey(0) x = jnp.ones((2, 3)) foo = Foo() with jax.check_tracer_leaks(): params = foo.init(key, x) foo.apply(params, x) @parameterized.named_parameters( ('retracing scan', True), ('simple scan', False) ) def test_jit_scan_retracing(self, retracing_scan: bool): num_blocks = 4 num_patterns = 4 features = 4 trace_counts = [0, 0] class Block(nn.Module): def setup(self): self.dense = nn.Dense(features, use_bias=False) @nn.jit def __call__(self, x): nonlocal trace_counts trace_counts[1] += 1 return self.dense(x) class BlockSequence(nn.Module): def setup(self): self.blocks = [Block() for _ in range(num_blocks)] @nn.jit def __call__(self, carry, inputs): nonlocal trace_counts trace_counts[0] += 1 for block in self.blocks: carry = block(carry) return carry, inputs class Transformer(nn.Module): retracing_scan: bool = True def setup(self): self.scan = nn.scan( BlockSequence, variable_axes={'params': 0}, split_rngs={'params': False}, length=num_patterns, check_constancy_invariants=retracing_scan, )() def __call__(self, inputs): return self.scan(jnp.zeros_like(inputs), inputs) model = Transformer(retracing_scan=retracing_scan) _ = model.init(random.key(0), jnp.ones((num_patterns, features,))) self.assertEqual(trace_counts[0], 2 if retracing_scan else 1) self.assertEqual(trace_counts[1], 2 if retracing_scan else 1) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/partitioning_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.linen.partitioning.""" import jax import jax.numpy as jnp from absl.testing import absltest, parameterized from jax import random, sharding from jax.experimental import mesh_utils from flax import linen as nn from flax.core import freeze, unfreeze from flax.linen import partitioning mock = absltest.mock # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() # Testing constants. AXIS_RULES_1 = (('foo', 'data'), ('bar', 'model'), ('baz', None)) AXIS_RULES_2 = (('foo', 'model'), ('bar', None), ('baz', 'data')) class PartitioningTest(parameterized.TestCase): def test_axis_rules(self): self.assertEqual(nn.spmd.get_logical_axis_rules(), ()) partitioning.set_axis_rules(AXIS_RULES_1) self.assertEqual(nn.spmd.get_logical_axis_rules(), AXIS_RULES_1) self.assertEqual(partitioning.get_axis_rules(), AXIS_RULES_1) partitioning.set_axis_rules(()) def test_axis_rules_context(self): partitioning.set_axis_rules(AXIS_RULES_1) self.assertEqual(partitioning.get_axis_rules(), AXIS_RULES_1) with partitioning.axis_rules(AXIS_RULES_2): self.assertEqual(partitioning.get_axis_rules(), AXIS_RULES_2) self.assertEqual(partitioning.get_axis_rules(), AXIS_RULES_1) def test_logical_to_mesh_axes_resolves_to_none_or_unconstrained(self): unconstrained = jax.sharding.PartitionSpec.UNCONSTRAINED rules = ( ('foo', None), ('bad', None), ('bar', unconstrained), ('baz', unconstrained), ) self.assertEqual( partitioning.logical_to_mesh_axes( ('foo', 'bad', 'bar', 'baz'), rules=rules ), (None, None, unconstrained, unconstrained), ) def test_logical_to_mesh_axes(self): axes_0 = ('foo', 'bar') # direct rule assignment self.assertEqual( partitioning.logical_to_mesh_axes(axes_0, rules=AXIS_RULES_1), ('data', 'model'), ) # Repeated None and Unconstrained unconstrained = jax.sharding.PartitionSpec.UNCONSTRAINED axes_repeated = ('foo', unconstrained, unconstrained, None, None) self.assertEqual( partitioning.logical_to_mesh_axes(axes_repeated, rules=AXIS_RULES_1), ('data', unconstrained, unconstrained, None, None), ) # axis rules context with partitioning.axis_rules(AXIS_RULES_1): self.assertEqual( partitioning.logical_to_mesh_axes(axes_0), ('data', 'model') ) # nested context with partitioning.axis_rules(AXIS_RULES_2): self.assertEqual( partitioning.logical_to_mesh_axes(axes_0), ('model', None) ) # duplicated logical names with partitioning.axis_rules(AXIS_RULES_1): with self.assertRaises(ValueError): partitioning.logical_to_mesh_axes(('foo', 'foo', 'baz')) def test_logical_to_mesh_axes_priorities(self): p_rules = (('foo', 'model'), ('bar', 'model'), ('baz', 'data')) with partitioning.axis_rules(p_rules): self.assertEqual( partitioning.logical_to_mesh_axes(('foo', 'bar', 'baz')), ('model', None, 'data'), ) self.assertEqual( partitioning.logical_to_mesh_axes(('bar', 'foo', 'baz')), (None, 'model', 'data'), ) self.assertEqual( partitioning.logical_to_mesh_axes(('baz', 'bar', 'foo')), ('data', None, 'model'), ) self.assertEqual( partitioning.logical_to_mesh_axes(('baz', 'bar', 'foo', 'unassigned')), ('data', None, 'model', None), ) @parameterized.parameters( dict( rules=(('a', ('model', 'data')), ('b', 'data')), axes=('a', 'b'), expected=(('model', 'data'), None), ), dict( rules=(('a', ('model', 'replica')), ('b', 'data')), axes=('a', 'b'), expected=(('model', 'replica'), 'data'), ), dict( rules=(('a', ('model', 'replica')), ('b', ('data', 'model'))), axes=('a', 'b'), expected=(('model', 'replica'), None), ), dict( rules=(('a', ('model', 'replica')), ('b', 'model')), axes=('a', 'b', 'c'), expected=(('model', 'replica'), None, None), ), dict(rules=(), axes=('a', 'b', 'c'), expected=(None, None, None)), dict( rules=(('a', None), ('a', 'model')), axes=('a', 'b'), expected=(None, None), ), dict( rules=( ('baz', 'data'), ('bar', None), ('foo', 'model'), ('foo', 'data'), ), axes=('baz', 'bar', 'foo'), expected=('data', None, 'model'), ), dict( rules=(('baz', 'data'), ('foo', ('model', 'emb'))), axes=('baz', jax.sharding.PartitionSpec.UNCONSTRAINED, 'foo'), expected=( 'data', jax.sharding.PartitionSpec.UNCONSTRAINED, ('model', 'emb'), ), ), ) def test_logical_to_mesh_axes_cases(self, rules, axes, expected): with partitioning.axis_rules(rules): result = partitioning.logical_to_mesh_axes(axes) self.assertEqual(result, expected) @mock.patch('flax.linen.spmd._with_sharding_constraint') def test_with_sharding_constraint(self, wsc_fn): unconstrained = jax.sharding.PartitionSpec.UNCONSTRAINED arr = jnp.ones((2, 2)) axes = ('foo', 'bar') partitioning.set_axis_rules(()) _ = partitioning.with_sharding_constraint(arr, axes) wsc_fn.assert_not_called() with partitioning.axis_rules(AXIS_RULES_1): _ = partitioning.with_sharding_constraint(arr, None) wsc_fn.assert_not_called() _ = partitioning.with_sharding_constraint(arr, axes) wsc_fn.assert_called_with( arr, jax.sharding.PartitionSpec('data', 'model'), mesh=None ) _ = partitioning.with_sharding_constraint(arr, ('foo', unconstrained)) wsc_fn.assert_called_with( arr, jax.sharding.PartitionSpec('data', unconstrained), mesh=None ) @mock.patch('flax.linen.spmd._with_sharding_constraint') def test_with_sharding_constraint_fallback(self, wsc_fn): arr = jnp.ones((2, 2)) with partitioning.axis_rules(AXIS_RULES_1): _ = partitioning.with_sharding_constraint(arr, ('foo', 'not_recognized')) wsc_fn.assert_called_with( arr, jax.sharding.PartitionSpec('data', None), mesh=None ) wsc_fn.reset_mock() _ = partitioning.with_sharding_constraint( arr, ('foo', 'not_recognized'), fallback=partitioning.RulesFallback.AXIS_IS_UNSHARDED, ) wsc_fn.assert_called_with( arr, jax.sharding.PartitionSpec('data', None), mesh=None ) wsc_fn.reset_mock() with self.assertRaises(ValueError): _ = partitioning.with_sharding_constraint( arr, ('foo', 'not_recognized'), fallback=partitioning.RulesFallback.RAISE_ERROR, ) wsc_fn.assert_not_called() _ = partitioning.with_sharding_constraint( arr, ('foo', 'not_recognized'), fallback=partitioning.RulesFallback.NO_CONSTRAINT, ) wsc_fn.assert_not_called() @parameterized.parameters(dict(axes_spec=None), dict(axes_spec=())) def test_param_with_axes_no_axes(self, axes_spec): class ParamTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.param_with_axes( 'foo', lambda k, s, d: jnp.zeros(s, d), (2, 2), x.dtype, axes=axes_spec, ) return x + foo k = random.key(0) x = jnp.ones((2, 2)) _ = ParamTest().init(k, x) def test_param_with_axes(self): class ParamTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.param_with_axes( 'foo', lambda k, s, d: jnp.zeros(s, d), (2, 2), x.dtype, axes=('foo', 'bar'), ) return x + foo p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) self.assertEqual( variables['params_axes']['foo_axes'], partitioning.AxisMetadata(names=('foo', 'bar')), ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) self.assertEqual( logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} ) def test_param_pytree_with_axes(self): def init_fn(k, s, d): del k return {'a': jnp.zeros(s, d), 'b': (jnp.zeros(s, d), jnp.zeros(s, d))} axes = {'a': ('foo', 'bar'), 'b': (('foo', 'bar'), ('bar', 'foo'))} class ParamTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.param_with_axes( 'foo', init_fn, (2, 2), x.dtype, axes=axes ) return x + foo['a'] p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) self.assertEqual( variables['params_axes']['foo_axes'], partitioning.AxisMetadata(names=axes), ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) expected = freeze( { 'foo': { 'a': jax.sharding.PartitionSpec('foo', 'bar'), 'b': ( jax.sharding.PartitionSpec('foo', 'bar'), jax.sharding.PartitionSpec('bar', 'foo'), ), } } ) self.assertEqual(logical_axis_names, expected) @parameterized.parameters(dict(axes_spec=None), dict(axes_spec=())) def test_variable_with_axes_no_axes(self, axes_spec): class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=axes_spec ) return x + foo.value k = random.key(0) x = jnp.ones((2, 2)) _ = VarTest().init(k, x) def test_variable_with_empty_tuple_has_empty_axes(self): class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=() ) return x + foo.value k = random.key(0) x = jnp.ones((2, 2)) variables = VarTest().init(k, x) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) self.assertEqual(logical_axis_names, {'foo': jax.sharding.PartitionSpec()}) def test_variable_with_axes(self): class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=('foo', 'bar') ) return x + foo.value p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = VarTest().init(k, x) self.assertIn('test', variables) self.assertIn('test_axes', variables) self.assertEqual( variables['test_axes']['foo_axes'], partitioning.AxisMetadata(names=('foo', 'bar')), ) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) self.assertEqual( logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} ) @mock.patch('flax.linen.partitioning._with_sharding_constraint') def test_variable_with_axes_fallback(self, wsc_fn): class VarTest(nn.Module): @nn.compact def __call__(self, x): foo = partitioning.variable_with_axes( 'test', 'foo', jnp.zeros, (2, 2), x.dtype, axes=('foo', 'bar'), fallback=partitioning.RulesFallback.NO_CONSTRAINT, ) return x + foo.value p_rules = ( # No rule for 'foo': ('bar', 'data'), ('baz', None), ) k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = VarTest().init(k, x) wsc_fn.assert_not_called() self.assertIn('test', variables) self.assertIn('test_axes', variables) self.assertEqual( variables['test_axes']['foo_axes'], partitioning.AxisMetadata(names=('foo', 'bar')), ) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) self.assertEqual( logical_axis_names, {'foo': jax.sharding.PartitionSpec('foo', 'bar')} ) def test_scan_with_axes(self): # MLP Hparams B, L, E = 8, 4, 32 # pylint: disable=invalid-name # fake inputs x = jnp.ones((B, E)) k = random.key(0) class SinDot(nn.Module): depth: int @nn.compact def __call__(self, x): W1 = partitioning.param_with_axes( # pylint: disable=invalid-name 'W1', nn.initializers.xavier_normal(), (x.shape[-1], self.depth), axes=('emb', 'mlp'), ) W2 = partitioning.param_with_axes( # pylint: disable=invalid-name 'W2', nn.initializers.xavier_normal(), (self.depth, x.shape[-1]), axes=('mlp', 'emb'), ) y = jnp.dot(jnp.sin(jnp.dot(x, W1)), W2) _ = partitioning.variable_with_axes( 'stats', 'y_st', lambda: y, axes=('batch', 'emb') ) # scan expects a (carry, out) return signature. return y, None class Scanned(nn.Module): num_layers: int depth: int @nn.compact def __call__(self, x): scanned_sindot = partitioning.scan_with_axes( SinDot, in_axes=(), variable_axes={'params': 0, 'stats': 1}, split_rngs={'params': True}, axis_name='layer', axes_collections=('params', 'stats'), length=self.num_layers, )(self.depth, name='scanned_layer') y, _ = scanned_sindot(x) # test calling again to test metadata compatibility across calls _, _ = scanned_sindot(x) return y p_rules = (('emb', 'data'), ('mlp', 'model'), ('batch', 'data')) with partitioning.axis_rules(p_rules): variables = Scanned(L, E).init(k, x) # Ensure that the module can be called when 'params_axes' is not mutable. Scanned(L, E).apply(variables, x) self.assertIn('params', variables) self.assertIn('params_axes', variables) self.assertIn('stats', variables) self.assertIn('stats_axes', variables) self.assertEqual( variables['params_axes']['scanned_layer']['W1_axes'], partitioning.AxisMetadata(names=('layer', 'emb', 'mlp')), ) logical_axis_names = partitioning.get_axis_names(variables['params_axes']) self.assertEqual( logical_axis_names, { 'scanned_layer': { 'W1': jax.sharding.PartitionSpec('layer', 'emb', 'mlp'), 'W2': jax.sharding.PartitionSpec('layer', 'mlp', 'emb'), } }, ) logical_axis_names = partitioning.get_axis_names(variables['stats_axes']) self.assertEqual( logical_axis_names, { 'scanned_layer': { 'y_st': jax.sharding.PartitionSpec('batch', 'layer', 'emb') } }, ) def test_vmap_with_axes(self): class Foo(nn.Module): @nn.compact def __call__(self, x): return ( partitioning.param_with_axes( 'w', jax.nn.initializers.uniform(), [4, 3], axes=('out', 'in') ) @ x ) class Vmapped(nn.Module): @nn.compact def __call__(self, x): FooVmapped = partitioning.vmap_with_axes( # pylint: disable=invalid-name Foo, variable_axes={ 'params': 1, }, split_rngs={'params': True}, partitioning_axis_names={'params': 'vmap_axis'}, ) return FooVmapped(name='foo_vmapped')(x) p_rules = (('out', None), ('in', 'data'), ('vmap_axis', 'model')) # check that regular Food module is correct with partitioning.axis_rules(p_rules): variables = Foo().init(jax.random.key(0), jnp.array([1, 2, 3])) variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map( lambda x: x.shape, variables['params'] ) self.assertDictEqual( variables, { 'params': {'w': (4, 3)}, 'params_axes': { 'w_axes': partitioning.AxisMetadata(names=('out', 'in')) }, }, ) # check that FooVmapped adds 'vmap_axis' to axis 1 with partitioning.axis_rules(p_rules): variables = Vmapped().init( jax.random.key(0), jnp.array([[1, 2, 3], [4, 5, 6]]) ) variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map( lambda x: x.shape, variables['params'] ) self.assertDictEqual( variables, { 'params': {'foo_vmapped': {'w': (4, 2, 3)}}, 'params_axes': { 'foo_vmapped': { 'w_axes': partitioning.AxisMetadata( names=('out', 'vmap_axis', 'in') ) } }, }, ) def test_logical_with_mesh_and_rules(self): devices = mesh_utils.create_device_mesh((jax.local_device_count(), 1)) mesh = sharding.Mesh(devices, ('in', 'out')) test = self rules = (('a', 'in'), ('b', 'out')) class Foo(nn.Module): @nn.compact def __call__(self, x): kernel_init = nn.with_logical_partitioning( nn.initializers.ones_init(), ('a', 'b'), mesh=mesh, rules=rules ) kernel = self.param('kernel', kernel_init, (x.shape[-1], 2)) kernel_box = self.get_variable('params', 'kernel') test.assertIsInstance(kernel_box, nn.Partitioned) test.assertEqual(kernel_box.names, ('a', 'b')) return x @ kernel @jax.jit def create_state(): module = Foo() variables = module.init(random.key(0), jnp.zeros((8, 4))) logical_spec = nn.get_partition_spec(variables) shardings = nn.logical_to_mesh_sharding(logical_spec, mesh, rules) variables = jax.lax.with_sharding_constraint(variables, shardings) return variables variables = create_state() self.assertEqual(variables['params']['kernel'].names, ('a', 'b')) self.assertIs(variables['params']['kernel'].mesh, mesh) self.assertEqual(variables['params']['kernel'].rules, rules) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/linen/summary_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import enum import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest from jax import random from flax import linen as nn from flax import struct from flax.core.scope import Array from flax.linen import summary # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000) def _get_shapes(pytree): return jax.tree_util.tree_map( lambda x: x.shape if hasattr(x, 'shape') else x, pytree ) def _get_obj_repr_value(x): if isinstance(x, summary._ObjectRepresentation): return x.obj return x class ConvBlock(nn.Module): features: int kernel_size: list[int] test_sow: bool def setup(self) -> None: self.conv = nn.Conv(self.features, self.kernel_size) self.bn = nn.BatchNorm() self.dropout = nn.Dropout(0.5) def block_method(self, x: Array, training: bool) -> Array: x = self.conv(x) if self.test_sow: self.sow('intermediates', 'INTERM', x) x = self.bn(x, use_running_average=not training) x = self.dropout(x, deterministic=not training) x = nn.relu(x) return x def __call__(self, x: Array, training: bool) -> Array: x = self.conv(x) if self.test_sow: self.sow('intermediates', 'INTERM', x) x = self.bn(x, use_running_average=not training) x = self.dropout(x, deterministic=not training) x = nn.relu(x) return x class CNN(nn.Module): test_sow: bool def setup(self) -> None: self.block1 = ConvBlock(32, [3, 3], test_sow=self.test_sow) self.block2 = ConvBlock(64, [3, 3], test_sow=self.test_sow) self.dense = nn.Dense(10) def cnn_method(self, x: Array, training: bool) -> Array: x = self.block1.block_method(x, training=training) x = self.block2.block_method(x, training=training) x = x.mean(axis=(1, 2)) if self.test_sow: self.sow('intermediates', 'INTERM', x) x = self.dense(x) return x, dict(a=x, b=x + 1.0) def __call__(self, x: Array, training: bool) -> Array: x = self.block1.block_method(x, training=training) x = self.block2.block_method(x, training=training) x = x.mean(axis=(1, 2)) if self.test_sow: self.sow('intermediates', 'INTERM', x) x = self.dense(x) return x, dict(a=x, b=x + 1.0) class SummaryTest(absltest.TestCase): def test_module_summary(self): """ This test creates a Table using `module_summary` and checks that it matches the expected output given the CNN model defined in `_get_tabulate_cnn`. """ batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) table = summary._get_module_table( module, depth=None, show_repeated=True, compute_flops=True, compute_vjp_flops=True, )( {'dropout': random.key(0), 'params': random.key(1)}, x, training=True, mutable=True, ) # get values for inputs and outputs from their _ValueRepresentation for row in table: row.inputs = jax.tree_util.tree_map(_get_obj_repr_value, row.inputs) row.outputs = jax.tree_util.tree_map(_get_obj_repr_value, row.outputs) # 10 rows = 1 CNN + 4 ConvBlock_0 + 4 ConvBlock_1 + 1 Dense_0 self.assertLen(table, 10) # check paths self.assertEqual(table[0].path, ()) self.assertEqual(table[1].path, ('block1',)) self.assertEqual(table[2].path, ('block1', 'conv')) self.assertEqual(table[3].path, ('block1', 'bn')) self.assertEqual(table[4].path, ('block1', 'dropout')) self.assertEqual(table[5].path, ('block2',)) self.assertEqual(table[6].path, ('block2', 'conv')) self.assertEqual(table[7].path, ('block2', 'bn')) self.assertEqual(table[8].path, ('block2', 'dropout')) self.assertEqual(table[9].path, ('dense',)) # check outputs shapes self.assertEqual( (table[0].inputs[0].shape, table[0].inputs[1]), (x.shape, dict(training=True)), ) self.assertEqual( _get_shapes(table[0].outputs), ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), ) self.assertEqual( _get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {'training': True}), ) self.assertEqual(table[1].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual(table[2].inputs.shape, (batch_size, 28, 28, 1)) self.assertEqual(table[2].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual( _get_shapes(table[3].inputs), ((batch_size, 28, 28, 32), {'use_running_average': False}), ) self.assertEqual(table[3].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual( _get_shapes(table[4].inputs), ((batch_size, 28, 28, 32), {'deterministic': False}), ) self.assertEqual(table[4].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual( _get_shapes(table[5].inputs), ((batch_size, 28, 28, 32), {'training': True}), ) self.assertEqual(table[5].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[6].inputs.shape, (batch_size, 28, 28, 32)) self.assertEqual(table[6].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual( _get_shapes(table[7].inputs), ((batch_size, 28, 28, 64), {'use_running_average': False}), ) self.assertEqual(table[7].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual( _get_shapes(table[8].inputs), ((batch_size, 28, 28, 64), {'deterministic': False}), ) self.assertEqual(table[8].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[9].inputs.shape, (batch_size, 64)) self.assertEqual(table[9].outputs.shape, (batch_size, 10)) # check no summary is performed for row in table: self.assertEqual( row.module_variables, row.counted_variables, ) # Each module FLOPs >= sum of its submodule FLOPs. # Can be greater due to ops like `nn.relu` not belonging to any submodule. for r in table: flops, vjp_flops = r.flops, r.vjp_flops submodule_flops, submodule_vjp_flops = 0, 0 for s in table: if len(s.path) == len(r.path) + 1 and s.path[: len(r.path)] == r.path: submodule_flops += s.flops submodule_vjp_flops += s.vjp_flops self.assertGreaterEqual(flops, submodule_flops) self.assertGreaterEqual(vjp_flops, submodule_vjp_flops) def test_module_summary_with_depth(self): """ This test creates a Table using `module_summary` set the `depth` argument to `1`, table should have fewer rows as a consequence. """ batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) table = summary._get_module_table( module, depth=1, show_repeated=True, compute_flops=True, compute_vjp_flops=True, )( {'dropout': random.key(0), 'params': random.key(1)}, x, training=True, mutable=True, ) # get values for inputs and outputs from their _ValueRepresentation for row in table: row.inputs = jax.tree_util.tree_map(_get_obj_repr_value, row.inputs) row.outputs = jax.tree_util.tree_map(_get_obj_repr_value, row.outputs) # 4 rows = 1 CNN + 1 ConvBlock_0 + 1 ConvBlock_1 + 1 Dense_0 self.assertLen(table, 4) # check paths self.assertEqual(table[0].path, ()) self.assertEqual(table[1].path, ('block1',)) self.assertEqual(table[2].path, ('block2',)) self.assertEqual(table[3].path, ('dense',)) # check outputs shapes self.assertEqual( (table[0].inputs[0].shape, table[0].inputs[1]), (x.shape, dict(training=True)), ) self.assertEqual( _get_shapes(table[0].outputs), ((batch_size, 10), dict(a=(batch_size, 10), b=(batch_size, 10))), ) self.assertEqual( _get_shapes(table[1].inputs), ((batch_size, 28, 28, 1), {'training': True}), ) self.assertEqual(table[1].outputs.shape, (batch_size, 28, 28, 32)) self.assertEqual( _get_shapes(table[2].inputs), ((batch_size, 28, 28, 32), {'training': True}), ) self.assertEqual(table[2].outputs.shape, (batch_size, 28, 28, 64)) self.assertEqual(table[3].inputs.shape, (batch_size, 64)) self.assertEqual(table[3].outputs.shape, (batch_size, 10)) # check ConvBlock_0 and ConvBlock_1 are summarized self.assertNotEqual(table[1].module_variables, table[1].counted_variables) self.assertNotEqual(table[2].module_variables, table[2].counted_variables) # check CNN and Dense_0 output are not summarized self.assertEqual(table[0].module_variables, table[0].counted_variables) self.assertEqual(table[3].module_variables, table[3].counted_variables) # Top level FLOPs > sum of listed submodule FLOPs, since not all are listed. self.assertGreater(table[0].flops, sum(r.flops for r in table[1:])) self.assertGreater(table[0].vjp_flops, sum(r.vjp_flops for r in table[1:])) def test_tabulate(self): """ This test creates a string representation of a Module using `Module.tabulate` and checks that it matches the expected output given the CNN model defined in `_get_tabulate_cnn`. """ batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) module_repr = module.tabulate( {'dropout': random.key(0), 'params': random.key(1)}, x, training=True, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) # NOTE: it's tricky to validate the content of lines # because it seems to be shell-dependent, so we will # just check lines that won't change between environments lines = module_repr.split('\n') # check title module_name = module.__class__.__name__ self.assertIn(f'{module_name} Summary', lines[1]) # check headers are correct self.assertIn('path', lines[3]) self.assertIn('module', lines[3]) self.assertIn('inputs', lines[3]) self.assertIn('outputs', lines[3]) self.assertIn('params', lines[3]) self.assertIn('flops', lines[3]) self.assertIn('vjp_flops', lines[3]) self.assertIn('batch_stats', lines[3]) # collection counts self.assertIn('Total', lines[-6]) self.assertIn('192', lines[-6]) self.assertIn('768 B', lines[-6]) self.assertIn('19,658', lines[-6]) self.assertIn('78.6 KB', lines[-6]) # total counts self.assertIn('Total Parameters', lines[-3]) self.assertIn('19,850', lines[-3]) self.assertIn('79.4 KB', lines[-3]) def test_tabulate_with_sow(self): batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=True) module_repr = module.tabulate( {'dropout': random.key(0), 'params': random.key(1)}, x, training=True, mutable=True, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) self.assertIn('intermediates', module_repr) self.assertIn('INTERM', module_repr) def test_tabulate_with_method(self): batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) module_repr = module.tabulate( {'dropout': random.key(0), 'params': random.key(1)}, x, training=True, method=CNN.cnn_method, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) self.assertIn('(block_method)', module_repr) self.assertIn('(cnn_method)', module_repr) def test_tabulate_function(self): """ This test creates a string representation of a Module using `Module.tabulate` and checks that it matches the expected output given the CNN model defined in `_get_tabulate_cnn`. """ batch_size = 32 x = jnp.ones((batch_size, 28, 28, 1)) module = CNN(test_sow=False) module_repr = nn.tabulate( module, {'dropout': random.key(0), 'params': random.key(1)}, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, )(x, training=True) lines = module_repr.split('\n') # check title module_name = module.__class__.__name__ self.assertIn(f'{module_name} Summary', lines[1]) # check headers are correct self.assertIn('path', lines[3]) self.assertIn('module', lines[3]) self.assertIn('inputs', lines[3]) self.assertIn('outputs', lines[3]) self.assertIn('params', lines[3]) self.assertIn('flops', lines[3]) self.assertIn('batch_stats', lines[3]) # collection counts self.assertIn('Total', lines[-6]) self.assertIn('192', lines[-6]) self.assertIn('768 B', lines[-6]) self.assertIn('19,658', lines[-6]) self.assertIn('78.6 KB', lines[-6]) # total counts self.assertIn('Total Parameters', lines[-3]) self.assertIn('19,850', lines[-3]) self.assertIn('79.4 KB', lines[-3]) def test_lifted_transform(self): class LSTM(nn.Module): features: int @nn.compact def __call__(self, x): carry = nn.LSTMCell(self.features).initialize_carry( random.key(0), x[:, 0].shape ) ScanLSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1, ) return ScanLSTM(self.features, name='ScanLSTM')(carry, x) lstm = LSTM(features=128) with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( random.key(0), x=jnp.ones((32, 128, 64)), console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = module_repr.splitlines() self.assertIn('LSTM', lines[5]) self.assertIn('ScanLSTM', lines[9]) self.assertIn('LSTMCell', lines[9]) self.assertIn('ScanLSTM/ii', lines[13]) self.assertIn('Dense', lines[13]) def test_lifted_transform_no_rename(self): class LSTM(nn.Module): features: int @nn.compact def __call__(self, x): carry = nn.LSTMCell(self.features).initialize_carry( random.key(0), x[:, 0].shape ) ScanLSTM = nn.scan( nn.LSTMCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1, ) return ScanLSTM(self.features)(carry, x) lstm = LSTM(features=128) with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( random.key(0), x=jnp.ones((32, 128, 64)), console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = module_repr.splitlines() self.assertIn('LSTM', lines[5]) self.assertIn('ScanLSTMCell_0', lines[9]) self.assertIn('LSTMCell', lines[9]) self.assertIn('ScanLSTMCell_0/ii', lines[13]) self.assertIn('Dense', lines[13]) def test_module_reuse(self): class ConvBlock(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(32, [3, 3])(x) x = nn.BatchNorm(use_running_average=True)(x) x = nn.Dropout(0.5, deterministic=True)(x) x = nn.relu(x) return x class CNN(nn.Module): @nn.compact def __call__(self, x): block = ConvBlock() x = block(x) x = block(x) x = block(x) return x x = jnp.ones((4, 28, 28, 32)) module_repr = CNN().tabulate( jax.random.key(0), x=x, show_repeated=True, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = module_repr.splitlines() # first call self.assertIn('ConvBlock_0/Conv_0', lines[9]) self.assertIn('bias', lines[9]) self.assertIn('ConvBlock_0/BatchNorm_0', lines[14]) self.assertIn('mean', lines[14]) self.assertIn('bias', lines[14]) self.assertIn('ConvBlock_0/Dropout_0', lines[19]) # second call self.assertIn('ConvBlock_0/Conv_0', lines[23]) self.assertNotIn('bias', lines[23]) self.assertIn('ConvBlock_0/BatchNorm_0', lines[25]) self.assertNotIn('mean', lines[25]) self.assertNotIn('bias', lines[25]) self.assertIn('ConvBlock_0/Dropout_0', lines[27]) # third call self.assertIn('ConvBlock_0/Conv_0', lines[31]) self.assertNotIn('bias', lines[31]) self.assertIn('ConvBlock_0/BatchNorm_0', lines[33]) self.assertNotIn('mean', lines[33]) self.assertNotIn('bias', lines[33]) self.assertIn('ConvBlock_0/Dropout_0', lines[35]) # Test that CNN FLOPs are 3x ConvBlock FLOPs. args = ({'dropout': random.key(0), 'params': random.key(1)}, x) cnn = summary._get_module_table( CNN(), depth=1, show_repeated=True, compute_flops=True, compute_vjp_flops=True, )(*args, mutable=True) block = summary._get_module_table( ConvBlock(), depth=1, show_repeated=True, compute_flops=True, compute_vjp_flops=True, )(*args, mutable=True) # Total forward/backward FLOPs equal to their sums of sub blocks. self.assertEqual(cnn[0].flops, sum(r.flops for r in cnn[1:])) self.assertEqual(cnn[0].vjp_flops, sum(r.vjp_flops for r in cnn[1:])) # Each sub block has cost equal to ConvBlock instantiated separately. for r in cnn[1:]: self.assertEqual(r.flops, block[0].flops) self.assertEqual(r.vjp_flops, block[0].vjp_flops) def test_empty_input(self): class EmptyInput(nn.Module): @nn.compact def __call__(self): return 1 module = EmptyInput() module_repr = module.tabulate( {}, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = module_repr.splitlines() # 1 output and 0 forward / backward FLOPs. self.assertRegex( lines[5], r'│\s*│\s*EmptyInput\s*│\s*│\s*1\s*│\s*0\s*│\s*0\s*│' ) def test_numpy_scalar(self): class Submodule(nn.Module): def __call__(self, x): return x + 1 class EmptyInput(nn.Module): @nn.compact def __call__(self): return Submodule()(x=np.pi) module = EmptyInput() module_repr = module.tabulate( {}, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = module_repr.splitlines() self.assertIn('4.141592', lines[5]) self.assertIn('x: 3.141592', lines[7]) self.assertIn('4.141592', lines[7]) # 0 forward / backward FLOPs due to precomputed output values. self.assertIn('│ 0 │ 0', lines[5]) self.assertIn('│ 0 │ 0', lines[7]) def test_partitioned_params(self): class Classifier(nn.Module): @nn.compact def __call__(self, x): hidden = nn.Dense( features=1024, kernel_init=nn.with_partitioning( nn.initializers.lecun_normal(), (None, 'data') ), bias_init=nn.with_partitioning(nn.initializers.zeros, (None,)), name='hidden', ) x = x / 255.0 x = x.reshape((x.shape[0], -1)) # flatten x = nn.relu(hidden(x)) x = nn.Dense(features=10, name='head')(x) return x module = Classifier() lines = module.tabulate( jax.random.key(0), jnp.empty((1, 28, 28, 1)), console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ).splitlines() self.assertIn('P(None,)', lines[7]) self.assertIn('P(None, data)', lines[8]) # Per-layer forward FLOPs: self.assertIn('1606656', lines[7]) # 1 * (28 * 28 * 1) * 1024 * 2 + 1024 self.assertIn('20490', lines[12]) # 1 * (1024 ) * 10 * 2 + 10 # Total forward FLOPs: input division + ReLU + two dense layers above. # (1 * 28 * 28 * 1) + (1 * 1024) + 1606656 + 20490. self.assertIn('1628954', lines[5]) # Per-layer backward FLOPs: # [3x MMs: forward, input cotangent, weight cotangent] 1024 * 784 * 2 * 3 # + [forward bias addition] 1024 # + [`mutable=True`: weight and bias sizes] 1024 * 784 + 1024 self.assertIn('5621760', lines[7]) # [3x matmuls: forward, input cotangent, weight cotangent] 1024 * 10 * 2 * 3 # + [forward bias addition] 10 # + [`mutable=True`: weight and bias sizes] 1024 * 10 + 10 self.assertIn('71700', lines[12]) # Total backward FLOPs: input division + ReLU + two dense layers above. # 2 * (1 * 28 * 28 * 1) + 3 * (1 * 1024) + 5621760 + 71700. self.assertIn('5698100', lines[5]) def test_non_array_variables(self): class Metadata(struct.PyTreeNode): names: tuple = struct.field(pytree_node=False) class Foo(nn.Module): @nn.compact def __call__(self): self.sow('foo', 'bar', Metadata(('baz', 'qux'))) module = Foo() lines = module.tabulate( {}, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ).splitlines() self.assertIn('names', lines[6]) self.assertIn('baz', lines[7]) self.assertIn('qux', lines[8]) # 0 forward and backward FLOPs. self.assertIn('│ 0 │ 0', lines[5]) def test_tabulate_param_count_and_flops(self): class Foo(nn.Module): @nn.compact def __call__(self, x): h = nn.Dense(4)(x) return nn.Dense(2)(h) module = Foo() rng = jax.random.key(0) x = jnp.ones((16, 9)) rep = module.tabulate( rng, x, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ) lines = rep.splitlines() self.assertIn('Total Parameters: 50', lines[-2]) def test_tabulate_enum(self): class Net(nn.Module): @nn.compact def __call__(self, inputs): x = inputs['x'] x = nn.Dense(features=2)(x) return jnp.sum(x) class InputEnum(str, enum.Enum): x = 'x' inputs = {InputEnum.x: jnp.ones((1, 1))} # test args lines = Net().tabulate(jax.random.key(0), inputs).split('\n') self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[5]) # test kwargs lines = Net().tabulate(jax.random.key(0), inputs=inputs).split('\n') self.assertIn('inputs:', lines[5]) self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[6]) def test_tabulate_norm_wrapper(self): class SubModel(nn.Module): @nn.compact def __call__(self, x): x = nn.SpectralNorm(nn.Dense(5))(x, update_stats=False) x = nn.Dense(6)(x) x = nn.WeightNorm(nn.Dense(7))(x) return x class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.WeightNorm(nn.Dense(3))(x) x = nn.Dense(4)(x) x = SubModel()(x) x = nn.Dense(8)(x) x = nn.SpectralNorm(nn.Dense(9))(x, update_stats=False) return x x = jnp.ones((1, 2)) key = jax.random.key(0) model = Model() lines = model.tabulate( key, x, console_kwargs=CONSOLE_TEST_KWARGS, compute_flops=True, compute_vjp_flops=True, ).splitlines() self.assertIn('Model', lines[5]) self.assertIn('WeightNorm_0', lines[7]) self.assertIn('Dense_0/kernel/scale', lines[7]) self.assertIn('Dense_0', lines[11]) self.assertIn('Dense_1', lines[16]) self.assertIn('SubModel_0', lines[21]) self.assertIn('SubModel_0/SpectralNorm_0', lines[23]) self.assertIn('Dense_0/kernel/sigma', lines[23]) self.assertIn('Dense_0/kernel/u', lines[24]) self.assertIn('SubModel_0/Dense_0', lines[28]) self.assertIn('SubModel_0/Dense_1', lines[33]) self.assertIn('SubModel_0/WeightNorm_0', lines[38]) self.assertIn('Dense_2/kernel/scale', lines[38]) self.assertIn('SubModel_0/Dense_2', lines[42]) self.assertIn('Dense_2', lines[47]) self.assertIn('SpectralNorm_0', lines[52]) self.assertIn('Dense_3/kernel/sigma', lines[52]) self.assertIn('Dense_3/kernel/u', lines[53]) self.assertIn('Dense_3', lines[57]) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/__init__.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: tests/nnx/bridge/module_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Any from flax.linen.dtypes import promote_dtype os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' import jax import jax.numpy as jnp from absl.testing import absltest import numpy as np from flax import linen as nn from flax import nnx from flax.nnx import bridge from flax.nnx.bridge.module import MODULE_CONTEXT class TestBridgeModule(absltest.TestCase): def test_update(self): class Foo(bridge.Module): a: int foo = Foo(1) state = {'b': {'c': nnx.Param(jnp.array(2))}} nnx.update(foo, state) def test_module_stack(self): """Test that apply set the module stack correctly.""" test = self class Foo(bridge.Module): def setup(self): current_ctx = MODULE_CONTEXT.module_stack[-1] test.assertIs(current_ctx.module, self) test.assertFalse(current_ctx.in_compact) def __call__(self): current_ctx = MODULE_CONTEXT.module_stack[-1] test.assertIs(current_ctx.module, self) test.assertFalse(current_ctx.in_compact) foo = Foo() foo.apply({}) def test_compact_basic(self): test = self class Linear(bridge.Module): dout: int @bridge.compact def __call__(self, x): w = self.param( 'w', nnx.initializers.uniform(), (x.shape[-1], self.dout) ) b = self.param('b', nn.initializers.zeros_init(), (self.dout,)) return x @ w + b[None] class Foo(bridge.Module): dout: int @bridge.compact def __call__(self, x): din = x.shape[-1] self.linear = Linear(self.dout) x = self.linear(x) # NNX graphdef, state = nnx.split(self) test.assertIn('Linear_0', state) test.assertIn('w', state['Linear_0']) test.assertIn('b', state['Linear_0']) return x foo = Foo(5) x = jnp.ones((3, 2)) self.assertIsInstance(foo, nnx.Module) variables = foo.init(0, x) params = variables['params'] self.assertIn('Linear_0', params) self.assertIn('w', params['Linear_0']) self.assertIn('b', params['Linear_0']) self.assertEqual(params['Linear_0']['w'].shape, (2, 5)) self.assertEqual(params['Linear_0']['b'].shape, (5,)) y: jax.Array = foo.apply(variables, x) self.assertEqual(y.shape, (3, 5)) def test_mutable_state(self): class FooLinen(nn.Module): @nn.compact def __call__(self): count = self.variable( 'counts', 'count', lambda: jnp.zeros((), jnp.int32) ) count.value += 1 model_linen = FooLinen() initial_vars_linen = model_linen.init({}) _, vars_linen = model_linen.apply(initial_vars_linen, mutable='counts') class FooNNX(bridge.Module): @bridge.compact def __call__(self): count = self.variable( 'counts', 'count', lambda: jnp.zeros((), jnp.int32) ) count[...] += 1 model_nnx = FooNNX() initial_vars_nnx = model_nnx.init({}) _, vars_nnx = model_nnx.apply(initial_vars_nnx, mutable='counts') self.assertEqual( initial_vars_linen['counts']['count'], initial_vars_nnx['counts']['count'] ) self.assertEqual(vars_linen['counts']['count'], vars_nnx['counts']['count']) def test_compact_parent_none(self): class Foo(bridge.Module): pass class Bar(bridge.Module): @bridge.compact def __call__(self): return Foo().scope bar = Bar() scope = bar.apply({}, rngs=1) self.assertIsNone(bar.scope) self.assertEqual(scope.rngs.default.key[...], jax.random.key(1)) self.assertEqual(scope.rngs.default.count[...], 0) class Baz(bridge.Module): @bridge.compact def __call__(self): return Foo(parent=None).scope baz = Baz() scope = baz.apply({}, rngs=1) self.assertIsNone(scope) def test_dense_port(self): class Dense(bridge.Module): features: int use_bias: bool = True dtype: Any = None param_dtype: Any = jnp.float32 precision: Any = None kernel_init: Any = nnx.initializers.lecun_normal() bias_init: Any = nnx.initializers.zeros_init() # Deprecated. Will be removed. dot_general: Any | None = None dot_general_cls: Any = None @bridge.compact def __call__(self, inputs: jax.Array) -> jax.Array: kernel = self.param( 'kernel', self.kernel_init, (jnp.shape(inputs)[-1], self.features), self.param_dtype, ) if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,), self.param_dtype ) else: bias = None inputs, kernel, bias = promote_dtype( inputs, kernel, bias, dtype=self.dtype ) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = jax.lax.dot_general y = dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y m = Dense(3) x = jnp.ones((1, 10, 2)) y, variables = m.init_with_output(0, x) self.assertEqual(y.shape, (1, 10, 3)) self.assertEqual(variables['params']['kernel'].shape, (2, 3)) self.assertEqual(variables['params']['bias'].shape, (3,)) y = m.apply(variables, x) self.assertEqual(y.shape, (1, 10, 3)) self.assertEqual(variables['params']['kernel'].shape, (2, 3)) self.assertEqual(variables['params']['bias'].shape, (3,)) @jax.jit def train_step(params, x, y): def loss_fn(params): y_pred = m.apply({'params': params}, x) return jnp.mean((y - y_pred) ** 2) grads = jax.grad(loss_fn)(params) params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads) return params params = variables['params'] x = jnp.ones((1, 10, 2)) y = jnp.ones((1, 10, 3)) params = train_step(params, x, y) def test_metadata(self): class Linear(bridge.Module): dout: int @bridge.compact def __call__(self, x): w = self.param( 'w', bridge.with_partitioning(nnx.initializers.uniform(), ('in', 'out')), (x.shape[-1], self.dout) ) b = self.param('b', nnx.initializers.zeros_init(), (self.dout,)) return x @ w + b[None] foo = Linear(6) x = jnp.ones((4, 2)) mesh = jax.make_mesh( (2, 2), ('in', 'out'), axis_types=(jax.sharding.AxisType.Auto,) * len(('in', 'out')), ) with jax.set_mesh(mesh): variables = foo.init(0, x) y: jax.Array = foo.apply(variables, x) params = variables['params'] self.assertIsInstance(params['w'], nn.Partitioned) self.assertEqual(params['w'].value.shape, (2, 6)) self.assertEqual(params['w'].names, ('in', 'out')) self.assertEqual(nn.get_partition_spec(variables)['params']['w'], jax.sharding.PartitionSpec('in', 'out')) self.assertIsInstance(params['b'], jax.Array) self.assertEqual(params['b'].shape, (6,)) self.assertEqual(y.shape, (4, 6)) def test_pure_nnx_submodule(self): class NNXLayer(nnx.Module): def __init__(self, dim, dropout, rngs): self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs) self.dropout = nnx.Dropout(dropout, rngs=rngs) self.count = nnx.Intermediate(jnp.array([0.])) def __call__(self, x): # Required check to avoid state update in `init()`. Can this be avoided? if not bridge.current_module().is_initializing(): self.count[...] += 1 x = self.linear(x) x = self.dropout(x) return x class BridgeMLP(bridge.Module): @bridge.compact def __call__(self, x): x = bridge.nnx_in_bridge_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x) x = bridge.nnx_in_bridge_mdl( lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x) return x model = BridgeMLP() x = jax.random.normal(jax.random.key(0), (4, 8)) variables = model.init(jax.random.key(1), x) self.assertSameElements(variables['params'].keys(), ['NNXLayer_0', 'another']) self.assertFalse(jnp.array_equal( variables['params']['NNXLayer_0']['linear']['kernel'], variables['params']['another']['linear']['kernel'], )) self.assertEqual(variables['intermediates']['NNXLayer_0']['count'], 0) k1, k2, k3 = jax.random.split(jax.random.key(0), 3) y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) assert not jnp.array_equal(y1, y2) _, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}, mutable=True) self.assertEqual(updates['intermediates']['NNXLayer_0']['count'], 1) class BridgeMLPSetup(bridge.Module): def setup(self): self.layer = bridge.nnx_in_bridge_mdl( lambda r: NNXLayer(8, 0.3, rngs=r)) def __call__(self, x): return self.layer(x) model = BridgeMLPSetup() variables = model.init(jax.random.key(1), x) self.assertSameElements(variables['params'].keys(), ['layer']) y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) assert not jnp.array_equal(y1, y2) def test_pure_nnx_submodule_modified_rng(self): class FooStack(nnx.Module): def __init__(self, in_dim, key): keys = jax.random.split(key, in_dim) self.rngs = nnx.Rngs(keys) def __call__(self, x): @nnx.vmap def generate_weights(r): return jax.random.normal(r.default(), (2,)) w = generate_weights(self.rngs) return x @ w class BridgeFoo(bridge.Module): @bridge.compact def __call__(self, x): x = bridge.nnx_in_bridge_mdl(lambda r: FooStack(4, r.default()))(x) return x model = BridgeFoo() v = model.init(jax.random.key(1), jnp.ones((1, 4))) y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1)) def test_linen_submodule(self): class LinenLayer(nn.Module): dim: int dropout_rate: float def setup(self): self.linear = nn.Dense(self.dim, use_bias=False) self.dropout = nn.Dropout(self.dropout_rate, deterministic=False) def __call__(self, x): if not self.is_initializing(): self.sow('intermediates', 'count', 1, init_fn=lambda: 0, reduce_fn=lambda a, b: a + b) x = self.linear(x) x = self.dropout(x) return x class BridgeMLP(bridge.Module): @bridge.compact def __call__(self, x): x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))(x) x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3), name='another')(x) return x model = BridgeMLP() x = jax.random.normal(jax.random.key(0), (4, 8)) variables = model.init(jax.random.key(1), x) self.assertFalse(jnp.array_equal( variables['params']['LinenLayer_0']['linear']['kernel'], variables['params']['another']['linear']['kernel'], )) k1, k2, k3 = jax.random.split(jax.random.key(0), 3) y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) assert not jnp.array_equal(y1, y2) _, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}, mutable=True) self.assertEqual(updates['intermediates']['LinenLayer_0']['count'], 1) class BridgeMLPSetup(bridge.Module): def setup(self): self.layer = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3)) def __call__(self, x): return self.layer(x) model = BridgeMLPSetup() variables = model.init(jax.random.key(1), x) self.assertSameElements(variables['params'].keys(), ['layer']) y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) assert not jnp.array_equal(y1, y2) def test_name(self): class Foo(bridge.Module): name: str class Bar(bridge.Module): @bridge.compact def __call__(self): f = Foo(name='f') assert f.name == 'f' assert self.f == f Bar().init() def test_transforms(self): class Dense(bridge.Module): dout: int @bridge.compact def __call__(self, x: jax.Array) -> jax.Array: return x @ self.param('w', nn.initializers.normal(), (x.shape[-1], self.dout)) class MLP(bridge.Module): dim: int num_layers: int def setup(self): @nnx.split_rngs(splits=self.num_layers) @nnx.vmap( in_axes=(nnx.StateAxes({nnx.RngState: 0, ...: None}),), axis_size=self.num_layers, transform_metadata={nnx.PARTITION_NAME: None}, ) def create_block(parent): block = Dense(self.dim) parent.block = block create_block(self) def __call__(self, x): @nnx.split_rngs(splits=self.num_layers) @nnx.scan( in_axes=(0, nnx.Carry), out_axes=nnx.Carry, transform_metadata={nnx.PARTITION_NAME: None}, ) def forward_block(model, x): return model(x) x = forward_block(self.block, x) return x model = MLP(dim=32, num_layers=2) x = jnp.ones((4, 32)) variables = model.init(jax.random.key(0), x) y = model.apply(variables, x) w = variables['params']['block']['w'] np.testing.assert_array_equal(y, x @ w[0] @ w[1]) def test_shared_modules(self): class Dense(bridge.Module): dout: int @bridge.compact def __call__(self, x): return x @ self.param('w', nn.initializers.normal(), (x.shape[-1], self.dout)) class Bottom(bridge.Module): def setup(self): self.layer = Dense(4) def __call__(self, x): return self.layer(x) class Top(bridge.Module): def setup(self): self.zzz = Bottom() self.set_attr_priority('zzz', bridge.AttrPriority.HIGH) self.aaa = self.zzz # another reference self.dense = self.aaa.layer # and another reference def forward(self, x): return self.aaa(x) def __call__(self, x): forward = nnx.remat(self.__class__.forward) return forward(self, x) model = Top() x = jnp.ones((4, 32)) params = model.init(jax.random.key(0), x)['params'] self.assertSameElements(['zzz'], params.keys()) def test_linen_layer_naming(self): class Dense(bridge.Module): dout: int @bridge.compact def __call__(self, x): return x @ self.param('w', lambda _: jnp.ones((x.shape[-1], self.dout))) class MLP(bridge.Module): nlayers: int def setup(self): self.layers = [Dense(4, name=f'layer_{i}') for i in range(self.nlayers)] def __call__(self, x): for layer in self.layers: x = layer(x) return x model = MLP(nlayers=3) x = jnp.ones((2, 4)) params = model.init(jax.random.key(0), x)['params'] self.assertSameElements([f'layer_{i}' for i in range(3)], params.keys()) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/bridge/wrappers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest import flax from flax import linen as nn from flax import nnx from flax.nnx import bridge class TestCompatibility(absltest.TestCase): def setUp(self): super().setUp() dim1 = max(jax.device_count() // 2, 1) device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1) self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out')) def test_functional(self): # Functional API for NNX Modules functional = bridge.functional(nnx.Linear)(32, 64) state = functional.init(rngs=nnx.Rngs(0)) x = jax.numpy.ones((1, 32)) y, updates = functional.apply(state)(x) ################## ### LinenToNNX ### ################## def test_linen_to_nnx(self): ## Wrapper API for Linen Modules linen_module = nn.Dense(features=64) x = jax.numpy.ones((1, 32)) model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init y = model(x) # like linen apply assert y.shape == (1, 64) self.assertIsInstance(model.kernel, nnx.Variable) # NNX automatically adds metadata box regardless of original Linen module. linen_vars = { 'params': {'kernel': model.kernel[...], 'bias': model.bias[...]} } linen_y = linen_module.apply(linen_vars, x) np.testing.assert_array_equal(y, linen_y) def test_linen_to_nnx_submodule(self): class NNXOuter(nnx.Module): def __init__(self, dout: int, *, rngs: nnx.Rngs): self.nn_dense1 = bridge.ToNNX(nn.Dense(dout, use_bias=False), rngs=rngs) self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, dout,))) self.batchnorm = bridge.ToNNX(nn.BatchNorm(use_running_average=True), rngs=rngs) self.rngs = rngs def __call__(self, x): x = self.nn_dense1(x) + self.b return self.batchnorm(x) x = jax.random.normal(jax.random.key(0), (2, 4)) model = NNXOuter(3, rngs=nnx.Rngs(0)) gdef_before_lazy_init, _ = nnx.split(model) bridge.lazy_init(model, x) gdef_full, state = nnx.split(model) assert gdef_before_lazy_init != gdef_full assert 'nn_dense1' in state assert 'batchnorm' in state assert 'kernel' in state['nn_dense1'] y = model(x) k, b = state['nn_dense1']['kernel'][...], state['b'][...] np.testing.assert_allclose(y, x @ k + b, rtol=1e-5) assert gdef_full == nnx.graphdef(model) # static data is stable now def test_linen_to_nnx_noncall_method(self): class Foo(nn.Module): @nn.compact def __call__(self, x): b = self.param('b', nn.zeros_init(), (1, 3,)) return self.dot(x) + b @nn.compact def dot(self, x): w = self.param('w', nn.initializers.lecun_normal(), (4, 3)) return x @ w def rngs(self): raise ValueError('This should not be called because ToNNX has .rngs') x = jax.random.normal(jax.random.key(0), (2, 4)) model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)) bridge.lazy_init(model.dot, x) y = model.dot(x) np.testing.assert_allclose(y, x @ nnx.state(model)['w'][...]) # lazy_init only initialized param w inside dot(), so calling __call__ should fail with self.assertRaises(flax.errors.ScopeParamNotFoundError): y = model(x) assert isinstance(model.to_nnx__rngs, nnx.Rngs) def test_linen_to_nnx_mutable(self): class Foo(nn.Module): def setup(self): self.count = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32)) def __call__(self, x): if not self.is_initializing(): self.count.value += 1 return x x = lambda: jnp.zeros((), jnp.int32) model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x) self.assertEqual(nnx.state(model)['count'][...], 0) y = model(x, mutable=True) self.assertEqual(nnx.state(model)['count'][...], 1) def test_linen_to_nnx_transform(self): class NNXOuter(nnx.Module): def __init__(self, dout: int, rngs: nnx.Rngs): self.inner = nnx.bridge.ToNNX(nn.Dense(dout), rngs=rngs) self.rngs = rngs def __call__(self, x): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(0, None), axis_size=5) def vmap_fn(inner, x): return inner(x) return vmap_fn(self.inner, x) x = jax.random.normal(jax.random.key(0), (2, 4)) model = NNXOuter(3, rngs=nnx.Rngs(0)) nnx.bridge.lazy_init(model, x) self.assertEqual(model.inner.kernel.shape, (5, 4, 3)) self.assertEqual(model.inner.bias.shape, (5, 3)) def test_linen_to_nnx_metadata(self): linen_module = nn.Dense( features=64, kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',), rules=(('out-alias', 'out'),)), ) x = jax.numpy.ones((1, 32)) linen_vars = linen_module.init(jax.random.key(0), x) @nnx.jit def create_sharded_nnx_module(x): model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x) state = nnx.state(model) sharded_state = jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state)) nnx.update(model, sharded_state) return model with jax.set_mesh(self.mesh): nnx_model = create_sharded_nnx_module(x) # nn.Partitioned metadata boxes translated into valid nnx.Variable boxes. self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.kernel, nnx.Variable) assert nnx_model.kernel.out_sharding == ('in', 'out') assert nnx_model.kernel[...].sharding.is_equivalent_to( jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec('in', 'out') ), ndim=2, ), f'{nnx_model.kernel[...].sharding = }' assert nnx_model.bias.out_sharding == ('out-alias',) assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),) assert nnx_model.bias[...].sharding.is_equivalent_to( jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out')), ndim=1, ) def test_linen_to_nnx_state_structure_consistency(self): class LinenInner(nn.Module): dout: int @nn.compact def __call__(self, x): w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.dout)) return nn.Dropout(rate=0.5, deterministic=False)(x @ w) class LinenMiddle(nn.Module): dout: int @nn.compact def __call__(self, x): dot = LinenInner(self.dout, name='dot') b = self.variable('bias', 'b', nn.initializers.zeros_init(), None, (1, self.dout)) return dot(x) + b.value @nnx.register_variable_name('bias') class Bias(nnx.Variable): pass class NNXMiddle(nnx.Module): def __init__(self, dout: int, *, rngs: nnx.Rngs): self.dot = bridge.ToNNX(LinenInner(dout), rngs=rngs) self.b = Bias(nnx.initializers.zeros_init()(rngs.params(), (1, dout))) def __call__(self, x): return self.dot(x) + self.b x = jax.random.normal(jax.random.key(42), (2, 4)) from_top = bridge.lazy_init( bridge.ToNNX(LinenMiddle(dout=3), rngs=nnx.Rngs(0, dropout=1)), x) from_middle = bridge.lazy_init( NNXMiddle(dout=3, rngs=nnx.Rngs(0, dropout=1)), x) # Remove the NNX-module-local RNG states, which will be different # because the NNX modules are on different level def get_weights(model): return nnx.split(model, nnx.RngCount, nnx.RngKey, ...)[3] from_top_weights = get_weights(from_top) from_middle_weights = get_weights(from_middle) # Confirm the rest of the state has the same structure. self.assertEqual(jax.tree.structure(from_top_weights), jax.tree.structure(from_middle_weights)) def test_adding_new_attributes(self): class LinenModule(nn.Module): @nn.compact def __call__(self): if self.is_initializing() and self.is_mutable_collection('cache'): self.put_variable('cache', 'x', 0) res = self.get_variable('cache', 'x') return res class NNXModule(nnx.Module): def __init__(self): self.module = nnx.bridge.ToNNX(LinenModule()).lazy_init() def __call__(self): result1 = self.module(mutable=['cache']) assert result1 == 0 result2 = self.module() assert result2 == 0, result2 # fails: result2 is None module = NNXModule() module() ################## ### NNXToLinen ### ################## def test_nnx_to_linen(self): model = bridge.to_linen(nnx.Linear, 32, out_features=64) x = jax.numpy.ones((1, 32)) y, variables = model.init_with_output(jax.random.key(0), x) assert y.shape == (1, 64) np.testing.assert_allclose(y, x @ variables['params']['kernel']) def test_nnx_to_linen_multiple_rngs(self): class NNXInner(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w[...]) class LinenOuter(nn.Module): @nn.compact def __call__(self, x): inner = bridge.to_linen(NNXInner, 4, 3) return inner(x) xkey, pkey, dkey1, dkey2 = jax.random.split(jax.random.key(0), 4) x = jax.random.normal(xkey, (2, 4)) model = LinenOuter() y1, var = model.init_with_output({'params': pkey, 'dropout': dkey1}, x) y2 = model.apply(var, x, rngs={'dropout': dkey2}) assert not jnp.allclose(y1, y2) # dropout keys are different def test_nnx_to_linen_multiple_collections(self): class NNXInner(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs) self.lora = nnx.LoRA(din, 3, dout, rngs=rngs) def __call__(self, x): return self.bn(x @ self.w[...]) + self.lora(x) xkey, pkey, dkey = jax.random.split(jax.random.key(0), 3) x = jax.random.normal(xkey, (2, 4)) model = bridge.to_linen(NNXInner, 4, 3) var = model.init({'params': pkey, 'dropout': dkey}, x) self.assertSameElements(var.keys(), ['LoRAParam', 'params', 'batch_stats']) y = model.apply(var, x) assert y.shape == (2, 3) def test_nnx_to_linen_mutable(self): @nnx.register_variable_name('Count', overwrite=True) class Count(nnx.Variable): pass class Counter(nnx.Module): def __init__(self): self.count = Count(jnp.array(0)) def __call__(self): self.count[...] += 1 model = bridge.ToLinen(Counter, skip_rng=True) variables = model.init(jax.random.key(0)) assert variables['Count']['count'] == 0 _, updates = model.apply(variables, mutable='Count') assert updates['Count']['count'] == 1 _ = model.apply(variables | updates) def test_to_linen_method_call(self): class Foo(nn.Module): def setup(self): self.embedding = nnx.bridge.to_linen(nnx.Embed, 2, 3) def __call__(self, x): return self.embedding(x) def attend(self, x): return self.embedding.attend(x) module = Foo() x = jnp.ones((1,), dtype=jnp.int32) z = jnp.ones((1, 3)) y , params = module.init_with_output(jax.random.key(0), x) assert y.shape == (1, 3) assert params['params']['embedding']['embedding'].shape == (2, 3) x_out = module.apply(params, z, method='attend') assert x_out.shape == (1, 2) def test_to_linen_nnx_method_arg(self): module = nnx.bridge.to_linen(nnx.Embed, 2, 3) x = jnp.ones((1,), dtype=jnp.int32) z = jnp.ones((1, 3)) y , params = module.init_with_output(jax.random.key(0), x) assert y.shape == (1, 3) assert params['params']['embedding'].shape == (2, 3) x_out = module.apply(params, z, nnx_method='attend') assert x_out.shape == (1, 2) def test_nnx_to_linen_mutated_static_data(self): @nnx.register_variable_name('Count', overwrite=True) class Count(nnx.Variable): pass class Counter(nnx.Module): def __init__(self): self.count = Count(jnp.array(0)) def __call__(self): self.count[...] += 1 self.count_nonzero = nnx.Intermediate(jnp.array(1)) model = bridge.ToLinen(Counter, skip_rng=True) variables = model.init(jax.random.key(0)) assert variables['Count']['count'] == 0 _, updates = model.apply(variables, mutable=['Count', 'intermediates']) assert updates['Count']['count'] == 1 assert updates['intermediates']['count_nonzero'] == 1 del updates['intermediates'] _ = model.apply(variables | updates) def test_nnx_to_linen_transforms(self): class LinenOuter(nn.Module): dout: int @nn.compact def __call__(self, x): inner = nn.vmap( bridge.ToLinen, variable_axes={'params': 0}, split_rngs={'params': True}, )(nnx.Linear, args=(x.shape[-1], self.dout)) return inner(x) xkey, pkey, _ = jax.random.split(jax.random.key(0), 3) x = jax.random.normal(xkey, (2, 4)) model = LinenOuter(dout=3) y, var = model.init_with_output(pkey, x) k = var['params']['VmapToLinen_0']['kernel'] assert k.shape == (2, 4, 3) np.testing.assert_allclose(y, jnp.einsum('ab,abc->ac', x, k)) def test_nnx_to_linen_metadata(self): model = bridge.to_linen( nnx.Linear, 32, 64, kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))) x = jax.numpy.ones((1, 32)) with jax.set_mesh(self.mesh): y, variables = model.init_with_output(jax.random.key(0), x) pspec_tree = nn.get_partition_spec(variables) assert y.shape == (1, 64) self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta) assert variables['params']['kernel'].metadata['out_sharding'] == ('in', 'out') self.assertEqual(pspec_tree['params']['kernel'], jax.sharding.PartitionSpec('in', 'out')) np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) def test_nnx_to_linen_metadata_transform(self): # TODO: add support and testing after axis add/remove in transform is fixed. pass def test_nnx_to_linen_pytree_structure_consistency(self): class NNXInner(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) @nnx.register_variable_name('bias', overwrite=True) class Bias(nnx.Variable): pass class NNXMiddle(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.dot = NNXInner(din, dout, rngs=rngs) self.b = Bias(nnx.initializers.zeros_init()(rngs.params(), (1, dout))) def __call__(self, x): return self.dot(x) + self.b class LinenMiddle(nn.Module): dout: int @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, name='dot') b = self.variable('bias', 'b', nn.initializers.zeros_init(), None, (1, self.dout)) return dot(x) + b.value x = jax.random.normal(jax.random.key(42), (2, 4)) keys = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} from_top = bridge.to_linen(NNXMiddle, din=4, dout=3).init(keys, x) from_middle = LinenMiddle(dout=3).init(keys, x) # Remove the NNX-module-local RNG states, which will be different # because the NNX modules are on different level def get_weights(variables): non_rngs = {} for kp, v in flax.traverse_util.flatten_dict(variables).items(): if 'rngs' not in kp: non_rngs[kp] = v return flax.traverse_util.unflatten_dict(non_rngs) from_top_weights = get_weights(from_top) from_middle_weights = get_weights(from_middle) # Confirm the rest of the state has the same structure. self.assertEqual(jax.tree.structure(from_top_weights), jax.tree.structure(from_middle_weights)) ############################ ### Hybrid mix-and-match ### ############################ def test_nnx_linen_nnx(self): class NNXInner(nnx.Module): def __init__(self, din, dout, dropout_rate, rngs): self.w = nnx.Param( nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=('in', 'out') )(rngs.params(), (din, dout))) self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w) class LinenMiddle(nn.Module): dout: int dropout_rate: float @nn.compact def __call__(self, x): dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot') logical_init = nn.with_logical_partitioning( nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out'),)) b = self.param('b', logical_init, (2, self.dout)) return dot(x) + b class NNXOuter(nnx.Module): def __init__(self, dout: int, dropout_rate: float, *, rngs: nnx.Rngs): self.inner = bridge.ToNNX(LinenMiddle(dout, dropout_rate), rngs=rngs) self.rngs = rngs def __call__(self, x): return self.inner(x) x = jax.random.normal(jax.random.key(0), (2, 4)) # Test the RNG with jax.set_mesh(self.mesh): model = bridge.lazy_init(NNXOuter(dout=6, dropout_rate=0.5, rngs=nnx.Rngs(default=1, dropout=2)), x) nnx.reseed(model, dropout=2) y1, y2 = model(x), model(x) # The dropout key of lowest NNX level still changes over stateful calls assert not jnp.allclose(y1, y2) # Another reseed resets the RNG key back nnx.reseed(model, dropout=2) np.testing.assert_array_equal(y1, model(x)) # Test the param value with disabled dropout with jax.set_mesh(self.mesh): model = bridge.lazy_init(NNXOuter(dout=6, dropout_rate=0., rngs=nnx.Rngs(default=1, dropout=2)), x) w, b = model.inner.dot['w'], model.inner.b np.testing.assert_allclose(model(x), x @ w + b) self.assertIsInstance(w, nnx.Param) assert hasattr(w, 'out_sharding') and w.out_sharding == ('in', 'out') assert hasattr(b, 'out_sharding') and b.out_sharding == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without # messing up the stateful part of the NNX module. pass if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/containers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flax import nnx from absl.testing import absltest import jax.numpy as jnp class TestContainers(absltest.TestCase): def test_unbox(self): x = nnx.Param( jnp.array(1), on_get_value=lambda c, x: x + 3, # type: ignore ) assert x[...] == 4 def test_on_set_value(self): x = nnx.Param( jnp.array(1), # type: ignore on_set_value=lambda c, x: x + 7, # type: ignore ) x[...] = 5 assert x.get_raw_value() == 12 def test_module_unbox(self): class Foo(nnx.Module): def __init__(self) -> None: self.x = nnx.Param(1, on_get_value=lambda c, x: x + 3) module = Foo() assert module.x.get_value() == 4 assert vars(module)['x'].get_raw_value() == 1 def test_module_box(self): class Foo(nnx.Module): def __init__(self) -> None: self.x = nnx.Param( jnp.array(1), on_set_value=lambda c, x: x + 7, # type: ignore ) module = Foo() module.x[...] = 5 assert module.x[...] == 12 assert vars(module)['x'][...] == 12 if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/filters_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from flax import nnx class TestFilters(absltest.TestCase): def test_path_contains(self): class Model(nnx.Module): def __init__(self, rngs): self.backbone1 = nnx.Linear(2, 3, rngs=rngs) self.backbone2 = nnx.Linear(3, 3, rngs=rngs) self.head = nnx.Linear(3, 10, rngs=rngs) model = Model(nnx.Rngs(0)) head_state = nnx.state(model, nnx.PathContains('head')) backbones_state = nnx.state(model, nnx.PathContains('backbone', exact=False)) self.assertIn('head', head_state) self.assertNotIn('backbone', head_state) self.assertIn('backbone1', backbones_state) self.assertIn('backbone2', backbones_state) self.assertNotIn('head', backbones_state) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/graph_utils_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses from functools import partial from threading import Thread from typing import Any from absl.testing import absltest, parameterized import numpy as np from flax import linen, nnx, struct import jax import jax.numpy as jnp class StatefulLinear(nnx.Module): def __init__(self, din, dout, rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) def increment(self): self.count[...] += 1 def __call__(self, x): self.count[...] += 1 return x @ self.w + self.b[None] class TestGraphUtils(parameterized.TestCase): def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] refmap = nnx.graphlib.RefMap() graphdef, flat_state = nnx.graphlib.flatten(g, ref_index=refmap, graph=True) assert flat_state[0][1].get_value() == 2 assert flat_state[1][1].get_value() == 4 assert len(refmap) == 2 # 2 Variables assert a['b'] in refmap assert g[3] in refmap def test_flatten_no_paths(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} g = [a, 3, a, nnx.Param(jnp.array(4))] refmap = nnx.graphlib.RefMap() graphdef, flat_state = nnx.graphlib.flatten( g, ref_index=refmap, with_paths=False, graph=True ) assert flat_state[0][...] == 2 assert flat_state[1][...] == 4 assert len(refmap) == 2 # 2 Variables assert a['b'] in refmap assert g[3] in refmap def test_unflatten(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) assert g[0] is g[2] @parameterized.parameters(True, False) def test_flatten_unflatten_unkown_leaves(self, graph): x = jnp.array(1.0) graphdef, flat_state = nnx.graphlib.flatten(x, graph=graph) self.assertIs(flat_state[0][1], x) x1 = nnx.merge(graphdef, flat_state) self.assertIs(x1, x) @parameterized.parameters(True, False) def test_split_merge_unkown_leaves(self, graph): x = jnp.array(1.0) graphdef, state = nnx.graphlib.split(x, graph=graph) self.assertIs(state, x) x1 = nnx.merge(graphdef, state) self.assertIs(x1, x) @parameterized.parameters(True, False) def test_split_merge_unkown_leaves_with_filters(self, graph): x = jnp.array(1.0) graphdef, state, rest = nnx.graphlib.split(x, jax.Array, ..., graph=graph) self.assertIs(state, x) x1 = nnx.merge(graphdef, state, rest) self.assertIs(x1, x) def test_unflatten_pure_dict(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) pure_state = nnx.to_pure_dict(state) g = nnx.merge(graphdef, pure_state) assert g[0] is g[2] def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] graphdef, state = nnx.split(g) g = nnx.merge(graphdef, state) assert g[0] is not g[2] def test_unflatten_empty(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.split(g) with self.assertRaisesRegex(ValueError, 'Incorrect number of leaves'): nnx.graphlib.unflatten(graphdef, nnx.State({})) def test_unflatten_return_variables(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) graphdef, state = nnx.graphlib.flatten( g, with_paths=True, graph=True ) self.assertLen(state, 2) self.assertIsInstance(state, nnx.graphlib.FlatState) self.assertIsInstance(state[0][1], nnx.Param) self.assertIsInstance(state[1][1], nnx.Param) def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} g = [a, 3, a, nnx.Param(jnp.array(4))] graphdef, state = nnx.split(g) state[0]['b'][...] = 3 nnx.update(g, state) assert g[0]['b'][...] == 3 assert g[2]['b'][...] == 3 def test_update_from_pure_dict(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} g = [a, 3, a, nnx.Param(jnp.array(4))] graphdef, state = nnx.split(g) pure_state = nnx.to_pure_dict(state) pure_state[0]['b'] = jnp.array(3) nnx.update(g, pure_state) assert g[0]['b'][...] == 3 assert g[2]['b'][...] == 3 @parameterized.parameters(True, False) def test_module_list(self, graph): rngs = nnx.Rngs(0) ls = [ nnx.Linear(2, 2, rngs=rngs), nnx.BatchNorm(2, rngs=rngs), ] graphdef, state = nnx.split(ls, graph=graph) assert state[0]['kernel'].shape == (2, 2) assert state[0]['bias'].shape == (2,) assert state[1]['scale'].shape == (2,) assert state[1]['bias'].shape == (2,) assert state[1]['mean'].shape == (2,) assert state[1]['var'].shape == (2,) def test_shared_variables(self): v = nnx.Param(1) g = [v, v] graphdef, state = nnx.split(g) assert len(nnx.to_flat_state(state)) == 1 g2 = nnx.merge(graphdef, state) assert g2[0] is g2[1] def test_tied_weights(self): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs) -> None: self.bar = nnx.Linear(2, 2, rngs=rngs) self.baz = nnx.Linear(2, 2, rngs=rngs) # tie the weights self.baz.kernel = self.bar.kernel node = Foo(rngs=nnx.Rngs(0)) graphdef, state = nnx.split(node) assert len(nnx.to_flat_state(state)) == 3 # 2 bias + 1 kernel node2 = nnx.merge(graphdef, state) assert node2.bar.kernel is node2.baz.kernel def test_tied_weights_example(self): class LinearTranspose(nnx.Module): def __init__(self, dout: int, din: int, *, rngs: nnx.Rngs) -> None: self.kernel = nnx.Param( nnx.initializers.lecun_normal()(rngs(), (dout, din)) ) def __call__(self, x): return x @ self.kernel.T class Encoder(nnx.Module): def __init__(self, *, rngs: nnx.Rngs) -> None: self.embed = nnx.Embed(10, 2, rngs=rngs) ... self.linear_out = LinearTranspose(10, 2, rngs=rngs) # tie the weights self.linear_out.kernel = self.embed.embedding def __call__(self, x): x = self.embed(x) ... return self.linear_out(x) model = Encoder(rngs=nnx.Rngs(0)) graphdef, state = nnx.split(model) assert len(nnx.to_flat_state(state)) == 1 x = jax.random.randint(jax.random.key(0), (2,), 0, 10) y = model(x) assert y.shape == (2, 10) def test_state_variables_shared_with_graph(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.array(1)) m = Foo() graphdef, state = nnx.split(m) assert isinstance(m.a, nnx.Param) assert isinstance(state['a'], nnx.Param) assert m.a is state['a'] assert m.a[...] == state['a'][...] m2 = nnx.merge(graphdef, state) assert isinstance(m2.a, nnx.Param) assert isinstance(state['a'], nnx.Param) assert m2.a is state['a'] assert m2.a[...] == state['a'][...] def test_shared_state_variables_shared_with_graph(self): class Foo(nnx.Module): def __init__(self): p = nnx.Param(jnp.array(1)) self.a = p self.b = p m = Foo() graphdef, state = nnx.split(m) assert isinstance(m.a, nnx.Param) assert isinstance(m.b, nnx.Param) assert isinstance(state['a'], nnx.Param) assert 'b' not in state assert m.a is state['a'] assert m.b is state['a'] assert m.a[...] == state['a'][...] assert m.b[...] == state['a'][...] m2 = nnx.merge(graphdef, state) assert isinstance(m2.a, nnx.Param) assert isinstance(m2.b, nnx.Param) assert isinstance(state['a'], nnx.Param) assert m2.a is state['a'] assert m2.b is state['a'] assert m2.a[...] == state['a'][...] assert m2.b[...] == state['a'][...] assert m2.a is m2.b def test_pytree_flatten(self): @struct.dataclass class Tree: a: int b: str = struct.field(pytree_node=False) p = Tree(1, 'a') leaves, treedef = nnx.graphlib._flatten_pytree(p) fields = dict(leaves) assert 'a' in fields assert 'b' not in fields assert fields['a'] == 1 p2 = nnx.graphlib._unflatten_pytree(leaves, treedef) assert isinstance(p2, Tree) assert p2.a == 1 def test_pytree_node(self): @struct.dataclass class Tree: a: nnx.Param[int] b: str = struct.field(pytree_node=False) class Foo(nnx.Module): def __init__(self): self.tree = nnx.data(Tree(nnx.Param(1), 'a')) m = Foo() graphdef, state = nnx.split(m) assert 'tree' in state assert 'a' in state['tree'] m2 = nnx.merge(graphdef, state) assert isinstance(m2.tree, Tree) assert m2.tree.a.get_value() == 1 assert m2.tree.b == 'a' assert m2.tree.a is m.tree.a assert m2.tree is not m.tree def test_cached_unflatten(self): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.BatchNorm(2, rngs=rngs) def f(m: Foo): m.a, m.b = m.b, m.a # type: ignore m = Foo(rngs=nnx.Rngs(0)) a = m.a b = m.b ref_out_idx_out = nnx.graphlib.RefMap() graphdef: nnx.graphlib.GraphDef[Foo] graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_out_idx_out, graph=True) state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graphlib.GraphDef[Foo], state): idx_out_ref_in = nnx.graphlib.IndexMap() m = nnx.graphlib.unflatten(graphdef, state, index_ref=idx_out_ref_in) ref_in_idx_out = nnx.graphlib.RefMap.from_indexmap(idx_out_ref_in) f(m) ref_in_idx_in = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten( m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out, graph=True ) state = state.to_nested_state() return state, graphdef state, graphdef_out = f_pure(graphdef, state) idx_out_ref_out = nnx.graphlib.IndexMap.from_refmap(ref_out_idx_out) m2 = nnx.graphlib.unflatten( graphdef_out, state, outer_index_outer_ref=idx_out_ref_out ) assert m2 is m assert m2.a is b assert m2.b is a def test_cached_unflatten_swap_variables(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.Param(2) def f(m: Foo): m.a, m.b = m.b, m.a m = Foo() a = m.a b = m.b ref_out_idx_out = nnx.graphlib.RefMap() graphdef: nnx.graphlib.GraphDef[Foo] graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_out_idx_out, graph=True) idx_out_ref_out = {v: k for k, v in ref_out_idx_out.items()} state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graphlib.GraphDef[Foo], state): idx_out_ref_in = nnx.graphlib.IndexMap() m = nnx.graphlib.unflatten(graphdef, state, index_ref=idx_out_ref_in) ref_in_idx_out = nnx.graphlib.RefMap.from_indexmap(idx_out_ref_in) f(m) ref_in_idx_in = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten( m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out, graph=True ) state = state.to_nested_state() return state, graphdef state, graphdef = f_pure(graphdef, state) m2 = nnx.graphlib.unflatten( graphdef, state, outer_index_outer_ref=idx_out_ref_out ) assert m2 is m assert m2.a is b assert m2.b is a def test_cached_unflatten_add_self_reference(self): class Foo(nnx.Module): def __init__(self): self.ref = nnx.data(None) def f(m: Foo): m.ref = m m = Foo() ref_out_idx_out = nnx.graphlib.RefMap() graphdef: nnx.graphlib.GraphDef[Foo] graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_out_idx_out, graph=True) idx_out_ref_out = nnx.graphlib.IndexMap.from_refmap(ref_out_idx_out) state = state.to_nested_state() @partial(jax.jit, static_argnums=(0,)) def f_pure(graphdef: nnx.graphlib.GraphDef[Foo], state): idx_out_ref_in = nnx.graphlib.IndexMap() m = nnx.graphlib.unflatten(graphdef, state, index_ref=idx_out_ref_in) ref_in_idx_out = nnx.graphlib.RefMap.from_indexmap(idx_out_ref_in) f(m) ref_in_idx_in = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten( m, ref_index=ref_in_idx_in, ref_outer_index=ref_in_idx_out, graph=True ) state = state.to_nested_state() return state, graphdef state, graphdef_out = f_pure(graphdef, state) m2 = nnx.graphlib.unflatten( graphdef_out, state, outer_index_outer_ref=idx_out_ref_out ) assert m2 is m assert m2.ref is m2 def test_call_jit_update(self): class Counter(nnx.Module): def __init__(self): self.count = nnx.Param(jnp.zeros(())) def inc(self): self.count[...] += 1 return 1 graph_state = nnx.split(Counter()) @jax.jit def update(graph_state: nnx.PureState[Counter]): out, graph_state = nnx.call(graph_state).inc() self.assertEqual(out, 1) return graph_state graph_state = update(graph_state) graph_state = update(graph_state) counter = nnx.merge(*graph_state) self.assertEqual(counter.count[...], 2) def test_stateful_linear(self): linear = StatefulLinear(3, 2, nnx.Rngs(0)) linear_state = nnx.split(linear) @jax.jit def forward(x, pure_linear: nnx.PureState[StatefulLinear]): y, pure_linear = nnx.call(pure_linear)(x) return y, pure_linear x = jnp.ones((1, 3)) y, linear_state = forward(x, linear_state) y, linear_state = forward(x, linear_state) self.assertEqual(linear.count[...], 0) new_linear = nnx.merge(*linear_state) self.assertEqual(new_linear.count[...], 2) def test_getitem(self): rngs = nnx.Rngs(0) nodes = dict( a=StatefulLinear(3, 2, rngs), b=StatefulLinear(2, 1, rngs), ) node_state = nnx.split(nodes) _, node_state = nnx.call(node_state)['b'].increment() nodes = nnx.merge(*node_state) self.assertEqual(nodes['a'].count[...], 0) self.assertEqual(nodes['b'].count[...], 1) def test_object_state_propagation(self): test = self class Foo(nnx.Module): def __call__(self): test.assertTrue(self._pytree__state.initializing) self = nnx.merge(*nnx.split(self)) test.assertTrue(self._pytree__state.initializing) module = Foo() nnx.bridge.lazy_init(module) def test_object_state_propagation_nested(self): class NNXOuter(nnx.Module): def __init__(self, dout: int, rngs: nnx.Rngs): self.inner = nnx.bridge.ToNNX(linen.Dense(dout), rngs=rngs) self.rngs = rngs def __call__(self, x): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(0, None), axis_size=5) def vmap_fn(inner, x): return inner(x) return vmap_fn(self.inner, x) x = jax.random.normal(jax.random.key(0), (2, 4)) model = NNXOuter(3, rngs=nnx.Rngs(0)) nnx.bridge.lazy_init(model, x) self.assertEqual(model.inner.kernel.shape, (5, 4, 3)) self.assertEqual(model.inner.bias.shape, (5, 3)) def test_split_merge_context(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) with nnx.graphlib.split_context() as ctx: graphdef1, state1 = ctx.split(m) graphdef2, state2 = ctx.split(m) self.assertFalse(hasattr(ctx, 'ref_index')) self.assertFalse(hasattr(ctx, 'ctxtag')) self.assertIsInstance(graphdef1.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(graphdef2.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(state1), 2) self.assertLen(nnx.to_flat_state(state2), 0) with nnx.graphlib.merge_context() as ctx: m1 = ctx.merge(graphdef1, state1) m2 = ctx.merge(graphdef2, state2) self.assertIs(m1, m2) self.assertFalse(hasattr(ctx, 'index_ref')) self.assertFalse(hasattr(ctx, 'ctxtag')) def test_split_merge_context_example(self): m1 = nnx.Dict({}) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) @jax.jit def f(graphdef, state): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) m2.a = 1 m2.ref = m2 # create a reference cycle with nnx.split_context('example') as ctx: return ctx.split(m2) graphdef_out, state_out = f(graphdef, state) with nnx.merge_context('example', False) as ctx: m3 = ctx.merge(graphdef_out, state_out) def test_split_merge_context_nested(self): m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m1 = nnx.Sequential(m2) with nnx.graphlib.split_context() as ctx: graphdef1, state1 = ctx.split(m1) graphdef2, state2 = ctx.split(m2) self.assertIsInstance(graphdef1.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(graphdef2.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(state1), 2) self.assertLen(nnx.to_flat_state(state2), 0) with nnx.graphlib.merge_context() as ctx: m1 = ctx.merge(graphdef1, state1) m2 = ctx.merge(graphdef2, state2) self.assertIs(m2, m1.layers[0]) self.assertFalse(hasattr(ctx, 'index_ref')) self.assertFalse(hasattr(ctx, 'ctxtag')) def test_split_merge_update_context(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.data(2) m = Foo() ctxtag = 'test' with nnx.update_context(ctxtag): with nnx.graphlib.split_context(ctxtag) as ctx: graphdef1, state1 = ctx.split(m) graphdef2, state2 = ctx.split(m) self.assertFalse(hasattr(ctx, 'ref_index')) self.assertFalse(hasattr(ctx, 'ctxtag')) self.assertIsInstance(graphdef1.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(graphdef2.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(state1), 2) self.assertLen(nnx.to_flat_state(state2), 0) @jax.jit def f(graphdef1, state1, graphdef2, state2): with nnx.graphlib.merge_context(ctxtag, True) as ctx: m1 = ctx.merge(graphdef1, state1) m2 = ctx.merge(graphdef2, state2) self.assertIs(m1, m2) self.assertFalse(hasattr(ctx, 'index_ref')) self.assertFalse(hasattr(ctx, 'ctxtag')) # swap a and b m1.a, m1.b = m1.b, m1.a with nnx.graphlib.split_context(ctxtag) as ctx: graphdef1, state1 = ctx.split(m1) graphdef2, state2 = ctx.split(m2) return graphdef1, state1, graphdef2, state2 graphdef1, state1, graphdef2, state2 = f( graphdef1, state1, graphdef2, state2 ) with nnx.graphlib.merge_context(ctxtag, False) as ctx: m1_out = ctx.merge(graphdef1, state1) m2_out = ctx.merge(graphdef2, state2) self.assertIs(m, m1_out) self.assertIs(m, m2_out) self.assertEqual(m.a, 2) self.assertEqual(m.b[...], 1) # type: ignore self.assertFalse(hasattr(ctx, 'index_ref')) self.assertFalse(hasattr(ctx, 'ctxtag')) def test_to_tree_simple(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) impure_tree = (m, 1, {'b': m}) pure_tree = nnx.to_tree(impure_tree) t1 = pure_tree[0] t2 = pure_tree[2]['b'] self.assertEqual(pure_tree[1], 1) self.assertIsInstance(t1, nnx.NodeStates) assert isinstance(t1, nnx.NodeStates) self.assertIsInstance(t2, nnx.NodeStates) assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(t2.graphdef.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(t1.states[0]), 2) self.assertLen(nnx.to_flat_state(t2.states[0]), 0) impure_tree2 = nnx.from_tree(pure_tree) m1_out = impure_tree2[0] m2_out = impure_tree2[2]['b'] self.assertIs(m1_out, m2_out) self.assertEqual(impure_tree2[1], 1) def test_to_tree_update_context(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.data(2) m = Foo() impure_tree = (m, 1, {'b': m}) ctxtag = 'test' with nnx.update_context(ctxtag): pure_tree = nnx.to_tree(impure_tree, ctxtag=ctxtag) t1 = pure_tree[0] t2 = pure_tree[2]['b'] self.assertEqual(pure_tree[1], 1) self.assertIsInstance(t1, nnx.NodeStates) assert isinstance(t1, nnx.NodeStates) self.assertIsInstance(t2, nnx.NodeStates) assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(t2.graphdef.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(t1.states[0]), 2) self.assertLen(nnx.to_flat_state(t2.states[0]), 0) @jax.jit def f(pure_tree): impure_tree2 = nnx.from_tree(pure_tree, ctxtag=ctxtag, is_inner=True) m1_out = impure_tree2[0] m2_out = impure_tree2[2]['b'] self.assertIs(m1_out, m2_out) # self.assertEqual(impure_tree2[1], 1) # swap a and b m1_out.a, m1_out.b = m1_out.b, m1_out.a pure_tree2 = nnx.to_tree(impure_tree2, ctxtag=ctxtag) t1 = pure_tree2[0] t2 = pure_tree2[2]['b'] # self.assertEqual(pure_tree2[1], 1) self.assertIsInstance(t1, nnx.NodeStates) assert isinstance(t1, nnx.NodeStates) self.assertIsInstance(t2, nnx.NodeStates) assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef.nodes[0], nnx.graphlib.NodeDef) self.assertIsInstance(t2.graphdef.nodes[0], nnx.graphlib.NodeRef) self.assertLen(nnx.to_flat_state(t1.states[0]), 2) self.assertLen(nnx.to_flat_state(t2.states[0]), 0) return pure_tree2 pure_tree2 = f(pure_tree) impure_tree2 = nnx.from_tree(pure_tree2, ctxtag=ctxtag, is_inner=False) m1_out = impure_tree2[0] m2_out = impure_tree2[2]['b'] self.assertIs(m, m1_out) self.assertIs(m, m2_out) self.assertEqual(m.a, 2) self.assertEqual(m.b[...], 1) # type: ignore self.assertEqual(impure_tree2[1], 1) def test_graph_flatten_with_data_wrapper(self): class Foo(nnx.Pytree): def __init__(self, data, static): self.data = nnx.data(data) self.static = nnx.static(static) tree = Foo(1, 2) state = nnx.state(tree) self.assertIn('data', state) self.assertIsInstance(state['data'], int) self.assertEqual(state['data'], 1) self.assertNotIn('static', state) def test_to_tree_consistent_prefix(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) impure_tree = (m, 1, {'b': m}) prefix = (0, None, 0) pure_tree = nnx.to_tree(impure_tree, prefix=prefix) prefix = (0, None, 1) with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): nnx.to_tree(impure_tree, prefix=prefix) def test_simple_vmap(self): @dataclasses.dataclass(frozen=True) class StateAxes: params: Any batch_stats: Any class Foo(nnx.Module): def __init__(self, a, b): self.a = nnx.Param(a) self.b = nnx.BatchStat(b) ctxtag = 'test' with nnx.update_context(ctxtag): m1 = Foo(a=jnp.array(0), b=jnp.arange(5)) m2 = Foo(a=jnp.array(1), b=jnp.array(2)) args = (m1, m2, {'b': m1}) m1_axes = StateAxes(None, 0) in_axes = (m1_axes, None, {'b': m1_axes}) jax_in_axes = jax.tree.map( lambda x: nnx.NodeStates.from_prefixes((x.params, x.batch_stats)) if isinstance(x, StateAxes) else x, in_axes, ) out_axes = 0 def split_fn(ctx: nnx.SplitContext, path, prefix, x): if isinstance(prefix, StateAxes): return nnx.NodeStates.from_split( *ctx.split(x, nnx.Param, nnx.BatchStat) ) return nnx.NodeStates.from_split(*ctx.split(x)) pure_args = nnx.to_tree( args, ctxtag=ctxtag, prefix=in_axes, split_fn=split_fn ) @partial(jax.vmap, in_axes=jax_in_axes, out_axes=(jax_in_axes, out_axes)) def f(*pure_args): args = nnx.from_tree(pure_args, ctxtag=ctxtag, is_inner=True) y = 0 self.assertIs(args[0], args[2]['b']) for path, m in nnx.iter_graph(args): if isinstance(m, Foo): self.assertEqual(m.a.shape, ()) self.assertEqual(m.b.shape, ()) y += m.a + m.b args_out = nnx.extract.clear_non_graph_nodes(args) pure_args_out, y = nnx.to_tree( (args_out, y), prefix=(in_axes, out_axes), ctxtag=ctxtag, split_fn=split_fn, ) return pure_args_out, y pure_args_out, y = f(*pure_args) args_out, y = nnx.from_tree( (pure_args_out, y), ctxtag=ctxtag, is_inner=False ) self.assertEqual(y.shape, (5,)) self.assertGreater(y.sum(), 5) self.assertIs(m1, args_out[0]) self.assertIs(m1, args_out[2]['b']) self.assertIs(m2, args_out[1]) @parameterized.parameters(True, False) def test_split_variable(self, graph): v = nnx.Param(1) graphdef, state = nnx.split(v, graph=graph) expected_type = nnx.graphlib.VariableDef if graph else nnx.graphlib.TreeNodeDef self.assertIsInstance(graphdef.nodes[0], expected_type) self.assertIsInstance(state, nnx.Variable) v2 = nnx.merge(graphdef, state) self.assertIsInstance(v2, nnx.Param) @parameterized.parameters(True, False) def test_split_filter_variable(self, graph): v = nnx.Param(1) graphdef, batch_stats, params, rest = nnx.split( v, nnx.BatchStat, nnx.Param, ..., graph=graph ) expected_type = nnx.graphlib.VariableDef if graph else nnx.graphlib.TreeNodeDef self.assertIsInstance(graphdef.nodes[0], expected_type) self.assertIsInstance(params, nnx.Variable) self.assertIsInstance(batch_stats, nnx.State) self.assertEmpty(batch_stats) self.assertIsInstance(rest, nnx.State) self.assertEmpty(rest) v2 = nnx.merge(graphdef, batch_stats, params, rest) self.assertIsInstance(v2, nnx.Param) @parameterized.parameters(True, False) def test_split_update_variable(self, graph): v = nnx.Param(jnp.array(1)) graphdef, state = nnx.split(v, graph=graph) expected_type = nnx.graphlib.VariableDef if graph else nnx.graphlib.TreeNodeDef self.assertIsInstance(graphdef.nodes[0], expected_type) self.assertIsInstance(state, nnx.Variable) state[...] = 2 nnx.update(v, state) self.assertEqual(v[...], 2) @parameterized.parameters(True, False) def test_split_update_filter_variable(self, graph): v = nnx.Param(jnp.array(1)) graphdef, batch_stats, params, rest = nnx.split( v, nnx.BatchStat, nnx.Param, ..., graph=graph ) expected_type = nnx.graphlib.VariableDef if graph else nnx.graphlib.TreeNodeDef self.assertIsInstance(graphdef.nodes[0], expected_type) self.assertIsInstance(params, nnx.Variable) self.assertIsInstance(batch_stats, nnx.State) self.assertEmpty(batch_stats) self.assertIsInstance(rest, nnx.State) self.assertEmpty(rest) params[...] = 2 nnx.update(v, batch_stats, params, rest) self.assertEqual(v[...], 2) @parameterized.parameters( (lambda: nnx.Param(1),), (lambda: 42,), (lambda: jnp.array([1, 2, 3]),), ) def test_split_leaf(self, leaf_fn): leaf = leaf_fn() graphdef, state = nnx.split(leaf) out = nnx.merge(graphdef, state) self.assertIs(out, state) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jit_variable(self, graph, graph_updates): v = nnx.Param(1) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(v): v[...] += 1 f(v) np.testing.assert_allclose(v[...], 2) def test_jit_pytree_of_variables(self): v1 = nnx.Param(jnp.array(1)) v2 = nnx.Param(jnp.array(2)) vs = [v1, v1, v2] @nnx.jit def f(vs): self.assertIs(vs[0], vs[1]) self.assertIsNot(vs[0], vs[2]) vs[0][...] += 10 f(vs) self.assertIs(vs[0], vs[1]) self.assertIsNot(vs[0], vs[2]) np.testing.assert_allclose(vs[0][...], 11) np.testing.assert_allclose(vs[2][...], 2) def test_variable_reference_in_module(self): class Foo(nnx.Module): def __init__(self, var): self.var = var var = nnx.Param(1) foo = Foo(var) @nnx.jit def increment_var(var, foo): self.assertIs(var, foo.var) var[...] += 1 increment_var(var, foo) self.assertEqual(foo.var[...], 2) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_variables_example(self, graph, graph_updates): def stateful_linear_init(din: int, dout: int, rngs: nnx.Rngs): w = nnx.Param(jax.random.normal(rngs(), (din, dout))) b = nnx.Param(jnp.zeros((dout,))) count = nnx.Variable(jnp.array(0)) return w, b, count rngs = nnx.Rngs(0) w, b, count = stateful_linear_init(2, 3, rngs=rngs) @nnx.jit(graph=graph, graph_updates=graph_updates) def stateful_linear(w, b, count, x): count[...] += 1 return x @ w + b[None] x = jax.random.normal(rngs(), (1, 2)) y = stateful_linear(w, b, count, x) self.assertEqual(count[...], 1) y = stateful_linear(w, b, count, x) self.assertEqual(count[...], 2) self.assertEqual(y.shape, (1, 3)) @parameterized.parameters(True, False) def test_array_attributes(self, graph): class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) self.b = 'yes' m = Foo() graphdef, state = nnx.split(m, graph=graph) self.assertLen(state, 1) self.assertIsInstance(state['a'], jax.Array) m2 = nnx.merge(graphdef, state) self.assertIsInstance(m2.a, jax.Array) self.assertEqual(m2.a, 1) self.assertEqual(m2.b, 'yes') def test_transform_array_attributes(self): class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) self.b = 'yes' m = Foo() @nnx.jit def f(m): m.a += 1 self.assertEqual(m.b, 'yes') f(m) self.assertEqual(m.a, 2) def test_data_after_init(self): test = self class Foo(nnx.Module): def __init__(self): self.ls = [] self.ls.append(jnp.array(1)) with self.assertRaisesRegex( ValueError, 'Found unexpected data on value of type' ): m = Foo() def test_update_dict(self): node = { 'a': { 'b': 1, 'c': nnx.Param(jnp.array(2)), 'd': 3, }, } updates = { 'a': { 'b': 4, 'c': jnp.array(10), }, } nnx.update(node, updates) self.assertEqual(node['a']['b'], 4) self.assertEqual(node['a']['c'][...], 10) self.assertEqual(node['a']['d'], 3) def test_pop_dict(self): node = { 'a': { 'b': jnp.array(1), 'c': nnx.Param(jnp.array(2)), 'd': jnp.array(3.0), }, } lt_2 = lambda _, x: x < 2 popped = nnx.pop(node, (nnx.Param, lt_2)) self.assertEqual(popped['a']['b'], 1) self.assertEqual(popped['a']['c'][...], 2) self.assertEqual(node['a']['d'], 3.0) self.assertLen(jax.tree.leaves(node), 1) self.assertLen(jax.tree.leaves(popped), 2) def test_iter_graph(self): arr0 = jnp.zeros(1) arr1 = jnp.zeros(1) var0 = nnx.Variable(jnp.zeros(1)) var1 = nnx.Variable(jnp.zeros(1)) child = nnx.Module() child.a = var0 child.b = arr0 child.c = var1 child.d = var0 child.e = arr1 child.f = arr0 root = nnx.Module() root.a = child root.b = var1 root.c = arr1 root.d = var1 root.e = child root.f = var0 root.g = arr1 nodes = [node for _, node in nnx.iter_graph(root)] count = lambda e: sum(node is e for node in nodes) # All internal nodes must be visited exactly once. self.assertEqual(count(var0), 1) self.assertEqual(count(var1), 1) self.assertEqual(count(child), 1) self.assertEqual(count(root), 1) # Arrays must not be deduplicated. self.assertEqual(count(arr0), 2) self.assertEqual(count(arr1), 3) # Nodes must be yielded in DFS order. expected = [var0, arr0, var1, arr1, arr0, child, arr1, arr1, root] unique = (arr0, arr1, var0, var1, child, root) index = lambda e: next(node is e for node in unique) actual = [node for node in nodes if any(node is e for e in unique)] self.assertEqual(list(map(index, actual)), list(map(index, expected))) def test_cached_partial_docstring_example(self): import optax model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) @nnx.jit def train_step(model, optimizer, x, y): def loss_fn(model): return jnp.mean((model(x) - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return loss cached_train_step = nnx.cached_partial(train_step, model, optimizer) for step in range(2): x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) loss = cached_train_step(x, y) self.assertIsInstance(loss, jax.Array) def test_find_duplicates(self): class SharedModules(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.a = nnx.Linear(1, 1, rngs=rngs) self.b = nnx.Linear(1, 1, rngs=rngs) self.c = self.a # shared Module model = SharedModules(nnx.Rngs(0)) duplicates = nnx.find_duplicates(model) self.assertLen(duplicates, 1) self.assertEqual(duplicates[0], [('a',), ('c',)]) def test_resursive_map(self): class Foo(nnx.Pytree): def __init__(self, d): self.d = d foo1 = Foo(10) foo2 = Foo(20) bar = [foo1, foo2, foo1] n = 0 def inc_d(path, node): nonlocal n if isinstance(node, Foo): n += 1 node.d += 1 return node bar2 = nnx.recursive_map(inc_d, bar) self.assertIs(bar2[0], bar2[2]) self.assertEqual(bar2[0].d, 11) self.assertEqual(bar2[1].d, 21) self.assertEqual(n, 2) def test_resursive_map_replace(self): class Foo(nnx.Pytree): def __init__(self, d): self.d = d foo1 = Foo(10) foo2 = Foo(20) bar = [foo1, foo2, foo1] n = 0 def swap(path, node): nonlocal n if isinstance(node, Foo): n += 1 node = Foo(-node.d) return node bar2 = nnx.recursive_map(swap, bar) self.assertIs(bar2[0], bar2[2]) self.assertEqual(bar2[0].d, -10) self.assertEqual(bar2[1].d, -20) self.assertEqual(n, 2) @parameterized.parameters(True, False) def test_recursive_map_with_list(self, graph): rngs = nnx.Rngs(0) model = nnx.Sequential(nnx.Linear(2, 3, rngs=rngs), nnx.relu, nnx.Linear(3, 4, rngs=rngs)) def add_rank2_lora(_, node): if isinstance(node, nnx.Linear): return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs) return node self.assertEqual(len(nnx.recursive_map(add_rank2_lora, model, graph=graph).layers), 3) def test_graphdef_hash_with_sequential(self): rngs = nnx.Rngs(0) net = nnx.Sequential( nnx.Linear(2, 1, rngs=rngs), ) hash(nnx.graphdef(net)) @nnx.set_graph_mode(False) def test_split_graph_error(self): v = nnx.Variable(jnp.array(1.0)) with self.assertRaisesRegex( ValueError, 'found at paths' ): graphdef, state = nnx.split((v, v)) class SimpleModule(nnx.Module): pass class TestThreading(parameterized.TestCase): def test_threading(self): x = SimpleModule() class MyThread(Thread): def run(self) -> None: nnx.graphlib.split(x) thread = MyThread() thread.start() thread.join() class TestTreeFlatten(parameterized.TestCase): def test_tree_flatten_unflatten(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} b = {'a': 5, 'b': nnx.Param(jnp.array(6))} g = [a, 3, b, nnx.Param(jnp.array(4))] graphdef, flat_state = nnx.graphlib.flatten(g, graph=False) self.assertIsInstance(graphdef.nodes[0], nnx.graphlib.TreeNodeDef) g2 = nnx.graphlib.unflatten(graphdef, flat_state) self.assertIsInstance(g2, list) self.assertLen(g2, 4) self.assertEqual(g2[0]['a'], 1) self.assertEqual(g2[1], 3) self.assertIsInstance(g2[0]['b'], nnx.Param) self.assertIsInstance(g2[3], nnx.Param) self.assertEqual(g2[0]['b'][...], 2) self.assertEqual(g2[3][...], 4) def test_tree_flatten_no_paths(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} b = {'a': 5, 'b': nnx.Param(jnp.array(6))} g = [a, 3, b, nnx.Param(jnp.array(4))] graphdef, leaves = nnx.graphlib.flatten(g, with_paths=False, graph=False) self.assertIsInstance(graphdef.nodes[0], nnx.graphlib.TreeNodeDef) self.assertIsInstance(leaves, list) def test_tree_split_merge(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} b = {'a': 5, 'b': nnx.Param(jnp.array(6))} g = [a, 3, b, nnx.Param(jnp.array(4))] graphdef, state = nnx.split(g, graph=False) g2 = nnx.merge(graphdef, state) self.assertIsInstance(g2, list) self.assertEqual(g2[0]['a'], 1) self.assertEqual(g2[1], 3) self.assertIsInstance(g2[0]['b'], nnx.Param) self.assertEqual(g2[0]['b'][...], 2) self.assertEqual(g2[3][...], 4) def test_tree_split_merge_module(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) graphdef, state = nnx.split(m, graph=False) self.assertIsInstance(graphdef.nodes[0], nnx.graphlib.TreeNodeDef) m2 = nnx.merge(graphdef, state) self.assertIsInstance(m2, nnx.Linear) self.assertEqual(m2.kernel.shape, (2, 3)) self.assertEqual(m2.bias.shape, (3,)) def test_tree_shared_variables_raises(self): v = nnx.Param(jnp.array(1)) g = [v, v] with self.assertRaises(ValueError): nnx.split(g, graph=False) def test_tree_shared_refs_raises(self): ref = jax.new_ref(jnp.array(1.0)) g = [ref, ref] with self.assertRaises(ValueError): nnx.split(g, graph=False) def test_tree_shared_variables_state_raises(self): v = nnx.Param(jnp.array(1)) g = [v, v] with self.assertRaises(ValueError): nnx.state(g, graph=False) def test_tree_shared_variables_graphdef_raises(self): v = nnx.Param(jnp.array(1)) g = [v, v] with self.assertRaises(ValueError): nnx.graphdef(g, graph=False) def test_tree_shared_variables_clone_raises(self): v = nnx.Param(jnp.array(1)) g = [v, v] with self.assertRaises(ValueError): nnx.clone(g, graph=False) def test_tree_flatten_unflatten_ordering(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) graphdef, state = nnx.split(m, graph=False) tree_nodedef = graphdef.nodes[0] self.assertIsInstance(tree_nodedef, nnx.graphlib.TreeNodeDef) paths = [p for p, _ in tree_nodedef.path_index] self.assertEqual(paths, sorted(paths)) m2 = nnx.merge(graphdef, state) np.testing.assert_array_equal(m.kernel[...], m2.kernel[...]) np.testing.assert_array_equal(m.bias[...], m2.bias[...]) def test_tree_flatten_dict(self): g = {'z': nnx.Param(jnp.array(1)), 'a': jnp.array(2)} graphdef, state = nnx.split(g, graph=False) g2 = nnx.merge(graphdef, state) self.assertEqual(g2['z'][...], 1) np.testing.assert_array_equal(g2['a'], jnp.array(2)) def test_tree_flatten_tuple(self): g = (nnx.Param(jnp.array(1)), jnp.array(2), 3) graphdef, state = nnx.split(g, graph=False) g2 = nnx.merge(graphdef, state) self.assertIsInstance(g2, tuple) self.assertEqual(g2[0][...], 1) np.testing.assert_array_equal(g2[1], jnp.array(2)) self.assertEqual(g2[2], 3) def test_tree_flatten_namedtuple(self): import collections Point = collections.namedtuple('Point', ['y', 'x']) g = Point( y=nnx.Param(jnp.array(1.0)), x=nnx.Param(jnp.array(2.0)), ) graphdef, flat_state = nnx.graphlib.flatten(g, graph=False) self.assertLen(flat_state, 2) path, value = flat_state[0] self.assertEqual(path, ('x',)) self.assertEqual(value, 2.0) path, value = flat_state[1] self.assertEqual(path, ('y',)) self.assertEqual(value, 1.0) g2 = nnx.graphlib.unflatten(graphdef, flat_state) self.assertIsInstance(g2, Point) self.assertEqual(g2.y[...], 1.0) self.assertEqual(g2.x[...], 2.0) def test_tree_flatten_registered_dataclass(self): @jax.tree_util.register_dataclass @dataclasses.dataclass class MyData: z_param: Any a_value: Any g = MyData( z_param=nnx.Param(jnp.array(10.0)), a_value=jnp.array(20.0), ) graphdef, flat_state = nnx.graphlib.flatten(g, graph=False) self.assertLen(flat_state, 2) path, value = flat_state[0] self.assertEqual(path, ('a_value',)) np.testing.assert_array_equal(value, 20.0) path, value = flat_state[1] self.assertEqual(path, ('z_param',)) self.assertEqual(value, 10.0) g2 = nnx.graphlib.unflatten(graphdef, flat_state) self.assertIsInstance(g2, MyData) self.assertEqual(g2.z_param[...], 10.0) np.testing.assert_array_equal(g2.a_value, jnp.array(20.0)) def test_tree_flatten_nested_mixed(self): g = { 'b': [nnx.Param(jnp.array(1)), jnp.array(2)], 'a': (nnx.Param(jnp.array(3)), 4), } graphdef, flat_state = nnx.graphlib.flatten(g, graph=False) self.assertLen(flat_state, 4) path, value = flat_state[0] self.assertEqual(path, ('a', 0)) self.assertEqual(value, 3) path, value = flat_state[1] self.assertEqual(path, ('a', 1)) self.assertEqual(value, 4) path, value = flat_state[2] self.assertEqual(path, ('b', 0)) self.assertEqual(value, 1) path, value = flat_state[3] self.assertEqual(path, ('b', 1)) np.testing.assert_array_equal(value, 2) g2 = nnx.graphlib.unflatten(graphdef, flat_state) self.assertIsInstance(g2, dict) self.assertEqual(g2['b'][0][...], 1) np.testing.assert_array_equal(g2['b'][1], jnp.array(2)) self.assertEqual(g2['a'][0][...], 3) self.assertEqual(g2['a'][1], 4) @parameterized.parameters(True, False) def test_iter_graph(self, graph): var0 = nnx.Variable(jnp.zeros(1)) var1 = nnx.Variable(jnp.zeros(1)) arr0 = jnp.zeros(1) child = nnx.Module() child.a = var0 child.b = arr0 child.c = var1 root = nnx.Module() root.x = child root.y = jnp.ones(2) node_ids = [id(node) for _, node in nnx.iter_graph(root, graph=graph)] self.assertIn(id(var0), node_ids) self.assertIn(id(var1), node_ids) self.assertIn(id(child), node_ids) self.assertIn(id(root), node_ids) def test_iter_graph_tree_mode_shared_variable_raises(self): var = nnx.Variable(jnp.zeros(1)) root = nnx.Module() root.a = var root.b = var with self.assertRaisesRegex( ValueError, 'found at paths' ): list(nnx.iter_graph(root, graph=False)) def test_iter_graph_tree_mode_cycle_raises(self): a = nnx.List([1]) b = nnx.List([2, a]) a.append(b) with self.assertRaisesRegex( ValueError, 'found at paths' ): list(nnx.iter_graph(a, graph=False)) @parameterized.parameters(True, False) def test_iter_modules(self, graph): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) modules = list(nnx.iter_modules(model, graph=graph)) self.assertLen(modules, 1) path, m = modules[0] self.assertEqual(path, ()) self.assertIs(m, model) @parameterized.parameters(True, False) def test_iter_modules_nested(self, graph): class Block(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5) model = Block(nnx.Rngs(0)) modules = list(nnx.iter_modules(model, graph=graph)) module_types = [type(m).__name__ for _, m in modules] self.assertIn('Block', module_types) self.assertIn('Linear', module_types) self.assertIn('Dropout', module_types) self.assertLen(modules, 3) def test_recursive_map_tree_mode(self): class Foo(nnx.Pytree): def __init__(self, d): self.d = d foo1 = Foo(10) foo2 = Foo(20) bar = [foo1, foo2] n = 0 def inc_d(path, node): nonlocal n if isinstance(node, Foo): n += 1 node.d += 1 return node bar2 = nnx.recursive_map(inc_d, bar, graph=False) self.assertEqual(bar2[0].d, 11) self.assertEqual(bar2[1].d, 21) self.assertEqual(n, 2) def test_recursive_map_tree_mode_replace(self): class Foo(nnx.Pytree): def __init__(self, d): self.d = d foo1 = Foo(10) foo2 = Foo(20) bar = [foo1, foo2] n = 0 def swap(path, node): nonlocal n if isinstance(node, Foo): n += 1 node = Foo(-node.d) return node bar2 = nnx.recursive_map(swap, bar, graph=False) self.assertEqual(bar2[0].d, -10) self.assertEqual(bar2[1].d, -20) self.assertEqual(n, 2) def test_recursive_map_tree_mode_with_list(self): rngs = nnx.Rngs(0) model = nnx.Sequential( nnx.Linear(2, 3, rngs=rngs), nnx.relu, nnx.Linear(3, 4, rngs=rngs) ) def add_rank2_lora(_, node): if isinstance(node, nnx.Linear): return nnx.LoRA( node.in_features, 2, node.out_features, base_module=node, rngs=rngs, ) return node result = nnx.recursive_map(add_rank2_lora, model, graph=False) self.assertLen(result.layers, 3) def test_recursive_map_tree_mode_shared_variable_raises(self): v = nnx.Param(jnp.array(1)) g = [v, v] with self.assertRaisesRegex( ValueError, 'found at paths' ): nnx.recursive_map(lambda path, node: node, g, graph=False) def test_recursive_map_tree_mode_cycle_raises(self): a = nnx.List([1]) b = nnx.List([2, a]) a.append(b) with self.assertRaisesRegex( ValueError, 'found at paths' ): nnx.recursive_map(lambda path, node: node, a, graph=False) def test_check_valid_pytree_flatten(self): class NotAPytree(nnx.Pytree, pytree=False): def __init__(self): self.x = 1 node = [NotAPytree()] with self.assertRaisesRegex( ValueError, "pytree=False.*Found at path" ): nnx.graphlib.flatten(node, graph=False) def test_check_valid_pytree_iter_graph(self): class NotAPytree(nnx.Pytree, pytree=False): def __init__(self): self.x = 1 node = nnx.List([NotAPytree()]) with self.assertRaisesRegex( ValueError, "pytree=False.*Found at path" ): list(nnx.iter_graph(node, graph=False)) def test_check_valid_pytree_iter_children(self): class NotAPytree(nnx.Pytree, pytree=False): def __init__(self): self.x = 1 node = NotAPytree() with self.assertRaisesRegex( ValueError, "pytree=False" ): list(nnx.iter_children(node, graph=False)) def test_check_valid_pytree_recursive_map(self): class NotAPytree(nnx.Pytree, pytree=False): def __init__(self): self.x = 1 node = nnx.List([NotAPytree()]) with self.assertRaisesRegex( ValueError, "pytree=False.*Found at path" ): nnx.recursive_map(lambda path, node: node, node, graph=False) @parameterized.parameters(True, False) def test_map(self, graph): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) new_model = nnx.map(lambda path, x: jnp.zeros_like(x), model, graph=graph) self.assertTrue(hasattr(new_model, 'kernel')) self.assertTrue(hasattr(new_model, 'bias')) np.testing.assert_array_equal(new_model.kernel, jnp.zeros((2, 3))) np.testing.assert_array_equal(new_model.bias, jnp.zeros((3,))) def test_map_with_path(self): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) paths_seen = [] def record_path(path, x): paths_seen.append(path) return x nnx.map(record_path, model) self.assertLen(paths_seen, 2) path_last_parts = sorted(p[-1] for p in paths_seen) self.assertEqual(path_last_parts, ['bias', 'kernel']) def test_map_nested(self): class Model(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) model = Model(rngs=nnx.Rngs(0)) new_model = nnx.map(lambda path, x: jnp.ones_like(x), model) self.assertTrue(hasattr(new_model, 'linear')) np.testing.assert_array_equal(new_model.linear.kernel, jnp.ones((2, 3))) np.testing.assert_array_equal(new_model.linear.bias, jnp.ones((3,))) def test_map_replace(self): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) new_model = nnx.map( lambda path, v: v.replace(jnp.zeros_like(v)), model ) self.assertTrue(hasattr(new_model, 'kernel')) self.assertTrue(hasattr(new_model, 'bias')) self.assertIsInstance(new_model.kernel, nnx.Param) np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3))) np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,))) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/helpers_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax import jax.numpy as jnp import optax from absl.testing import absltest import numpy as np from flax import linen from flax import nnx class TrainState(nnx.TrainState): batch_stats: nnx.State class TestHelpers(absltest.TestCase): def test_train_state(self): m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) graphdef, params, batch_stats = nnx.split(m, nnx.Param, nnx.BatchStat) state = TrainState.create( graphdef, params=params, tx=optax.sgd(1.0), batch_stats=batch_stats, ) leaves = jax.tree_util.tree_leaves(state) def test_train_state_methods(self): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 4, rngs=rngs) self.batch_norm = nnx.BatchNorm(4, rngs=rngs) def __call__(self, x: jax.Array, train: bool) -> jax.Array: x = self.linear(x) x = self.batch_norm(x, use_running_average=not train) return x module = Foo(rngs=nnx.Rngs(0)) graphdef, params, batch_stats = nnx.split(module, nnx.Param, nnx.BatchStat) state = TrainState.create( graphdef, params=params, tx=optax.sgd(1.0), batch_stats=batch_stats, ) x = jax.numpy.ones((1, 2)) y, _updates = state.apply('params', 'batch_stats')(x, train=True) assert y.shape == (1, 4) # fake gradient grads = jax.tree.map(jnp.ones_like, state.params) # test apply_gradients state = state.apply_gradients(grads) def test_nnx_linen_sequential_equivalence(self): key1, key2 = jax.random.split(jax.random.key(0), 2) rngs = nnx.Rngs(0) x = jax.random.uniform(key1, (3, 1, 5)) model_nnx = nnx.Sequential( nnx.Linear(5, 4, rngs=rngs), nnx.Linear(4, 2, rngs=rngs) ) model = linen.Sequential([linen.Dense(4), linen.Dense(2)]) variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): variables['params'][f'layers_{layer_index}'][param] = getattr( model_nnx.layers[layer_index], param )[...] out_nnx = model_nnx(x) out = model.apply(variables, x) np.testing.assert_array_equal(out, out_nnx) variables = model.init(key2, x) for layer_index in range(2): for param in ('kernel', 'bias'): getattr(model_nnx.layers[layer_index], param)[...] = variables[ 'params' ][f'layers_{layer_index}'][param] out_nnx = model_nnx(x) out = model.apply(variables, x) np.testing.assert_array_equal(out, out_nnx) def test_nnx_empty_sequential_is_identity(self): iden = nnx.Sequential() assert iden(12) == 12 assert iden(12, 23) == (12, 23) assert iden() is None assert iden(k=2) == {'k': 2} def test_dict_mutable_mapping(self): d = nnx.Dict({'a': 1, 'b': 2}) self.assertEqual(d['a'], 1) self.assertEqual(d['b'], 2) self.assertEqual(len(d), 2) d['c'] = 3 self.assertEqual(d['c'], 3) self.assertEqual(len(d), 3) del d['a'] self.assertEqual(len(d), 2) with self.assertRaises(KeyError): _ = d['a'] self.assertSetEqual(set(d), {'b', 'c'}) def test_dict_setdefault(self): d = nnx.Dict({'a': 1, 'b': 2}) self.assertEqual(d.setdefault('a', 10), 1) self.assertEqual(d['a'], 1) self.assertEqual(d.setdefault('c', 3), 3) self.assertEqual(d['c'], 3) self.assertEqual(len(d), 3) def test_dict_contains(self): d = nnx.Dict({'a': 1, 'b': 2}) self.assertIn('a', d) self.assertIn('b', d) self.assertNotIn('c', d) d['c'] = 3 self.assertIn('c', d) del d['a'] self.assertNotIn('a', d) def test_list_mutable_sequence(self): l = nnx.List([1, 2, 3]) self.assertEqual(len(l), 3) self.assertEqual(l[0], 1) self.assertEqual(l[1], 2) self.assertEqual(l[2], 3) l.append(4) self.assertEqual(len(l), 4) self.assertEqual(l[3], 4) l.insert(1, 5) self.assertEqual(len(l), 5) self.assertEqual(l[0], 1) self.assertEqual(l[1], 5) self.assertEqual(l[2], 2) self.assertEqual(l[3], 3) self.assertEqual(l[4], 4) del l[2] self.assertEqual(len(l), 4) self.assertEqual(l[0], 1) self.assertEqual(l[1], 5) self.assertEqual(l[2], 3) self.assertEqual(l[3], 4) l[1:3] = [6, 7] self.assertEqual(l[1], 6) self.assertEqual(l[2], 7) self.assertEqual(l[1:3], [6, 7]) def test_list_fori_loop(self): class Foo(nnx.Module): def __init__(self): self.layers = nnx.List([ nnx.Linear(1, 1, rngs=nnx.Rngs(0)), nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ]) def batch_loop_body(i, carry): return carry net = Foo() jax.lax.fori_loop(0, 2, batch_loop_body, net) def test_list_pytree_default_behavior(self): ls = nnx.List([jnp.array(1), jnp.array(2), jnp.array(3)]) leaves = jax.tree_util.tree_leaves(ls) self.assertLen(leaves, 3) np.testing.assert_array_equal(leaves[0], jnp.array(1)) np.testing.assert_array_equal(leaves[1], jnp.array(2)) np.testing.assert_array_equal(leaves[2], jnp.array(3)) def test_list_pytree_static_elements(self): ls = nnx.List([nnx.static(10), nnx.static(20), nnx.static(30)]) leaves = jax.tree_util.tree_leaves(ls) self.assertEmpty(leaves) def test_list_pytree_data_elements(self): ls = nnx.List([nnx.data(1), nnx.data(2), nnx.data(3)]) leaves = jax.tree_util.tree_leaves(ls) self.assertLen(leaves, 3) self.assertEqual(leaves[0], 1) self.assertEqual(leaves[1], 2) self.assertEqual(leaves[2], 3) def test_list_pytree_mixed_static_data(self): ls = nnx.List([ nnx.data(jnp.array(1)), nnx.static(100), nnx.data(jnp.array(2)), nnx.static(200), ]) leaves = jax.tree_util.tree_leaves(ls) self.assertLen(leaves, 2) np.testing.assert_array_equal(leaves[0], jnp.array(1)) np.testing.assert_array_equal(leaves[1], jnp.array(2)) def test_list_pytree_flatten_unflatten(self): ls = nnx.List([nnx.data(10), nnx.static('hello'), nnx.data(20)]) leaves, treedef = jax.tree_util.tree_flatten(ls) self.assertLen(leaves, 2) self.assertEqual(leaves[0], 10) self.assertEqual(leaves[1], 20) new_leaves = [x * 2 for x in leaves] new_ls = jax.tree_util.tree_unflatten(treedef, new_leaves) self.assertEqual(new_ls[0], 20) self.assertEqual(new_ls[1], 'hello') self.assertEqual(new_ls[2], 40) def test_list_pytree_jit(self): ls = nnx.List([nnx.data(jnp.array(1.0)), nnx.static(999)]) @jax.jit def double(ls): return jax.tree.map(lambda x: x * 2, ls) result = double(ls) np.testing.assert_array_equal(result[0], jnp.array(2.0)) self.assertEqual(result[1], 999) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/ids_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from absl.testing import absltest from flax.nnx import ids class TestIds(absltest.TestCase): def test_hashable(self): id1 = ids.uuid() id2 = ids.uuid() assert id1 == id1 assert id1 != id2 assert hash(id1) != hash(id2) id1c = copy.copy(id1) id1dc = copy.deepcopy(id1) assert hash(id1) != hash(id1c) assert hash(id1) != hash(id1dc) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/integration_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile import typing as tp from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp import numpy as np import orbax.checkpoint as ocp import optax from flax import nnx A = tp.TypeVar('A') class TestIntegration(parameterized.TestCase): @parameterized.parameters(True, False) def test_basic_view_example(self, graph_mode): class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization train_model = nnx.view( model, deterministic=False, use_running_average=False ) eval_model = nnx.view( model, deterministic=True, use_running_average=True ) optimizer = nnx.Optimizer(train_model, optax.adam(1e-3), wrt=nnx.Param) self.assertEqual(train_model.dropout.deterministic, False) self.assertEqual(train_model.bn.use_running_average, False) self.assertEqual(eval_model.dropout.deterministic, True) self.assertEqual(eval_model.bn.use_running_average, True) self.assertIs(train_model.dropout.rngs.count, eval_model.dropout.rngs.count) @nnx.jit(graph=graph_mode) # automatic state management for JAX transforms def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates return loss @nnx.jit(graph=graph_mode) def eval_step(model, x, y): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) x = jax.random.normal(jax.random.key(0), (8, 2)) y = jax.random.normal(jax.random.key(1), (8, 3)) train_step(train_model, optimizer, x, y) self.assertEqual(train_model.dropout.rngs.count[...], 1) eval_step(eval_model, x, y) self.assertEqual(train_model.dropout.rngs.count[...], 1) def test_shared_modules(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): self.linear = linear self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, *, rngs): shared = nnx.Linear(2, 2, rngs=rngs) self.block1 = Block(shared, rngs=rngs) self.block2 = Block(shared, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x @nnx.jit def train_step(model: Model, x, y): @nnx.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) nnx.update( model, jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) model = Model(rngs=nnx.Rngs(0)) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) model.set_attributes(use_running_average=False, graph=True) for _i in range(3): train_step(model, x, y) assert model.block1.linear is model.block2.linear assert model.block1.linear.bias is not None assert model.block1.bn is not model.block2.bn def test_shared_modules_view(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): self.linear = linear self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, *, rngs): shared = nnx.Linear(2, 2, rngs=rngs) self.block1 = Block(shared, rngs=rngs) self.block2 = Block(shared, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x @nnx.jit def train_step(model: Model, x, y): @nnx.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) nnx.update( model, jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) model = Model(rngs=nnx.Rngs(0)) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) new_model = nnx.view(model, use_running_average=False) for _i in range(3): train_step(model, x, y) assert new_model.block1.linear is new_model.block2.linear assert new_model.block1.linear.bias is not None assert new_model.block1.bn is not new_model.block2.bn def test_shared_modules_pure(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): self.linear = linear self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): shared = nnx.Linear(2, 2, rngs=rngs) self.block1 = Block(shared, rngs=rngs) self.block2 = Block(shared, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = nnx.merge(graphdef, state) model.set_attributes(use_running_average=False, graph=True) @nnx.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) nnx.update( model, jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) return nnx.split(model) graphdef: nnx.GraphDef[Model] graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) for _i in range(3): graphdef, state = train_step(state, graphdef, x, y) model = nnx.merge(graphdef, state) assert model.block1.linear.bias is not None assert model.block2.linear.bias is not None assert model.block1.linear.kernel is model.block2.linear.kernel assert model.block1.linear.bias is model.block2.linear.bias assert model.block1.bn is not model.block2.bn def test_shared_modules_pure_view(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): self.linear = linear self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): shared = nnx.Linear(2, 2, rngs=rngs) self.block1 = Block(shared, rngs=rngs) self.block2 = Block(shared, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = nnx.merge(graphdef, state) new_model = nnx.view(model, use_running_average=False, graph=True) @nnx.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(new_model) nnx.update( new_model, jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(new_model, nnx.Param), grads ), ) return nnx.split(new_model) graphdef: nnx.GraphDef[Model] graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) for _ in range(3): graphdef, state = train_step(state, graphdef, x, y) model = nnx.merge(graphdef, state) assert model.block1.linear.bias is not None assert model.block2.linear.bias is not None assert model.block1.linear.kernel is model.block2.linear.kernel assert model.block1.linear.bias is model.block2.linear.bias assert model.block1.bn is not model.block2.bn @parameterized.parameters(True, False) def test_stateful_example(self, graph_mode): class State(nnx.Variable[A]): pass class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = State(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) assert model.count[...] == 1 @nnx.jit(graph=graph_mode) def train_step(model, x, y): def loss_fn(model): y_pred = model(x) return jax.numpy.mean((y_pred - y) ** 2) # compute gradient grads: nnx.State = nnx.grad(loss_fn)(model) # SGD update nnx.update( model, jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads ), ) # execute the training step train_step(model, x, y) assert model.count[...] == 2 def test_functional_example(self): class Count(nnx.Variable[A]): pass class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = Count(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) assert model.count[...] == 1 graphdef, params, counts = nnx.split(model, nnx.Param, Count) @jax.jit def train_step(params, counts, x, y): def loss_fn(params): model = nnx.merge(graphdef, params, counts, copy=True) loss = jax.numpy.mean((model(x) - y) ** 2) return loss, nnx.state(model, Count) # compute gradient grads, counts = jax.grad(loss_fn, has_aux=True)(params) # SGD update params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grads) return params, counts # execute the training step params, counts = train_step(params, counts, x, y) model = nnx.merge(graphdef, params, counts) assert model.count[...] == 2 def test_intermediates_example(self): class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y model = Linear(12, 2, rngs=nnx.Rngs(0)) y = model(jnp.ones((8, 12))) intermediates = nnx.pop(model, nnx.Intermediate) assert 'y' in intermediates def test_intermediates_example_functional(self): class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y model = Linear(12, 2, rngs=nnx.Rngs(0)) graphdef, state = nnx.split(model) y, (_, state) = graphdef.apply(state)(jnp.ones((8, 12))) intermediates, state = nnx.split_state(state, nnx.Intermediate, ...) assert 'y' in intermediates def test_replace_by_pure_dict(self): class MLPs(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.layers = nnx.List() for _ in range(4): self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False)) def __call__(self, x): for layer in self.layers: x = layer(x) return x model = MLPs(4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(42), (3, 4)) assert model(x).shape == (3, 4) _, state = nnx.split(model) pure_dict_state = nnx.to_pure_dict(state) nnx.display(pure_dict_state) with tempfile.TemporaryDirectory() as tmpdir: ckpt_dir = ocp.test_utils.erase_and_create_empty( tmpdir + '/my-checkpoints/' ) checkpointer = ocp.StandardCheckpointer() # checkpointer.save(ckpt_dir / 'state', state) checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state) # Restore as a pure dictionary. restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict') restored_pure_dict = nnx.statelib.restore_int_paths(restored_pure_dict) model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0))) nnx.update(model, restored_pure_dict) assert model(x).shape == (3, 4) # The model still works! @nnx.var_defaults(hijax=True) def test_example_mutable_arrays(self): class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @jax.jit # automatic state management for JAX transforms def train_step(x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return ((model(x) - y) ** 2).mean() # call methods directly loss, grads = jax.value_and_grad(loss_fn)( nnx.vars_as(params, hijax=False) ) optimizer.update(model, grads) # in-place updates return loss x = jax.random.normal(jax.random.key(0), (8, 2)) y = jax.random.normal(jax.random.key(1), (8, 3)) train_step(x, y) def test_tree_mode_train_step(self): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @nnx.jit(graph=False) def train_step(model, optimizer, x, y): def loss_fn(model): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn, graph=False)(model) optimizer.update(model, grads) return loss x = jax.random.normal(jax.random.key(0), (4, 2)) y = jax.random.normal(jax.random.key(1), (4, 3)) loss0 = train_step(model, optimizer, x, y) loss1 = train_step(model, optimizer, x, y) self.assertLess(loss1, loss0) def test_tree_mode_multi_module(self): class Block(nnx.Module): def __init__(self, din, dout, *, rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, *, rngs): self.block1 = Block(2, 2, rngs=rngs) self.block2 = Block(2, 2, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x @nnx.jit(graph=False) def train_step(model: Model, x, y): @nnx.grad(graph=False) def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) model = jax.tree.map( lambda w, g: w - 0.1 * g, model, grads, ) return model model = Model(rngs=nnx.Rngs(0)) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) model.set_attributes(use_running_average=False) for _i in range(3): model = train_step(model, x, y) def test_tree_mode_stateful(self): class State(nnx.Variable[A]): pass class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = State(jnp.array(0)) def __call__(self, x): self.count[...] = self.count[...] + 1 return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) x = jnp.ones((8, 12)) y = model(x) assert model.count[...] == 1 @nnx.jit(graph=False) def train_step(model, x, y): def loss_fn(model): y_pred = model(x) return jax.numpy.mean((y_pred - y) ** 2) grads = nnx.grad(loss_fn, graph=False, allow_int=True)(model) params = jax.tree.map( lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), nnx.state(grads, nnx.Param), ) nnx.update(model, params) train_step(model, x, y) assert model.count[...] == 2 if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/metrics_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from absl.testing import parameterized from flax import nnx import jax import jax.numpy as jnp import numpy as np class TestMetrics(parameterized.TestCase): def test_split_merge(self): logits = jnp.array( [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] ) labels = jnp.array([1, 1, 1, 1, 1]) logits2 = jnp.array( [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] ) labels2 = jnp.array([1, 1, 1, 1, 0]) accuracy = nnx.metrics.Accuracy() accuracy.update(logits=logits, labels=labels) graphdef, state = accuracy.split() accuracy = nnx.merge(graphdef, state) self.assertEqual(accuracy.compute(), 0.6) accuracy.update(logits=logits2, labels=labels2) self.assertEqual(accuracy.compute(), 0.7) def test_welford(self): values = jax.random.normal(jax.random.key(0), (5, 2)) welford = nnx.metrics.Welford() welford.update(values=values) graphdef, state = welford.split() welford = nnx.merge(graphdef, state) expected = nnx.metrics.Statistics( mean=values.mean(), standard_deviation=values.std(), standard_error_of_mean=values.std() / jnp.sqrt(values.size), ) computed = welford.compute() self.assertAlmostEqual(computed.mean, expected.mean, ) self.assertAlmostEqual(computed.standard_deviation, expected.standard_deviation) self.assertAlmostEqual( computed.standard_error_of_mean, expected.standard_error_of_mean ) def test_welford_large(self): values = jax.random.normal(jax.random.key(0), (5, 2)) + 1e16 welford = nnx.metrics.Welford() welford.update(values=values) graphdef, state = welford.split() welford = nnx.merge(graphdef, state) expected = nnx.metrics.Statistics( mean=values.mean(), standard_deviation=values.std(), standard_error_of_mean=values.std() / jnp.sqrt(values.size), ) computed = welford.compute() self.assertAlmostEqual(computed.mean, expected.mean) self.assertAlmostEqual(computed.standard_deviation, expected.standard_deviation) self.assertAlmostEqual( computed.standard_error_of_mean, expected.standard_error_of_mean ) def test_welford_many(self): values = jax.random.normal(jax.random.key(0), (50_000,)) welford = nnx.metrics.Welford() welford.update(values=values) graphdef, state = welford.split() welford = nnx.merge(graphdef, state) computed = welford.compute() self.assertAlmostEqual( computed.mean, 0.0, delta=3 * computed.standard_error_of_mean ) self.assertAlmostEqual(computed.standard_deviation, 1.0, places=2) @parameterized.product(with_mask=[True, False]) def test_multimetric(self, with_mask): logits = jnp.array( [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] ) labels = jnp.array([1, 1, 0, 1, 0]) mask = labels > 0 if with_mask else None logits2 = jnp.array( [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] ) labels2 = jnp.array([0, 1, 1, 1, 1]) mask2 = labels2 > 0 if with_mask else None batch_loss = jnp.array([1, 2, 3, 4]) batch_loss2 = jnp.array([3, 2, 1, 0]) loss_mask2 = batch_loss2 > 1 if with_mask else None metrics = nnx.MultiMetric( accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ) values = metrics.compute() self.assertTrue(jnp.isnan(values['accuracy'])) self.assertTrue(jnp.isnan(values['loss'])) metrics.update(logits=logits, labels=labels, values=batch_loss, mask={"accuracy": mask}) values = metrics.compute() self.assertEqual(values['accuracy'], 0.6 if not with_mask else 2 / 3) self.assertEqual(values['loss'], 2.5) metrics.update(logits=logits2, labels=labels2, values=batch_loss2, mask={"accuracy": mask2, "loss": loss_mask2}) values = metrics.compute() self.assertEqual(values['accuracy'], 0.5 if not with_mask else (2 + 2) / (3 + 4)) self.assertEqual(values['loss'], 2.0 if not with_mask else 15 / 6) metrics.reset() values = metrics.compute() self.assertTrue(jnp.isnan(values['accuracy'])) self.assertTrue(jnp.isnan(values['loss'])) @parameterized.product(with_mask=[True, False]) def test_multimetric_with_custom_metric(self, with_mask): class CustomAccuracy(nnx.metrics.Accuracy): # we use unused values arg is on purpose instead of using **kwargs # to reproduce the error if mask arg is injected to the custom metric without mask arg def update(self, y_preds, y_true, values): super().update(logits=y_preds, labels=y_true) logits = jnp.array( [[-1.0, 1.0], [-1.0, 1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, -1.0]] ) labels = jnp.array([1, 1, 0, 1, 0]) batch_loss = jnp.array([1, 2, 3, 4]) mask = batch_loss > 1 if with_mask else None metrics = nnx.MultiMetric( accuracy=CustomAccuracy(), loss=nnx.metrics.Average() ) metrics.update(y_preds=logits, y_true=labels, values=batch_loss, mask={"loss": mask}) values = metrics.compute() self.assertEqual(values['accuracy'], 0.6) self.assertEqual(values['loss'], 3.0 if with_mask else 2.5) def test_binary_classification_accuracy(self): logits = jnp.array([0.4, 0.7, 0.2, 0.6]) labels = jnp.array([0, 1, 1, 1]) logits2 = jnp.array([0.1, 0.9, 0.8, 0.3]) labels2 = jnp.array([0, 1, 1, 0]) accuracy = nnx.metrics.Accuracy(threshold=0.5) accuracy.update(logits=logits, labels=labels) self.assertEqual(accuracy.compute(), 0.75) accuracy.update(logits=logits2, labels=labels2) self.assertEqual(accuracy.compute(), 0.875) @parameterized.parameters( { 'logits': np.array([[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]), 'labels': np.array([0, 0, 0, 0]), 'threshold': None, 'error_msg': 'For multi-class classification' }, { 'logits': np.array([0.0, 0.0, 0.0, 0.0]), 'labels': np.array([[0, 0], [0, 0]]), 'threshold': 0.5, 'error_msg': 'For binary classification' } ) def test_accuracy_dims(self, logits, labels, threshold, error_msg): accuracy = nnx.metrics.Accuracy(threshold=threshold) with self.assertRaisesRegex(ValueError, error_msg): accuracy.update(logits=logits, labels=labels) @parameterized.product( with_mask=[True, False], scalar_values=[True, False], ) def test_average(self, with_mask, scalar_values): average = nnx.metrics.Average() value1 = 2 if scalar_values else jnp.arange(5 * 3, dtype=jnp.int32).reshape(5, 3) value2 = 3 if scalar_values else jnp.arange(5 * 3 * 4, dtype=jnp.float32).reshape(5, 3, 4) if with_mask: list_masks = [ jnp.ones(5) if scalar_values else jnp.where(value1 > 4, 1.0, 0.0), jnp.ones(5) if scalar_values else value2 > 10, ] else: list_masks = [None, None] list_values = [value1, value2] for mask, values in zip(list_masks, list_values): if with_mask and scalar_values: with self.assertRaisesRegex(ValueError, "should be a jax array"): average.update(mask=mask, values=values) else: average.update(mask=mask, values=values) if with_mask and scalar_values: return self.assertEqual( average.count, ( len(list_values) if scalar_values else sum(m.sum() for m in list_masks) if with_mask else sum(v.size for v in list_values) ) ) self.assertEqual( average.total, ( sum(list_values) if scalar_values else sum((v * m).sum() for m, v in zip(list_masks, list_values)) if with_mask else sum(v.sum() for v in list_values) ) ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/module_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy import dataclasses import pickle import tempfile from typing import TypeVar from absl.testing import absltest from absl.testing import parameterized import cloudpickle from flax import errors, nnx import jax import jax.numpy as jnp import numpy as np import flax A = TypeVar('A') from contextlib import contextmanager @contextmanager def set_graph_mode(mode): old_mode = flax.config._read('nnx_graph_mode') try: flax.config.update('nnx_graph_mode', mode) yield finally: flax.config.update('nnx_graph_mode', old_mode) class PytreeTest(absltest.TestCase): def test_pytree(self): class Foo(nnx.Pytree): def __init__(self, a, b): self.a = nnx.data(a) self.b = nnx.static(b) foo = Foo(a=1, b=2) self.assertEqual(jax.tree.leaves(foo), [1]) def test_sequential_map(self): model = nnx.Sequential(nnx.Linear(2,8, rngs=nnx.Rngs(0))) jax.tree.map(lambda x: x + 1, model) # shouldn't error def test_sequential_has_leaves(self): model = nnx.Sequential(nnx.Linear(2,8, rngs=nnx.Rngs(0))) self.assertLen(jax.tree.leaves(model), 2) def test_consistent_attrs(self): class Foo(nnx.Pytree): def __init__(self, a, b, c): self.a = nnx.data(a) self.b = nnx.static(b) self.c = c foo = Foo(a=1, b=2, c=jnp.array(3)) self.assertLen(jax.tree.leaves(foo), 2) foo.a = 3 self.assertLen(jax.tree.leaves(foo), 2) foo.a = nnx.static(3) self.assertLen(jax.tree.leaves(foo), 1) foo.b = 4 # ok self.assertLen(jax.tree.leaves(foo), 1) foo.b = nnx.data(4) self.assertLen(jax.tree.leaves(foo), 2) foo.c = jnp.array(5) # ok self.assertLen(jax.tree.leaves(foo), 2) foo.c = nnx.static(5) self.assertLen(jax.tree.leaves(foo), 1) with self.assertRaisesRegex( ValueError, 'Found data on value of type', ): foo.a = ['hi', jnp.array(6)] with self.assertRaisesRegex( ValueError, 'Found data in value of type', ): foo.b = nnx.static(jnp.array(4)) def test_assing_pytree_with_data(self): class Foo(nnx.Pytree): pass foo = Foo() with self.assertRaisesRegex( ValueError, 'Found data on value of type', ): foo.a = [nnx.Variable(1)] def test_consistent_attrs_frozen_dataclass(self): @nnx.dataclass class Foo(nnx.Pytree): a: int = nnx.data() b: int = nnx.static() c: jax.Array foo = Foo(a=1, b=2, c=jnp.array(3)) self.assertLen(jax.tree.leaves(foo), 2) def test_consistent_attrs_dataclass_annotations(self): @dataclasses.dataclass class Foo(nnx.Pytree): a: nnx.Data[int] b: nnx.Static[int] c: jax.Array foo = Foo(a=1, b=2, c=jnp.array(3)) self.assertLen(jax.tree.leaves(foo), 2) foo.a = 3 self.assertLen(jax.tree.leaves(foo), 2) foo.a = nnx.static(3) self.assertLen(jax.tree.leaves(foo), 1) foo.b = 4 # ok self.assertLen(jax.tree.leaves(foo), 1) foo.b = nnx.data(4) self.assertLen(jax.tree.leaves(foo), 2) foo.c = jnp.array(5) # ok self.assertLen(jax.tree.leaves(foo), 2) foo.c = nnx.static(5) self.assertLen(jax.tree.leaves(foo), 1) with self.assertRaisesRegex( ValueError, 'Found data on value of type', ): foo.a = ['hi', jnp.array(6)] with self.assertRaisesRegex( ValueError, 'Found data in value of type', ): foo.b = nnx.static(jnp.array(4)) def test_explicit_dont_change(self): class Foo(nnx.Pytree): def __init__(self): self.b = nnx.data(2) foo = Foo() self.assertEqual(jax.tree.leaves(foo), [2]) foo.b = "hello" self.assertEqual(jax.tree.leaves(foo), ["hello"]) def test_no_data_in_static(self): class Foo(nnx.Pytree): def __init__(self): self.a = nnx.static(jnp.array(1)) with self.assertRaisesRegex( ValueError, 'Found data in value of type', ): foo = Foo() class TestCapture(parameterized.TestCase): def test_vmap(self): class Foo(nnx.Module): def __init__(self, dim): self.w = nnx.Param(jax.random.normal(jax.random.key(0), dim)) def __call__(self, x): x = self.perturb('grad_of_x', x) y = jnp.dot(x, self.w) self.sow(nnx.Intermediate, 'y', y) return y def pre_run(self, x): graphdef, intms, params = nnx.split(model, nnx.Intermediate, nnx.Param) def run(intms, params, x): return nnx.merge(graphdef, intms, params)(x) nnx.vmap(run, in_axes=(0, None, 0))(intms, params, x) @nnx.jit def train_step(model, perturbations, x): def loss_grad(model, perturbations, x): def loss(model, perturbations, x): loss, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(x) return loss, interms (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) return grads, nnx.merge_state(perturb_grads, sowed) return nnx.vmap(loss_grad, in_axes=(None, 0, 0))(model, perturbations,x) model, x = Foo(4), jnp.ones((3, 4)) pre_run_capture = nnx.capture(model.pre_run, nnx.Perturbation) _, perturbations = pre_run_capture(x) _, intermediates = train_step(model, perturbations, x) np.testing.assert_allclose(intermediates['grad_of_x'].get_value(), jnp.broadcast_to(model.w.get_value()[None, :], (3, 4))) self.assertEqual(intermediates['y'].get_value()[0].shape, (3,)) @parameterized.parameters(True, False) def test_fwd_bwd(self, graph_mode): with set_graph_mode(graph_mode): class Foo(nnx.Module): @nnx.jit def __call__(self, x): x = self.perturb('grad_of_x', x) y = 3 * x self.sow(nnx.Intermediate, 'y', y) return y model = Foo() @nnx.jit def train_step(model, perturbations, x): def loss(model, perturbations, x): loss, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(x) return loss, interms (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) return grads, nnx.merge_state(perturb_grads, sowed) x = 1.0 forward = nnx.capture(model, nnx.Perturbation) _, perturbations = forward(x) grads, intermediates = train_step(model, perturbations, x) self.assertEqual(intermediates['grad_of_x'], 3) self.assertEqual(intermediates['y'][0], 3) @parameterized.parameters(True, False) def test_nested_modules(self, graph_mode): with set_graph_mode(graph_mode): class Foo(nnx.Module): def __call__(self, x): x = self.perturb('grad_of_x', x) y = 3 * x self.sow(nnx.Intermediate, 'y', y) return y class Bar(nnx.Module): def __init__(self): self.foos = nnx.data([Foo() for _ in range(3)]) def __call__(self, x): for block in self.foos: x = block(x) return x model = Bar() @nnx.jit def train_step(model, perturbations, x): def loss(model, perturbations, x): loss, interms = nnx.capture(model, nnx.Intermediate, init=perturbations)(x) return loss, interms (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x) return grads, nnx.merge_state(perturb_grads, sowed) x = 1.0 forward = nnx.capture(model, nnx.Perturbation) _, perturbations = forward(x) _, intermediates = train_step(model, perturbations, x) for i in range(3): self.assertEqual(intermediates['foos'][i]['grad_of_x'], 3**(3-i)) self.assertEqual(intermediates['foos'][i]['y'][0], 3**(i+1)) def test_method_outputs_single_module(self): class Foo(nnx.Module): def __init__(self, dim): self.w = nnx.Param(jax.random.normal(jax.random.key(0), (dim, dim))) def __call__(self, x): return x @ self.w def helper(self, x): return jnp.sin(x) def run(self, x): y = self(x) z = self.helper(y) return (y, z) model = Foo(8) x = jnp.ones((4, 8)) run_with_capture = nnx.capture( model.run, nnx.Intermediate, method_outputs=nnx.Intermediate ) (y, z), intms = run_with_capture(x) self.assertIn('__call__', intms) self.assertIn('helper', intms) np.testing.assert_allclose(intms['__call__'][0], y) np.testing.assert_allclose(intms['helper'][0], z) def test_method_outputs_nested_modules(self): class Inner(nnx.Module): def __init__(self, dim, rngs): self.w = nnx.Param(jax.random.normal(rngs.params(), (dim, dim))) def __call__(self, x): return x @ self.w def process(self, x): return jnp.sin(x) class Outer(nnx.Module): def __init__(self, rngs): self.inner1 = Inner(8, rngs) self.inner2 = Inner(8, rngs) def __call__(self, x): x = self.inner1(x) x = self.inner2.process(x) return x model = Outer(nnx.Rngs(0)) x = jnp.ones((4, 8)) forward = nnx.capture( model, nnx.Intermediate, method_outputs=nnx.Intermediate ) y, intms = forward(x) self.assertIn('__call__', intms) self.assertIn('inner1', intms) self.assertIn('process', intms['inner2']) self.assertEqual(intms['inner1']['__call__'][0].shape, (4, 8)) self.assertEqual(intms['inner2']['process'][0].shape, (4, 8)) def test_method_outputs_mixed_with_sow(self): class Foo(nnx.Module): def __init__(self, dim): self.w = nnx.Param(jax.random.normal(jax.random.key(0), (dim, dim))) def __call__(self, x): x = x @ self.w self.sow(nnx.Intermediate, 'intermediate', x) return jnp.sin(x) model = Foo(8) x = jnp.ones((4, 8)) forward = nnx.capture( model, nnx.Intermediate, method_outputs=nnx.Intermediate ) y, intms = forward(x) self.assertIn('__call__', intms) self.assertIn('intermediate', intms) np.testing.assert_allclose(intms['__call__'][0], y) np.testing.assert_allclose(jnp.sin(intms['intermediate'][0]), y) class SowMod(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(4, 4, rngs=rngs) def __call__(self, x): y = self.linear(x) self.sow(nnx.Intermediate, "my_summary", y.mean()) return y * 2 class TestModule(parameterized.TestCase): def test_has_module_state(self): class Foo(nnx.Module): ... foo = Foo() assert hasattr(foo, '_pytree__state') def test_trace_level(self): m = nnx.Dict(a=nnx.Param(1)) @jax.jit def f(): with self.assertRaisesRegex( errors.TraceContextError, "Cannot mutate 'Dict' from different trace level", ): m.a = 2 f() def test_tree_map(self): m = nnx.Dict(a=nnx.Param(1)) graphdef, state = nnx.split(m) state = jax.tree.map(lambda x: x + 1, state) def test_split_2(self): m = nnx.Dict(a=nnx.Param(1)) graphdef, empty, some = nnx.split(m, None, ...) some = jax.tree.map(lambda x: x + 1, some) def test_split_merge(self): m = nnx.Dict(a=nnx.Param(1)) @jax.jit def g(graphdef: nnx.GraphDef[nnx.Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a = 2 return nnx.split(m) graphdef, state = g(*nnx.split(m)) m2 = nnx.merge(graphdef, state) assert m2.a == 2 def test_call(self): class Foo(nnx.Module): def __init__(self, c: float, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, ())) self.c = c def __call__(self, x, *, rngs: nnx.Rngs): return self.w * x + rngs.e.normal(()) + self.c foo = Foo(c=1.0, rngs=nnx.Rngs(0)) y = foo(x=2.0, rngs=nnx.Rngs(e=1)) assert isinstance(y, jax.Array) def test_shared_module(self): m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) m3 = nnx.merge(*nnx.split(m2)) assert m3['x'] is m3['y'] assert m3['x']['a'] is m3['y']['a'] assert m3['x']['b'] is m3['y']['b'] def test_module_graph(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.sub = self m = Foo() graphdef, state = nnx.split(m) assert len(state) == 1 m2 = nnx.merge(graphdef, state) assert m2 is m2.sub def test_deref_through_jit(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) m = m0 = nnx.Dict({'a': nnx.List([r1, r2]), 'b': r1}) @jax.jit def f(graphdef: nnx.GraphDef[nnx.Dict], state: nnx.State): m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] return nnx.split(m) graphdef, state = f(*nnx.split(m)) m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] # compare with original assert m['a'][0] is not m0['a'][0] assert m['a'][1] is not m0['a'][1] assert m['b'] is not m0['b'] def test_cross_barrier(self): m = nnx.Dict(a=nnx.Param(jnp.array(1))) @jax.jit def g(graphdef: nnx.GraphDef[nnx.Dict], state: nnx.State): m = nnx.merge(graphdef, state) m.a[...] += 1 return nnx.split(m) graphdef, state = g(*nnx.split(m)) m2 = nnx.merge(graphdef, state) assert m2 is not m assert m.a[...] == 1 assert m2.a[...] == 2 def test_no_rejit(self): n = 0 m = nnx.Dict(a=nnx.Param(jnp.array(1))) @jax.jit def g(state_and_def): nonlocal n n += 1 m = nnx.merge(*state_and_def) m.a[...] += 1 return nnx.split(m) m2 = nnx.merge(*g(nnx.split(m))) assert n == 1 assert m2 is not m assert m.a[...] == 1 assert m2.a[...] == 2 g(nnx.split(m)) assert n == 1 g(nnx.split(m2)) assert n == 1 m2.b = nnx.Param(10) g(nnx.split(m2)) assert n == 2 def test_deref_number_of_fields(self): r1 = nnx.Variable(1) r2 = nnx.Variable(2) v1 = 3 m = nnx.Dict( { 'a': nnx.List([r1, r2, v1]), 'b': nnx.Dict({'c': r1, 'd': r2}), } ) graphdef, p = nnx.split(m) assert len(nnx.to_flat_state(p)) == 2 assert len(jax.tree_util.tree_leaves(p)) == 2 def test_clone(self): m = nnx.Dict( a=nnx.List([nnx.Param(1), nnx.Param(2), 3]), b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), ) m2 = nnx.clone(m) assert m is not m2 assert m2.a[0].get_value() == m2.b.c.get_value() assert m2.a[1].get_value() == m2.b.d.get_value() assert m.a[0].get_value() == m2.a[0].get_value() assert m.a[1].get_value() == m2.a[1].get_value() assert m.b.c.get_value() == m2.b.c.get_value() assert m.b.d.get_value() == m2.b.d.get_value() def test_sow_existing_non_variable_field(self): class Foo(nnx.Module): def __init__(self) -> None: self.y = 10 def __call__(self, x): y = x + 1 self.sow(nnx.Intermediate, 'y', y) return y m = Foo() with self.assertRaisesRegex(ValueError, 'to be a Variable, got'): m(2) def test_sow_wrong_collection(self): class Foo(nnx.Module): def __init__(self) -> None: self.y = nnx.Param(10) def __call__(self, x): y = x + 1 self.sow(nnx.Intermediate, 'y', y) return y m = Foo() with self.assertRaisesRegex(ValueError, 'to be of type'): m(2) def test_sow_pop(self): x = jnp.ones((2, 4)) model = SowMod(nnx.Rngs(42)) out, intermediates = nnx.capture(model, nnx.Intermediate)(x) attr_names = set(model._pytree__nodes) assert 'my_summary' not in attr_names def test_cached_partial(self): model = SowMod(nnx.Rngs(42)) x = jnp.ones((2, 4)) @nnx.jit def train_step(model, x): out, intermediates = nnx.capture(model, nnx.Intermediate)(x) return out, intermediates train_step_fn = nnx.cached_partial(train_step, model) train_step_fn(x) def test_update_static_state_submodules(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 def add_field(self): self.y = 2 class Foo(nnx.Module): def __init__(self) -> None: self.a = Bar() self.b = self.a m1 = Foo() with nnx.update_context('test'): with nnx.split_context('test') as ctx: graphdef, state = ctx.split(m1) with nnx.merge_context('test', inner=True) as ctx: m2 = ctx.merge(graphdef, state) m2.a.add_field() with nnx.split_context('test') as ctx: new_graphdef, state = ctx.split(m2) with nnx.merge_context('test', inner=False) as ctx: m3 = ctx.merge(new_graphdef, state) assert m3 is m1 assert m1.a.x == 1 assert m1.a.y == 2 assert m1.b.x == 1 assert m1.b.y == 2 def test_update_new_submodule(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 class Foo(nnx.Module): def __init__(self) -> None: self.a = Bar() def add_module(self): self.b = Bar() m1 = Foo() with nnx.update_context('test'): with nnx.split_context('test') as ctx: graphdef, state = ctx.split(m1) with nnx.merge_context('test', inner=True) as ctx: m2 = ctx.merge(graphdef, state) m2.add_module() with nnx.split_context('test') as ctx: new_graphdef, state = ctx.split(m2) with nnx.merge_context('test', inner=False) as ctx: m3 = ctx.merge(new_graphdef, state) assert m3 is m1 assert m1.a.x == 1 assert m1.b.x == 1 def test_update_update_submodule(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 class Foo(nnx.Module): def __init__(self) -> None: self.a = Bar() self.b = self.a m1 = Foo() with nnx.update_context('test'): with nnx.split_context('test') as ctx: graphdef, state = ctx.split(m1) with nnx.merge_context('test', inner=True) as ctx: m2 = ctx.merge(graphdef, state) m2.a.x = 2 with nnx.split_context('test') as ctx: new_graphdef, state = ctx.split(m2) with nnx.merge_context('test', inner=False) as ctx: m3 = ctx.merge(new_graphdef, state) assert m3 is m1 assert m1.a.x == 2 assert m1.b.x == 2 def test_update_add_shared(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 class Foo(nnx.Module): def __init__(self) -> None: self.a = Bar() self.b = self.a def add_submodule(self): self.c = self.a m1 = Foo() with nnx.update_context('test'): with nnx.split_context('test') as ctx: graphdef, state = ctx.split(m1) with nnx.merge_context('test', inner=True) as ctx: m2 = ctx.merge(graphdef, state) m2.add_submodule() with nnx.split_context('test') as ctx: new_graphdef, state = ctx.split(m2) with nnx.merge_context('test', inner=False) as ctx: m3 = ctx.merge(new_graphdef, state) assert m3 is m1 assert hasattr(m1, 'c') def test_create_abstract(self): linear = nnx.eval_shape(lambda: nnx.Linear(2, 3, rngs=nnx.Rngs(0))) assert linear.kernel.get_value() == jax.ShapeDtypeStruct((2, 3), jnp.float32) assert linear.bias.get_value() == jax.ShapeDtypeStruct((3,), jnp.float32) def test_create_abstract_stateful(self): linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0))) assert linear.rngs.key.get_value() == jax.ShapeDtypeStruct( (), jax.random.key(0).dtype ) def test_partial_init(self): linear = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) state = nnx.state(linear) del state['bias'] @nnx.jit def partial_init(state: nnx.State): m = nnx.Linear( 2, 3, bias_init=nnx.initializers.ones_init(), rngs=nnx.Rngs(1) ) nnx.update(m, state) return m linear2 = partial_init(state) np.testing.assert_allclose(linear.kernel[...], linear2.kernel[...]) np.testing.assert_allclose(linear.bias[...], 0) np.testing.assert_allclose(linear2.bias[...], 1) def test_deepcopy(self): class Foo(nnx.Module): def __init__(self) -> None: self.a = nnx.Param(jnp.array(1)) self.b = [1, 2, 3] self.c = nnx.Param(jnp.array([1.0])) self.self = self m1 = Foo() m2 = deepcopy(m1) assert m1.a[...] == m2.a[...] assert vars(m1)['a'] is not vars(m2)['a'] assert m1.b is not m2.b assert m1.c is not m2.c assert m1.self is m1 def test_set_attributes(self): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) assert block.dropout.deterministic == False assert block.batch_norm.use_running_average == False block.set_attributes(deterministic=True, use_running_average=True) assert block.dropout.deterministic == True assert block.batch_norm.use_running_average == True block = Block(2, 5, rngs=nnx.Rngs(0)) block.set_attributes(nnx.Dropout, deterministic=True) # Only the dropout will be modified assert block.dropout.deterministic == True assert block.batch_norm.use_running_average == False def test_set_attribute_error(self): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) with self.assertRaisesRegex( ValueError, ( 'Could not find at least one instance of the following attributes:' " \\['unknown'\\]" ), ): block.set_attributes( deterministic=True, use_running_average=True, unknown=True ) block.set_attributes( deterministic=True, use_running_average=True, unknown=True, raise_if_not_found=False, ) @parameterized.parameters(True, False) def test_view(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) assert block.dropout.deterministic == False assert block.batch_norm.use_running_average == False new_block = nnx.view(block, deterministic=True, use_running_average=True, graph=graph) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == True assert new_block.linear.kernel is block.linear.kernel block = Block(2, 5, rngs=nnx.Rngs(0)) new_block = nnx.view(block, only=nnx.Dropout, deterministic=True, graph=graph) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == False @parameterized.parameters(True, False) def test_with_attributes(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) assert block.dropout.deterministic == False assert block.batch_norm.use_running_average == False new_block = nnx.with_attributes( block, deterministic=True, use_running_average=True, graph=graph ) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == True assert new_block.linear.kernel is block.linear.kernel assert block.dropout.deterministic == False assert block.batch_norm.use_running_average == False @parameterized.parameters(True, False) def test_with_attributes_filter(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) new_block = nnx.with_attributes( block, only=nnx.Dropout, deterministic=True, graph=graph ) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == False @parameterized.parameters(True, False) def test_with_attributes_error(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) with self.assertRaisesRegex( ValueError, ( 'Could not find at least one instance of the following attributes:' " \\['unknown'\\]" ), ): nnx.with_attributes( block, deterministic=True, use_running_average=True, unknown=True, graph=graph, ) new_block = nnx.with_attributes( block, deterministic=True, use_running_average=True, unknown=True, raise_if_not_found=False, graph=graph, ) assert new_block.dropout.deterministic == True assert new_block.batch_norm.use_running_average == True @parameterized.parameters(True, False) def test_view_error(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) self.batch_norm = nnx.BatchNorm( 10, use_running_average=False, rngs=rngs ) block = Block(2, 5, rngs=nnx.Rngs(0)) with self.assertRaisesRegex( ValueError, ( "Unused keys found in nnx.view: \\['unknown'\\]" ), ): nnx.view(block, deterministic=True, use_running_average=True, unknown=True, graph=graph) def test_cloud_pickle(self): import platform if platform.python_version().startswith('3.11'): self.skipTest("Cloudpickle cannot pickle PRNGKeyArray on python 3.11") class Model(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dmid, rngs=rngs) self.bn = nnx.BatchNorm(dmid, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization model.eval() y1 = model(jnp.ones((5, 2))) with tempfile.TemporaryDirectory() as tmpdir: path = f'{tmpdir}/model.pkl' with open(path, 'wb') as f: cloudpickle.dump(model, f) del model with open(path, 'rb') as f: model = pickle.load(f) self.assertIsInstance(model, Model) y2 = model(jnp.ones((5, 2))) np.testing.assert_allclose(y1, y2) def test_repr(self): class Block(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) class Foo(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.block1 = Block(32, 128, rngs=rngs) self.block2 = Block(128, 10, rngs=rngs) def __call__(self, x): return self.block2(self.block1(x)) obj = Foo(nnx.Rngs(0)) leaves = nnx.to_flat_state(nnx.state(obj)).leaves expected_total = sum(int(np.prod(x.shape)) for x in leaves) expected_total_params = sum( int(np.prod(x.shape)) for x in leaves if isinstance(x, nnx.Param) ) expected_total_batch_stats = sum( int(np.prod(x.shape)) for x in leaves if isinstance(x, nnx.BatchStat) ) expected_total_rng_states = sum( int(np.prod(x.shape)) for x in leaves if isinstance(x, nnx.RngState) ) foo_repr = repr(obj).replace(',', '').splitlines() self.assertIn(str(expected_total), foo_repr[0]) self.assertIn(str(expected_total_params), foo_repr[0]) self.assertIn(str(expected_total_batch_stats), foo_repr[0]) self.assertIn(str(expected_total_rng_states), foo_repr[0]) @parameterized.parameters(True, False) def test_view_info(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) class Foo(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.block1 = Block(32, 128, rngs=rngs) self.block2 = Block(128, 10, rngs=rngs) def __call__(self, x): return self.block2(self.block1(x)) obj = Foo(rngs=nnx.Rngs(0)) info_str = nnx.view_info(obj, graph=graph) self.assertEqual(info_str.count("BatchNorm:"), 1) self.assertEqual(info_str.count("Dropout:"), 1) @parameterized.parameters(True, False) def test_view_info_with_filter(self, graph): class Block(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) obj = Block(4, 8, rngs=nnx.Rngs(0)) info_str = nnx.view_info(obj, only=nnx.Dropout, graph=graph) self.assertIn("Dropout:", info_str) self.assertNotIn("BatchNorm:", info_str) info_str = nnx.view_info(obj, only=nnx.MultiHeadAttention, graph=graph) self.assertEmpty(info_str) @parameterized.parameters(True, False) def test_view_info_with_custom_set_mode(self, graph): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): pass def __call__(self, x): return x def set_view(self, arg1: bool | None = None, arg2: int | None = None, **kwargs) -> dict: """Example set_view docstring. This follows Google style docstrings. Args: arg1: The first argument. arg2: The second argument. This has two lines. """ return kwargs obj = Block(rngs=nnx.Rngs(0)) info_str = nnx.view_info(obj, graph=graph) self.assertEqual(f"{obj.__class__.__qualname__}:\n arg1: bool | None = None\n The first argument.\n arg2: int | None = None\n The second argument.\n This has two lines.", info_str) class TestModuleDataclass(absltest.TestCase): def test_basic(self): @dataclasses.dataclass class Foo(nnx.Module): a: int b: nnx.Variable[int] c: nnx.Param[int] d: nnx.Variable[int] e: nnx.Variable[int] f: int m = Foo( a=1, # graphdef b=nnx.Variable(2), # node c=nnx.Param(3), # param d=nnx.Variable(4), # var e=nnx.BatchStat(5), # var f=6, # graphdef int ) graphdef, state = nnx.split(m) assert len(state) == 4 assert state['b'].get_value() == 2 assert isinstance(state['b'], nnx.Variable) assert state['c'].get_value() == 3 assert isinstance(state['c'], nnx.Param) assert state['d'].get_value() == 4 assert isinstance(state['d'], nnx.Variable) assert state['e'].get_value() == 5 assert isinstance(state['e'], nnx.BatchStat) def test_field_specifiers(self): @nnx.dataclass class Foo(nnx.Pytree): a: int = nnx.static() b: jax.Array = nnx.data() m = Foo(a=1, b=jnp.array(2)) leaves = jax.tree.leaves(m) assert len(leaves) == 1 assert leaves[0] == jnp.array(2) def test_field_specifiers_forced(self): @nnx.dataclass class Bar(nnx.Pytree): a: int = nnx.data() m = Bar(a=1) leaves = jax.tree.leaves(m) assert len(leaves) == 1 assert leaves[0] == 1 def test_field_specifiers_with_defaults(self): @nnx.dataclass class Bar(nnx.Pytree): a: int = nnx.data(default=3) m = Bar() leaves = jax.tree.leaves(m) assert len(leaves) == 1 assert leaves[0] == 3 def test_field_specifiers_array_in_static(self): @nnx.dataclass class Bar(nnx.Pytree): a: jax.Array = nnx.static() with self.assertRaisesRegex( ValueError, 'Found unexpected data on value of type', ): m = Bar(a=jnp.array(3)) def test_variable_in_static_list(self): @nnx.dataclass class Foo(nnx.Module): filters: list with self.assertRaisesRegex( ValueError, 'Found data on value of type', ): Foo([nnx.Variable(1)]) def test_module_in_static_list(self): class Bar(nnx.Module): pass @nnx.dataclass class Foo(nnx.Module): filters: list with self.assertRaisesRegex( ValueError, 'Found data on value of type', ): Foo([Bar()]) def test_post_init(self): @dataclasses.dataclass class DFoo(nnx.Module): din: int dout: int rngs: nnx.Rngs def __post_init__(self): self.bar = nnx.Linear(self.din, self.dout, rngs=self.rngs) def __call__(self, x): return self.bar(x) m = DFoo(1, 1, rngs=nnx.Rngs(0)) assert hasattr(m, 'bar') class TestModuleDef(parameterized.TestCase): def test_apply(self): class Foo(nnx.Module): def __init__(self, c: float, *, rngs: nnx.Rngs): self.w = nnx.Param(jax.random.uniform(rngs.params(), ())) self.c = c def __call__(self, x, *, rngs: nnx.Rngs): return self.w * x + rngs.e.normal(()) + self.c rngs = nnx.Rngs(0) foo = Foo(c=1.0, rngs=rngs) graphdef, states = nnx.split(foo) assert isinstance(states, nnx.State) assert isinstance(states['w'], nnx.Param) y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) assert isinstance(y, jax.Array) def test_derefed_mod_apply(self): class Foo(nnx.Module): def __init__(self, c: float, *, rngs: nnx.Rngs): self.w = nnx.Param( jax.random.uniform(rngs.params(), ()), ) self.c = nnx.Variable(c) def __call__(self, x, *, rngs: nnx.Rngs): return self.w * x + rngs.e.normal(()) + self.c foo = Foo(c=1.0, rngs=nnx.Rngs(0)) graphdef, state = nnx.split(foo) assert isinstance(graphdef.nodes[0], nnx.graphlib.NodeDef | nnx.graphlib.NodeRef) assert isinstance(state, nnx.State) assert isinstance(state['w'], nnx.Param) assert isinstance(state['c'], nnx.Variable) y, (graphdef, state) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) assert isinstance(y, jax.Array) def test_modules_iterator(self): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.submodules = nnx.data([ {'a': nnx.Linear(1, 1, rngs=rngs)}, {'b': nnx.Conv(1, 1, 1, rngs=rngs)}, ]) self.linear = nnx.Linear(1, 1, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) module = Foo(rngs=nnx.Rngs(0)) modules = list(nnx.iter_modules(module)) assert len(modules) == 5 assert modules[0][0] == ('dropout',) assert isinstance(modules[0][1], nnx.Dropout) assert modules[1][0] == ('linear',) assert isinstance(modules[1][1], nnx.Linear) assert modules[2][0] == ('submodules', 0, 'a') assert isinstance(modules[2][1], nnx.Linear) assert modules[3][0] == ('submodules', 1, 'b') assert isinstance(modules[3][1], nnx.Conv) assert modules[4][0] == () assert isinstance(modules[4][1], Foo) @parameterized.parameters(True, False) def test_children_modules_iterator(self, graph): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.submodules = nnx.data([ {'a': nnx.Linear(1, 1, rngs=rngs)}, {'b': nnx.Conv(1, 1, 1, rngs=rngs)}, ]) self.linear = nnx.Linear(1, 1, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) module = Foo(rngs=nnx.Rngs(0)) modules = list(nnx.iter_children(module, graph=graph)) assert len(modules) == 2 assert modules[0][0] == 'dropout' assert isinstance(modules[0][1], nnx.Dropout) assert modules[1][0] == 'linear' assert isinstance(modules[1][1], nnx.Linear) def test_state_in_module(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.data(nnx.State({'b': nnx.Param(jnp.array(1.0))})) foo = Foo() graphdef, state = nnx.split(foo) assert isinstance(state, nnx.State) assert isinstance(state['a'], nnx.State) foo2 = nnx.merge(graphdef, state) assert isinstance(foo2.a, nnx.State) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/mutable_array_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses from absl.testing import absltest, parameterized import optax from flax import nnx import flax.errors import jax import jax.numpy as jnp import numpy as np class TestPytree(absltest.TestCase): def test_pytree(self): class Foo(nnx.Module): def __init__(self): self.node = jnp.array(1) self.meta = 1 m = nnx.vars_as(Foo(), ref=True) m = jax.tree.map(lambda x: x + 1, m) assert m.node == 2 assert m.meta == 1 def test_pytree_data_typehint(self): class Foo(nnx.Module): node: jax.Array = nnx.data() def __init__(self): self.node = jnp.array(1) self.meta = 1 m = Foo() m = jax.tree.map(lambda x: x + 1, m) assert m.node == 2 assert m.meta == 1 def test_pytree_data_instance(self): class Foo(nnx.Module): def __init__(self): self.node = nnx.data(jnp.array(1)) self.meta = 1 m = Foo() m = jax.tree.map(lambda x: x + 1, m) assert m.node == 2 assert m.meta == 1 def test_pytree_dataclass(self): @nnx.dataclass class Foo(nnx.Module): node: jax.Array = nnx.data() meta: int meta2: int = 3 meta3: int = 4 meta4: int = 5 node2: int = nnx.data(default=6) m = Foo(node=jnp.array(1), meta=1) m: Foo = jax.tree.map(lambda x: x + 1, m) assert m.node == 2 assert m.meta == 1 assert m.meta2 == 3 assert m.meta3 == 4 assert m.meta4 == 5 assert m.node2 == 7 def test_data_example(self): class Foo(nnx.Pytree): def __init__(self): self.data_attr = nnx.data(42) # pytree data self.static_attr = 'hello' # static attribute foo = Foo() self.assertEqual(jax.tree.leaves(foo), [42]) def test_register_data_type(self): @dataclasses.dataclass(frozen=True) class MyType: value: int nnx.register_data_type(MyType) class Foo(nnx.Pytree): def __init__(self, a): self.a = MyType(a) # Automatically registered as data self.b = 'hello' # str not registered as data foo = Foo(42) self.assertTrue(nnx.is_data(foo.a)) self.assertEqual(jax.tree.leaves(foo), [MyType(value=42)]) class TestVariableRefMode(absltest.TestCase): def test_split_mutable_array(self): m = jax.new_ref(1) graphdef, state = nnx.split(m) self.assertIs(m, state) m2 = nnx.merge(graphdef, state) self.assertIs(m2, m) def test_to_arrays_example(self): node = [nnx.Variable(1.0), nnx.Variable(2.0, mode='ref')] mutable_node = nnx.vars_as(node, ref=True) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) assert isinstance(mutable_node[1].get_raw_value(), jax.Ref) shared_array = nnx.Variable(1.0, mode='pytree') node = [shared_array, shared_array] with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): nnx.vars_as(node, ref=True) node = [nnx.Variable(1.0), nnx.Variable(2.0)] mutable_node = nnx.vars_as( node, ref=True, only=lambda path, x: path[0] == 0 ) assert isinstance(mutable_node[0].get_raw_value(), jax.Ref) assert isinstance(mutable_node[1].get_raw_value(), float) def test_freeze_and_mutable_with_filter(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.BatchStat(2) m = nnx.vars_as(Foo(), hijax=True, ref=True) self.assertEqual(m.a.ref, True) self.assertEqual(m.b.ref, True) m2 = nnx.vars_as(m, hijax=False, only=nnx.BatchStat) self.assertEqual(m2.a.ref, True) self.assertEqual(m2.a.hijax, True) self.assertEqual(m2.b.ref, True) self.assertEqual(m2.b.hijax, False) self.assertIsNot(m, m2) m3 = nnx.vars_as(m2, hijax=True, only=nnx.BatchStat) self.assertEqual(m3.a.ref, True) self.assertEqual(m3.b.ref, True) self.assertEqual(m3.b.hijax, True) self.assertIsNot(m2, m3) self.assertIs(m.a, m3.a) def test_freeze_duplicate_error(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1, mode='ref') self.b = self.a m = Foo() with self.assertRaisesRegex(ValueError, 'Found duplicate at path'): nnx.vars_as(m, ref=True) def test_mutable_array_split(self): class Foo(nnx.Module): def __init__(self): self.a = jax.new_ref(1) self.b = self.a m = Foo() ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) self.assertLen(state, 1) self.assertLen(ref_map, 2) # 1 Foo + 1 ArrayRef m1 = nnx.merge(graphdef, state) self.assertIs(m1.a, m1.b) self.assertIsInstance(m1.a, jax.Ref) def test_mutable_array_split_merge_in_variable(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1, ref=True) self.b = self.a m = Foo() ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) self.assertLen(state, 1) self.assertLen(ref_map, 3) # 1 Foo + 1 Param + 1 Ref m1 = nnx.merge(graphdef, state) self.assertIs(m1.a, m1.b) self.assertIsInstance(m1.a, nnx.Param) def test_mutable_array_split_merge_in_variable_shared_array(self): class Foo(nnx.Module): def __init__(self): m_array = 1 self.a = nnx.Param(m_array, ref=True) self.b = nnx.Param(m_array, ref=True) m = Foo() ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) self.assertLen(state, 2) self.assertLen(ref_map, 5) # 1 Foo + 2 Param + 2 Ref m1 = nnx.merge(graphdef, state) # Each variable will own its own array and ref. self.assertIsInstance(m1.a, nnx.Param) def test_mutable_example(self): tree = [nnx.Variable(1.0), nnx.Variable(2.0, ref=True)] assert tree[0].ref == False assert tree[1].ref == True mutable_tree = nnx.vars_as(tree, ref=True) assert isinstance(mutable_tree[0].get_raw_value(), jax.Ref) assert isinstance(mutable_tree[1].get_raw_value(), jax.Ref) def test_mutable_array_split_freeze(self): class Foo(nnx.Module): def __init__(self): self.a = jax.new_ref(1) self.b = self.a m = Foo() ref_map = nnx.graphlib.RefMap() graphdef, state = nnx.graphlib.flatten(m, ref_index=ref_map, graph=True) state = nnx.vars_as(state, hijax=False) self.assertLen(state, 1) m1 = nnx.merge(graphdef, state) self.assertIs(m1.a, m1.b) self.assertIsInstance(m1.a, jax.Ref) def test_update_context(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.split((m2, m_out1, m2)) self.assertIsInstance( state_out[0]['kernel'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( state_out[0]['bias'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( state_out[1]['kernel'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertIsInstance( state_out[1]['bias'].get_value(), nnx.graphlib.ArrayRefOutput ) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) with nnx.merge_context('example', False) as ctx: m3, m_out2, _ = ctx.merge(graphdef_out, state_out) self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) def test_update_context_flatten(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.flatten(m1) with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.flatten((m2, m_out1, m2)) state_out_dict = dict(state_out) self.assertIsInstance( state_out_dict[(0, 'kernel')].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( state_out_dict[(0, 'bias')].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( state_out_dict[(1, 'kernel')].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertIsInstance( state_out_dict[(1, 'bias')].get_value(), nnx.graphlib.ArrayRefOutput ) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) with nnx.merge_context('example', False) as ctx: m3, m_out2, _ = ctx.merge(graphdef_out, state_out) self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree1(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( out_tree[0][0].states[0]['kernel'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[0][0].states[0]['bias'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[1].states[0]['kernel'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertIsInstance( out_tree[1].states[0]['bias'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(out_tree), 2) # with nnx.merge_context('example', False) as ctx: # m3, m_out2 = ctx.merge(graphdef_out, state_out) (m3,), m_out2, _ = nnx.from_tree( out_tree, ctxtag='example', is_inner=False ) self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree2(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example') as ctx: m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( out_tree[0][0].states[0]['kernel'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[0][0].states[0]['bias'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[1].states[0]['kernel'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertIsInstance( out_tree[1].states[0]['bias'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(out_tree), 2) # with nnx.merge_context('example', False) as ctx: # m3, m_out2 = ctx.merge(graphdef_out, state_out) (m3,), m_out2, _ = nnx.from_tree( out_tree, ctxtag='example', is_inner=False ) self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree_trivial_prefix(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example', prefix=0) (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True, prefix=0) m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example', prefix=0) self.assertIsInstance( out_tree[0][0].states[0]['kernel'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[0][0].states[0]['bias'].get_value(), nnx.graphlib.NoUpdate ) self.assertIsInstance( out_tree[1].states[0]['kernel'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertIsInstance( out_tree[1].states[0]['bias'].get_value(), nnx.graphlib.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(out_tree), 2) # with nnx.merge_context('example', False) as ctx: # m3, m_out2 = ctx.merge(graphdef_out, state_out) (m3,), m_out2, _ = nnx.from_tree( out_tree, ctxtag='example', is_inner=False, prefix=0 ) self.assertIs(m3, m1) self.assertIsNot(m_out2, m_out1) def test_simple_jit(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) m_out1 = None @nnx.jit def f(m2): nonlocal m_out1 m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) return m_out1 m_out2 = f(m1) self.assertIsNot(m_out1, m_out2) self.assertIsInstance(m_out2.kernel, nnx.Param) self.assertIsInstance(m_out2.kernel[...], jax.Array) def test_jit_mutable(self): @nnx.dataclass class Foo(nnx.Pytree): a: jax.Ref = nnx.data() m1 = Foo(a=jax.new_ref(1)) @nnx.jit def f(m2: Foo): m2.a[...] += 1 return m2 m_out1 = f(m1) self.assertEqual(m_out1.a[...], 2) self.assertIs(m_out1, m1) self.assertIsInstance(m_out1.a, jax.Ref) def test_static(self): class C(nnx.Module): def __init__(self, meta): self.meta = meta n = 0 @jax.jit def f(x): nonlocal n n += 1 f(C(1)) assert n == 1 f(C(1)) assert n == 1 f(C(2)) assert n == 2 f(C(2)) assert n == 2 def test_variable_creation(self): v = nnx.Variable(jnp.array(1), ref=True) self.assertEqual(v[...], 1) self.assertTrue(v.ref) self.assertIsInstance(v.get_raw_value(), jax.Ref) def test_variable_metadata(self): v = nnx.Variable(jnp.array(1), a=2, b=3) self.assertEqual(v.a, 2) self.assertEqual(v.b, 3) def test_object(self): class Params(nnx.Pytree): def __init__(self, din: int, dout: int): self.w = nnx.Param(jnp.zeros((din, dout), jnp.float32)) self.b = nnx.Param(jnp.zeros((dout,), jnp.float32)) self.count = nnx.Variable(jnp.array(0)) params = Params(3, 4) params = nnx.vars_as(params, ref=True) paths_leaves, treedef = jax.tree.flatten_with_path(params) paths, leaves = zip(*paths_leaves) self.assertLen(paths_leaves, 3) self.assertEqual(leaves[0].shape, (4,)) # b self.assertEqual(leaves[1].shape, ()) # count self.assertEqual(leaves[2].shape, (3, 4)) # w self.assertEqual( paths[0], (jax.tree_util.GetAttrKey('b'), jax.tree_util.GetAttrKey('value')), ) self.assertEqual( paths[1], (jax.tree_util.GetAttrKey('count'), jax.tree_util.GetAttrKey('value')), ) self.assertEqual( paths[2], (jax.tree_util.GetAttrKey('w'), jax.tree_util.GetAttrKey('value')), ) params = jax.tree.unflatten(treedef, leaves) self.assertEqual(params.w.shape, (3, 4)) self.assertEqual(params.b.shape, (4,)) self.assertEqual(params.count.shape, ()) self.assertIsInstance(params.w, nnx.Variable) self.assertIsInstance(params.w[...], jax.Array) self.assertIsInstance(params.b, nnx.Variable) self.assertIsInstance(params.b[...], jax.Array) self.assertIsInstance(params.count, nnx.Variable) self.assertIsInstance(params.count[...], jax.Array) @jax.jit def linear(params: Params, x: jax.Array): params.count[...] += 1 return x @ params.w[...] + params.b[...][None] x = jnp.ones((1, 3)) y = linear(params, x) self.assertEqual(y.shape, (1, 4)) self.assertEqual(params.count[...], 1) y = linear(params, x) self.assertEqual(params.count[...], 2) def test_object_state(self): class Params(nnx.Pytree): def __init__(self, din: int, dout: int): self.w = jnp.zeros((din, dout), jnp.float32) self.b = jnp.zeros((dout,), jnp.float32) self.count = nnx.data(0) params = Params(3, 4) with self.assertRaises(flax.errors.TraceContextError): @jax.jit def f(): params.count = 1 f() @jax.jit def f(params: Params): params.count = 1 return params params = f(params) self.assertEqual(params.count, 1) def test_rngs_create(self): rngs = nnx.Rngs(0) paths_leaves = jax.tree.leaves_with_path(rngs) paths, leaves = zip(*paths_leaves) self.assertLen(paths_leaves, 2) self.assertEqual(leaves[0].shape, ()) # key self.assertEqual(leaves[1].shape, ()) # count self.assertEqual( paths[0], ( jax.tree_util.GetAttrKey('default'), jax.tree_util.GetAttrKey('count'), jax.tree_util.GetAttrKey('value'), ), ) self.assertEqual( paths[1], ( jax.tree_util.GetAttrKey('default'), jax.tree_util.GetAttrKey('key'), jax.tree_util.GetAttrKey('value'), ), ) def test_rngs_call(self): rngs = nnx.Rngs(0) key = rngs() self.assertIsInstance(key, jax.Array) class TestOptimizer(absltest.TestCase): def test_optimize_arrays(self): class Model(nnx.Module): def __init__(self, rngs): self.w = jax.random.uniform(rngs(), (2, 4)) self.count = jnp.array(0) def __call__(self, x): self.count += 1 return x @ self.w x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) wrt = lambda path, x: path[-1] == 'w' model = Model(nnx.Rngs(1)) optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) @jax.jit def train_step(model, optimizer, x, y): graphdef, params, nondiff = nnx.split(model, wrt, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2), nnx.state(model, nnx.Not(wrt)) (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)(params) nnx.update(model, updates) optimizer.update(model, grads) return loss, model, optimizer loss, model, optimizer = train_step(model, optimizer, x, y) self.assertNotEqual(loss, 0.0) self.assertEqual(model.count[...], 1) self.assertEqual(optimizer.step[...], 1) @nnx.var_defaults(hijax=True) def test_optimize_hijax(self): class Model(nnx.Module): def __init__(self, rngs): self.w = nnx.Variable(jax.random.uniform(rngs(), (2, 4))) self.count = nnx.Variable(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) wrt = lambda path, x: path[-1] == 'w' model = Model(nnx.Rngs(1)) optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) @jax.jit def train_step(model, optimizer, x, y): graphdef, params, nondiff = nnx.split(model, wrt, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, hijax=False)) optimizer.update(params, grads) return loss loss = train_step(model, optimizer, x, y) self.assertNotEqual(loss, 0.0) class TestHijaxVariables(parameterized.TestCase): def test_variable_to_hijax(self): v_low = nnx.Param(jnp.array(1), a='hi') v_hi = nnx.vars_as(v_low, hijax=True) self.assertTrue(v_hi.hijax) self.assertEqual(v_hi[...], 1) self.assertIsInstance(v_hi, nnx.Param) v_hi[...] = 2 self.assertEqual(v_hi[...], 2) @jax.jit def set(v_hi, a): self.assertIsInstance(v_hi, nnx.Param) v_hi[...] = a self.assertEqual(v_hi.a, 'hi') self.assertTrue(v_hi.hijax) v_hi[...] += 5 return v_hi + 2 y = set(v_hi, 10) self.assertEqual(v_hi[...], 15) self.assertEqual(y, 17) v_low = nnx.vars_as(v_hi, hijax=False) self.assertIsInstance(v_low, nnx.Param) self.assertFalse(v_low.hijax) self.assertEqual(v_low[...], 15) def test_from_metadata(self): value = 1 metadata = { 'a': 'hi', 'hijax': False, 'ref': False, } v_low = nnx.Param.from_metadata(value, metadata) self.assertIsInstance(v_low, nnx.Param) self.assertFalse(v_low.hijax) metadata['hijax'] = True v_hi = nnx.Param.from_metadata(value, metadata) self.assertIsInstance(v_hi, nnx.Param) self.assertTrue(v_hi.hijax) def test_variable_to_hijax_clean(self): v_low = nnx.Param(jnp.array([1]), tag='hello') print() print(v_low) assert not v_low.hijax v_hi = nnx.vars_as(v_low, hijax=True) v_hi[...] = jnp.array([2]) assert v_hi.hijax print(v_hi) assert v_hi[...] == 2 @jax.jit def set(v_hi, a): v_hi[...] = a print(v_hi) assert v_hi.tag == 'hello' set(v_hi, 10) assert v_hi[...] == 10 v_low = nnx.vars_as(v_hi, hijax=False) assert not v_low.hijax assert v_low[...] == 10 def test_pytree_value(self): v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True) @jax.jit def inc_and_double(v): v['a'] += 1 v['b'] *= 2 inc_and_double(v) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 4) def test_hijax_dynamic_structure(self): x = jnp.ones((4, 5)) metrics = nnx.Variable({}, hijax=True) @jax.jit def f(x, metrics: nnx.Variable): metrics['x_sum'] = jnp.sum(x) self.assertEmpty(metrics) f(x, metrics) self.assertIn('x_sum', metrics) self.assertEqual(metrics['x_sum'], 20) def test_hijax_and_pytree(self): class Foo(nnx.Pytree): def __init__(self, din, dout, rngs: nnx.Rngs): self.w = nnx.Param(rngs.uniform((din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = nnx.Variable(0) foo = Foo(2, 4, nnx.Rngs(1)) assert not foo.w.hijax assert not foo.b.hijax foo = nnx.vars_as(foo, hijax=True) assert foo.w.hijax assert foo.b.hijax @jax.jit def forward(foo, x): foo.count[...] += 1 return x @ foo.w + foo.b[None] x = jnp.ones((1, 2)) y = forward(foo, x) assert y.shape == (1, 4) assert foo.count[...] == 1 def test_use_hijax(self): v_low = nnx.Param(1, a='hi') self.assertFalse(v_low.hijax) v_hi = nnx.Param(1, a='hi', hijax=True) self.assertTrue(v_hi.hijax) with nnx.var_defaults(hijax=True): v2 = nnx.Param(1, a='hi') self.assertIs(type(v2), nnx.variablelib.HijaxVariable) self.assertTrue(v2.hijax) @nnx.var_defaults(hijax=True) def test_hijax_rngs(self): rngs = nnx.Rngs(0) self.assertIs(type(rngs.default.key), nnx.variablelib.HijaxVariable) self.assertIs(type(rngs.default.count), nnx.variablelib.HijaxVariable) @jax.jit def f(rngs: nnx.Rngs): return rngs() k1 = f(rngs) k2 = f(rngs) assert k1 != k2 @absltest.skip(reason='not yet supported') def test_return_hijax_from_transform(self): @jax.jit def create_var(): return nnx.Param(1, hijax=True) v = create_var() self.assertTrue(v.hijax) @absltest.skip('not yet supported') @nnx.var_defaults(hijax=True) def test_lower(self): v = nnx.Param(jnp.ones((2, 3))) @jax.jit def f(v): v[...] += 1 return v[...] e = f.lower(v) y = e.out_info[2] self.assertEqual(y.shape, ()) @nnx.var_defaults(hijax=True) def test_eval_shape(self): v = nnx.Param(jnp.array(0)) def f(v): v[...] += 1 return v[...] y = jax.eval_shape(f, v) self.assertEqual(y.shape, ()) @nnx.var_defaults(hijax=True) def test_no_qdd_grad(self): v = nnx.Param(jnp.array(3.0), hijax=False) self.assertFalse(v.hijax) def f(v): return v[...] ** 2 grad = jax.grad(f)(v) self.assertIsInstance(grad, nnx.Param) self.assertEqual(grad[...], 6.0) @nnx.var_defaults(hijax=True) def test_no_qdd_grad_new(self): x = jnp.array(3.0) def f(x): v = nnx.Param(x, hijax=False) self.assertFalse(v.hijax) return v[...] ** 2 grad = jax.grad(f)(x) self.assertIsInstance(grad, jax.Array) self.assertEqual(grad, 6.0) @parameterized.product( hijax=[True, False], ref=[True, False], ) def test_variable_properties(self, hijax, ref): v = nnx.Variable(jnp.array(1), hijax=hijax, ref=ref) self.assertEqual(v.hijax, hijax) self.assertEqual(v.ref, ref) if hijax: self.assertIsInstance(v, nnx.variablelib.HijaxVariable) else: self.assertNotIsInstance(v, nnx.variablelib.HijaxVariable) if ref: self.assertIsInstance(v.get_raw_value(), jax.Ref) else: self.assertNotIsInstance(v.get_raw_value(), jax.Ref) @parameterized.product( hijax=[True, False], ref=[True, False], ) def test_variable_copy_properties(self, hijax, ref): v_original = nnx.Variable(jnp.array(1)) v = v_original.copy(hijax=hijax, ref=ref) self.assertEqual(v.hijax, hijax) self.assertEqual(v.ref, ref) if hijax: self.assertIsInstance(v, nnx.variablelib.HijaxVariable) else: self.assertNotIsInstance(v, nnx.variablelib.HijaxVariable) if ref: self.assertIsInstance(v.get_raw_value(), jax.Ref) else: self.assertNotIsInstance(v.get_raw_value(), jax.Ref) @parameterized.product( hijax=[True, False], ref=[True, False], ) def test_variable_vars_as_properties(self, hijax, ref): v_original = nnx.Variable(jnp.array(1)) v = nnx.vars_as(v_original, hijax=hijax, ref=ref) self.assertEqual(v.hijax, hijax) self.assertEqual(v.ref, ref) if hijax: self.assertIsInstance(v, nnx.variablelib.HijaxVariable) else: self.assertNotIsInstance(v, nnx.variablelib.HijaxVariable) if ref: self.assertIsInstance(v.get_raw_value(), jax.Ref) else: self.assertNotIsInstance(v.get_raw_value(), jax.Ref) class TestVarDefaults(absltest.TestCase): def test_defaults(self): defaults = nnx.var_defaults() self.assertIsInstance(defaults, nnx.variablelib.VarDefaults) # Default values might depend on config/env, but generally hijax=False, ref=False initially self.assertFalse(defaults.hijax) self.assertFalse(defaults.ref) def test_context_manager_hijax(self): with nnx.var_defaults(hijax=True): self.assertTrue(nnx.var_defaults().hijax) v = nnx.Variable(1) self.assertTrue(v.hijax) self.assertFalse(nnx.var_defaults().hijax) v2 = nnx.Variable(1) self.assertFalse(v2.hijax) def test_context_manager_ref(self): with nnx.var_defaults(ref=True): self.assertTrue(nnx.var_defaults().ref) v = nnx.Variable(1) self.assertTrue(v.ref) self.assertFalse(nnx.var_defaults().ref) v2 = nnx.Variable(1) self.assertFalse(v2.ref) def test_context_manager_nested(self): with nnx.var_defaults(hijax=True, ref=False): self.assertTrue(nnx.var_defaults().hijax) self.assertFalse(nnx.var_defaults().ref) with nnx.var_defaults(ref=True): self.assertTrue(nnx.var_defaults().hijax) self.assertTrue(nnx.var_defaults().ref) with nnx.var_defaults(hijax=False): self.assertFalse(nnx.var_defaults().hijax) self.assertTrue(nnx.var_defaults().ref) self.assertTrue(nnx.var_defaults().hijax) self.assertFalse(nnx.var_defaults().ref) def test_mapping_protocol(self): defaults = nnx.var_defaults() self.assertIn('hijax', defaults) self.assertIn('ref', defaults) self.assertEqual(len(defaults), 2) self.assertEqual(list(defaults), ['hijax', 'ref']) self.assertEqual(defaults['hijax'], defaults.hijax) self.assertEqual(defaults['ref'], defaults.ref) def test_decorator(self): @nnx.var_defaults(hijax=True, ref=True) def f(): return nnx.var_defaults() defaults = f() self.assertTrue(defaults.hijax) self.assertTrue(defaults.ref) self.assertFalse(nnx.var_defaults().hijax) def test_variable_init_override(self): with nnx.var_defaults(hijax=True): v = nnx.Variable(1, hijax=False) self.assertFalse(v.hijax) with nnx.var_defaults(ref=True): v = nnx.Variable(1, ref=False) self.assertFalse(v.ref) class HijaxTransformCoverageTest(absltest.TestCase): # ------------ # grad # ------------ # with differentiable hijax arguments (immutable variable) def test_hitypes_as_grad_args(self): v = nnx.Variable((jnp.array(2.0), jnp.array(3.0)), hijax=False) def loss_fn(v): x = v[0] return x ** 2 grads = jax.grad(loss_fn)(v) np.testing.assert_allclose(grads[0], 4.0) def test_hitypes_as_nondiff_grad_args(self): v = nnx.Variable((jnp.array(2.0), jnp.array(3.0)), hijax=False) x = jnp.array(3.0) def loss_fn(x, v): y = v[1] return x ** 2 + y grad = jax.grad(loss_fn)(x, v) np.testing.assert_allclose(grad, 6.0) def test_hitypes_as_captured_args(self): v = nnx.Variable((jnp.array(2.0), jnp.array(3.0)), hijax=False) def loss_fn(x): y = v[1] return x ** 2 + y grad = jax.grad(loss_fn)(jnp.array(4.0)) np.testing.assert_allclose(grad, 8.0) # with differentiable mutable hijax arguments @absltest.skip("Not yet implemented") def test_mutable_hitypes_as_grad_args(self): v = nnx.Variable(jnp.array(2.0), hijax=True) def loss_fn(v): return v[...] ** 2 grads = jax.grad(loss_fn)(v) # NOTE: unclear what the tangent type will be here # with non-differentiable mutable hijax arguments def test_mutable_hitypes_as_nondiff_grad_args(self): v = nnx.Variable(jnp.array(2.0), hijax=True) x = jnp.array(3.0) def loss_fn(x, v): v[...] = jax.lax.stop_gradient(x * 2) return x ** 2 + v[...] grad = jax.grad(loss_fn)(x, v) np.testing.assert_allclose(v[...], 6.0) np.testing.assert_allclose(grad, 6.0) # with mutable hijax captured arguments def test_mutable_hitypes_as_captured_args(self): v = nnx.Variable(jnp.array(2.0), hijax=True) def loss_fn(x): v[...] = jax.lax.stop_gradient(x * 3) return x ** 2 + v[...] grad = jax.grad(loss_fn)(jnp.array(4.0)) np.testing.assert_allclose(v[...], 12.0) np.testing.assert_allclose(grad, 8.0) #------------ # scan #------------ # with hijax carry arguments (immutable variable) @absltest.skip("scan not yet supported for hijax Variables") def test_hitypes_as_scan_carry(self): v = nnx.Variable((jnp.array(1.0), jnp.array(2.0)), hijax=True, mutable=False) def body(v, _): x, y = v return nnx.Variable((x + 1.0, y + 2.0), hijax=True, mutable=False), None v_out, _ = jax.lax.scan(body, v, None, length=5) x, y = v_out[...] np.testing.assert_allclose(x, 6.0) np.testing.assert_allclose(y, 12.0) # with hijax extensive arguments (immutable variable) @absltest.skip("scan not yet supported for hijax Variables") def test_hitypes_as_scan_extensive(self): v = nnx.Variable((jnp.arange(5), -jnp.arange(5)), hijax=True, mutable=False) def body(_, v_i): x, y = v_i v_i = nnx.Variable((x * 2, y * 2), hijax=True, mutable=False) return None, v_i _, v_out = jax.lax.scan(body, None, v) x, y = v_out np.testing.assert_allclose(x, jnp.arange(5) * 2) np.testing.assert_allclose(y, -jnp.arange(5) * 2) # with hijax captured arguments (immutable variable) @absltest.skip("scan not yet supported for hijax Variables") def test_hitypes_as_scan_captured(self): v = nnx.Variable((jnp.array(3.0), jnp.array(4.0)), hijax=True, mutable=False) carry0 = jnp.array(1.0) xs = jnp.arange(5, dtype=jnp.float32) def body(carry, x): a, b = v carry = a * carry + b y = a * x + b return carry, nnx.Variable(y, hijax=True, mutable=False) carry, ys_v = jax.lax.scan(body, carry0, xs) ys = ys_v[...] np.testing.assert_allclose(carry, 727.0) np.testing.assert_allclose(ys, 3.0 * xs + 4.0) # with mutable hijax carry arguments @absltest.skip("has_qdd not yet supported for mutable Variable in scan carry") def test_mutable_hitypes_as_scan_carry(self): v = nnx.Variable(jnp.array(1.0), hijax=True) def body(v, _): v[...] = v[...] * 2 return v, None v_out, _ = jax.lax.scan(body, v, None, length=5) np.testing.assert_allclose(v_out[...], 32.0) # with mutable hijax extensive arguments @absltest.skip("Variable doesn't have shape attribute needed for scan extensive") def test_mutable_hitypes_as_scan_extensive(self): vs = [nnx.Variable(jnp.float32(i), hijax=True) for i in range(5)] def body(_, v_i): val = v_i[...] v_i[...] = val * 2 return None, v_i _, vs_out = jax.lax.scan(body, None, vs) for i, v in enumerate(vs_out): np.testing.assert_allclose(v[...], i * 2) # with mutable hijax captured arguments def test_mutable_hitypes_as_scan_captured(self): v = nnx.Variable(jnp.array(3.0), hijax=True) def body(_, __): v[...] = v[...] + 1.0 return None, None jax.lax.scan(body, None, None, length=5) np.testing.assert_allclose(v[...], 8.0) def test_hijax_variable_in_jit_graph_updates_false(self): v = nnx.Variable(jnp.array(1.0), hijax=True) @nnx.jit(graph=True, graph_updates=False) def f(v, v2): self.assertIs(v, v2) v[...] += 1.0 return v[...] * 2 y = f(v, v) np.testing.assert_allclose(v[...], 2.0) np.testing.assert_allclose(y, 4.0) y = f(v, v) np.testing.assert_allclose(v[...], 3.0) np.testing.assert_allclose(y, 6.0) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/attention_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax, jax.numpy as jnp from jax.lax import Precision from flax import linen from flax import nnx from flax.nnx.nn.attention import combine_masks from flax.typing import Dtype, PrecisionLike import numpy as np import typing as tp from absl.testing import parameterized from absl.testing import absltest try: # JAX v0.8.0 and newer from jax import enable_x64 except ImportError: from jax.experimental import enable_x64 class TestMultiHeadAttention(parameterized.TestCase): def test_basic(self): module = nnx.MultiHeadAttention( num_heads=2, in_features=3, qkv_features=6, out_features=6, rngs=nnx.Rngs(0), ) y = module(jnp.ones((1, 7, 3)), decode=False) assert y.shape == (1, 7, 6) def test_multihead_sow_attention_weights(self): class Model(nnx.Module): attention_kwargs: dict def __init__(self, attention_kwargs, rng): self.attention_layers = nnx.data([ nnx.MultiHeadAttention(**attention_kwargs, rngs=rng) for i in range(3) ]) def __call__(self, x, sow_weights=False): x = self.attention_layers[0](x, sow_weights=sow_weights) x = self.attention_layers[1](x) x = self.attention_layers[2](x, sow_weights=sow_weights) return x rng = nnx.Rngs(0) x = jnp.ones((4, 6, 8)) module = Model( dict( in_features=8, num_heads=8, kernel_init=nnx.initializers.ones_init(), bias_init=nnx.initializers.zeros_init(), deterministic=False, ), rng, ) module.set_attributes(decode=False) _, intermediates = nnx.capture(module, nnx.Intermediate)(x, True) assert intermediates['attention_layers'][0]['attention_weights'][ 0 ].shape == (4, 8, 6, 6) assert 1 not in intermediates['attention_layers'] assert intermediates['attention_layers'][2]['attention_weights'][ 0 ].shape == (4, 8, 6, 6) _, intermediates = nnx.capture(module, nnx.Intermediate)(x) assert not intermediates # empty def test_autoregressive_decode_with_x64(self): with enable_x64(): x = jnp.ones((1, 4, 4)) module = nnx.MultiHeadAttention( in_features=4, num_heads=2, qkv_features=4, decode=True, rngs=nnx.Rngs(0), ) module.init_cache(x.shape, dtype=x.dtype) assert module.cached_key.shape == (1, 4, 2, 2) assert module.cached_value.shape == (1, 4, 2, 2) y1 = module(x[:, :1, :]) y2 = module(x[:, 1:2, :]) assert y1.shape == (1, 1, 4) assert y2.shape == (1, 1, 4) @parameterized.product(keep_rngs=[True, False]) def test_keep_rngs(self, keep_rngs): rngs = nnx.Rngs(42) module = nnx.MultiHeadAttention( in_features=4, num_heads=2, qkv_features=4, decode=True, rngs=rngs, dropout_rate=0.5, keep_rngs=keep_rngs ) if keep_rngs: assert module.rngs is not None else: assert module.rngs is None if keep_rngs: _, _, nondiff = nnx.split(module, nnx.Param, ...) assert isinstance(nondiff['rngs']['count'], nnx.RngCount) assert isinstance(nondiff['rngs']['key'], nnx.RngKey) else: nnx.split(module, nnx.Param) @parameterized.product(use_padding=[True, False], is_cross_attention=[True, False]) def test_causal_mask_equivalence( self, use_padding: bool, is_cross_attention: bool ): batch_size = 1 num_heads = 2 q_len = 2 kv_len = 4 if is_cross_attention else q_len head_dim = 4 q = jax.random.normal( key=jax.random.key(0), shape=(batch_size, 1, q_len, num_heads, head_dim) ) k = jax.random.normal( key=jax.random.key(1), shape=(batch_size, 1, kv_len, num_heads, head_dim) ) v = jax.random.normal( key=jax.random.key(2), shape=(batch_size, 1, kv_len, num_heads, head_dim) ) causal_mask = jnp.tril(jnp.ones( shape=(q_len, kv_len), dtype=jnp.bool_ ) ) causal_mask = jnp.broadcast_to( array=causal_mask, shape=(batch_size, 1, num_heads, q_len, kv_len) ) padding_mask = None if use_padding: padding_mask = jnp.ones( shape=(batch_size, 1, 1, q_len, kv_len), dtype=jnp.bool_, ) padding_mask = padding_mask.at[..., -2:].set(False) manual_mask = combine_masks(padding_mask, causal_mask, dtype=q.dtype) # Jax.nn path with precombined mask and is_causal = False attn_jax = nnx.dot_product_attention( query=q, key=k, value=v, mask=manual_mask, is_causal=False, deterministic=True, module=None, ) class DummyModule(nnx.Module): pass # nnx path with padding mask and is_causal = True (internally combines them) dummy = DummyModule() def _run(m): return nnx.dot_product_attention( query=q, key=k, value=v, mask=padding_mask, is_causal=True, deterministic=True, module=m, ) attn_manual, _ = nnx.capture(_run, nnx.Intermediate)(dummy) np.testing.assert_allclose(attn_jax, attn_manual, atol=1e-6) # TODO: add all possible constructor argument values to parameterized.product class TestLinenConsistency(parameterized.TestCase): @parameterized.product( use_bias=[True, False], dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], decode=[True, False], normalize_qk=[True, False], ) def test_nnx_attention_equivalence( self, use_bias: bool, dtype: tp.Optional[Dtype], param_dtype: Dtype, precision: PrecisionLike, decode: bool, normalize_qk: bool, ): key = jax.random.key(42) rngs = nnx.Rngs(42) num_heads = 2 in_features = 3 qkv_features = 6 out_features = 6 x = jax.numpy.ones((1, in_features)) model_nnx = nnx.MultiHeadAttention( num_heads=num_heads, in_features=in_features, qkv_features=qkv_features, out_features=out_features, use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, precision=precision, decode=decode, normalize_qk=normalize_qk, rngs=rngs, ) model = linen.MultiHeadDotProductAttention( num_heads=num_heads, qkv_features=qkv_features, out_features=out_features, use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, precision=precision, decode=decode, normalize_qk=normalize_qk, ) variables = model.init(key, x) for qkvo in ('query', 'key', 'value', 'out'): getattr(model_nnx, qkvo).kernel[...] = variables['params'][qkvo]['kernel'] if use_bias: getattr(model_nnx, qkvo).bias[...] = variables['params'][qkvo]['bias'] if decode: model_nnx.init_cache(x.shape, dtype=dtype) out_nnx = model_nnx(x) out, cache = model.apply(variables, x, mutable=['cache']) np.testing.assert_array_equal(out, out_nnx) class TestKVFeatures(parameterized.TestCase): def test_varying_num_features(self): key = jax.random.key(42) rngs = nnx.Rngs(42) num_heads = 2 in_features = 3 in_kv_features = 4 qkv_features = 6 out_features = 6 x = jax.numpy.ones((1, in_features)) y = jax.random.normal(key, (1, in_kv_features)) layer = nnx.MultiHeadAttention( num_heads=num_heads, in_features=in_features, qkv_features=qkv_features, out_features=out_features, in_kv_features=in_kv_features, rngs=rngs, decode=False ) self.assertIsNotNone(layer(x, y)) class TestGQADotProductAttention(parameterized.TestCase): def test_gqa_shapes(self): B, T, S = 2, 4, 5 D = 8 num_heads_q = 6 num_heads_kv = 3 k1, k2, k3 = jax.random.split(jax.random.key(0), 3) query = jax.random.normal(k1, (B, T, num_heads_q, D)) key = jax.random.normal(k2, (B, S, num_heads_kv, D)) value = jax.random.normal(k3, (B, S, num_heads_kv, D)) output = nnx.dot_product_attention(query, key, value) expected_shape = (B, T, num_heads_q, D) self.assertEqual(output.shape, expected_shape) def test_gqa_invalid_heads(self): B, T, D = 1, 4, 8 query = jnp.ones((B, T, 5, D)) key = jnp.ones((B, T, 2, D)) value = key with self.assertRaisesRegex(ValueError, "must be a multiple"): nnx.dot_product_attention(query, key, value) def test_gqa_multihead_attention(self): in_feat = 128 n_heads = 32 n_kv_heads = 8 qkv_feat = 2048 head_dim = qkv_feat // n_heads model = nnx.MultiHeadAttention( num_heads=n_heads, in_features=in_feat, qkv_features=qkv_feat, num_kv_heads=n_kv_heads, rngs=nnx.Rngs(0), ) assert model.query.kernel.shape == (in_feat, n_heads, head_dim) assert model.key.kernel.shape == (in_feat, n_kv_heads, head_dim) assert model.value.kernel.shape == (in_feat, n_kv_heads, head_dim) x = jnp.ones((1, 10, in_feat)) y = model(x, decode=False) assert y.shape == (1, 10, in_feat) model.init_cache((1, 10, in_feat)) assert model.cached_key.shape == (1, 10, n_kv_heads, head_dim) x_token = jnp.ones((1, 1, in_feat)) y_token = model(x_token, decode=True) assert y_token.shape == (1, 1, in_feat) assert model.cache_index == 1 def test_gqa_parity_with_jax(self): class DummyModule(nnx.Module): pass dummy_module = DummyModule() B, T, S, D = 2, 8, 8, 16 num_heads_q = 4 num_heads_kv = 2 rng = jax.random.key(42) k1, k2, k3 = jax.random.split(rng, 3) query = jax.random.normal(k1, (B, T, num_heads_q, D)) key = jax.random.normal(k2, (B, S, num_heads_kv, D)) value = jax.random.normal(k3, (B, S, num_heads_kv, D)) jax_out = jax.nn.dot_product_attention(query, key, value) # NNX should handle broadcasting internally def _run(m): return nnx.dot_product_attention(query, key, value, module=m) nnx_out, _ = nnx.capture(_run, nnx.Intermediate)(dummy_module) np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/conv_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from functools import partial import typing as tp import jax from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp from jax.lax import Precision import numpy as np from flax import linen from flax import nnx from flax.typing import PaddingLike, Dtype, PrecisionLike class TestConvLinenConsistency(parameterized.TestCase): @parameterized.product( strides=[None, (2, 3)], padding=['VALID', 'CIRCULAR', 'REFLECT', (4, 2)], input_dilation=[(2, 3)], kernel_dilation=[(2, 3)], feature_group_count=[3], use_bias=[True, False], use_mask=[False, True], dtype=[jnp.float32], param_dtype=[jnp.float16], precision=[Precision.HIGHEST], preferred_element_type=[None, jnp.float32], ) def test_nnx_linen_conv_equivalence( self, strides: tp.Union[None, int, tp.Sequence[int]], padding: PaddingLike, input_dilation: tp.Union[None, int, tp.Sequence[int]], kernel_dilation: tp.Union[None, int, tp.Sequence[int]], feature_group_count: int, use_bias: bool, use_mask: bool, dtype: tp.Optional[Dtype], param_dtype: Dtype, precision: PrecisionLike, preferred_element_type: tp.Optional[Dtype], ): key = jax.random.key(42) rngs = nnx.Rngs(42) IN_FEATURES = 3 OUT_FEATURES = 6 INPUT_SHAPE = (24, 9, IN_FEATURES) kernel_size = (7, 4) if use_mask: mask = jnp.tril(jnp.ones((7, 4, 1, 6))) else: mask = None # Cannot use string padding specification for transpose conv if isinstance(input_dilation, Sequence) or ( isinstance(input_dilation, int) and input_dilation > 1 ): padding = (4, 2) x = jax.numpy.ones(INPUT_SHAPE) model_nnx = nnx.Conv( IN_FEATURES, OUT_FEATURES, kernel_size, strides, padding=padding, input_dilation=input_dilation, kernel_dilation=kernel_dilation, feature_group_count=feature_group_count, use_bias=use_bias, mask=mask, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, rngs=rngs, ) if preferred_element_type is not None: conv_general_dilated = partial( jax.lax.conv_general_dilated, preferred_element_type=preferred_element_type, ) else: conv_general_dilated = None model = linen.Conv( OUT_FEATURES, kernel_size=kernel_size, strides=strides, padding=padding, input_dilation=input_dilation, kernel_dilation=kernel_dilation, feature_group_count=feature_group_count, use_bias=use_bias, mask=mask, dtype=dtype, param_dtype=param_dtype, precision=precision, conv_general_dilated=conv_general_dilated, ) variables = model.init(key, x) model_nnx.kernel[...] = variables['params']['kernel'] if use_bias: model_nnx.bias[...] = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) @parameterized.product( strides=[None, (2, 3)], padding=['VALID', 'CIRCULAR', (4, 2)], kernel_dilation=[(2, 3)], use_bias=[True, False], use_mask=[False, True], dtype=[jnp.float32], param_dtype=[jnp.float16], precision=[Precision.HIGHEST], preferred_element_type=[None, jnp.float32], ) def test_nnx_linen_convtranspose_equivalence( self, strides: tp.Union[None, tp.Sequence[int]], padding: PaddingLike, kernel_dilation: tp.Union[None, tp.Sequence[int]], use_bias: bool, use_mask: bool, dtype: tp.Optional[Dtype], param_dtype: Dtype, precision: PrecisionLike, preferred_element_type: tp.Optional[Dtype], ): key = jax.random.key(42) rngs = nnx.Rngs(42) IN_FEATURES = 3 OUT_FEATURES = 6 INPUT_SHAPE = (24, 9, IN_FEATURES) kernel_size = (7, 4) if use_mask: mask = jnp.tril(jnp.ones((7, 4, 3, 6))) else: mask = None x = jax.numpy.ones(INPUT_SHAPE) model_nnx = nnx.ConvTranspose( IN_FEATURES, OUT_FEATURES, kernel_size, strides, padding=padding, kernel_dilation=kernel_dilation, use_bias=use_bias, mask=mask, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, rngs=rngs, ) model = linen.ConvTranspose( OUT_FEATURES, kernel_size=kernel_size, strides=strides, padding=padding, kernel_dilation=kernel_dilation, use_bias=use_bias, mask=mask, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, ) variables = model.init(key, x) model_nnx.kernel[...] = variables['params']['kernel'] if use_bias: assert model_nnx.bias is not None model_nnx.bias[...] = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/embed_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp import jax from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp import numpy as np from flax import linen from flax import nnx from flax.typing import Dtype class TestLinenConsistency(parameterized.TestCase): @parameterized.product( input_dtype=[jnp.int16, jnp.int32], num_embeddings=[1, 7], dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], ) def test_nnx_linen_equivalence( self, input_dtype: tp.Optional[Dtype], num_embeddings: int, dtype: tp.Optional[Dtype], param_dtype: Dtype, ): key = jax.random.key(42) rngs = nnx.Rngs(42) IN_FEATURES = 32 NUM_EMBEDDINGS = num_embeddings x = jax.numpy.arange(NUM_EMBEDDINGS, dtype=input_dtype) model_nnx = nnx.eval_shape( lambda rngs: nnx.Embed( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ), rngs, ) model = linen.Embed( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype ) variables = model.init(key, x) model_nnx.embedding.set_value(variables['params']['embedding']) out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) x = jax.numpy.ones((10,), dtype=input_dtype) * 10 out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/linear_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import typing as tp import jax import jax.numpy as jnp from absl.testing import absltest from absl.testing import parameterized from jax.lax import Precision import numpy as np from flax import linen from flax import nnx from flax.typing import Dtype, PrecisionLike, Shape class TestLinearGeneral(parameterized.TestCase): @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], preferred_element_type=[None, jnp.float32], ) def test_basic( self, dtype, param_dtype, precision, preferred_element_type, ): module = nnx.LinearGeneral( 2, 3, rngs=nnx.Rngs(0), dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, ) y = module(jnp.ones((1, 2))) assert y.shape == (1, 3) if preferred_element_type is not None: assert y.dtype == preferred_element_type assert module.kernel.shape == (2, 3) assert module.kernel.dtype == param_dtype assert module.bias is not None assert module.bias.shape == (3,) def test_basic_multi_features(self): module = nnx.LinearGeneral(2, (3, 4), rngs=nnx.Rngs(0)) y = module(jnp.ones((1, 2))) assert y.shape == (1, 3, 4) assert module.kernel.shape == (2, 3, 4) assert module.bias is not None assert module.bias.shape == (3, 4) class TestLinenConsistency(parameterized.TestCase): @parameterized.product( use_bias=[True, False], dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], preferred_element_type=[None, jnp.float32], ) def test_nnx_linear_equivalence( self, use_bias: bool, dtype: tp.Optional[Dtype], param_dtype: Dtype, precision: PrecisionLike, preferred_element_type: tp.Optional[Dtype], ): key = jax.random.key(42) rngs = nnx.Rngs(42) IN_FEATURES = 32 OUT_FEATURES = 64 x = jax.numpy.ones((1, IN_FEATURES)) model_nnx = nnx.eval_shape( lambda rngs: nnx.Linear( IN_FEATURES, OUT_FEATURES, use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, rngs=rngs, ), rngs, ) if preferred_element_type is not None: dot_general = partial( jax.lax.dot_general, preferred_element_type=preferred_element_type, ) else: dot_general = None model = linen.Dense( OUT_FEATURES, use_bias=use_bias, dtype=dtype, param_dtype=param_dtype, precision=precision, dot_general=dot_general, ) variables = model.init(key, x) model_nnx.kernel.set_value(variables['params']['kernel']) if use_bias: model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) @parameterized.product( einsum_str=['defab,bcef->adefc', 'd...ab,bc...->ad...c'], bias_shape=[None, (6, 7, 5)], dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], preferred_element_type=[None, jnp.float32], ) def test_nnx_einsum_equivalence( self, einsum_str, bias_shape: tp.Optional[Shape], dtype: tp.Optional[Dtype], param_dtype: Dtype, precision: PrecisionLike, preferred_element_type: tp.Optional[Dtype], ): key = jax.random.key(42) rngs = nnx.Rngs(42) INPUT_SHAPE = (8, 6, 7, 3, 4) KERNEL_SHAPE = (4, 5, 6, 7) x = jax.random.normal(key, INPUT_SHAPE) model_nnx = nnx.Einsum( einsum_str, KERNEL_SHAPE, bias_shape, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, rngs=rngs, ) model = linen.Einsum( KERNEL_SHAPE, einsum_str, use_bias=True if bias_shape is not None else False, dtype=dtype, param_dtype=param_dtype, precision=precision, preferred_element_type=preferred_element_type, ) variables = model.init(key, x) variables['params']['kernel'] = model_nnx.kernel[...] if bias_shape is not None: assert model_nnx.bias is not None variables['params']['bias'] = model_nnx.bias[...] out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) variables = model.init(key, x) model_nnx.kernel.set_value(variables['params']['kernel']) if bias_shape is not None: assert model_nnx.bias is not None model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) np.testing.assert_array_equal(out, out_nnx) def test_einsum_op(self): def custom_einsum(*args, **kwargs): out = jnp.einsum(*args, **kwargs) return out.reshape((1, *out.shape)) model = nnx.Einsum('ab,bc->ac', (3, 4), einsum_op=custom_einsum, rngs=nnx.Rngs(42)) y = model(jnp.ones((2, 3))) assert y.shape == (1, 2, 4) class TestPReLUConsistency(parameterized.TestCase): @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], ) def test_equivalence(self, dtype, param_dtype): key = jax.random.key(42) x = jnp.linspace(-10, 10, 20, dtype=dtype) negative_slope_init = 0.02 nnx_prelu = nnx.PReLU(negative_slope_init=negative_slope_init, param_dtype=param_dtype) linen_prelu = linen.PReLU(negative_slope_init=negative_slope_init, param_dtype=param_dtype) variables = linen_prelu.init(key, x) expected = linen_prelu.apply(variables, x) output = nnx_prelu(x) np.testing.assert_array_equal(output, expected) # Check gradients @jax.jit def nnx_loss_function(model): y = model(x) return y.mean() @jax.jit def linen_loss_function(variables): y = linen_prelu.apply(variables, x) return y.mean() expected_loss, expected_grads = jax.value_and_grad(linen_loss_function)(variables) loss, grads = jax.value_and_grad(nnx_loss_function)(nnx_prelu) np.testing.assert_array_equal(loss, expected_loss) np.testing.assert_array_equal( expected_grads['params']['negative_slope'], grads.negative_slope[...] ) class TestLayersSameGraph(parameterized.TestCase): @parameterized.product( module_args_kwargs_initargs=[ (nnx.LinearGeneral, (2, (3, 4)), ("kernel_init", "bias_init")), (nnx.Linear, (2, 4), ("kernel_init", "bias_init")), (nnx.Einsum, ("ik,kj->ij", (5, 4), 5), ("kernel_init", "bias_init")), (nnx.Conv, (2, 4, 3), ("kernel_init", "bias_init")), (nnx.ConvTranspose, (2, 4, 3), ("kernel_init", "bias_init")), (nnx.Embed, (2, 4), ("embedding_init",)), ( nnx.MultiHeadAttention, (8, 5, 16), ("kernel_init", "out_kernel_init", "bias_init", "out_bias_init"), ), (nnx.BatchNorm, (3,), ("scale_init", "bias_init")), (nnx.LayerNorm, (3,), ("scale_init", "bias_init")), (nnx.RMSNorm, (3,), ("scale_init",)), (nnx.GroupNorm, (6, 3), ("scale_init", "bias_init")), (nnx.InstanceNorm, (6,), ("scale_init", "bias_init")), ( nnx.LSTMCell, (4, 5), ("kernel_init", "recurrent_kernel_init", "bias_init"), ), ( nnx.OptimizedLSTMCell, (4, 5), ("kernel_init", "recurrent_kernel_init", "bias_init"), ), ( nnx.SimpleCell, (4, 5), ("kernel_init", "recurrent_kernel_init", "bias_init"), ), ( nnx.GRUCell, (4, 5), ("kernel_init", "recurrent_kernel_init", "bias_init"), ), ], ) def test(self, module_args_kwargs_initargs): module_cls, args, init_argnames = module_args_kwargs_initargs kwargs = {"rngs": nnx.Rngs(0)} init_zeros = nnx.initializers.zeros init_ones = nnx.initializers.ones init1_kwargs = {k: init_zeros for k in init_argnames} init2_kwargs = {k: init_ones for k in init_argnames} mod1 = module_cls(*args, **init1_kwargs, **kwargs) mod2 = module_cls(*args, **init2_kwargs, **kwargs) g1, g2 = nnx.graphdef(mod1), nnx.graphdef(mod2) assert g1 == g2 class TestLayersParamsMetadata(parameterized.TestCase): @parameterized.product( module_args_kwargs_initargs=[ (nnx.LinearGeneral, (2, (3, 4)), (("kernel", 2, ()), ("bias", 1, ()))), (nnx.Linear, (2, 4), (("kernel", 2, ()), ("bias", 1, ()))), (nnx.Einsum, ("ik,kj->ij", (5, 4), 5), (("kernel", 2, ()), ("bias", 1, ()))), (nnx.Conv, (2, 4, 3), (("kernel", 2, ()), ("bias", 1, ()))), (nnx.ConvTranspose, (2, 4, 3), (("kernel", 2, ()), ("bias", 1, ()))), (nnx.Embed, (2, 4), (("embedding", 2, ()), )), ( partial(nnx.MultiHeadAttention, normalize_qk=True), (8, 5, 16), ( ("kernel", 2, (("query", "kernel"), ("key", "kernel"), ("value", "kernel"))), ("out_kernel", 2, (("out", "kernel"), )), ("bias", 1, (("query", "bias"), ("key", "bias"), ("value", "bias"))), ("out_bias", 1, (("out", "bias"), )), ("query_ln_scale", 1, (("query_ln", "scale"), )), ("key_ln_scale", 1, (("key_ln", "scale"), )), ), ), (nnx.BatchNorm, (3,), (("scale", 1, ()), ("bias", 1, ()))), (nnx.LayerNorm, (3,), (("scale", 1, ()), ("bias", 1, ()))), (nnx.RMSNorm, (3,), (("scale", 1, ()), )), (nnx.GroupNorm, (6, 3), (("scale", 1, ()), ("bias", 1, ()))), (nnx.InstanceNorm, (6,), (("scale", 1, ()), ("bias", 1, ()))), ( nnx.LoRA, (3, 2, 4), ( ("a", 2, ((None, "lora_a"), )), ("b", 2, ((None, "lora_b"), )), ) ), ( partial(nnx.LoRALinear, lora_rank=4), (3, 2), ( ("a", 2, (("lora", "lora_a"), )), ("b", 2, (("lora", "lora_b"), )), ) ), ( nnx.LSTMCell, (4, 5), ( ( "kernel", 2, ((name, "kernel") for name in ["ii", "if_", "ig", "io"]) ), ( "recurrent_kernel", 2, ((name, "kernel") for name in ["hi", "hf", "hg", "ho"]) ), ( "bias", 1, ((name, "bias") for name in ["hi", "hf", "hg", "ho"]) ), ), ), ( nnx.OptimizedLSTMCell, (4, 5), ( ("kernel", 2, (("dense_i", "kernel"), )), ("recurrent_kernel", 2, (("dense_h", "kernel"), )), ("bias", 1, (("dense_h", "bias"), )), ) ), ( nnx.SimpleCell, (4, 5), ( ("kernel", 2, (("dense_i", "kernel"), )), ("bias", 1, (("dense_i", "bias"), )), ("recurrent_kernel", 2, (("dense_h", "kernel"), )), ) ), ( nnx.GRUCell, (4, 5), ( ("kernel", 2, (("dense_i", "kernel"), )), ("bias", 1, (("dense_i", "bias"), )), ("recurrent_kernel", 2, (("dense_h", "kernel"), )), ) ), ], ) def test(self, module_args_kwargs_initargs): module_cls, args, metadata_argnames = module_args_kwargs_initargs kwargs = {"rngs": nnx.Rngs(0)} out_sharding = ("din", "dout") metadata_kwargs = { f"{key}_metadata": {"out_sharding": out_sharding[:le]} for key, le, _ in metadata_argnames } mesh = jax.make_mesh( (1, 1), out_sharding, axis_types=(jax.sharding.AxisType.Auto,) * len(out_sharding), ) with jax.set_mesh(mesh): module = module_cls(*args, **metadata_kwargs, **kwargs) for key, le, attrs in metadata_argnames: attrs = attrs if attrs else ((None, key), ) for attr_name, param_name in attrs: attr = getattr(module, attr_name) if attr_name is not None else module param = getattr(attr, param_name) self.assertIsNotNone(param.out_sharding) self.assertEqual(param.out_sharding, out_sharding[:le]) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/lora_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax from absl.testing import absltest import numpy as np from flax import nnx from jax import numpy as jnp class TestLora(absltest.TestCase): def test_basic(self): module = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(0), (1, 3)) y = module(x) assert y.shape == (1, 4) assert module.lora_a.shape == (3, 2) assert module.lora_b.shape == (2, 4) np.testing.assert_allclose(y, x @ module.lora_a @ module.lora_b) def test_lora_base_module(self): rngs = nnx.Rngs(0) linear = nnx.Linear(3, 4, use_bias=False, rngs=rngs) module = nnx.LoRA(3, 2, 4, base_module=linear, rngs=rngs) x = jax.random.normal(jax.random.key(0), (1, 3)) y = module(x) assert y.shape == (1, 4) assert module.base_module == linear assert module.base_module.kernel.shape == (3, 4) assert module.base_module.bias == None assert module.lora_a.shape == (3, 2) assert module.lora_b.shape == (2, 4) np.testing.assert_allclose( y, x @ linear.kernel + x @ module.lora_a @ module.lora_b ) def test_layer_swap_lora(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) def __call__(self, x): x = self.linear1(x) return self.linear2(x) rngs = nnx.Rngs(0) model = MLP(3, rngs=rngs) x = jax.random.normal(jax.random.key(0), (1, 3)) y = model(x) # Replace one of the linear layers as LoRA linear layer. model.linear2 = nnx.LoRA(3, 4, 3, base_module=model.linear2, rngs=rngs) lora_y = model(x) assert y.shape == (1, 3) assert lora_y.shape == (1, 3) np.testing.assert_allclose(y, lora_y) a, b = model.linear2.lora_a[...], model.linear2.lora_b[...] np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) def test_layer_swap_loralinear(self): class MLP(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs): self.linear1 = nnx.Linear(dim, dim, rngs=rngs) self.linear2 = nnx.Linear(dim, dim, rngs=rngs) def __call__(self, x): x = self.linear1(x) return self.linear2(x) rngs = nnx.Rngs(0) model = MLP(3, rngs=rngs) x = jax.random.normal(jax.random.key(0), (1, 3)) y = model(x) # Replace one of the linear layers as LoRA linear layer. _, state = nnx.split( model.linear2 ) # To keep the kernel and bias of linear2 model.linear2 = nnx.LoRALinear(3, 3, lora_rank=4, rngs=rngs) nnx.update(model.linear2, state) lora_y = model(x) assert y.shape == (1, 3) assert lora_y.shape == (1, 3) np.testing.assert_allclose(y, lora_y) a, b = model.linear2.lora.lora_a[...], model.linear2.lora.lora_b[...] np.testing.assert_allclose(y + model.linear1(x) @ a @ b, lora_y) def test_lora_param_type(self): rngs = nnx.Rngs(0) model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.LoRAParam, rngs=rngs) _, lora_params, params = nnx.split(model, nnx.LoRAParam, nnx.Param) assert params == {} assert ('lora_a' in lora_params) and ('lora_b' in lora_params) np.testing.assert_allclose(lora_params['lora_a'][...], model.lora_a[...]) model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.Param, rngs=rngs) _, params, lora_params = nnx.split(model, nnx.Param, nnx.LoRAParam) assert ('lora_a' in params) and ('lora_b' in params) np.testing.assert_allclose(params['lora_a'][...], model.lora_a[...]) assert lora_params == {} def test_dtype(self): rngs = nnx.Rngs(0) model = nnx.LoRA(3, 4, 2, dtype=jnp.float16, param_dtype=jnp.float32, rngs=rngs) assert model.lora_a.dtype == jnp.float32 y = model(jnp.ones((1, 3)).astype(jnp.float32)) assert y.dtype == jnp.float16 if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/normalization_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp import jax import jax.numpy as jnp from absl.testing import absltest from absl.testing import parameterized import numpy as np from flax import linen from flax import nnx from flax.typing import Dtype class TestLinenConsistency(parameterized.TestCase): @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_batchnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): self.norm_layer = nnx.BatchNorm( 5, use_running_average=False, dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, promote_dtype=lambda x, **kwargs: x, # ensure same behavior as Linen rngs=rngs, ) self.linear = nnx.Linear( 5, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_fast_variance: bool = True def setup(self): self.norm_layer = linen.BatchNorm( use_running_average=False, dtype=self.dtype, param_dtype=self.param_dtype, use_fast_variance=use_fast_variance, ) self.linear = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 5)) linen_model = LinenModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance ) variables: dict = linen_model.init(jax.random.key(1), x) linen_out, batch_stats = linen_model.apply( variables, x, mask=mask, mutable=['batch_stats'] ) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) nnx_model.linear.kernel[...] = variables['params']['linear']['kernel'] nnx_model.linear.bias[...] = variables['params']['linear']['bias'] linen_out, updates = linen_model.apply( variables, x, mask=mask, mutable=['batch_stats'] ) variables.update(updates) nnx_out = nnx_model(x, mask=mask) np.testing.assert_array_equal(linen_out, nnx_out) # Compare BatchNorm parameters np.testing.assert_array_equal( variables['params']['norm_layer']['scale'], nnx_model.norm_layer.scale[...], ) np.testing.assert_array_equal( variables['params']['norm_layer']['bias'], nnx_model.norm_layer.bias[...] ) np.testing.assert_array_equal( variables['batch_stats']['norm_layer']['mean'], nnx_model.norm_layer.mean[...], ) np.testing.assert_array_equal( variables['batch_stats']['norm_layer']['var'], nnx_model.norm_layer.var[...], ) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_layernorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): self.norm_layer = nnx.LayerNorm( 5, dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, promote_dtype=lambda x, **kwargs: x, # ensure same behavior as Linen rngs=rngs, ) self.linear = nnx.Linear( 5, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_fast_variance: bool = True def setup(self): self.norm_layer = linen.LayerNorm( dtype=self.dtype, param_dtype=self.param_dtype, use_fast_variance=self.use_fast_variance, ) self.linear = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 5)) linen_model = LinenModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance ) variables = linen_model.init(jax.random.key(1), x) linen_out = linen_model.apply(variables, x, mask=mask) assert isinstance(linen_out, jax.Array) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) nnx_model.linear.kernel[...] = variables['params']['linear']['kernel'] nnx_model.linear.bias[...] = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], mask=[None, np.array([True, False, True, False, True])], ) def test_nnx_linen_rmsnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): self.norm_layer = nnx.RMSNorm( 5, dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, promote_dtype=lambda x, **kwargs: x, # ensure same behavior as Linen rngs=rngs, ) self.linear = nnx.Linear( 5, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_fast_variance: bool = True def setup(self): self.norm_layer = linen.RMSNorm( dtype=self.dtype, param_dtype=self.param_dtype, use_fast_variance=self.use_fast_variance, ) self.linear = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 5)) linen_model = LinenModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance ) variables = linen_model.init(jax.random.key(1), x) linen_out = linen_model.apply(variables, x, mask=mask) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) nnx_model.linear.kernel.set_value(variables['params']['linear']['kernel']) nnx_model.linear.bias.set_value(variables['params']['linear']['bias']) nnx_out = nnx_model(x, mask=mask) assert isinstance(linen_out, jax.Array) np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], mask=[None, np.array([True, False, True, False, True, False])], ) def test_nnx_linen_groupnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): self.norm_layer = nnx.GroupNorm( 6, num_groups=3, dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, promote_dtype=lambda x, **kwargs: x, # ensure same behavior as Linen rngs=rngs, ) self.linear = nnx.Linear( 6, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_fast_variance: bool = True def setup(self): self.norm_layer = linen.GroupNorm( num_groups=3, dtype=self.dtype, param_dtype=self.param_dtype, use_fast_variance=self.use_fast_variance, ) self.linear = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 6)) linen_model = LinenModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance ) variables = linen_model.init(jax.random.key(1), x) linen_out = linen_model.apply(variables, x, mask=mask) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) nnx_model.linear.kernel[...] = variables['params']['linear']['kernel'] nnx_model.linear.bias[...] = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) assert isinstance(linen_out, jax.Array) np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], scale_init=[ nnx.initializers.ones, nnx.initializers.constant(10.0), nnx.initializers.constant(0.5), ], ) def test_nnx_linen_weightnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, scale_init: nnx.Initializer, ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, rngs): self.dense = nnx.Linear( 8, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) self.normed = nnx.WeightNorm( self.dense, use_scale=True, scale_init=scale_init, feature_axes=-1, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__(self, x, *, mask=None): return self.normed(x) class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 def setup(self): self.dense = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) self.weight_norm = linen.WeightNorm( self.dense, variable_filter={'kernel'}, scale_init=scale_init ) def __call__(self, x, *, mask=None): return self.weight_norm(x) rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 8)) linen_model = LinenModel(dtype=dtype, param_dtype=param_dtype) variables = linen_model.init(jax.random.key(1), x) nnx_model = NNXModel(dtype=dtype, param_dtype=param_dtype, rngs=rngs) nnx_model.dense.kernel.set_value(variables['params']['dense']['kernel']) nnx_model.dense.bias.set_value(variables['params']['dense']['bias']) linen_out = linen_model.apply(variables, x) nnx_out = nnx_model(x) np.testing.assert_array_equal( variables['params']['weight_norm']['dense/kernel/scale'], nnx_model.normed.scales[('kernel',)]) np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], use_fast_variance=[True, False], mask=[None, np.array([True, False, True, False, True, False])], ) def test_nnx_linen_instancenorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, use_fast_variance: bool, mask: tp.Optional[np.ndarray], ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, use_fast_variance, rngs): self.norm_layer = nnx.InstanceNorm( 6, dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) self.linear = nnx.Linear( 6, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 use_fast_variance: bool = True def setup(self): self.norm_layer = linen.InstanceNorm( dtype=self.dtype, param_dtype=self.param_dtype, use_fast_variance=self.use_fast_variance, ) self.linear = linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ) def __call__(self, x, *, mask=None): x = self.norm_layer(x, mask=mask) x = self.linear(x) return x rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 6)) linen_model = LinenModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance ) variables = linen_model.init(jax.random.key(1), x) linen_out = linen_model.apply(variables, x, mask=mask) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance, rngs=rngs, ) nnx_model.linear.kernel[...] = variables['params']['linear']['kernel'] nnx_model.linear.bias[...] = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) assert isinstance(linen_out, jax.Array) np.testing.assert_array_equal(linen_out, nnx_out) @parameterized.product( dtype=[jnp.float32, jnp.float16], param_dtype=[jnp.float32, jnp.float16], n_steps=[1, 10], update_stats=[True, False], ) def test_nnx_linen_spectralnorm_equivalence( self, dtype: tp.Optional[Dtype], param_dtype: Dtype, n_steps: int, update_stats: bool, ): class NNXModel(nnx.Module): def __init__(self, dtype, param_dtype, rngs): self.seq = nnx.Sequential( nnx.Linear( 5, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs ), nnx.relu, nnx.BatchNorm( 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs, use_running_average=not update_stats, ), ) self.norm_layer = nnx.SpectralNorm( self.seq, n_steps=n_steps, dtype=dtype, param_dtype=param_dtype, update_stats=update_stats, rngs=rngs, ) def __call__(self, x): return self.norm_layer(x) class LinenModel(linen.Module): dtype: tp.Optional[Dtype] = None param_dtype: Dtype = jnp.float32 def setup(self): self.seq = linen.Sequential([ linen.Dense( 4, dtype=self.dtype, param_dtype=self.param_dtype ), linen.relu, linen.BatchNorm( dtype=self.dtype, param_dtype=self.param_dtype, use_running_average=not update_stats, ), ]) self.norm_layer = linen.SpectralNorm(self.seq, n_steps=n_steps) def __call__(self, x): return self.norm_layer(x, update_stats=update_stats) rngs = nnx.Rngs(42) x = jax.random.normal(jax.random.key(0), (10, 5)) linen_model = LinenModel(dtype=dtype, param_dtype=param_dtype) variables = linen_model.init(jax.random.key(1), x) nnx_model = NNXModel( dtype=dtype, param_dtype=param_dtype, rngs=rngs ) # Setup the same weights and batch stats var_params_seq_0 = variables['params']['seq']['layers_0'] nnx_model.seq.layers[0].kernel.set_value(var_params_seq_0['kernel']) nnx_model.seq.layers[0].bias.set_value(var_params_seq_0['bias']) var_params_seq_2 = variables['params']['seq']['layers_2'] nnx_model.seq.layers[2].scale.set_value(var_params_seq_2['scale']) nnx_model.seq.layers[2].bias.set_value(var_params_seq_0['bias']) var_norm_layer = variables['batch_stats']['norm_layer'] nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'u') ].set_value(var_norm_layer['seq/layers_0/kernel/u']) nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'sigma') ].set_value(var_norm_layer['seq/layers_0/kernel/sigma']) linen_out = linen_model.apply(variables, x, mutable=['batch_stats']) nnx_out = nnx_model(x) np.testing.assert_array_equal(linen_out[0], nnx_out) np.testing.assert_array_equal( nnx_model.norm_layer.batch_stats[("layers", 0, "kernel", "u")], linen_out[1]['batch_stats']['norm_layer']['seq/layers_0/kernel/u'], ) np.testing.assert_array_equal( nnx_model.norm_layer.batch_stats[("layers", 0, "kernel", "sigma")], linen_out[1]['batch_stats']['norm_layer']['seq/layers_0/kernel/sigma'], ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/recurrent_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax, jax.numpy as jnp from jax import random from flax import linen from flax import nnx from flax.nnx.nn import initializers import numpy as np from absl.testing import absltest class TestLSTMCell(absltest.TestCase): def test_basic(self): module = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(0), ) x = jnp.ones((2, 3)) carry = module.initialize_carry(x.shape, nnx.Rngs(0)) new_carry, y = module(carry, x) self.assertEqual(y.shape, (2, 4)) def test_lstm_sequence(self): """Test LSTMCell over a sequence of inputs.""" module = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(0), ) x = random.normal(random.PRNGKey(1), (5, 2, 3)) # seq_len, batch, feature carry = module.initialize_carry(x.shape[1:], nnx.Rngs(0)) outputs = [] for t in range(x.shape[0]): carry, y = module(carry, x[t]) outputs.append(y) outputs = jnp.stack(outputs) self.assertEqual(outputs.shape, (5, 2, 4)) def test_lstm_with_different_dtypes(self): """Test LSTMCell with different data types.""" module = nnx.LSTMCell( in_features=3, hidden_features=4, dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=nnx.Rngs(0), ) x = jnp.ones((2, 3), dtype=jnp.bfloat16) carry = module.initialize_carry(x.shape, nnx.Rngs(0)) new_carry, y = module(carry, x) self.assertEqual(y.dtype, jnp.bfloat16) self.assertEqual(y.shape, (2, 4)) def test_lstm_with_custom_activations(self): """Test LSTMCell with custom activation functions.""" module = nnx.LSTMCell( in_features=3, hidden_features=4, gate_fn=jax.nn.relu, activation_fn=jax.nn.elu, rngs=nnx.Rngs(0), ) x = jnp.ones((1, 3)) carry = module.initialize_carry(x.shape, nnx.Rngs(0)) new_carry, y = module(carry, x) self.assertEqual(y.shape, (1, 4)) def test_lstm_initialize_carry(self): """Test the initialize_carry method.""" module = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(0), ) x_shape = (1, 3) carry = module.initialize_carry( x_shape, nnx.Rngs(0), carry_init=initializers.ones ) c, h = carry self.assertTrue(jnp.all(c == 1.0)) self.assertTrue(jnp.all(h == 1.0)) self.assertEqual(c.shape, (1, 4)) self.assertEqual(h.shape, (1, 4)) def test_lstm_with_variable_sequence_length(self): """Test LSTMCell with variable sequence lengths.""" module = nnx.LSTMCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)) # Simulate a batch with variable sequence lengths x = jnp.array( [ [[1, 2, 3], [4, 5, 6], [0, 0, 0]], # Sequence length 2 [[7, 8, 9], [10, 11, 12], [13, 14, 15]], # Sequence length 3 ] ) # Shape: (batch_size=2, max_seq_length=3, features=3) seq_lengths = jnp.array([2, 3]) # Actual lengths for each sequence batch_size = x.shape[0] max_seq_length = x.shape[1] carry = module.initialize_carry((batch_size, 3), nnx.Rngs(0)) outputs = [] for t in range(max_seq_length): input_t = x[:, t, :] carry, y = module(carry, input_t) outputs.append(y) outputs = jnp.stack( outputs, axis=1 ) # Shape: (batch_size, max_seq_length, hidden_features) # Zero out outputs beyond the actual sequence lengths mask = jnp.arange(max_seq_length)[None, :] < seq_lengths[:, None] outputs = outputs * mask[:, :, None] self.assertEqual(outputs.shape, (2, 3, 4)) def test_lstm_stateful(self): """Test that LSTMCell maintains state across calls.""" module = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(0), ) x1 = jnp.ones((1, 3)) x2 = jnp.ones((1, 3)) * 2 carry = module.initialize_carry(x1.shape, nnx.Rngs(0)) carry, y1 = module(carry, x1) carry, y2 = module(carry, x2) self.assertEqual(y1.shape, (1, 4)) self.assertEqual(y2.shape, (1, 4)) def test_lstm_equivalence_with_flax_linen(self): """Test that nnx.LSTMCell produces the same outputs as flax.linen.LSTMCell.""" in_features = 3 hidden_features = 4 key = random.PRNGKey(42) x = random.normal(key, (1, in_features)) # Initialize nnx.LSTMCell rngs_nnx = nnx.Rngs(0) module_nnx = nnx.LSTMCell( in_features=in_features, hidden_features=hidden_features, rngs=rngs_nnx, ) carry_nnx = module_nnx.initialize_carry(x.shape, rngs_nnx) # Initialize flax.linen.LSTMCell module_linen = linen.LSTMCell( features=hidden_features, ) carry_linen = module_linen.initialize_carry(random.PRNGKey(0), x.shape) variables_linen = module_linen.init(random.PRNGKey(1), carry_linen, x) # Copy parameters from flax.linen.LSTMCell to nnx.LSTMCell params_linen = variables_linen['params'] # Map the parameters from linen to nnx # Assuming the parameter names and shapes are compatible # For a precise mapping, you might need to adjust parameter names # Get the parameters from nnx module nnx_params = module_nnx.__dict__ # Map parameters from linen to nnx for gate in ['i', 'f', 'g', 'o']: # Input kernels (input to gate) if gate == 'f': nnx_layer = getattr(module_nnx, f'if_') else: nnx_layer = getattr(module_nnx, f'i{gate}') linen_params = params_linen[f'i{gate}'] nnx_layer.kernel[...] = linen_params['kernel'] if nnx_layer.use_bias: nnx_layer.bias[...] = linen_params['bias'] # Hidden kernels (hidden state to gate) nnx_layer = getattr(module_nnx, f'h{gate}') linen_params = params_linen[f'h{gate}'] nnx_layer.kernel[...] = linen_params['kernel'] if nnx_layer.use_bias: nnx_layer.bias[...] = linen_params['bias'] # Run both modules new_carry_nnx, y_nnx = module_nnx(carry_nnx, x) new_carry_linen, y_linen = module_linen.apply( variables_linen, carry_linen, x ) # Compare outputs np.testing.assert_allclose(y_nnx, y_linen, atol=1e-5) # Compare carries for c_nnx, c_linen in zip(new_carry_nnx, new_carry_linen): np.testing.assert_allclose(c_nnx, c_linen, atol=1e-5) class TestRNN(absltest.TestCase): def test_rnn_with_lstm_cell(self): """Test RNN module using LSTMCell.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(0), ) # Initialize the RNN module with the LSTMCell rnn = nnx.RNN(cell) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.ones((2, 5, 3)) # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(0)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual( outputs.shape, (2, 5, 4) ) # Output features should match hidden_features def test_rnn_with_gru_cell(self): """Test RNN module using GRUCell.""" # Initialize the GRUCell cell = nnx.GRUCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(1), ) # Initialize the RNN module with the GRUCell rnn = nnx.RNN(cell) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.ones((2, 5, 3)) # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(1)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual( outputs.shape, (2, 5, 4) ) # Output features should match hidden_features def test_rnn_time_major(self): """Test RNN module with time_major=True.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(2), ) # Initialize the RNN module with time_major=True rnn = nnx.RNN(cell, time_major=True) # Create input data (seq_length=5, batch_size=2, features=3) x = jnp.ones((5, 2, 3)) # Initialize the carry carry = cell.initialize_carry(x.shape[1:2] + x.shape[2:], nnx.Rngs(2)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual( outputs.shape, (5, 2, 4) ) # Output features should match hidden_features def test_rnn_reverse(self): """Test RNN module with reverse=True.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(3), ) # Initialize the RNN module with reverse=True rnn = nnx.RNN(cell, reverse=True) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.tile(jnp.arange(5), (2, 1)).reshape( 2, 5, 1 ) # Distinct values to check reversal x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) # Run the RNN module outputs = rnn(x) # Check if the outputs are in reverse order outputs_reversed = outputs[:, ::-1, :] # Since we used distinct input values, we can compare outputs to check reversal # For simplicity, just check the shapes here self.assertEqual(outputs.shape, (2, 5, 4)) self.assertEqual(outputs_reversed.shape, (2, 5, 4)) def test_rnn_with_seq_lengths(self): """Test RNN module with variable sequence lengths.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(4), ) # Initialize the RNN module rnn = nnx.RNN(cell, return_carry=True) # Create input data with padding (batch_size=2, seq_length=5, features=3) x = jnp.array( [ [ [1, 1, 1], [2, 2, 2], [3, 3, 3], [0, 0, 0], [0, 0, 0], ], # Sequence length 3 [ [4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7], [8, 8, 8], ], # Sequence length 5 ] ) # Shape: (2, 5, 3) seq_lengths = jnp.array([3, 5]) # Actual lengths for each sequence # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(4)) # Run the RNN module final_carry, outputs = rnn(x, initial_carry=carry, seq_lengths=seq_lengths) self.assertEqual(outputs.shape, (2, 5, 4)) self.assertEqual( final_carry[0].shape, (2, 4) ) # c: (batch_size, hidden_features) self.assertEqual( final_carry[1].shape, (2, 4) ) # h: (batch_size, hidden_features) # Todo: a better test by matching the outputs with the expected values def test_rnn_with_keep_order(self): """Test RNN module with reverse=True and keep_order=True.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(5), ) # Initialize the RNN module with reverse=True and keep_order=True rnn = nnx.RNN(cell, reverse=True, keep_order=True) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.tile(jnp.arange(5), (2, 1)).reshape( 2, 5, 1 ) # Distinct values to check reversal x = jnp.concatenate([x, x, x], axis=-1) # Shape: (2, 5, 3) # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(5)) # Run the RNN module outputs = rnn(x, initial_carry=carry) # Check if the outputs are in the original order despite processing in reverse self.assertEqual(outputs.shape, (2, 5, 4)) def test_rnn_equivalence_with_flax_linen(self): """Test that nnx.RNN produces the same outputs as flax.linen.RNN.""" in_features = 3 hidden_features = 4 seq_length = 5 batch_size = 2 key = random.PRNGKey(42) # Create input data x = random.normal(key, (batch_size, seq_length, in_features)) # Initialize nnx.LSTMCell and RNN rngs_nnx = nnx.Rngs(0) cell_nnx = nnx.LSTMCell( in_features=in_features, hidden_features=hidden_features, rngs=rngs_nnx, ) rnn_nnx = nnx.RNN(cell_nnx) # Initialize flax.linen.LSTMCell and RNN cell_linen = linen.LSTMCell(features=hidden_features) rnn_linen = linen.RNN(cell_linen) carry_linen = cell_linen.initialize_carry(random.PRNGKey(0), x[:, 0].shape) variables_linen = rnn_linen.init(random.PRNGKey(1), x) # Copy parameters from flax.linen to nnx params_linen = variables_linen['params']['cell'] # Copy cell parameters for gate in ['i', 'f', 'g', 'o']: # Input kernels if gate == 'f': nnx_layer = getattr(cell_nnx, f'if_') else: nnx_layer = getattr(cell_nnx, f'i{gate}') linen_params = params_linen[f'i{gate}'] nnx_layer.kernel[...] = linen_params['kernel'] if nnx_layer.use_bias: nnx_layer.bias[...] = linen_params['bias'] # Hidden kernels nnx_layer = getattr(cell_nnx, f'h{gate}') linen_params = params_linen[f'h{gate}'] nnx_layer.kernel[...] = linen_params['kernel'] if nnx_layer.use_bias: nnx_layer.bias[...] = linen_params['bias'] # Initialize carries carry_nnx = cell_nnx.initialize_carry((batch_size, in_features), rngs_nnx) # Run nnx.RNN outputs_nnx = rnn_nnx(x, initial_carry=carry_nnx) # Run flax.linen.RNN outputs_linen = rnn_linen.apply( variables_linen, x, initial_carry=carry_linen ) # Compare outputs np.testing.assert_allclose(outputs_nnx, outputs_linen, atol=1e-5) def test_rnn_with_unroll(self): """Test RNN module with unroll parameter.""" # Initialize the LSTMCell cell = nnx.LSTMCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(6)) # Initialize the RNN module with unroll=2 rnn = nnx.RNN(cell, unroll=2) # Create input data (batch_size=2, seq_length=6, features=3) x = jnp.ones((2, 6, 3)) # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(6)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual( outputs.shape, (2, 6, 4) ) # Output features should match hidden_features def test_rnn_with_custom_cell(self): """Test RNN module with a custom RNN cell.""" class CustomRNNCell(nnx.Module): """A simple custom RNN cell.""" in_features: int hidden_features: int def __init__(self, in_features, hidden_features, rngs): self.in_features = in_features self.hidden_features = hidden_features self.rngs = rngs self.dense = nnx.Linear( in_features=in_features + hidden_features, out_features=hidden_features, rngs=rngs, ) def __call__(self, carry, inputs): h = carry x = jnp.concatenate([inputs, h], axis=-1) new_h = jax.nn.tanh(self.dense(x)) return new_h, new_h def initialize_carry(self, input_shape, rngs): batch_size = input_shape[0] h = jnp.zeros((batch_size, self.hidden_features)) return h @property def num_feature_axes(self) -> int: return 1 # Initialize the custom RNN cell cell = CustomRNNCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(7)) # Initialize the RNN module rnn = nnx.RNN(cell) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.ones((2, 5, 3)) # Initialize the carry carry = cell.initialize_carry((2, 3), cell.rngs) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual( outputs.shape, (2, 5, 4) ) # Output features should match hidden_features def test_rnn_with_different_dtypes(self): """Test RNN module with different data types.""" # Initialize the LSTMCell with float16 cell = nnx.LSTMCell( in_features=3, hidden_features=4, dtype=jnp.float16, param_dtype=jnp.float16, rngs=nnx.Rngs(8), ) # Initialize the RNN module rnn = nnx.RNN(cell) # Create input data (batch_size=2, seq_length=5, features=3) x = jnp.ones((2, 5, 3), dtype=jnp.float16) # Initialize the carry carry = cell.initialize_carry((2, 3), nnx.Rngs(8)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual(outputs.dtype, jnp.float16) self.assertEqual(outputs.shape, (2, 5, 4)) def test_rnn_with_variable_batch_size(self): """Test RNN module with variable batch sizes.""" # Initialize the LSTMCell cell = nnx.LSTMCell( in_features=3, hidden_features=4, rngs=nnx.Rngs(9), ) # Initialize the RNN module rnn = nnx.RNN(cell) for batch_size in [1, 2, 5]: # Create input data (batch_size, seq_length=5, features=3) x = jnp.ones((batch_size, 5, 3)) # Initialize the carry carry = cell.initialize_carry((batch_size, 3), nnx.Rngs(9)) # Run the RNN module outputs = rnn(x, initial_carry=carry) self.assertEqual(outputs.shape, (batch_size, 5, 4)) def test_recurrent_dropout(self): class LSTMWithRecurrentDropout(nnx.OptimizedLSTMCell): def __init__( self, *, rngs: nnx.Rngs, in_features: int, hidden_features: int, dropout_rate: float, **kwargs, ): super().__init__( in_features=in_features, hidden_features=hidden_features, rngs=rngs, keep_rngs=True, **kwargs, ) self.recurrent_dropout = nnx.Dropout( rate=dropout_rate, rng_collection='recurrent_dropout', rngs=rngs ) def __call__(self, carry, x): h, c = carry new_h, new_c = super().__call__((h, c), x) new_h = jax.tree.map(self.recurrent_dropout, new_h) return new_h, new_c class RNNWithRecurrentDropout(nnx.Module): def __init__( self, *, rngs: nnx.Rngs, in_features: int, hidden_features: int = 32, dropout_rate: float = 0.5, recurrent_dropout_rate: float = 0.25, ): cell = LSTMWithRecurrentDropout( in_features=in_features, hidden_features=hidden_features, rngs=rngs, dropout_rate=recurrent_dropout_rate, ) self.lstm = nnx.RNN(cell, broadcast_rngs='recurrent_dropout') self.dropout = nnx.Dropout(dropout_rate, rngs=rngs) self.dense = nnx.Linear( in_features=hidden_features, out_features=1, rngs=rngs ) def __call__(self, x): x = self.lstm(x) x = self.dropout(x) x = x[:, -1, :] # Use only the final hidden state return self.dense(x) model = RNNWithRecurrentDropout( in_features=32, hidden_features=64, dropout_rate=0.2, recurrent_dropout_rate=0.1, rngs=nnx.Rngs(0, recurrent_dropout=1), ) x = jnp.ones((8, 10, 32)) self.assertEqual(model.lstm.cell.recurrent_dropout.rngs.count[...], 0) y = model(x) self.assertEqual(y.shape, (8, 1)) self.assertEqual(model.lstm.cell.recurrent_dropout.rngs.count[...], 1) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/nn/stochastic_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax.numpy as jnp import numpy as np from flax import nnx import pytest class TestStochastic: def test_dropout_internal_rngs(self): n = 0 m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) m2 = nnx.Dropout(rate=0.5, deterministic=False) rngs2 = nnx.Rngs(dropout=0).fork() @nnx.jit def f(m, x, rngs=None): nonlocal n n += 1 return m(x, rngs=rngs) x = jnp.ones((1, 10)) assert m1.rngs is not None and m1.rngs.count[...] == 0 y1 = f(m1, x) assert n == 1 assert m1.rngs.count[...] == 1 y2 = f(m2, x, rngs=rngs2) assert n == 2 assert rngs2.dropout.count[...] == 1 np.testing.assert_allclose(y1, y2) y1 = f(m1, x) assert m1.rngs.count[...] == 2 y2 = f(m2, x, rngs=rngs2) assert rngs2.dropout.count[...] == 2 np.testing.assert_allclose(y1, y2) assert n == 2 def test_dropout_rng_override(self): m1 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) m2 = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=1)) x = jnp.ones((1, 10)) y1 = m1(x) y2 = m2(x) with pytest.raises(AssertionError): np.testing.assert_allclose(y1, y2) y2 = m2(x, rngs=nnx.Rngs(dropout=0).fork()) np.testing.assert_allclose(y1, y2) def test_dropout_arg_override(self): m = nnx.Dropout(rate=0.5) x = jnp.ones((1, 10)) # deterministic call arg provided m(x, deterministic=True) # deterministic constructor arg provided m.set_attributes(deterministic=True) y = m(x) # both deterministic call and constructor arg provided with pytest.raises(AssertionError): np.testing.assert_allclose( y, m(x, deterministic=False, rngs=nnx.Rngs(dropout=0)) ) # no rng arg provided m.set_attributes(deterministic=False) with pytest.raises( ValueError, match='`deterministic` is False, but no `rngs` argument was provided to Dropout', ): m(x) def test_dropout_arg_override_view(self): m = nnx.Dropout(rate=0.5) x = jnp.ones((1, 10)) # deterministic call arg provided m(x, deterministic=True) # deterministic constructor arg provided new_m = nnx.view(m, deterministic=True) y = new_m(x) # both deterministic call and constructor arg provided with pytest.raises(AssertionError): np.testing.assert_allclose( y, new_m(x, deterministic=False, rngs=nnx.Rngs(dropout=0)) ) # no rng arg provided new_m = nnx.view(m, deterministic=False) with pytest.raises( ValueError, match='`deterministic` is False, but no `rngs` argument was provided to Dropout', ): new_m(x) ================================================ FILE: tests/nnx/optimizer_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from absl.testing import parameterized from flax import nnx import jax import jax.numpy as jnp import numpy as np import optax def assert_equal(path, x, y): np.testing.assert_array_equal(x, y, err_msg=f'Mismatch at path: {path}') def assert_not_equal(path, x, y): np.testing.assert_( np.any(np.not_equal(x, y)), msg=f'Unexpected match at path: {path}' ) class Model(nnx.Module): def __init__(self, in_features, out_features, rngs): self.linear1 = nnx.Linear(in_features, 3, rngs=rngs) self.linear2 = nnx.Linear(3, out_features, rngs=rngs) def __call__(self, x): return self.linear2(self.linear1(x)) class TestOptimizer(parameterized.TestCase): @parameterized.parameters( {'module_cls': nnx.Linear}, {'module_cls': Model}, ) def test_split_merge(self, module_cls): x = jax.random.normal(jax.random.key(0), (1, 2)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) tx = optax.adam(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) out = model(x) graphdef, optimizer = nnx.split(optimizer) optimizer = nnx.merge(graphdef, optimizer) np.testing.assert_allclose(out, model(x)) def test_update(self): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(0.1), wrt=nnx.Param) def loss_fn(model): params = nnx.state(model) loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) return loss grads = nnx.grad(loss_fn)(model) optimizer.update(model, grads) def test_sharding_propagation(self): with jax.set_mesh( jax.make_mesh( (1, 1), ('a', 'b'), axis_types=(jax.sharding.AxisType.Auto,) * len(('a', 'b')), ) ): model = nnx.Linear( 2, 3, rngs=nnx.Rngs(0), kernel_init=nnx.with_partitioning( nnx.initializers.lecun_normal(), sharding=('a', 'b'), ), use_bias=False, ) optimizer = nnx.Optimizer(model, optax.adamw(0.1), wrt=nnx.Param) state = nnx.state(optimizer) partition_spec = nnx.get_partition_spec(state) self.assertEqual(state['opt_state'][0]['mu']['kernel'].out_sharding, ('a', 'b')) self.assertEqual( partition_spec['opt_state'][0]['mu']['kernel'].get_value(), jax.sharding.PartitionSpec('a', 'b'), ) @parameterized.product( module_cls=[nnx.Linear, Model], jit_decorator=[lambda f: f, nnx.jit, jax.jit], optimizer=[optax.sgd, optax.adam], ) def test_jit(self, module_cls, jit_decorator, optimizer): x = jax.random.normal(jax.random.key(0), (1, 2)) y = jnp.ones((1, 4)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) tx = optimizer( 1e-3 ) # TODO: this doesn't work with adam optimizer for some reason state = nnx.ModelAndOptimizer(model, tx) if jit_decorator == jax.jit: model_static, model_state = nnx.split(model) loss_fn = lambda graphdef, state, x, y: ( (nnx.merge(graphdef, state)(x) - y) ** 2 ).mean() initial_loss = loss_fn(model_static, model_state, x, y) def jax_jit_train_step(graphdef, state, x, y): state = nnx.merge(graphdef, state) model_static, model_state = nnx.split(model) grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y) state.update(grads) return nnx.split(state) graphdef, state = jit_decorator(jax_jit_train_step)( *nnx.split(state), x, y ) state = nnx.merge(graphdef, state) new_loss = loss_fn(*nnx.split(state.model), x, y) else: loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): grads = nnx.grad(loss_fn)(optimizer.model, x, y) optimizer.update(grads) jit_decorator(nnx_jit_train_step)(state, x, y) new_loss = loss_fn(state.model, x, y) self.assertTrue(new_loss < initial_loss) @parameterized.product( module_cls=[nnx.Linear, Model], jit_decorator=[lambda f: f, nnx.jit, jax.jit], optimizer=[optax.lbfgs], ) def test_jit_linesearch(self, module_cls, jit_decorator, optimizer): x = jax.random.normal(jax.random.key(0), (1, 2)) y = jnp.ones((1, 4)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) tx = optimizer(1e-3) state = nnx.ModelAndOptimizer(model, tx) if jit_decorator == jax.jit: model_static, model_state = nnx.split(state.model) loss_fn = lambda graphdef, state, x, y: ( (nnx.merge(graphdef, state)(x) - y) ** 2 ).mean() initial_loss = loss_fn(model_static, model_state, x, y) def jax_jit_train_step(graphdef, state, x, y): state = nnx.merge(graphdef, state) model_static, model_state = nnx.split(state.model) grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y) state.update( grads, grad=grads, value=initial_loss, value_fn=lambda state: loss_fn(model_static, state, x, y), ) return nnx.split(state) graphdef, state = jit_decorator(jax_jit_train_step)( *nnx.split(state), x, y ) state = nnx.merge(graphdef, state) new_loss = loss_fn(*nnx.split(state.model), x, y) else: graphdef = nnx.graphdef(model) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): grads = nnx.grad(loss_fn)(optimizer.model, x, y) optimizer.update( grads, grad=grads, value=initial_loss, value_fn=loss_fn_split ) jit_decorator(nnx_jit_train_step)(state, x, y) new_loss = loss_fn(state.model, x, y) self.assertTrue(new_loss < initial_loss) @parameterized.product( module_cls=[nnx.Linear, Model], optimizer=[optax.sgd, optax.adam], ) def test_metrics(self, module_cls, optimizer): class TrainState(nnx.ModelAndOptimizer): def __init__(self, model, tx, metrics): self.metrics = metrics super().__init__(model, tx) def update(self, *, grads, **updates): # type: ignore[signature-mismatch] self.metrics.update(**updates) super().update(grads) x = jax.random.normal(jax.random.key(0), (1, 2)) y = jnp.ones((1, 4)) model = module_cls(2, 4, rngs=nnx.Rngs(0)) tx = optax.adam(1e-3) metrics = nnx.metrics.Average() state = TrainState(model, tx, metrics) loss_fn = lambda model: ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn)(state.model) state.update(grads=grads, values=loss_fn(state.model)) initial_loss = state.metrics.compute() state.update(grads=grads, values=loss_fn(state.model)) self.assertTrue(state.metrics.compute() < initial_loss) @parameterized.parameters( {'variable': nnx.Param}, {'variable': nnx.LoRAParam}, {'variable': (nnx.Param, nnx.LoRAParam)}, ) def test_wrt_update(self, variable): in_features = 4 out_features = 10 model = nnx.LoRA( in_features=in_features, lora_rank=2, out_features=out_features, base_module=Model( in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) ), rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() grad_fn = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable)) def step(): grads = grad_fn(model, x, y) initial_loss = loss_fn(model, x, y) state.update(model, grads) self.assertTrue(loss_fn(model, x, y) < initial_loss) # Since lora_b is initialized to zeros by default, the gradient flow to lora_a # will be zeroed out in first call. Thus, run the step twice to make sure # lora_a is updated. for _ in range(2): step() # make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged variables, other_variables = nnx.state(model, variable, ...) jax.tree.map_with_path(assert_not_equal, prev_variables, variables) if other_variables: jax.tree.map_with_path( assert_equal, prev_other_variables, other_variables ) @parameterized.parameters( {'variable': nnx.Param}, # {'variable': nnx.LoRAParam}, {'variable': (nnx.Param, nnx.LoRAParam)}, ) def test_wrt_update_linesearch(self, variable): in_features = 4 out_features = 10 model = nnx.LoRA( in_features=in_features, lora_rank=2, out_features=out_features, base_module=Model( in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) ), rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable) prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() grad_fn = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable)) graphdef = nnx.graphdef(model) loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) def step(): grads = grad_fn(model, x, y) initial_loss = loss_fn(model, x, y) state.update( model, grads, grad=grads, value_fn=loss_fn_split, value=initial_loss ) self.assertTrue(loss_fn(model, x, y) < initial_loss) # Since lora_b is initialized to zeros by default, the gradient flow to lora_a # will be zeroed out in first call. Thus, run the step twice to make sure # lora_a is updated. for _ in range(2): step() # make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged variables, other_variables = nnx.state(model, variable, ...) jax.tree.map_with_path(assert_not_equal, prev_variables, variables) if other_variables: jax.tree.map_with_path( assert_equal, prev_other_variables, other_variables ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/partitioning_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from flax import nnx import jax import jax.numpy as jnp class TestPartitioning(absltest.TestCase): def test_partition(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(1)), nnx.BatchStat(jnp.array(2))]), b=nnx.Param(jnp.array(2)), c=100, ) graphdef, params, rest = nnx.split(m, nnx.Param, ...) self.assertLen(params, 2) self.assertLen(rest, 1) # check params self.assertEqual(params['a'][0][...], m.a[0][...]) self.assertEqual(params['b'][...], m.b[...]) # check rest self.assertEqual(rest['a'][1][...], m.a[1][...]) m2 = nnx.merge(graphdef, params, rest) self.assertEqual(m2.a[0][...], m.a[0][...]) self.assertEqual(m2.a[1][...], m.a[1][...]) self.assertEqual(m2.b[...], m.b[...]) self.assertEqual(m2.c, 100) def test_complete_partitioning(self): m = nnx.Dict( a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable) def test_complete_partitioning_plus_ellipsis(self): m = nnx.Dict( a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) # no error if additional ... is passed at the end nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable, ...) def test_inclomplete_partition_error(self): m = nnx.Dict( a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( ValueError, 'Non-exhaustive filters, got a non-empty remainder' ): nnx.split(m, nnx.Param) def test_ellipsis_not_last_error(self): m = nnx.Dict( a=nnx.List([nnx.Param(1), nnx.Param(2), nnx.Variable(3)]), b=nnx.Dict(c=nnx.Param(1), d=nnx.BatchStat(2)), ) with self.assertRaisesRegex( ValueError, '`...` or `True` can only be used as the last filters' ): nnx.split(m, ..., nnx.Param) def test_update_from(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(1)), nnx.BatchStat(jnp.array(3))]), b=nnx.Param(jnp.array(2)), c=100, ) state = nnx.split( m, )[1] state = jax.tree.map(lambda x: x * 2, state) nnx.update(m, state) self.assertEqual(m.a[0][...], 2) self.assertEqual(m.a[1][...], 6) self.assertEqual(m.b[...], 4) self.assertEqual(m.c, 100) def test_update_from_with_array_leaf(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(1)), nnx.BatchStat(jnp.array(3))]), b=nnx.Param(jnp.array(2)), c=nnx.Variable(jax.numpy.array(100)), ) graphdef, state = nnx.split(m) state = jax.tree.map(lambda x: x * 2, state) nnx.update(m, state) self.assertEqual(m.a[0][...], 2) self.assertEqual(m.a[1][...], 6) self.assertEqual(m.b[...], 4) self.assertEqual(m.c[...], 200) def test_grad_example(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(1.0)), nnx.BatchStat(jnp.array(-10))]), b=nnx.Param(jnp.array(2.0)), c=100, ) params = nnx.state(m, nnx.Param) def loss(params): return sum(2 * p for p in jax.tree_util.tree_leaves(params)) grads = jax.grad(loss)(params) nnx.update(m, grads) self.assertEqual(m.a[0][...], 2.0) self.assertEqual(m.a[1][...], -10) self.assertEqual(m.b[...], 2.0) self.assertEqual(m.c, 100) def test_get_paritition(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(10.0)), nnx.Param(jnp.array(20.0))]), b=nnx.Param(jnp.array(10.0)), c=7, d=5.0, ) state = nnx.state(m, nnx.Variable) self.assertEqual(state['a'][0][...], m.a[0][...]) self.assertEqual(state['a'][1][...], m.a[1][...]) self.assertEqual(state['b'][...], m.b[...]) self.assertIsNot(state['b'], state['a'][0]) self.assertLen(nnx.to_flat_state(state), 3) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/rngs_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Any import jax import jax.numpy as jnp import numpy as np from absl.testing import absltest, parameterized from flax import nnx from flax import errors class TestRngs(parameterized.TestCase): def test_call(self): rngs = nnx.Rngs(0) key = rngs() def test_fallback(self): rngs = nnx.Rngs(0) key = rngs.dropout() def test_fallback_error_no_default(self): rngs = nnx.Rngs(some_name=0) with self.assertRaisesRegex(AttributeError, 'No RngStream named'): key = rngs.dropout() def test_rng_stream(self): key0 = jax.random.key(0) rngs = nnx.Rngs(params=key0) self.assertEqual(rngs.params.count[...], 0) key1 = rngs.params() self.assertEqual(rngs.params.count[...], 1) self.assertIs(rngs.params.key[...], key0) self.assertFalse(jnp.allclose(key0, key1)) key2 = rngs.params() self.assertEqual(rngs.params.count[...], 2) self.assertIs(rngs.params.key[...], key0) self.assertFalse(jnp.allclose(key1, key2)) def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) @jax.jit def f(): with self.assertRaisesRegex( errors.TraceContextError, 'Cannot mutate RngCount from a different trace level', ): rngs.params() f() rngs1: Any = None @jax.jit def h(): nonlocal rngs1 rngs1 = nnx.Rngs(1) h() self.assertIsInstance(rngs1, nnx.Rngs) with self.assertRaisesRegex( errors.TraceContextError, 'Cannot mutate RngCount from a different trace level', ): rngs1.params() def test_jit_updates(self): class Foo(nnx.Module): def __init__(self, not_rngs): rngs = not_rngs self.linear = nnx.Linear(2, 2, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False) def __call__(self, x, rngs): x = self.linear(x) x = self.dropout(x, rngs=rngs) return x rngs = nnx.Rngs(0) m = Foo(rngs) # +1 for the Linear kernel, +1 for the Linear bias self.assertEqual(rngs['default'].count[...], 2) @nnx.jit def f(m: Foo, x: jax.Array, not_rngs: nnx.Rngs): rngs = not_rngs x = m(x, rngs) x = m(x, rngs) return x x = jnp.ones((2, 2)) x = f(m, x, rngs) # +1 for the Dropout mask self.assertEqual(rngs['default'].count[...], 4) def test_lifting_rng_state(self): class Foo(nnx.Module): def __init__(self, rngs): self.rngs = rngs self.dropout = nnx.Dropout(0.5, deterministic=False) self.linear = nnx.Linear(2, 3, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.dropout(x, rngs=self.rngs) return x rngs = nnx.Rngs(params=0, dropout=1) m = Foo(rngs) graphdef, params, rng_counts, dropout_keys, param_keys = nnx.split( m, nnx.Param, nnx.RngCount, 'dropout', 'params' ) self.assertEqual(m.rngs.params.count[...], 2) self.assertEqual(m.rngs['dropout'].count[...], 0) self.assertLen(nnx.to_flat_state(dropout_keys), 1) self.assertLen(nnx.to_flat_state(param_keys), 1) self.assertLen(nnx.to_flat_state(rng_counts), 2) # split dropout keys split_dropout_keys = jax.tree.map( lambda x: jax.random.split(x, 4), dropout_keys ) # replicate params params = jax.tree.map(lambda x: jnp.stack([x] * 4, axis=0), params) @partial( jax.vmap, in_axes=(0, 0, None, None, 0), out_axes=(0, 0, None), ) def f(params, dropout_keys, param_keys, rng_counts, x): m = nnx.merge(graphdef, params, dropout_keys, param_keys, rng_counts) y = m(x) _, params, rng_counts, dropout_keys, param_keys = nnx.split( m, nnx.Param, nnx.RngCount, 'dropout', 'params' ) return y, params, rng_counts x = jnp.ones((4, 1, 2)) y, params, rng_counts = f( params, split_dropout_keys, param_keys, rng_counts, x, ) nnx.update(m, params, dropout_keys, param_keys, rng_counts) self.assertEqual(y.shape, (4, 1, 3)) self.assertEqual(m.rngs.params.count[...], 2) self.assertEqual(m.rngs['dropout'].count[...], 1) @parameterized.parameters(True, False) def test_reseed(self, graph): class Model(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x): return self.dropout(self.linear(x)) model = Model(nnx.Rngs(params=0, dropout=42)) x = jnp.ones((1, 2)) y1 = model(x) nnx.reseed(model, graph=graph, dropout=42) y2 = model(x) np.testing.assert_allclose(y1, y2) @parameterized.parameters(True, False) def test_split_rngs(self, graph): rngs = nnx.Rngs(params=0, dropout=1) result = nnx.split_rngs(rngs, splits=5, graph=graph) if graph: self.assertEqual(rngs.params.key.shape, (5,)) self.assertEqual(rngs['dropout'].key.shape, (5,)) nnx.restore_rngs(result) self.assertEqual(rngs.params.key.shape, ()) self.assertEqual(rngs['dropout'].key.shape, ()) else: self.assertEqual(rngs.params.key.shape, ()) self.assertEqual(rngs['dropout'].key.shape, ()) self.assertEqual(result.params.key.shape, (5,)) self.assertEqual(result['dropout'].key.shape, (5,)) @parameterized.parameters(True, False) def test_fork_rngs(self, graph): rngs = nnx.Rngs(params=0, dropout=1) backups = nnx.fork_rngs(rngs, graph=graph) new_key = rngs.params.key.copy() nnx.restore_rngs(backups) self.assertNotEqual(rngs.params.key, new_key) def test_random_helpers(self): rngs = nnx.Rngs(0, params=1) x_nnx = rngs.normal((2, 3)) x_jax = jax.random.normal(jax.random.fold_in(jax.random.key(0), 0), (2, 3)) np.testing.assert_allclose(x_nnx, x_jax) x_nnx = rngs.params.uniform((2, 3)) x_jax = jax.random.uniform(jax.random.fold_in(jax.random.key(1), 0), (2, 3)) np.testing.assert_allclose(x_nnx, x_jax) x_nnx = rngs.lecun_normal()((2, 3)) x_jax = jax.nn.initializers.lecun_normal()( jax.random.fold_in(jax.random.key(0), 1), (2, 3) ) np.testing.assert_allclose(x_nnx, x_jax) x_nnx = rngs.params.lecun_uniform()((2, 3)) x_jax = jax.nn.initializers.lecun_uniform()( jax.random.fold_in(jax.random.key(1), 1), (2, 3) ) np.testing.assert_allclose(x_nnx, x_jax) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/spmd_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' from absl.testing import absltest from absl.testing import parameterized from flax import nnx import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P, NamedSharding, AxisType, reshard from jax.experimental.layout import Format, Layout import optax class TestSPMD(parameterized.TestCase): def setUp(self): if jax.device_count() < 4: self.skipTest('At least 4 devices required') def test_init(self): class Foo(nnx.Module): def __init__(self): self.w = nnx.Param( nnx.with_partitioning( lambda: jnp.ones((8, 2)), sharding=('model', 'data'), )() ) def __call__(self, x): return x @ self.w @jax.jit def create_module(): return nnx.split(Foo()) mesh = jax.make_mesh( (2, 2), ('model', 'data'), axis_types=(jax.sharding.AxisType.Auto,) * len(('model', 'data')), ) with jax.set_mesh(mesh): m: Foo = nnx.merge(*create_module()) # type: ignore[invalid-annotation] x = jax.device_put(jnp.zeros((4, 8)), P(None, 'model')) y = m(x) assert m.w.shape == (8, 2) assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) def test_init_all_devices(self): class Foo(nnx.Module): def __init__(self): self.w = nnx.Param( nnx.with_partitioning( lambda: jnp.ones((8, 2)), sharding=('model', 'data'), )() ) def __call__(self, x): return x @ self.w @jax.jit def create_module(): return nnx.split(Foo()) mesh = jax.make_mesh( (1, 1), ('model', 'data'), axis_types=(jax.sharding.AxisType.Auto,) * len(('model', 'data')), ) with jax.set_mesh(mesh): m: Foo = nnx.merge(*create_module()) # type: ignore[invalid-annotation] assert m.w.shape == (8, 2) assert m.w.sharding.shard_shape(m.w.shape) == (8, 2) def test_shard_optimizer_state(self): class Foo(nnx.Module): def __init__(self): self.w = nnx.Param( nnx.with_partitioning( lambda: jnp.ones((8, 2)), sharding=('row', 'col'), )() ) def __call__(self, x): return x @ self.w mesh = jax.make_mesh( (2, 2), ('row', 'col'), axis_types=(jax.sharding.AxisType.Auto,) * len(('row', 'col')), ) with jax.set_mesh(mesh): graphdef, params = nnx.split(Foo()) state = nnx.TrainState.create( graphdef, params=params, tx=optax.adam(1e-3), ) assert state.params['w'].sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) assert state.opt_state[0].mu['w'].sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) assert state.opt_state[0].nu['w'].sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) def test_add_remove_axis_in_transform(self): test = self kadds, kremoves, badds, bremoves = [], [], [], [] class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap( in_axes=(0, 0), transform_metadata={nnx.PARTITION_NAME: 'layers', 'nickname': 'nick'}, ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( 4, 4, kernel_init=nnx.with_metadata( nnx.initializers.lecun_normal(), out_sharding=('din', 'dout'), nickname=('in', 'out'), on_add_axis=lambda _, idx, name: kadds.append((idx, name)), on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), ), bias_init=nnx.with_metadata( nnx.initializers.zeros_init(), # no sharding annotation here! on_add_axis=lambda _, idx, name: badds.append((idx, name)), on_remove_axis=lambda _, idx, name: bremoves.append((idx, name)), ), rngs=rngs, ) @nnx.scan( in_axes=(0, nnx.Carry), transform_metadata={nnx.PARTITION_NAME: 'layers'} ) def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (4, 4)) test.assertEqual(self.linear.kernel.out_sharding, ('din', 'dout')) # at least a remove_axis was already called to remove the layer axis test.assertEqual(kremoves[-1], (0, 'layers')) test.assertEqual(bremoves[-1], (0, 'layers')) return x, None mesh = jax.make_mesh( (1, 2, 2), ('layers', 'din', 'dout'), axis_types=(jax.sharding.AxisType.Auto,) * len(('layers', 'din', 'dout')), ) with jax.set_mesh(mesh): m = MLP(rngs=nnx.Rngs(0)) self.assertEqual(m.linear.kernel.shape, (5, 4, 4)) self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out')) self.assertEqual(m.linear.bias.shape, (5, 4)) # One add_axis called to add the `nnx.vmap` dimension self.assertEqual(kadds, [(0, 'layers')]) self.assertEqual(kremoves, []) self.assertEqual(badds, [(0, 'layers')]) self.assertEqual(bremoves, []) # One remove_axis and one add_axis called when in and out of `nnx.scan` with jax.set_mesh(mesh): _ = m(jnp.ones((5, 4))) self.assertEqual(kadds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(kremoves, [(0, 'layers')]) self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(bremoves, [(0, 'layers')]) def test_transform_metadata_decorator(self): v = nnx.Param( jnp.array(0), out_sharding=('din', 'dout'), eager_sharding=False, ) @nnx.transform_metadata(in_axes=0, out_axes=1, partition='din') def f(v): v[...] += 1 self.assertEqual(v.out_sharding, ('dout',)) v2 = nnx.Param( jnp.array(10), out_sharding=('dmid', 'dout'), eager_sharding=False, ) return v2 v2 = f(v) self.assertEqual(v.out_sharding, ('din', 'dout')) self.assertEqual(v[...], 1) self.assertEqual(v2.out_sharding, ('dmid', 'din', 'dout')) self.assertEqual(v2[...], 10) @parameterized.product(use_eager_sharding=[True, False]) def test_eager_sharding_context(self, use_eager_sharding): rngs = nnx.Rngs(0) with nnx.use_eager_sharding(use_eager_sharding): mesh = jax.make_mesh( (2, 2), ('data', 'model'), axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto), ) with jax.set_mesh(mesh): w = nnx.Param( rngs.lecun_normal()((4, 8)), out_sharding=(None, 'model')) if use_eager_sharding: assert has_sharding_spec(w) else: assert not has_sharding_spec(w) def test_out_sharding_linear_layers(self): mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) with jax.set_mesh(mesh): replicated_array = jnp.arange(4).reshape(2, 2) sharded_array = reshard(replicated_array, P("X", None)) layers = [ nnx.Linear(2, 4, rngs=nnx.Rngs(0)), nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)), nnx.Einsum('ab,bc->ac', (2, 4), (4,), rngs=nnx.Rngs(0)), ] for layer in layers: assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array))) assert 'float32[2@X,4@Y]' in str(jax.typeof(layer(sharded_array, out_sharding=P("X", "Y")))) def test_out_sharding_embed(self): mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) with jax.set_mesh(mesh): emb = nnx.Embed(num_embeddings=8, features=4, rngs=nnx.Rngs(0)) emb = reshard(emb, P("X")) sharded_array = reshard(jnp.arange(4), P("Y")) self.assertRaises(Exception, emb, sharded_array) self.assertEqual('float32[4@X,4]', str(jax.typeof(emb(sharded_array, out_sharding=P("X"))))) def test_out_sharding_conv(self): mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) with jax.set_mesh(mesh): replicated_array = jnp.arange(32).reshape(2, 4, 4).astype(jnp.float32) sharded_array = reshard(replicated_array, P("X", None, None)) layer = nnx.Conv(4, 8, kernel_size=(3,), rngs=nnx.Rngs(0)) assert 'float32[2@X,4,8]' in str(jax.typeof(layer(sharded_array))) assert 'float32[2@X,4@Y,8]' in str(jax.typeof(layer(sharded_array, out_sharding=P("X", "Y", None)))) def test_out_sharding_embed_attend(self): mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) with jax.set_mesh(mesh): replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32) sharded_array = reshard(replicated_array, P("X", None)) layer = nnx.Embed(num_embeddings=10, features=4, rngs=nnx.Rngs(0)) assert 'float32[2@X,10]' in str(jax.typeof(layer.attend(sharded_array))) assert 'float32[2@X,10@Y]' in str(jax.typeof(layer.attend(sharded_array, out_sharding=P("X", "Y")))) def test_out_sharding_dropout(self): mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) with jax.set_mesh(mesh): replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32) sharded_array = reshard(replicated_array, P("X", None)) layers = [ nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)), nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)), ] for layer in layers: assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array))) @jax.jit def func(x, rngs): return layer(x, rngs=rngs) assert 'float32[2@X,4]' in str(jax.typeof(func(sharded_array, nnx.Rngs(0)))) @parameterized.product(use_hijax=[True, False]) def test_logical_rules(self, use_hijax): self.enter_context(nnx.var_defaults(hijax=use_hijax)) class Foo(nnx.Module): def __init__(self): self.w = nnx.Param( nnx.with_partitioning( lambda: jnp.ones((8, 2)), sharding=('row-alias', 'col-alias'), sharding_rules=(('row-alias', 'row'),), )() ) self.b = nnx.Param( nnx.with_partitioning( lambda: jnp.zeros((2,)), sharding=('col-alias',) )() ) def __call__(self, x): return x @ self.w + self.b mesh = jax.make_mesh( (1, 2, 2), ('layers', 'row', 'col'), axis_types=(jax.sharding.AxisType.Auto,) * len(('layers', 'row', 'col')), ) with jax.set_mesh(mesh), nnx.logical_axis_rules((('col-alias', 'col'),)): model = Foo() optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) assert model.w.sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) assert optimizer.opt_state[0].mu['w'].sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) assert optimizer.opt_state[0].nu['w'].sharding.is_equivalent_to( NamedSharding(mesh, P('row', 'col')), ndim=2) def test_get_abstract_model(self): class Foo(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear( 8, 8, rngs=rngs, use_bias=False, kernel_init=nnx.with_partitioning( nnx.initializers.lecun_normal(), (None, 'model'))) self.shared = self.linear.kernel mesh = jax.make_mesh( (2, 2), ('batch', 'model'), axis_types=(jax.sharding.AxisType.Auto,) * len(('batch', 'model')), ) gdef, abs_state = nnx.get_abstract_model(lambda: Foo(nnx.Rngs(0)), mesh) assert len(jax.tree.leaves(abs_state)) == 1 assert jax.tree.leaves(abs_state)[0].sharding.is_equivalent_to( NamedSharding(mesh, P(None, 'model')), ndim=2) @parameterized.parameters('auto', 'explicit', 'mixed') def test_sharding_axis_types(self, mode): if mode == 'auto': axis_types = (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto) elif mode == 'explicit': axis_types = (jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit) else: axis_types = (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Explicit) mesh = jax.make_mesh( (2, 2), ('row', 'col'), axis_types=axis_types, ) if mode == 'mixed': with self.assertRaises(ValueError): nnx.Variable( jnp.ones((4, 4)), out_sharding=('row', 'col'), mesh=mesh, ) else: v = nnx.Variable( jnp.ones((4, 4)), out_sharding=('row', 'col'), mesh=mesh, ) self.assertEqual(v.sharding.mesh, mesh) self.assertEqual(v.sharding.spec, P('row', 'col')) def test_eval_shape_with_explicit_sharding(self): axis_types = (jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit) mesh1 = jax.make_mesh((2, 2), ("a", "b"), axis_types) class Model(nnx.Module): def __init__(self): self.p1 = nnx.Param( reshard(jnp.ones((4,4)), NamedSharding(mesh1, P('a', 'b'))), mesh=mesh1) abs_model = nnx.eval_shape(lambda: Model()) self.assertEqual(abs_model.p1.sharding.spec, P('a', 'b')) def test_eval_shape_with_sharding0(self): # based on https://github.com/google/flax/issues/5110 mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) class Model(nnx.Module): def __init__(self): self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b"), "mesh": mesh1}) self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("c", "d"), "mesh": mesh2}) abs_model = nnx.eval_shape(lambda: Model()) assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding) assert abs_model.p1.kernel.sharding.mesh.axis_names == mesh1.axis_names assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b") assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding) assert abs_model.p2.kernel.sharding.mesh.axis_names == mesh2.axis_names assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d") def test_eval_shape_with_sharding1(self): class Model(nnx.Module): def __init__(self): self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b")}) mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) with jax.set_mesh(mesh): abs_model = nnx.eval_shape(lambda: Model()) assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding) assert abs_model.linear.kernel.sharding.mesh.axis_names == mesh.axis_names assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b") @parameterized.product(axis_type_name=['auto', 'explicit']) def test_variable_out_sharding_types(self, axis_type_name): if axis_type_name == 'auto': axis_types = (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto) else: # 'explicit' axis_types = (jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Explicit) mesh = jax.make_mesh( (2, 2), ('data', 'model'), axis_types=axis_types, ) with jax.set_mesh(mesh): value = jnp.ones((4, 4)) # Test with PartitionSpec v_pspec = nnx.Variable(value, out_sharding=P('data', 'model')) self.assertEqual(v_pspec.sharding.spec, P('data', 'model')) # Test with NamedSharding ns = NamedSharding(mesh, P('data', None)) v_namedsharding = nnx.Variable(value, out_sharding=ns) self.assertEqual(v_namedsharding.sharding, ns) # Test with Format if axis_type_name == 'auto': v_format = nnx.Variable(value, out_sharding=Format(Layout(major_to_minor=(1, 0)), ns)) self.assertEqual(v_format.sharding, ns) def test_get_abstract_with_abstract_mesh(self): mesh = jax.make_mesh( (2, 2), ('a', 'b'), axis_types=(jax.sharding.AxisType.Auto,) * 2, ) with jax.set_mesh(mesh): abs_model = nnx.eval_shape( lambda: nnx.Linear( 4, 8, rngs=nnx.Rngs(0), kernel_metadata={'out_sharding': ('a', 'b')}, ) ) abs_model = nnx.abstract_with_sharding(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertEqual(abs_model.kernel.sharding.spec, P('a', 'b')) self.assertEqual( abs_model.kernel.sharding.mesh.axis_names, mesh.axis_names, ) def test_get_abstract_with_per_variable_mesh(self): mesh1 = jax.make_mesh( (2, 2), ('a', 'b'), axis_types=(jax.sharding.AxisType.Auto,) * 2, ) mesh2 = jax.make_mesh( (1, 4), ('c', 'd'), axis_types=(jax.sharding.AxisType.Auto,) * 2, ) class Model(nnx.Module): def __init__(self): self.p1 = nnx.Linear( 4, 8, rngs=nnx.Rngs(0), kernel_metadata={'out_sharding': ('a', 'b'), 'mesh': mesh1}, ) self.p2 = nnx.Linear( 4, 8, rngs=nnx.Rngs(0), kernel_metadata={'out_sharding': ('c', 'd'), 'mesh': mesh2}, ) abs_model = nnx.eval_shape(lambda: Model()) abs_model = nnx.abstract_with_sharding(abs_model) self.assertEqual(abs_model.p1.kernel.sharding.spec, P('a', 'b')) self.assertEqual(abs_model.p1.kernel.sharding.mesh, mesh1) self.assertEqual(abs_model.p2.kernel.sharding.spec, P('c', 'd')) self.assertEqual(abs_model.p2.kernel.sharding.mesh, mesh2) def test_get_abstract_no_sharding_metadata(self): abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0))) abs_model = nnx.abstract_with_sharding(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertIsNone( getattr(abs_model.kernel.get_value(), 'sharding', None) ) def has_sharding_spec(array): sharding = array.sharding if hasattr(sharding, 'spec'): # For NamedSharding or PositionalSharding return sharding.spec is not None and any( s is not None for s in sharding.spec ) return False if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/state_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest from flax import nnx import jax from jax import numpy as jnp class StateTest(absltest.TestCase): def test_create_state(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) assert state['a'][...] == 1 assert state['b']['c'][...] == 2 def test_get_attr(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) assert state.a[...] == 1 assert state.b.c[...] == 2 def test_set_attr(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) state.a[...] = 3 state.b.c[...] = 4 assert state['a'][...] == 3 assert state['b']['c'][...] == 4 def test_set_attr_variables(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) state.a[...] = 3 state.b.c[...] = 4 assert isinstance(state.a, nnx.Param) assert state.a[...] == 3 assert isinstance(state.b.c, nnx.Param) assert state.b.c[...] == 4 def test_add_nested_attr(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) state.b.d = nnx.Param(jnp.array(5)) assert state['b']['d'][...] == 5 def test_delete_nested_attr(self): state = nnx.State( {'a': nnx.Param(jnp.array(1)), 'b': {'c': nnx.Param(jnp.array(2))}} ) del state['b']['c'] assert 'c' not in state['b'] def test_integer_access(self): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.layers = nnx.List([ nnx.Linear(1, 2, rngs=rngs), nnx.Linear(2, 3, rngs=rngs) ]) module = Foo(rngs=nnx.Rngs(0)) state = nnx.state(module) assert module.layers[0].kernel.shape == (1, 2) assert state.layers[0].kernel.shape == (1, 2) assert module.layers[1].kernel.shape == (2, 3) assert state.layers[1].kernel.shape == (2, 3) def test_pure_dict(self): module = nnx.Linear(4, 5, rngs=nnx.Rngs(0)) state = nnx.state(module) pure_dict = nnx.to_pure_dict(state) assert isinstance(pure_dict, dict) assert isinstance(pure_dict['kernel'], jax.Array) assert isinstance(pure_dict['bias'], jax.Array) nnx.replace_by_pure_dict(state, jax.tree.map(jnp.zeros_like, pure_dict)) assert isinstance(state, nnx.State) assert isinstance(state['kernel'], nnx.Variable) assert jnp.array_equal(state['kernel'][...], jnp.zeros((4, 5))) assert type(state['kernel']) == nnx.Param nnx.update(module, state) assert jnp.array_equal(module(jnp.ones((3, 4))), jnp.zeros((3, 5))) def test_diff(self): class MLPs(nnx.Module): def __init__(self, dim, rngs: nnx.Rngs, n=4): self.layers = nnx.List() for _ in range(n): self.layers.append(nnx.Linear(dim, dim, rngs=rngs, use_bias=False)) def __call__(self, x): for layer in self.layers: x = layer(x) return x model1 = MLPs(4, rngs=nnx.Rngs(0), n=4) model2 = MLPs(4, rngs=nnx.Rngs(1), n=4) model3 = MLPs(4, rngs=nnx.Rngs(1), n=5) self.assertEqual( nnx.statelib.diff(nnx.state(model2), nnx.state(model1)), nnx.state({}) ) self.assertNotEqual( nnx.statelib.diff(nnx.state(model3), nnx.state(model1)), nnx.state({}) ) self.assertEqual( nnx.statelib.diff(nnx.state(model1), nnx.state(model3)), nnx.state({}) ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/summary_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import jax.numpy as jnp from absl.testing import absltest from flax import nnx CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000) class SummaryTest(absltest.TestCase): def test_tabulate(self): class Block(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) self.dropout = nnx.Dropout(0.2, rngs=rngs) def forward(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) class Foo(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.block1 = Block(32, 128, rngs=rngs) self.block2 = Block(128, 10, rngs=rngs) def __call__(self, x): return self.block2.forward(self.block1.forward(x)) foo = Foo(nnx.Rngs(0)) x = jnp.ones((1, 32)) table_repr_ = nnx.tabulate( foo, x, console_kwargs=CONSOLE_TEST_KWARGS ) table_repr = table_repr_.splitlines() self.assertIn('Foo Summary', table_repr[0]) self.assertIn('path', table_repr[2]) self.assertIn('type', table_repr[2]) self.assertIn('BatchStat', table_repr[2]) self.assertIn('Param', table_repr[2]) self.assertIn('block1/forward', table_repr[6]) self.assertIn('Block', table_repr[6]) self.assertIn('block1/linear', table_repr[8]) self.assertIn('Linear', table_repr[8]) self.assertIn('block1/bn', table_repr[13]) self.assertIn('BatchNorm', table_repr[13]) self.assertIn('block1/dropout', table_repr[18]) self.assertIn('Dropout', table_repr[18]) self.assertIn('block2/forward', table_repr[20]) self.assertIn('Block', table_repr[20]) self.assertIn('block2/linear', table_repr[22]) self.assertIn('Linear', table_repr[22]) self.assertIn('block2/bn', table_repr[27]) self.assertIn('BatchNorm', table_repr[27]) self.assertIn('block2/dropout', table_repr[32]) self.assertIn('Dropout', table_repr[32]) self.assertIn('Total', table_repr[34]) self.assertIn('276 (1.1 KB)', table_repr[34]) self.assertIn('5,790 (23.2 KB)', table_repr[34]) self.assertIn('4 (24 B)', table_repr[34]) self.assertIn('Total Parameters: 6,070 (24.3 KB)', table_repr[37]) def test_multiple_inputs_and_outputs(self): class CustomMLP(nnx.Module): def __init__(self): self.weight = nnx.Param(jnp.ones((4, 8))) self.bias = nnx.Param(jnp.ones(8)) def __call__(self, x, x2): y = x @ self.weight y += self.bias[None] y += x2 return x, y, 2 * y cmlp = CustomMLP() x = jnp.ones((1, 4)) x2 = jnp.ones((1, 8)) table_repr = nnx.tabulate( cmlp, x, x2, console_kwargs=CONSOLE_TEST_KWARGS ).splitlines() self.assertIn('CustomMLP Summary', table_repr[0]) self.assertIn('float32[1,4]', table_repr[4]) self.assertIn('float32[1,8]', table_repr[5]) self.assertIn('float32[1,8]', table_repr[6]) def test_tabulate_empty_dict_first_arg(self): class Model(nnx.Module): def subroutine(self, foo, x): return x def __call__(self, x): return self.subroutine({}, x) model = Model() out = nnx.tabulate( model, jnp.zeros((1, 8)), depth=1, console_kwargs=CONSOLE_TEST_KWARGS ) # Ensure empty dict argument is preserved and array input is shown self.assertIn('{}', out) self.assertIn('float32[1,8]', out) def test_tabulate_empty_dict_last_arg(self): class Model(nnx.Module): def subroutine(self, foo, x): return x def __call__(self, x): return self.subroutine(x, {}) model = Model() out = nnx.tabulate( model, jnp.zeros((1, 8)), depth=1, console_kwargs=CONSOLE_TEST_KWARGS ) # Ensure trailing empty dict is not dropped self.assertIn('{}', out) def test_tabulate_empty_dict_and_none_kwarg(self): class Model(nnx.Module): def subroutine(self, x, *, foo=None): return x def __call__(self, x): # One call with empty dict, one with None _ = self.subroutine(x, foo={}) return self.subroutine(x, foo=None) model = Model() out = nnx.tabulate( model, jnp.zeros((1, 8)), depth=2, console_kwargs=CONSOLE_TEST_KWARGS ) # Distinguish {} and None in output self.assertIn('{}', out) self.assertIn('None', out) def test_tabulate_empty_dict_property(self): class Model(nnx.Module): def __init__(self): self.foo = {} def subroutine(self, foo, x): return x def __call__(self, x): return self.subroutine(self.foo, x) model = Model() out = nnx.tabulate( model, jnp.zeros((1, 1024)), depth=1, console_kwargs=CONSOLE_TEST_KWARGS ) # Should not crash and should show the empty dict argument self.assertIn('{}', out) def test_no_dup_flops(self): class Model(nnx.Module): def g(self, x): return x**2 def __call__(self, x): return self.g(x) m = Model() x = jnp.ones(4) table_rep = nnx.tabulate(m, x, compute_flops=True) table_lines = table_rep.splitlines() self.assertEqual(sum(" g " in l for l in table_lines), 1) def test_flops(self): class Model(nnx.Module): def __init__(self): self.weight = nnx.Param(jnp.ones(4)) def __call__(self, x1): return jnp.sum((x1 * self.weight)**2) m = Model() x = jnp.ones(4) table_repr1 = nnx.tabulate( m, x, compute_flops=True ).splitlines() self.assertIn('flops', table_repr1[2]) self.assertNotIn('vjp_flops', table_repr1[2]) table_repr2 = nnx.tabulate( m, x, compute_flops=True, compute_vjp_flops=True ).splitlines() self.assertIn('vjp_flops', table_repr2[2]) def test_nested(self): class Block(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 2, rngs=rngs) self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, rngs): self.block1 = Block(rngs) self.block2 = Block(rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x m = Model(nnx.Rngs(0)) x = jnp.ones((4, 2)) table = nnx.tabulate(m, x, compute_flops=True, compute_vjp_flops=True) # We should see 3 calls per block, plus one overall call self.assertEqual(sum([s.startswith("├─") for s in table.splitlines()]), 7) def test_time_complexity(self): counter = [] class Block(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 2, rngs=rngs) def __call__(self, x): counter.append(1) return self.linear(x) class Model(nnx.Module): def __init__(self, rngs): for d in range(10): setattr(self, f"linear{d}", Block(rngs)) def __call__(self, x): for d in range(10): x = getattr(self, f"linear{d}")(x) return x m = Model(nnx.Rngs(0)) x = jnp.ones((4, 2)) nnx.tabulate(m, x, compute_flops=True, compute_vjp_flops=False) self.assertEqual(len(counter), 10) def test_shared(self): class Block(nnx.Module): def __init__(self, linear: nnx.Linear, *, rngs): self.linear = linear self.bn = nnx.BatchNorm(2, rngs=rngs) def __call__(self, x): x = self.linear(x) x = self.bn(x) return nnx.relu(x) class Model(nnx.Module): def __init__(self, rngs): shared = nnx.Linear(2, 2, rngs=rngs) self.block1 = Block(shared, rngs=rngs) self.block2 = Block(shared, rngs=rngs) def __call__(self, x): x = self.block1(x) x = self.block2(x) return x m = Model(nnx.Rngs(0)) x = jnp.ones((4, 2)) table = nnx.tabulate(m, x, compute_vjp_flops=True) # We should see 3 calls per block, plus one overall call, minus the shared call self.assertEqual(sum([s.startswith("├─") for s in table.splitlines()]), 6) def test_tabulate_with_variable_hooks(self): """Test that tabulate works with Variables implementing hooks and custom metadata.""" class Custom: def __repr__(self): return "" class VarWithHooks(nnx.Variable): def on_get_value(self, value): return value def on_set_value(self, value): return value + 1.0 class Model(nnx.Module): def __init__(self): # Variable with hooks self.hooked_param = VarWithHooks(value=jnp.ones((2, 3))) self.hooked_param.set_metadata('description', 'Custom parameter') self.hooked_param.set_metadata('trainable', True) # Variable with custom non-serializable metadata self.custom_param = nnx.Param(jnp.ones((2, 2))) self.custom_param.set_metadata('custom_obj', Custom()) def __call__(self, x): return jnp.dot(x, self.hooked_param[...]) + self.custom_param.sum() module = Model() # Should not raise yaml.representer.RepresenterError table_repr = nnx.tabulate(module, jnp.ones((1, 2)), console_kwargs=CONSOLE_TEST_KWARGS) self.assertIsNotNone(table_repr) # Verify table contains expected content self.assertIn('Model Summary', table_repr) self.assertIn('hooked_param', table_repr) self.assertIn('on_set_value', table_repr) self.assertIn('', table_repr) # Verify metadata is preserved in the module self.assertEqual(module.hooked_param.get_metadata('description'), 'Custom parameter') self.assertEqual(module.hooked_param.get_metadata('trainable'), True) def test_tabulate_concrete_shape(self): class Net(nnx.Module): def __init__(self): self.rngs = nnx.Rngs(0) def __call__(self, x): return self.rngs.uniform((x.shape[0], 10)) net = Net() x = jnp.zeros((4, 8)) nnx.tabulate(net, x, console_kwargs={"width": 200}) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/test_traversals.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.nnx.traversal.""" from absl.testing import absltest from flax.core import freeze from flax.nnx import traversals import jax # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class TraversalTest(absltest.TestCase): def test_flatten_mapping(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traversals.flatten_mapping(xs) self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, }, ) flat_xs = traversals.flatten_mapping(freeze(xs)) self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, }, ) flat_xs = traversals.flatten_mapping(xs, sep='/') self.assertEqual( flat_xs, { 'foo': 1, 'bar/a': 2, }, ) def test_unflatten_mapping(self): expected_xs = {'foo': 1, 'bar': {'a': 2}} xs = traversals.unflatten_mapping( { ('foo',): 1, ('bar', 'a'): 2, } ) self.assertEqual(xs, expected_xs) xs = traversals.unflatten_mapping( { 'foo': 1, 'bar/a': 2, }, sep='/', ) self.assertEqual(xs, expected_xs) def test_flatten_mapping_keep_empty(self): ys = {'a': {}} xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_ys = traversals.flatten_mapping(ys, keep_empty_nodes=True) flat_xs = traversals.flatten_mapping(xs, keep_empty_nodes=True) empty_node = flat_ys[('a',)] self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, ('bar', 'b'): empty_node, }, ) xs_restore = traversals.unflatten_mapping(flat_xs) self.assertEqual(xs, xs_restore) def test_flatten_mapping_is_leaf(self): xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}} flat_xs = traversals.flatten_mapping( xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2 ) self.assertEqual( flat_xs, { ('foo', 'c'): 4, ('bar',): {'a': 2, 'b': {}}, }, ) xs_restore = traversals.unflatten_mapping(flat_xs) self.assertEqual(xs, xs_restore) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/transforms_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4' import dataclasses from functools import partial import typing as tp from absl.testing import absltest from absl.testing import parameterized from flax import nnx from flax.nnx.transforms.iteration import pure_jax_fancy_scan from flax.nnx.transforms import general import jax from jax.experimental import checkify, mesh_utils import jax.numpy as jnp import numpy as np import optax from flax import errors class TestJIT(parameterized.TestCase): def test_jit(self): m = nnx.Dict(a=nnx.Param(1)) @nnx.jit def g(m: nnx.Dict): m.a = 2 return 1.0 out = g(m) assert m.a == 2 assert out == 1.0 def test_mutable_array_input_output(self): m = jax.new_ref(jnp.array(1.0)) @nnx.jit def f(m: jax.Ref): m[...] += 1.0 m2 = jax.new_ref(jnp.array(10.0)) return m2, m m2, m_out = f(m) self.assertEqual(m[...], 2.0) self.assertIs(m, m_out) self.assertIsInstance(m2, jax.Ref) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_simple_double_call(self, graph_mode, graph_updates): n = 0 m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.jit(graph=graph_mode, graph_updates=graph_updates) def f(m: nnx.Linear, x: jnp.ndarray) -> jnp.ndarray: nonlocal n n += 1 return m(x) x = jnp.ones((1, 2)) y = f(m, x) self.assertEqual(n, 1) self.assertEqual(y.shape, (1, 3)) y = f(m, x) self.assertEqual(n, 1) self.assertEqual(y.shape, (1, 3)) def test_jit_on_init(self): n = 0 class Foo(nnx.Module): @nnx.jit(static_argnums=(1, 2)) def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): nonlocal n n += 1 key = rngs.params() self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) self.din = din self.dout = dout m = Foo(2, 3, rngs=nnx.Rngs(0)) assert n == 1 assert m.w.shape == (2, 3) assert m.din == 2 assert m.dout == 3 assert isinstance(m.din, int) assert isinstance(m.dout, int) assert isinstance(m.w[...], jax.Array) m = Foo(2, 3, rngs=nnx.Rngs(0)) assert n == 1 @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jit_on_call(self, graph_mode, graph_updates): n = 0 class Foo(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) self.din = din self.dout = dout @nnx.jit(graph=graph_mode, graph_updates=graph_updates) def __call__(self, x: jax.Array) -> jax.Array: nonlocal n n += 1 return jnp.dot(x, self.w) m = Foo(2, 3, rngs=nnx.Rngs(0)) assert m.w.shape == (2, 3) assert m.din == 2 assert m.dout == 3 assert isinstance(m.din, int) assert isinstance(m.dout, int) assert isinstance(m.w[...], jax.Array) y = m(jnp.ones((1, 2))) assert y.shape == (1, 3) assert n == 1 y = m(jnp.ones((1, 2))) assert n == 1 def test_cached_unflatten(self): n = 0 class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.BatchNorm(2, rngs=rngs) @nnx.jit def f(m: Foo): nonlocal n n += 1 m.a, m.b = m.b, m.a # type: ignore m = Foo(rngs=nnx.Rngs(0)) a = m.a b = m.b a_kernel = a.kernel[...] a_bias = a.bias[...] b_scale = b.scale[...] b_bias = b.bias[...] b_mean = b.mean[...] b_var = b.var[...] f(m) assert n == 1 assert m.a is b assert m.b is a np.testing.assert_allclose(a_kernel, a.kernel[...]) np.testing.assert_allclose(a_bias, a.bias[...]) np.testing.assert_allclose(b_scale, b.scale[...]) np.testing.assert_allclose(b_bias, b.bias[...]) np.testing.assert_allclose(b_mean, b.mean[...]) np.testing.assert_allclose(b_var, b.var[...]) f(m) assert n == 2 assert m.a is a assert m.b is b f(m) assert n == 2 assert m.a is b assert m.b is a f(m) assert n == 2 assert m.a is a assert m.b is b @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jit_custom_vjp(self, graph_mode, graph_updates): @nnx.custom_vjp(graph=graph_mode, graph_updates=graph_updates) def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd) nnx_out = nnx.jit(f, graph=graph_mode, graph_updates=graph_updates)(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) jax_out = jax.jit(f)(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) assert (nnx_out == jax_out).all() def test_cached_unflatten_same_type(self): n = 0 class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.Linear(2, 2, rngs=rngs) @nnx.jit def f(m: Foo): nonlocal n n += 1 m.a, m.b = m.b, m.a m = Foo(rngs=nnx.Rngs(0)) a = m.a b = m.b f(m) assert n == 1 assert m.a is b assert m.b is a f(m) assert n == 1 assert m.a is a assert m.b is b def test_objects_in_pytree(self): n = 0 class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.Linear(2, 2, rngs=rngs) class FooDict(tp.TypedDict): foo: Foo @nnx.jit def f(tree: tuple[FooDict]): nonlocal n n += 1 m = tree[0]['foo'] m.a, m.b = m.b, m.a m = Foo(rngs=nnx.Rngs(0)) a = m.a b = m.b f(({'foo': m},)) assert n == 1 assert m.a is b assert m.b is a f(({'foo': m},)) assert n == 1 assert m.a is a assert m.b is b def test_cached_unflatten_swap_variables(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.Param(2) @nnx.jit def f(m: Foo): m.a, m.b = m.b, m.a m = Foo() a = m.a b = m.b f(m) assert m.a is b assert m.b is a def test_cached_unflatten_add_self_reference(self): n = 0 class Foo(nnx.Module): def __init__(self): self.ref: tp.Optional[Foo] = nnx.data(None) # type: ignore[name-error] @nnx.jit def f(m: Foo): nonlocal n n += 1 m.ref = m m = Foo() f(m) assert n == 1 assert m.ref is m f(m) assert n == 2 assert m.ref is m f(m) assert n == 2 assert m.ref is m def test_cached_unflatten_ref_in_output(self): n = 0 class Foo(nnx.Module): def __init__(self): self.ref: tp.Optional[Foo] = nnx.data(None) # type: ignore[name-error] @nnx.jit def f(m: Foo): nonlocal n n += 1 m.ref = m return m m = Foo() m2 = f(m) assert n == 1 assert m.ref is m assert m2 is m m2 = f(m) assert n == 2 assert m.ref is m assert m2 is m m2 = f(m) assert n == 2 assert m.ref is m assert m2 is m def test_apply_shardings(self): n_devices = max(jax.local_device_count() // 2, 1) devices = mesh_utils.create_device_mesh( (n_devices, jax.local_device_count() // n_devices) ) mesh = jax.sharding.Mesh(devices, ('a', 'b')) def sharding(*args): return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) state_sharding = nnx.StateSharding( { nnx.PathContains('kernel'): sharding('a', 'b'), nnx.PathContains('bias'): sharding('b'), } ) m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @nnx.jit(in_shardings=(state_sharding,)) def constrain_object(m): pass constrain_object(m) self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) def test_cache_args(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.jit def f(cached_m: nnx.Linear, m: nnx.Linear): self.assertIsNot(cached_m, m) self.assertIs(cached_m.kernel, m.kernel) self.assertIs(cached_m.bias, m.bias) return cached_m cached_f = nnx.cached_partial(f, m) cached_m = cached_f(m) self.assertIsNot(m, cached_m) self.assertIs(m.kernel, cached_m.kernel) self.assertIs(m.bias, cached_m.bias) # test that cached m is reused cached_m2 = cached_f(m) self.assertIs(cached_m, cached_m2) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jit_wrapped(self, graph_mode, graph_updates): class Foo(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.count = nnx.Variable(jnp.array(0)) @nnx.jit(graph=graph_mode, graph_updates=graph_updates) def __call__(self, x: jax.Array) -> jax.Array: self.count[...] += 1 return x * 2 m = Foo(rngs=nnx.Rngs(0)) x = jnp.array(3.0) @nnx.jit(graph=graph_mode, graph_updates=graph_updates) def f(m: nnx.Linear, x): return m(x) lowered = f.lower(m, x) compiled = lowered.compile() text = compiled.as_text() cost_analysis = compiled.cost_analysis() self.assertIsNotNone(cost_analysis) self.assertIsNotNone(text) y = compiled(m, x) np.testing.assert_allclose(y, 6.0) self.assertEqual(m.count[...], 1) y = compiled(m, x) self.assertEqual(m.count[...], 2) @parameterized.parameters( {'graph_mode': True, 'graph_updates': True, 'static_argnums': (2,), 'static_argnames': None}, {'graph_mode': True, 'graph_updates': True, 'static_argnums': None, 'static_argnames': ('use_relu',)}, {'graph_mode': True, 'graph_updates': False, 'static_argnums': (2,), 'static_argnames': None}, {'graph_mode': True, 'graph_updates': False, 'static_argnums': None, 'static_argnames': ('use_relu',)}, {'graph_mode': False, 'graph_updates': False, 'static_argnums': (2,), 'static_argnames': None}, {'graph_mode': False, 'graph_updates': False, 'static_argnums': None, 'static_argnames': ('use_relu',)}, ) def test_jit_static_args_with_shardings(self, graph_mode, graph_updates, static_argnums, static_argnames): """Test static arguments work correctly with in_shardings.""" n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('data',)) def fn(x, scale, use_relu): y = x * scale if use_relu: y = jnp.maximum(y, 0) return y.sum() x = jnp.linspace(-1.0, 1.0, 16, dtype=jnp.float32).reshape(4, 4) x_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('data')) f = nnx.jit(fn, in_shardings=(x_sharding, None), static_argnums=static_argnums, static_argnames=static_argnames, graph=graph_mode, graph_updates=graph_updates) y_relu = f(x, 0.5, True) y_no_relu = f(x, 0.5, False) self.assertNotEqual(y_relu, y_no_relu) @parameterized.parameters( { 'static_args': {'static_argnums': (2, 3)}, }, { 'static_args': {'static_argnames': ('static_arg1', 'static_arg2')}, }, ) def test_with_sharding_and_static_args(self, static_args): n_devices = max(jax.local_device_count() // 2, 1) devices = mesh_utils.create_device_mesh( (n_devices, jax.local_device_count() // n_devices) ) mesh = jax.sharding.Mesh(devices, ('a', 'b')) def sharding(*args): return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) state_sharding = nnx.StateSharding( { nnx.PathContains('kernel'): sharding('a', 'b'), nnx.PathContains('bias'): sharding('b'), } ) m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @nnx.jit( in_shardings=(state_sharding, None), **static_args, ) def constrain_object(m, scale: float, static_arg1: bool, static_arg2: bool): new_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('b', 'a')) m.kernel = jax.lax.with_sharding_constraint(m.kernel, new_sharding) return None constrain_object(m, 0.5, True, True) self.assertEqual(m.kernel.sharding.spec, jax.sharding.PartitionSpec("a", "b")) class TestTreeJIT(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_jit_basic(self, graph, graph_updates): m = nnx.Dict(a=nnx.Param(jnp.array(1))) @nnx.jit(graph=graph, graph_updates=graph_updates) def g(m: nnx.Dict): m.a[...] = 2 return 1.0 out = g(m) assert m.a[...] == 2 assert out == 1.0 @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_jit_module(self, graph, graph_updates): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(m, x): return m(x) x = jnp.ones((1, 2)) y = f(m, x) self.assertEqual(y.shape, (1, 3)) def test_tree_jit_variable_update(self): class Foo(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) @nnx.jit(graph=False) def __call__(self, x): self.count[...] += 1 return x * 2 m = Foo() y = m(jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) self.assertEqual(m.count[...], 1) y = m(jnp.array(3.0)) self.assertEqual(m.count[...], 2) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_jit_no_retrace(self, graph, graph_updates): n = 0 m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(m, x): nonlocal n n += 1 return m(x) x = jnp.ones((1, 2)) y = f(m, x) self.assertEqual(n, 1) self.assertEqual(y.shape, (1, 3)) y = f(m, x) self.assertEqual(n, 1) self.assertEqual(y.shape, (1, 3)) def test_tree_jit_static_argnums(self): @nnx.jit(graph=False, static_argnums=(1,)) def f(x, use_relu): if use_relu: return jnp.maximum(x, 0) return x x = jnp.array([-1.0, 2.0]) y_relu = f(x, True) np.testing.assert_allclose(y_relu, jnp.array([0.0, 2.0])) y_no_relu = f(x, False) np.testing.assert_allclose(y_no_relu, x) def test_tree_jit_no_input_output_aliasing(self): v = nnx.Param(jnp.array(1.0)) @nnx.jit(graph=False) def f(v): return v with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): f(v) def test_tree_jit_no_shared_variable_refs(self): v = nnx.Param(jnp.array(1.0)) @nnx.jit(graph=False) def f(v1, v2): pass with self.assertRaisesRegex( ValueError, 'found at paths' ): f(v, v) def test_tree_jit_new_variable_output_ok(self): @nnx.jit(graph=False) def f(x): return nnx.Param(x + 1) v = f(jnp.array(1.0)) self.assertIsInstance(v, nnx.Param) np.testing.assert_allclose(v[...], 2.0) def test_tree_jit_donate_argnums_unchanged_var(self): v = nnx.Param(jnp.array(1.0)) @nnx.jit(graph=False, donate_argnums=(0,)) def f(v): return v[...] + 1.0 out = f(v) np.testing.assert_allclose(out, 2.0) np.testing.assert_allclose(v[...], 1.0) out = f(v) np.testing.assert_allclose(out, 2.0) np.testing.assert_allclose(v[...], 1.0) def test_tree_jit_donate_argnums_module(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) original_kernel = jnp.copy(m.kernel[...]) @nnx.jit(graph=False, donate_argnums=(0,)) def f(m, x): return m(x) x = jnp.ones((1, 2)) y = f(m, x) self.assertEqual(y.shape, (1, 3)) np.testing.assert_allclose(m.kernel[...], original_kernel) y = f(m, x) self.assertEqual(y.shape, (1, 3)) np.testing.assert_allclose(m.kernel[...], original_kernel) def test_tree_jit_donate_argnums_with_mutation(self): v = nnx.Param(jnp.array(0.0)) @nnx.jit(graph=False, donate_argnums=(0,)) def f(v): v[...] += 1.0 return None f(v) np.testing.assert_allclose(v[...], 1.0) f(v) np.testing.assert_allclose(v[...], 2.0) def test_tree_jit_donate_argnames(self): v = nnx.Param(jnp.array(1.0)) @nnx.jit(graph=False, donate_argnames=('v',)) def f(v): return v[...] + 1.0 out = f(v=v) np.testing.assert_allclose(out, 2.0) np.testing.assert_allclose(v[...], 1.0) out = f(v=v) np.testing.assert_allclose(out, 2.0) np.testing.assert_allclose(v[...], 1.0) def test_tree_jit_donate_selective(self): donated = nnx.Param(jnp.array(1.0)) not_donated = nnx.Param(jnp.array(2.0)) @nnx.jit(graph=False, donate_argnums=(0,)) def f(donated, not_donated): return donated[...] + not_donated[...] out = f(donated, not_donated) np.testing.assert_allclose(out, 3.0) np.testing.assert_allclose(donated[...], 1.0) np.testing.assert_allclose(not_donated[...], 2.0) out = f(donated, not_donated) np.testing.assert_allclose(out, 3.0) @parameterized.parameters(True, False) def test_jit_partial_basic(self, graph_mode): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) def f(m, x): return m(x) f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) x = jnp.ones((1, 2)) y = f_jit(x) self.assertEqual(y.shape, (1, 3)) @parameterized.parameters(True, False) def test_jit_partial_lower_compile(self, graph_mode): class Foo(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) self.w = nnx.Param(jnp.ones((2, 3))) m = Foo() def f(m, x): m.count[...] += 1 return x @ m.w[...] f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) compiled = f_jit.lower(jnp.ones((1, 2))).compile() x = jnp.ones((1, 2)) y = compiled(x) self.assertEqual(y.shape, (1, 3)) np.testing.assert_allclose(y, jnp.ones((1, 3)) * 2) self.assertEqual(m.count[...], 1) y = compiled(x) self.assertEqual(m.count[...], 2) @parameterized.parameters(True, False) def test_jit_partial_variable_update(self, graph_mode): class Foo(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) m = Foo() def f(m, x): m.count[...] += 1 return x * 2 f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) y = f_jit(jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) self.assertEqual(m.count[...], 1) y = f_jit(jnp.array(3.0)) self.assertEqual(m.count[...], 2) @parameterized.parameters(True, False) def test_jit_partial_multiple_args(self, graph_mode): m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(3, 4, rngs=nnx.Rngs(1)) def f(m1, m2, x): return m2(m1(x)) f_jit = nnx.jit_partial(f, m1, m2, graph=graph_mode, graph_updates=False) x = jnp.ones((1, 2)) y = f_jit(x) self.assertEqual(y.shape, (1, 4)) @parameterized.parameters(True, False) def test_jit_partial_no_retrace(self, graph_mode): n = 0 m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) def f(m, x): nonlocal n n += 1 return m(x) f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) x = jnp.ones((1, 2)) f_jit(x) self.assertEqual(n, 1) f_jit(x) self.assertEqual(n, 1) @parameterized.parameters(True, False) def test_jit_partial_no_retrace_after_mutation(self, graph_mode): n = 0 class Foo(nnx.Module): def __init__(self): self.w = nnx.Param(jnp.ones((2, 3))) m = Foo() def f(m, x): nonlocal n n += 1 return x @ m.w[...] f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) x = jnp.ones((1, 2)) f_jit(x) self.assertEqual(n, 1) m.w[...] = jnp.zeros((2, 3)) y = f_jit(x) self.assertEqual(n, 1) np.testing.assert_allclose(y, jnp.zeros((1, 3))) @parameterized.parameters(True, False) def test_jit_partial_no_partial_args(self, graph_mode): f_partial = nnx.jit_partial(lambda x: x * 2, graph=graph_mode, graph_updates=False) y = f_partial(jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) @parameterized.parameters((False,)) def test_jit_partial_in_shardings_none_broadcast(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('data',)) m = nnx.Linear(4, 3, rngs=nnx.Rngs(0)) def f(m, x): return m(x) f_jit = nnx.jit_partial(f, m, in_shardings=(None, None), graph=graph_mode, graph_updates=False) x = jnp.ones((n_devices, 4)) y = f_jit(x) self.assertEqual(y.shape, (n_devices, 3)) @parameterized.parameters((False,)) def test_jit_partial_in_shardings_named(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('data',)) PS = jax.sharding.PartitionSpec v = nnx.Param(jnp.ones((n_devices, 4))) def f(v, x): return v[...] + x x_sharding = jax.sharding.NamedSharding(mesh, PS('data')) v_sharding = jax.sharding.NamedSharding(mesh, PS('data')) f_jit = nnx.jit_partial( f, v, in_shardings=(v_sharding, x_sharding), graph=graph_mode, graph_updates=False) x = jnp.ones((n_devices, 4)) y = f_jit(x) self.assertEqual(y.shape, (n_devices, 4)) @parameterized.parameters((False,)) def test_jit_partial_mixed_shardings(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('data',)) PS = jax.sharding.PartitionSpec m1 = nnx.Linear(4, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(3, 2, rngs=nnx.Rngs(1)) def f(m1, m2, x): return m2(m1(x)) x_sharding = jax.sharding.NamedSharding(mesh, PS('data')) f_jit = nnx.jit_partial( f, m1, m2, in_shardings=(None, None, x_sharding), graph=graph_mode, graph_updates=False) x = jnp.ones((n_devices, 4)) y = f_jit(x) self.assertEqual(y.shape, (n_devices, 2)) @parameterized.parameters(True, False) def test_jit_partial_in_shardings_non_tuple(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('data',)) PS = jax.sharding.PartitionSpec m = nnx.Linear(4, 3, rngs=nnx.Rngs(0)) def f(m, x): return m(x) sharding = jax.sharding.NamedSharding(mesh, PS()) f_jit = nnx.jit_partial(f, m, in_shardings=sharding, graph=graph_mode, graph_updates=False) x = jnp.ones((n_devices, 4)) y = f_jit(x) self.assertEqual(y.shape, (n_devices, 3)) @parameterized.parameters(True, False) def test_jit_partial_train_step(self, graph_mode): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) def train_step(model, optimizer, x, y): def loss_fn(model): return jnp.mean((model(x) - y) ** 2) loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) return loss train_step_fn = nnx.jit_partial(train_step, model, optimizer, graph=graph_mode, graph_updates=False) for _ in range(2): x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) loss = train_step_fn(x, y) self.assertIsInstance(loss, jax.Array) def test_jit_partial_shared_variable(self): v = nnx.Param(jnp.array(1.0)) class Container(nnx.Module): def __init__(self, v): self.v = v c1 = Container(v) c2 = Container(v) def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] f_jit = nnx.jit_partial(f, c1, c2, graph=True, graph_updates=False) y = f_jit(jnp.array(1.0)) np.testing.assert_allclose(y, 4.0) np.testing.assert_allclose(v[...], 2.0) def test_jit_inconsistent_aliasing(self): v = nnx.Param(jnp.array(1.0)) P = jax.sharding.PartitionSpec @nnx.jit( in_shardings=(P(), P('x')), graph=True, graph_updates=False, ) def f(a, b): return a[...] + b[...] mesh = jax.sharding.Mesh(jax.devices(), ('x',)) with mesh: with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) class TestEvalShape(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_eval_shape(self, graph, graph_updates): abs_model = nnx.eval_shape( lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0)), graph=graph, graph_updates=graph_updates, ) self.assertIsInstance(abs_model, nnx.Linear) self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) def test_eval_shape_mutable_array(self): with nnx.var_defaults(hijax=True): abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) self.assertIsInstance(abs_model, nnx.Linear) self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) self.assertEqual(abs_model.kernel.shape, (1, 2)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_eval_shape_with_module_input(self, graph, graph_updates): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) def f(m, x): return m(x) x = jnp.ones((4, 2)) out = nnx.eval_shape(f, model, x, graph=graph, graph_updates=graph_updates) self.assertIsInstance(out, jax.ShapeDtypeStruct) self.assertEqual(out.shape, (4, 3)) @parameterized.parameters( (True, False), (False, False), ) def test_eval_shape_no_state_update(self, graph, graph_updates): count = nnx.Variable(jnp.array(0)) def f(c): c[...] += 1 return jnp.ones((3, 4)) * c[...] out = nnx.eval_shape(f, count, graph=graph, graph_updates=graph_updates) self.assertIsInstance(out, jax.ShapeDtypeStruct) self.assertEqual(out.shape, (3, 4)) self.assertEqual(count[...], 0) @parameterized.parameters( (True, False), (False, False), ) def test_eval_shape_no_input_output_aliasing(self, graph, graph_updates): v = nnx.Param(jnp.array(1.0)) def f(v): return v with self.assertRaises(ValueError): nnx.eval_shape(f, v, graph=graph, graph_updates=graph_updates) @parameterized.parameters( (True, False), (False, False), ) def test_eval_shape_no_shared_variable_refs(self, graph, graph_updates): v = nnx.Param(jnp.array(1.0)) def f(v1, v2): v1[...] += 1 return None with self.assertRaises(ValueError): nnx.eval_shape(f, v, v, graph=graph, graph_updates=graph_updates) class TestShardMap(parameterized.TestCase): def test_basic_shardmap(self): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('a',)) PS = jax.sharding.PartitionSpec state_sharding = nnx.StateSharding( { nnx.PathContains('kernel'): PS(None, 'a'), nnx.PathContains('bias'): PS(), } ) m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) def f(m: nnx.Linear): self.assertEqual( m.kernel.shape, (m.in_features, m.out_features // n_devices) ) self.assertEqual(m.bias.shape, (m.out_features,)) f(m) self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_basic_shardmap_variables(self, graph, graph_updates): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('a',)) P = jax.sharding.PartitionSpec rngs = nnx.Rngs(0) w = nnx.Param(jax.random.normal(rngs.params(), (16, 32))) b = nnx.Param(jax.random.normal(rngs.params(), (32,))) count = nnx.BatchStat(jnp.array(0)) self.assertNotIsInstance(w.sharding, jax.sharding.NamedSharding) @nnx.shard_map( mesh=mesh, in_specs=(P(None, 'a'), P(), P()), out_specs=None, graph=graph, graph_updates=graph_updates, ) def f(w, b, count): count[...] += 1 self.assertEqual(w.shape, (16, 32 // n_devices)) self.assertEqual(b.shape, (32,)) f(w, b, count) if graph and graph_updates: self.assertIsInstance(w.sharding, jax.sharding.NamedSharding) self.assertIsInstance(b.sharding, jax.sharding.NamedSharding) self.assertEqual(count[...], 1) def test_from_state(self): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) mesh = jax.sharding.Mesh(devices, ('a',)) PS = jax.sharding.PartitionSpec state_spec = nnx.State( { 'kernel': PS(None, 'a'), 'bias': PS(), } ) state_sharding = nnx.StateSharding(state_spec) m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) def f(m: nnx.Linear): self.assertEqual( m.kernel.shape, (m.in_features, m.out_features // n_devices) ) self.assertEqual(m.bias.shape, (m.out_features,)) f(m) self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) self.assertIsInstance(m.bias.sharding, jax.sharding.NamedSharding) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_simple_data_parallel(self, graph, graph_updates): P = jax.sharding.PartitionSpec n_devices = jax.local_device_count() mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) with jax.set_mesh(mesh): m = nnx.Linear( in_features=2, out_features=3, kernel_metadata={'out_sharding': jax.P(None)}, bias_metadata={'out_sharding': jax.P(None)}, rngs=nnx.Rngs(0), ) x = jnp.ones((32, 2)) @nnx.shard_map( mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data'), graph=graph, graph_updates=graph_updates, ) def f(m, x): self.assertEqual(x.shape, (32 // n_devices, 2)) return m(x) y = f(m, x) self.assertEqual(y.shape, (32, 3)) self.assertIsInstance(y.sharding, jax.sharding.NamedSharding) if graph and graph_updates: self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) self.assertIsInstance(m.bias.sharding, jax.sharding.NamedSharding) def test_simple_tensor_parallel(self): P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) def path_ends_with(path_suffix): return lambda path, value: path[-len(path_suffix) :] == path_suffix model_sharding = nnx.StateSharding( { path_ends_with(('linear1', 'kernel')): P(None, 'model'), path_ends_with(('linear2', 'kernel')): P('model', None), } ) @nnx.shard_map( mesh=mesh, in_specs=(model_sharding, P(None)), out_specs=P(None) ) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_shardmap_with_sharding_names(self, graph, graph_updates): n_devices = jax.local_device_count() P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) with jax.set_mesh(mesh): w = nnx.Param(jnp.ones((8, 4)), out_sharding=('data', None)) b = nnx.Param(jnp.ones((4,)), out_sharding=(None,)) self.assertIsInstance(w.get_raw_value().sharding, jax.sharding.NamedSharding) self.assertEqual(w.out_sharding, ('data', None)) self.assertEqual(b.out_sharding, (None,)) @nnx.shard_map( mesh=mesh, in_specs=(P('data', None), P(None)), out_specs=P('data', None), graph=graph, graph_updates=graph_updates, ) def f(w, b): self.assertEqual(w.shape, (8 // n_devices, 4)) self.assertEqual(b.shape, (4,)) return w + b[None] y = f(w, b) self.assertEqual(y.shape, (8, 4)) self.assertIsInstance(y.sharding, jax.sharding.NamedSharding) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_shardmap_sharding_names_mutation(self, graph, graph_updates): n_devices = jax.local_device_count() P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) with jax.set_mesh(mesh): w = nnx.Param(jnp.zeros((8, 4)), out_sharding=('data', None)) count = nnx.BatchStat(jnp.array(0)) @nnx.shard_map( mesh=mesh, in_specs=(P('data', None), P()), out_specs=P('data', None), graph=graph, graph_updates=graph_updates, ) def f(w, count): count[...] += 1 self.assertEqual(w.shape, (8 // n_devices, 4)) return w + 1.0 y = f(w, count) self.assertEqual(count[...], 1) self.assertEqual(y.shape, (8, 4)) np.testing.assert_allclose(w[...], jnp.zeros((8, 4))) def test_shardmap_shared_variable(self): P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) v = nnx.Param(jnp.array(1.0)) class Container(nnx.Module): def __init__(self, v): self.v = v c1 = Container(v) c2 = Container(v) @nnx.shard_map( mesh=mesh, in_specs=(P(), P(), P()), out_specs=P(), graph=True, graph_updates=True, ) def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] y = f(c1, c2, jnp.array(1.0)) np.testing.assert_allclose(y, 4.0) np.testing.assert_allclose(v[...], 2.0) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_shardmap_module_variable_update(self, graph, graph_updates): P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) class Foo(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) m = Foo() @nnx.shard_map( mesh=mesh, in_specs=(P(), P()), out_specs=P(), graph=graph, graph_updates=graph_updates, ) def f(m, x): m.count[...] += 1 return x * 2 y = f(m, jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) self.assertEqual(m.count[...], 1) y = f(m, jnp.array(3.0)) self.assertEqual(m.count[...], 2) def test_shard_map_inconsistent_aliasing(self): v = nnx.Param(jnp.array(1.0)) P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.devices(), ('x',)) @nnx.shard_map( mesh=mesh, in_specs=(P(), P('x')), out_specs=P(), graph=True, graph_updates=False, ) def f(a, b): return a[...] + b[...] with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) class TestGrad(parameterized.TestCase): def test_grad(self): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) m = nnx.Dict( a=nnx.List([p1, p2]), b=p1, c=7, d=5.0, ) @nnx.grad def f(m: nnx.Dict): # sum all params return m['a'][0][...] + m['a'][1][...] + m['b'][...] grads = f(m) assert m.a[0] is m.b assert isinstance(grads, nnx.State) assert grads['a'][0][...] == 2.0 assert issubclass(type(grads['a'][0]), nnx.Variable) assert grads['a'][1][...] == 1.0 assert issubclass(type(grads['a'][1]), nnx.Variable) assert len(nnx.to_flat_state(grads)) == 2 nnx.update(m, grads) assert m['a'][0] is m.b assert m['a'][0][...] == 2.0 assert m['a'][1][...] == 1.0 assert m['b'][...] == 2.0 assert m['c'] == 7 assert m['d'] == 5.0 def test_grad_with_multiple_ref_types(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(10.0)), nnx.BatchStat(jnp.array(20.0))]), b=nnx.Param(jnp.array(10.0)), c=7, d=5.0, ) @nnx.grad def f(m: nnx.Dict): # sum all params return m.a[0] + m.a[1] + m.b grads = f(m) assert isinstance(grads, nnx.State) assert grads['a'][0][...] == 1.0 assert issubclass(type(grads['a'][0]), nnx.Param) assert len(grads) == 2 nnx.update(m, grads) assert m.a[0][...] == 1.0 assert m.a[1][...] == 20.0 assert m.b[...] == 1.0 assert m.c == 7 assert m.d == 5.0 def test_grad_with_type_predicate(self): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(10.0)), nnx.BatchStat(jnp.array(20.0))]), b=nnx.Param(jnp.array(10.0)), c=7, d=5.0, ) @nnx.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) def f(m: nnx.Dict): # sum all params return m.a[0] + m.a[1] + m.b grads = f(m) assert isinstance(grads, nnx.State) assert grads['a'][1][...] == 1.0 assert issubclass(type(grads['a'][1]), nnx.BatchStat) assert len(grads) == 1 nnx.update(m, grads) assert m.a[0][...] == 10.0 assert m.a[1][...] == 1.0 assert m.b[...] == 10.0 assert m.c == 7 assert m.d == 5.0 def test_multiple_inputs(self): rngs = nnx.Rngs(0) m = nnx.Linear(2, 3, rngs=rngs) loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) grad_fn = nnx.grad(loss_fn) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) grads = grad_fn(m, x, y) assert 'kernel' in grads assert grads['kernel'].shape == (2, 3) assert 'bias' in grads assert grads['bias'].shape == (3,) @parameterized.parameters( { 'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (0, 1), }, { 'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), 'argnums': (1, 3), }, ) def test_multiple_graph_nodes(self, loss_fn, argnums): rngs = nnx.Rngs(0) m1 = nnx.Linear(2, 3, rngs=rngs) m2 = nnx.Linear(3, 3, rngs=rngs) grad_fn = nnx.grad(loss_fn, argnums=argnums) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) inputs = [x, y] inputs.insert(argnums[0], m1) inputs.insert(argnums[1], m2) grads_m1, grads_m2 = grad_fn(*inputs) assert 'kernel' in grads_m1 assert grads_m1['kernel'].shape == (2, 3) assert 'bias' in grads_m1 assert grads_m1['bias'].shape == (3,) assert 'kernel' in grads_m2 assert grads_m2['kernel'].shape == (3, 3) assert 'bias' in grads_m2 assert grads_m2['bias'].shape == (3,) def test_multiple_args(self): m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(m1: nnx.Linear, m2: nnx.Linear): return jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias) grads_m1, grads_m2 = loss_fn(m1, m2) self.assertIn('kernel', grads_m1) self.assertNotIn('bias', grads_m1) self.assertNotIn('kernel', grads_m2) self.assertIn('bias', grads_m2) def test_multiple_args_in_pytrees(self): m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias ) grads_m1, grads_m2 = loss_fn([m1], [m2]) self.assertIn('kernel', grads_m1[0]) self.assertNotIn('bias', grads_m1[0]) self.assertNotIn('kernel', grads_m2[0]) self.assertIn('bias', grads_m2[0]) def test_value_and_grad_multiple_args_in_pytrees(self): m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias ) loss, (grads_m1, grads_m2) = loss_fn([m1], [m2]) self.assertEqual(loss.shape, ()) self.assertIn('kernel', grads_m1[0]) self.assertNotIn('bias', grads_m1[0]) self.assertNotIn('kernel', grads_m2[0]) self.assertIn('bias', grads_m2[0]) def test_value_and_grad_with_aux(self): m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate), has_aux=True) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): loss = jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias ) l1[0].kernel.set_value(jnp.array(-1.0)) m3 = nnx.Linear(2, 3, rngs=nnx.Rngs(2)) return loss, m3 (loss, m3), (grads_m1, grads_m2) = loss_fn([m1], [m2]) self.assertEqual(m1.kernel[...], -1.0) self.assertEqual(loss.shape, ()) self.assertIsInstance(m3, nnx.Linear) self.assertIn('kernel', grads_m1[0]) self.assertNotIn('bias', grads_m1[0]) self.assertNotIn('kernel', grads_m2[0]) self.assertIn('bias', grads_m2[0]) def test_variables_in_grad(self): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) m = dict(a=[p1, p2], b=p1) @nnx.grad def f(m: dict): return m['a'][0] + m['a'][1] + m['b'] grads = f(m) assert m['a'][0] is m['b'] assert isinstance(grads, dict) assert issubclass(type(grads['a'][0]), nnx.Variable) assert grads['a'][1][...] == 1.0 assert issubclass(type(grads['a'][1]), nnx.Variable) assert len(jax.tree.leaves(grads)) == 2 jax.tree.map( nnx.update, m, grads, is_leaf=lambda x: isinstance(x, nnx.Variable) ) assert m['a'][0] is m['b'] assert m['a'][0][...] == 2.0 assert m['a'][1][...] == 1.0 assert m['b'][...] == 2.0 @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_grad(self, graph, graph_updates): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.grad(graph=graph, graph_updates=graph_updates) def loss_fn(m: nnx.Linear): return jnp.mean(m.kernel) + jnp.mean(m.bias) grads = loss_fn(m) grad_type = nnx.State if graph and graph_updates else nnx.Linear self.assertIsInstance(grads, grad_type) self.assertEqual(grads.kernel.shape, (2, 3)) self.assertEqual(grads.bias.shape, (3,)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_grad_multiple_inputs(self, graph, graph_updates): rngs = nnx.Rngs(0) m = nnx.Linear(2, 3, rngs=rngs) loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) grad_fn = nnx.grad(loss_fn, graph=graph, graph_updates=graph_updates) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) grads = grad_fn(m, x, y) grad_type = nnx.State if graph and graph_updates else nnx.Linear self.assertIsInstance(grads, grad_type) self.assertEqual(grads.kernel.shape, (2, 3)) self.assertEqual(grads.bias.shape, (3,)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_grad_multiple_graph_nodes(self, graph, graph_updates): rngs = nnx.Rngs(0) m1 = nnx.Linear(2, 3, rngs=rngs) m2 = nnx.Linear(3, 3, rngs=rngs) loss_fn = lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2) grad_fn = nnx.grad(loss_fn, argnums=(0, 1), graph=graph, graph_updates=graph_updates) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) grads_m1, grads_m2 = grad_fn(m1, m2, x, y) grad_type = nnx.State if graph and graph_updates else nnx.Linear self.assertIsInstance(grads_m1, grad_type) self.assertEqual(grads_m1.kernel.shape, (2, 3)) self.assertEqual(grads_m1.bias.shape, (3,)) self.assertIsInstance(grads_m2, grad_type) self.assertEqual(grads_m2.kernel.shape, (3, 3)) self.assertEqual(grads_m2.bias.shape, (3,)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_value_and_grad_with_aux(self, graph, graph_updates): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.value_and_grad(has_aux=True, graph=graph, graph_updates=graph_updates) def loss_fn(m: nnx.Linear): loss = jnp.mean(m.kernel) m.kernel[...] = jnp.ones_like(m.kernel[...]) return loss, {'aux': 1} (loss, aux), grads = loss_fn(m) self.assertEqual(loss.shape, ()) self.assertEqual(aux, {'aux': 1}) grad_type = nnx.State if graph and graph_updates else nnx.Linear self.assertIsInstance(grads, grad_type) self.assertEqual(grads.kernel.shape, (2, 3)) np.testing.assert_allclose(m.kernel[...], jnp.ones((2, 3))) class TestCustomVJP(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_basic_call(self, graph, graph_updates): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) m2 = nnx.Linear(1, 1, rngs=nnx.Rngs(1)) @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) def f(m1: nnx.Linear, m2: nnx.Linear): y = m1.kernel * m2.kernel return y def f_fwd(m1, m2): y = f(m1, m2) return y, (m1, m2) def f_bwd(res, g): inputs_g, out_g = g m1, m2 = res return inputs_g f.defvjp(f_fwd, f_bwd) y = f(m1, m2) self.assertEqual(y.shape, (1, 1)) @parameterized.parameters( (True,), (False,), ) def test_basic_call_with_state(self, graph): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) m2 = nnx.Linear(1, 1, rngs=nnx.Rngs(1)) state = nnx.BatchStat(jnp.array(0.0)) @nnx.custom_vjp( nondiff_argnums=(2,), graph=graph, graph_updates=False, ) def f(m1: nnx.Linear, m2: nnx.Linear, state): y = m1.kernel * m2.kernel state[...] = jnp.array(-1.0) return y def f_fwd(m1, m2, state): y = f(m1, m2, state) return y, (m1, m2) def f_bwd(state, res, g): inputs_g, out_g = g m1, m2 = res return inputs_g f.defvjp(f_fwd, f_bwd) y = f(m1, m2, state) self.assertEqual(state[...], -1.0) self.assertEqual(y.shape, (1, 1)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jax_example(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore def f_fwd(m: Foo): y = f(m) res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore return y, res def f_bwd(res, g): cos_x, sin_x, m = res if graph and graph_updates: (m_g,), out_g = g self.assertIsInstance(m_g, nnx.State) m_g['x'][...] = cos_x * out_g * m.y m_g['y'][...] = sin_x * out_g return (m_g,) else: out_g = g m_g = nnx.clone(m) m_g.x[...] = cos_x * out_g * m.y m_g.y[...] = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) if graph and graph_updates: grads = nnx.grad(f, argnums=nnx.DiffState(0, ...))(m) self.assertIsInstance(grads, nnx.State) else: grads = nnx.grad( f, graph=graph, graph_updates=graph_updates, )(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore if graph and graph_updates: self.assertEqual(m.z, 1) else: self.assertEqual(m.z, 0) def test_diff_state(self): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int x_in_path = nnx.PathContains('x') diff_state = nnx.DiffState(0, x_in_path) @nnx.custom_vjp(nondiff_argnums=(diff_state,)) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore def f_fwd(m: Foo): y = f(m) res = (jnp.cos(m.x), m) # type: ignore return y, res def f_bwd(res, g): (m_g,), out_g = g cos_x, m = res self.assertIsInstance(m_g, nnx.State) self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) m_g['x'][...] = cos_x * out_g * m.y del m_g['y'] return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, x_in_path))(m) np.testing.assert_allclose(grad['x'][...], jnp.cos(1.0) * 2.0) # type: ignore self.assertEqual(m.z, 1) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_jax_example_with_remat(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) @nnx.remat(graph=graph, graph_updates=graph_updates) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore def f_fwd(m: Foo): y = f(m) res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore return y, res def f_bwd(res, g): cos_x, sin_x, m = res if graph and graph_updates: (m_g,), out_g = g self.assertIsInstance(m_g, nnx.State) m_g['x'][...] = cos_x * out_g * m.y m_g['y'][...] = sin_x * out_g return (m_g,) else: out_g = g m_g = jax.tree.map(lambda x: x, m) m_g.x[...] = cos_x * out_g * m.y m_g.y[...] = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) @nnx.jit(graph=graph, graph_updates=graph_updates) def loss_fn(m): return f(m) if graph and graph_updates: grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) self.assertIsInstance(grads, nnx.State) else: grads = nnx.grad( loss_fn, graph=graph, graph_updates=graph_updates, )(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore if graph and graph_updates: self.assertEqual(m.z, 1) else: self.assertEqual(m.z, 0) def test_two_args(self): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int @nnx.custom_vjp def f(m1: Foo, m2: Foo): m1.z += 1 y = jnp.sin(m1.x) * m1.y # type: ignore return y, m2 def f_fwd(m1: Foo, m2: Foo): y, m2 = f(m1, m2) res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) # type: ignore return (y, m2), res def f_bwd(res, g): (m1_g, m2_g), (y_g, _) = g cos_x, sin_x, m = res self.assertIsInstance(m1_g, nnx.State) self.assertIsInstance(m2_g, nnx.State) self.assertEqual(y_g.shape, ()) self.assertIsInstance(m, Foo) m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) m2_g = nnx.State(dict(x=m2_g['x'], y=m2_g['y'])) return m1_g, m2_g f.defvjp(f_fwd, f_bwd) m1 = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) m2 = Foo(nnx.Param(jnp.array(3.0)), nnx.Param(jnp.array(4.0)), 0) def loss_fn(m1, m2): y, m2 = f(m1, m2) return y + m2.x * m2.y m1_grad: nnx.State m2_grad: nnx.State m1_grad, m2_grad = nnx.grad( loss_fn, argnums=(nnx.DiffState(0, ...), nnx.DiffState(1, ...)) )(m1, m2) np.testing.assert_allclose(m1_grad['x'][...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(m1_grad['y'][...], jnp.sin(1.0)) # type: ignore self.assertEqual(m1.z, 1) np.testing.assert_allclose(m2_grad['x'][...], 4.0) # type: ignore np.testing.assert_allclose(m2_grad['y'][...], 3.0) # type: ignore @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_non_diff_args(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int @nnx.custom_vjp( nondiff_argnums=(0, 2), graph=graph, graph_updates=graph_updates, ) def f(a, m: Foo, b): self.assertEqual(a, 1) self.assertEqual(b, 2) m.z += 1 return jnp.sin(m.x) * m.y # type: ignore def f_fwd(a, m: Foo, b): self.assertEqual(a, 1) self.assertEqual(b, 2) y = f(a, m, b) res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore return y, res def f_bwd(a, b, res, g): cos_x, sin_x, m = res self.assertEqual(a, 1) self.assertEqual(b, 2) if graph and graph_updates: (m_g,), out_g = g self.assertIsInstance(m_g, nnx.State) m_g['x'][...] = cos_x * out_g * m.y m_g['y'][...] = sin_x * out_g return (m_g,) else: out_g = g m_g = jax.tree.map(lambda x: x, m) m_g.x[...] = cos_x * out_g * m.y m_g.y[...] = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) def loss_fn(m): a = 1 b = 2 return f(a, m, b) if graph and graph_updates: grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) self.assertIsInstance(grads, nnx.State) else: grads = nnx.grad( loss_fn, graph=graph, graph_updates=graph_updates, )(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore if graph and graph_updates: self.assertEqual(m.z, 1) else: self.assertEqual(m.z, 0) def test_docs_example(self): import jax.numpy as jnp from flax import nnx class Foo(nnx.Module): def __init__(self, x, y): self.x = nnx.Param(x) self.y = nnx.Param(y) @nnx.custom_vjp def f(m: Foo): return jnp.sin(m.x) * m.y # type: ignore def f_fwd(m: Foo): return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore def f_bwd(res, g): ins_g, out_g = g cos_x, sin_x, m = res tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g)) return (tangent_m,) f.defvjp(f_fwd, f_bwd) m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) grads = nnx.grad(f)(m) @parameterized.parameters( {'use_custom_vjp': False}, {'use_custom_vjp': True}, ) def test_issue(self, use_custom_vjp: bool): class MyLinear(nnx.Module): def __init__( self, in_features: int, out_features: int, *, rngs: nnx.Rngs ): kernel_init = nnx.initializers.normal(in_features**-0.5) self.kernel = nnx.Param( kernel_init(rngs.params(), (in_features, out_features), jnp.float32) ) self.bias = nnx.Param(jnp.zeros((out_features,), jnp.float32)) self.n = nnx.BatchStat(jnp.array(0, jnp.uint32)) def linear(m: MyLinear, x: jax.Array) -> jax.Array: m.n[...] += 1 y = x @ m.kernel + m.bias return y def linear_fwd(m: MyLinear, x: jax.Array): return linear(m, x), (m, x) def linear_bwd(res, g): m, x = res (m_g, _x_grad), outputs_g = g kernel_grad = outputs_g[None, :] * x[:, None] bias_grad = outputs_g x_grad = m.kernel @ outputs_g assert x_grad.shape == x.shape, 'Shape mismatch for x' assert m.kernel.shape == kernel_grad.shape, 'Shape mismatch for kernel' assert m.bias.shape == bias_grad.shape, 'Shape mismatch for bias' return (m_g, x_grad) if use_custom_vjp: linear = nnx.custom_vjp(linear) linear.defvjp(linear_fwd, linear_bwd) @nnx.jit def loss_fn(x, mod): y = linear(mod, x) return y.mean() mod = MyLinear(10, 5, rngs=nnx.Rngs(0)) self.assertEqual(mod.n[...], 0) x = jax.random.normal(jax.random.key(0), (10,)) loss, grad = nnx.value_and_grad(loss_fn)(x, mod) self.assertEqual(loss.shape, ()) self.assertEqual(grad.shape, (10,)) self.assertEqual(mod.n[...], 1) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_basic_call(self, graph, graph_updates): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) m2 = nnx.Linear(1, 1, rngs=nnx.Rngs(1)) state = nnx.BatchStat(jnp.array(0.0)) @nnx.custom_vjp( nondiff_argnums=(2,), graph=graph, graph_updates=graph_updates, ) def f(m1: nnx.Linear, m2: nnx.Linear, state): y = m1.kernel * m2.kernel state[...] = jnp.array(-1.0) return y def f_fwd(m1, m2, state): y = f(m1, m2, state) return y, (m1, m2) def f_bwd(state, res, g): m1, m2 = res return g, g f.defvjp(f_fwd, f_bwd) y = f(m1, m2, state) self.assertEqual(state[...], -1.0) self.assertEqual(y.shape, (1, 1)) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_jax_example(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) def f(m: Foo): return jnp.sin(m.x) * m.y def f_fwd(m: Foo): y = f(m) res = (jnp.cos(m.x), jnp.sin(m.x), m) return y, res def f_bwd(res, g): cos_x, sin_x, m = res m_g = jax.tree.map(lambda x: x, m) m_g.x[...] = cos_x * g * m.y m_g.y[...] = sin_x * g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0))) grads = nnx.grad(f, graph=graph, graph_updates=graph_updates)(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_with_remat(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) @nnx.remat(graph=graph, graph_updates=graph_updates) def f(m: Foo): return jnp.sin(m.x) * m.y def f_fwd(m: Foo): y = f(m) res = (jnp.cos(m.x), jnp.sin(m.x), m) return y, res def f_bwd(res, g): cos_x, sin_x, m = res m_g = jax.tree.map(lambda x: x, m) m_g.x[...] = cos_x * g * m.y m_g.y[...] = sin_x * g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0))) @nnx.jit(graph=graph, graph_updates=graph_updates) def loss_fn(m): return f(m) grads = nnx.grad(loss_fn, graph=graph, graph_updates=graph_updates)(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_non_diff_args(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] @nnx.custom_vjp(nondiff_argnums=(0, 2), graph=graph, graph_updates=graph_updates) def f(a, m: Foo, b): self.assertEqual(a, 1) self.assertEqual(b, 2) return jnp.sin(m.x) * m.y def f_fwd(a, m: Foo, b): self.assertEqual(a, 1) self.assertEqual(b, 2) y = f(a, m, b) res = (jnp.cos(m.x), jnp.sin(m.x), m) return y, res def f_bwd(a, b, res, g): cos_x, sin_x, m = res self.assertEqual(a, 1) self.assertEqual(b, 2) m_g = jax.tree.map(lambda x: x, m) m_g.x[...] = cos_x * g * m.y m_g.y[...] = sin_x * g return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0))) def loss_fn(m): a = 1 b = 2 return f(a, m, b) grads = nnx.grad(loss_fn, graph=graph, graph_updates=graph_updates)(m) self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) def test_tree_mode_diffstate_error(self): x_in_path = nnx.PathContains('x') diff_state = nnx.DiffState(0, x_in_path) with self.assertRaisesRegex( ValueError, r'`nondiff_argnums` cannot contain `DiffState` objects', ): nnx.custom_vjp(lambda m: m, nondiff_argnums=(diff_state,), graph=False) def test_grad_inconsistent_aliasing(self): v = nnx.Param(jnp.array(1.0)) def f(v_diff, v_nondiff): return v_diff[...] + v_nondiff[...] with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): nnx.grad(f, argnums=0, graph=True, graph_updates=False)(v, v) def test_custom_vjp_inconsistent_aliasing(self): v = nnx.Param(jnp.array(1.0)) @nnx.custom_vjp(nondiff_argnums=(1,), graph=True, graph_updates=False) def f(v_diff, v_nondiff): return v_diff[...] + v_nondiff[...] def f_fwd(v_diff, v_nondiff): return f(v_diff, v_nondiff), () def f_bwd(v_nondiff, res, g): return (nnx.clone(v_nondiff),) f.defvjp(f_fwd, f_bwd) with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) def test_custom_vjp_diff_arg_mutation_error(self): @nnx.custom_vjp(graph=True, graph_updates=False) def f(m): m.x[...] += 1 return m.x[...] * m.y[...] def f_fwd(m): return f(m), (m,) def f_bwd(res, g): (m,) = res m_g = nnx.clone(m) m_g.x[...] = g * m.y[...] m_g.y[...] = g * m.x[...] return (m_g,) f.defvjp(f_fwd, f_bwd) @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0))) with self.assertRaisesRegex( ValueError, 'Variables in differentiable argument' ): f(m) class TestVjpJvp(parameterized.TestCase): @parameterized.parameters( (True, False), (False, False), ) def test_vjp_basic(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] def f(m: Foo, x): return jnp.sum(m.w * x) m = Foo(w=nnx.Param(jnp.array([1.0, 2.0, 3.0]))) x = jnp.array([4.0, 5.0, 6.0]) primals_out, vjp_fn = nnx.vjp( f, m, x, graph=graph, graph_updates=graph_updates, ) np.testing.assert_allclose(primals_out, 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0) m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out)) self.assertIsInstance(m_grad, Foo) np.testing.assert_allclose(m_grad.w[...], x) np.testing.assert_allclose(x_grad, m.w[...]) @parameterized.parameters( (True, False), (False, False), ) def test_vjp_has_aux(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] def f(m: Foo, x): y = jnp.sum(m.w * x) return y, {'input': x} m = Foo(w=nnx.Param(jnp.array([1.0, 2.0]))) x = jnp.array([3.0, 4.0]) primals_out, vjp_fn, aux = nnx.vjp( f, m, x, has_aux=True, graph=graph, graph_updates=graph_updates, ) np.testing.assert_allclose(primals_out, 1.0 * 3.0 + 2.0 * 4.0) np.testing.assert_allclose(aux['input'], x) @parameterized.parameters( (True, False), (False, False), ) def test_vjp_state_propagation(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] count: nnx.BatchStat[jax.Array] def f(m: Foo, x): m.count[...] += 1 return jnp.sum(m.w * x) m = Foo( w=nnx.Param(jnp.array([1.0, 2.0])), count=nnx.BatchStat(jnp.array(0)), ) x = jnp.array([3.0, 4.0]) self.assertEqual(m.count[...], 0) primals_out, vjp_fn = nnx.vjp( f, m, x, graph=graph, graph_updates=graph_updates, ) self.assertEqual(m.count[...], 1) @parameterized.parameters( (True, False), (False, False), ) def test_vjp_matches_jax(self, graph, graph_updates): def f(w, x): return jnp.sum(w * x) w = jnp.array([1.0, 2.0, 3.0]) x = jnp.array([4.0, 5.0, 6.0]) jax_primals, jax_vjp_fn = jax.vjp(f, w, x) jax_grads = jax_vjp_fn(jnp.ones_like(jax_primals)) nnx_primals, nnx_vjp_fn = nnx.vjp( f, w, x, graph=graph, graph_updates=graph_updates, ) nnx_grads = nnx_vjp_fn(jnp.ones_like(nnx_primals)) np.testing.assert_allclose(nnx_primals, jax_primals) for ng, jg in zip(nnx_grads, jax_grads): np.testing.assert_allclose(ng, jg) @parameterized.parameters( (True, False), (False, False), ) def test_vjp_decorator(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] @nnx.vjp(graph=graph, graph_updates=graph_updates) def f(m: Foo, x): return jnp.sum(m.w * x) m = Foo(w=nnx.Param(jnp.array([1.0, 2.0]))) x = jnp.array([3.0, 4.0]) primals_out, vjp_fn = f(m, x) np.testing.assert_allclose(primals_out, 11.0) m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out)) np.testing.assert_allclose(m_grad.w[...], x) @parameterized.parameters( (True, False), (False, False), ) def test_jvp_basic(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] def f(m: Foo, x): return jnp.sum(m.w * x) m = Foo(w=nnx.Param(jnp.array([1.0, 2.0, 3.0]))) x = jnp.array([4.0, 5.0, 6.0]) m_tangent = jax.tree.map(jnp.ones_like, m) x_tangent = jnp.ones_like(x) primals_out, tangent_out = nnx.jvp( f, (m, x), (m_tangent, x_tangent), graph=graph, graph_updates=graph_updates, ) np.testing.assert_allclose(primals_out, 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0) expected_tangent = jnp.sum(jnp.ones(3) * x + m.w[...] * jnp.ones(3)) np.testing.assert_allclose(tangent_out, expected_tangent) @parameterized.parameters( (True, False), (False, False), ) def test_jvp_has_aux(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] def f(m: Foo, x): y = jnp.sum(m.w * x) return y, {'input': x} m = Foo(w=nnx.Param(jnp.array([1.0, 2.0]))) x = jnp.array([3.0, 4.0]) m_tangent = jax.tree.map(jnp.ones_like, m) x_tangent = jnp.ones_like(x) primals_out, tangent_out, aux = nnx.jvp( f, (m, x), (m_tangent, x_tangent), has_aux=True, graph=graph, graph_updates=graph_updates, ) np.testing.assert_allclose(primals_out, 11.0) np.testing.assert_allclose(aux['input'], x) @parameterized.parameters( (True, False), (False, False), ) def test_jvp_state_propagation(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] count: nnx.BatchStat[jax.Array] def f(m: Foo, x): m.count[...] += 1 return jnp.sum(m.w * x) m = Foo( w=nnx.Param(jnp.array([1.0, 2.0])), count=nnx.BatchStat(jnp.array(0.0)), ) x = jnp.array([3.0, 4.0]) m_tangent = jax.tree.map(jnp.zeros_like, m) x_tangent = jnp.zeros_like(x) self.assertEqual(m.count[...], 0.0) primals_out, tangent_out = nnx.jvp( f, (m, x), (m_tangent, x_tangent), graph=graph, graph_updates=graph_updates, ) self.assertEqual(m.count[...], 1.0) @parameterized.parameters( (True, False), (False, False), ) def test_jvp_matches_jax(self, graph, graph_updates): def f(w, x): return jnp.sum(w * x) w = jnp.array([1.0, 2.0, 3.0]) x = jnp.array([4.0, 5.0, 6.0]) w_tangent = jnp.ones_like(w) x_tangent = jnp.ones_like(x) jax_primals, jax_tangents = jax.jvp(f, (w, x), (w_tangent, x_tangent)) nnx_primals, nnx_tangents = nnx.jvp( f, (w, x), (w_tangent, x_tangent), graph=graph, graph_updates=graph_updates, ) np.testing.assert_allclose(nnx_primals, jax_primals) np.testing.assert_allclose(nnx_tangents, jax_tangents) @parameterized.parameters( (True, False), (False, False), ) def test_jvp_decorator(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): w: nnx.Param[jax.Array] @nnx.jvp(graph=graph, graph_updates=graph_updates) def f(m: Foo, x): return jnp.sum(m.w * x) m = Foo(w=nnx.Param(jnp.array([1.0, 2.0]))) x = jnp.array([3.0, 4.0]) m_tangent = jax.tree.map(jnp.ones_like, m) x_tangent = jnp.ones_like(x) primals_out, tangent_out = f((m, x), (m_tangent, x_tangent)) np.testing.assert_allclose(primals_out, 11.0) expected_tangent = jnp.sum(jnp.ones(2) * x + m.w[...] * jnp.ones(2)) np.testing.assert_allclose(tangent_out, expected_tangent) class TestScan(parameterized.TestCase): def test_basic(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) return x @nnx.split_rngs(splits=5) @nnx.scan(in_axes=(nnx.Carry, 0), length=5) def create_block(_, rngs: nnx.Rngs): return None, Block(rngs=rngs) _, module = create_block(None, nnx.Rngs(0)) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) @nnx.scan(in_axes=(nnx.Carry, 0, None), length=5) def forward_block(_, block: Block, x: jax.Array): return None, block(x) x = jnp.ones((1, 3)) out, y = forward_block(None, module, x) assert y.shape == (5, 1, 3) assert out is None @parameterized.parameters(True, False) def test_variables_in_scan(self, graph_updates): def block_init(din, dout, rngs): w = nnx.Param(jax.random.normal(rngs.params(), (din, dout))) b = nnx.Param(jnp.zeros((dout,))) return w, b def block_forward(w, b, x): return nnx.gelu(x @ w + b[None]) @nnx.split_rngs(splits=5) @nnx.scan(in_axes=0, out_axes=0, length=5, graph_updates=graph_updates) def create_block(rngs: nnx.Rngs): return block_init(3, 3, rngs) w, b = create_block(nnx.Rngs(0)) assert w.shape == (5, 3, 3) assert b.shape == (5, 3) @nnx.scan( in_axes=(0, 0, nnx.Carry), out_axes=nnx.Carry, graph_updates=graph_updates, ) def stack_forward(w, b, x): return block_forward(w, b, x) x = jnp.ones((1, 3)) y = stack_forward(w, b, x) assert y.shape == (1, 3) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_variables_as_carries_in_scan(self, graph, graph_updates): w = nnx.Param(jax.random.normal(jax.random.key(0), (3, 3))) b = nnx.Param(jnp.zeros((3,))) count = nnx.BatchStat(0) def block_forward(w, b, x): return nnx.gelu(x @ w + b[None]) @nnx.scan( in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), graph=graph, graph_updates=graph_updates, ) def stack_forward(params, x): w, b, count = params y = block_forward(w, b, x) count[...] += 1 return (w, b, count), y x = jnp.ones((5, 1, 3)) (w, b, count), y = stack_forward((w, b, count), x) assert y.shape == (5, 1, 3) assert count[...] == 5 def test_basic_no_carry(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) return x @nnx.split_rngs(splits=5) @nnx.scan(in_axes=(0,), out_axes=0, length=5) def create_block(rngs: nnx.Rngs): return Block(rngs=rngs) module = create_block(nnx.Rngs(0)) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) # assert module.node.shape == (2,) @nnx.scan(in_axes=(0, None), out_axes=0, length=5) def forward_block(block: Block, x: jax.Array): return block(x) x = jnp.ones((1, 3)) y = forward_block(module, x) assert y.shape == (5, 1, 3) @parameterized.parameters(True, False) def test_all_carry(self, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): n: nnx.BatchStat[int] foo = Foo(n=nnx.BatchStat(0)) @nnx.scan( in_axes=nnx.Carry, out_axes=nnx.Carry, length=3, graph_updates=graph_updates, ) def loop(foo: Foo): foo.n[...] += 1 return foo foo2 = loop(foo) self.assertIs(foo2.n, foo.n) self.assertEqual(foo.n[...], 3) def test_all_carry_one_argument_error(self): @dataclasses.dataclass class Foo(nnx.Module): n: nnx.BatchStat[int] foo = Foo(n=nnx.BatchStat(0)) @nnx.scan(in_axes=nnx.Carry, out_axes=nnx.Carry, length=3) def loop(foo: Foo, x): ... with self.assertRaisesRegex( ValueError, 'When in_axes=Carry, the function must take exactly one argument', ): loop(foo, 0) def test_all_carry_new_reference_error(self): class Foo(nnx.Module): def __init__(self, n: nnx.BatchStat[int]): self.n = n xs = jnp.arange(3) foo = Foo(n=nnx.BatchStat(0)) @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0)) def loop(foo: Foo, x): x = x + 1 foo = Foo(nnx.BatchStat(foo.n[...] + 1)) # new reference return foo, x with self.assertRaisesRegex( ValueError, 'Carry references must be the same between iterations', ): loop(foo, xs) @parameterized.parameters(True, False) def test_all_scan(self, graph_updates): class Foo(nnx.Module): def __init__(self, n: nnx.BatchStat[jax.Array]): self.n = n xs = jnp.arange(3) foo = Foo(n=nnx.BatchStat(jnp.arange(3))) @nnx.scan(in_axes=0, out_axes=0, graph_updates=graph_updates) def loop(foo: Foo, x): x = x + 1 foo.n[...] += 1 return x ys = loop(foo, xs) np.testing.assert_allclose(ys, jnp.arange(1, 4)) np.testing.assert_allclose(foo.n[...], jnp.arange(1, 4)) def test_all_broadcast(self): class Foo(nnx.Module): def __init__(self, n: nnx.BatchStat[int]): self.n = n xs = jnp.array(1) foo = Foo(n=nnx.BatchStat(2)) @nnx.scan(in_axes=None, out_axes=0, length=4) def loop(foo: Foo, x): return x + foo.n ys = loop(foo, xs) np.testing.assert_allclose(ys, 3) self.assertEqual(ys.shape, (4,)) def test_input_output_carry_mismatch_error(self): with self.assertRaisesRegex( ValueError, 'If one of in_axes or out_axes has Carry, the other must also have Carry', ): @nnx.scan(in_axes=0, out_axes=(nnx.Carry, 0)) def loop(a, b): ... with self.assertRaisesRegex( ValueError, 'If one of in_axes or out_axes has Carry, the other must also have Carry', ): @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) def loop(a, b): ... def test_double_carry_error(self): with self.assertRaisesRegex( ValueError, 'Found multiple Carry definitions', ): @nnx.scan(in_axes=(nnx.Carry, nnx.Carry)) def loop(a, b): ... def test_broadcast_in_output_error(self): with self.assertRaisesRegex( ValueError, 'Cannot broadcast output state', ): @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None)) def loop(a, b): ... with self.assertRaisesRegex( ValueError, 'Cannot broadcast output state. Got StateAxes', ): @nnx.scan( in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, nnx.StateAxes({...: None})) ) def loop(a, b): ... @parameterized.parameters( (True, False), (False, False), ) def test_scan_stateful(self, graph, graph_updates): count = nnx.Variable(jnp.array(0)) @nnx.scan(graph=graph, graph_updates=graph_updates) def f(count, x): count[...] += 1 return count, x + count[...] xs = jnp.arange(5) count_out, ys = f(count, xs) self.assertIs(count_out, count) self.assertEqual(count[...], 5) np.testing.assert_allclose(ys, jnp.array([1, 3, 5, 7, 9])) @parameterized.parameters( (True, False), (False, False), ) def test_scan_carry_identity_error(self, graph, graph_updates): count = nnx.Variable(jnp.array(0)) @nnx.scan(graph=graph, graph_updates=graph_updates) def f(count, x): new_count = nnx.Variable(count[...] + 1) return new_count, x with self.assertRaisesRegex( ValueError, 'scan Variable identity must be preserved', ): f(count, jnp.arange(3)) def test_tree_mode_custom_axes(self): @nnx.scan(in_axes=nnx.Carry, out_axes=nnx.Carry, length=3, graph=False) def loop(x): return x result = loop(jnp.array(1.0)) np.testing.assert_allclose(result, jnp.array(1.0)) @parameterized.parameters(True, False) def test_only_carry(self, graph_updates): class Foo(nnx.Module): def __init__(self): self.c = nnx.BatchStat(jnp.array(0)) @nnx.scan(in_axes=(nnx.Carry,), length=5, graph_updates=graph_updates) def loop(foo: Foo) -> tuple[Foo, jax.Array]: foo.c[...] += 1 return foo, foo.c[...] foo = Foo() foo2, cs = loop(foo) self.assertIs(foo2.c, foo.c) np.testing.assert_allclose(cs, jnp.arange(1, 6)) def test_out_axes(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.BatchStat(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=(nnx.Carry, 1, 2)) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) return x, x, x module = MLP(rngs=nnx.Rngs(0)) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) assert module.node.shape == (2,) x = jnp.ones((1, 3)) c, y1, y2 = module(x) assert c.shape == (1, 3) assert y1.shape == (1, 5, 3) assert y2.shape == (1, 3, 5) def test_in_axes_simple(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): @nnx.vmap(in_axes=(state_axes, 0)) def __init__(self, key: jax.Array): rngs = nnx.Rngs(key) self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) return x key = jax.random.split(jax.random.key(0), 5) module = MLP(key=key) x = jnp.ones((1, 3)) y = module(x) assert y.shape == (1, 3) def test_in_axes(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState, nnx.Intermediate): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry, 0)) def __call__( self, x: jax.Array, a: jax.Array ) -> tp.Tuple[jax.Array, None]: assert x.shape == a.shape x = x + a x = self.linear(x) x = nnx.gelu(x) self.sow(nnx.Intermediate, "data", x) return x, None module = MLP(rngs=nnx.Rngs(0)) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) assert module.node.shape == (2,) x = jnp.ones((1, 3)) a = jnp.ones((5, 1, 3)) (y, out), intermediates = nnx.capture(module, nnx.Intermediate)(x, a) assert y.shape == (1, 3) assert out is None assert intermediates['data'][0].shape == (5, 1, 3) def test_in_axes_broadcast(self): test = self state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.BatchStat(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry, 0, None)) def __call__( self, x: jax.Array, a: jax.Array, b: jax.Array ) -> tp.Tuple[jax.Array, None]: test.assertEqual(x.shape, a.shape) test.assertEqual(x.shape, b.shape) x = x + a + b x = self.linear(x) x = nnx.gelu(x) return x, None module = MLP(rngs=nnx.Rngs(0)) self.assertEqual(module.linear.kernel.shape, (5, 3, 3)) self.assertEqual(module.linear.bias.shape, (5, 3)) self.assertEqual(module.node.shape, (2,)) x = jnp.ones((1, 3)) a = jnp.ones((5, 1, 3)) b = jnp.ones((1, 3)) y, out = module(x, a, b) self.assertEqual(y.shape, (1, 3)) self.assertIsNone(out) def test_complex(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = MLP(rngs=nnx.Rngs(0)) module.set_attributes(deterministic=False, use_running_average=False) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) assert module.node.shape == (2,) x = jnp.ones((1, 3)) y, _ = module(x) assert y.shape == (1, 3) def test_complex_view(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = MLP(rngs=nnx.Rngs(0)) new_module = nnx.view(module, deterministic=False, use_running_average=False) assert new_module.linear.kernel.shape == (5, 3, 3) assert new_module.linear.bias.shape == (5, 3) assert new_module.node.shape == (2,) x = jnp.ones((1, 3)) y, _ = new_module(x) assert y.shape == (1, 3) def test_complex_broadcast_dropout(self): state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5, only='params') @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.split_rngs(splits=5, only='params') @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = MLP(rngs=nnx.Rngs(params=0, dropout=1)) module.set_attributes(deterministic=False, use_running_average=False) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) assert module.node.shape == (2,) x = jnp.ones((1, 3)) y, _ = module(x) assert y.shape == (1, 3) def test_complex_broadcast_dropout_view(self): state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) class MLP(nnx.Module): @nnx.split_rngs(splits=5, only='params') @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.split_rngs(splits=5, only='params') @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = MLP(rngs=nnx.Rngs(params=0, dropout=1)) new_module = nnx.view(module, deterministic=False, use_running_average=False) assert new_module.linear.kernel.shape == (5, 3, 3) assert new_module.linear.bias.shape == (5, 3) assert new_module.node.shape == (2,) x = jnp.ones((1, 3)) y, _ = new_module(x) assert y.shape == (1, 3) def test_complex_decorator(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class Block(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) def __init__(self, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = Block(rngs=nnx.Rngs(0)) module.set_attributes(deterministic=False, use_running_average=False) assert module.d == 3 assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) assert module.node.shape == (2,) x = jnp.ones((1, 3)) y, out = module(x) assert y.shape == (1, 3) assert out is None def test_complex_decorator_view(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class Block(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) def __init__(self, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) @nnx.scan(in_axes=(state_axes, nnx.Carry)) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) x = self.dropout(x) x = nnx.gelu(x) return x, None module = Block(rngs=nnx.Rngs(0)) new_module = nnx.view(module, deterministic=False, use_running_average=False) assert new_module.d == 3 assert new_module.linear.kernel.shape == (5, 3, 3) assert new_module.linear.bias.shape == (5, 3) assert new_module.node.shape == (2,) x = jnp.ones((1, 3)) y, out = new_module(x) assert y.shape == (1, 3) assert out is None def test_scan_with_sharding(self): test = self state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) transform_metadata = {nnx.PARTITION_NAME: 'layers'} class MLP(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap( in_axes=(state_axes, state_axes), transform_metadata=transform_metadata, ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( 3, 3, kernel_init=nnx.with_metadata( nnx.initializers.lecun_normal(), out_sharding=('din', 'dout') ), bias_init=nnx.with_metadata( nnx.initializers.zeros_init(), out_sharding=('dout',) ), rngs=rngs, ) @nnx.scan( in_axes=(state_axes, nnx.Carry), transform_metadata=transform_metadata ) def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (3, 3)) test.assertEqual(self.linear.kernel.out_sharding, ('din', 'dout')) test.assertEqual(self.linear.bias.shape, (3,)) test.assertEqual(self.linear.bias.out_sharding, ('dout',)) return x, None mesh = jax.make_mesh((1, 1, 1), ('layers', 'din', 'dout'), axis_types=(jax.sharding.AxisType.Auto,) * len(('layers', 'din', 'dout'))) with jax.set_mesh(mesh): m = MLP(rngs=nnx.Rngs(0)) # test sharding layers axes is set self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) x = jnp.ones((1, 3)) with jax.set_mesh(mesh): y, out = m(x) # test sharding axes is preserved self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) def test_cache_tracing_simple(self): n = 0 x = jnp.arange(5) count = jnp.array(0) @nnx.scan def f(count, x): nonlocal n n += 1 return count + 1, x**2 count, y = f(count, x) assert n == 1 assert count == 5 np.testing.assert_allclose(y, x**2) count, y = f(count, x) assert n == 1 assert count == 10 def test_cache_tracing_object(self): n = 0 x = jnp.arange(5) count = jnp.array(0) class Foo(nnx.Pytree): @nnx.split_rngs(splits=5) @nnx.vmap(axis_size=5) def __init__(self, rngs: nnx.Rngs): self.x = nnx.Param(jax.random.normal(rngs(), shape=(3,))) foo = Foo(rngs=nnx.Rngs(0)) assert foo.x.shape == (5, 3) @nnx.scan(in_axes=(nnx.Carry, 0, 0)) def f(count, x, foo): nonlocal n n += 1 assert foo.x.shape == (3,) return count + 1, x**2 count, y = f(count, x, foo) assert n == 1 assert count == 5 np.testing.assert_allclose(y, x**2) count, y = f(count, x, foo) assert n == 1 assert count == 10 def test_scan_broadcast_keys(self): params_key = jax.random.split(jax.random.key(0), 3) rngs = nnx.Rngs(params=params_key, dropout=1) state_axes = nnx.StateAxes({'params': 0, ...: None}) @nnx.scan(in_axes=(nnx.Carry, state_axes), length=3) def f(_, rngs: nnx.Rngs): param_key = rngs.params() dropout_key = rngs.dropout() return (), (param_key, dropout_key) _, (param_keys, dropout_keys) = f((), rngs) assert jnp.not_equal(param_keys[0], param_keys[1]) assert jnp.not_equal(param_keys[1], param_keys[2]) assert jnp.equal(dropout_keys[0], dropout_keys[1]) assert jnp.equal(dropout_keys[1], dropout_keys[2]) def test_rnn_example(self): class RNNCell(nnx.Module): def __init__(self, input_size, hidden_size, rngs): self.linear = nnx.Linear( hidden_size + input_size, hidden_size, rngs=rngs ) self.drop = nnx.Dropout(0.1, rngs=rngs) self.hidden_size = hidden_size def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: carry = self.drop(carry) # recurrent dropout x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1))) return x, x def initial_state(self, batch_size: int): return jnp.zeros((batch_size, self.hidden_size)) cell = RNNCell(20, 20, nnx.Rngs(params=0, dropout=1)) state_axes = nnx.StateAxes({'dropout': None, ...: nnx.Carry}) def rnn_forward(cell: RNNCell, x: jax.Array): carry = cell.initial_state(x.shape[0]) @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: return cell(carry, x) _, y = unroll(cell, carry, x) return y x = jnp.ones((16, 10, 20)) y = rnn_forward(cell, x) def test_carry_pytree_sow(self): class CarryAsPytree(nnx.Pytree): def __init__(self, data: jax.Array): self.data = data class Model(nnx.Module): def __init__(self, num_steps): self.fc = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) self.num_steps = num_steps def _step(self, state): new_data = state.data + 1 self.sow(nnx.Intermediate, "data", new_data) state.data = new_data return state def _step2(self, state: tuple[CarryAsPytree, jax.Array, CarryAsPytree]): out = self.fc(state[1]) new_data1 = state[0].data + 1 new_data2 = state[2].data + 1 self.sow(nnx.Intermediate, "data1", new_data1) self.sow(nnx.Intermediate, "data2", new_data2) state[0].data = new_data1 state[2].data = new_data2 return (state[0], out, state[2]) @nnx.jit(static_argnames=("method")) def __call__(self, state, method): state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry}) state_final = nnx.scan( method, in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry, length=self.num_steps, )(self, state) return state_final num_steps = 5 model = Model(num_steps=num_steps) carry = CarryAsPytree(data=jnp.array(0.0)) carry_final, intermediates = nnx.capture(model, nnx.Intermediate)(carry, method=Model._step) self.assertEqual(carry_final.data, num_steps) np.testing.assert_array_equal( intermediates['data'][0], 1.0 + jnp.arange(num_steps) ) carry = ( CarryAsPytree(data=jnp.array(0.0)), jnp.ones((10,)), CarryAsPytree(data=jnp.array(10.0)) ) carry_final, intermediates = nnx.capture(model, nnx.Intermediate)(carry, method=Model._step2) self.assertEqual(carry_final[0].data, num_steps) self.assertEqual(carry_final[2].data, 10 + num_steps) np.testing.assert_array_equal( intermediates['data1'][0], 1.0 + jnp.arange(num_steps) ) np.testing.assert_array_equal( intermediates['data2'][0], 11.0 + jnp.arange(num_steps) ) def test_broadcast_variable_mutation_rejected(self): v = nnx.Variable(jnp.array(1.0)) @nnx.scan( in_axes=(None, nnx.Carry, 0), graph=False, graph_updates=False, ) def fn(v, carry, x): v[...] = v[...] + 1.0 return carry + x, carry with self.assertRaisesRegex(ValueError, 'Broadcast.*mutated'): fn(v, jnp.array(0.0), jnp.arange(3.0)) def test_broadcast_out_axes_rejected(self): @nnx.scan( in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None), graph=False, graph_updates=False, ) def fn(carry, x): return carry + x, jnp.zeros(3) with self.assertRaisesRegex(ValueError, 'broadcast'): fn(jnp.array(0.0), jnp.arange(3.0)) def test_scan_inconsistent_aliasing(self): v = nnx.Param(jnp.array(0.0)) @nnx.scan( in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), graph=True, graph_updates=False, ) def f(carry, x): return carry, x[...] with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) def test_scan_input_output_aliasing(self): v = nnx.Param(jnp.arange(5)) @nnx.scan(in_axes=0, out_axes=0, graph=True, graph_updates=False) def f(carry): return carry with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): f(v) class TestRemat(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_remat_basic(self, graph, graph_updates): class RematLinear(nnx.Module): def __init__(self, din: int, dout: int, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) @nnx.remat(graph=graph, graph_updates=graph_updates) def __call__(self, x: jax.Array) -> jax.Array: return self.linear(x) module = RematLinear(2, 3, nnx.Rngs(0)) def loss_fn(module, x): y = module(x) return jnp.sum(y) grad_type = nnx.State if graph and graph_updates else RematLinear loss, grads = nnx.value_and_grad( loss_fn, graph=graph, graph_updates=graph_updates, )(module, jnp.ones((1, 2))) assert loss.shape == () assert isinstance(grads, grad_type) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_remat_variables(self, graph, graph_updates): rngs = nnx.Rngs(0) w = nnx.Param(jax.random.normal(rngs(), (2, 3))) b = nnx.Param(jax.random.normal(rngs(), (3,))) count = nnx.BatchStat(jnp.array(0)) @nnx.remat(graph=graph, graph_updates=graph_updates) def linear(w, b, count, x): count[...] += 1 return x @ w + b[None] def loss_fn(w, b, count, x): return jnp.sum(linear(w, b, count, x)) x = jnp.ones((1, 2)) loss, grads = nnx.value_and_grad( loss_fn, argnums=(0, 1), graph=graph, graph_updates=graph_updates, )(w, b, count, x) assert loss.shape == () assert isinstance(grads, tuple) assert len(grads) == 2 assert count[...] == 1 def test_remat_with_scan_decorator(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class ScanLinear(nnx.Module): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @nnx.scan(in_axes=(state_axes, nnx.Carry)) @nnx.remat def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: x = self.linear(x) return x, None m = ScanLinear(nnx.Rngs(0)) assert m.linear.kernel.shape == (5, 3, 3) assert m.linear.bias.shape == (5, 3) y, _ = m(jnp.ones((1, 3))) assert y.shape == (1, 3) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_remat_basic(self, graph, graph_updates): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.remat(graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) def loss_fn(model, x): y = forward(model, x) return jnp.sum(y) grads = nnx.grad( loss_fn, graph=graph, graph_updates=graph_updates, )(model, jnp.ones((1, 2))) assert grads.kernel.shape == (2, 3) assert grads.bias.shape == (3,) @parameterized.parameters( (True, False), (False, False), ) def test_tree_mode_remat_stateful(self, graph, graph_updates): class Counter(nnx.Variable): pass class Linear(nnx.Module): def __init__(self, din, dout, *, rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) self.count = Counter(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w + self.b[None] model = Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.remat(graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) y = forward(model, jnp.ones((1, 2))) assert y.shape == (1, 3) assert model.count[...] == 1 class TestVmap(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_vmap_basic(self, graph, graph_updates): class LinearEnsemble(nnx.Module): def __init__(self, num, *, rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) model = LinearEnsemble(5, rngs=nnx.Rngs(0)) x = jnp.ones((2,)) @nnx.vmap(in_axes=(0, None), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): return x @ model.w y = forward(model, x) assert y.shape == (5, 3) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_vmap_stateful(self, graph, graph_updates): class Counter(nnx.Variable): pass class Linear(nnx.Module): def __init__(self, din, dout, *, rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.count = Counter(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w model = Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.vmap(in_axes=(None, 0), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) x = jnp.ones((5, 2)) y = forward(model, x) assert y.shape == (5, 3) assert model.count[...] == 1 @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_vmap_variables(self, graph, graph_updates): rngs = nnx.Rngs(0) w = nnx.Param(jax.random.normal(rngs(), (5, 2, 3))) b = nnx.Param(jax.random.normal(rngs(), (5, 3))) @nnx.vmap(in_axes=(0, 0, 1), out_axes=1, graph=graph, graph_updates=graph_updates) def forward(w, b, x): return x @ w + b x = jax.random.uniform(rngs(), (2, 5)) y = forward(w, b, x) assert y.shape == (3, 5) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_vmap_ensemble_forward(self, graph, graph_updates): class Linear(nnx.Module): def __init__(self, din, dout, *, rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): return x @ self.w + self.b[None] @nnx.vmap(in_axes=0, out_axes=0, graph=graph, graph_updates=graph_updates) def create_ensemble(keys): return Linear(2, 3, rngs=nnx.Rngs(keys)) keys = jax.random.split(jax.random.key(0), 5) ensemble = create_ensemble(keys) assert ensemble.w.shape == (5, 2, 3) assert ensemble.b.shape == (5, 3) @nnx.vmap(in_axes=(0, None), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) x = jnp.ones((1, 2)) y = forward(ensemble, x) assert y.shape == (5, 1, 3) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_vmap_replicate(self, graph, graph_updates): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.vmap(in_axes=(None, 0), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) x = jnp.ones((5, 1, 2)) y = forward(model, x) assert y.shape == (5, 1, 3) def test_basic(self): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=0, out_axes=0, axis_size=5) def create_block(rngs: nnx.Rngs): return nnx.Linear(2, 3, rngs=rngs) rngs = nnx.Rngs(0) block = create_block(rngs) self.assertEqual(block.kernel.shape, (5, 2, 3)) self.assertEqual(rngs.default.count[...], 1) @nnx.vmap(in_axes=(0, 1), out_axes=1) def forward(block: nnx.Linear, x): self.assertEqual(block.kernel.shape, (2, 3)) self.assertEqual(block.bias.shape, (3,)) self.assertEqual(x.shape, (2,)) return block(x) x = jax.random.uniform(rngs(), (2, 5)) y = forward(block, x) self.assertEqual(y.shape, (3, 5)) def test_basic_variables(self): @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=0, out_axes=0, axis_size=5) def create_block(rngs: nnx.Rngs): w = nnx.Param(jax.random.normal(rngs(), (2, 3))) b = nnx.Param(jax.random.normal(rngs(), (3,))) return w, b rngs = nnx.Rngs(0) w, b = create_block(rngs) self.assertEqual(w.shape, (5, 2, 3)) self.assertEqual(b.shape, (5, 3)) self.assertEqual(rngs.default.count[...], 1) @nnx.vmap(in_axes=(0, 0, 1), out_axes=1) def forward(w, b, x): self.assertEqual(w.shape, (2, 3)) self.assertEqual(b.shape, (3,)) self.assertEqual(x.shape, (2,)) return x @ w + b x = jax.random.uniform(rngs(), (2, 5)) y = forward(w, b, x) self.assertEqual(y.shape, (3, 5)) def test_state_axes(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = nnx.relu(x) x = self.dropout(x) return x @nnx.vmap( in_axes=0, out_axes=nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), ) def create_block(rngs: nnx.Rngs): rngs = nnx.clone(rngs) return Block(rngs) rngs = nnx.Rngs(0) initial_key = rngs.default.key[...] backups = nnx.split_rngs(rngs, splits=5) module = create_block(rngs) nnx.restore_rngs(backups) assert rngs.default.count[...] == 1 assert rngs.default.key[...] == initial_key assert not jnp.allclose( module.linear.kernel[0], module.linear.kernel[1], ) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) x = jnp.ones((5, 1, 3)) @nnx.vmap( in_axes=(nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), 0), ) def forward_block(module, x): return module(x) backups = nnx.split_rngs(rngs, splits=5) y = forward_block(module, x) nnx.restore_rngs(backups) assert y.shape == (5, 1, 3) assert rngs.default.count[...] == 2 assert rngs.default.key[...] == initial_key y2 = forward_block(module, x) assert not jnp.allclose(y, y2) def test_split_rngs_context_manager(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = nnx.relu(x) x = self.dropout(x) return x state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(0) initial_key = rngs.default.key[...] module = create_block(rngs.split(5)) assert rngs.default.count[...] == 1 assert rngs.default.key[...] == initial_key assert not jnp.allclose( module.linear.kernel[0], module.linear.kernel[1], ) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) x = jnp.ones((5, 1, 3)) @nnx.vmap(in_axes=(state_axes, 0)) def forward_block(module, x): return module(x) y = forward_block(module, x) assert y.shape == (5, 1, 3) assert rngs.default.key[...] == initial_key y2 = forward_block(module, x) assert not jnp.allclose(y, y2) def test_split_rngs_decorator(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = nnx.relu(x) x = self.dropout(x) return x state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @nnx.split_rngs(splits=5) @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(0) initial_key = rngs.default.key[...] module = create_block(rngs) assert rngs.default.count[...] == 1 assert rngs.default.key[...] == initial_key assert not jnp.allclose( module.linear.kernel[0], module.linear.kernel[1], ) assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) x = jnp.ones((5, 1, 3)) @nnx.vmap(in_axes=(state_axes, 0)) def forward_block(module, x): self.assertEqual(x.shape, (1, 3)) return module(x) y = forward_block(module, x) assert y.shape == (5, 1, 3) assert rngs.default.key[...] == initial_key y2 = forward_block(module, x) assert not jnp.allclose(y, y2) def test_state_axes_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return nnx.relu(self.dropout(self.bn(self.linear(x)))) state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(params=0, dropout=1) nnx.split_rngs(rngs, splits=5, only='dropout') module = create_block(rngs) assert module.linear.kernel.shape == (2, 3) assert module.bn.scale.shape == (3,) assert module.bn.mean.shape == (5, 3) @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) def forward_block(module, x): return module(x) x = jnp.ones((5, 1, 2)) y = forward_block(module, x) assert y.shape == (5, 1, 3) def test_split_rngs_decorator_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return nnx.relu(self.dropout(self.bn(self.linear(x)))) state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) @nnx.split_rngs(splits=5, only='dropout') @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(params=0, dropout=1) module = create_block(rngs) assert module.linear.kernel.shape == (2, 3) assert module.bn.scale.shape == (3,) assert module.bn.mean.shape == (5, 3) assert module.dropout.rngs is not None self.assertEqual(module.dropout.rngs.key.shape, (5,)) @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) def forward_block(module: Block, x): assert module.dropout.rngs is not None self.assertEqual(module.dropout.rngs.key.shape, ()) return module(x) x = jnp.ones((5, 1, 2)) y = forward_block(module, x) assert module.dropout.rngs is not None self.assertEqual(module.dropout.rngs.key.shape, (5,)) assert y.shape == (5, 1, 3) def test_state_axes_super_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return nnx.relu(self.dropout(self.bn(self.linear(x)))) @nnx.vmap(in_axes=0, out_axes=0) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(0) nnx.split_rngs(rngs, splits=5) module = create_block(rngs) assert module.linear.kernel.shape == (5, 2, 3) assert module.bn.scale.shape == (5, 3) assert module.bn.mean.shape == (5, 3) @nnx.vmap(in_axes=(0, 0), out_axes=0) def forward_block(module, x): return module(x) x = jnp.ones((5, 1, 2)) y = forward_block(module, x) assert y.shape == (5, 1, 3) def test_replicate(self): din = 3 dout = 10 class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) def create_block(rngs: nnx.Rngs): return Block(rngs) state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) @nnx.split_rngs(splits=5) @partial(nnx.vmap, in_axes=(state_axes, 0), out_axes=0) def forward_block(module: Block, x): return module(x) rngs = nnx.Rngs(0) module = create_block(rngs) initial_key = module.dropout.rngs.key[...] assert module.dropout.rngs.count[...] == 0 assert module.linear.kernel.shape == (din, dout) assert module.linear.bias.shape == (dout,) x = jnp.ones((5, 1, din)) y = forward_block(module, x) assert y.shape == (5, 1, dout) assert module.dropout.rngs.count[...] == 1 assert not jnp.allclose(y[0], y[1]) y2 = forward_block(module, x) # dropout is working! assert not jnp.allclose(y, y2) assert module.dropout.rngs.key[...] == initial_key def test_consistent_aliasing_inputs(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((5, 5))) m = Foo() @nnx.vmap(in_axes=(0, 1)) def f(m1, m2): pass with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): f(m, m) def test_consistent_aliasing_input_output(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((2, 3))) m = Foo() @partial(nnx.vmap, in_axes=0, out_axes=1) def f(m): return m with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): m2 = f(m) def test_consistent_aliasing_shared(self): class Shared(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((3, 3))) class Foo(nnx.Module): def __init__(self, shared: Shared): self.a = shared shared = Shared() m1 = Foo(shared) m2 = Foo(shared) @nnx.vmap(in_axes=(0, 1)) def f(m1, m2): pass with self.assertRaisesRegex( ValueError, r'Inconsistent aliasing detected([\s\S]*)Param([\s\S]*)a:' r' 0([\s\S]*)a: 1', ): f(m1, m2) def test_equivalent_state_axes_mapping(self): m = nnx.Linear(3, 3, rngs=nnx.Rngs(0)) sa1 = nnx.StateAxes({...: 0}) sa2 = nnx.StateAxes({nnx.Param: 0}) @nnx.vmap(in_axes=(0, sa1, sa2)) def f(m1, m2, m3): pass f(m, m, m) def test_equivalent_state_sharding_mapping(self): m = nnx.Linear(4, 4, rngs=nnx.Rngs(0)) mesh = jax.sharding.Mesh(jax.devices(), ('mp',)) sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec('mp') ) sa1 = nnx.StateSharding({...: sharding}) sa2 = nnx.StateSharding({nnx.Param: sharding}) @nnx.jit(in_shardings=(sharding, sa1, sa2)) def f(m1, m2, m3): pass f(m, m, m) def test_captured_module_in_return_error(self): class Foo(nnx.Module): def __init__(self): self.a = jnp.zeros((4, 4)) m = Foo() @nnx.vmap(in_axes=0, out_axes=0) def f(x): return x, m with self.assertRaisesRegex( errors.TraceContextError, r'Trying to extract graph node from different trace level.*Foo', ): x = jnp.zeros((4,)) f(x) def test_vmap_and_cond_passthrough(self): class Broadcast(nnx.Variable[nnx.A]): ... class Vectorized(nnx.Variable[nnx.A]): ... class Env(nnx.Module): def __init__(self): self.broadcast = Broadcast(jnp.array(1)) self.index = Vectorized(jnp.arange(8)) self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) env = Env() @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) def f(env: Env): self.assertEqual(env.step.shape, ()) def increment(env: Env): env.step[...] += 1 def no_nothing(env: Env): pass is_even = env.index % 2 == 0 nnx.cond(is_even, increment, no_nothing, env) f(env) np.testing.assert_array_equal(env.step[...], [1, 0, 1, 0, 1, 0, 1, 0]) def test_vmap_and_cond_passthrough_error(self): class Broadcast(nnx.Variable[nnx.A]): ... class Vectorized(nnx.Variable[nnx.A]): ... class Env(nnx.Module): def __init__(self): self.broadcast = Broadcast(jnp.array(1)) self.index = Vectorized(jnp.arange(8)) self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) env = Env() @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) def f(env: Env): self.assertEqual(env.step.shape, ()) def increment(env: Env): env.step[...] += 1 env.broadcast[...] += 1 def no_nothing(env: Env): pass is_even = env.index % 2 == 0 nnx.cond(is_even, increment, no_nothing, env) with self.assertRaisesRegex( ValueError, r"at vmap.*'broadcast'.*got axis spec None but output was batched on" r' axis 0', ): f(env) def test_example(self): class Model(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) @nnx.vmap(in_axes=0, out_axes=0) def initialize_ensemble(key): rngs = nnx.Rngs(key) return Model(2, 3, rngs=rngs) keys = jax.random.split(jax.random.key(0), 5) ensemble = initialize_ensemble(keys) self.assertEqual(ensemble.linear.kernel.shape, (5, 2, 3)) @nnx.vmap(in_axes=(0, None), out_axes=0) def forward(model, x): return model(x) x = jnp.ones((4, 2)) y = forward(ensemble, x) self.assertEqual(y.shape, (5, 4, 3)) def test_example_with_vectorization(self): class LinearEnsemble(nnx.Module): def __init__(self, num, rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) model = LinearEnsemble(5, rngs=nnx.Rngs(0)) @nnx.vmap(in_axes=(0, None), out_axes=0) def forward(model, x): self.assertEqual(model.w.shape, (2, 3)) return jnp.dot(x, model.w) x = jnp.ones((4, 2)) y = forward(model, x) self.assertEqual(y.shape, (5, 4, 3)) def test_metadata(self): @nnx.vmap( in_axes=(None,), out_axes=0, axis_size=5, transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, ) def create_block(rngs: nnx.Rngs): return nnx.Linear( 16, 32, rngs=rngs, kernel_init=nnx.with_partitioning( nnx.initializers.lecun_normal(), ('a', 'b') ), ) mesh = jax.make_mesh((1, 1, 1), ('a', 'b', 'c'), axis_types=(jax.sharding.AxisType.Auto,) * len(('a', 'b', 'c'))) with jax.set_mesh(mesh): m = create_block(nnx.Rngs(0)) self.assertEqual(m.kernel.shape, (5, 16, 32)) self.assertEqual(m.kernel.out_sharding, ('c', 'a', 'b')) def test_state_axes_from_state(self): class Model(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) model = Model(2, 3, rngs=nnx.Rngs(0)) state = nnx.state(model) state['linear']['kernel'] = 0 state['linear']['bias'] = 1 state['bn']['scale'] = 0 state['bn']['mean'] = 1 state['bn']['var'] = 0 state['bn']['bias'] = None state_axes = nnx.StateAxes(state) self.assertEqual(state_axes.map_prefix(('linear', 'kernel'), None), 0) self.assertEqual(state_axes.map_prefix(('linear', 'bias'), None), 1) self.assertEqual(state_axes.map_prefix(('bn', 'scale'), None), 0) self.assertEqual(state_axes.map_prefix(('bn', 'mean'), None), 1) self.assertEqual(state_axes.map_prefix(('bn', 'var'), None), 0) self.assertEqual(state_axes.map_prefix(('bn', 'bias'), None), None) @nnx.vmap(out_axes=state_axes, axis_size=5) def create_block(): return Model(2, 3, rngs=nnx.Rngs(0)) model = create_block() self.assertEqual(model.linear.kernel.shape, (5, 2, 3)) self.assertEqual(model.linear.bias.shape, (3, 5)) self.assertEqual(model.bn.scale.shape, (5, 3)) self.assertEqual(model.bn.mean.shape, (3, 5)) self.assertEqual(model.bn.var.shape, (5, 3)) self.assertEqual(model.bn.bias.shape, (3,)) def test_vmap_inconsistent_aliasing(self): v = nnx.Param(jnp.arange(3.0)) @nnx.vmap(in_axes=(0, None), graph=True, graph_updates=False) def f(v_mapped, v_broadcast): return v_mapped[...] + v_broadcast[...] with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) class TestPmap(parameterized.TestCase): def test_basic_single(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 10, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = nnx.elu(x) x = self.dropout(x) return x state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @nnx.split_rngs(splits=1) @nnx.pmap(in_axes=(state_axes,), out_axes=state_axes, axis_size=1, graph=True) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(0) module = create_block(rngs) initial_key = module.dropout.rngs.key[...] assert module.dropout.rngs.count[0] == 0 assert module.linear.kernel.shape == (1, 3, 10) assert module.linear.bias.shape == (1, 10) x = jnp.ones((1, 1, 3)) @nnx.pmap(in_axes=(state_axes, 0), axis_size=1, graph=True) def forward_block(module, x): return module(x) y = forward_block(module, x) assert y.shape == (1, 1, 10) assert module.dropout.rngs.count[0] == 1 assert module.dropout.rngs.key[...] == initial_key y2 = forward_block(module, x) assert not jnp.allclose(y, y2) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_basic_demo_single(self, graph, graph_updates): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(20, 20, rngs=rngs) self.dropout = nnx.Dropout(0.2, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) @nnx.split_rngs(splits=1) @nnx.pmap(axis_size=1, graph=graph, graph_updates=graph_updates) def create_block(rngs: nnx.Rngs): return Block(rngs) @nnx.pmap(axis_size=1, graph=graph, graph_updates=graph_updates) def forward_block(module: Block, x): return module(x) rngs = nnx.Rngs(0) module = create_block(rngs) assert module.dropout.rngs.count[...] == 0 assert module.linear.kernel.shape == (1, 20, 20) assert module.linear.bias.shape == (1, 20) x = jnp.ones((1, 10, 20)) y = forward_block(module, x) assert y.shape == (1, 10, 20) assert module.dropout.rngs.count[...] == 1 y2 = forward_block(module, x) # dropout is working! assert not jnp.allclose(y, y2) def test_replicate_single(self): din = 3 dout = 10 class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) def create_block(rngs: nnx.Rngs): return Block(rngs) state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) @nnx.split_rngs(splits=1) @partial(nnx.pmap, in_axes=(state_axes, 0), out_axes=0, axis_size=1, graph=True) def forward_block(module: Block, x): return module(x) rngs = nnx.Rngs(0) module = create_block(rngs) initial_key = module.dropout.rngs.key[...] assert module.dropout.rngs.count[...] == 0 assert module.linear.kernel.shape == (din, dout) assert module.linear.bias.shape == (dout,) x = jnp.ones((1, 5, din)) y = forward_block(module, x) assert y.shape == (1, 5, dout) assert module.dropout.rngs.count[...] == 1 y2 = forward_block(module, x) # dropout is working! assert not jnp.allclose(y, y2) assert module.dropout.rngs.key[...] == initial_key class TestCond(parameterized.TestCase): def test_basic(self): class TimeStep(tp.NamedTuple): step: nnx.Variable[jax.Array] reward: nnx.Variable[jax.Array] @staticmethod def zero(): return TimeStep( step=nnx.Variable(jnp.array(0)), reward=nnx.Variable(jnp.array(0.0)) ) @nnx.dataclass class Foo(nnx.Pytree): timestep: TimeStep = nnx.data() def update(self): def reward_2(self: Foo): self.timestep = TimeStep( step=nnx.Variable(self.timestep.step + 1), reward=nnx.Variable(jnp.array(2.0)), ) def reward_0(self: Foo): self.timestep = TimeStep( step=nnx.Variable(self.timestep.step + 1), reward=nnx.Variable(jnp.array(0.0)), ) nnx.cond(self.timestep.step % 2 == 0, reward_2, reward_0, self) foo = Foo(timestep=TimeStep.zero()) foo.update() self.assertEqual(foo.timestep.step[...], 1) self.assertEqual(foo.timestep.reward[...], 2.0) foo.update() self.assertEqual(foo.timestep.step[...], 2) self.assertEqual(foo.timestep.reward[...], 0.0) foo.update() self.assertEqual(foo.timestep.step[...], 3) self.assertEqual(foo.timestep.reward[...], 2.0) foo.update() self.assertEqual(foo.timestep.step[...], 4) self.assertEqual(foo.timestep.reward[...], 0.0) @parameterized.parameters( (True, False), (False, False), ) def test_basic_variable(self, graph, graph_updates): def collatz(x): def even(x): x[...] = x // 2 def odd(x): x[...] = 3 * x + 1 return nnx.cond( x % 2 == 0, even, odd, x, graph=graph, graph_updates=graph_updates, ) x = nnx.Variable(jnp.array(8)) collatz(x) self.assertEqual(x[...], 4) collatz(x) self.assertEqual(x[...], 2) collatz(x) self.assertEqual(x[...], 1) collatz(x) self.assertEqual(x[...], 4) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_cond_and_vmap(self, graph, graph_updates): class Env(nnx.Pytree): def __init__(self): self.index = nnx.Variable(jnp.arange(8)) self.step = nnx.Variable(jnp.zeros((8,), jnp.uint32)) env = Env() model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.vmap(in_axes=(0, None), out_axes=None, graph=graph, graph_updates=graph_updates) def f(env: Env, model: nnx.Linear): self.assertEqual(env.index.shape, ()) def increment(env: Env): env.step[...] += 1 def no_nothing(env: Env): pass is_even = env.index % 2 == 0 nnx.cond( is_even, increment, no_nothing, env, graph=graph, graph_updates=graph_updates, ) f(env, model) np.testing.assert_array_equal( env.step[...], np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) ) @parameterized.parameters( (True, False), (False, False), ) def test_cond_different_variable_per_branch(self, graph, graph_updates): a = nnx.Variable(jnp.array(0)) b = nnx.Variable(jnp.array(0)) def update_a(a, b): a[...] += 1 def update_b(a, b): b[...] += 10 nnx.cond( True, update_a, update_b, a, b, graph=graph, graph_updates=graph_updates, ) self.assertEqual(a[...], 1) self.assertEqual(b[...], 0) nnx.cond( False, update_a, update_b, a, b, graph=graph, graph_updates=graph_updates, ) self.assertEqual(a[...], 1) self.assertEqual(b[...], 10) def test_cond_shared_references(self): @dataclasses.dataclass class Foo(nnx.Module): a: nnx.Variable b: nnx.Variable v = nnx.Variable(jnp.array(0)) m = Foo(a=v, b=v) def true_fn(m): m.a[...] += 1 def false_fn(m): m.b[...] += 2 nnx.cond(True, true_fn, false_fn, m, graph=True, graph_updates=False) np.testing.assert_allclose(m.a[...], 1) np.testing.assert_allclose(m.b[...], 1) nnx.cond(False, true_fn, false_fn, m, graph=True, graph_updates=False) np.testing.assert_allclose(m.a[...], 3) np.testing.assert_allclose(m.b[...], 3) with self.assertRaises(ValueError): nnx.cond(True, true_fn, false_fn, m, graph=False) class TestSwitch(parameterized.TestCase): @parameterized.parameters( (True, False), (False, False), ) def test_basic(self, graph, graph_updates): class RoundTable(nnx.Module): def __init__(self): self.next_index = 0 self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) self.linear.kernel[...] = jnp.identity(10) self.rounds_count = nnx.Variable(jnp.array(0)) def __call__(self, x): def fn0(m, x): m.rounds_count[...] += 1 return m.linear(x) def fn1(m, x): return m.linear(x) * 2 def fn2(m, x): m.linear.kernel[...] = jnp.zeros((10, 10)) return m.linear(x) y = nnx.switch( self.next_index, (fn0, fn1, fn2), self, x, graph=graph, graph_updates=graph_updates, ) self.next_index = (self.next_index + 1) % 3 return y model = RoundTable() x = jnp.ones((10,)) np.testing.assert_array_equal(model(x), x) assert model.rounds_count[...] == 1 assert model.next_index == 1 np.testing.assert_array_equal(model(x), x * 2) assert model.rounds_count[...] == 1 assert model.next_index == 2 np.testing.assert_array_equal(model(x), jnp.zeros((10,))) assert model.rounds_count[...] == 1 assert model.next_index == 0 np.testing.assert_array_equal(model(x), jnp.zeros((10,))) assert model.rounds_count[...] == 2 assert model.next_index == 1 @parameterized.parameters( (True, False), (False, False), ) def test_switch_variable(self, graph, graph_updates): def add_1(x): x[...] += 1 def add_10(x): x[...] += 10 def add_100(x): x[...] += 100 x = nnx.Variable(jnp.array(0)) nnx.switch(0, (add_1, add_10, add_100), x, graph=graph, graph_updates=graph_updates) self.assertEqual(x[...], 1) nnx.switch(1, (add_1, add_10, add_100), x, graph=graph, graph_updates=graph_updates) self.assertEqual(x[...], 11) nnx.switch(2, (add_1, add_10, add_100), x, graph=graph, graph_updates=graph_updates) self.assertEqual(x[...], 111) def test_switch_shared_references(self): @dataclasses.dataclass class Foo(nnx.Module): a: nnx.Variable b: nnx.Variable v = nnx.Variable(jnp.array(0)) m = Foo(a=v, b=v) def add_a(m): m.a[...] += 1 def add_b(m): m.b[...] += 10 nnx.switch(0, (add_a, add_b), m, graph=True, graph_updates=False) np.testing.assert_allclose(m.a[...], 1) np.testing.assert_allclose(m.b[...], 1) nnx.switch(1, (add_a, add_b), m, graph=True, graph_updates=False) np.testing.assert_allclose(m.a[...], 11) np.testing.assert_allclose(m.b[...], 11) with self.assertRaises(ValueError): nnx.switch(0, (add_a, add_b), m, graph=False) class TestWhileLoop(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_basic(self, graph, graph_updates): def fwd_fn(input): m, x, c = input y = m(x) return m, y, c - 1.0 module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0), graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(y, x * 8) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_multiple_objects(self, graph, graph_updates): def fwd_fn(input): m1, (w2,), x, c = input y = m1(x) @ w2 return m1, (w2,), y, c - 1.0 m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) m1.kernel[...] = jnp.identity(10) * 2 w2 = nnx.Variable(jnp.identity(10) * 0.5) x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (m1, (w2,), x, 3.0), graph=graph, graph_updates=graph_updates) np.testing.assert_allclose(y, x) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_nested_module(self, graph, graph_updates): def fwd_fn(input): m, x, c = input y = m(x) return m, y, c - 1.0 module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 module = nnx.Sequential(module) x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0), graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(y, x * 8) def test_shared_module(self): m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0)) m2.kernel = m1.kernel module = nnx.Sequential(m1, m2) self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params def fwd_fn(input): m, x, c = input y = m(x) m.layers[0].kernel[...] = jnp.zeros_like(m.layers[0].kernel[...]) return m, y, c - 1.0 x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params np.testing.assert_array_equal( m1.kernel[...], jnp.zeros((10, 10)), ) np.testing.assert_array_equal( m2.kernel[...], jnp.zeros((10, 10)), ) np.testing.assert_array_equal(y, jnp.zeros((10,))) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_value_changed(self, graph, graph_updates): def fwd_fn(input): m, x, c = input m.kernel[...] = jnp.zeros_like(m.kernel) y = m(x) return m, y, c - 1.0 module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0), graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal( module.kernel[...], jnp.zeros((10, 10)), ) np.testing.assert_array_equal(y, jnp.zeros((10,))) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_ref_changed(self, graph, graph_updates): def fwd_fn(input): m, x, c = input y = m(x) m.kernel = nnx.Param(jnp.zeros_like(m.kernel)) return m, y, c - 1.0 module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) with self.assertRaises(ValueError): _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0), graph=graph, graph_updates=graph_updates) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_structure_changed(self, graph, graph_updates): def fwd_fn(input): m, x, c = input m = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(1)) m.kernel[...] = jnp.identity(10) * 2 y = m(x) return m, y, c - 1.0 module = nnx.Linear(10, 10, use_bias=True, rngs=nnx.Rngs(0)) x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) with self.assertRaises((ValueError, TypeError)): _, y, _ = nnx.while_loop( lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0), graph=graph, graph_updates=graph_updates) def test_repeated_object(self): m = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) def body_fn(val): count, m, _ = val return count + 1, m, m count, m, _ = nnx.while_loop( lambda val: val[0] < 2, body_fn, (0, m, m), ) def test_immut_fori_loop(self): def immut_fn(i, carry): g_accum = carry grads = jax.tree.map(jnp.ones_like, g_accum) g_accum = jax.tree.map(lambda gm, g: gm + g, g_accum, grads) return g_accum model = nnx.Linear(10, 10, rngs=nnx.Rngs(0), use_bias=False) g_accum = jax.tree.map(jnp.zeros_like, nnx.state(model)) nnx.fori_loop(0, 2, immut_fn, g_accum) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_fori_loop_grad_accum(self, graph, graph_updates): accum = nnx.Variable(jnp.zeros((10, 10))) def accum_fn(i, accum): accum[...] += 1 return accum accum = nnx.fori_loop(0, 3, accum_fn, accum, graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(accum[...], jnp.full((10, 10), 3.0)) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_fori_loop_basic(self, graph, graph_updates): def fwd_fn(i, input): m, x = input m.kernel[...] = jnp.identity(10) * i return m, m(x) module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) x = jax.random.normal(jax.random.key(0), (10,)) _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x), graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(y, x * 2 * 3) def test_fori_loop_with_sharing(self): class A(nnx.Pytree): def __init__(self): self.params = nnx.Param(jnp.zeros((10,), dtype=int)) class B(nnx.Pytree): def __init__(self, a: A): self.a = a class C(nnx.Pytree): def __init__(self, a: A): self.a = a class D(nnx.Pytree): def __init__(self): self.a = A() self.b = B(self.a) self.c = C(self.a) def increment(_, d: D) -> D: d.a.params[...] += 1 return d @nnx.jit def rollout(d: D): nnx.fori_loop(0, 10, increment, d) d = D() rollout(d) np.testing.assert_array_equal( d.a.params[...], np.full((10,), 10, dtype=int) ) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_loops_multiple_modules(self, graph, graph_updates): class Foo(nnx.Module): def __init__(self): self.param = nnx.Param(jnp.zeros((1,))) def __call__(self, x): return self.param def loop_fn(inputs): return inputs while_loop_fn = lambda inputs: (*loop_fn(inputs[:-1]), inputs[-1]-1) fori_loop_fn = lambda i, inputs: loop_fn(inputs) a = Foo() b = Foo() nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2), graph=graph, graph_updates=graph_updates) nnx.fori_loop(0, 2, fori_loop_fn, (a, b), graph=graph) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_while_loop_stateful(self, graph, graph_updates): class Counter(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) counter = Counter() module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 x = jax.random.normal(jax.random.key(0), (10,)) def body_fn(val): counter, module, x, i = val counter.count[...] += 1 x = module(x) return counter, module, x, i - 1 counter, module, y, _ = nnx.while_loop( lambda val: val[-1] > 0, body_fn, (counter, module, x, 3), graph=graph, graph_updates=graph_updates, ) np.testing.assert_array_equal(counter.count[...], 3) np.testing.assert_array_equal(y, x * 8) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_while_loop_inside_jit(self, graph, graph_updates): module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 x = jax.random.normal(jax.random.key(0), (10,)) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(module, x): def body_fn(val): m, x, c = val return m, m(x), c - 1.0 _, y, _ = nnx.while_loop( lambda val: val[-1] > 0, body_fn, (module, x, 3.0), graph=graph, ) return y y = f(module, x) np.testing.assert_array_equal(y, x * 8) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_fori_loop_stateful(self, graph, graph_updates): class Counter(nnx.Module): def __init__(self): self.count = nnx.Variable(jnp.array(0)) counter = Counter() module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 x = jax.random.normal(jax.random.key(0), (10,)) def body_fn(i, val): counter, module, x = val counter.count[...] += 1 x = module(x) return counter, module, x counter, module, y = nnx.fori_loop( 0, 3, body_fn, (counter, module, x), graph=graph, graph_updates=graph_updates, ) np.testing.assert_array_equal(counter.count[...], 3) np.testing.assert_array_equal(y, x * 8) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_fori_loop_inside_jit(self, graph, graph_updates): module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) module.kernel[...] = jnp.identity(10) * 2 x = jax.random.normal(jax.random.key(0), (10,)) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(module, x): def body_fn(i, val): m, x = val return m, m(x) _, y = nnx.fori_loop( 0, 3, body_fn, (module, x), graph=graph, ) return y y = f(module, x) np.testing.assert_array_equal(y, x * 8) class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): class StatefulLinear(nnx.Linear): def __init__(self, din: int, dout: int, rngs: nnx.Rngs): super().__init__(din, dout, rngs=rngs) self.counter = nnx.BatchStat(jnp.array(0, jnp.uint32)) def __call__(self, x): self.counter[...] += 1 return super().__call__(x) model = StatefulLinear(3, 4, rngs=nnx.Rngs(0)) @general.split_inputs @jax.jit @general.merge_inputs def forward(model, x): return model(x) x = jnp.ones((2, 3)) y = forward(model, x) self.assertEqual(model.counter[...], 1) def test_split_inputs_cond(self): class Counter(nnx.Linear): def __init__(self): self.count = nnx.BatchStat(jnp.array(0, jnp.uint32)) def increment(self): self.count[...] += 1 counter = Counter() @general.merge_inputs def increment(counter: Counter): counter.increment() @general.merge_inputs def no_nothing(counter: Counter): pass general.split_inputs(jax.lax.cond)(True, increment, no_nothing, counter) self.assertEqual(counter.count[...], 1) general.split_inputs(jax.lax.cond)(False, increment, no_nothing, counter) self.assertEqual(counter.count[...], 1) def test_split_inputs_vmap(self): class EnvState(nnx.Variable[nnx.A]): pass class Env(nnx.Pytree): def __init__(self): self.index = EnvState(jnp.arange(8)) self.step = EnvState(jnp.zeros((8,), jnp.uint32)) env = Env() model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) # internally merge_inputs returns (args, out) in_axes = (0, None) out_axes = (in_axes, None) @general.split_inputs @partial(jax.vmap, in_axes=in_axes, out_axes=out_axes) @general.merge_inputs def f(env: Env, model: nnx.Linear): self.assertEqual(env.index.shape, ()) @general.merge_inputs def increment(env: Env): env.step[...] += 1 @general.merge_inputs def no_nothing(env: Env): pass is_even = env.index % 2 == 0 general.split_inputs(jax.lax.cond)(is_even, increment, no_nothing, env) f(env, model) np.testing.assert_array_equal( env.step[...], np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) ) class TestCheckify(parameterized.TestCase): @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_basic(self, graph, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): a: nnx.Param @nnx.jit(graph=graph, graph_updates=graph_updates) def f(m): y = jnp.sin(m.a) # error return m.a + y m = Foo(a=nnx.Param(jnp.inf)) err, out = nnx.checkify( f, errors=checkify.float_checks, graph=graph, graph_updates=graph_updates, )(m) with self.assertRaisesRegex(ValueError, 'nan generated by primitive: sin'): err.throw() @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_checkify_stateful(self, graph, graph_updates): count = nnx.Variable(jnp.array(0)) @nnx.jit(graph=graph, graph_updates=graph_updates) def f(c): c[...] += 1 return c[...] err, out = nnx.checkify( f, graph=graph, graph_updates=graph_updates, )(count) self.assertEqual(count[...], 1) np.testing.assert_allclose(out, 1) class TestBoundMethodTransforms(parameterized.TestCase): def test_remat_with_bound_method_raises(self): class M(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.count = nnx.BatchStat(0) def block(self, x: jax.Array) -> jax.Array: self.count[...] += 1 return self.linear(x) m = M(rngs=nnx.Rngs(0)) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.remat(m.block) def test_jit_with_bound_method_raises(self): class M(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) def apply(self, x: jax.Array, scale: int): return self.linear(x) * scale m = M(rngs=nnx.Rngs(0)) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.jit(m.apply, static_argnums=1) def test_vmap_with_bound_method_raises(self): class M(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) def __call__(self, x: jax.Array): return self.linear(x) m = M(rngs=nnx.Rngs(0)) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.vmap(m.__call__, in_axes=(0,), out_axes=0) def test_eval_shape_with_bound_method_raises(self): class M(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) def __call__(self, x: jax.Array): return self.linear(x) m = M(rngs=nnx.Rngs(0)) x_spec = jax.ShapeDtypeStruct((1, 2), jnp.float32) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.eval_shape(m.__call__, x_spec) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_grad_with_bound_method_raises(self, graph_mode, graph_updates): class M(nnx.Module): def __init__(self): self.w = nnx.Param(jnp.array(1.0)) def loss(self, s: float): return (self.w * s) ** 2 m = M() with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.grad(m.loss, graph=graph_mode) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_value_and_grad_with_bound_method_raises(self, graph_mode, graph_updates): class TestModel(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 1, rngs=rngs) def loss_fn(self, x, y): pred = self.linear(x) return jnp.mean((pred - y) ** 2) model = TestModel(rngs=nnx.Rngs(0)) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.value_and_grad(model.loss_fn, graph=graph_mode) def test_checkify_with_bound_method_raises(self): """Test that checkify raises error for bound methods.""" class M(nnx.Module): def __call__(self, x: jax.Array): return x + 1 m = M() with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.checkify(m.__call__) def test_pmap_with_bound_method_raises(self): """Test that pmap raises error for bound methods.""" class M(nnx.Module): def __call__(self, x: jax.Array): return x + 1 m = M() with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.pmap(m.__call__) def test_shard_map_with_bound_method_raises(self): """Test that shard_map raises error for bound methods.""" class M(nnx.Module): def __call__(self, x: jax.Array): return x + 1 m = M() mesh = jax.sharding.Mesh(jax.local_devices()[:1], ('data',)) with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.shard_map(m.__call__, mesh=mesh, in_specs=None, out_specs=None) def test_custom_vjp_with_bound_method_raises(self): """Test that custom_vjp raises error for bound methods.""" class M(nnx.Module): def __call__(self, x: jax.Array): return x + 1 m = M() with self.assertRaisesRegex(ValueError, 'bound methods'): nnx.custom_vjp(m.__call__) def test_scan_bound_method_raises(self): class M(nnx.Module): def __call__(self, x: jax.Array): return x + 1 m = M() with self.assertRaisesRegex(ValueError, 'bound methods'): _ = nnx.scan(m.__call__, in_axes=(0,), out_axes=0) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_pmap_basic(self, graph, graph_updates): class LinearEnsemble(nnx.Module): def __init__(self, num, *, rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) model = LinearEnsemble(1, rngs=nnx.Rngs(0)) x = jnp.ones((2,)) @nnx.pmap(in_axes=(0, None), out_axes=0, axis_size=1, graph=graph, graph_updates=graph_updates) def forward(model, x): return x @ model.w y = forward(model, x) assert y.shape == (1, 3) @parameterized.parameters( (True, True), (True, False), (False, False), ) def test_tree_mode_pmap_stateful(self, graph, graph_updates): class Counter(nnx.Variable): pass class Linear(nnx.Module): def __init__(self, din, dout, *, rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.count = Counter(jnp.array(0)) def __call__(self, x): self.count[...] += 1 return x @ self.w model = Linear(2, 3, rngs=nnx.Rngs(0)) @nnx.pmap(in_axes=(None, 0), out_axes=0, axis_size=1, graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) x = jnp.ones((1, 2)) y = forward(model, x) assert y.shape == (1, 3) assert model.count.get_value() == 1 def test_tree_mode_pmap_split_merge(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 10, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: x = self.linear(x) x = nnx.elu(x) x = self.dropout(x) return x rngs = nnx.Rngs(0) @nnx.split_rngs(splits=1, graph=False) @nnx.pmap(in_axes=0, out_axes=(None, 0, 0, None), axis_size=1, graph=False) def create_block(rngs): block = Block(rngs) graphdef, params_state, rng_state, rest_state = nnx.split( block, nnx.Param, nnx.RngState, ..., ) return graphdef, params_state, rng_state, rest_state graphdef, params_state, rng_state, rest_state = create_block(rngs) assert rng_state.dropout.rngs.count[0] == 0 assert params_state.linear.kernel.shape == (1, 3, 10) assert params_state.linear.bias.shape == (1, 10) x = jnp.ones((1, 1, 3)) @nnx.pmap(in_axes=(0, 0, None, 0), axis_size=1, graph=False) def forward_block(params_state, rng_state, rest_state, x): return nnx.merge(graphdef, params_state, rng_state, rest_state)(x) y = forward_block(params_state, rng_state, rest_state, x) assert y.shape == (1, 1, 10) y2 = forward_block(params_state, rng_state, rest_state, x) assert not jnp.allclose(y, y2) def test_tree_mode_pmap_replicate(self): din = 3 dout = 10 class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) rngs = nnx.Rngs(0) module = Block(rngs) assert module.dropout.rngs.count[...] == 0 assert module.linear.kernel.shape == (din, dout) assert module.linear.bias.shape == (dout,) module = nnx.split_rngs(module, splits=1, graph=False) graphdef, rng_state, rest_state = nnx.split( module, nnx.RngState, ..., ) @nnx.pmap(in_axes=(0, None, 0), out_axes=0, axis_size=1, graph=False) def forward_block(rng_state, rest_state, x): module = nnx.merge(graphdef, rng_state, rest_state) y = module(x) return y x = jnp.ones((1, 5, din)) y = forward_block(rng_state, rest_state, x) assert y.shape == (1, 5, dout) assert module.dropout.rngs.count[0] == 1 y2 = forward_block(rng_state, rest_state, x) assert module.dropout.rngs.count[0] == 2 assert not jnp.allclose(y, y2) class TestPureJaxFancyScan(absltest.TestCase): def test_carry_and_scan(self): def cumsum(carry, x): carry = carry + x return carry, carry final_carry, ys = pure_jax_fancy_scan( cumsum, jnp.array(0.0), jnp.arange(5.0), in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry, 10.0) np.testing.assert_allclose(ys, jnp.array([0., 1., 3., 6., 10.])) def test_carry_only_output(self): def sum_fn(carry, x): return carry + x result = pure_jax_fancy_scan( sum_fn, jnp.array(0.0), jnp.arange(5.0), in_axes=(nnx.Carry, 0), out_axes=nnx.Carry, ) np.testing.assert_allclose(result, 10.0) def test_broadcast_args(self): def scale_cumsum(carry, scale, x): carry = carry + x * scale return carry, carry final_carry, _ = pure_jax_fancy_scan( scale_cumsum, jnp.array(0.0), jnp.array(2.0), jnp.arange(5.0), in_axes=(nnx.Carry, None, 0), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry, 20.0) def test_pytree_carry(self): def dict_scan(carry, x): carry = {'a': carry['a'] + x['a'], 'b': carry['b'] + x['b']} return carry, carry xs = {'a': jnp.arange(3.0), 'b': jnp.ones(3)} init = {'a': jnp.array(0.0), 'b': jnp.array(0.0)} final_carry, _ = pure_jax_fancy_scan( dict_scan, init, xs, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry['a'], 3.0) np.testing.assert_allclose(final_carry['b'], 3.0) def test_no_carry_all_scanned(self): def double(x): return (x * 2,) (ys,) = pure_jax_fancy_scan( double, jnp.arange(5.0), in_axes=(0,), out_axes=(0,), ) np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) def test_reverse(self): def cumsum(carry, x): carry = carry + x return carry, carry final_carry, _ = pure_jax_fancy_scan( cumsum, jnp.array(0.0), jnp.arange(5.0), in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), reverse=True, ) np.testing.assert_allclose(final_carry, 10.0) def test_pytree_prefix_in_axes(self): def fn(carry, x): carry = carry + x['a'] + x['b'] return carry, carry xs = {'a': jnp.arange(3.0), 'b': jnp.array(1.0)} final_carry, _ = pure_jax_fancy_scan( fn, jnp.array(0.0), xs, in_axes=(nnx.Carry, {'a': 0, 'b': None}), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry, 6.0) def test_nested_carry_rejected(self): with self.assertRaises(ValueError): pure_jax_fancy_scan( lambda x: x, {'a': jnp.array(1.0)}, in_axes=({'a': nnx.Carry},), out_axes=nnx.Carry, ) def test_broadcast_out_axes_rejected(self): with self.assertRaises(ValueError): pure_jax_fancy_scan( lambda c, x: (c, x), jnp.array(0.0), jnp.arange(3.0), in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None), ) def test_none_broadcast_input(self): def fn(carry, _unused, x): carry = carry + x return carry, carry final_carry, _ = pure_jax_fancy_scan( fn, jnp.array(0.0), None, jnp.arange(3.0), in_axes=(nnx.Carry, None, 0), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry, 3.0) def test_none_nested_in_arg(self): def fn(carry, x): carry = carry + x['a'] return carry, carry xs = {'a': jnp.arange(3.0), 'b': None} final_carry, _ = pure_jax_fancy_scan( fn, jnp.array(0.0), xs, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), ) np.testing.assert_allclose(final_carry, 3.0) def test_nested_carry_in_out_axes_rejected(self): with self.assertRaises(ValueError): pure_jax_fancy_scan( lambda c, x: (c, x), jnp.array(0.0), jnp.arange(3.0), in_axes=(nnx.Carry, 0), out_axes=({'a': nnx.Carry},), ) def test_carry_in_in_axes_only_rejected(self): with self.assertRaises(ValueError): pure_jax_fancy_scan( lambda c, x: (c + x,), jnp.array(0.0), jnp.arange(3.0), in_axes=(nnx.Carry, 0), out_axes=(0,), ) def test_carry_in_out_axes_only_rejected(self): with self.assertRaises(ValueError): pure_jax_fancy_scan( lambda x: x, jnp.arange(3.0), in_axes=(0,), out_axes=nnx.Carry, ) def test_non_tuple_carry_only(self): def f(carry): return carry + 1.0 result = pure_jax_fancy_scan( f, jnp.array(0.0), in_axes=nnx.Carry, out_axes=nnx.Carry, length=5, ) np.testing.assert_allclose(result, 5.0) def test_non_tuple_scan_only(self): def f(x): return x * 2 ys = pure_jax_fancy_scan( f, jnp.arange(5.0), in_axes=0, out_axes=0, ) np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/nnx/variable_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typing as tp import jax import jax.numpy as jnp import numpy as np import pytest from absl.testing import absltest, parameterized from flax import nnx A = tp.TypeVar('A') class TestVariable(parameterized.TestCase): def test_pytree(self): r1 = nnx.Param(1) self.assertEqual(r1.get_value(), 1) r2 = jax.tree.map(lambda x: x + 1, r1) self.assertEqual(r1.get_value(), 1) self.assertEqual(r2.get_value(), 2) self.assertIsNot(r1, r2) def test_overloads_module(self): class Linear(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): key = rngs() self.w = nnx.Param(jax.random.normal(key, (din, dout))) self.b = nnx.Param(jax.numpy.zeros((dout,))) def __call__(self, x: jax.Array): return x @ self.w + self.b linear = Linear(3, 4, nnx.Rngs(0)) x = jax.numpy.ones((3,)) y = linear(x) self.assertEqual(y.shape, (4,)) def test_jax_array(self): class Linear(nnx.Module): def __init__(self, din, dout, rngs: nnx.Rngs): key = rngs() self.w = nnx.Param(jax.random.normal(key, (din, dout))) self.b = nnx.Param(jax.numpy.zeros((dout,))) def __call__(self, x: jax.Array): return jnp.dot(x, self.w) + self.b # type: ignore[arg-type] linear = Linear(3, 4, nnx.Rngs(0)) x = jax.numpy.ones((3,)) y = linear(x) self.assertEqual(y.shape, (4,)) def test_proxy_access(self): v = nnx.Param(jnp.ones((2, 3))) t = v.T self.assertEqual(t.shape, (3, 2)) def test_proxy_call(self): class Callable(tp.NamedTuple): value: int def __call__(self, x): return self.value * x v = nnx.Param(Callable(2)) result = v(3) self.assertEqual(result, 6) def test_binary_ops(self): v1 = nnx.Param(jnp.array(2)) v2 = nnx.Param(jnp.array(3)) result = v1 + v2 self.assertEqual(result, 5) self.assertFalse(v1 == v2) v1[...] += v2 self.assertEqual(v1[...], 5) @parameterized.product( v1=[np.array([1, 2]), np.array(2), 3], v2=[np.array([1, 2]), np.array(2), 3], ) def test_eq_op(self, v1, v2): p1 = nnx.Param(jnp.asarray(v1) if isinstance(v1, np.ndarray) else v1) p2 = nnx.Param(jnp.asarray(v2) if isinstance(v2, np.ndarray) else v2) if isinstance(v1, np.ndarray) or isinstance(v2, np.ndarray): self.assertEqual((p1 == p2).all(), (v1 == v2).all()) else: self.assertEqual(p1 == p2, v1 == v2) def test_mutable_array_context(self): initial_mode = nnx.var_defaults().hijax with nnx.var_defaults(hijax=False): v = nnx.Variable(jnp.array(1.0)) self.assertEqual(nnx.var_defaults().hijax, False) self.assertNotIsInstance(v[...], jax.Ref) with nnx.var_defaults(hijax=True): v = nnx.Variable(jnp.array(1.0)) self.assertEqual(nnx.var_defaults().hijax, True) self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(2.0)) self.assertIsInstance(v[...], jax.Array) self.assertEqual(nnx.var_defaults().hijax, False) nnx.var_defaults(hijax=True) v = nnx.Variable(jnp.array(0.0)) self.assertEqual(nnx.var_defaults().hijax, True) self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(1.0)) self.assertEqual(nnx.var_defaults().hijax, initial_mode) self.assertIsInstance(v[...], jax.Array) def test_get_set_metadata(self): v = nnx.Variable(jnp.array(1.0)) self.assertEqual( v.get_metadata(), { 'hijax': False, 'ref': False, 'eager_sharding': True, }, ) v.set_metadata(a=1, b=2) self.assertEqual(v.get_metadata('a'), 1) self.assertEqual(v.get_metadata('b'), 2) v.set_metadata({ 'b': 3, 'c': 4, 'hijax': False, 'ref': False, 'eager_sharding': True, }) self.assertEqual( v.get_metadata(), { 'b': 3, 'c': 4, 'hijax': False, 'ref': False, 'eager_sharding': True, }, ) self.assertEqual(v.get_metadata('b'), 3) self.assertEqual(v.get_metadata('c'), 4) c = v.get_metadata('c') self.assertEqual(c, 4) x = v.get_metadata('x', default=10) self.assertEqual(x, 10) def test_set_module_metadata(self): class Module(nnx.Module): def __init__(self): self.v = nnx.Variable(jnp.array(0.0)) self.p = nnx.Param(jnp.array(1.0)) m = Module() self.assertNotIn('foo', m.v.get_metadata()) self.assertNotIn('foo', m.p.get_metadata()) nnx.set_metadata(m, foo='bar') # Check that foo was added but the default metadata is still there v_metadata = m.v.get_metadata() p_metadata = m.p.get_metadata() self.assertEqual(v_metadata['foo'], 'bar') self.assertEqual(p_metadata['foo'], 'bar') # Check that default metadata is preserved self.assertIn('hijax', v_metadata) self.assertIn('ref', v_metadata) self.assertNotIn('differentiable', m.v.get_metadata()) self.assertNotIn('differentiable', m.p.get_metadata()) nnx.set_metadata(m, differentiable=False, only=nnx.Param) # Check that v still has foo but not differentiable v_metadata = m.v.get_metadata() self.assertEqual(v_metadata['foo'], 'bar') self.assertNotIn('differentiable', v_metadata) # Check that p has both foo and differentiable p_metadata = m.p.get_metadata() self.assertEqual(p_metadata['foo'], 'bar') self.assertEqual(p_metadata['differentiable'], False) @pytest.mark.skip(reason="Ref doesn't support broadcasting yet") def test_broadcasting(self): v = nnx.Param(jnp.array([1.0, 2.0, 3.0])) x = v[None] self.assertEqual(x.shape, (1, 3)) def test_set_metadata_out_sharding(self): v = nnx.Variable(jnp.array(1.0)) v.set_metadata(out_sharding=jax.sharding.PartitionSpec(None)) self.assertEqual( v.get_metadata('out_sharding'), jax.sharding.PartitionSpec(None) ) v.set_metadata('out_sharding', jax.sharding.PartitionSpec('x')) self.assertEqual( v.get_metadata('out_sharding'), jax.sharding.PartitionSpec('x') ) v.set_metadata({ 'out_sharding': jax.sharding.PartitionSpec('y'), 'hijax': False, 'ref': False, 'eager_sharding': True, }) self.assertEqual( v.get_metadata('out_sharding'), jax.sharding.PartitionSpec('y') ) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/pickle_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.errors.""" from absl.testing import absltest from flax.errors import FlaxError, ScopeVariableNotFoundError import pickle class ErrorrsTest(absltest.TestCase): def test_exception_can_be_pickled(self): # tests the new __reduce__ method fixes bug reported in issue #4000 ex = ScopeVariableNotFoundError('varname', 'collection', 'scope') pickled_ex = pickle.dumps(ex) unpicked_ex = pickle.loads(pickled_ex) self.assertIsInstance(unpicked_ex, FlaxError) self.assertIn('varname', str(unpicked_ex)) self.assertIn('#flax.errors.ScopeVariableNotFoundError', str(unpicked_ex)) self.assertNotIn('#flax.errors.FlaxError', str(unpicked_ex)) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/run_all_tests.sh ================================================ #!/bin/bash # export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python PYTEST_OPTS= RUN_DOCTEST=false RUN_MYPY=false RUN_PYTEST=false RUN_PYTYPE=false GH_VENV=false for flag in "$@"; do case $flag in --with-cov) PYTEST_OPTS+="--cov=flax --cov-report=xml --cov-report=term --cov-config=pyproject.toml" ;; --help) echo "Usage:" echo " --with-cov: Also generate pytest coverage." exit ;; --only-doctest) RUN_DOCTEST=true ;; --only-pytest) RUN_PYTEST=true ;; --only-pytype) RUN_PYTYPE=true ;; --only-mypy) RUN_MYPY=true ;; --use-venv) GH_VENV=true ;; *) echo "Unknown flag: $flag" exit 1 ;; esac done # if neither --only-doctest, --only-pytest, --only-pytype, --only-mypy is set, run all tests if ! $RUN_DOCTEST && ! $RUN_PYTEST && ! $RUN_PYTYPE && ! $RUN_MYPY; then RUN_DOCTEST=true RUN_PYTEST=true RUN_PYTYPE=true RUN_MYPY=true fi # Activate cached virtual env for github CI if $GH_VENV; then source $(dirname "$0")/../.venv/bin/activate fi echo "====== test config =======" echo "PYTEST_OPTS: $PYTEST_OPTS" echo "RUN_DOCTEST: $RUN_DOCTEST" echo "RUN_PYTEST: $RUN_PYTEST" echo "RUN_MYPY: $RUN_MYPY" echo "RUN_PYTYPE: $RUN_PYTYPE" echo "GH_VENV: $GH_VENV" echo "WHICH PYTHON: $(which python)" echo "jax: $(python -c 'import jax; print(jax.__version__)')" echo "flax: $(python -c 'import flax; print(flax.__version__)')" echo "==========================" echo "" sh $(dirname "$0")/download_dataset_metadata.sh || exit # Instead of using set -e, we have a manual error trap that # exits for any error code != 5 since pytest returns error code 5 # for no found tests. (We may force minimal test coverage in examples # in the future!) trap handle_errors ERR handle_errors () { ret="$?" if [[ "$ret" == 5 ]]; then echo "error code $ret == no tests found in $egd" else echo "error code $ret" exit 1 fi } # Run embedded tests inside docs if $RUN_DOCTEST; then echo "=== RUNNING DOCTESTS ===" # test doctest sphinx-build -M doctest docs docs/_build -T sphinx-build -M doctest docs_nnx docs_nnx/_build -T # test build html sphinx-build -M html docs docs/_build -T sphinx-build -M html docs_nnx docs_nnx/_build -T # test docstrings pytest -n auto flax \ --doctest-modules \ --suppress-no-test-exit-code \ --ignore=flax/nnx/examples fi # check that flax is running on editable mode # (i.e. no notebook installed flax from pypi) echo "=== CHECKING FLAX IS EDITABLE ===" assert_error="flax is not running on editable mode." (cd docs; python -c "import flax; assert 'site-packages' not in flax.__file__, \"$assert_error\"") # env vars must be set after doctest export JAX_NUMPY_RANK_PROMOTION=raise export FLAX_PROFILE=1 if $RUN_PYTEST; then echo "=== RUNNING PYTESTS ===" # Run battery of core FLAX API tests. echo "XLA_FLAGS='--xla_force_host_platform_device_count=4' pytest -n auto tests $PYTEST_OPTS" XLA_FLAGS='--xla_force_host_platform_device_count=4' pytest -n auto tests $PYTEST_OPTS # Run nnx tests pytest -n auto docs/_ext/codediff_test.py $PYTEST_OPTS $PYTEST_IGNORE pytest -n auto docs_nnx/_ext/codediff_test.py $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests. # # we apply pytest within each example to avoid pytest's annoying test-filename collision. # In pytest foo/bar/baz_test.py and baz/bleep/baz_test.py will collide and error out when # /foo/bar and /baz/bleep aren't set up as packages. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" if [[ $egd == *"_"* ]]; then continue fi # skiping examples until tfds issue is resolved # pytest $egd done fi if $RUN_PYTYPE; then echo "=== RUNNING PYTYPE ===" # Validate types in examples. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do # skip if folder starts with "_" or is "nnx_toy_examples" if [[ $egd == *"_"* ]] || [[ $egd == *"nnx_toy_examples"* ]]; then continue fi # use cd to make sure pytype cache lives in example dir and doesn't name clash # use *.py to avoid importing configs as a top-level import which leads to import errors # because config files use relative imports (e.g. from config import ...). (cd $egd ; pytype "*.py" --jobs auto --config ../../pyproject.toml) done # Validate types in library code. pytype --jobs auto --config pyproject.toml flax/ fi if $RUN_MYPY; then echo "=== RUNNING MYPY ===" # Validate types in library code. mypy --config pyproject.toml flax/ --show-error-codes fi # Return error code 0 if no real failures happened. echo "finished all tests." ================================================ FILE: tests/serialization_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.struct and flax.serialization.""" import collections import platform from typing import Any, NamedTuple import jax import jax.numpy as jnp import msgpack import numpy as np import optax import pytest from absl.testing import absltest, parameterized from jax import random from jax.tree_util import Partial from flax import linen as nn from flax import serialization, struct from flax.core import freeze from flax.training import train_state # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @struct.dataclass class Point: x: float y: float meta: Any = struct.field(pytree_node=False) @struct.dataclass class Box: value: int def to_state_dict(box: Box): return box.value def from_state_dict(box: Box, state: Any): return box.replace(value=state) serialization.register_serialization_state( Box, to_state_dict, from_state_dict, override=True ) class OriginalTuple(NamedTuple): value: Any class WrongTuple(NamedTuple): wrong_field: Any class OriginalModule(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(10)(x) return x class WrongModule(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(10)(x) x = nn.Dense(10)(x) return x class SerializationTest(parameterized.TestCase): def test_dataclass_serialization(self): p = Point(x=1, y=2, meta={'dummy': True}) state_dict = serialization.to_state_dict(p) self.assertEqual( state_dict, { 'x': 1, 'y': 2, }, ) restored_p = serialization.from_state_dict(p, {'x': 3, 'y': 4}) expected_p = Point(x=3, y=4, meta={'dummy': True}) self.assertEqual(restored_p, expected_p) with self.assertRaises(ValueError): # invalid field serialization.from_state_dict(p, {'z': 3}) with self.assertRaises(ValueError): # missing field serialization.from_state_dict(p, {'x': 3}) def test_pass_through_serialization(self): p = Box(value=123) state_dict = serialization.to_state_dict(p) self.assertEqual( state_dict, 123, ) restored_box = serialization.from_state_dict(p, 123) expected_box = Box(value=123) self.assertEqual(restored_box, expected_box) def test_model_serialization(self): rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) x = jnp.ones((1, 1), jnp.float32) initial_params = module.init(rng, x) state = serialization.to_state_dict(initial_params) self.assertEqual( state, { 'params': { 'kernel': np.ones((1, 1)), 'bias': np.zeros((1,)), } }, ) state = { 'params': { 'kernel': np.zeros((1, 1)), 'bias': np.zeros((1,)), } } restored_model = serialization.from_state_dict(initial_params, state) self.assertEqual(restored_model, freeze(state)) def test_partial_serialization(self): add_one = Partial(jnp.add, 1) state = serialization.to_state_dict(add_one) self.assertEqual(state, {'args': {'0': 1}, 'keywords': {}}) restored_add_one = serialization.from_state_dict(add_one, state) self.assertEqual(add_one.args, restored_add_one.args) def test_optimizer_serialization(self): rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) x = jnp.ones((1, 1), jnp.float32) initial_params = module.init(rng, x) tx = optax.sgd(0.1, momentum=0.1) tx_state = tx.init(initial_params) state = serialization.to_state_dict(tx_state) expected_state = { '0': { 'trace': { 'params': { 'bias': np.array([0.0], dtype=jnp.float32), 'kernel': np.array([[0.0]], dtype=jnp.float32), } } }, '1': {}, } self.assertEqual(state, expected_state) state = jax.tree_util.tree_map(lambda x: x + 1, expected_state) restored_tx_state = serialization.from_state_dict(tx_state, state) tx_state_plus1 = jax.tree_util.tree_map(lambda x: x + 1, tx_state) self.assertEqual(restored_tx_state, tx_state_plus1) def test_collection_serialization(self): @struct.dataclass class DummyDataClass: x: float @classmethod def initializer(cls, shape): del shape return cls(x=0.0) class StatefulModule(nn.Module): @nn.compact def __call__(self): state = self.variable('state', 'dummy', DummyDataClass.initializer, ()) state.value = state.value.replace(x=state.value.x + 1.0) initial_variables = StatefulModule().init(random.key(0)) _, variables = StatefulModule().apply(initial_variables, mutable=['state']) serialized_state_dict = serialization.to_state_dict(variables) self.assertEqual(serialized_state_dict, {'state': {'dummy': {'x': 2.0}}}) deserialized_state = serialization.from_state_dict( variables, serialized_state_dict ) self.assertEqual(variables, deserialized_state) @parameterized.parameters([ 'byte', 'b', 'ubyte', 'short', 'h', 'ushort', 'i', 'uint', 'intp', 'p', 'uintp', 'long', 'l', 'longlong', 'q', 'ulonglong', 'half', 'e', 'f', 'double', 'd', 'longdouble', 'g', 'cdouble', 'clongdouble', 'm', 'b1', 'int64', 'i8', 'uint64', 'u8', 'float16', 'f2', 'float32', 'f4', 'float64', 'f8', 'float128', 'f16', 'complex64', 'c8', 'complex128', 'c16', 'complex256', 'c32', 'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', 'i2', 'uint16', 'u2', 'int8', 'i1', 'uint8', 'u1', 'single', 'csingle', 'intc', 'uintc', 'int', 'float', 'complex', 'bool', ]) def test_numpy_serialization(self, dtype): np.random.seed(0) if ( (dtype in {'float128', 'f16', 'complex256', 'c32'}) and (platform.system() == 'Darwin') and (platform.machine() == 'arm64') ): pytest.skip( f'Mac M1 does not support dtype {dtype}' ) # skip testing these dtypes if user is on Mac M1 v = np.random.uniform(-100, 100, size=()).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v) ) self.assertEqual(restored_v.dtype, v.dtype) np.testing.assert_array_equal(restored_v, v) for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: arr = np.random.uniform(-100, 100, size=shape).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr) ) self.assertEqual(restored_arr.dtype, arr.dtype) np.testing.assert_array_equal(restored_arr, arr) def test_jax_numpy_serialization(self): jax_dtypes = [ jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float16, jnp.float32, jnp.complex64, ] for dtype in jax_dtypes: v = jnp.array(np.random.uniform(-100, 100, size=())).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v) ) self.assertEqual(restored_v.dtype, v.dtype) np.testing.assert_array_equal(restored_v, v) for shape in [(), (5,), (10, 10), (1, 20, 30, 1)]: arr = jnp.array(np.random.uniform(-100, 100, size=shape)).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr) ) self.assertEqual(restored_arr.dtype, arr.dtype) np.testing.assert_array_equal(restored_arr, arr) def test_complex_serialization(self): for x in [1j, 1 + 2j]: restored_x = serialization.msgpack_restore( serialization.msgpack_serialize(x) ) self.assertEqual(x, restored_x) def test_restore_chunked(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = serialization.to_bytes(tmp) restored = serialization.msgpack_restore(serialized) finally: serialization.MAX_CHUNK_SIZE = old_chunksize np.testing.assert_array_equal(restored, tmp) def test_restore_unchunked(self): """Check if mgspack_restore works for unchunked inputs.""" def msgpack_serialize_legacy(pytree): """Old implementation that was not chunking.""" return msgpack.packb( pytree, default=serialization._msgpack_ext_pack, strict_types=True ) tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = msgpack_serialize_legacy(tmp) old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: restored = serialization.msgpack_restore(serialized) finally: serialization.MAX_CHUNK_SIZE = old_chunksize np.testing.assert_array_equal(restored, tmp) def test_namedtuple_serialization(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3) x1_serialized = serialization.to_bytes(x1) x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_bytes(x2, x1_serialized) self.assertEqual(type(x1), type(restored_x1)) self.assertEqual(x1, restored_x1) def test_namedtuple_restore_legacy(self): foo_class = collections.namedtuple('Foo', 'a b c') x1 = foo_class(a=1, b=2, c=3) legacy_encoding = { 'name': 'Foo', 'fields': {'0': 'a', '1': 'b', '2': 'c'}, 'values': {'0': 1, '1': 2, '2': 3}, } x2 = foo_class(a=0, b=0, c=0) restored_x1 = serialization.from_state_dict(x2, legacy_encoding) self.assertEqual(type(x1), type(restored_x1)) self.assertEqual(x1, restored_x1) def test_model_serialization_to_bytes(self): rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) initial_params = module.init(rng, jnp.ones((1, 1), jnp.float32)) serialized_bytes = serialization.to_bytes(initial_params) restored_params = serialization.from_bytes(initial_params, serialized_bytes) self.assertEqual(restored_params, initial_params) def test_optimizer_serialization_to_bytes(self): rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) initial_params = module.init(rng, jnp.ones((1, 1), jnp.float32)) # model = nn.Model(module, initial_params) tx = optax.sgd(0.1, momentum=0.1) tx_state = tx.init(initial_params) serialized_bytes = serialization.to_bytes(tx_state) restored_tx_state = serialization.from_bytes(tx_state, serialized_bytes) self.assertEqual(restored_tx_state, tx_state) def test_serialization_chunking(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmp = serialization._chunk_array_leaves_in_place(tmp) finally: serialization.MAX_CHUNK_SIZE = old_chunksize test = jax.tree_util.tree_map(jnp.shape, tmp) ref = { 'a': { '__msgpack_chunked_array__': (), 'chunks': {'0': (91,), '1': (9,)}, 'shape': {'0': (), '1': ()}, } } self.assertEqual(test, ref) def test_serialization_chunking2(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.to_bytes(tmp) newtmp = serialization.from_bytes(tmp, tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_util.tree_map(np.testing.assert_array_equal, tmp, newtmp) def test_serialization_chunking3(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.msgpack_serialize(tmp) newtmp = serialization.msgpack_restore(tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_util.tree_map(np.testing.assert_array_equal, tmp, newtmp) @parameterized.parameters( { 'target': [[[1, 2, 3], [4, 5]]], 'wrong_target': [[[1, 2, 3], [4]]], 'msg': ( 'The size of the list and the state dict do not match,' ' got 1 and 2 at path ./0/1' ), }, { 'target': (((1, 2, 3), (4, 5)),), 'wrong_target': (((1, 2, 3), (4,)),), 'msg': ( 'The size of the list and the state dict do not match,' ' got 1 and 2 at path ./0/1' ), }, { 'target': (((1, 2, 3), (OriginalTuple([4, 5]), 6)),), 'wrong_target': (((1, 2, 3), (WrongTuple([4, 5]), 6)),), 'msg': ( 'The field names of the state dict and the named tuple do ' "not match, got {'value'} and {'wrong_field'} at path ./0/1/0" ), }, { 'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'd': [4]}}}, 'msg': ( 'The size of the list and the state dict do not match,' ' got 1 and 2 at path ./a/b/d' ), }, { 'target': {'a': {'b': {'c': [1, 2, 3], 'd': [4, 5]}}}, 'wrong_target': {'a': {'b': {'c': [1, 2, 3], 'e': [4, 5]}}}, 'msg': ( 'The target dict keys and state dict keys do not match, target' " dict contains keys {'e'} which are not present in state dict at" ' path ./a/b' ), }, { 'target': 'original_params', 'wrong_target': 'wrong_params', 'msg': ( 'The target dict keys and state dict keys do not match, target' " dict contains keys {'Dense_1'} which are not present in state" ' dict at path ./params' ), }, { 'target': 'original_train_state', 'wrong_target': 'wrong_train_state', 'msg': ( 'The target dict keys and state dict keys do not match, target' " dict contains keys {'Dense_1'} which are not present in state" ' dict at path ./params/params' ), }, ) def test_serialization_errors(self, target, wrong_target, msg): if target == 'original_params': x = jnp.ones((1, 28, 28, 1)) rng = jax.random.key(1) original_module = OriginalModule() target = original_module.init(rng, x) wrong_module = WrongModule() wrong_target = wrong_module.init(rng, x) elif target == 'original_train_state': x = jnp.ones((1, 28, 28, 1)) rng = jax.random.key(1) original_module = OriginalModule() original_params = original_module.init(rng, x) wrong_module = WrongModule() wrong_params = wrong_module.init(rng, x) tx = optax.sgd(learning_rate=0.1, momentum=0.9) target = train_state.TrainState.create( apply_fn=original_module.apply, params=original_params, tx=tx ) wrong_target = train_state.TrainState.create( apply_fn=wrong_module.apply, params=wrong_params, tx=tx ) encoded_bytes = serialization.to_bytes(target) with self.assertRaisesWithLiteralMatch(ValueError, msg): serialization.from_bytes(wrong_target, encoded_bytes) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/struct_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.struct.""" import dataclasses from typing import Any, Generic, TypeVar import jax from absl.testing import absltest, parameterized from jax._src.tree_util import prefix_errors from flax import struct # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() @struct.dataclass class Point: x: float y: float meta: Any = struct.field(pytree_node=False) class StructTest(parameterized.TestCase): def test_no_extra_fields(self): p = Point(x=1, y=2, meta={}) with self.assertRaises(dataclasses.FrozenInstanceError): p.new_field = 1 def test_mutation(self): p = Point(x=1, y=2, meta={}) new_p = p.replace(x=3) self.assertEqual(new_p, Point(x=3, y=2, meta={})) with self.assertRaises(dataclasses.FrozenInstanceError): p.y = 3 def test_slots(self): @struct.dataclass(frozen=False, slots=True) class SlotsPoint: x: float y: float p = SlotsPoint(x=1., y=2.) p.x = 3. # can assign to existing fields self.assertEqual(p, SlotsPoint(x=3., y=2.)) with self.assertRaises(AttributeError): p.z = 0. # can't create new fields by accident. def test_pytree_nodes(self): p = Point(x=1, y=2, meta={'abc': True}) leaves = jax.tree_util.tree_leaves(p) self.assertEqual(leaves, [1, 2]) new_p = jax.tree_util.tree_map(lambda x: x + x, p) self.assertEqual(new_p, Point(x=2, y=4, meta={'abc': True})) def test_keypath_error(self): # TODO(mattjj): avoid using internal prefix_errors by testing vmap error msg (e,) = prefix_errors(Point(1.0, [2.0], meta={}), Point(1.0, 2.0, meta={})) with self.assertRaisesRegex(ValueError, r'in_axes\.y'): raise e('in_axes') def test_double_wrap_no_op(self): class A: a: int self.assertFalse(hasattr(A, '_flax_dataclass')) A = struct.dataclass(A) self.assertTrue(hasattr(A, '_flax_dataclass')) A = struct.dataclass(A) # no-op self.assertTrue(hasattr(A, '_flax_dataclass')) def test_wrap_pytree_node_no_error(self): @struct.dataclass class A(struct.PyTreeNode): a: int @parameterized.parameters( {'mode': 'dataclass'}, {'mode': 'pytreenode'}, ) def test_kw_only(self, mode): if mode == 'dataclass': @struct.dataclass class A: a: int = 1 @struct.dataclass(kw_only=True) class B(A): b: int elif mode == 'pytreenode': class A(struct.PyTreeNode): a: int = 1 class B(A, struct.PyTreeNode, kw_only=True): b: int obj = B(b=2) self.assertEqual(obj.a, 1) self.assertEqual(obj.b, 2) with self.assertRaisesRegex(TypeError, "non-default argument 'b' follows default argument"): if mode == 'dataclass': @struct.dataclass class B(A): b: int elif mode == 'pytreenode': class B(A, struct.PyTreeNode): b: int def test_metadata_pass_through(self): @struct.dataclass class A: foo: int = struct.field(default=9, metadata={'baz': 9}) assert A.__dataclass_fields__['foo'].metadata == {'baz': 9, 'pytree_node': True} @parameterized.parameters( {'mode': 'dataclass'}, {'mode': 'pytreenode'}, ) def test_mutable(self, mode): if mode == 'dataclass': @struct.dataclass class A: a: int = 1 @struct.dataclass(frozen=False) class B: b: int = 1 elif mode == 'pytreenode': class A(struct.PyTreeNode): a: int = 1 class B(struct.PyTreeNode, frozen=False): b: int = 1 obj = A() with self.assertRaisesRegex(dataclasses.FrozenInstanceError, "cannot assign to field 'a'"): obj.a = 2 obj = B() obj.b = 2 self.assertEqual(obj.b, 2) def test_generic_pytreenode_base_order(self): # PyTreeNode + Generic should work regardless of base order (#5233). T = TypeVar('T') U = TypeVar('U') # Generic after PyTreeNode. class A(struct.PyTreeNode, Generic[T, U]): x: int = 0 self.assertEqual(A.__parameters__, (T, U)) A[int, int] # should not raise # Generic before PyTreeNode. class B(Generic[T, U], struct.PyTreeNode): x: int = 0 self.assertEqual(B.__parameters__, (T, U)) B[int, int] # should not raise # Subclassing a parameterized generic PyTreeNode. class Base(struct.PyTreeNode, Generic[T, U]): x: int = 0 class Sub(Base[int, str]): y: int = 1 obj = Sub(x=1, y=2) self.assertEqual(obj.x, 1) self.assertEqual(obj.y, 2) leaves = jax.tree_util.tree_leaves(obj) self.assertEqual(leaves, [1, 2]) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/tensorboard_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.metrics.tensorboard.""" import itertools import pathlib import tempfile import numpy as np import tensorflow as tf from absl.testing import absltest from tensorboard.backend.event_processing import ( directory_watcher, event_file_loader, ) from tensorboard.util import tensor_util from flax.metrics.tensorboard import SummaryWriter, _flatten_dict def _process_event(event): for value in event.summary.value: yield {'wall_time': event.wall_time, 'step': event.step, 'value': value} def _disk_usage(path: pathlib.Path): """Recursively computes the disk usage of a directory.""" if path.is_file(): return path.stat().st_size elif path.is_dir(): size_bytes = 0 for file in path.iterdir(): size_bytes += _disk_usage(file) return size_bytes else: raise NotImplementedError('What filetype is {file}?') class TensorboardTest(absltest.TestCase): def parse_and_return_summary_value(self, path): """Parse the event file in the given path and return the only summary value.""" event_value_list = [] event_file_generator = directory_watcher.DirectoryWatcher( path, event_file_loader.EventFileLoader ).Load() event_values = itertools.chain.from_iterable( map(_process_event, event_file_generator) ) for value_dict in event_values: event_value_list.append(value_dict) self.assertLen(event_value_list, 1) self.assertEqual(event_value_list[0]['step'], 1) self.assertGreater(event_value_list[0]['wall_time'], 0.0) return event_value_list[0]['value'] def test_summarywriter_flush_after_close(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) summary_writer.close() with self.assertRaises(AttributeError): summary_writer.flush() def test_summarywriter_scalar(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) # Write the scalar and check if the event exists and check data. float_value = 99.1232 summary_writer.scalar(tag='scalar_test', value=float_value, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'scalar_test') self.assertTrue( np.allclose( tensor_util.make_ndarray(summary_value.tensor).item(), float_value ) ) def test_summarywriter_text(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) text = 'hello world.' summary_writer.text(tag='text_test', textdata=text, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'text_test') self.assertEqual( tensor_util.make_ndarray(summary_value.tensor).item().decode('utf-8'), text, ) def test_summarywriter_image(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = np.random.uniform(low=0.0, high=255.0, size=(30, 30, 3)) expected_img = expected_img.astype(np.uint8) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(np.allclose(actual_img, expected_img)) def test_summarywriter_image_float_pixel_values(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = np.random.uniform(low=0.0, high=1.0, size=(30, 30, 3)) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) # convert and scale expected_img appropriately to numpy uint8. expected_img = tf.image.convert_image_dtype( image=expected_img, dtype=np.uint8 ) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(np.allclose(actual_img, expected_img)) def test_summarywriter_2dimage_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = np.random.uniform(low=0.0, high=255.0, size=(30, 30)) img = img.astype(np.uint8) summary_writer.image(tag='2dimage_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, '2dimage_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3)) def test_summarywriter_single_channel_image_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = np.random.uniform(low=0.0, high=255.0, size=(30, 30, 1)) img = img.astype(np.uint8) summary_writer.image(tag='2dimage_1channel_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, '2dimage_1channel_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3)) def test_summarywriter_multiple_images(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = np.random.uniform(low=0.0, high=255.0, size=(2, 30, 30, 3)) expected_img = expected_img.astype(np.uint8) summary_writer.image(tag='multiple_images_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_images_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] self.assertTrue(np.allclose(np.stack(actual_imgs, axis=0), expected_img)) def test_summarywriter_multiple_2dimages_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = np.random.uniform(low=0.0, high=255.0, size=(2, 30, 30)) img = img.astype(np.uint8) summary_writer.image(tag='multiple_2dimages_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_2dimages_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] # assert the images were increased in dimension self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3)) def test_summarywriter_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = np.random.uniform( low=-1.0, high=1.0, size=(2, 48000, 2) ) summary_writer.audio( tag='audio_test', audiodata=expected_audio_samples, step=1 ) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert two audio files are parsed. self.assertLen(summary_value.tensor.string_val, 2) # Assert values. actual_audio_1 = tf.audio.decode_wav( summary_value.tensor.string_val[0] ).audio self.assertTrue( np.allclose(expected_audio_samples[0], actual_audio_1, atol=1e-04) ) actual_audio_2 = tf.audio.decode_wav( summary_value.tensor.string_val[1] ).audio self.assertTrue( np.allclose(expected_audio_samples[1], actual_audio_2, atol=1e-04) ) def test_summarywriter_audio_sampled_output(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = np.random.uniform( low=-1.0, high=1.0, size=(2, 48000, 2) ) summary_writer.audio( tag='audio_test', audiodata=expected_audio_samples, step=1, max_outputs=1, ) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert only the first audio clip is available. self.assertLen(summary_value.tensor.string_val, 1) # Assert values. actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio self.assertTrue( np.allclose(expected_audio_samples[0], actual_audio, atol=1e-04) ) def test_summarywriter_clipped_audio(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_audio_samples = np.random.uniform( low=-2.0, high=2.0, size=(2, 48000, 2) ) summary_writer.audio( tag='audio_test', audiodata=expected_audio_samples, step=1, max_outputs=1, ) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'audio_test') # Assert one audio files is parsed. self.assertLen(summary_value.tensor.string_val, 1) # actual_audio is clipped. actual_audio = tf.audio.decode_wav(summary_value.tensor.string_val[0]).audio self.assertFalse( np.allclose(expected_audio_samples[0], actual_audio, atol=1e-04) ) clipped_audio = np.clip(np.array(expected_audio_samples[0]), -1, 1) self.assertTrue(np.allclose(clipped_audio, actual_audio, atol=1e-04)) def test_summarywriter_histogram_defaultbins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = np.arange(1000) # Histogram will be created for 30 (default) bins. summary_writer.histogram(tag='histogram_test', values=histogram, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'histogram_test') actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (30, 3)) self.assertTrue( np.allclose(actual_histogram[0], (0.0, 33.3, 34.0), atol=1e-01) ) def test_summarywriter_histogram_2bins(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) histogram = np.arange(1000) summary_writer.histogram( tag='histogram_test', values=histogram, step=1, bins=2 ) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'histogram_test') actual_histogram = tensor_util.make_ndarray(summary_value.tensor) self.assertTrue(actual_histogram.shape, (2, 3)) self.assertTrue( np.allclose(actual_histogram[0], (0.0, 499.5, 500.0), atol=1e-01) ) self.assertTrue( np.allclose(actual_histogram[1], (499.5, 999.0, 500.0), atol=1e-01) ) def test_flatten_dict(self): # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py input_hparams = { # Example Invalid Types 'None': None, 'List': [1, 2, 3], 'Tuple': (1, 2, 3), 'Complex': complex('1+1j'), 'np.complex_': np.complex128('1+1j'), # Valid Python Types 'Bool': True, 'Int': 1, 'Float': 1.0, 'Str': 'test', # Valid Numpy Types 'np.bool_': np.bool_(1), 'np.integer': np.int_(1), 'np.floating': np.float64(1.0), 'np.character': np.str_('test'), # Nested dict to flatten 'Nested_Dict': { 'None': None, 'List': [1, 2, 3], 'Tuple': (1, 2, 3), 'Complex': complex('1+1j'), 'np.complex_': np.complex128('1+1j'), 'Bool': True, 'Int': 1, 'Float': 1.0, 'Str': 'test', 'np.bool_': np.bool_(1), 'np.integer': np.int_(1), 'np.floating': np.float64(1.0), 'np.character': np.str_('test'), }, } result_hparams = _flatten_dict(input_hparams) expected_hparams = { 'None': 'None', 'List': '[1, 2, 3]', 'Tuple': '(1, 2, 3)', 'Complex': '(1+1j)', 'np.complex_': '(1+1j)', # Valid Python Types 'Bool': True, 'Int': 1, 'Float': 1.0, 'Str': 'test', # Valid Numpy Types 'np.bool_': np.bool_(1), 'np.integer': np.int_(1), 'np.floating': np.float64(1.0), 'np.character': np.str_('test'), # Nested Dict 'Nested_Dict.None': 'None', 'Nested_Dict.List': '[1, 2, 3]', 'Nested_Dict.Tuple': '(1, 2, 3)', 'Nested_Dict.Complex': '(1+1j)', 'Nested_Dict.np.complex_': '(1+1j)', 'Nested_Dict.Bool': True, 'Nested_Dict.Int': 1, 'Nested_Dict.Float': 1.0, 'Nested_Dict.Str': 'test', 'Nested_Dict.np.bool_': np.bool_(1), 'Nested_Dict.np.integer': np.int_(1), 'Nested_Dict.np.floating': np.float64(1.0), 'Nested_Dict.np.character': np.str_('test'), } self.assertDictEqual(result_hparams, expected_hparams) def test_auto_flush(self): tmp_dir = pathlib.Path(self.create_tempdir().full_path) summary_writer = SummaryWriter(tmp_dir, auto_flush=True) summary_writer.scalar('metric', 123, 1) filesize_before_flush = _disk_usage(tmp_dir) summary_writer.flush() filesize_after_flush = _disk_usage(tmp_dir) self.assertEqual(filesize_before_flush, filesize_after_flush) def test_no_auto_flush(self): tmp_dir = pathlib.Path(self.create_tempdir().full_path) summary_writer = SummaryWriter(tmp_dir, auto_flush=False) summary_writer.scalar('metric', 123, 1) filesize_before_flush = _disk_usage(tmp_dir) summary_writer.flush() filesize_after_flush = _disk_usage(tmp_dir) self.assertLess(filesize_before_flush, filesize_after_flush) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/traceback_util_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.traceback_util.""" import contextlib import sys import traceback import jax from absl.testing import absltest from jax import numpy as jnp from jax import random from jax._src import traceback_util as jax_traceback_util from flax import linen as nn from flax import traceback_util # pylint: disable=arguments-differ,protected-access, g-wrong-blank-lines # __tracebackhide__ is a python >=3.7 feature. TRACEBACKHIDE_SUPPORTED = tuple(sys.version_info)[:3] >= (3, 7, 0) EXPECTED_FILES = (__file__, contextlib.__spec__.origin) class TracebackTest(absltest.TestCase): def test_exclusion_list(self): traceback_util.show_flax_in_tracebacks() exclusion_len_wo_flax = len(jax_traceback_util._exclude_paths) traceback_util.hide_flax_in_tracebacks() exclusion_len_w_flax = len(jax_traceback_util._exclude_paths) self.assertLen( traceback_util._flax_exclusions, exclusion_len_w_flax - exclusion_len_wo_flax, ) def test_simple_exclusion_tracebackhide(self): if not TRACEBACKHIDE_SUPPORTED: return class Test1(nn.Module): @nn.remat @nn.compact def __call__(self, x): return Test2()(x) class Test2(nn.Module): @nn.jit @nn.compact def __call__(self, x): raise ValueError('error here.') return x # pylint: disable=unreachable traceback_util.hide_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'tracebackhide') key = random.key(0) try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: tb = e.__traceback__ filtered_frames = 0 unfiltered_frames = 0 for f, _ in traceback.walk_tb(tb): if '__tracebackhide__' not in f.f_locals: self.assertIn(f.f_code.co_filename, EXPECTED_FILES) filtered_frames += 1 unfiltered_frames += 1 self.assertEqual(filtered_frames, 3) self.assertGreater(unfiltered_frames, filtered_frames) def test_simple_exclusion_remove_frames(self): class Test1(nn.Module): @nn.remat @nn.compact def __call__(self, x): return Test2()(x) class Test2(nn.Module): @nn.jit @nn.compact def __call__(self, x): raise ValueError('error here.') return x # pylint: disable=unreachable traceback_util.hide_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'remove_frames') key = random.key(0) try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: tb_filtered = e.__traceback__ tb_unfiltered = e.__cause__.__traceback__ e_cause = e.__cause__ self.assertIsInstance(e_cause, jax_traceback_util.UnfilteredStackTrace) filtered_frames = 0 for _, _ in traceback.walk_tb(tb_filtered): filtered_frames += 1 unfiltered_frames = 0 for _, _ in traceback.walk_tb(tb_unfiltered): unfiltered_frames += 1 self.assertEqual(filtered_frames, 3) self.assertGreater(unfiltered_frames, filtered_frames) def test_dynamic_exclusion(self): if not TRACEBACKHIDE_SUPPORTED: return class Test1(nn.Module): @nn.remat @nn.compact def __call__(self, x): return Test2()(x) class Test2(nn.Module): @nn.jit @nn.compact def __call__(self, x): raise ValueError('error here.') return x # pylint: disable=unreachable key = random.key(0) traceback_util.show_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'off') try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: tb_all = e.__traceback__ traceback_util.hide_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'tracebackhide') try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: tb_no_flax = e.__traceback__ traceback_util.show_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'tracebackhide') try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: tb_w_flax = e.__traceback__ filtered_frames_all = 0 unfiltered_frames_all = 0 for f, _ in traceback.walk_tb(tb_all): if '__tracebackhide__' not in f.f_locals: unfiltered_frames_all += 1 else: filtered_frames_all += 1 filtered_frames_no_flax = 0 unfiltered_frames_no_flax = 0 for f, _ in traceback.walk_tb(tb_no_flax): if '__tracebackhide__' not in f.f_locals: self.assertIn(f.f_code.co_filename, EXPECTED_FILES) unfiltered_frames_no_flax += 1 else: filtered_frames_no_flax += 1 filtered_frames_w_flax = 0 unfiltered_frames_w_flax = 0 for f, _ in traceback.walk_tb(tb_w_flax): if '__tracebackhide__' not in f.f_locals: unfiltered_frames_w_flax += 1 else: filtered_frames_w_flax += 1 self.assertEqual( unfiltered_frames_all + filtered_frames_all, unfiltered_frames_w_flax + filtered_frames_w_flax, ) self.assertEqual( unfiltered_frames_all + filtered_frames_all, unfiltered_frames_no_flax + filtered_frames_no_flax, ) self.assertEqual(unfiltered_frames_no_flax, 3) self.assertGreater(unfiltered_frames_all, unfiltered_frames_w_flax) self.assertGreater(unfiltered_frames_w_flax, unfiltered_frames_no_flax) if __name__ == '__main__': absltest.main() ================================================ FILE: tests/traverse_util_test.py ================================================ # Copyright 2024 The Flax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for flax.traverse_util.""" import collections import jax import jax.numpy as jnp import numpy as np import optax from absl.testing import absltest import flax from flax import traverse_util from flax.core import freeze # Parse absl flags test_srcdir and test_tmpdir. jax.config.parse_flags_with_absl() class Foo: def __init__(self, foo, bar=None): self.foo = foo self.bar = bar def __eq__(self, other): return self.foo == other.foo and self.bar == other.bar Point = collections.namedtuple('Point', ['x', 'y']) class TraversalTest(absltest.TestCase): def test_traversal_id(self): x = 1 traversal = traverse_util.t_identity self.assertEqual(list(traversal.iterate(x)), [1]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, 2) def test_traverse_item(self): x = {'foo': 1} traversal = traverse_util.t_identity['foo'] self.assertEqual(list(traversal.iterate(x)), [1]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, {'foo': 2}) def test_traverse_tuple_item(self): x = (1, 2, 3) traversal = traverse_util.t_identity[1] self.assertEqual(list(traversal.iterate(x)), [2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, (1, 4, 3)) def test_traverse_tuple_items(self): x = (1, 2, 3, 4) traversal = traverse_util.t_identity[1:3] self.assertEqual(list(traversal.iterate(x)), [2, 3]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, (1, 4, 6, 4)) def test_traverse_namedtuple_item(self): x = Point(x=1, y=2) traversal = traverse_util.t_identity[1] self.assertEqual(list(traversal.iterate(x)), [2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, Point(x=1, y=4)) def test_traverse_attr(self): x = Foo(foo=1) traversal = traverse_util.t_identity.foo self.assertEqual(list(traversal.iterate(x)), [1]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, Foo(foo=2)) def test_traverse_namedtuple_attr(self): x = Point(x=1, y=2) traversal = traverse_util.t_identity.y self.assertEqual(list(traversal.iterate(x)), [2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, Point(x=1, y=4)) def test_traverse_dataclass_attr(self): x = Point(x=1, y=2) traversal = traverse_util.t_identity.y self.assertEqual(list(traversal.iterate(x)), [2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, Point(x=1, y=4)) def test_traverse_merge(self): x = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}] traversal_base = traverse_util.t_identity.each() traversal = traversal_base.merge( traverse_util.TraverseItem('foo'), traverse_util.TraverseItem('bar') ) self.assertEqual(list(traversal.iterate(x)), [1, 2, 3, 4]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, [{'foo': 2, 'bar': 4}, {'foo': 6, 'bar': 8}]) def test_traverse_each(self): x = [{'foo': 1}, {'foo': 2}] traversal = traverse_util.t_identity.each()['foo'] self.assertEqual(list(traversal.iterate(x)), [1, 2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, [{'foo': 2}, {'foo': 4}]) def test_traverse_each_dict(self): x = {'foo': 1, 'bar': 2} traversal = traverse_util.t_identity.each() self.assertEqual(list(traversal.iterate(x)), [1, 2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, {'foo': 2, 'bar': 4}) def test_traverse_tree(self): x = [{'foo': 1}, {'bar': 2}] traversal = traverse_util.t_identity.tree() self.assertEqual(list(traversal.iterate(x)), [1, 2]) y = traversal.update(lambda x: x + x, x) self.assertEqual(y, [{'foo': 2}, {'bar': 4}]) def test_traverse_filter(self): x = [1, -2, 3, -4] traversal = traverse_util.t_identity.each().filter(lambda x: x < 0) self.assertEqual(list(traversal.iterate(x)), [-2, -4]) y = traversal.update(lambda x: -x, x) self.assertEqual(y, [1, 2, 3, 4]) def test_traversal_set(self): x = {'foo': [1, 2]} traversal = traverse_util.t_identity['foo'].each() y = traversal.set([3, 4], x) self.assertEqual(y, {'foo': [3, 4]}) with self.assertRaises(ValueError): traversal.set([3], x) # too few values with self.assertRaises(ValueError): traversal.set([3, 4, 5], x) # too many values def test_flatten_dict(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs) self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, }, ) flat_xs = traverse_util.flatten_dict(freeze(xs)) self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, }, ) flat_xs = traverse_util.flatten_dict(xs, sep='/') self.assertEqual( flat_xs, { 'foo': 1, 'bar/a': 2, }, ) def test_unflatten_dict(self): expected_xs = {'foo': 1, 'bar': {'a': 2}} xs = traverse_util.unflatten_dict( { ('foo',): 1, ('bar', 'a'): 2, } ) self.assertEqual(xs, expected_xs) xs = traverse_util.unflatten_dict( { 'foo': 1, 'bar/a': 2, }, sep='/', ) self.assertEqual(xs, expected_xs) def test_flatten_dict_keep_empty(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True) self.assertEqual( flat_xs, { ('foo',): 1, ('bar', 'a'): 2, ('bar', 'b'): traverse_util.empty_node, }, ) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) def test_flatten_dict_is_leaf(self): xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict( xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2 ) self.assertEqual( flat_xs, { ('foo', 'c'): 4, ('bar',): {'a': 2, 'b': {}}, }, ) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore) class ModelParamTraversalTest(absltest.TestCase): def test_only_works_on_model_params(self): traversal = traverse_util.ModelParamTraversal(lambda *_: True) with self.assertRaises(ValueError): list(traversal.iterate([])) def test_param_selection(self): params = { 'x': { 'kernel': 1, 'bias': 2, 'y': { 'kernel': 3, 'bias': 4, }, 'z': {}, }, } expected_params = { 'x': { 'kernel': 2, 'bias': 2, 'y': { 'kernel': 6, 'bias': 4, }, 'z': {}, }, } names = [] def filter_fn(name, _): names.append(name) # track names passed to filter_fn for testing return 'kernel' in name traversal = traverse_util.ModelParamTraversal(filter_fn) values = list(traversal.iterate(params)) configs = [ (params, expected_params), (flax.core.FrozenDict(params), flax.core.FrozenDict(expected_params)), ] for model, expected_model in configs: self.assertEqual(values, [1, 3]) self.assertEqual( set(names), {'/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias'} ) new_model = traversal.update(lambda x: x + x, model) self.assertEqual(new_model, expected_model) def test_path_value(self): params_in = {'a': {'b': 10, 'c': 2}} params_out = traverse_util.path_aware_map( lambda path, x: x + 1 if 'b' in path else -x, params_in ) self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}}) def test_path_aware_map_with_multi_transform(self): params = { 'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}, } gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients param_labels = traverse_util.path_aware_map( lambda path, x: 'kernel' if 'w' in path else 'bias', params ) tx = optax.multi_transform( {'kernel': optax.sgd(1.0), 'bias': optax.set_to_zero()}, param_labels ) state = tx.init(params) updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates) self.assertTrue( np.allclose(new_params['linear_1']['b'], params['linear_1']['b']) ) self.assertTrue( np.allclose(new_params['linear_2']['b'], params['linear_2']['b']) ) self.assertFalse( np.allclose(new_params['linear_1']['w'], params['linear_1']['w']) ) self.assertFalse( np.allclose(new_params['linear_2']['w'], params['linear_2']['w']) ) def test_path_aware_map_with_masked(self): params = { 'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}, } gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients params_mask = traverse_util.path_aware_map( lambda path, x: 'w' in path, params ) tx = optax.masked(optax.sgd(1.0), params_mask) state = tx.init(params) updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates) self.assertTrue( np.allclose(new_params['linear_1']['b'], gradients['linear_1']['b']) ) self.assertTrue( np.allclose(new_params['linear_2']['b'], gradients['linear_2']['b']) ) self.assertTrue( np.allclose(new_params['linear_1']['w'], -gradients['linear_1']['w']) ) self.assertTrue( np.allclose(new_params['linear_2']['w'], -gradients['linear_2']['w']) ) def test_path_aware_map_with_empty_nodes(self): params_in = {'a': {'b': 10, 'c': 2}, 'b': {}} params_out = traverse_util.path_aware_map( lambda path, x: x + 1 if 'b' in path else -x, params_in ) self.assertEqual(params_out, {'a': {'b': 11, 'c': -2}, 'b': {}}) if __name__ == '__main__': absltest.main()